feat(asyncify): list model stragglers (#2362)
This commit is contained in:
@@ -23,7 +23,7 @@ async def list_llm_models(
|
||||
# Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""List available LLM models using the asynchronous implementation for improved performance"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
models = await server.list_llm_models_async(
|
||||
provider_category=provider_category,
|
||||
@@ -42,7 +42,7 @@ async def list_embedding_models(
|
||||
# Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""List available embedding models using the asynchronous implementation for improved performance"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
models = await server.list_embedding_models_async(actor=actor)
|
||||
|
||||
return models
|
||||
|
||||
@@ -1381,7 +1381,7 @@ class SyncServer(Server):
|
||||
"""Asynchronously list available models with maximum concurrency"""
|
||||
import asyncio
|
||||
|
||||
providers = self.get_enabled_providers(
|
||||
providers = await self.get_enabled_providers_async(
|
||||
provider_category=provider_category,
|
||||
provider_name=provider_name,
|
||||
provider_type=provider_type,
|
||||
@@ -1427,7 +1427,7 @@ class SyncServer(Server):
|
||||
import asyncio
|
||||
|
||||
# Get all eligible providers first
|
||||
providers = self.get_enabled_providers(actor=actor)
|
||||
providers = await self.get_enabled_providers_async(actor=actor)
|
||||
|
||||
# Fetch embedding models from each provider concurrently
|
||||
async def get_provider_embedding_models(provider):
|
||||
@@ -1480,6 +1480,35 @@ class SyncServer(Server):
|
||||
|
||||
return providers
|
||||
|
||||
async def get_enabled_providers_async(
|
||||
self,
|
||||
actor: User,
|
||||
provider_category: Optional[List[ProviderCategory]] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
) -> List[Provider]:
|
||||
providers = []
|
||||
if not provider_category or ProviderCategory.base in provider_category:
|
||||
providers_from_env = [p for p in self._enabled_providers]
|
||||
providers.extend(providers_from_env)
|
||||
|
||||
if not provider_category or ProviderCategory.byok in provider_category:
|
||||
providers_from_db = await self.provider_manager.list_providers_async(
|
||||
name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
|
||||
providers.extend(providers_from_db)
|
||||
|
||||
if provider_name is not None:
|
||||
providers = [p for p in providers if p.name == provider_name]
|
||||
|
||||
if provider_type is not None:
|
||||
providers = [p for p in providers if p.provider_type == provider_type]
|
||||
|
||||
return providers
|
||||
|
||||
@trace_method
|
||||
def get_llm_config_from_handle(
|
||||
self,
|
||||
|
||||
@@ -91,6 +91,32 @@ class ProviderManager:
|
||||
)
|
||||
return [provider.to_pydantic() for provider in providers]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_providers_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
) -> List[PydanticProvider]:
|
||||
"""List all providers with optional pagination."""
|
||||
filter_kwargs = {}
|
||||
if name:
|
||||
filter_kwargs["name"] = name
|
||||
if provider_type:
|
||||
filter_kwargs["provider_type"] = provider_type
|
||||
async with db_registry.async_session() as session:
|
||||
providers = await ProviderModel.list_async(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
**filter_kwargs,
|
||||
)
|
||||
return [provider.to_pydantic() for provider in providers]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
||||
|
||||
@@ -6,13 +6,11 @@ from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.helpers.noop_helper import singleton
|
||||
from letta.tracing import trace_method
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class TelemetryManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_provider_trace_by_step_id_async(
|
||||
self,
|
||||
step_id: str,
|
||||
@@ -23,7 +21,6 @@ class TelemetryManager:
|
||||
return provider_trace.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
|
||||
async with db_registry.async_session() as session:
|
||||
provider_trace = ProviderTraceModel(**provider_trace_create.model_dump())
|
||||
@@ -38,7 +35,6 @@ class TelemetryManager:
|
||||
return provider_trace.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
|
||||
with db_registry.session() as session:
|
||||
provider_trace = ProviderTraceModel(**provider_trace_create.model_dump())
|
||||
@@ -52,14 +48,11 @@ class NoopTelemetryManager(TelemetryManager):
|
||||
Noop implementation of TelemetryManager.
|
||||
"""
|
||||
|
||||
@trace_method
|
||||
async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
|
||||
return
|
||||
|
||||
@trace_method
|
||||
async def get_provider_trace_by_step_id_async(self, step_id: str, actor: PydanticUser) -> PydanticStep:
|
||||
return
|
||||
|
||||
@trace_method
|
||||
def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
|
||||
return
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
@@ -359,6 +360,14 @@ def other_agent_id(server, user_id, base_tools):
|
||||
server.agent_manager.delete_agent(agent_state.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop(request):
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
def test_error_on_nonexistent_agent(server, user, agent_id):
|
||||
try:
|
||||
fake_agent_id = str(uuid.uuid4())
|
||||
@@ -527,7 +536,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_local_llm_configs(server: SyncServer, user: User):
|
||||
async def test_read_local_llm_configs(server: SyncServer, user: User, event_loop):
|
||||
configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs")
|
||||
clean_up_dir = False
|
||||
if not os.path.exists(configs_base_dir):
|
||||
@@ -564,7 +573,7 @@ async def test_read_local_llm_configs(server: SyncServer, user: User):
|
||||
|
||||
# Try to use in agent creation
|
||||
context_window_override = 4000
|
||||
agent = server.create_agent(
|
||||
agent = await server.create_agent_async(
|
||||
request=CreateAgent(
|
||||
model="caren/my-custom-model",
|
||||
context_window_limit=context_window_override,
|
||||
@@ -1061,8 +1070,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
async def test_messages_with_provider_override(server: SyncServer, user_id: str, event_loop):
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id)
|
||||
provider = server.provider_manager.create_provider(
|
||||
request=ProviderCreate(
|
||||
name="caren-anthropic",
|
||||
@@ -1077,7 +1086,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str)
|
||||
models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.base])
|
||||
assert provider.name not in [model.provider_name for model in models]
|
||||
|
||||
agent = server.create_agent(
|
||||
agent = await server.create_agent_async(
|
||||
request=CreateAgent(
|
||||
memory_blocks=[],
|
||||
model="caren-anthropic/claude-3-5-sonnet-20240620",
|
||||
@@ -1141,7 +1150,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
|
||||
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User, event_loop):
|
||||
models = await server.list_llm_models_async(actor=user)
|
||||
model_handles = [model.handle for model in models]
|
||||
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"
|
||||
|
||||
Reference in New Issue
Block a user