feat: enable provider models persistence (#6193)

* Revert "fix test"

This reverts commit 5126815f23cefb4edad3e3bf9e7083209dcc7bf1.

* fix server and better test

* test fix, get api key for base and byok?

* set letta default endpoint

* try to fix timeout for test

* fix for letta api key

* Delete apps/core/tests/sdk_v1/conftest.py

* Update utils.py

* clean up a few issues

* fix filterning on list_llm_models

* soft delete models with provider

* add one more test

* fix ci

* add timeout

* band aid for letta embedding provider

* info instead of error logs when creating models
This commit is contained in:
Ari Webb
2025-12-09 14:33:06 -08:00
committed by Caren Thomas
parent b4af037c19
commit 848a73125c
8 changed files with 754 additions and 205 deletions

View File

@@ -109,7 +109,7 @@ from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import DatabaseChoice, model_settings, settings, tool_settings
from letta.streaming_interface import AgentChunkStreamingInterface
from letta.utils import get_friendly_error_msg, get_persona_text, make_key, safe_create_task
from letta.utils import get_friendly_error_msg, get_persona_text, safe_create_task
config = LettaConfig.load()
logger = get_logger(__name__)
@@ -203,12 +203,10 @@ class SyncServer(object):
"""Initialize the MCP clients (there may be multiple)"""
self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {}
# TODO: Remove these in memory caches
self._llm_config_cache = {}
self._embedding_config_cache = {}
# collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
from letta.constants import LETTA_MODEL_ENDPOINT
self._enabled_providers: List[Provider] = [LettaProvider(name="letta", base_url=LETTA_MODEL_ENDPOINT)]
if model_settings.openai_api_key:
self._enabled_providers.append(
OpenAIProvider(
@@ -342,6 +340,12 @@ class SyncServer(object):
print(f"Default user: {self.default_user} and org: {self.default_org}")
await self.tool_manager.upsert_base_tools_async(actor=self.default_user)
# Sync environment-based providers to database (idempotent, safe for multi-pod startup)
await self.provider_manager.sync_base_providers(base_providers=self._enabled_providers, actor=self.default_user)
# Sync provider models to database
await self._sync_provider_models_async()
# For OSS users, create a local sandbox config
oss_default_user = await self.user_manager.get_default_actor_async()
use_venv = False if not tool_settings.tool_exec_venv_name else True
@@ -378,6 +382,65 @@ class SyncServer(object):
force_recreate=True,
)
def _get_enabled_provider(self, provider_name: str) -> Optional[Provider]:
"""Find and return an enabled provider by name.
Args:
provider_name: The name of the provider to find
Returns:
The matching enabled provider, or None if not found
"""
for provider in self._enabled_providers:
if provider.name == provider_name:
return provider
return None
async def _sync_provider_models_async(self):
"""Sync all provider models to database at startup."""
logger.info("Syncing provider models to database")
# Get persisted providers from database (they now have IDs)
persisted_providers = await self.provider_manager.list_providers_async(actor=self.default_user)
for persisted_provider in persisted_providers:
try:
# Find the matching enabled provider instance to call list_models on
enabled_provider = self._get_enabled_provider(persisted_provider.name)
if not enabled_provider:
# Only delete base providers that are no longer enabled
# BYOK providers are user-created and should not be automatically deleted
if persisted_provider.provider_category == ProviderCategory.base:
logger.info(f"Base provider {persisted_provider.name} is no longer enabled, deleting from database")
try:
await self.provider_manager.delete_provider_by_id_async(
provider_id=persisted_provider.id, actor=self.default_user
)
except NoResultFound:
# Provider was already deleted (race condition in multi-pod startup)
logger.debug(f"Provider {persisted_provider.name} was already deleted, skipping")
else:
logger.debug(f"No enabled provider for BYOK provider {persisted_provider.name}, skipping model sync")
continue
# Fetch models from provider
llm_models = await enabled_provider.list_llm_models_async()
embedding_models = await enabled_provider.list_embedding_models_async()
# Save to database with the persisted provider (which has an ID)
await self.provider_manager.sync_provider_models_async(
provider=persisted_provider,
llm_models=llm_models,
embedding_models=embedding_models,
organization_id=None, # Global models
)
logger.info(
f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}"
)
except Exception as e:
logger.error(f"Failed to sync models for provider {persisted_provider.name}: {e}", exc_info=True)
async def init_mcp_clients(self):
# TODO: remove this
mcp_server_configs = self.get_mcp_servers()
@@ -405,39 +468,6 @@ class SyncServer(object):
logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}")
logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}")
@trace_method
def get_cached_llm_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries")
return self._llm_config_cache[key]
@trace_method
async def get_cached_llm_config_async(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = await self.get_llm_config_from_handle_async(actor=actor, **kwargs)
logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries")
return self._llm_config_cache[key]
@trace_method
def get_cached_embedding_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._embedding_config_cache:
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries")
return self._embedding_config_cache[key]
# @async_redis_cache(key_func=lambda (actor, **kwargs): actor.id + hash(kwargs))
@trace_method
async def get_cached_embedding_config_async(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._embedding_config_cache:
self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs)
logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries")
return self._embedding_config_cache[key]
@trace_method
async def create_agent_async(
self,
@@ -471,10 +501,9 @@ class SyncServer(object):
"max_reasoning_tokens": request.max_reasoning_tokens,
"enable_reasoner": request.enable_reasoner,
}
config_params.update(additional_config_params)
log_event(name="start get_cached_llm_config", attributes=config_params)
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
log_event(name="end get_cached_llm_config", attributes=config_params)
log_event(name="start get_llm_config_from_handle", attributes=config_params)
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
log_event(name="end get_llm_config_from_handle", attributes=config_params)
if request.model and isinstance(request.model, str):
assert request.llm_config.handle == request.model, (
f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}"
@@ -504,9 +533,9 @@ class SyncServer(object):
"handle": request.embedding,
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
}
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
request.embedding_config = await self.get_cached_embedding_config_async(actor=actor, **embedding_config_params)
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
log_event(name="start get_embedding_config_from_handle", attributes=embedding_config_params)
request.embedding_config = await self.get_embedding_config_from_handle_async(actor=actor, **embedding_config_params)
log_event(name="end get_embedding_config_from_handle", attributes=embedding_config_params)
log_event(name="start create_agent db")
main_agent = await self.agent_manager.create_agent_async(
@@ -555,9 +584,9 @@ class SyncServer(object):
"context_window_limit": request.context_window_limit,
"max_tokens": request.max_tokens,
}
log_event(name="start get_cached_llm_config", attributes=config_params)
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
log_event(name="end get_cached_llm_config", attributes=config_params)
log_event(name="start get_llm_config_from_handle", attributes=config_params)
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
log_event(name="end get_llm_config_from_handle", attributes=config_params)
# update with model_settings
if request.model_settings is not None:
@@ -1061,73 +1090,85 @@ class SyncServer(object):
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[LLMConfig]:
"""Asynchronously list available models with maximum concurrency"""
import asyncio
"""List available LLM models from database cache"""
# Get provider IDs if filtering by provider
provider_ids = None
if provider_name or provider_type:
providers = await self.get_enabled_providers_async(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
)
provider_ids = [p.id for p in providers]
providers = await self.get_enabled_providers_async(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
# If filtering was requested but no providers matched, return empty list
if not provider_ids:
return []
# Get models from database
provider_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="llm",
provider_id=provider_ids[0] if provider_ids and len(provider_ids) == 1 else None,
enabled=True,
)
async def get_provider_models(provider: Provider) -> list[LLMConfig]:
try:
async with asyncio.timeout(constants.GET_PROVIDERS_TIMEOUT_SECONDS):
return await provider.list_llm_models_async()
except asyncio.TimeoutError:
logger.warning(f"Timeout while listing LLM models for provider {provider}")
return []
except Exception as e:
logger.exception(f"Error while listing LLM models for provider {provider}: {e}")
return []
# Execute all provider model listing tasks concurrently
provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers])
# Flatten the results
# Build LLMConfig objects from cached data
# Cache providers to avoid N+1 queries
provider_cache: Dict[str, Provider] = {}
llm_models = []
for models in provider_results:
llm_models.extend(models)
for model in provider_models:
# Skip if filtering by provider and model doesn't match
if provider_ids and model.provider_id not in provider_ids:
continue
# Get local configs - if this is potentially slow, consider making it async too
local_configs = self.get_local_llm_configs()
llm_models.extend(local_configs)
# Get provider details (with caching to avoid N+1 queries)
if model.provider_id not in provider_cache:
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
provider = provider_cache[model.provider_id]
# dedupe by handle for uniqueness
# Seems like this is required from the tests?
seen_handles = set()
unique_models = []
for model in llm_models:
if model.handle not in seen_handles:
seen_handles.add(model.handle)
unique_models.append(model)
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=provider.base_url or model.model_endpoint_type,
context_window=model.max_context_window or 16384,
handle=model.handle,
provider_name=provider.name,
provider_category=provider.provider_category,
)
llm_models.append(llm_config)
return unique_models
return llm_models
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
"""Asynchronously list available embedding models with maximum concurrency"""
import asyncio
"""List available embedding models from database cache"""
# Get models from database
provider_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="embedding",
enabled=True,
)
# Get all eligible providers first
providers = await self.get_enabled_providers_async(actor=actor)
# Fetch embedding models from each provider concurrently
async def get_provider_embedding_models(provider):
try:
# All providers now have list_embedding_models_async
return await provider.list_embedding_models_async()
except Exception as e:
logger.exception(f"An error occurred while listing embedding models for provider {provider}: {e}")
return []
# Execute all provider model listing tasks concurrently
provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers])
# Flatten the results
# Build EmbeddingConfig objects from cached data
# Cache providers to avoid N+1 queries
provider_cache: Dict[str, Provider] = {}
embedding_models = []
for models in provider_results:
embedding_models.extend(models)
for model in provider_models:
# Get provider details (with caching to avoid N+1 queries)
if model.provider_id not in provider_cache:
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
provider = provider_cache[model.provider_id]
embedding_config = EmbeddingConfig(
embedding_model=model.name,
embedding_endpoint_type=model.model_endpoint_type,
embedding_endpoint=provider.base_url or model.model_endpoint_type,
embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=model.handle,
)
embedding_models.append(embedding_config)
return embedding_models
@@ -1140,17 +1181,22 @@ class SyncServer(object):
) -> List[Provider]:
providers = []
if not provider_category or ProviderCategory.base in provider_category:
providers_from_env = [p for p in self._enabled_providers]
providers.extend(providers_from_env)
# Add enabled providers (base providers from environment)
enabled_providers = [p for p in self._enabled_providers]
providers.extend(enabled_providers)
if not provider_category or ProviderCategory.byok in provider_category:
providers_from_db = await self.provider_manager.list_providers_async(
# Add persisted BYOK providers from database
# Note: list_providers_async returns both org-specific and global providers,
# so we filter to only include BYOK providers to avoid duplicating base providers
persisted_providers = await self.provider_manager.list_providers_async(
name=provider_name,
provider_type=provider_type,
actor=actor,
)
providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
providers.extend(providers_from_db)
# Filter to only BYOK providers (base providers are already in self._enabled_providers)
persisted_byok_providers = [p.cast_to_subtype() for p in persisted_providers if p.provider_category == ProviderCategory.byok]
providers.extend(persisted_byok_providers)
if provider_name is not None:
providers = [p for p in providers if p.name == provider_name]
@@ -1170,32 +1216,19 @@ class SyncServer(object):
max_reasoning_tokens: Optional[int] = None,
enable_reasoner: Optional[bool] = None,
) -> LLMConfig:
# Use provider_manager to get LLMConfig from handle
try:
provider_name, model_name = handle.split("/", 1)
provider = await self.get_provider_from_name_async(provider_name, actor)
all_llm_configs = await provider.list_llm_models_async()
llm_configs = [config for config in all_llm_configs if config.handle == handle]
if not llm_configs:
llm_configs = [config for config in all_llm_configs if config.model == model_name]
if not llm_configs:
available_handles = [config.handle for config in all_llm_configs]
raise HandleNotFoundError(handle, available_handles)
except ValueError as e:
llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle]
if not llm_configs:
llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name]
if not llm_configs:
raise e
if len(llm_configs) == 1:
llm_config = llm_configs[0]
elif len(llm_configs) > 1:
raise LettaInvalidArgumentError(
f"Multiple LLM models with name {model_name} supported by {provider_name}", argument_name="model_name"
llm_config = await self.provider_manager.get_llm_config_from_handle(
handle=handle,
actor=actor,
)
else:
llm_config = llm_configs[0]
except Exception as e:
# Convert to HandleNotFoundError for backwards compatibility
from letta.orm.errors import NoResultFound
if isinstance(e, NoResultFound):
raise HandleNotFoundError(handle, [])
raise
if context_window_limit is not None:
if context_window_limit > llm_config.context_window:
@@ -1227,33 +1260,22 @@ class SyncServer(object):
async def get_embedding_config_from_handle_async(
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
) -> EmbeddingConfig:
# Use provider_manager to get EmbeddingConfig from handle
try:
provider_name, model_name = handle.split("/", 1)
provider = await self.get_provider_from_name_async(provider_name, actor)
all_embedding_configs = await provider.list_embedding_models_async()
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
if not embedding_configs:
raise LettaInvalidArgumentError(
f"Embedding model {model_name} is not supported by {provider_name}", argument_name="model_name"
)
except LettaInvalidArgumentError as e:
# search local configs
embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle]
if not embedding_configs:
raise e
if len(embedding_configs) == 1:
embedding_config = embedding_configs[0]
elif len(embedding_configs) > 1:
raise LettaInvalidArgumentError(
f"Multiple embedding models with name {model_name} supported by {provider_name}", argument_name="model_name"
embedding_config = await self.provider_manager.get_embedding_config_from_handle(
handle=handle,
actor=actor,
)
else:
embedding_config = embedding_configs[0]
except Exception as e:
# Convert to LettaInvalidArgumentError for backwards compatibility
from letta.orm.errors import NoResultFound
if embedding_chunk_size:
embedding_config.embedding_chunk_size = embedding_chunk_size
if isinstance(e, NoResultFound):
raise LettaInvalidArgumentError(f"Embedding model {handle} not found", argument_name="handle")
raise
# Override chunk size if provided
embedding_config.embedding_chunk_size = embedding_chunk_size
return embedding_config
@@ -1272,46 +1294,6 @@ class SyncServer(object):
return provider
def get_local_llm_configs(self):
llm_models = []
# NOTE: deprecated
# try:
# llm_configs_dir = os.path.expanduser("~/.letta/llm_configs")
# if os.path.exists(llm_configs_dir):
# for filename in os.listdir(llm_configs_dir):
# if filename.endswith(".json"):
# filepath = os.path.join(llm_configs_dir, filename)
# try:
# with open(filepath, "r") as f:
# config_data = json.load(f)
# llm_config = LLMConfig(**config_data)
# llm_models.append(llm_config)
# except (json.JSONDecodeError, ValueError) as e:
# logger.warning(f"Error parsing LLM config file {filename}: {e}")
# except Exception as e:
# logger.warning(f"Error reading LLM configs directory: {e}")
return llm_models
def get_local_embedding_configs(self):
embedding_models = []
# NOTE: deprecated
# try:
# embedding_configs_dir = os.path.expanduser("~/.letta/embedding_configs")
# if os.path.exists(embedding_configs_dir):
# for filename in os.listdir(embedding_configs_dir):
# if filename.endswith(".json"):
# filepath = os.path.join(embedding_configs_dir, filename)
# try:
# with open(filepath, "r") as f:
# config_data = json.load(f)
# embedding_config = EmbeddingConfig(**config_data)
# embedding_models.append(embedding_config)
# except (json.JSONDecodeError, ValueError) as e:
# logger.warning(f"Error parsing embedding config file {filename}: {e}")
# except Exception as e:
# logger.warning(f"Error reading embedding configs directory: {e}")
return embedding_models
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
"""Add a new LLM model"""