feat: Concurrently gather llm models in /models endpoint (#2288)
This commit is contained in:
@@ -14,30 +14,35 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
|
||||
def list_llm_models(
|
||||
async def list_llm_models(
|
||||
provider_category: Optional[List[ProviderCategory]] = Query(None),
|
||||
provider_name: Optional[str] = Query(None),
|
||||
provider_type: Optional[ProviderType] = Query(None),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
# Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""List available LLM models using the asynchronous implementation for improved performance"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
models = server.list_llm_models(
|
||||
|
||||
models = await server.list_llm_models_async(
|
||||
provider_category=provider_category,
|
||||
provider_name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
# print(models)
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models")
|
||||
def list_embedding_models(
|
||||
async def list_embedding_models(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
# Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""List available embedding models using the asynchronous implementation for improved performance"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
models = server.list_embedding_models(actor=actor)
|
||||
# print(models)
|
||||
models = await server.list_embedding_models_async(actor=actor)
|
||||
|
||||
return models
|
||||
|
||||
@@ -1365,6 +1365,48 @@ class SyncServer(Server):
|
||||
|
||||
return llm_models
|
||||
|
||||
@trace_method
|
||||
async def list_llm_models_async(
|
||||
self,
|
||||
actor: User,
|
||||
provider_category: Optional[List[ProviderCategory]] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
) -> List[LLMConfig]:
|
||||
"""Asynchronously list available models with maximum concurrency"""
|
||||
import asyncio
|
||||
|
||||
providers = self.get_enabled_providers(
|
||||
provider_category=provider_category,
|
||||
provider_name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
async def get_provider_models(provider):
|
||||
try:
|
||||
return await provider.list_llm_models_async()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}")
|
||||
return []
|
||||
|
||||
# Execute all provider model listing tasks concurrently
|
||||
provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers])
|
||||
|
||||
# Flatten the results
|
||||
llm_models = []
|
||||
for models in provider_results:
|
||||
llm_models.extend(models)
|
||||
|
||||
# Get local configs - if this is potentially slow, consider making it async too
|
||||
local_configs = self.get_local_llm_configs()
|
||||
llm_models.extend(local_configs)
|
||||
|
||||
return llm_models
|
||||
|
||||
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
|
||||
"""List available embedding models"""
|
||||
embedding_models = []
|
||||
@@ -1375,6 +1417,35 @@ class SyncServer(Server):
|
||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return embedding_models
|
||||
|
||||
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
|
||||
"""Asynchronously list available embedding models with maximum concurrency"""
|
||||
import asyncio
|
||||
|
||||
# Get all eligible providers first
|
||||
providers = self.get_enabled_providers(actor=actor)
|
||||
|
||||
# Fetch embedding models from each provider concurrently
|
||||
async def get_provider_embedding_models(provider):
|
||||
try:
|
||||
# All providers now have list_embedding_models_async
|
||||
return await provider.list_embedding_models_async()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return []
|
||||
|
||||
# Execute all provider model listing tasks concurrently
|
||||
provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers])
|
||||
|
||||
# Flatten the results
|
||||
embedding_models = []
|
||||
for models in provider_results:
|
||||
embedding_models.extend(models)
|
||||
|
||||
return embedding_models
|
||||
|
||||
def get_enabled_providers(
|
||||
self,
|
||||
actor: User,
|
||||
|
||||
@@ -566,7 +566,8 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
|
||||
server.agent_manager.delete_agent(agent_state.id, actor=another_user)
|
||||
|
||||
|
||||
def test_read_local_llm_configs(server: SyncServer, user: User):
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_local_llm_configs(server: SyncServer, user: User):
|
||||
configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs")
|
||||
clean_up_dir = False
|
||||
if not os.path.exists(configs_base_dir):
|
||||
@@ -589,7 +590,7 @@ def test_read_local_llm_configs(server: SyncServer, user: User):
|
||||
|
||||
# Call list_llm_models
|
||||
assert os.path.exists(configs_base_dir)
|
||||
llm_models = server.list_llm_models(actor=user)
|
||||
llm_models = await server.list_llm_models_async(actor=user)
|
||||
|
||||
# Assert that the config is in the returned models
|
||||
assert any(
|
||||
@@ -1224,7 +1225,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
|
||||
assert len(agent_state.tools) == len(base_tools) - 2
|
||||
|
||||
|
||||
def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
provider = server.provider_manager.create_provider(
|
||||
request=ProviderCreate(
|
||||
@@ -1234,10 +1236,10 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.byok])
|
||||
models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.byok])
|
||||
assert provider.name in [model.provider_name for model in models]
|
||||
|
||||
models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.base])
|
||||
models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.base])
|
||||
assert provider.name not in [model.provider_name for model in models]
|
||||
|
||||
agent = server.create_agent(
|
||||
@@ -1303,11 +1305,12 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
assert total_tokens == usage.total_tokens
|
||||
|
||||
|
||||
def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
|
||||
models = server.list_llm_models(actor=user)
|
||||
@pytest.mark.asyncio
|
||||
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
|
||||
models = await server.list_llm_models_async(actor=user)
|
||||
model_handles = [model.handle for model in models]
|
||||
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"
|
||||
embeddings = server.list_embedding_models(actor=user)
|
||||
embeddings = await server.list_embedding_models_async(actor=user)
|
||||
embedding_handles = [embedding.handle for embedding in embeddings]
|
||||
assert sorted(embedding_handles) == sorted(list(set(embedding_handles))), "All embeddings should have unique handles"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user