Revert "Revert "feat: enable provider models persistence" (#6590)" (#6595)

This commit is contained in:
Ari Webb
2026-01-05 13:19:29 -08:00
committed by Caren Thomas
parent 64a1a8b14e
commit cc825b4f5c
8 changed files with 977 additions and 279 deletions

View File

@@ -235,24 +235,38 @@ class LLMClientBase:
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Returns the override key for the given llm config.
For both base and BYOK providers, fetch the API key from the database.
"""
api_key = None
if llm_config.provider_category == ProviderCategory.byok:
# Fetch API key from database for both base and BYOK providers
# This ensures that base providers (from environment) also have their keys persisted and accessible
if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
# If we got an empty string from the database (e.g., Letta provider), treat it as None
# so the client can fall back to environment variables or default behavior
if api_key == "":
api_key = None
return api_key, None, None
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Returns the override key for the given llm config.
For both base and BYOK providers, fetch the API key from the database.
"""
api_key = None
if llm_config.provider_category == ProviderCategory.byok:
# Fetch API key from database for both base and BYOK providers
# This ensures that base providers (from environment) also have their keys persisted and accessible
if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]:
from letta.services.provider_manager import ProviderManager
api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
# If we got an empty string from the database (e.g., Letta provider), treat it as None
# so the client can fall back to environment variables or default behavior
if api_key == "":
api_key = None
return api_key, None, None

View File

