From db6982a4bc3133576bac688b102435d75d07def6 Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 6 May 2025 17:31:36 -0700 Subject: [PATCH] feat: add provider_category field to distinguish byok (#2038) --- .../878607e41ca4_add_provider_category.py | 31 +++++++ letta/agent.py | 12 ++- letta/agents/letta_agent.py | 6 +- letta/agents/letta_agent_batch.py | 6 +- letta/client/client.py | 4 +- letta/llm_api/anthropic.py | 21 +++-- letta/llm_api/anthropic_client.py | 18 +++-- letta/llm_api/google_ai_client.py | 6 +- letta/llm_api/llm_api_tools.py | 12 ++- letta/llm_api/llm_client.py | 20 +++-- letta/llm_api/llm_client_base.py | 11 +-- letta/llm_api/openai_client.py | 12 +-- letta/memory.py | 13 +-- letta/orm/provider.py | 1 + letta/schemas/enums.py | 5 ++ letta/schemas/llm_config.py | 2 + letta/schemas/providers.py | 34 +++++++- letta/server/rest_api/routers/v1/llms.py | 23 ++++-- letta/server/rest_api/routers/v1/sources.py | 1 + letta/server/server.py | 80 +++++++++++++------ letta/services/provider_manager.py | 19 +++-- tests/helpers/endpoints_helper.py | 3 +- tests/test_server.py | 22 +++-- 23 files changed, 250 insertions(+), 112 deletions(-) create mode 100644 alembic/versions/878607e41ca4_add_provider_category.py diff --git a/alembic/versions/878607e41ca4_add_provider_category.py b/alembic/versions/878607e41ca4_add_provider_category.py new file mode 100644 index 00000000..fb914c67 --- /dev/null +++ b/alembic/versions/878607e41ca4_add_provider_category.py @@ -0,0 +1,31 @@ +"""add provider category + +Revision ID: 878607e41ca4 +Revises: 0335b1eb9c40 +Create Date: 2025-05-06 12:10:25.751536 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "878607e41ca4" +down_revision: Union[str, None] = "0335b1eb9c40" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("providers", sa.Column("provider_category", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("providers", "provider_category") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 4e76bf12..d0c9ac0f 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -331,10 +331,9 @@ class Agent(BaseAgent): log_telemetry(self.logger, "_get_ai_reply create start") # New LLM client flow llm_client = LLMClient.create( - provider_name=self.agent_state.llm_config.provider_name, provider_type=self.agent_state.llm_config.model_endpoint_type, put_inner_thoughts_first=put_inner_thoughts_first, - actor_id=self.user.id, + actor=self.user, ) if llm_client and not stream: @@ -943,7 +942,10 @@ class Agent(BaseAgent): model_endpoint=self.agent_state.llm_config.model_endpoint, context_window_limit=self.agent_state.llm_config.context_window, usage=response.usage, - provider_id=self.provider_manager.get_provider_id_from_name(self.agent_state.llm_config.provider_name), + provider_id=self.provider_manager.get_provider_id_from_name( + self.agent_state.llm_config.provider_name, + actor=self.user, + ), job_id=job_id, ) for message in all_new_messages: @@ -1087,7 +1089,9 @@ class Agent(BaseAgent): LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"] ) - summary = summarize_messages(agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize) + summary = summarize_messages( + agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize, actor=self.user + ) logger.info(f"Got summary: {summary}") # Metadata that's useful for the agent to see diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 90997c5c..4a194457 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -75,10 +75,9 @@ class LettaAgent(BaseAgent): ) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) llm_client = LLMClient.create( - provider_name=agent_state.llm_config.provider_name, provider_type=agent_state.llm_config.model_endpoint_type, put_inner_thoughts_first=True, - actor_id=self.actor.id, + actor=self.actor, ) for _ in range(max_steps): response = await self._get_ai_reply( @@ -120,10 +119,9 @@ class LettaAgent(BaseAgent): ) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) llm_client = LLMClient.create( - provider_name=agent_state.llm_config.provider_name, provider_type=agent_state.llm_config.model_endpoint_type, put_inner_thoughts_first=True, - actor_id=self.actor.id, + actor=self.actor, ) for _ in range(max_steps): diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index 1c63c079..58cb5be7 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -172,10 +172,9 @@ class LettaAgentBatch: log_event(name="init_llm_client") llm_client = LLMClient.create( - provider_name=agent_states[0].llm_config.provider_name, provider_type=agent_states[0].llm_config.model_endpoint_type, put_inner_thoughts_first=True, - actor_id=self.actor.id, + actor=self.actor, ) agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states} @@ -284,10 +283,9 @@ class LettaAgentBatch: # translate provider‑specific response → OpenAI‑style tool call (unchanged) llm_client = LLMClient.create( - provider_name=item.llm_config.provider_name, provider_type=item.llm_config.model_endpoint_type, put_inner_thoughts_first=True, - actor_id=self.actor.id, + actor=self.actor, ) tool_call = ( llm_client.convert_response_to_chat_completion( diff --git a/letta/client/client.py b/letta/client/client.py index 90dd2823..14fdc009 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -3455,7 +3455,7 @@ class LocalClient(AbstractClient): Returns: configs (List[LLMConfig]): List of LLM configurations """ - return self.server.list_llm_models() + return self.server.list_llm_models(actor=self.user) def list_embedding_configs(self) -> List[EmbeddingConfig]: """ @@ -3464,7 +3464,7 @@ class LocalClient(AbstractClient): Returns: configs (List[EmbeddingConfig]): List of embedding configurations """ - return self.server.list_embedding_models() + return self.server.list_embedding_models(actor=self.user) def create_org(self, name: Optional[str] = None) -> Organization: return self.server.organization_manager.create_organization(pydantic_org=Organization(name=name)) diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 6a2a6e55..aada2259 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -26,7 +26,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.log import get_logger -from letta.schemas.enums import ProviderType +from letta.schemas.enums import ProviderCategory from letta.schemas.message import Message as _Message from letta.schemas.message import MessageRole as _MessageRole from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool @@ -42,6 +42,7 @@ from letta.schemas.openai.chat_completion_response import Message from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage from letta.schemas.openai.chat_completion_response import MessageDelta, ToolCall, ToolCallDelta, UsageStatistics from letta.services.provider_manager import ProviderManager +from letta.services.user_manager import UserManager from letta.settings import model_settings from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface from letta.tracing import log_event @@ -744,12 +745,15 @@ def anthropic_chat_completions_request( extended_thinking: bool = False, max_reasoning_tokens: Optional[int] = None, provider_name: Optional[str] = None, + provider_category: Optional[ProviderCategory] = None, betas: List[str] = ["tools-2024-04-04"], + user_id: Optional[str] = None, ) -> ChatCompletionResponse: """https://docs.anthropic.com/claude/docs/tool-use""" anthropic_client = None - if provider_name and provider_name != ProviderType.anthropic.value: - api_key = ProviderManager().get_override_key(provider_name) + if provider_category == ProviderCategory.byok: + actor = UserManager().get_user_or_default(user_id=user_id) + api_key = ProviderManager().get_override_key(provider_name, actor=actor) anthropic_client = anthropic.Anthropic(api_key=api_key) elif model_settings.anthropic_api_key: anthropic_client = anthropic.Anthropic() @@ -803,7 +807,9 @@ def anthropic_chat_completions_request_stream( extended_thinking: bool = False, max_reasoning_tokens: Optional[int] = None, provider_name: Optional[str] = None, + provider_category: Optional[ProviderCategory] = None, betas: List[str] = ["tools-2024-04-04"], + user_id: Optional[str] = None, ) -> Generator[ChatCompletionChunkResponse, None, None]: """Stream chat completions from Anthropic API. @@ -817,8 +823,9 @@ def anthropic_chat_completions_request_stream( extended_thinking=extended_thinking, max_reasoning_tokens=max_reasoning_tokens, ) - if provider_name and provider_name != ProviderType.anthropic.value: - api_key = ProviderManager().get_override_key(provider_name) + if provider_category == ProviderCategory.byok: + actor = UserManager().get_user_or_default(user_id=user_id) + api_key = ProviderManager().get_override_key(provider_name, actor=actor) anthropic_client = anthropic.Anthropic(api_key=api_key) elif model_settings.anthropic_api_key: anthropic_client = anthropic.Anthropic() @@ -867,10 +874,12 @@ def anthropic_chat_completions_process_stream( extended_thinking: bool = False, max_reasoning_tokens: Optional[int] = None, provider_name: Optional[str] = None, + provider_category: Optional[ProviderCategory] = None, create_message_id: bool = True, create_message_datetime: bool = True, betas: List[str] = ["tools-2024-04-04"], name: Optional[str] = None, + user_id: Optional[str] = None, ) -> ChatCompletionResponse: """Process a streaming completion response from Anthropic, similar to OpenAI's streaming. @@ -952,7 +961,9 @@ def anthropic_chat_completions_process_stream( extended_thinking=extended_thinking, max_reasoning_tokens=max_reasoning_tokens, provider_name=provider_name, + provider_category=provider_category, betas=betas, + user_id=user_id, ) ): assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 35317dd8..f26d58eb 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -27,7 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.log import get_logger -from letta.schemas.enums import ProviderType +from letta.schemas.enums import ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import Tool @@ -45,18 +45,18 @@ logger = get_logger(__name__) class AnthropicClient(LLMClientBase): def request(self, request_data: dict, llm_config: LLMConfig) -> dict: - client = self._get_anthropic_client(async_client=False) + client = self._get_anthropic_client(llm_config, async_client=False) response = client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) return response.model_dump() async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: - client = self._get_anthropic_client(async_client=True) + client = self._get_anthropic_client(llm_config, async_client=True) response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) return response.model_dump() @trace_method async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]: - client = self._get_anthropic_client(async_client=True) + client = self._get_anthropic_client(llm_config, async_client=True) request_data["stream"] = True return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) @@ -96,7 +96,7 @@ class AnthropicClient(LLMClientBase): for agent_id in agent_messages_mapping } - client = self._get_anthropic_client(async_client=True) + client = self._get_anthropic_client(list(agent_llm_config_mapping.values())[0], async_client=True) anthropic_requests = [ Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items() @@ -112,10 +112,12 @@ class AnthropicClient(LLMClientBase): raise self.handle_llm_error(e) @trace_method - def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]: + def _get_anthropic_client( + self, llm_config: LLMConfig, async_client: bool = False + ) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]: override_key = None - if self.provider_name and self.provider_name != ProviderType.anthropic.value: - override_key = ProviderManager().get_override_key(self.provider_name) + if llm_config.provider_category == ProviderCategory.byok: + override_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor) if async_client: return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic() diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index 2d82c911..92577195 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -13,7 +13,7 @@ from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.json_parser import clean_json_string_extra_backslash from letta.local_llm.utils import count_tokens from letta.log import get_logger -from letta.schemas.enums import ProviderType +from letta.schemas.enums import ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import Tool @@ -31,10 +31,10 @@ class GoogleAIClient(LLMClientBase): Performs underlying request to llm and returns raw response. """ api_key = None - if llm_config.provider_name and llm_config.provider_name != ProviderType.google_ai.value: + if llm_config.provider_category == ProviderCategory.byok: from letta.services.provider_manager import ProviderManager - api_key = ProviderManager().get_override_key(llm_config.provider_name) + api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor) if not api_key: api_key = model_settings.gemini_api_key diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index b1112290..7a778cda 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -24,7 +24,7 @@ from letta.llm_api.openai import ( from letta.local_llm.chat_completion_proxy import get_chat_completion from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages -from letta.schemas.enums import ProviderType +from letta.schemas.enums import ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype @@ -172,10 +172,12 @@ def create( if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1": # only is a problem if we are *not* using an openai proxy raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"]) - elif llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value: + elif llm_config.provider_category == ProviderCategory.byok: from letta.services.provider_manager import ProviderManager + from letta.services.user_manager import UserManager - api_key = ProviderManager().get_override_key(llm_config.provider_name) + actor = UserManager().get_user_or_default(user_id=user_id) + api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=actor) elif model_settings.openai_api_key is None: # the openai python client requires a dummy API key api_key = "DUMMY_API_KEY" @@ -379,7 +381,9 @@ def create( extended_thinking=llm_config.enable_reasoner, max_reasoning_tokens=llm_config.max_reasoning_tokens, provider_name=llm_config.provider_name, + provider_category=llm_config.provider_category, name=name, + user_id=user_id, ) else: @@ -390,6 +394,8 @@ def create( extended_thinking=llm_config.enable_reasoner, max_reasoning_tokens=llm_config.max_reasoning_tokens, provider_name=llm_config.provider_name, + provider_category=llm_config.provider_category, + user_id=user_id, ) if llm_config.put_inner_thoughts_in_kwargs: diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index a63913a4..63adbcc2 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -1,8 +1,11 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional from letta.llm_api.llm_client_base import LLMClientBase from letta.schemas.enums import ProviderType +if TYPE_CHECKING: + from letta.orm import User + class LLMClient: """Factory class for creating LLM clients based on the model endpoint type.""" @@ -10,9 +13,8 @@ class LLMClient: @staticmethod def create( provider_type: ProviderType, - provider_name: Optional[str] = None, put_inner_thoughts_first: bool = True, - actor_id: Optional[str] = None, + actor: Optional["User"] = None, ) -> Optional[LLMClientBase]: """ Create an LLM client based on the model endpoint type. @@ -32,33 +34,29 @@ class LLMClient: from letta.llm_api.google_ai_client import GoogleAIClient return GoogleAIClient( - provider_name=provider_name, put_inner_thoughts_first=put_inner_thoughts_first, - actor_id=actor_id, + actor=actor, ) case ProviderType.google_vertex: from letta.llm_api.google_vertex_client import GoogleVertexClient return GoogleVertexClient( - provider_name=provider_name, put_inner_thoughts_first=put_inner_thoughts_first, - actor_id=actor_id, + actor=actor, ) case ProviderType.anthropic: from letta.llm_api.anthropic_client import AnthropicClient return AnthropicClient( - provider_name=provider_name, put_inner_thoughts_first=put_inner_thoughts_first, - actor_id=actor_id, + actor=actor, ) case ProviderType.openai: from letta.llm_api.openai_client import OpenAIClient return OpenAIClient( - provider_name=provider_name, put_inner_thoughts_first=put_inner_thoughts_first, - actor_id=actor_id, + actor=actor, ) case _: return None diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 223921f9..f56601ee 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union from anthropic.types.beta.messages import BetaMessageBatch from openai import AsyncStream, Stream @@ -11,6 +11,9 @@ from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.tracing import log_event +if TYPE_CHECKING: + from letta.orm import User + class LLMClientBase: """ @@ -20,13 +23,11 @@ class LLMClientBase: def __init__( self, - provider_name: Optional[str] = None, put_inner_thoughts_first: Optional[bool] = True, use_tool_naming: bool = True, - actor_id: Optional[str] = None, + actor: Optional["User"] = None, ): - self.actor_id = actor_id - self.provider_name = provider_name + self.actor = actor self.put_inner_thoughts_first = put_inner_thoughts_first self.use_tool_naming = use_tool_naming diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index cf464b2c..c641f5e1 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -22,7 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST from letta.log import get_logger -from letta.schemas.enums import ProviderType +from letta.schemas.enums import ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import ChatCompletionRequest @@ -78,10 +78,10 @@ def supports_parallel_tool_calling(model: str) -> bool: class OpenAIClient(LLMClientBase): def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict: api_key = None - if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value: + if llm_config.provider_category == ProviderCategory.byok: from letta.services.provider_manager import ProviderManager - api_key = ProviderManager().get_override_key(llm_config.provider_name) + api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor) if not api_key: api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") @@ -156,11 +156,11 @@ class OpenAIClient(LLMClientBase): ) # always set user id for openai requests - if self.actor_id: - data.user = self.actor_id + if self.actor: + data.user = self.actor.id if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT: - if not self.actor_id: + if not self.actor: # override user id for inference.letta.com import uuid diff --git a/letta/memory.py b/letta/memory.py index 100d3966..939e0874 100644 --- a/letta/memory.py +++ b/letta/memory.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List +from typing import TYPE_CHECKING, Callable, Dict, List from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK from letta.llm_api.llm_api_tools import create @@ -13,6 +13,9 @@ from letta.settings import summarizer_settings from letta.tracing import trace_method from letta.utils import count_tokens, printd +if TYPE_CHECKING: + from letta.orm import User + def get_memory_functions(cls: Memory) -> Dict[str, Callable]: """Get memory functions for a memory class""" @@ -51,6 +54,7 @@ def _format_summary_history(message_history: List[Message]): def summarize_messages( agent_state: AgentState, message_sequence_to_summarize: List[Message], + actor: "User", ): """Summarize a message sequence using GPT""" # we need the context_window @@ -63,7 +67,7 @@ def summarize_messages( trunc_ratio = (summarizer_settings.memory_warning_threshold * context_window / summary_input_tkns) * 0.8 # For good measure... cutoff = int(len(message_sequence_to_summarize) * trunc_ratio) summary_input = str( - [summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff])] + [summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff], actor=actor)] + message_sequence_to_summarize[cutoff:] ) @@ -79,10 +83,9 @@ def summarize_messages( llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False llm_client = LLMClient.create( - provider_name=llm_config_no_inner_thoughts.provider_name, - provider_type=llm_config_no_inner_thoughts.model_endpoint_type, + provider_type=agent_state.llm_config.model_endpoint_type, put_inner_thoughts_first=False, - actor_id=agent_state.created_by_id, + actor=actor, ) # try to use new client, otherwise fallback to old flow # TODO: we can just directly call the LLM here? diff --git a/letta/orm/provider.py b/letta/orm/provider.py index d85e5ef2..803b4110 100644 --- a/letta/orm/provider.py +++ b/letta/orm/provider.py @@ -26,6 +26,7 @@ class Provider(SqlalchemyBase, OrganizationMixin): name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider") provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider") + provider_category: Mapped[str] = mapped_column(nullable=True, doc="The category of the provider (base or byok)") api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.") base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.") diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 6258e1e5..2a3de409 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -19,6 +19,11 @@ class ProviderType(str, Enum): bedrock = "bedrock" +class ProviderCategory(str, Enum): + base = "base" + byok = "byok" + + class MessageRole(str, Enum): assistant = "assistant" user = "user" diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 27795121..903d9a7e 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator from letta.constants import LETTA_MODEL_ENDPOINT from letta.log import get_logger +from letta.schemas.enums import ProviderCategory logger = get_logger(__name__) @@ -51,6 +52,7 @@ class LLMConfig(BaseModel): ] = Field(..., description="The endpoint type for the model.") model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.") provider_name: Optional[str] = Field(None, description="The provider name for the model.") + provider_category: Optional[ProviderCategory] = Field(None, description="The provider category for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") context_window: int = Field(..., description="The context window size for the model.") put_inner_thoughts_in_kwargs: Optional[bool] = Field( diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index 0b9dc2b3..291271e3 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -9,7 +9,7 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_ from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES -from letta.schemas.enums import ProviderType +from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.letta_base import LettaBase from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES @@ -24,6 +24,7 @@ class Provider(ProviderBase): id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.") name: str = Field(..., description="The name of the provider") provider_type: ProviderType = Field(..., description="The type of the provider") + provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)") api_key: Optional[str] = Field(None, description="API key used for requests to the provider.") base_url: Optional[str] = Field(None, description="Base URL for the provider.") organization_id: Optional[str] = Field(None, description="The organization id of the user") @@ -113,6 +114,7 @@ class ProviderUpdate(ProviderBase): class LettaProvider(Provider): provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") def list_llm_models(self) -> List[LLMConfig]: return [ @@ -123,6 +125,7 @@ class LettaProvider(Provider): context_window=8192, handle=self.get_handle("letta-free"), provider_name=self.name, + provider_category=self.provider_category, ) ] @@ -141,6 +144,7 @@ class LettaProvider(Provider): class OpenAIProvider(Provider): provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") api_key: str = Field(..., description="API key for the OpenAI API.") base_url: str = Field(..., description="Base URL for the OpenAI API.") @@ -225,6 +229,7 @@ class OpenAIProvider(Provider): context_window=context_window_size, handle=self.get_handle(model_name), provider_name=self.name, + provider_category=self.provider_category, ) ) @@ -281,6 +286,7 @@ class DeepSeekProvider(OpenAIProvider): """ provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.") api_key: str = Field(..., description="API key for the DeepSeek API.") @@ -332,6 +338,7 @@ class DeepSeekProvider(OpenAIProvider): handle=self.get_handle(model_name), put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, provider_name=self.name, + provider_category=self.provider_category, ) ) @@ -344,6 +351,7 @@ class DeepSeekProvider(OpenAIProvider): class LMStudioOpenAIProvider(OpenAIProvider): provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.") api_key: Optional[str] = Field(None, description="API key for the LMStudio API.") @@ -470,6 +478,7 @@ class XAIProvider(OpenAIProvider): """https://docs.x.ai/docs/api-reference""" provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") api_key: str = Field(..., description="API key for the xAI/Grok API.") base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.") @@ -523,6 +532,7 @@ class XAIProvider(OpenAIProvider): context_window=context_window_size, handle=self.get_handle(model_name), provider_name=self.name, + provider_category=self.provider_category, ) ) @@ -535,6 +545,7 @@ class XAIProvider(OpenAIProvider): class AnthropicProvider(Provider): provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") api_key: str = Field(..., description="API key for the Anthropic API.") base_url: str = "https://api.anthropic.com/v1" @@ -611,6 +622,7 @@ class AnthropicProvider(Provider): put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, max_tokens=max_tokens, provider_name=self.name, + provider_category=self.provider_category, ) ) return configs @@ -621,6 +633,7 @@ class AnthropicProvider(Provider): class MistralProvider(Provider): provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") api_key: str = Field(..., description="API key for the Mistral API.") base_url: str = "https://api.mistral.ai/v1" @@ -645,6 +658,7 @@ class MistralProvider(Provider): context_window=model["max_context_length"], handle=self.get_handle(model["id"]), provider_name=self.name, + provider_category=self.provider_category, ) ) @@ -672,6 +686,7 @@ class OllamaProvider(OpenAIProvider): """ provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = Field(..., description="Base URL for the Ollama API.") api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).") default_prompt_formatter: str = Field( @@ -702,6 +717,7 @@ class OllamaProvider(OpenAIProvider): context_window=context_window, handle=self.get_handle(model["name"]), provider_name=self.name, + provider_category=self.provider_category, ) ) return configs @@ -785,6 +801,7 @@ class OllamaProvider(OpenAIProvider): class GroqProvider(OpenAIProvider): provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = "https://api.groq.com/openai/v1" api_key: str = Field(..., description="API key for the Groq API.") @@ -804,6 +821,7 @@ class GroqProvider(OpenAIProvider): context_window=model["context_window"], handle=self.get_handle(model["id"]), provider_name=self.name, + provider_category=self.provider_category, ) ) return configs @@ -825,6 +843,7 @@ class TogetherProvider(OpenAIProvider): """ provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = "https://api.together.ai/v1" api_key: str = Field(..., description="API key for the TogetherAI API.") default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") @@ -873,6 +892,7 @@ class TogetherProvider(OpenAIProvider): context_window=context_window_size, handle=self.get_handle(model_name), provider_name=self.name, + provider_category=self.provider_category, ) ) @@ -927,6 +947,7 @@ class TogetherProvider(OpenAIProvider): class GoogleAIProvider(Provider): # gemini provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") api_key: str = Field(..., description="API key for the Google AI API.") base_url: str = "https://generativelanguage.googleapis.com" @@ -955,6 +976,7 @@ class GoogleAIProvider(Provider): handle=self.get_handle(model), max_tokens=8192, provider_name=self.name, + provider_category=self.provider_category, ) ) return configs @@ -991,6 +1013,7 @@ class GoogleAIProvider(Provider): class GoogleVertexProvider(Provider): provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.") google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.") @@ -1008,6 +1031,7 @@ class GoogleVertexProvider(Provider): handle=self.get_handle(model), max_tokens=8192, provider_name=self.name, + provider_category=self.provider_category, ) ) return configs @@ -1032,6 +1056,7 @@ class GoogleVertexProvider(Provider): class AzureProvider(Provider): provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation base_url: str = Field( ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`." @@ -1065,6 +1090,7 @@ class AzureProvider(Provider): context_window=context_window_size, handle=self.get_handle(model_name), provider_name=self.name, + provider_category=self.provider_category, ), ) return configs @@ -1106,6 +1132,7 @@ class VLLMChatCompletionsProvider(Provider): # NOTE: vLLM only serves one model at a time (so could configure that through env variables) provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = Field(..., description="Base URL for the vLLM API.") def list_llm_models(self) -> List[LLMConfig]: @@ -1125,6 +1152,7 @@ class VLLMChatCompletionsProvider(Provider): context_window=model["max_model_len"], handle=self.get_handle(model["id"]), provider_name=self.name, + provider_category=self.provider_category, ) ) return configs @@ -1139,6 +1167,7 @@ class VLLMCompletionsProvider(Provider): # NOTE: vLLM only serves one model at a time (so could configure that through env variables) provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = Field(..., description="Base URL for the vLLM API.") default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") @@ -1159,6 +1188,7 @@ class VLLMCompletionsProvider(Provider): context_window=model["max_model_len"], handle=self.get_handle(model["id"]), provider_name=self.name, + provider_category=self.provider_category, ) ) return configs @@ -1174,6 +1204,7 @@ class CohereProvider(OpenAIProvider): class AnthropicBedrockProvider(Provider): provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") aws_region: str = Field(..., description="AWS region for Bedrock") def list_llm_models(self): @@ -1192,6 +1223,7 @@ class AnthropicBedrockProvider(Provider): context_window=self.get_model_context_window(model_arn), handle=self.get_handle(model_arn), provider_name=self.name, + provider_category=self.provider_category, ) ) return configs diff --git a/letta/server/rest_api/routers/v1/llms.py b/letta/server/rest_api/routers/v1/llms.py index f050cf7d..450f8608 100644 --- a/letta/server/rest_api/routers/v1/llms.py +++ b/letta/server/rest_api/routers/v1/llms.py @@ -1,8 +1,9 @@ from typing import TYPE_CHECKING, List, Optional -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, Header, Query from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.llm_config import LLMConfig from letta.server.rest_api.utils import get_letta_server @@ -14,12 +15,19 @@ router = APIRouter(prefix="/models", tags=["models", "llms"]) @router.get("/", response_model=List[LLMConfig], operation_id="list_models") def list_llm_models( - byok_only: Optional[bool] = Query(None), - default_only: Optional[bool] = Query(None), + provider_category: Optional[List[ProviderCategory]] = Query(None), + provider_name: Optional[str] = Query(None), + provider_type: Optional[ProviderType] = Query(None), server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): - - models = server.list_llm_models(byok_only=byok_only, default_only=default_only) + actor = server.user_manager.get_user_or_default(user_id=actor_id) + models = server.list_llm_models( + provider_category=provider_category, + provider_name=provider_name, + provider_type=provider_type, + actor=actor, + ) # print(models) return models @@ -27,8 +35,9 @@ def list_llm_models( @router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models") def list_embedding_models( server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): - - models = server.list_embedding_models() + actor = server.user_manager.get_user_or_default(user_id=actor_id) + models = server.list_embedding_models(actor=actor) # print(models) return models diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 97a76eb3..bc4b8086 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -95,6 +95,7 @@ def create_source( source_create.embedding_config = server.get_embedding_config_from_handle( handle=source_create.embedding, embedding_chunk_size=source_create.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE, + actor=actor, ) source = Source( name=source_create.name, diff --git a/letta/server/server.py b/letta/server/server.py index d99fd74d..5bd8a4b9 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -42,7 +42,7 @@ from letta.schemas.block import Block, BlockUpdate, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig # openai schemas -from letta.schemas.enums import JobStatus, MessageStreamStatus +from letta.schemas.enums import JobStatus, MessageStreamStatus, ProviderCategory, ProviderType from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate from letta.schemas.group import GroupCreate, ManagerType, SleeptimeManager, VoiceSleeptimeManager from letta.schemas.job import Job, JobUpdate @@ -734,17 +734,17 @@ class SyncServer(Server): return self._command(user_id=user_id, agent_id=agent_id, command=command) @trace_method - def get_cached_llm_config(self, **kwargs): + def get_cached_llm_config(self, actor: User, **kwargs): key = make_key(**kwargs) if key not in self._llm_config_cache: - self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs) + self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs) return self._llm_config_cache[key] @trace_method - def get_cached_embedding_config(self, **kwargs): + def get_cached_embedding_config(self, actor: User, **kwargs): key = make_key(**kwargs) if key not in self._embedding_config_cache: - self._embedding_config_cache[key] = self.get_embedding_config_from_handle(**kwargs) + self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs) return self._embedding_config_cache[key] @trace_method @@ -766,7 +766,7 @@ class SyncServer(Server): "enable_reasoner": request.enable_reasoner, } log_event(name="start get_cached_llm_config", attributes=config_params) - request.llm_config = self.get_cached_llm_config(**config_params) + request.llm_config = self.get_cached_llm_config(actor=actor, **config_params) log_event(name="end get_cached_llm_config", attributes=config_params) if request.embedding_config is None: @@ -777,7 +777,7 @@ class SyncServer(Server): "embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE, } log_event(name="start get_cached_embedding_config", attributes=embedding_config_params) - request.embedding_config = self.get_cached_embedding_config(**embedding_config_params) + request.embedding_config = self.get_cached_embedding_config(actor=actor, **embedding_config_params) log_event(name="end get_cached_embedding_config", attributes=embedding_config_params) log_event(name="start create_agent db") @@ -802,10 +802,10 @@ class SyncServer(Server): actor: User, ) -> AgentState: if request.model is not None: - request.llm_config = self.get_llm_config_from_handle(handle=request.model) + request.llm_config = self.get_llm_config_from_handle(handle=request.model, actor=actor) if request.embedding is not None: - request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding) + request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding, actor=actor) if request.enable_sleeptime: agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) @@ -1201,10 +1201,21 @@ class SyncServer(Server): except NoResultFound: raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found") - def list_llm_models(self, byok_only: bool = False, default_only: bool = False) -> List[LLMConfig]: + def list_llm_models( + self, + actor: User, + provider_category: Optional[List[ProviderCategory]] = None, + provider_name: Optional[str] = None, + provider_type: Optional[ProviderType] = None, + ) -> List[LLMConfig]: """List available models""" llm_models = [] - for provider in self.get_enabled_providers(byok_only=byok_only, default_only=default_only): + for provider in self.get_enabled_providers( + provider_category=provider_category, + provider_name=provider_name, + provider_type=provider_type, + actor=actor, + ): try: llm_models.extend(provider.list_llm_models()) except Exception as e: @@ -1214,32 +1225,49 @@ class SyncServer(Server): return llm_models - def list_embedding_models(self) -> List[EmbeddingConfig]: + def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]: """List available embedding models""" embedding_models = [] - for provider in self.get_enabled_providers(): + for provider in self.get_enabled_providers(actor): try: embedding_models.extend(provider.list_embedding_models()) except Exception as e: warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") return embedding_models - def get_enabled_providers(self, byok_only: bool = False, default_only: bool = False): - providers_from_env = {p.name: p for p in self._enabled_providers} + def get_enabled_providers( + self, + actor: User, + provider_category: Optional[List[ProviderCategory]] = None, + provider_name: Optional[str] = None, + provider_type: Optional[ProviderType] = None, + ) -> List[Provider]: + providers = [] + if not provider_category or ProviderCategory.base in provider_category: + providers_from_env = [p for p in self._enabled_providers] + providers.extend(providers_from_env) - if default_only: - return list(providers_from_env.values()) + if not provider_category or ProviderCategory.byok in provider_category: + providers_from_db = self.provider_manager.list_providers( + name=provider_name, + provider_type=provider_type, + actor=actor, + ) + providers_from_db = [p.cast_to_subtype() for p in providers_from_db] + providers.extend(providers_from_db) - providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()} + if provider_name is not None: + providers = [p for p in providers if p.name == provider_name] - if byok_only: - return list(providers_from_db.values()) + if provider_type is not None: + providers = [p for p in providers if p.provider_type == provider_type] - return list(providers_from_env.values()) + list(providers_from_db.values()) + return providers @trace_method def get_llm_config_from_handle( self, + actor: User, handle: str, context_window_limit: Optional[int] = None, max_tokens: Optional[int] = None, @@ -1248,7 +1276,7 @@ class SyncServer(Server): ) -> LLMConfig: try: provider_name, model_name = handle.split("/", 1) - provider = self.get_provider_from_name(provider_name) + provider = self.get_provider_from_name(provider_name, actor) llm_configs = [config for config in provider.list_llm_models() if config.handle == handle] if not llm_configs: @@ -1292,11 +1320,11 @@ class SyncServer(Server): @trace_method def get_embedding_config_from_handle( - self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE + self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE ) -> EmbeddingConfig: try: provider_name, model_name = handle.split("/", 1) - provider = self.get_provider_from_name(provider_name) + provider = self.get_provider_from_name(provider_name, actor) embedding_configs = [config for config in provider.list_embedding_models() if config.handle == handle] if not embedding_configs: @@ -1319,8 +1347,8 @@ class SyncServer(Server): return embedding_config - def get_provider_from_name(self, provider_name: str) -> Provider: - providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name] + def get_provider_from_name(self, provider_name: str, actor: User) -> Provider: + providers = [provider for provider in self.get_enabled_providers(actor) if provider.name == provider_name] if not providers: raise ValueError(f"Provider {provider_name} is not supported") elif len(providers) > 1: diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index d012171d..49ec99f4 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -1,9 +1,9 @@ from typing import List, Optional, Union from letta.orm.provider import Provider as ProviderModel -from letta.schemas.enums import ProviderType +from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.providers import Provider as PydanticProvider -from letta.schemas.providers import ProviderUpdate +from letta.schemas.providers import ProviderCreate, ProviderUpdate from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types @@ -16,9 +16,12 @@ class ProviderManager: self.session_maker = db_context @enforce_types - def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider: + def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider: """Create a new provider if it doesn't already exist.""" with self.session_maker() as session: + provider_create_args = {**request.model_dump(), "provider_category": ProviderCategory.byok} + provider = PydanticProvider(**provider_create_args) + if provider.name == provider.provider_type.value: raise ValueError("Provider name must be unique and different from provider type") @@ -65,11 +68,11 @@ class ProviderManager: @enforce_types def list_providers( self, + actor: PydanticUser, name: Optional[str] = None, provider_type: Optional[ProviderType] = None, after: Optional[str] = None, limit: Optional[int] = 50, - actor: PydanticUser = None, ) -> List[PydanticProvider]: """List all providers with optional pagination.""" filter_kwargs = {} @@ -88,11 +91,11 @@ class ProviderManager: return [provider.to_pydantic() for provider in providers] @enforce_types - def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]: - providers = self.list_providers(name=provider_name) + def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]: + providers = self.list_providers(name=provider_name, actor=actor) return providers[0].id if providers else None @enforce_types - def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]: - providers = self.list_providers(name=provider_name) + def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]: + providers = self.list_providers(name=provider_name, actor=actor) return providers[0].api_key if providers else None diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index b0cb2802..7774a752 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -105,9 +105,8 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner agent = Agent(agent_state=full_agent_state, interface=None, user=client.user) llm_client = LLMClient.create( - provider_name=agent_state.llm_config.provider_name, provider_type=agent_state.llm_config.model_endpoint_type, - actor_id=client.user.id, + actor=client.user, ) if llm_client: response = llm_client.send_llm_request( diff --git a/tests/test_server.py b/tests/test_server.py index 7d6d73e6..b6440c42 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -13,10 +13,10 @@ import letta.utils as utils from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR from letta.orm import Provider, Step from letta.schemas.block import CreateBlock -from letta.schemas.enums import MessageRole, ProviderType +from letta.schemas.enums import MessageRole, ProviderCategory, ProviderType from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage from letta.schemas.llm_config import LLMConfig -from letta.schemas.providers import Provider as PydanticProvider +from letta.schemas.providers import ProviderCreate from letta.schemas.sandbox_config import SandboxType from letta.schemas.user import User @@ -587,7 +587,7 @@ def test_read_local_llm_configs(server: SyncServer, user: User): # Call list_llm_models assert os.path.exists(configs_base_dir) - llm_models = server.list_llm_models() + llm_models = server.list_llm_models(actor=user) # Assert that the config is in the returned models assert any( @@ -1225,17 +1225,23 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to def test_messages_with_provider_override(server: SyncServer, user_id: str): actor = server.user_manager.get_user_or_default(user_id) provider = server.provider_manager.create_provider( - provider=PydanticProvider( + request=ProviderCreate( name="caren-anthropic", provider_type=ProviderType.anthropic, api_key=os.getenv("ANTHROPIC_API_KEY"), ), actor=actor, ) + models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.byok]) + assert provider.name in [model.provider_name for model in models] + + models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.base]) + assert provider.name not in [model.provider_name for model in models] + agent = server.create_agent( request=CreateAgent( memory_blocks=[], - model="caren-anthropic/claude-3-opus-20240229", + model="caren-anthropic/claude-3-5-sonnet-20240620", context_window_limit=100000, embedding="openai/text-embedding-ada-002", ), @@ -1295,11 +1301,11 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str): assert total_tokens == usage.total_tokens -def test_unique_handles_for_provider_configs(server: SyncServer): - models = server.list_llm_models() +def test_unique_handles_for_provider_configs(server: SyncServer, user: User): + models = server.list_llm_models(actor=user) model_handles = [model.handle for model in models] assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles" - embeddings = server.list_embedding_models() + embeddings = server.list_embedding_models(actor=user) embedding_handles = [embedding.handle for embedding in embeddings] assert sorted(embedding_handles) == sorted(list(set(embedding_handles))), "All embeddings should have unique handles"