feat: add provider_category field to distinguish byok (#2038)

This commit is contained in:
cthomas
2025-05-06 17:31:36 -07:00
committed by GitHub
parent 230eb944d1
commit db6982a4bc
23 changed files with 250 additions and 112 deletions

View File

@@ -0,0 +1,31 @@
"""add provider category
Revision ID: 878607e41ca4
Revises: 0335b1eb9c40
Create Date: 2025-05-06 12:10:25.751536
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "878607e41ca4"
down_revision: Union[str, None] = "0335b1eb9c40"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("providers", sa.Column("provider_category", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("providers", "provider_category")
# ### end Alembic commands ###

View File

@@ -331,10 +331,9 @@ class Agent(BaseAgent):
log_telemetry(self.logger, "_get_ai_reply create start")
# New LLM client flow
llm_client = LLMClient.create(
provider_name=self.agent_state.llm_config.provider_name,
provider_type=self.agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=self.user.id,
actor=self.user,
)
if llm_client and not stream:
@@ -943,7 +942,10 @@ class Agent(BaseAgent):
model_endpoint=self.agent_state.llm_config.model_endpoint,
context_window_limit=self.agent_state.llm_config.context_window,
usage=response.usage,
provider_id=self.provider_manager.get_provider_id_from_name(self.agent_state.llm_config.provider_name),
provider_id=self.provider_manager.get_provider_id_from_name(
self.agent_state.llm_config.provider_name,
actor=self.user,
),
job_id=job_id,
)
for message in all_new_messages:
@@ -1087,7 +1089,9 @@ class Agent(BaseAgent):
LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"]
)
summary = summarize_messages(agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize)
summary = summarize_messages(
agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize, actor=self.user
)
logger.info(f"Got summary: {summary}")
# Metadata that's useful for the agent to see

View File

@@ -75,10 +75,9 @@ class LettaAgent(BaseAgent):
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor_id=self.actor.id,
actor=self.actor,
)
for _ in range(max_steps):
response = await self._get_ai_reply(
@@ -120,10 +119,9 @@ class LettaAgent(BaseAgent):
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor_id=self.actor.id,
actor=self.actor,
)
for _ in range(max_steps):

View File

@@ -172,10 +172,9 @@ class LettaAgentBatch:
log_event(name="init_llm_client")
llm_client = LLMClient.create(
provider_name=agent_states[0].llm_config.provider_name,
provider_type=agent_states[0].llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor_id=self.actor.id,
actor=self.actor,
)
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
@@ -284,10 +283,9 @@ class LettaAgentBatch:
# translate providerspecific response → OpenAIstyle tool call (unchanged)
llm_client = LLMClient.create(
provider_name=item.llm_config.provider_name,
provider_type=item.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor_id=self.actor.id,
actor=self.actor,
)
tool_call = (
llm_client.convert_response_to_chat_completion(

View File

@@ -3455,7 +3455,7 @@ class LocalClient(AbstractClient):
Returns:
configs (List[LLMConfig]): List of LLM configurations
"""
return self.server.list_llm_models()
return self.server.list_llm_models(actor=self.user)
def list_embedding_configs(self) -> List[EmbeddingConfig]:
"""
@@ -3464,7 +3464,7 @@ class LocalClient(AbstractClient):
Returns:
configs (List[EmbeddingConfig]): List of embedding configurations
"""
return self.server.list_embedding_models()
return self.server.list_embedding_models(actor=self.user)
def create_org(self, name: Optional[str] = None) -> Organization:
return self.server.organization_manager.create_organization(pydantic_org=Organization(name=name))

View File

@@ -26,7 +26,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.enums import ProviderCategory
from letta.schemas.message import Message as _Message
from letta.schemas.message import MessageRole as _MessageRole
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
@@ -42,6 +42,7 @@ from letta.schemas.openai.chat_completion_response import Message
from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage
from letta.schemas.openai.chat_completion_response import MessageDelta, ToolCall, ToolCallDelta, UsageStatistics
from letta.services.provider_manager import ProviderManager
from letta.services.user_manager import UserManager
from letta.settings import model_settings
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
from letta.tracing import log_event
@@ -744,12 +745,15 @@ def anthropic_chat_completions_request(
extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
provider_category: Optional[ProviderCategory] = None,
betas: List[str] = ["tools-2024-04-04"],
user_id: Optional[str] = None,
) -> ChatCompletionResponse:
"""https://docs.anthropic.com/claude/docs/tool-use"""
anthropic_client = None
if provider_name and provider_name != ProviderType.anthropic.value:
api_key = ProviderManager().get_override_key(provider_name)
if provider_category == ProviderCategory.byok:
actor = UserManager().get_user_or_default(user_id=user_id)
api_key = ProviderManager().get_override_key(provider_name, actor=actor)
anthropic_client = anthropic.Anthropic(api_key=api_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
@@ -803,7 +807,9 @@ def anthropic_chat_completions_request_stream(
extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
provider_category: Optional[ProviderCategory] = None,
betas: List[str] = ["tools-2024-04-04"],
user_id: Optional[str] = None,
) -> Generator[ChatCompletionChunkResponse, None, None]:
"""Stream chat completions from Anthropic API.
@@ -817,8 +823,9 @@ def anthropic_chat_completions_request_stream(
extended_thinking=extended_thinking,
max_reasoning_tokens=max_reasoning_tokens,
)
if provider_name and provider_name != ProviderType.anthropic.value:
api_key = ProviderManager().get_override_key(provider_name)
if provider_category == ProviderCategory.byok:
actor = UserManager().get_user_or_default(user_id=user_id)
api_key = ProviderManager().get_override_key(provider_name, actor=actor)
anthropic_client = anthropic.Anthropic(api_key=api_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
@@ -867,10 +874,12 @@ def anthropic_chat_completions_process_stream(
extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
provider_category: Optional[ProviderCategory] = None,
create_message_id: bool = True,
create_message_datetime: bool = True,
betas: List[str] = ["tools-2024-04-04"],
name: Optional[str] = None,
user_id: Optional[str] = None,
) -> ChatCompletionResponse:
"""Process a streaming completion response from Anthropic, similar to OpenAI's streaming.
@@ -952,7 +961,9 @@ def anthropic_chat_completions_process_stream(
extended_thinking=extended_thinking,
max_reasoning_tokens=max_reasoning_tokens,
provider_name=provider_name,
provider_category=provider_category,
betas=betas,
user_id=user_id,
)
):
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)