@@ -8,10 +8,13 @@ from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers.base import Provider
LETTA_EMBEDDING_ENDPOINT = "https://embeddings.letta.com/"
class LettaProvider(Provider):
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
base_url: str = Field(LETTA_EMBEDDING_ENDPOINT, description="Base URL for the Letta API (used for embeddings).")
async def list_llm_models_async(self) -> list[LLMConfig]:
return [
@@ -32,7 +35,7 @@ class LettaProvider(Provider):
EmbeddingConfig(
embedding_model="letta-free", # NOTE: renamed
embedding_endpoint_type="openai",
embedding_endpoint="https://embeddings.letta.com/",
embedding_endpoint=self.base_url,
embedding_dim=1536,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=self.get_handle("letta-free", is_embedding=True),

View File

@@ -116,7 +116,7 @@ from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import DatabaseChoice, model_settings, settings, tool_settings
from letta.streaming_interface import AgentChunkStreamingInterface
from letta.utils import get_friendly_error_msg, get_persona_text, make_key, safe_create_task
from letta.utils import get_friendly_error_msg, get_persona_text, safe_create_task
config = LettaConfig.load()
logger = get_logger(__name__)
@@ -210,12 +210,10 @@ class SyncServer(object):
"""Initialize the MCP clients (there may be multiple)"""
self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {}
# TODO: Remove these in memory caches
self._llm_config_cache = {}
self._embedding_config_cache = {}
# collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
from letta.constants import LETTA_MODEL_ENDPOINT
self._enabled_providers: List[Provider] = [LettaProvider(name="letta", base_url=LETTA_MODEL_ENDPOINT)]
if model_settings.openai_api_key:
self._enabled_providers.append(
OpenAIProvider(
@@ -347,6 +345,12 @@ class SyncServer(object):
print(f"Default user: {self.default_user} and org: {self.default_org}")
await self.tool_manager.upsert_base_tools_async(actor=self.default_user)
# Sync environment-based providers to database (idempotent, safe for multi-pod startup)
await self.provider_manager.sync_base_providers(base_providers=self._enabled_providers, actor=self.default_user)
# Sync provider models to database
await self._sync_provider_models_async()
# For OSS users, create a local sandbox config
oss_default_user = await self.user_manager.get_default_actor_async()
use_venv = False if not tool_settings.tool_exec_venv_name else True
@@ -383,6 +387,65 @@ class SyncServer(object):
force_recreate=True,
)
def _get_enabled_provider(self, provider_name: str) -> Optional[Provider]:
"""Find and return an enabled provider by name.
Args:
provider_name: The name of the provider to find
Returns:
The matching enabled provider, or None if not found
"""
for provider in self._enabled_providers:
if provider.name == provider_name:
return provider
return None
async def _sync_provider_models_async(self):
"""Sync all provider models to database at startup."""
logger.info("Syncing provider models to database")
# Get persisted providers from database (they now have IDs)
persisted_providers = await self.provider_manager.list_providers_async(actor=self.default_user)
for persisted_provider in persisted_providers:
try:
# Find the matching enabled provider instance to call list_models on
enabled_provider = self._get_enabled_provider(persisted_provider.name)
if not enabled_provider:
# Only delete base providers that are no longer enabled
# BYOK providers are user-created and should not be automatically deleted
if persisted_provider.provider_category == ProviderCategory.base:
logger.info(f"Base provider {persisted_provider.name} is no longer enabled, deleting from database")
try:
await self.provider_manager.delete_provider_by_id_async(
provider_id=persisted_provider.id, actor=self.default_user
)
except NoResultFound:
# Provider was already deleted (race condition in multi-pod startup)
logger.debug(f"Provider {persisted_provider.name} was already deleted, skipping")
else:
logger.debug(f"No enabled provider for BYOK provider {persisted_provider.name}, skipping model sync")
continue
# Fetch models from provider
llm_models = await enabled_provider.list_llm_models_async()
embedding_models = await enabled_provider.list_embedding_models_async()
# Save to database with the persisted provider (which has an ID)
await self.provider_manager.sync_provider_models_async(
provider=persisted_provider,
llm_models=llm_models,
embedding_models=embedding_models,
organization_id=None, # Global models
)
logger.info(
f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}"
)
except Exception as e:
logger.error(f"Failed to sync models for provider {persisted_provider.name}: {e}", exc_info=True)
async def init_mcp_clients(self):
# TODO: remove this
mcp_server_configs = self.get_mcp_servers()
@@ -410,39 +473,6 @@ class SyncServer(object):
logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}")
logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}")
@trace_method
def get_cached_llm_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries")
return self._llm_config_cache[key]
@trace_method
async def get_cached_llm_config_async(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = await self.get_llm_config_from_handle_async(actor=actor, **kwargs)
logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries")
return self._llm_config_cache[key]
@trace_method
def get_cached_embedding_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._embedding_config_cache:
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries")
return self._embedding_config_cache[key]
# @async_redis_cache(key_func=lambda (actor, **kwargs): actor.id + hash(kwargs))
@trace_method
async def get_cached_embedding_config_async(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._embedding_config_cache:
self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs)
logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries")
return self._embedding_config_cache[key]
@trace_method
async def create_agent_async(
self,
@@ -476,10 +506,9 @@ class SyncServer(object):
"max_reasoning_tokens": request.max_reasoning_tokens,
"enable_reasoner": request.enable_reasoner,
}
config_params.update(additional_config_params)
log_event(name="start get_cached_llm_config", attributes=config_params)
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
log_event(name="end get_cached_llm_config", attributes=config_params)
log_event(name="start get_llm_config_from_handle", attributes=config_params)
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
log_event(name="end get_llm_config_from_handle", attributes=config_params)
if request.model and isinstance(request.model, str):
assert request.llm_config.handle == request.model, (
f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}"
@@ -507,9 +536,9 @@ class SyncServer(object):
"handle": request.embedding,
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
}
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
request.embedding_config = await self.get_cached_embedding_config_async(actor=actor, **embedding_config_params)
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
log_event(name="start get_embedding_config_from_handle", attributes=embedding_config_params)
request.embedding_config = await self.get_embedding_config_from_handle_async(actor=actor, **embedding_config_params)
log_event(name="end get_embedding_config_from_handle", attributes=embedding_config_params)
log_event(name="start create_agent db")
main_agent = await self.agent_manager.create_agent_async(
@@ -558,9 +587,9 @@ class SyncServer(object):
"context_window_limit": request.context_window_limit,
"max_tokens": request.max_tokens,
}
log_event(name="start get_cached_llm_config", attributes=config_params)
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
log_event(name="end get_cached_llm_config", attributes=config_params)
log_event(name="start get_llm_config_from_handle", attributes=config_params)
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
log_event(name="end get_llm_config_from_handle", attributes=config_params)
# update with model_settings
if request.model_settings is not None:
@@ -1070,73 +1099,85 @@ class SyncServer(object):
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[LLMConfig]:
"""Asynchronously list available models with maximum concurrency"""
import asyncio
"""List available LLM models from database cache"""
# Get provider IDs if filtering by provider
provider_ids = None
if provider_name or provider_type:
providers = await self.get_enabled_providers_async(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
)
provider_ids = [p.id for p in providers]
providers = await self.get_enabled_providers_async(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
# If filtering was requested but no providers matched, return empty list
if not provider_ids:
return []
# Get models from database
provider_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="llm",
provider_id=provider_ids[0] if provider_ids and len(provider_ids) == 1 else None,
enabled=True,
)
async def get_provider_models(provider: Provider) -> list[LLMConfig]:
try:
async with asyncio.timeout(constants.GET_PROVIDERS_TIMEOUT_SECONDS):
return await provider.list_llm_models_async()
except asyncio.TimeoutError:
logger.warning(f"Timeout while listing LLM models for provider {provider}")
return []
except Exception as e:
logger.exception(f"Error while listing LLM models for provider {provider}: {e}")
return []
# Execute all provider model listing tasks concurrently
provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers])
# Flatten the results
# Build LLMConfig objects from cached data
# Cache providers to avoid N+1 queries
provider_cache: Dict[str, Provider] = {}
llm_models = []
for models in provider_results:
llm_models.extend(models)
for model in provider_models:
# Skip if filtering by provider and model doesn't match
if provider_ids and model.provider_id not in provider_ids:
continue
# Get local configs - if this is potentially slow, consider making it async too
local_configs = self.get_local_llm_configs()
llm_models.extend(local_configs)
# Get provider details (with caching to avoid N+1 queries)
if model.provider_id not in provider_cache:
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
provider = provider_cache[model.provider_id]
# dedupe by handle for uniqueness
# Seems like this is required from the tests?
seen_handles = set()
unique_models = []
for model in llm_models:
if model.handle not in seen_handles:
seen_handles.add(model.handle)
unique_models.append(model)
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=provider.base_url or model.model_endpoint_type,
context_window=model.max_context_window or 16384,
handle=model.handle,
provider_name=provider.name,
provider_category=provider.provider_category,
)
llm_models.append(llm_config)
return unique_models
return llm_models
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
"""Asynchronously list available embedding models with maximum concurrency"""
import asyncio
"""List available embedding models from database cache"""
# Get models from database
provider_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="embedding",
enabled=True,
)
# Get all eligible providers first
providers = await self.get_enabled_providers_async(actor=actor)
# Fetch embedding models from each provider concurrently
async def get_provider_embedding_models(provider):
try:
# All providers now have list_embedding_models_async
return await provider.list_embedding_models_async()
except Exception as e:
logger.exception(f"An error occurred while listing embedding models for provider {provider}: {e}")
return []
# Execute all provider model listing tasks concurrently
provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers])
# Flatten the results
# Build EmbeddingConfig objects from cached data
# Cache providers to avoid N+1 queries
provider_cache: Dict[str, Provider] = {}
embedding_models = []
for models in provider_results:
embedding_models.extend(models)
for model in provider_models:
# Get provider details (with caching to avoid N+1 queries)
if model.provider_id not in provider_cache:
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
provider = provider_cache[model.provider_id]
embedding_config = EmbeddingConfig(
embedding_model=model.name,
embedding_endpoint_type=model.model_endpoint_type,
embedding_endpoint=provider.base_url or model.model_endpoint_type,
embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=model.handle,
)
embedding_models.append(embedding_config)
return embedding_models
@@ -1147,25 +1188,17 @@ class SyncServer(object):
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[Provider]:
providers = []
if not provider_category or ProviderCategory.base in provider_category:
providers_from_env = [p for p in self._enabled_providers]
providers.extend(providers_from_env)
# Query all persisted providers from database
persisted_providers = await self.provider_manager.list_providers_async(
name=provider_name,
provider_type=provider_type,
actor=actor,
)
providers = [p.cast_to_subtype() for p in persisted_providers]
if not provider_category or ProviderCategory.byok in provider_category:
providers_from_db = await self.provider_manager.list_providers_async(
name=provider_name,
provider_type=provider_type,
actor=actor,
)
providers_from_db = [p.cast_to_subtype() for p in providers_from_db if p.provider_category == ProviderCategory.byok]
providers.extend(providers_from_db)
if provider_name is not None:
providers = [p for p in providers if p.name == provider_name]
if provider_type is not None:
providers = [p for p in providers if p.provider_type == provider_type]
# Filter by category if specified
if provider_category:
providers = [p for p in providers if p.provider_category in provider_category]
return providers
@@ -1179,32 +1212,19 @@ class SyncServer(object):
max_reasoning_tokens: Optional[int] = None,
enable_reasoner: Optional[bool] = None,
) -> LLMConfig:
# Use provider_manager to get LLMConfig from handle
try:
provider_name, model_name = handle.split("/", 1)
provider = await self.get_provider_from_name_async(provider_name, actor)
all_llm_configs = await provider.list_llm_models_async()
llm_configs = [config for config in all_llm_configs if config.handle == handle]
if not llm_configs:
llm_configs = [config for config in all_llm_configs if config.model == model_name]
if not llm_configs:
available_handles = [config.handle for config in all_llm_configs]
raise HandleNotFoundError(handle, available_handles)
except ValueError as e:
llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle]
if not llm_configs:
llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name]
if not llm_configs:
raise e
if len(llm_configs) == 1:
llm_config = llm_configs[0]
elif len(llm_configs) > 1:
raise LettaInvalidArgumentError(
f"Multiple LLM models with name {model_name} supported by {provider_name}", argument_name="model_name"
llm_config = await self.provider_manager.get_llm_config_from_handle(
handle=handle,
actor=actor,
)
else:
llm_config = llm_configs[0]
except Exception as e:
# Convert to HandleNotFoundError for backwards compatibility
from letta.orm.errors import NoResultFound
if isinstance(e, NoResultFound):
raise HandleNotFoundError(handle, [])
raise
if context_window_limit is not None:
if context_window_limit > llm_config.context_window:
@@ -1236,33 +1256,22 @@ class SyncServer(object):
async def get_embedding_config_from_handle_async(
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
) -> EmbeddingConfig:
# Use provider_manager to get EmbeddingConfig from handle
try:
provider_name, model_name = handle.split("/", 1)
provider = await self.get_provider_from_name_async(provider_name, actor)
all_embedding_configs = await provider.list_embedding_models_async()
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
if not embedding_configs:
raise LettaInvalidArgumentError(
f"Embedding model {model_name} is not supported by {provider_name}", argument_name="model_name"
)
except LettaInvalidArgumentError as e:
# search local configs
embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle]
if not embedding_configs:
raise e
if len(embedding_configs) == 1:
embedding_config = embedding_configs[0]
elif len(embedding_configs) > 1:
raise LettaInvalidArgumentError(
f"Multiple embedding models with name {model_name} supported by {provider_name}", argument_name="model_name"
embedding_config = await self.provider_manager.get_embedding_config_from_handle(
handle=handle,
actor=actor,
)
else:
embedding_config = embedding_configs[0]
except Exception as e:
# Convert to LettaInvalidArgumentError for backwards compatibility
from letta.orm.errors import NoResultFound
if embedding_chunk_size:
embedding_config.embedding_chunk_size = embedding_chunk_size
if isinstance(e, NoResultFound):
raise LettaInvalidArgumentError(f"Embedding model {handle} not found", argument_name="handle")
raise
# Override chunk size if provided
embedding_config.embedding_chunk_size = embedding_chunk_size
return embedding_config
@@ -1271,7 +1280,7 @@ class SyncServer(object):
providers = [provider for provider in all_providers if provider.name == provider_name]
if not providers:
raise LettaInvalidArgumentError(
f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in self._enabled_providers])})",
f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in all_providers])})",
argument_name="provider_name",
)
elif len(providers) > 1:
@@ -1282,46 +1291,6 @@ class SyncServer(object):
return provider
def get_local_llm_configs(self):
llm_models = []
# NOTE: deprecated
# try:
# llm_configs_dir = os.path.expanduser("~/.letta/llm_configs")
# if os.path.exists(llm_configs_dir):
# for filename in os.listdir(llm_configs_dir):
# if filename.endswith(".json"):
# filepath = os.path.join(llm_configs_dir, filename)
# try:
# with open(filepath, "r") as f:
# config_data = json.load(f)
# llm_config = LLMConfig(**config_data)
# llm_models.append(llm_config)
# except (json.JSONDecodeError, ValueError) as e:
# logger.warning(f"Error parsing LLM config file {filename}: {e}")
# except Exception as e:
# logger.warning(f"Error reading LLM configs directory: {e}")
return llm_models
def get_local_embedding_configs(self):
embedding_models = []
# NOTE: deprecated
# try:
# embedding_configs_dir = os.path.expanduser("~/.letta/embedding_configs")
# if os.path.exists(embedding_configs_dir):
# for filename in os.listdir(embedding_configs_dir):
# if filename.endswith(".json"):
# filepath = os.path.join(embedding_configs_dir, filename)
# try:
# with open(filepath, "r") as f:
# config_data = json.load(f)
# embedding_config = EmbeddingConfig(**config_data)
# embedding_models.append(embedding_config)
# except (json.JSONDecodeError, ValueError) as e:
# logger.warning(f"Error parsing embedding config file {filename}: {e}")
# except Exception as e:
# logger.warning(f"Error reading embedding configs directory: {e}")
return embedding_models
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
"""Add a new LLM model"""

