feat(asyncify): list model stragglers (#2362)

This commit is contained in:
cthomas
2025-05-23 00:42:05 -07:00
committed by GitHub
parent 378f22087a
commit d1b0756657
5 changed files with 74 additions and 17 deletions

View File

@@ -23,7 +23,7 @@ async def list_llm_models(
# Extract user_id from header, default to None if not present
):
"""List available LLM models using the asynchronous implementation for improved performance"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
models = await server.list_llm_models_async(
provider_category=provider_category,
@@ -42,7 +42,7 @@ async def list_embedding_models(
# Extract user_id from header, default to None if not present
):
"""List available embedding models using the asynchronous implementation for improved performance"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
models = await server.list_embedding_models_async(actor=actor)
return models

View File

@@ -1381,7 +1381,7 @@ class SyncServer(Server):
"""Asynchronously list available models with maximum concurrency"""
import asyncio
providers = self.get_enabled_providers(
providers = await self.get_enabled_providers_async(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
@@ -1427,7 +1427,7 @@ class SyncServer(Server):
import asyncio
# Get all eligible providers first
providers = self.get_enabled_providers(actor=actor)
providers = await self.get_enabled_providers_async(actor=actor)
# Fetch embedding models from each provider concurrently
async def get_provider_embedding_models(provider):
@@ -1480,6 +1480,35 @@ class SyncServer(Server):
return providers
async def get_enabled_providers_async(
self,
actor: User,
provider_category: Optional[List[ProviderCategory]] = None,
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[Provider]:
providers = []
if not provider_category or ProviderCategory.base in provider_category:
providers_from_env = [p for p in self._enabled_providers]
providers.extend(providers_from_env)
if not provider_category or ProviderCategory.byok in provider_category:
providers_from_db = await self.provider_manager.list_providers_async(
name=provider_name,
provider_type=provider_type,
actor=actor,
)
providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
providers.extend(providers_from_db)
if provider_name is not None:
providers = [p for p in providers if p.name == provider_name]
if provider_type is not None:
providers = [p for p in providers if p.provider_type == provider_type]
return providers
@trace_method
def get_llm_config_from_handle(
self,

View File

@@ -91,6 +91,32 @@ class ProviderManager:
)
return [provider.to_pydantic() for provider in providers]
@enforce_types
@trace_method
async def list_providers_async(
self,
actor: PydanticUser,
name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
) -> List[PydanticProvider]:
"""List all providers with optional pagination."""
filter_kwargs = {}
if name:
filter_kwargs["name"] = name
if provider_type:
filter_kwargs["provider_type"] = provider_type
async with db_registry.async_session() as session:
providers = await ProviderModel.list_async(
db_session=session,
after=after,
limit=limit,
actor=actor,
**filter_kwargs,
)
return [provider.to_pydantic() for provider in providers]
@enforce_types
@trace_method
def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:

View File

@@ -6,13 +6,11 @@ from letta.schemas.step import Step as PydanticStep
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.helpers.noop_helper import singleton
from letta.tracing import trace_method
from letta.utils import enforce_types
class TelemetryManager:
@enforce_types
@trace_method
async def get_provider_trace_by_step_id_async(
self,
step_id: str,
@@ -23,7 +21,6 @@ class TelemetryManager:
return provider_trace.to_pydantic()
@enforce_types
@trace_method
async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
async with db_registry.async_session() as session:
provider_trace = ProviderTraceModel(**provider_trace_create.model_dump())
@@ -38,7 +35,6 @@ class TelemetryManager:
return provider_trace.to_pydantic()
@enforce_types
@trace_method
def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
with db_registry.session() as session:
provider_trace = ProviderTraceModel(**provider_trace_create.model_dump())
@@ -52,14 +48,11 @@ class NoopTelemetryManager(TelemetryManager):
Noop implementation of TelemetryManager.
"""
@trace_method
async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
return
@trace_method
async def get_provider_trace_by_step_id_async(self, step_id: str, actor: PydanticUser) -> PydanticStep:
return
@trace_method
def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
return

View File

@@ -1,3 +1,4 @@
import asyncio
import json
import os
import shutil
@@ -359,6 +360,14 @@ def other_agent_id(server, user_id, base_tools):
server.agent_manager.delete_agent(agent_state.id, actor=actor)
@pytest.fixture(scope="session")
def event_loop(request):
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
def test_error_on_nonexistent_agent(server, user, agent_id):
try:
fake_agent_id = str(uuid.uuid4())
@@ -527,7 +536,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
@pytest.mark.asyncio
async def test_read_local_llm_configs(server: SyncServer, user: User):
async def test_read_local_llm_configs(server: SyncServer, user: User, event_loop):
configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs")
clean_up_dir = False
if not os.path.exists(configs_base_dir):
@@ -564,7 +573,7 @@ async def test_read_local_llm_configs(server: SyncServer, user: User):
# Try to use in agent creation
context_window_override = 4000
agent = server.create_agent(
agent = await server.create_agent_async(
request=CreateAgent(
model="caren/my-custom-model",
context_window_limit=context_window_override,
@@ -1061,8 +1070,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
@pytest.mark.asyncio
async def test_messages_with_provider_override(server: SyncServer, user_id: str):
actor = server.user_manager.get_user_or_default(user_id)
async def test_messages_with_provider_override(server: SyncServer, user_id: str, event_loop):
actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id)
provider = server.provider_manager.create_provider(
request=ProviderCreate(
name="caren-anthropic",
@@ -1077,7 +1086,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str)
models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.base])
assert provider.name not in [model.provider_name for model in models]
agent = server.create_agent(
agent = await server.create_agent_async(
request=CreateAgent(
memory_blocks=[],
model="caren-anthropic/claude-3-5-sonnet-20240620",
@@ -1141,7 +1150,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str)
@pytest.mark.asyncio
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User, event_loop):
models = await server.list_llm_models_async(actor=user)
model_handles = [model.handle for model in models]
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"