View File

@@ -27,7 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool
@@ -45,18 +45,18 @@ logger = get_logger(__name__)
class AnthropicClient(LLMClientBase):
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
client = self._get_anthropic_client(async_client=False)
client = self._get_anthropic_client(llm_config, async_client=False)
response = client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
return response.model_dump()
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
client = self._get_anthropic_client(async_client=True)
client = self._get_anthropic_client(llm_config, async_client=True)
response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
return response.model_dump()
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]:
client = self._get_anthropic_client(async_client=True)
client = self._get_anthropic_client(llm_config, async_client=True)
request_data["stream"] = True
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
@@ -96,7 +96,7 @@ class AnthropicClient(LLMClientBase):
for agent_id in agent_messages_mapping
}
client = self._get_anthropic_client(async_client=True)
client = self._get_anthropic_client(list(agent_llm_config_mapping.values())[0], async_client=True)
anthropic_requests = [
Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items()
@@ -112,10 +112,12 @@ class AnthropicClient(LLMClientBase):
raise self.handle_llm_error(e)
@trace_method
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
def _get_anthropic_client(
self, llm_config: LLMConfig, async_client: bool = False
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
override_key = None
if self.provider_name and self.provider_name != ProviderType.anthropic.value:
override_key = ProviderManager().get_override_key(self.provider_name)
if llm_config.provider_category == ProviderCategory.byok:
override_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
if async_client:
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()

View File

@@ -13,7 +13,7 @@ from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.json_parser import clean_json_string_extra_backslash
from letta.local_llm.utils import count_tokens
from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool
@@ -31,10 +31,10 @@ class GoogleAIClient(LLMClientBase):
Performs underlying request to llm and returns raw response.
"""
api_key = None
if llm_config.provider_name and llm_config.provider_name != ProviderType.google_ai.value:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name)
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
if not api_key:
api_key = model_settings.gemini_api_key

View File

@@ -24,7 +24,7 @@ from letta.llm_api.openai import (
from letta.local_llm.chat_completion_proxy import get_chat_completion
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.schemas.enums import ProviderType
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
@@ -172,10 +172,12 @@ def create(
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])
elif llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
elif llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
from letta.services.user_manager import UserManager
api_key = ProviderManager().get_override_key(llm_config.provider_name)
actor = UserManager().get_user_or_default(user_id=user_id)
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=actor)
elif model_settings.openai_api_key is None:
# the openai python client requires a dummy API key
api_key = "DUMMY_API_KEY"
@@ -379,7 +381,9 @@ def create(
extended_thinking=llm_config.enable_reasoner,
max_reasoning_tokens=llm_config.max_reasoning_tokens,
provider_name=llm_config.provider_name,
provider_category=llm_config.provider_category,
name=name,
user_id=user_id,
)
else:
@@ -390,6 +394,8 @@ def create(
extended_thinking=llm_config.enable_reasoner,
max_reasoning_tokens=llm_config.max_reasoning_tokens,
provider_name=llm_config.provider_name,
provider_category=llm_config.provider_category,
user_id=user_id,
)
if llm_config.put_inner_thoughts_in_kwargs:

View File

@@ -1,8 +1,11 @@
from typing import Optional
from typing import TYPE_CHECKING, Optional
from letta.llm_api.llm_client_base import LLMClientBase
from letta.schemas.enums import ProviderType
if TYPE_CHECKING:
from letta.orm import User
class LLMClient:
"""Factory class for creating LLM clients based on the model endpoint type."""
@@ -10,9 +13,8 @@ class LLMClient:
@staticmethod
def create(
provider_type: ProviderType,
provider_name: Optional[str] = None,
put_inner_thoughts_first: bool = True,
actor_id: Optional[str] = None,
actor: Optional["User"] = None,
) -> Optional[LLMClientBase]:
"""
Create an LLM client based on the model endpoint type.
@@ -32,33 +34,29 @@ class LLMClient:
from letta.llm_api.google_ai_client import GoogleAIClient
return GoogleAIClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=actor_id,
actor=actor,
)
case ProviderType.google_vertex:
from letta.llm_api.google_vertex_client import GoogleVertexClient
return GoogleVertexClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=actor_id,
actor=actor,
)
case ProviderType.anthropic:
from letta.llm_api.anthropic_client import AnthropicClient
return AnthropicClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=actor_id,
actor=actor,
)
case ProviderType.openai:
from letta.llm_api.openai_client import OpenAIClient
return OpenAIClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=actor_id,
actor=actor,
)
case _:
return None

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from anthropic.types.beta.messages import BetaMessageBatch
from openai import AsyncStream, Stream
@@ -11,6 +11,9 @@ from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.tracing import log_event
if TYPE_CHECKING:
from letta.orm import User
class LLMClientBase:
"""
@@ -20,13 +23,11 @@ class LLMClientBase:
def __init__(
self,
provider_name: Optional[str] = None,
put_inner_thoughts_first: Optional[bool] = True,
use_tool_naming: bool = True,
actor_id: Optional[str] = None,
actor: Optional["User"] = None,
):
self.actor_id = actor_id
self.provider_name = provider_name
self.actor = actor
self.put_inner_thoughts_first = put_inner_thoughts_first
self.use_tool_naming = use_tool_naming

View File

@@ -22,7 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
@@ -78,10 +78,10 @@ def supports_parallel_tool_calling(model: str) -> bool:
class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
api_key = None
if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name)
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
if not api_key:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
@@ -156,11 +156,11 @@ class OpenAIClient(LLMClientBase):
)
# always set user id for openai requests
if self.actor_id:
data.user = self.actor_id
if self.actor:
data.user = self.actor.id
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
if not self.actor_id:
if not self.actor:
# override user id for inference.letta.com
import uuid

View File

@@ -1,4 +1,4 @@
from typing import Callable, Dict, List
from typing import TYPE_CHECKING, Callable, Dict, List
from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK
from letta.llm_api.llm_api_tools import create
@@ -13,6 +13,9 @@ from letta.settings import summarizer_settings
from letta.tracing import trace_method
from letta.utils import count_tokens, printd
if TYPE_CHECKING:
from letta.orm import User
def get_memory_functions(cls: Memory) -> Dict[str, Callable]:
"""Get memory functions for a memory class"""
@@ -51,6 +54,7 @@ def _format_summary_history(message_history: List[Message]):
def summarize_messages(
agent_state: AgentState,
message_sequence_to_summarize: List[Message],
actor: "User",
):
"""Summarize a message sequence using GPT"""
# we need the context_window
@@ -63,7 +67,7 @@ def summarize_messages(
trunc_ratio = (summarizer_settings.memory_warning_threshold * context_window / summary_input_tkns) * 0.8 # For good measure...
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
summary_input = str(
[summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff])]
[summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff], actor=actor)]
+ message_sequence_to_summarize[cutoff:]
)
@@ -79,10 +83,9 @@ def summarize_messages(
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
llm_client = LLMClient.create(
provider_name=llm_config_no_inner_thoughts.provider_name,
provider_type=llm_config_no_inner_thoughts.model_endpoint_type,
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=False,
actor_id=agent_state.created_by_id,
actor=actor,
)
# try to use new client, otherwise fallback to old flow
# TODO: we can just directly call the LLM here?

View File

@@ -26,6 +26,7 @@ class Provider(SqlalchemyBase, OrganizationMixin):
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider")
provider_category: Mapped[str] = mapped_column(nullable=True, doc="The category of the provider (base or byok)")
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.")

View File

@@ -19,6 +19,11 @@ class ProviderType(str, Enum):
bedrock = "bedrock"
class ProviderCategory(str, Enum):
base = "base"
byok = "byok"
class MessageRole(str, Enum):
assistant = "assistant"
user = "user"

View File

@@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
from letta.constants import LETTA_MODEL_ENDPOINT
from letta.log import get_logger
from letta.schemas.enums import ProviderCategory
logger = get_logger(__name__)
@@ -51,6 +52,7 @@ class LLMConfig(BaseModel):
] = Field(..., description="The endpoint type for the model.")
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
provider_category: Optional[ProviderCategory] = Field(None, description="The provider category for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
context_window: int = Field(..., description="The context window size for the model.")
put_inner_thoughts_in_kwargs: Optional[bool] = Field(

View File

@@ -9,7 +9,7 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
from letta.schemas.enums import ProviderType
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.letta_base import LettaBase
from letta.schemas.llm_config import LLMConfig
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
@@ -24,6 +24,7 @@ class Provider(ProviderBase):
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
name: str = Field(..., description="The name of the provider")
provider_type: ProviderType = Field(..., description="The type of the provider")
provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)")
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
base_url: Optional[str] = Field(None, description="Base URL for the provider.")
organization_id: Optional[str] = Field(None, description="The organization id of the user")
@@ -113,6 +114,7 @@ class ProviderUpdate(ProviderBase):
class LettaProvider(Provider):
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
def list_llm_models(self) -> List[LLMConfig]:
return [
@@ -123,6 +125,7 @@ class LettaProvider(Provider):
context_window=8192,
handle=self.get_handle("letta-free"),
provider_name=self.name,
provider_category=self.provider_category,
)
]
@@ -141,6 +144,7 @@ class LettaProvider(Provider):
class OpenAIProvider(Provider):
provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
api_key: str = Field(..., description="API key for the OpenAI API.")
base_url: str = Field(..., description="Base URL for the OpenAI API.")
@@ -225,6 +229,7 @@ class OpenAIProvider(Provider):
context_window=context_window_size,
handle=self.get_handle(model_name),
provider_name=self.name,
provider_category=self.provider_category,
)
)
@@ -281,6 +286,7 @@ class DeepSeekProvider(OpenAIProvider):
"""
provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
api_key: str = Field(..., description="API key for the DeepSeek API.")
@@ -332,6 +338,7 @@ class DeepSeekProvider(OpenAIProvider):
handle=self.get_handle(model_name),
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
provider_name=self.name,
provider_category=self.provider_category,
)
)
@@ -344,6 +351,7 @@ class DeepSeekProvider(OpenAIProvider):
class LMStudioOpenAIProvider(OpenAIProvider):
provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
@@ -470,6 +478,7 @@ class XAIProvider(OpenAIProvider):
"""https://docs.x.ai/docs/api-reference"""
provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
api_key: str = Field(..., description="API key for the xAI/Grok API.")
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
@@ -523,6 +532,7 @@ class XAIProvider(OpenAIProvider):
context_window=context_window_size,
handle=self.get_handle(model_name),
provider_name=self.name,
provider_category=self.provider_category,
)
)
@@ -535,6 +545,7 @@ class XAIProvider(OpenAIProvider):
class AnthropicProvider(Provider):
provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
api_key: str = Field(..., description="API key for the Anthropic API.")
base_url: str = "https://api.anthropic.com/v1"
@@ -611,6 +622,7 @@ class AnthropicProvider(Provider):
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
max_tokens=max_tokens,
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs
@@ -621,6 +633,7 @@ class AnthropicProvider(Provider):
class MistralProvider(Provider):
provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
api_key: str = Field(..., description="API key for the Mistral API.")
base_url: str = "https://api.mistral.ai/v1"
@@ -645,6 +658,7 @@ class MistralProvider(Provider):
context_window=model["max_context_length"],
handle=self.get_handle(model["id"]),
provider_name=self.name,
provider_category=self.provider_category,
)
)
@@ -672,6 +686,7 @@ class OllamaProvider(OpenAIProvider):
"""
provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
base_url: str = Field(..., description="Base URL for the Ollama API.")
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
default_prompt_formatter: str = Field(
@@ -702,6 +717,7 @@ class OllamaProvider(OpenAIProvider):
context_window=context_window,
handle=self.get_handle(model["name"]),
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs
@@ -785,6 +801,7 @@ class OllamaProvider(OpenAIProvider):
class GroqProvider(OpenAIProvider):
provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
base_url: str = "https://api.groq.com/openai/v1"
api_key: str = Field(..., description="API key for the Groq API.")
@@ -804,6 +821,7 @@ class GroqProvider(OpenAIProvider):
context_window=model["context_window"],
handle=self.get_handle(model["id"]),
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs
@@ -825,6 +843,7 @@ class TogetherProvider(OpenAIProvider):
"""
provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
base_url: str = "https://api.together.ai/v1"
api_key: str = Field(..., description="API key for the TogetherAI API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
@@ -873,6 +892,7 @@ class TogetherProvider(OpenAIProvider):
context_window=context_window_size,
handle=self.get_handle(model_name),
provider_name=self.name,
provider_category=self.provider_category,
)
)
@@ -927,6 +947,7 @@ class TogetherProvider(OpenAIProvider):
class GoogleAIProvider(Provider):
# gemini
provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
api_key: str = Field(..., description="API key for the Google AI API.")
base_url: str = "https://generativelanguage.googleapis.com"
@@ -955,6 +976,7 @@ class GoogleAIProvider(Provider):
handle=self.get_handle(model),
max_tokens=8192,
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs
@@ -991,6 +1013,7 @@ class GoogleAIProvider(Provider):
class GoogleVertexProvider(Provider):
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
@@ -1008,6 +1031,7 @@ class GoogleVertexProvider(Provider):
handle=self.get_handle(model),
max_tokens=8192,
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs
@@ -1032,6 +1056,7 @@ class GoogleVertexProvider(Provider):
class AzureProvider(Provider):
provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
base_url: str = Field(
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
@@ -1065,6 +1090,7 @@ class AzureProvider(Provider):
context_window=context_window_size,
handle=self.get_handle(model_name),
provider_name=self.name,
provider_category=self.provider_category,
),
)
return configs
@@ -1106,6 +1132,7 @@ class VLLMChatCompletionsProvider(Provider):
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
base_url: str = Field(..., description="Base URL for the vLLM API.")
def list_llm_models(self) -> List[LLMConfig]:
@@ -1125,6 +1152,7 @@ class VLLMChatCompletionsProvider(Provider):
context_window=model["max_model_len"],
handle=self.get_handle(model["id"]),
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs
@@ -1139,6 +1167,7 @@ class VLLMCompletionsProvider(Provider):
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
base_url: str = Field(..., description="Base URL for the vLLM API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
@@ -1159,6 +1188,7 @@ class VLLMCompletionsProvider(Provider):
context_window=model["max_model_len"],
handle=self.get_handle(model["id"]),
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs
@@ -1174,6 +1204,7 @@ class CohereProvider(OpenAIProvider):
class AnthropicBedrockProvider(Provider):
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
aws_region: str = Field(..., description="AWS region for Bedrock")
def list_llm_models(self):
@@ -1192,6 +1223,7 @@ class AnthropicBedrockProvider(Provider):
context_window=self.get_model_context_window(model_arn),
handle=self.get_handle(model_arn),
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs

View File

@@ -1,8 +1,9 @@
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, Header, Query
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.server.rest_api.utils import get_letta_server
@@ -14,12 +15,19 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
def list_llm_models(
byok_only: Optional[bool] = Query(None),
default_only: Optional[bool] = Query(None),
provider_category: Optional[List[ProviderCategory]] = Query(None),
provider_name: Optional[str] = Query(None),
provider_type: Optional[ProviderType] = Query(None),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
models = server.list_llm_models(byok_only=byok_only, default_only=default_only)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
models = server.list_llm_models(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
)
# print(models)
return models
@@ -27,8 +35,9 @@ def list_llm_models(
@router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models")
def list_embedding_models(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
models = server.list_embedding_models()
actor = server.user_manager.get_user_or_default(user_id=actor_id)
models = server.list_embedding_models(actor=actor)
# print(models)
return models

View File

@@ -95,6 +95,7 @@ def create_source(
source_create.embedding_config = server.get_embedding_config_from_handle(
handle=source_create.embedding,
embedding_chunk_size=source_create.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
actor=actor,
)
source = Source(
name=source_create.name,

View File

@@ -42,7 +42,7 @@ from letta.schemas.block import Block, BlockUpdate, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
# openai schemas
from letta.schemas.enums import JobStatus, MessageStreamStatus
from letta.schemas.enums import JobStatus, MessageStreamStatus, ProviderCategory, ProviderType
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
from letta.schemas.group import GroupCreate, ManagerType, SleeptimeManager, VoiceSleeptimeManager
from letta.schemas.job import Job, JobUpdate
@@ -734,17 +734,17 @@ class SyncServer(Server):
return self._command(user_id=user_id, agent_id=agent_id, command=command)
@trace_method
def get_cached_llm_config(self, **kwargs):
def get_cached_llm_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs)
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
return self._llm_config_cache[key]
@trace_method
def get_cached_embedding_config(self, **kwargs):
def get_cached_embedding_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._embedding_config_cache:
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(**kwargs)
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
return self._embedding_config_cache[key]
@trace_method
@@ -766,7 +766,7 @@ class SyncServer(Server):
"enable_reasoner": request.enable_reasoner,
}
log_event(name="start get_cached_llm_config", attributes=config_params)
request.llm_config = self.get_cached_llm_config(**config_params)
request.llm_config = self.get_cached_llm_config(actor=actor, **config_params)
log_event(name="end get_cached_llm_config", attributes=config_params)
if request.embedding_config is None:
@@ -777,7 +777,7 @@ class SyncServer(Server):
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
}
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
request.embedding_config = self.get_cached_embedding_config(**embedding_config_params)
request.embedding_config = self.get_cached_embedding_config(actor=actor, **embedding_config_params)
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
log_event(name="start create_agent db")
@@ -802,10 +802,10 @@ class SyncServer(Server):
actor: User,
) -> AgentState:
if request.model is not None:
request.llm_config = self.get_llm_config_from_handle(handle=request.model)
request.llm_config = self.get_llm_config_from_handle(handle=request.model, actor=actor)
if request.embedding is not None:
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding)
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding, actor=actor)
if request.enable_sleeptime:
agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
@@ -1201,10 +1201,21 @@ class SyncServer(Server):
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
def list_llm_models(self, byok_only: bool = False, default_only: bool = False) -> List[LLMConfig]:
def list_llm_models(
self,
actor: User,
provider_category: Optional[List[ProviderCategory]] = None,
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[LLMConfig]:
"""List available models"""
llm_models = []
for provider in self.get_enabled_providers(byok_only=byok_only, default_only=default_only):
for provider in self.get_enabled_providers(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
):
try:
llm_models.extend(provider.list_llm_models())
except Exception as e:
@@ -1214,32 +1225,49 @@ class SyncServer(Server):
return llm_models
def list_embedding_models(self) -> List[EmbeddingConfig]:
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
"""List available embedding models"""
embedding_models = []
for provider in self.get_enabled_providers():
for provider in self.get_enabled_providers(actor):
try:
embedding_models.extend(provider.list_embedding_models())
except Exception as e:
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models
def get_enabled_providers(self, byok_only: bool = False, default_only: bool = False):
providers_from_env = {p.name: p for p in self._enabled_providers}
def get_enabled_providers(
self,
actor: User,
provider_category: Optional[List[ProviderCategory]] = None,
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[Provider]:
providers = []
if not provider_category or ProviderCategory.base in provider_category:
providers_from_env = [p for p in self._enabled_providers]
providers.extend(providers_from_env)
if default_only:
return list(providers_from_env.values())
if not provider_category or ProviderCategory.byok in provider_category:
providers_from_db = self.provider_manager.list_providers(
name=provider_name,
provider_type=provider_type,
actor=actor,
)
providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
providers.extend(providers_from_db)
providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
if provider_name is not None:
providers = [p for p in providers if p.name == provider_name]
if byok_only:
return list(providers_from_db.values())
if provider_type is not None:
providers = [p for p in providers if p.provider_type == provider_type]
return list(providers_from_env.values()) + list(providers_from_db.values())
return providers
@trace_method
def get_llm_config_from_handle(
self,
actor: User,
handle: str,
context_window_limit: Optional[int] = None,
max_tokens: Optional[int] = None,
@@ -1248,7 +1276,7 @@ class SyncServer(Server):
) -> LLMConfig:
try:
provider_name, model_name = handle.split("/", 1)
provider = self.get_provider_from_name(provider_name)
provider = self.get_provider_from_name(provider_name, actor)
llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
if not llm_configs:
@@ -1292,11 +1320,11 @@ class SyncServer(Server):
@trace_method
def get_embedding_config_from_handle(
self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
) -> EmbeddingConfig:
try:
provider_name, model_name = handle.split("/", 1)
provider = self.get_provider_from_name(provider_name)
provider = self.get_provider_from_name(provider_name, actor)
embedding_configs = [config for config in provider.list_embedding_models() if config.handle == handle]
if not embedding_configs:
@@ -1319,8 +1347,8 @@ class SyncServer(Server):
return embedding_config
def get_provider_from_name(self, provider_name: str) -> Provider:
providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
def get_provider_from_name(self, provider_name: str, actor: User) -> Provider:
providers = [provider for provider in self.get_enabled_providers(actor) if provider.name == provider_name]
if not providers:
raise ValueError(f"Provider {provider_name} is not supported")
elif len(providers) > 1:

View File

@@ -1,9 +1,9 @@
from typing import List, Optional, Union
from letta.orm.provider import Provider as ProviderModel
from letta.schemas.enums import ProviderType
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.providers import ProviderUpdate
from letta.schemas.providers import ProviderCreate, ProviderUpdate
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
@@ -16,9 +16,12 @@ class ProviderManager:
self.session_maker = db_context
@enforce_types
def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider:
def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider:
"""Create a new provider if it doesn't already exist."""
with self.session_maker() as session:
provider_create_args = {**request.model_dump(), "provider_category": ProviderCategory.byok}
provider = PydanticProvider(**provider_create_args)
if provider.name == provider.provider_type.value:
raise ValueError("Provider name must be unique and different from provider type")
@@ -65,11 +68,11 @@ class ProviderManager:
@enforce_types
def list_providers(
self,
actor: PydanticUser,
name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
actor: PydanticUser = None,
) -> List[PydanticProvider]:
"""List all providers with optional pagination."""
filter_kwargs = {}
@@ -88,11 +91,11 @@ class ProviderManager:
return [provider.to_pydantic() for provider in providers]
@enforce_types
def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]:
providers = self.list_providers(name=provider_name)
def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
providers = self.list_providers(name=provider_name, actor=actor)
return providers[0].id if providers else None
@enforce_types
def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]:
providers = self.list_providers(name=provider_name)
def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
providers = self.list_providers(name=provider_name, actor=actor)
return providers[0].api_key if providers else None