View File

@@ -201,7 +201,7 @@ class ProviderManager:
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
@trace_method
async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
"""Delete a provider."""
"""Delete a provider and its associated models."""
async with db_registry.async_session() as session:
# Clear api key field
existing_provider = await ProviderModel.read_async(
@@ -218,6 +218,15 @@ class ProviderManager:
await existing_provider.update_async(session, actor=actor)
# Soft delete all models associated with this provider
provider_models = await ProviderModelORM.list_async(
db_session=session,
provider_id=provider_id,
check_is_deleted=True,
)
for model in provider_models:
await model.delete_async(session, actor=actor)
# Soft delete in provider table
await existing_provider.delete_async(session, actor=actor)
@@ -701,11 +710,11 @@ class ProviderManager:
await model.create_async(session)
logger.info(f" ✓ Successfully created LLM model {llm_config.handle} with ID {model.id}")
except Exception as e:
logger.error(f" ✗ Failed to create LLM model {llm_config.handle}: {e}")
logger.info(f" ✗ Failed to create LLM model {llm_config.handle}: {e}")
# Log the full error details
import traceback
logger.error(f" Full traceback: {traceback.format_exc()}")
logger.info(f" Full traceback: {traceback.format_exc()}")
# Roll back the session to clear the failed transaction
await session.rollback()
else:
@@ -899,8 +908,12 @@ class ProviderManager:
if not model:
raise NoResultFound(f"LLM model not found with handle='{handle}'")
# Get the provider for this model
# Get the provider for this model and cast to subtype to access provider-specific methods
provider = await self.get_provider_async(provider_id=model.provider_id, actor=actor)
typed_provider = provider.cast_to_subtype()
# Get the default max_output_tokens from the provider (provider-specific logic)
max_tokens = typed_provider.get_default_max_output_tokens(model.name)
# Construct the LLMConfig from the model and provider data
llm_config = LLMConfig(
@@ -911,6 +924,7 @@ class ProviderManager:
handle=model.handle,
provider_name=provider.name,
provider_category=provider.provider_category,
max_tokens=max_tokens,
)
return llm_config

View File

@@ -258,14 +258,14 @@ async def test_create_agent_with_model_handle_uses_correct_llm_config(server: Sy
"""When CreateAgent.model is provided, ensure the correct handle is used to resolve llm_config.
This verifies that the model handle passed by the client is forwarded into
SyncServer.get_cached_llm_config_async and that the resulting AgentState
SyncServer.get_llm_config_from_handle_async and that the resulting AgentState
carries an llm_config with the same handle.
"""
# Track the arguments used to resolve the LLM config
captured_kwargs: dict = {}
async def fake_get_cached_llm_config_async(self, actor, **kwargs): # type: ignore[override]
async def fake_get_llm_config_from_handle_async(self, actor, **kwargs): # type: ignore[override]
from letta.schemas.llm_config import LLMConfig as PydanticLLMConfig
captured_kwargs.update(kwargs)
@@ -282,8 +282,8 @@ async def test_create_agent_with_model_handle_uses_correct_llm_config(server: Sy
model_handle = "openai/gpt-4o-mini"
# Patch SyncServer.get_cached_llm_config_async so we don't depend on provider DB state
with patch.object(SyncServer, "get_cached_llm_config_async", new=fake_get_cached_llm_config_async):
# Patch SyncServer.get_llm_config_from_handle_async so we don't depend on provider DB state
with patch.object(SyncServer, "get_llm_config_from_handle_async", new=fake_get_llm_config_from_handle_async):
created_agent = await server.create_agent_async(
request=CreateAgent(
name="agent_with_model_handle",

View File

@@ -487,87 +487,435 @@ async def test_byok_provider_auto_syncs_models(provider_manager, default_user, m
# ======================================================================================================================
# No Encryption Key Tests
# Server Startup Provider Sync Tests
# ======================================================================================================================
@pytest.fixture
def no_encryption_key():
"""Fixture to ensure NO encryption key is set for tests."""
original_key = settings.encryption_key
settings.encryption_key = None
yield None
settings.encryption_key = original_key
@pytest.mark.asyncio
async def test_server_startup_syncs_base_providers(default_user, default_organization, monkeypatch):
"""Test that server startup properly syncs base provider models from environment.
This test simulates the server startup process and verifies that:
1. Base providers from environment variables are synced to database
2. Provider models are fetched from mocked API endpoints
3. Models are properly persisted to the database with correct metadata
4. Models can be retrieved using handles
"""
from unittest.mock import AsyncMock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers import AnthropicProvider, OpenAIProvider
from letta.server.server import SyncServer
# Mock OpenAI API responses
mock_openai_models = {
"data": [
{
"id": "gpt-4",
"object": "model",
"created": 1687882411,
"owned_by": "openai",
"max_model_len": 8192,
},
{
"id": "gpt-4-turbo",
"object": "model",
"created": 1712361441,
"owned_by": "system",
"max_model_len": 128000,
},
{
"id": "text-embedding-ada-002",
"object": "model",
"created": 1671217299,
"owned_by": "openai-internal",
},
{
"id": "gpt-4-vision", # Should be filtered out by OpenAI provider logic (has disallowed keyword)
"object": "model",
"created": 1698959748,
"owned_by": "system",
"max_model_len": 8192,
},
]
}
# Mock Anthropic API responses
mock_anthropic_models = {
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"type": "model",
"display_name": "Claude 3.5 Sonnet",
"created_at": "2024-10-22T00:00:00Z",
},
{
"id": "claude-3-opus-20240229",
"type": "model",
"display_name": "Claude 3 Opus",
"created_at": "2024-02-29T00:00:00Z",
},
]
}
# Mock the API calls for OpenAI
async def mock_openai_get_model_list_async(*args, **kwargs):
return mock_openai_models
# Mock Anthropic models.list() response
from unittest.mock import MagicMock
mock_anthropic_response = MagicMock()
mock_anthropic_response.model_dump.return_value = mock_anthropic_models
# Mock the Anthropic AsyncAnthropic client
class MockAnthropicModels:
async def list(self):
return mock_anthropic_response
class MockAsyncAnthropic:
def __init__(self, *args, **kwargs):
self.models = MockAnthropicModels()
# Patch the actual API calling functions
monkeypatch.setattr(
"letta.llm_api.openai.openai_get_model_list_async",
mock_openai_get_model_list_async,
)
monkeypatch.setattr(
"anthropic.AsyncAnthropic",
MockAsyncAnthropic,
)
# Clear ALL provider-related env vars first to ensure clean state
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
monkeypatch.delenv("AZURE_API_KEY", raising=False)
monkeypatch.delenv("GROQ_API_KEY", raising=False)
monkeypatch.delenv("TOGETHER_API_KEY", raising=False)
monkeypatch.delenv("VLLM_API_BASE", raising=False)
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
monkeypatch.delenv("LMSTUDIO_BASE_URL", raising=False)
monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False)
monkeypatch.delenv("XAI_API_KEY", raising=False)
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.delenv("ZAI_API_KEY", raising=False)
# Set environment variables to enable only OpenAI and Anthropic
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key-12345")
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key-67890")
# Reload model_settings to pick up new env vars
from letta.settings import model_settings
monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key-12345")
monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key-67890")
monkeypatch.setattr(model_settings, "gemini_api_key", None)
monkeypatch.setattr(model_settings, "google_cloud_project", None)
monkeypatch.setattr(model_settings, "google_cloud_location", None)
monkeypatch.setattr(model_settings, "azure_api_key", None)
monkeypatch.setattr(model_settings, "groq_api_key", None)
monkeypatch.setattr(model_settings, "together_api_key", None)
monkeypatch.setattr(model_settings, "vllm_api_base", None)
monkeypatch.setattr(model_settings, "aws_access_key_id", None)
monkeypatch.setattr(model_settings, "aws_secret_access_key", None)
monkeypatch.setattr(model_settings, "lmstudio_base_url", None)
monkeypatch.setattr(model_settings, "deepseek_api_key", None)
monkeypatch.setattr(model_settings, "xai_api_key", None)
monkeypatch.setattr(model_settings, "openrouter_api_key", None)
monkeypatch.setattr(model_settings, "zai_api_key", None)
# Create server instance (this will load enabled providers from environment)
server = SyncServer(init_with_default_org_and_user=False)
# Manually set up the default user/org (since we disabled auto-init)
server.default_user = default_user
server.default_org = default_organization
# Verify enabled providers were loaded
assert len(server._enabled_providers) == 3 # Exactly: letta, openai, anthropic
enabled_provider_names = [p.name for p in server._enabled_providers]
assert "letta" in enabled_provider_names
assert "openai" in enabled_provider_names
assert "anthropic" in enabled_provider_names
# First, sync base providers to database (this is what init_async does)
await server.provider_manager.sync_base_providers(
base_providers=server._enabled_providers,
actor=default_user,
)
# Now call the actual _sync_provider_models_async method
# This simulates what happens during server startup
await server._sync_provider_models_async()
# Verify OpenAI models were synced
openai_providers = await server.provider_manager.list_providers_async(
name="openai",
actor=default_user,
)
assert len(openai_providers) == 1, "OpenAI provider should exist"
openai_provider = openai_providers[0]
# Check OpenAI LLM models
openai_llm_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=openai_provider.id,
model_type="llm",
)
# Should have gpt-4 and gpt-4-turbo (gpt-4-vision filtered out due to "vision" keyword)
assert len(openai_llm_models) >= 2, f"Expected at least 2 OpenAI LLM models, got {len(openai_llm_models)}"
openai_model_names = [m.name for m in openai_llm_models]
assert "gpt-4" in openai_model_names
assert "gpt-4-turbo" in openai_model_names
# Check OpenAI embedding models
openai_embedding_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=openai_provider.id,
model_type="embedding",
)
assert len(openai_embedding_models) >= 1, "Expected at least 1 OpenAI embedding model"
embedding_model_names = [m.name for m in openai_embedding_models]
assert "text-embedding-ada-002" in embedding_model_names
# Verify model metadata is correct
gpt4_models = [m for m in openai_llm_models if m.name == "gpt-4"]
assert len(gpt4_models) > 0, "gpt-4 model should exist"
gpt4_model = gpt4_models[0]
assert gpt4_model.handle == "openai/gpt-4"
assert gpt4_model.model_endpoint_type == "openai"
assert gpt4_model.max_context_window == 8192
assert gpt4_model.enabled is True
# Verify Anthropic models were synced
anthropic_providers = await server.provider_manager.list_providers_async(
name="anthropic",
actor=default_user,
)
assert len(anthropic_providers) == 1, "Anthropic provider should exist"
anthropic_provider = anthropic_providers[0]
anthropic_llm_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=anthropic_provider.id,
model_type="llm",
)
# Should have Claude models
assert len(anthropic_llm_models) >= 2, f"Expected at least 2 Anthropic models, got {len(anthropic_llm_models)}"
anthropic_model_names = [m.name for m in anthropic_llm_models]
assert "claude-3-5-sonnet-20241022" in anthropic_model_names
assert "claude-3-opus-20240229" in anthropic_model_names
# Test that we can retrieve LLMConfig from handle
llm_config = await server.provider_manager.get_llm_config_from_handle(
handle="openai/gpt-4",
actor=default_user,
)
assert llm_config.model == "gpt-4"
assert llm_config.handle == "openai/gpt-4"
assert llm_config.provider_name == "openai"
assert llm_config.context_window == 8192
# Test that we can retrieve EmbeddingConfig from handle
embedding_config = await server.provider_manager.get_embedding_config_from_handle(
handle="openai/text-embedding-ada-002",
actor=default_user,
)
assert embedding_config.embedding_model == "text-embedding-ada-002"
assert embedding_config.handle == "openai/text-embedding-ada-002"
assert embedding_config.embedding_dim == 1536
@pytest.mark.asyncio
async def test_provider_works_without_encryption_key(provider_manager, default_user, no_encryption_key):
"""Test that providers can be created and read when no encryption key is configured.
async def test_server_startup_handles_disabled_providers(default_user, default_organization, monkeypatch):
"""Test that server startup properly handles providers that are no longer enabled.
When LETTA_ENCRYPTION_KEY is not set, the Secret class should store values as
plaintext in the _enc column and successfully retrieve them.
This test verifies that:
1. Base providers that are no longer enabled (env vars removed) are deleted
2. BYOK providers that are no longer enabled are NOT deleted (user-created)
3. The sync process handles providers gracefully when API calls fail
"""
# Create a provider without encryption key configured
provider_create = ProviderCreate(
name="test-no-encryption-provider",
from letta.schemas.providers import OpenAIProvider, ProviderCreate
from letta.server.server import SyncServer
# First, manually create providers in the database
provider_manager = ProviderManager()
# Create a base OpenAI provider (simulating it was synced before)
base_openai_create = ProviderCreate(
name="openai",
provider_type=ProviderType.openai,
api_key="sk-plaintext-key-12345",
api_key="sk-old-key",
base_url="https://api.openai.com/v1",
)
base_openai = await provider_manager.create_provider_async(
base_openai_create,
actor=default_user,
is_byok=False, # This is a base provider
)
# Create provider - should work even without encryption
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
# Create a BYOK provider (user-created)
byok_provider_create = ProviderCreate(
name="my-custom-openai",
provider_type=ProviderType.openai,
api_key="sk-my-key",
base_url="https://api.openai.com/v1",
)
byok_provider = await provider_manager.create_provider_async(
byok_provider_create,
actor=default_user,
is_byok=True,
)
assert byok_provider.provider_category == ProviderCategory.byok
# Verify provider was created
assert created_provider is not None
assert created_provider.name == "test-no-encryption-provider"
# Now create server with NO environment variables set (all base providers disabled)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
# Verify api_key can be retrieved (stored as plaintext in _enc column)
assert created_provider.api_key_enc.get_plaintext() == "sk-plaintext-key-12345"
from letta.settings import model_settings
# Read the provider back from database
retrieved_provider = await provider_manager.get_provider_async(created_provider.id, actor=default_user)
monkeypatch.setattr(model_settings, "openai_api_key", None)
monkeypatch.setattr(model_settings, "anthropic_api_key", None)
# Verify round-trip works
assert retrieved_provider.api_key_enc.get_plaintext() == "sk-plaintext-key-12345"
# Create server instance
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.default_org = default_organization
# Verify the value in _enc column is actually plaintext (not encrypted)
async with db_registry.async_session() as session:
provider_orm = await ProviderModel.read_async(
db_session=session,
identifier=created_provider.id,
actor=default_user,
)
# Verify only letta provider is enabled (no openai)
enabled_names = [p.name for p in server._enabled_providers]
assert "letta" in enabled_names
assert "openai" not in enabled_names
# The value should be stored as plaintext since no encryption key was available
assert provider_orm.api_key_enc is not None
# When no encryption key is set, the plaintext is stored directly
# so from_encrypted + get_plaintext should return the original value
assert Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() == "sk-plaintext-key-12345"
# Sync base providers (should not include openai anymore)
await server.provider_manager.sync_base_providers(
base_providers=server._enabled_providers,
actor=default_user,
)
# Call _sync_provider_models_async
await server._sync_provider_models_async()
# Verify base OpenAI provider was deleted (no longer enabled)
try:
await server.provider_manager.get_provider_async(base_openai.id, actor=default_user)
assert False, "Base OpenAI provider should have been deleted"
except Exception:
# Expected - provider should not exist
pass
# Verify BYOK provider still exists (should NOT be deleted)
byok_still_exists = await server.provider_manager.get_provider_async(
byok_provider.id,
actor=default_user,
)
assert byok_still_exists is not None
assert byok_still_exists.name == "my-custom-openai"
assert byok_still_exists.provider_category == ProviderCategory.byok
@pytest.mark.asyncio
async def test_provider_update_works_without_encryption_key(provider_manager, default_user, no_encryption_key):
"""Test that provider updates work when no encryption key is configured."""
# Create initial provider
provider_create = ProviderCreate(
name="test-no-enc-update-provider",
provider_type=ProviderType.anthropic,
api_key="sk-ant-initial-key",
async def test_server_startup_handles_api_errors_gracefully(default_user, default_organization, monkeypatch):
"""Test that server startup handles API errors gracefully without crashing.
This test verifies that:
1. If a provider's API call fails during sync, it logs an error but continues
2. Other providers can still sync successfully
3. The server startup completes without crashing
"""
from letta.schemas.providers import AnthropicProvider, OpenAIProvider
from letta.server.server import SyncServer
# Mock OpenAI to fail
async def mock_openai_fail(*args, **kwargs):
raise Exception("OpenAI API is down")
# Mock Anthropic to succeed
from unittest.mock import MagicMock
mock_anthropic_response = MagicMock()
mock_anthropic_response.model_dump.return_value = {
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"type": "model",
"display_name": "Claude 3.5 Sonnet",
"created_at": "2024-10-22T00:00:00Z",
}
]
}
class MockAnthropicModels:
async def list(self):
return mock_anthropic_response
class MockAsyncAnthropic:
def __init__(self, *args, **kwargs):
self.models = MockAnthropicModels()
monkeypatch.setattr(
"letta.llm_api.openai.openai_get_model_list_async",
mock_openai_fail,
)
monkeypatch.setattr(
"anthropic.AsyncAnthropic",
MockAsyncAnthropic,
)
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
# Set environment variables
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key")
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key")
# Update the api_key
provider_update = ProviderUpdate(
api_key="sk-ant-updated-key",
from letta.settings import model_settings
monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key")
monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key")
# Create server
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.default_org = default_organization
# Sync base providers
await server.provider_manager.sync_base_providers(
base_providers=server._enabled_providers,
actor=default_user,
)
updated_provider = await provider_manager.update_provider_async(created_provider.id, provider_update, actor=default_user)
# This should NOT crash even though OpenAI fails
await server._sync_provider_models_async()
# Verify the updated key is accessible
assert updated_provider.api_key_enc.get_plaintext() == "sk-ant-updated-key"
# Verify Anthropic still synced successfully
anthropic_providers = await server.provider_manager.list_providers_async(
name="anthropic",
actor=default_user,
)
assert len(anthropic_providers) == 1
# Verify via database read
retrieved_provider = await provider_manager.get_provider_async(created_provider.id, actor=default_user)
assert retrieved_provider.api_key_enc.get_plaintext() == "sk-ant-updated-key"
anthropic_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=anthropic_providers[0].id,
model_type="llm",
)
assert len(anthropic_models) >= 1, "Anthropic models should have synced despite OpenAI failure"
# OpenAI should have no models (sync failed)
openai_providers = await server.provider_manager.list_providers_async(
name="openai",
actor=default_user,
)
if len(openai_providers) > 0:
openai_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=openai_providers[0].id,
)
# Models might exist from previous runs, but the sync attempt should have been logged as failed
# The key is that the server didn't crash

View File

@@ -32,6 +32,7 @@ from letta.config import LettaConfig
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.server.server import SyncServer
from tests.helpers.utils import upload_file_and_wait
from tests.utils import wait_for_server
# Constants
SERVER_PORT = 8283
@@ -106,7 +107,7 @@ def client() -> LettaSDKClient:
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
wait_for_server(server_url, timeout=60)
print("Running client tests with server:", server_url)
client = LettaSDKClient(base_url=server_url)

View File

@@ -1740,3 +1740,352 @@ async def test_handle_uniqueness_per_org(default_user, provider_manager):
assert model is not None
assert model.provider_id == provider_1.id # Still original provider
assert model.max_context_window == 8192 # Still original
@pytest.mark.asyncio
async def test_delete_provider_cascades_to_models(default_user, provider_manager, monkeypatch):
"""Test that deleting a provider also soft-deletes its associated models."""
test_id = generate_test_id()
# Mock _sync_default_models_for_provider to avoid external API calls
async def mock_sync(provider, actor):
pass # Don't actually sync - we'll manually create models below
monkeypatch.setattr(provider_manager, "_sync_default_models_for_provider", mock_sync)
# 1. Create a BYOK provider (org-scoped, so the actor can delete it)
provider_create = ProviderCreate(
name=f"test-cascade-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-test-key",
)
provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=True)
# 2. Manually sync models to the provider
llm_models = [
LLMConfig(
model=f"gpt-4o-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-{test_id}/gpt-4o",
provider_name=provider.name,
provider_category=ProviderCategory.byok,
),
LLMConfig(
model=f"gpt-4o-mini-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=16384,
handle=f"test-{test_id}/gpt-4o-mini",
provider_name=provider.name,
provider_category=ProviderCategory.byok,
),
]
embedding_models = [
EmbeddingConfig(
embedding_model=f"text-embedding-3-small-{test_id}",
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=300,
handle=f"test-{test_id}/text-embedding-3-small",
),
]
await provider_manager.sync_provider_models_async(
provider=provider,
llm_models=llm_models,
embedding_models=embedding_models,
organization_id=default_user.organization_id, # Org-scoped for BYOK provider
)
# 3. Verify models exist before deletion
llm_models_before = await provider_manager.list_models_async(
actor=default_user,
model_type="llm",
provider_id=provider.id,
)
embedding_models_before = await provider_manager.list_models_async(
actor=default_user,
model_type="embedding",
provider_id=provider.id,
)
llm_handles_before = {m.handle for m in llm_models_before}
embedding_handles_before = {m.handle for m in embedding_models_before}
assert f"test-{test_id}/gpt-4o" in llm_handles_before
assert f"test-{test_id}/gpt-4o-mini" in llm_handles_before
assert f"test-{test_id}/text-embedding-3-small" in embedding_handles_before
# 4. Delete the provider
await provider_manager.delete_provider_by_id_async(provider.id, actor=default_user)
# 5. Verify models are soft-deleted (no longer returned in list)
all_llm_models_after = await provider_manager.list_models_async(
actor=default_user,
model_type="llm",
)
all_embedding_models_after = await provider_manager.list_models_async(
actor=default_user,
model_type="embedding",
)
all_llm_handles_after = {m.handle for m in all_llm_models_after}
all_embedding_handles_after = {m.handle for m in all_embedding_models_after}
# All models from the deleted provider should be gone
assert f"test-{test_id}/gpt-4o" not in all_llm_handles_after
assert f"test-{test_id}/gpt-4o-mini" not in all_llm_handles_after
assert f"test-{test_id}/text-embedding-3-small" not in all_embedding_handles_after
# 6. Verify provider is also deleted
providers_after = await provider_manager.list_providers_async(
actor=default_user,
name=f"test-cascade-{test_id}",
)
assert len(providers_after) == 0
@pytest.mark.asyncio
async def test_get_llm_config_from_handle_includes_max_tokens(default_user, provider_manager):
"""Test that get_llm_config_from_handle includes max_tokens from provider's get_default_max_output_tokens.
This test verifies that:
1. The max_tokens field is populated when retrieving LLMConfig from a handle
2. The max_tokens value comes from the provider's get_default_max_output_tokens method
3. Different providers return different default max_tokens values (e.g., OpenAI returns 16384)
"""
test_id = generate_test_id()
# Create an OpenAI provider
provider_create = ProviderCreate(
name=f"test-openai-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-test-key",
base_url="https://api.openai.com/v1",
)
provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=False)
# Sync a model with the provider
llm_models = [
LLMConfig(
model=f"gpt-4o-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-{test_id}/gpt-4o",
provider_name=provider.name,
provider_category=ProviderCategory.base,
),
]
await provider_manager.sync_provider_models_async(
provider=provider,
llm_models=llm_models,
embedding_models=[],
organization_id=None, # Global model
)
# Get LLMConfig from handle
llm_config = await provider_manager.get_llm_config_from_handle(
handle=f"test-{test_id}/gpt-4o",
actor=default_user,
)
# Verify max_tokens is set and comes from OpenAI provider's default (16384 for non-o1/o3 models)
assert llm_config.max_tokens is not None, "max_tokens should be set"
assert llm_config.max_tokens == 16384, f"Expected max_tokens=16384 for OpenAI gpt-4o, got {llm_config.max_tokens}"
# Test with a gpt-5 model (should have 16384)
llm_models_gpt5 = [
LLMConfig(
model=f"gpt-5-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=200000,
handle=f"test-{test_id}/gpt-5",
provider_name=provider.name,
provider_category=ProviderCategory.base,
),
]
await provider_manager.sync_provider_models_async(
provider=provider,
llm_models=llm_models_gpt5,
embedding_models=[],
organization_id=None,
)
llm_config_gpt5 = await provider_manager.get_llm_config_from_handle(
handle=f"test-{test_id}/gpt-5",
actor=default_user,
)
# gpt-5 models also have 16384 max_tokens
assert llm_config_gpt5.max_tokens == 16384, f"Expected max_tokens=16384 for gpt-5, got {llm_config_gpt5.max_tokens}"
@pytest.mark.asyncio
async def test_server_list_llm_models_async_reads_from_database(default_user, provider_manager):
"""Test that the server's list_llm_models_async reads models from database, not in-memory.
This test verifies that:
1. Models synced to the database are returned by list_llm_models_async
2. The LLMConfig objects are correctly constructed from database-cached models
3. Provider filtering works correctly when reading from database
"""
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create a provider in the database
provider_create = ProviderCreate(
name=f"test-db-provider-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-test-key",
base_url="https://custom.openai.com/v1",
)
provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=False)
# Sync models to database
llm_models = [
LLMConfig(
model=f"custom-model-1-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://custom.openai.com/v1",
context_window=32000,
handle=f"test-{test_id}/custom-model-1",
provider_name=provider.name,
provider_category=ProviderCategory.base,
),
LLMConfig(
model=f"custom-model-2-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://custom.openai.com/v1",
context_window=64000,
handle=f"test-{test_id}/custom-model-2",
provider_name=provider.name,
provider_category=ProviderCategory.base,
),
]
await provider_manager.sync_provider_models_async(
provider=provider,
llm_models=llm_models,
embedding_models=[],
organization_id=None,
)
# Create server instance
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
# List LLM models via server
models = await server.list_llm_models_async(
actor=default_user,
provider_name=f"test-db-provider-{test_id}",
)
# Verify models were read from database
handles = {m.handle for m in models}
assert f"test-{test_id}/custom-model-1" in handles, "custom-model-1 should be in database"
assert f"test-{test_id}/custom-model-2" in handles, "custom-model-2 should be in database"
# Verify LLMConfig properties are correctly populated from database
model_1 = next(m for m in models if m.handle == f"test-{test_id}/custom-model-1")
assert model_1.model == f"custom-model-1-{test_id}"
assert model_1.context_window == 32000
assert model_1.model_endpoint == "https://custom.openai.com/v1"
assert model_1.provider_name == f"test-db-provider-{test_id}"
model_2 = next(m for m in models if m.handle == f"test-{test_id}/custom-model-2")
assert model_2.model == f"custom-model-2-{test_id}"
assert model_2.context_window == 64000
@pytest.mark.asyncio
async def test_get_enabled_providers_async_queries_database(default_user, provider_manager):
"""Test that get_enabled_providers_async queries providers from database, not in-memory list.
This test verifies that:
1. Providers created in the database are returned by get_enabled_providers_async
2. The method queries the database, not an in-memory _enabled_providers list
3. Provider filtering by category works correctly from database
"""
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create providers in the database
base_provider_create = ProviderCreate(
name=f"test-base-provider-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-test-key",
base_url="https://api.openai.com/v1",
)
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
byok_provider_create = ProviderCreate(
name=f"test-byok-provider-{test_id}",
provider_type=ProviderType.anthropic,
api_key="sk-test-byok-key",
)
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
# Create server instance - importantly, don't set _enabled_providers
# This ensures we're testing database queries, not in-memory list
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
# Clear in-memory providers to prove we're querying database
server._enabled_providers = []
# Get all providers - should query database
all_providers = await server.get_enabled_providers_async(actor=default_user)
provider_names = [p.name for p in all_providers]
assert f"test-base-provider-{test_id}" in provider_names, "Base provider should be in database"
assert f"test-byok-provider-{test_id}" in provider_names, "BYOK provider should be in database"
# Filter by provider category
base_only = await server.get_enabled_providers_async(
actor=default_user,
provider_category=[ProviderCategory.base],
)
base_only_names = [p.name for p in base_only]
assert f"test-base-provider-{test_id}" in base_only_names, "Base provider should be in base-only list"
assert f"test-byok-provider-{test_id}" not in base_only_names, "BYOK provider should NOT be in base-only list"
byok_only = await server.get_enabled_providers_async(
actor=default_user,
provider_category=[ProviderCategory.byok],
)
byok_only_names = [p.name for p in byok_only]
assert f"test-byok-provider-{test_id}" in byok_only_names, "BYOK provider should be in byok-only list"
assert f"test-base-provider-{test_id}" not in byok_only_names, "Base provider should NOT be in byok-only list"
# Filter by provider name
specific_provider = await server.get_enabled_providers_async(
actor=default_user,
provider_name=f"test-base-provider-{test_id}",
)
assert len(specific_provider) == 1
assert specific_provider[0].name == f"test-base-provider-{test_id}"
assert specific_provider[0].provider_type == ProviderType.openai
# Filter by provider type
openai_providers = await server.get_enabled_providers_async(
actor=default_user,
provider_type=ProviderType.openai,
)
openai_names = [p.name for p in openai_providers]
assert f"test-base-provider-{test_id}" in openai_names
assert f"test-byok-provider-{test_id}" not in openai_names # This is anthropic type