feat: Concurrently gather llm models in /models endpoint (#2288)

This commit is contained in:
Matthew Zhou
2025-05-21 11:19:13 -07:00
committed by GitHub
parent b0367cf814
commit fdc2d8ec22
3 changed files with 95 additions and 16 deletions

View File

@@ -14,30 +14,35 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
def list_llm_models(
async def list_llm_models(
provider_category: Optional[List[ProviderCategory]] = Query(None),
provider_name: Optional[str] = Query(None),
provider_type: Optional[ProviderType] = Query(None),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"),
# Extract user_id from header, default to None if not present
):
"""List available LLM models using the asynchronous implementation for improved performance"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
models = server.list_llm_models(
models = await server.list_llm_models_async(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
)
# print(models)
return models
@router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models")
def list_embedding_models(
async def list_embedding_models(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"),
# Extract user_id from header, default to None if not present
):
"""List available embedding models using the asynchronous implementation for improved performance"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
models = server.list_embedding_models(actor=actor)
# print(models)
models = await server.list_embedding_models_async(actor=actor)
return models

View File

@@ -1365,6 +1365,48 @@ class SyncServer(Server):
return llm_models
@trace_method
async def list_llm_models_async(
self,
actor: User,
provider_category: Optional[List[ProviderCategory]] = None,
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[LLMConfig]:
"""Asynchronously list available models with maximum concurrency"""
import asyncio
providers = self.get_enabled_providers(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
)
async def get_provider_models(provider):
try:
return await provider.list_llm_models_async()
except Exception as e:
import traceback
traceback.print_exc()
warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}")
return []
# Execute all provider model listing tasks concurrently
provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers])
# Flatten the results
llm_models = []
for models in provider_results:
llm_models.extend(models)
# Get local configs - if this is potentially slow, consider making it async too
local_configs = self.get_local_llm_configs()
llm_models.extend(local_configs)
return llm_models
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
"""List available embedding models"""
embedding_models = []
@@ -1375,6 +1417,35 @@ class SyncServer(Server):
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
"""Asynchronously list available embedding models with maximum concurrency"""
import asyncio
# Get all eligible providers first
providers = self.get_enabled_providers(actor=actor)
# Fetch embedding models from each provider concurrently
async def get_provider_embedding_models(provider):
try:
# All providers now have list_embedding_models_async
return await provider.list_embedding_models_async()
except Exception as e:
import traceback
traceback.print_exc()
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return []
# Execute all provider model listing tasks concurrently
provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers])
# Flatten the results
embedding_models = []
for models in provider_results:
embedding_models.extend(models)
return embedding_models
def get_enabled_providers(
self,
actor: User,

View File

@@ -566,7 +566,8 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
server.agent_manager.delete_agent(agent_state.id, actor=another_user)
def test_read_local_llm_configs(server: SyncServer, user: User):
@pytest.mark.asyncio
async def test_read_local_llm_configs(server: SyncServer, user: User):
configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs")
clean_up_dir = False
if not os.path.exists(configs_base_dir):
@@ -589,7 +590,7 @@ def test_read_local_llm_configs(server: SyncServer, user: User):
# Call list_llm_models
assert os.path.exists(configs_base_dir)
llm_models = server.list_llm_models(actor=user)
llm_models = await server.list_llm_models_async(actor=user)
# Assert that the config is in the returned models
assert any(
@@ -1224,7 +1225,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
assert len(agent_state.tools) == len(base_tools) - 2
def test_messages_with_provider_override(server: SyncServer, user_id: str):
@pytest.mark.asyncio
async def test_messages_with_provider_override(server: SyncServer, user_id: str):
actor = server.user_manager.get_user_or_default(user_id)
provider = server.provider_manager.create_provider(
request=ProviderCreate(
@@ -1234,10 +1236,10 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
),
actor=actor,
)
models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.byok])
models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.byok])
assert provider.name in [model.provider_name for model in models]
models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.base])
models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.base])
assert provider.name not in [model.provider_name for model in models]
agent = server.create_agent(
@@ -1303,11 +1305,12 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
assert total_tokens == usage.total_tokens
def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
models = server.list_llm_models(actor=user)
@pytest.mark.asyncio
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
models = await server.list_llm_models_async(actor=user)
model_handles = [model.handle for model in models]
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"
embeddings = server.list_embedding_models(actor=user)
embeddings = await server.list_embedding_models_async(actor=user)
embedding_handles = [embedding.handle for embedding in embeddings]
assert sorted(embedding_handles) == sorted(list(set(embedding_handles))), "All embeddings should have unique handles"