diff --git a/letta/server/rest_api/routers/v1/llms.py b/letta/server/rest_api/routers/v1/llms.py index 48556382..c98c2a11 100644 --- a/letta/server/rest_api/routers/v1/llms.py +++ b/letta/server/rest_api/routers/v1/llms.py @@ -23,7 +23,7 @@ async def list_llm_models( # Extract user_id from header, default to None if not present ): """List available LLM models using the asynchronous implementation for improved performance""" - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) models = await server.list_llm_models_async( provider_category=provider_category, @@ -42,7 +42,7 @@ async def list_embedding_models( # Extract user_id from header, default to None if not present ): """List available embedding models using the asynchronous implementation for improved performance""" - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) models = await server.list_embedding_models_async(actor=actor) return models diff --git a/letta/server/server.py b/letta/server/server.py index 02419ee9..cac12d69 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1381,7 +1381,7 @@ class SyncServer(Server): """Asynchronously list available models with maximum concurrency""" import asyncio - providers = self.get_enabled_providers( + providers = await self.get_enabled_providers_async( provider_category=provider_category, provider_name=provider_name, provider_type=provider_type, @@ -1427,7 +1427,7 @@ class SyncServer(Server): import asyncio # Get all eligible providers first - providers = self.get_enabled_providers(actor=actor) + providers = await self.get_enabled_providers_async(actor=actor) # Fetch embedding models from each provider concurrently async def get_provider_embedding_models(provider): @@ -1480,6 +1480,35 @@ class SyncServer(Server): return providers + async def get_enabled_providers_async( + self, + actor: User, + provider_category: Optional[List[ProviderCategory]] = None, + provider_name: Optional[str] = None, + provider_type: Optional[ProviderType] = None, + ) -> List[Provider]: + providers = [] + if not provider_category or ProviderCategory.base in provider_category: + providers_from_env = [p for p in self._enabled_providers] + providers.extend(providers_from_env) + + if not provider_category or ProviderCategory.byok in provider_category: + providers_from_db = await self.provider_manager.list_providers_async( + name=provider_name, + provider_type=provider_type, + actor=actor, + ) + providers_from_db = [p.cast_to_subtype() for p in providers_from_db] + providers.extend(providers_from_db) + + if provider_name is not None: + providers = [p for p in providers if p.name == provider_name] + + if provider_type is not None: + providers = [p for p in providers if p.provider_type == provider_type] + + return providers + @trace_method def get_llm_config_from_handle( self, diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 1b8bb4b8..6b2bab01 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -91,6 +91,32 @@ class ProviderManager: ) return [provider.to_pydantic() for provider in providers] + @enforce_types + @trace_method + async def list_providers_async( + self, + actor: PydanticUser, + name: Optional[str] = None, + provider_type: Optional[ProviderType] = None, + after: Optional[str] = None, + limit: Optional[int] = 50, + ) -> List[PydanticProvider]: + """List all providers with optional pagination.""" + filter_kwargs = {} + if name: + filter_kwargs["name"] = name + if provider_type: + filter_kwargs["provider_type"] = provider_type + async with db_registry.async_session() as session: + providers = await ProviderModel.list_async( + db_session=session, + after=after, + limit=limit, + actor=actor, + **filter_kwargs, + ) + return [provider.to_pydantic() for provider in providers] + @enforce_types @trace_method def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]: diff --git a/letta/services/telemetry_manager.py b/letta/services/telemetry_manager.py index 5bb7c0d2..a57474b1 100644 --- a/letta/services/telemetry_manager.py +++ b/letta/services/telemetry_manager.py @@ -6,13 +6,11 @@ from letta.schemas.step import Step as PydanticStep from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.services.helpers.noop_helper import singleton -from letta.tracing import trace_method from letta.utils import enforce_types class TelemetryManager: @enforce_types - @trace_method async def get_provider_trace_by_step_id_async( self, step_id: str, @@ -23,7 +21,6 @@ class TelemetryManager: return provider_trace.to_pydantic() @enforce_types - @trace_method async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: async with db_registry.async_session() as session: provider_trace = ProviderTraceModel(**provider_trace_create.model_dump()) @@ -38,7 +35,6 @@ class TelemetryManager: return provider_trace.to_pydantic() @enforce_types - @trace_method def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: with db_registry.session() as session: provider_trace = ProviderTraceModel(**provider_trace_create.model_dump()) @@ -52,14 +48,11 @@ class NoopTelemetryManager(TelemetryManager): Noop implementation of TelemetryManager. """ - @trace_method async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: return - @trace_method async def get_provider_trace_by_step_id_async(self, step_id: str, actor: PydanticUser) -> PydanticStep: return - @trace_method def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: return diff --git a/tests/test_server.py b/tests/test_server.py index 519d95d6..cc1c6b65 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,3 +1,4 @@ +import asyncio import json import os import shutil @@ -359,6 +360,14 @@ def other_agent_id(server, user_id, base_tools): server.agent_manager.delete_agent(agent_state.id, actor=actor) +@pytest.fixture(scope="session") +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + def test_error_on_nonexistent_agent(server, user, agent_id): try: fake_agent_id = str(uuid.uuid4()) @@ -527,7 +536,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User): @pytest.mark.asyncio -async def test_read_local_llm_configs(server: SyncServer, user: User): +async def test_read_local_llm_configs(server: SyncServer, user: User, event_loop): configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs") clean_up_dir = False if not os.path.exists(configs_base_dir): @@ -564,7 +573,7 @@ async def test_read_local_llm_configs(server: SyncServer, user: User): # Try to use in agent creation context_window_override = 4000 - agent = server.create_agent( + agent = await server.create_agent_async( request=CreateAgent( model="caren/my-custom-model", context_window_limit=context_window_override, @@ -1061,8 +1070,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to @pytest.mark.asyncio -async def test_messages_with_provider_override(server: SyncServer, user_id: str): - actor = server.user_manager.get_user_or_default(user_id) +async def test_messages_with_provider_override(server: SyncServer, user_id: str, event_loop): + actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id) provider = server.provider_manager.create_provider( request=ProviderCreate( name="caren-anthropic", @@ -1077,7 +1086,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str) models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.base]) assert provider.name not in [model.provider_name for model in models] - agent = server.create_agent( + agent = await server.create_agent_async( request=CreateAgent( memory_blocks=[], model="caren-anthropic/claude-3-5-sonnet-20240620", @@ -1141,7 +1150,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str) @pytest.mark.asyncio -async def test_unique_handles_for_provider_configs(server: SyncServer, user: User): +async def test_unique_handles_for_provider_configs(server: SyncServer, user: User, event_loop): models = await server.list_llm_models_async(actor=user) model_handles = [model.handle for model in models] assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"