View File

@@ -105,9 +105,8 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
llm_client = LLMClient.create(
provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
actor_id=client.user.id,
actor=client.user,
)
if llm_client:
response = llm_client.send_llm_request(

View File

@@ -13,10 +13,10 @@ import letta.utils as utils
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
from letta.orm import Provider, Step
from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole, ProviderType
from letta.schemas.enums import MessageRole, ProviderCategory, ProviderType
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.providers import ProviderCreate
from letta.schemas.sandbox_config import SandboxType
from letta.schemas.user import User
@@ -587,7 +587,7 @@ def test_read_local_llm_configs(server: SyncServer, user: User):
# Call list_llm_models
assert os.path.exists(configs_base_dir)
llm_models = server.list_llm_models()
llm_models = server.list_llm_models(actor=user)
# Assert that the config is in the returned models
assert any(
@@ -1225,17 +1225,23 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
def test_messages_with_provider_override(server: SyncServer, user_id: str):
actor = server.user_manager.get_user_or_default(user_id)
provider = server.provider_manager.create_provider(
provider=PydanticProvider(
request=ProviderCreate(
name="caren-anthropic",
provider_type=ProviderType.anthropic,
api_key=os.getenv("ANTHROPIC_API_KEY"),
),
actor=actor,
)
models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.byok])
assert provider.name in [model.provider_name for model in models]
models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.base])
assert provider.name not in [model.provider_name for model in models]
agent = server.create_agent(
request=CreateAgent(
memory_blocks=[],
model="caren-anthropic/claude-3-opus-20240229",
model="caren-anthropic/claude-3-5-sonnet-20240620",
context_window_limit=100000,
embedding="openai/text-embedding-ada-002",
),
@@ -1295,11 +1301,11 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
assert total_tokens == usage.total_tokens
def test_unique_handles_for_provider_configs(server: SyncServer):
models = server.list_llm_models()
def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
models = server.list_llm_models(actor=user)
model_handles = [model.handle for model in models]
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"
embeddings = server.list_embedding_models()
embeddings = server.list_embedding_models(actor=user)
embedding_handles = [embedding.handle for embedding in embeddings]
assert sorted(embedding_handles) == sorted(list(set(embedding_handles))), "All embeddings should have unique handles"