feat: add provider_category field to distinguish byok (#2038)
This commit is contained in:
31
alembic/versions/878607e41ca4_add_provider_category.py
Normal file
31
alembic/versions/878607e41ca4_add_provider_category.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""add provider category
|
||||
|
||||
Revision ID: 878607e41ca4
|
||||
Revises: 0335b1eb9c40
|
||||
Create Date: 2025-05-06 12:10:25.751536
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "878607e41ca4"
|
||||
down_revision: Union[str, None] = "0335b1eb9c40"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("providers", sa.Column("provider_category", sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("providers", "provider_category")
|
||||
# ### end Alembic commands ###
|
||||
@@ -331,10 +331,9 @@ class Agent(BaseAgent):
|
||||
log_telemetry(self.logger, "_get_ai_reply create start")
|
||||
# New LLM client flow
|
||||
llm_client = LLMClient.create(
|
||||
provider_name=self.agent_state.llm_config.provider_name,
|
||||
provider_type=self.agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor_id=self.user.id,
|
||||
actor=self.user,
|
||||
)
|
||||
|
||||
if llm_client and not stream:
|
||||
@@ -943,7 +942,10 @@ class Agent(BaseAgent):
|
||||
model_endpoint=self.agent_state.llm_config.model_endpoint,
|
||||
context_window_limit=self.agent_state.llm_config.context_window,
|
||||
usage=response.usage,
|
||||
provider_id=self.provider_manager.get_provider_id_from_name(self.agent_state.llm_config.provider_name),
|
||||
provider_id=self.provider_manager.get_provider_id_from_name(
|
||||
self.agent_state.llm_config.provider_name,
|
||||
actor=self.user,
|
||||
),
|
||||
job_id=job_id,
|
||||
)
|
||||
for message in all_new_messages:
|
||||
@@ -1087,7 +1089,9 @@ class Agent(BaseAgent):
|
||||
LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"]
|
||||
)
|
||||
|
||||
summary = summarize_messages(agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize)
|
||||
summary = summarize_messages(
|
||||
agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize, actor=self.user
|
||||
)
|
||||
logger.info(f"Got summary: {summary}")
|
||||
|
||||
# Metadata that's useful for the agent to see
|
||||
|
||||
@@ -75,10 +75,9 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
provider_name=agent_state.llm_config.provider_name,
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor_id=self.actor.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
for _ in range(max_steps):
|
||||
response = await self._get_ai_reply(
|
||||
@@ -120,10 +119,9 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
provider_name=agent_state.llm_config.provider_name,
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor_id=self.actor.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
for _ in range(max_steps):
|
||||
|
||||
@@ -172,10 +172,9 @@ class LettaAgentBatch:
|
||||
|
||||
log_event(name="init_llm_client")
|
||||
llm_client = LLMClient.create(
|
||||
provider_name=agent_states[0].llm_config.provider_name,
|
||||
provider_type=agent_states[0].llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor_id=self.actor.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
|
||||
|
||||
@@ -284,10 +283,9 @@ class LettaAgentBatch:
|
||||
|
||||
# translate provider‑specific response → OpenAI‑style tool call (unchanged)
|
||||
llm_client = LLMClient.create(
|
||||
provider_name=item.llm_config.provider_name,
|
||||
provider_type=item.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor_id=self.actor.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
tool_call = (
|
||||
llm_client.convert_response_to_chat_completion(
|
||||
|
||||
@@ -3455,7 +3455,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
configs (List[LLMConfig]): List of LLM configurations
|
||||
"""
|
||||
return self.server.list_llm_models()
|
||||
return self.server.list_llm_models(actor=self.user)
|
||||
|
||||
def list_embedding_configs(self) -> List[EmbeddingConfig]:
|
||||
"""
|
||||
@@ -3464,7 +3464,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
configs (List[EmbeddingConfig]): List of embedding configurations
|
||||
"""
|
||||
return self.server.list_embedding_models()
|
||||
return self.server.list_embedding_models(actor=self.user)
|
||||
|
||||
def create_org(self, name: Optional[str] = None) -> Organization:
|
||||
return self.server.organization_manager.create_organization(pydantic_org=Organization(name=name))
|
||||
|
||||
@@ -26,7 +26,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.message import Message as _Message
|
||||
from letta.schemas.message import MessageRole as _MessageRole
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||
@@ -42,6 +42,7 @@ from letta.schemas.openai.chat_completion_response import Message
|
||||
from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage
|
||||
from letta.schemas.openai.chat_completion_response import MessageDelta, ToolCall, ToolCallDelta, UsageStatistics
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
from letta.tracing import log_event
|
||||
@@ -744,12 +745,15 @@ def anthropic_chat_completions_request(
|
||||
extended_thinking: bool = False,
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
provider_category: Optional[ProviderCategory] = None,
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
user_id: Optional[str] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""https://docs.anthropic.com/claude/docs/tool-use"""
|
||||
anthropic_client = None
|
||||
if provider_name and provider_name != ProviderType.anthropic.value:
|
||||
api_key = ProviderManager().get_override_key(provider_name)
|
||||
if provider_category == ProviderCategory.byok:
|
||||
actor = UserManager().get_user_or_default(user_id=user_id)
|
||||
api_key = ProviderManager().get_override_key(provider_name, actor=actor)
|
||||
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.Anthropic()
|
||||
@@ -803,7 +807,9 @@ def anthropic_chat_completions_request_stream(
|
||||
extended_thinking: bool = False,
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
provider_category: Optional[ProviderCategory] = None,
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
user_id: Optional[str] = None,
|
||||
) -> Generator[ChatCompletionChunkResponse, None, None]:
|
||||
"""Stream chat completions from Anthropic API.
|
||||
|
||||
@@ -817,8 +823,9 @@ def anthropic_chat_completions_request_stream(
|
||||
extended_thinking=extended_thinking,
|
||||
max_reasoning_tokens=max_reasoning_tokens,
|
||||
)
|
||||
if provider_name and provider_name != ProviderType.anthropic.value:
|
||||
api_key = ProviderManager().get_override_key(provider_name)
|
||||
if provider_category == ProviderCategory.byok:
|
||||
actor = UserManager().get_user_or_default(user_id=user_id)
|
||||
api_key = ProviderManager().get_override_key(provider_name, actor=actor)
|
||||
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.Anthropic()
|
||||
@@ -867,10 +874,12 @@ def anthropic_chat_completions_process_stream(
|
||||
extended_thinking: bool = False,
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
provider_category: Optional[ProviderCategory] = None,
|
||||
create_message_id: bool = True,
|
||||
create_message_datetime: bool = True,
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
name: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Process a streaming completion response from Anthropic, similar to OpenAI's streaming.
|
||||
|
||||
@@ -952,7 +961,9 @@ def anthropic_chat_completions_process_stream(
|
||||
extended_thinking=extended_thinking,
|
||||
max_reasoning_tokens=max_reasoning_tokens,
|
||||
provider_name=provider_name,
|
||||
provider_category=provider_category,
|
||||
betas=betas,
|
||||
user_id=user_id,
|
||||
)
|
||||
):
|
||||
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)
|
||||
|
||||
@@ -27,7 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
@@ -45,18 +45,18 @@ logger = get_logger(__name__)
|
||||
class AnthropicClient(LLMClientBase):
|
||||
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
client = self._get_anthropic_client(async_client=False)
|
||||
client = self._get_anthropic_client(llm_config, async_client=False)
|
||||
response = client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
||||
return response.model_dump()
|
||||
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
client = self._get_anthropic_client(async_client=True)
|
||||
client = self._get_anthropic_client(llm_config, async_client=True)
|
||||
response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]:
|
||||
client = self._get_anthropic_client(async_client=True)
|
||||
client = self._get_anthropic_client(llm_config, async_client=True)
|
||||
request_data["stream"] = True
|
||||
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
||||
|
||||
@@ -96,7 +96,7 @@ class AnthropicClient(LLMClientBase):
|
||||
for agent_id in agent_messages_mapping
|
||||
}
|
||||
|
||||
client = self._get_anthropic_client(async_client=True)
|
||||
client = self._get_anthropic_client(list(agent_llm_config_mapping.values())[0], async_client=True)
|
||||
|
||||
anthropic_requests = [
|
||||
Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items()
|
||||
@@ -112,10 +112,12 @@ class AnthropicClient(LLMClientBase):
|
||||
raise self.handle_llm_error(e)
|
||||
|
||||
@trace_method
|
||||
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||
def _get_anthropic_client(
|
||||
self, llm_config: LLMConfig, async_client: bool = False
|
||||
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||
override_key = None
|
||||
if self.provider_name and self.provider_name != ProviderType.anthropic.value:
|
||||
override_key = ProviderManager().get_override_key(self.provider_name)
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
override_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
|
||||
|
||||
if async_client:
|
||||
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
|
||||
|
||||
@@ -13,7 +13,7 @@ from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
@@ -31,10 +31,10 @@ class GoogleAIClient(LLMClientBase):
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
api_key = None
|
||||
if llm_config.provider_name and llm_config.provider_name != ProviderType.google_ai.value:
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name)
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.gemini_api_key
|
||||
|
||||
@@ -24,7 +24,7 @@ from letta.llm_api.openai import (
|
||||
from letta.local_llm.chat_completion_proxy import get_chat_completion
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
|
||||
@@ -172,10 +172,12 @@ def create(
|
||||
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
|
||||
# only is a problem if we are *not* using an openai proxy
|
||||
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])
|
||||
elif llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
|
||||
elif llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.services.user_manager import UserManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name)
|
||||
actor = UserManager().get_user_or_default(user_id=user_id)
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=actor)
|
||||
elif model_settings.openai_api_key is None:
|
||||
# the openai python client requires a dummy API key
|
||||
api_key = "DUMMY_API_KEY"
|
||||
@@ -379,7 +381,9 @@ def create(
|
||||
extended_thinking=llm_config.enable_reasoner,
|
||||
max_reasoning_tokens=llm_config.max_reasoning_tokens,
|
||||
provider_name=llm_config.provider_name,
|
||||
provider_category=llm_config.provider_category,
|
||||
name=name,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -390,6 +394,8 @@ def create(
|
||||
extended_thinking=llm_config.enable_reasoner,
|
||||
max_reasoning_tokens=llm_config.max_reasoning_tokens,
|
||||
provider_name=llm_config.provider_name,
|
||||
provider_category=llm_config.provider_category,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.schemas.enums import ProviderType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import User
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Factory class for creating LLM clients based on the model endpoint type."""
|
||||
@@ -10,9 +13,8 @@ class LLMClient:
|
||||
@staticmethod
|
||||
def create(
|
||||
provider_type: ProviderType,
|
||||
provider_name: Optional[str] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
actor_id: Optional[str] = None,
|
||||
actor: Optional["User"] = None,
|
||||
) -> Optional[LLMClientBase]:
|
||||
"""
|
||||
Create an LLM client based on the model endpoint type.
|
||||
@@ -32,33 +34,29 @@ class LLMClient:
|
||||
from letta.llm_api.google_ai_client import GoogleAIClient
|
||||
|
||||
return GoogleAIClient(
|
||||
provider_name=provider_name,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor_id=actor_id,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.google_vertex:
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
|
||||
return GoogleVertexClient(
|
||||
provider_name=provider_name,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor_id=actor_id,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.anthropic:
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
|
||||
return AnthropicClient(
|
||||
provider_name=provider_name,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor_id=actor_id,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.openai:
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
|
||||
return OpenAIClient(
|
||||
provider_name=provider_name,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor_id=actor_id,
|
||||
actor=actor,
|
||||
)
|
||||
case _:
|
||||
return None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from anthropic.types.beta.messages import BetaMessageBatch
|
||||
from openai import AsyncStream, Stream
|
||||
@@ -11,6 +11,9 @@ from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.tracing import log_event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import User
|
||||
|
||||
|
||||
class LLMClientBase:
|
||||
"""
|
||||
@@ -20,13 +23,11 @@ class LLMClientBase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: Optional[str] = None,
|
||||
put_inner_thoughts_first: Optional[bool] = True,
|
||||
use_tool_naming: bool = True,
|
||||
actor_id: Optional[str] = None,
|
||||
actor: Optional["User"] = None,
|
||||
):
|
||||
self.actor_id = actor_id
|
||||
self.provider_name = provider_name
|
||||
self.actor = actor
|
||||
self.put_inner_thoughts_first = put_inner_thoughts_first
|
||||
self.use_tool_naming = use_tool_naming
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
@@ -78,10 +78,10 @@ def supports_parallel_tool_calling(model: str) -> bool:
|
||||
class OpenAIClient(LLMClientBase):
|
||||
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
|
||||
api_key = None
|
||||
if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name)
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
@@ -156,11 +156,11 @@ class OpenAIClient(LLMClientBase):
|
||||
)
|
||||
|
||||
# always set user id for openai requests
|
||||
if self.actor_id:
|
||||
data.user = self.actor_id
|
||||
if self.actor:
|
||||
data.user = self.actor.id
|
||||
|
||||
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
|
||||
if not self.actor_id:
|
||||
if not self.actor:
|
||||
# override user id for inference.letta.com
|
||||
import uuid
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, Dict, List
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||
|
||||
from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
@@ -13,6 +13,9 @@ from letta.settings import summarizer_settings
|
||||
from letta.tracing import trace_method
|
||||
from letta.utils import count_tokens, printd
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import User
|
||||
|
||||
|
||||
def get_memory_functions(cls: Memory) -> Dict[str, Callable]:
|
||||
"""Get memory functions for a memory class"""
|
||||
@@ -51,6 +54,7 @@ def _format_summary_history(message_history: List[Message]):
|
||||
def summarize_messages(
|
||||
agent_state: AgentState,
|
||||
message_sequence_to_summarize: List[Message],
|
||||
actor: "User",
|
||||
):
|
||||
"""Summarize a message sequence using GPT"""
|
||||
# we need the context_window
|
||||
@@ -63,7 +67,7 @@ def summarize_messages(
|
||||
trunc_ratio = (summarizer_settings.memory_warning_threshold * context_window / summary_input_tkns) * 0.8 # For good measure...
|
||||
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
|
||||
summary_input = str(
|
||||
[summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff])]
|
||||
[summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff], actor=actor)]
|
||||
+ message_sequence_to_summarize[cutoff:]
|
||||
)
|
||||
|
||||
@@ -79,10 +83,9 @@ def summarize_messages(
|
||||
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
|
||||
|
||||
llm_client = LLMClient.create(
|
||||
provider_name=llm_config_no_inner_thoughts.provider_name,
|
||||
provider_type=llm_config_no_inner_thoughts.model_endpoint_type,
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=False,
|
||||
actor_id=agent_state.created_by_id,
|
||||
actor=actor,
|
||||
)
|
||||
# try to use new client, otherwise fallback to old flow
|
||||
# TODO: we can just directly call the LLM here?
|
||||
|
||||
@@ -26,6 +26,7 @@ class Provider(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
|
||||
provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider")
|
||||
provider_category: Mapped[str] = mapped_column(nullable=True, doc="The category of the provider (base or byok)")
|
||||
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
|
||||
base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.")
|
||||
|
||||
|
||||
@@ -19,6 +19,11 @@ class ProviderType(str, Enum):
|
||||
bedrock = "bedrock"
|
||||
|
||||
|
||||
class ProviderCategory(str, Enum):
|
||||
base = "base"
|
||||
byok = "byok"
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
assistant = "assistant"
|
||||
user = "user"
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from letta.constants import LETTA_MODEL_ENDPOINT
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -51,6 +52,7 @@ class LLMConfig(BaseModel):
|
||||
] = Field(..., description="The endpoint type for the model.")
|
||||
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
||||
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
|
||||
provider_category: Optional[ProviderCategory] = Field(None, description="The provider category for the model.")
|
||||
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
|
||||
context_window: int = Field(..., description="The context window size for the model.")
|
||||
put_inner_thoughts_in_kwargs: Optional[bool] = Field(
|
||||
|
||||
@@ -9,7 +9,7 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_
|
||||
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
|
||||
@@ -24,6 +24,7 @@ class Provider(ProviderBase):
|
||||
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
|
||||
name: str = Field(..., description="The name of the provider")
|
||||
provider_type: ProviderType = Field(..., description="The type of the provider")
|
||||
provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)")
|
||||
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
|
||||
base_url: Optional[str] = Field(None, description="Base URL for the provider.")
|
||||
organization_id: Optional[str] = Field(None, description="The organization id of the user")
|
||||
@@ -113,6 +114,7 @@ class ProviderUpdate(ProviderBase):
|
||||
|
||||
class LettaProvider(Provider):
|
||||
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
return [
|
||||
@@ -123,6 +125,7 @@ class LettaProvider(Provider):
|
||||
context_window=8192,
|
||||
handle=self.get_handle("letta-free"),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -141,6 +144,7 @@ class LettaProvider(Provider):
|
||||
|
||||
class OpenAIProvider(Provider):
|
||||
provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
api_key: str = Field(..., description="API key for the OpenAI API.")
|
||||
base_url: str = Field(..., description="Base URL for the OpenAI API.")
|
||||
|
||||
@@ -225,6 +229,7 @@ class OpenAIProvider(Provider):
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -281,6 +286,7 @@ class DeepSeekProvider(OpenAIProvider):
|
||||
"""
|
||||
|
||||
provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
|
||||
api_key: str = Field(..., description="API key for the DeepSeek API.")
|
||||
|
||||
@@ -332,6 +338,7 @@ class DeepSeekProvider(OpenAIProvider):
|
||||
handle=self.get_handle(model_name),
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -344,6 +351,7 @@ class DeepSeekProvider(OpenAIProvider):
|
||||
|
||||
class LMStudioOpenAIProvider(OpenAIProvider):
|
||||
provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
|
||||
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
|
||||
|
||||
@@ -470,6 +478,7 @@ class XAIProvider(OpenAIProvider):
|
||||
"""https://docs.x.ai/docs/api-reference"""
|
||||
|
||||
provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
api_key: str = Field(..., description="API key for the xAI/Grok API.")
|
||||
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
|
||||
|
||||
@@ -523,6 +532,7 @@ class XAIProvider(OpenAIProvider):
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -535,6 +545,7 @@ class XAIProvider(OpenAIProvider):
|
||||
|
||||
class AnthropicProvider(Provider):
|
||||
provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
api_key: str = Field(..., description="API key for the Anthropic API.")
|
||||
base_url: str = "https://api.anthropic.com/v1"
|
||||
|
||||
@@ -611,6 +622,7 @@ class AnthropicProvider(Provider):
|
||||
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
|
||||
max_tokens=max_tokens,
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -621,6 +633,7 @@ class AnthropicProvider(Provider):
|
||||
|
||||
class MistralProvider(Provider):
|
||||
provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
api_key: str = Field(..., description="API key for the Mistral API.")
|
||||
base_url: str = "https://api.mistral.ai/v1"
|
||||
|
||||
@@ -645,6 +658,7 @@ class MistralProvider(Provider):
|
||||
context_window=model["max_context_length"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -672,6 +686,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
"""
|
||||
|
||||
provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
base_url: str = Field(..., description="Base URL for the Ollama API.")
|
||||
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
|
||||
default_prompt_formatter: str = Field(
|
||||
@@ -702,6 +717,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
context_window=context_window,
|
||||
handle=self.get_handle(model["name"]),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -785,6 +801,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
|
||||
class GroqProvider(OpenAIProvider):
|
||||
provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
base_url: str = "https://api.groq.com/openai/v1"
|
||||
api_key: str = Field(..., description="API key for the Groq API.")
|
||||
|
||||
@@ -804,6 +821,7 @@ class GroqProvider(OpenAIProvider):
|
||||
context_window=model["context_window"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -825,6 +843,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
"""
|
||||
|
||||
provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
base_url: str = "https://api.together.ai/v1"
|
||||
api_key: str = Field(..., description="API key for the TogetherAI API.")
|
||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
||||
@@ -873,6 +892,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -927,6 +947,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
class GoogleAIProvider(Provider):
|
||||
# gemini
|
||||
provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
api_key: str = Field(..., description="API key for the Google AI API.")
|
||||
base_url: str = "https://generativelanguage.googleapis.com"
|
||||
|
||||
@@ -955,6 +976,7 @@ class GoogleAIProvider(Provider):
|
||||
handle=self.get_handle(model),
|
||||
max_tokens=8192,
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -991,6 +1013,7 @@ class GoogleAIProvider(Provider):
|
||||
|
||||
class GoogleVertexProvider(Provider):
|
||||
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
|
||||
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
|
||||
|
||||
@@ -1008,6 +1031,7 @@ class GoogleVertexProvider(Provider):
|
||||
handle=self.get_handle(model),
|
||||
max_tokens=8192,
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -1032,6 +1056,7 @@ class GoogleVertexProvider(Provider):
|
||||
|
||||
class AzureProvider(Provider):
|
||||
provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
||||
base_url: str = Field(
|
||||
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
|
||||
@@ -1065,6 +1090,7 @@ class AzureProvider(Provider):
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
),
|
||||
)
|
||||
return configs
|
||||
@@ -1106,6 +1132,7 @@ class VLLMChatCompletionsProvider(Provider):
|
||||
|
||||
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
||||
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
@@ -1125,6 +1152,7 @@ class VLLMChatCompletionsProvider(Provider):
|
||||
context_window=model["max_model_len"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -1139,6 +1167,7 @@ class VLLMCompletionsProvider(Provider):
|
||||
|
||||
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
||||
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
||||
|
||||
@@ -1159,6 +1188,7 @@ class VLLMCompletionsProvider(Provider):
|
||||
context_window=model["max_model_len"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -1174,6 +1204,7 @@ class CohereProvider(OpenAIProvider):
|
||||
|
||||
class AnthropicBedrockProvider(Provider):
|
||||
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
aws_region: str = Field(..., description="AWS region for Bedrock")
|
||||
|
||||
def list_llm_models(self):
|
||||
@@ -1192,6 +1223,7 @@ class AnthropicBedrockProvider(Provider):
|
||||
context_window=self.get_model_context_window(model_arn),
|
||||
handle=self.get_handle(model_arn),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi import APIRouter, Depends, Header, Query
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
|
||||
@@ -14,12 +15,19 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
|
||||
|
||||
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
|
||||
def list_llm_models(
|
||||
byok_only: Optional[bool] = Query(None),
|
||||
default_only: Optional[bool] = Query(None),
|
||||
provider_category: Optional[List[ProviderCategory]] = Query(None),
|
||||
provider_name: Optional[str] = Query(None),
|
||||
provider_type: Optional[ProviderType] = Query(None),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
|
||||
models = server.list_llm_models(byok_only=byok_only, default_only=default_only)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
models = server.list_llm_models(
|
||||
provider_category=provider_category,
|
||||
provider_name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
# print(models)
|
||||
return models
|
||||
|
||||
@@ -27,8 +35,9 @@ def list_llm_models(
|
||||
@router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models")
|
||||
def list_embedding_models(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
|
||||
models = server.list_embedding_models()
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
models = server.list_embedding_models(actor=actor)
|
||||
# print(models)
|
||||
return models
|
||||
|
||||
@@ -95,6 +95,7 @@ def create_source(
|
||||
source_create.embedding_config = server.get_embedding_config_from_handle(
|
||||
handle=source_create.embedding,
|
||||
embedding_chunk_size=source_create.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
actor=actor,
|
||||
)
|
||||
source = Source(
|
||||
name=source_create.name,
|
||||
|
||||
@@ -42,7 +42,7 @@ from letta.schemas.block import Block, BlockUpdate, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
# openai schemas
|
||||
from letta.schemas.enums import JobStatus, MessageStreamStatus
|
||||
from letta.schemas.enums import JobStatus, MessageStreamStatus, ProviderCategory, ProviderType
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.group import GroupCreate, ManagerType, SleeptimeManager, VoiceSleeptimeManager
|
||||
from letta.schemas.job import Job, JobUpdate
|
||||
@@ -734,17 +734,17 @@ class SyncServer(Server):
|
||||
return self._command(user_id=user_id, agent_id=agent_id, command=command)
|
||||
|
||||
@trace_method
|
||||
def get_cached_llm_config(self, **kwargs):
|
||||
def get_cached_llm_config(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._llm_config_cache:
|
||||
self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs)
|
||||
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
|
||||
return self._llm_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
def get_cached_embedding_config(self, **kwargs):
|
||||
def get_cached_embedding_config(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._embedding_config_cache:
|
||||
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(**kwargs)
|
||||
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
|
||||
return self._embedding_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
@@ -766,7 +766,7 @@ class SyncServer(Server):
|
||||
"enable_reasoner": request.enable_reasoner,
|
||||
}
|
||||
log_event(name="start get_cached_llm_config", attributes=config_params)
|
||||
request.llm_config = self.get_cached_llm_config(**config_params)
|
||||
request.llm_config = self.get_cached_llm_config(actor=actor, **config_params)
|
||||
log_event(name="end get_cached_llm_config", attributes=config_params)
|
||||
|
||||
if request.embedding_config is None:
|
||||
@@ -777,7 +777,7 @@ class SyncServer(Server):
|
||||
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
}
|
||||
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
|
||||
request.embedding_config = self.get_cached_embedding_config(**embedding_config_params)
|
||||
request.embedding_config = self.get_cached_embedding_config(actor=actor, **embedding_config_params)
|
||||
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
|
||||
|
||||
log_event(name="start create_agent db")
|
||||
@@ -802,10 +802,10 @@ class SyncServer(Server):
|
||||
actor: User,
|
||||
) -> AgentState:
|
||||
if request.model is not None:
|
||||
request.llm_config = self.get_llm_config_from_handle(handle=request.model)
|
||||
request.llm_config = self.get_llm_config_from_handle(handle=request.model, actor=actor)
|
||||
|
||||
if request.embedding is not None:
|
||||
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding)
|
||||
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding, actor=actor)
|
||||
|
||||
if request.enable_sleeptime:
|
||||
agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
@@ -1201,10 +1201,21 @@ class SyncServer(Server):
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
|
||||
|
||||
def list_llm_models(self, byok_only: bool = False, default_only: bool = False) -> List[LLMConfig]:
|
||||
def list_llm_models(
|
||||
self,
|
||||
actor: User,
|
||||
provider_category: Optional[List[ProviderCategory]] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
) -> List[LLMConfig]:
|
||||
"""List available models"""
|
||||
llm_models = []
|
||||
for provider in self.get_enabled_providers(byok_only=byok_only, default_only=default_only):
|
||||
for provider in self.get_enabled_providers(
|
||||
provider_category=provider_category,
|
||||
provider_name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
):
|
||||
try:
|
||||
llm_models.extend(provider.list_llm_models())
|
||||
except Exception as e:
|
||||
@@ -1214,32 +1225,49 @@ class SyncServer(Server):
|
||||
|
||||
return llm_models
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
|
||||
"""List available embedding models"""
|
||||
embedding_models = []
|
||||
for provider in self.get_enabled_providers():
|
||||
for provider in self.get_enabled_providers(actor):
|
||||
try:
|
||||
embedding_models.extend(provider.list_embedding_models())
|
||||
except Exception as e:
|
||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return embedding_models
|
||||
|
||||
def get_enabled_providers(self, byok_only: bool = False, default_only: bool = False):
|
||||
providers_from_env = {p.name: p for p in self._enabled_providers}
|
||||
def get_enabled_providers(
|
||||
self,
|
||||
actor: User,
|
||||
provider_category: Optional[List[ProviderCategory]] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
) -> List[Provider]:
|
||||
providers = []
|
||||
if not provider_category or ProviderCategory.base in provider_category:
|
||||
providers_from_env = [p for p in self._enabled_providers]
|
||||
providers.extend(providers_from_env)
|
||||
|
||||
if default_only:
|
||||
return list(providers_from_env.values())
|
||||
if not provider_category or ProviderCategory.byok in provider_category:
|
||||
providers_from_db = self.provider_manager.list_providers(
|
||||
name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
|
||||
providers.extend(providers_from_db)
|
||||
|
||||
providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
|
||||
if provider_name is not None:
|
||||
providers = [p for p in providers if p.name == provider_name]
|
||||
|
||||
if byok_only:
|
||||
return list(providers_from_db.values())
|
||||
if provider_type is not None:
|
||||
providers = [p for p in providers if p.provider_type == provider_type]
|
||||
|
||||
return list(providers_from_env.values()) + list(providers_from_db.values())
|
||||
return providers
|
||||
|
||||
@trace_method
|
||||
def get_llm_config_from_handle(
|
||||
self,
|
||||
actor: User,
|
||||
handle: str,
|
||||
context_window_limit: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
@@ -1248,7 +1276,7 @@ class SyncServer(Server):
|
||||
) -> LLMConfig:
|
||||
try:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = self.get_provider_from_name(provider_name)
|
||||
provider = self.get_provider_from_name(provider_name, actor)
|
||||
|
||||
llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
|
||||
if not llm_configs:
|
||||
@@ -1292,11 +1320,11 @@ class SyncServer(Server):
|
||||
|
||||
@trace_method
|
||||
def get_embedding_config_from_handle(
|
||||
self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
) -> EmbeddingConfig:
|
||||
try:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = self.get_provider_from_name(provider_name)
|
||||
provider = self.get_provider_from_name(provider_name, actor)
|
||||
|
||||
embedding_configs = [config for config in provider.list_embedding_models() if config.handle == handle]
|
||||
if not embedding_configs:
|
||||
@@ -1319,8 +1347,8 @@ class SyncServer(Server):
|
||||
|
||||
return embedding_config
|
||||
|
||||
def get_provider_from_name(self, provider_name: str) -> Provider:
|
||||
providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
|
||||
def get_provider_from_name(self, provider_name: str, actor: User) -> Provider:
|
||||
providers = [provider for provider in self.get_enabled_providers(actor) if provider.name == provider_name]
|
||||
if not providers:
|
||||
raise ValueError(f"Provider {provider_name} is not supported")
|
||||
elif len(providers) > 1:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from letta.orm.provider import Provider as ProviderModel
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
from letta.schemas.providers import ProviderUpdate
|
||||
from letta.schemas.providers import ProviderCreate, ProviderUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
@@ -16,9 +16,12 @@ class ProviderManager:
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider:
|
||||
def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider:
|
||||
"""Create a new provider if it doesn't already exist."""
|
||||
with self.session_maker() as session:
|
||||
provider_create_args = {**request.model_dump(), "provider_category": ProviderCategory.byok}
|
||||
provider = PydanticProvider(**provider_create_args)
|
||||
|
||||
if provider.name == provider.provider_type.value:
|
||||
raise ValueError("Provider name must be unique and different from provider type")
|
||||
|
||||
@@ -65,11 +68,11 @@ class ProviderManager:
|
||||
@enforce_types
|
||||
def list_providers(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
actor: PydanticUser = None,
|
||||
) -> List[PydanticProvider]:
|
||||
"""List all providers with optional pagination."""
|
||||
filter_kwargs = {}
|
||||
@@ -88,11 +91,11 @@ class ProviderManager:
|
||||
return [provider.to_pydantic() for provider in providers]
|
||||
|
||||
@enforce_types
|
||||
def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]:
|
||||
providers = self.list_providers(name=provider_name)
|
||||
def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
||||
providers = self.list_providers(name=provider_name, actor=actor)
|
||||
return providers[0].id if providers else None
|
||||
|
||||
@enforce_types
|
||||
def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]:
|
||||
providers = self.list_providers(name=provider_name)
|
||||
def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
||||
providers = self.list_providers(name=provider_name, actor=actor)
|
||||
return providers[0].api_key if providers else None
|
||||
|
||||
@@ -105,9 +105,8 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
|
||||
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
|
||||
|
||||
llm_client = LLMClient.create(
|
||||
provider_name=agent_state.llm_config.provider_name,
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
actor_id=client.user.id,
|
||||
actor=client.user,
|
||||
)
|
||||
if llm_client:
|
||||
response = llm_client.send_llm_request(
|
||||
|
||||
@@ -13,10 +13,10 @@ import letta.utils as utils
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.orm import Provider, Step
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.enums import MessageRole, ProviderType
|
||||
from letta.schemas.enums import MessageRole, ProviderCategory, ProviderType
|
||||
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
from letta.schemas.providers import ProviderCreate
|
||||
from letta.schemas.sandbox_config import SandboxType
|
||||
from letta.schemas.user import User
|
||||
|
||||
@@ -587,7 +587,7 @@ def test_read_local_llm_configs(server: SyncServer, user: User):
|
||||
|
||||
# Call list_llm_models
|
||||
assert os.path.exists(configs_base_dir)
|
||||
llm_models = server.list_llm_models()
|
||||
llm_models = server.list_llm_models(actor=user)
|
||||
|
||||
# Assert that the config is in the returned models
|
||||
assert any(
|
||||
@@ -1225,17 +1225,23 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
|
||||
def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
provider = server.provider_manager.create_provider(
|
||||
provider=PydanticProvider(
|
||||
request=ProviderCreate(
|
||||
name="caren-anthropic",
|
||||
provider_type=ProviderType.anthropic,
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.byok])
|
||||
assert provider.name in [model.provider_name for model in models]
|
||||
|
||||
models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.base])
|
||||
assert provider.name not in [model.provider_name for model in models]
|
||||
|
||||
agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
memory_blocks=[],
|
||||
model="caren-anthropic/claude-3-opus-20240229",
|
||||
model="caren-anthropic/claude-3-5-sonnet-20240620",
|
||||
context_window_limit=100000,
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
@@ -1295,11 +1301,11 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
assert total_tokens == usage.total_tokens
|
||||
|
||||
|
||||
def test_unique_handles_for_provider_configs(server: SyncServer):
|
||||
models = server.list_llm_models()
|
||||
def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
|
||||
models = server.list_llm_models(actor=user)
|
||||
model_handles = [model.handle for model in models]
|
||||
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"
|
||||
embeddings = server.list_embedding_models()
|
||||
embeddings = server.list_embedding_models(actor=user)
|
||||
embedding_handles = [embedding.handle for embedding in embeddings]
|
||||
assert sorted(embedding_handles) == sorted(list(set(embedding_handles))), "All embeddings should have unique handles"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user