diff --git a/letta/server/rest_api/routers/v1/llms.py b/letta/server/rest_api/routers/v1/llms.py index 450f8608..48556382 100644 --- a/letta/server/rest_api/routers/v1/llms.py +++ b/letta/server/rest_api/routers/v1/llms.py @@ -14,30 +14,35 @@ router = APIRouter(prefix="/models", tags=["models", "llms"]) @router.get("/", response_model=List[LLMConfig], operation_id="list_models") -def list_llm_models( +async def list_llm_models( provider_category: Optional[List[ProviderCategory]] = Query(None), provider_name: Optional[str] = Query(None), provider_type: Optional[ProviderType] = Query(None), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), + # Extract user_id from header, default to None if not present ): + """List available LLM models using the asynchronous implementation for improved performance""" actor = server.user_manager.get_user_or_default(user_id=actor_id) - models = server.list_llm_models( + + models = await server.list_llm_models_async( provider_category=provider_category, provider_name=provider_name, provider_type=provider_type, actor=actor, ) - # print(models) + return models @router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models") -def list_embedding_models( +async def list_embedding_models( server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), + # Extract user_id from header, default to None if not present ): + """List available embedding models using the asynchronous implementation for improved performance""" actor = server.user_manager.get_user_or_default(user_id=actor_id) - models = server.list_embedding_models(actor=actor) - # print(models) + models = await server.list_embedding_models_async(actor=actor) + return models diff --git a/letta/server/server.py b/letta/server/server.py index 42808850..04434d37 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1365,6 +1365,48 @@ class SyncServer(Server): return llm_models + @trace_method + async def list_llm_models_async( + self, + actor: User, + provider_category: Optional[List[ProviderCategory]] = None, + provider_name: Optional[str] = None, + provider_type: Optional[ProviderType] = None, + ) -> List[LLMConfig]: + """Asynchronously list available models with maximum concurrency""" + import asyncio + + providers = self.get_enabled_providers( + provider_category=provider_category, + provider_name=provider_name, + provider_type=provider_type, + actor=actor, + ) + + async def get_provider_models(provider): + try: + return await provider.list_llm_models_async() + except Exception as e: + import traceback + + traceback.print_exc() + warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}") + return [] + + # Execute all provider model listing tasks concurrently + provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers]) + + # Flatten the results + llm_models = [] + for models in provider_results: + llm_models.extend(models) + + # Get local configs - if this is potentially slow, consider making it async too + local_configs = self.get_local_llm_configs() + llm_models.extend(local_configs) + + return llm_models + def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]: """List available embedding models""" embedding_models = [] @@ -1375,6 +1417,35 @@ class SyncServer(Server): warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") return embedding_models + async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]: + """Asynchronously list available embedding models with maximum concurrency""" + import asyncio + + # Get all eligible providers first + providers = self.get_enabled_providers(actor=actor) + + # Fetch embedding models from each provider concurrently + async def get_provider_embedding_models(provider): + try: + # All providers now have list_embedding_models_async + return await provider.list_embedding_models_async() + except Exception as e: + import traceback + + traceback.print_exc() + warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") + return [] + + # Execute all provider model listing tasks concurrently + provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers]) + + # Flatten the results + embedding_models = [] + for models in provider_results: + embedding_models.extend(models) + + return embedding_models + def get_enabled_providers( self, actor: User, diff --git a/tests/test_server.py b/tests/test_server.py index a832bc92..200ff54e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -566,7 +566,8 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User): server.agent_manager.delete_agent(agent_state.id, actor=another_user) -def test_read_local_llm_configs(server: SyncServer, user: User): +@pytest.mark.asyncio +async def test_read_local_llm_configs(server: SyncServer, user: User): configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs") clean_up_dir = False if not os.path.exists(configs_base_dir): @@ -589,7 +590,7 @@ def test_read_local_llm_configs(server: SyncServer, user: User): # Call list_llm_models assert os.path.exists(configs_base_dir) - llm_models = server.list_llm_models(actor=user) + llm_models = await server.list_llm_models_async(actor=user) # Assert that the config is in the returned models assert any( @@ -1224,7 +1225,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to assert len(agent_state.tools) == len(base_tools) - 2 -def test_messages_with_provider_override(server: SyncServer, user_id: str): +@pytest.mark.asyncio +async def test_messages_with_provider_override(server: SyncServer, user_id: str): actor = server.user_manager.get_user_or_default(user_id) provider = server.provider_manager.create_provider( request=ProviderCreate( @@ -1234,10 +1236,10 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str): ), actor=actor, ) - models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.byok]) + models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.byok]) assert provider.name in [model.provider_name for model in models] - models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.base]) + models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.base]) assert provider.name not in [model.provider_name for model in models] agent = server.create_agent( @@ -1303,11 +1305,12 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str): assert total_tokens == usage.total_tokens -def test_unique_handles_for_provider_configs(server: SyncServer, user: User): - models = server.list_llm_models(actor=user) +@pytest.mark.asyncio +async def test_unique_handles_for_provider_configs(server: SyncServer, user: User): + models = await server.list_llm_models_async(actor=user) model_handles = [model.handle for model in models] assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles" - embeddings = server.list_embedding_models(actor=user) + embeddings = await server.list_embedding_models_async(actor=user) embedding_handles = [embedding.handle for embedding in embeddings] assert sorted(embedding_handles) == sorted(list(set(embedding_handles))), "All embeddings should have unique handles"