Revert "feat: enable provider models persistence" (#6590)

Revert "feat: enable provider models persistence (#6193)"

This reverts commit 9682aff32640a6ee8cf71a6f18c9fa7cda25c40e.
This commit is contained in:
Sarah Wooders
2025-12-09 16:46:26 -08:00
committed by Caren Thomas
parent bbd52e291c
commit 8440e319e2
8 changed files with 205 additions and 754 deletions

View File

@@ -220,38 +220,24 @@ class LLMClientBase:
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Returns the override key for the given llm config.
For both base and BYOK providers, fetch the API key from the database.
"""
api_key = None
# Fetch API key from database for both base and BYOK providers
# This ensures that base providers (from environment) also have their keys persisted and accessible
if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
# If we got an empty string from the database (e.g., Letta provider), treat it as None
# so the client can fall back to environment variables or default behavior
if api_key == "":
api_key = None
return api_key, None, None
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Returns the override key for the given llm config.
For both base and BYOK providers, fetch the API key from the database.
"""
api_key = None
# Fetch API key from database for both base and BYOK providers
# This ensures that base providers (from environment) also have their keys persisted and accessible
if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
# If we got an empty string from the database (e.g., Letta provider), treat it as None
# so the client can fall back to environment variables or default behavior
if api_key == "":
api_key = None
return api_key, None, None

View File

@@ -8,13 +8,10 @@ from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers.base import Provider
LETTA_EMBEDDING_ENDPOINT = "https://embeddings.letta.com/"
class LettaProvider(Provider):
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
base_url: str = Field(LETTA_EMBEDDING_ENDPOINT, description="Base URL for the Letta API (used for embeddings).")
async def list_llm_models_async(self) -> list[LLMConfig]:
return [
@@ -34,7 +31,7 @@ class LettaProvider(Provider):
EmbeddingConfig(
embedding_model="letta-free", # NOTE: renamed
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_endpoint="https://embeddings.letta.com/",
embedding_dim=1536,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=self.get_handle("letta-free", is_embedding=True),

View File

@@ -113,7 +113,7 @@ class Secret(BaseModel):
"MIGRATION_NEEDED: Reading from plaintext column instead of encrypted column. "
"This indicates data that hasn't been migrated to the _enc column yet. "
"Please run migrate data to _enc columns as plaintext columns will be deprecated.",
# stack_info=True,
stack_info=True,
)
return cls.from_plaintext(plaintext_value)
return cls.from_plaintext(None)

View File

@@ -109,7 +109,7 @@ from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import DatabaseChoice, model_settings, settings, tool_settings
from letta.streaming_interface import AgentChunkStreamingInterface
from letta.utils import get_friendly_error_msg, get_persona_text, safe_create_task
from letta.utils import get_friendly_error_msg, get_persona_text, make_key, safe_create_task
config = LettaConfig.load()
logger = get_logger(__name__)
@@ -203,10 +203,12 @@ class SyncServer(object):
"""Initialize the MCP clients (there may be multiple)"""
self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {}
# collect providers (always has Letta as a default)
from letta.constants import LETTA_MODEL_ENDPOINT
# TODO: Remove these in memory caches
self._llm_config_cache = {}
self._embedding_config_cache = {}
self._enabled_providers: List[Provider] = [LettaProvider(name="letta", base_url=LETTA_MODEL_ENDPOINT)]
# collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
if model_settings.openai_api_key:
self._enabled_providers.append(
OpenAIProvider(
@@ -340,12 +342,6 @@ class SyncServer(object):
print(f"Default user: {self.default_user} and org: {self.default_org}")
await self.tool_manager.upsert_base_tools_async(actor=self.default_user)
# Sync environment-based providers to database (idempotent, safe for multi-pod startup)
await self.provider_manager.sync_base_providers(base_providers=self._enabled_providers, actor=self.default_user)
# Sync provider models to database
await self._sync_provider_models_async()
# For OSS users, create a local sandbox config
oss_default_user = await self.user_manager.get_default_actor_async()
use_venv = False if not tool_settings.tool_exec_venv_name else True
@@ -382,65 +378,6 @@ class SyncServer(object):
force_recreate=True,
)
def _get_enabled_provider(self, provider_name: str) -> Optional[Provider]:
"""Find and return an enabled provider by name.
Args:
provider_name: The name of the provider to find
Returns:
The matching enabled provider, or None if not found
"""
for provider in self._enabled_providers:
if provider.name == provider_name:
return provider
return None
async def _sync_provider_models_async(self):
"""Sync all provider models to database at startup."""
logger.info("Syncing provider models to database")
# Get persisted providers from database (they now have IDs)
persisted_providers = await self.provider_manager.list_providers_async(actor=self.default_user)
for persisted_provider in persisted_providers:
try:
# Find the matching enabled provider instance to call list_models on
enabled_provider = self._get_enabled_provider(persisted_provider.name)
if not enabled_provider:
# Only delete base providers that are no longer enabled
# BYOK providers are user-created and should not be automatically deleted
if persisted_provider.provider_category == ProviderCategory.base:
logger.info(f"Base provider {persisted_provider.name} is no longer enabled, deleting from database")
try:
await self.provider_manager.delete_provider_by_id_async(
provider_id=persisted_provider.id, actor=self.default_user
)
except NoResultFound:
# Provider was already deleted (race condition in multi-pod startup)
logger.debug(f"Provider {persisted_provider.name} was already deleted, skipping")
else:
logger.debug(f"No enabled provider for BYOK provider {persisted_provider.name}, skipping model sync")
continue
# Fetch models from provider
llm_models = await enabled_provider.list_llm_models_async()
embedding_models = await enabled_provider.list_embedding_models_async()
# Save to database with the persisted provider (which has an ID)
await self.provider_manager.sync_provider_models_async(
provider=persisted_provider,
llm_models=llm_models,
embedding_models=embedding_models,
organization_id=None, # Global models
)
logger.info(
f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}"
)
except Exception as e:
logger.error(f"Failed to sync models for provider {persisted_provider.name}: {e}", exc_info=True)
async def init_mcp_clients(self):
# TODO: remove this
mcp_server_configs = self.get_mcp_servers()
@@ -468,6 +405,39 @@ class SyncServer(object):
logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}")
logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}")
@trace_method
def get_cached_llm_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries")
return self._llm_config_cache[key]
@trace_method
async def get_cached_llm_config_async(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = await self.get_llm_config_from_handle_async(actor=actor, **kwargs)
logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries")
return self._llm_config_cache[key]
@trace_method
def get_cached_embedding_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._embedding_config_cache:
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries")
return self._embedding_config_cache[key]
# @async_redis_cache(key_func=lambda (actor, **kwargs): actor.id + hash(kwargs))
@trace_method
async def get_cached_embedding_config_async(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._embedding_config_cache:
self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs)
logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries")
return self._embedding_config_cache[key]
@trace_method
async def create_agent_async(
self,
@@ -501,9 +471,10 @@ class SyncServer(object):
"max_reasoning_tokens": request.max_reasoning_tokens,
"enable_reasoner": request.enable_reasoner,
}
log_event(name="start get_llm_config_from_handle", attributes=config_params)
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
log_event(name="end get_llm_config_from_handle", attributes=config_params)
config_params.update(additional_config_params)
log_event(name="start get_cached_llm_config", attributes=config_params)
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
log_event(name="end get_cached_llm_config", attributes=config_params)
if request.model and isinstance(request.model, str):
assert request.llm_config.handle == request.model, (
f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}"
@@ -533,9 +504,9 @@ class SyncServer(object):
"handle": request.embedding,
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
}
log_event(name="start get_embedding_config_from_handle", attributes=embedding_config_params)
request.embedding_config = await self.get_embedding_config_from_handle_async(actor=actor, **embedding_config_params)
log_event(name="end get_embedding_config_from_handle", attributes=embedding_config_params)
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
request.embedding_config = await self.get_cached_embedding_config_async(actor=actor, **embedding_config_params)
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
log_event(name="start create_agent db")
main_agent = await self.agent_manager.create_agent_async(
@@ -584,9 +555,9 @@ class SyncServer(object):
"context_window_limit": request.context_window_limit,
"max_tokens": request.max_tokens,
}
log_event(name="start get_llm_config_from_handle", attributes=config_params)
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
log_event(name="end get_llm_config_from_handle", attributes=config_params)
log_event(name="start get_cached_llm_config", attributes=config_params)
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
log_event(name="end get_cached_llm_config", attributes=config_params)
# update with model_settings
if request.model_settings is not None:
@@ -1090,85 +1061,73 @@ class SyncServer(object):
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[LLMConfig]:
"""List available LLM models from database cache"""
# Get provider IDs if filtering by provider
provider_ids = None
if provider_name or provider_type:
providers = await self.get_enabled_providers_async(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
)
provider_ids = [p.id for p in providers]
"""Asynchronously list available models with maximum concurrency"""
import asyncio
# If filtering was requested but no providers matched, return empty list
if not provider_ids:
providers = await self.get_enabled_providers_async(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
)
async def get_provider_models(provider: Provider) -> list[LLMConfig]:
try:
async with asyncio.timeout(constants.GET_PROVIDERS_TIMEOUT_SECONDS):
return await provider.list_llm_models_async()
except asyncio.TimeoutError:
logger.warning(f"Timeout while listing LLM models for provider {provider}")
return []
except Exception as e:
logger.exception(f"Error while listing LLM models for provider {provider}: {e}")
return []
# Get models from database
provider_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="llm",
provider_id=provider_ids[0] if provider_ids and len(provider_ids) == 1 else None,
enabled=True,
)
# Execute all provider model listing tasks concurrently
provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers])
# Build LLMConfig objects from cached data
# Cache providers to avoid N+1 queries
provider_cache: Dict[str, Provider] = {}
# Flatten the results
llm_models = []
for model in provider_models:
# Skip if filtering by provider and model doesn't match
if provider_ids and model.provider_id not in provider_ids:
continue
for models in provider_results:
llm_models.extend(models)
# Get provider details (with caching to avoid N+1 queries)
if model.provider_id not in provider_cache:
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
provider = provider_cache[model.provider_id]
# Get local configs - if this is potentially slow, consider making it async too
local_configs = self.get_local_llm_configs()
llm_models.extend(local_configs)
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=provider.base_url or model.model_endpoint_type,
context_window=model.max_context_window or 16384,
handle=model.handle,
provider_name=provider.name,
provider_category=provider.provider_category,
)
llm_models.append(llm_config)
# dedupe by handle for uniqueness
# Seems like this is required from the tests?
seen_handles = set()
unique_models = []
for model in llm_models:
if model.handle not in seen_handles:
seen_handles.add(model.handle)
unique_models.append(model)
return llm_models
return unique_models
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
"""List available embedding models from database cache"""
# Get models from database
provider_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="embedding",
enabled=True,
)
"""Asynchronously list available embedding models with maximum concurrency"""
import asyncio
# Build EmbeddingConfig objects from cached data
# Cache providers to avoid N+1 queries
provider_cache: Dict[str, Provider] = {}
# Get all eligible providers first
providers = await self.get_enabled_providers_async(actor=actor)
# Fetch embedding models from each provider concurrently
async def get_provider_embedding_models(provider):
try:
# All providers now have list_embedding_models_async
return await provider.list_embedding_models_async()
except Exception as e:
logger.exception(f"An error occurred while listing embedding models for provider {provider}: {e}")
return []
# Execute all provider model listing tasks concurrently
provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers])
# Flatten the results
embedding_models = []
for model in provider_models:
# Get provider details (with caching to avoid N+1 queries)
if model.provider_id not in provider_cache:
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
provider = provider_cache[model.provider_id]
embedding_config = EmbeddingConfig(
embedding_model=model.name,
embedding_endpoint_type=model.model_endpoint_type,
embedding_endpoint=provider.base_url or model.model_endpoint_type,
embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=model.handle,
)
embedding_models.append(embedding_config)
for models in provider_results:
embedding_models.extend(models)
return embedding_models
@@ -1181,22 +1140,17 @@ class SyncServer(object):
) -> List[Provider]:
providers = []
if not provider_category or ProviderCategory.base in provider_category:
# Add enabled providers (base providers from environment)
enabled_providers = [p for p in self._enabled_providers]
providers.extend(enabled_providers)
providers_from_env = [p for p in self._enabled_providers]
providers.extend(providers_from_env)
if not provider_category or ProviderCategory.byok in provider_category:
# Add persisted BYOK providers from database
# Note: list_providers_async returns both org-specific and global providers,
# so we filter to only include BYOK providers to avoid duplicating base providers
persisted_providers = await self.provider_manager.list_providers_async(
providers_from_db = await self.provider_manager.list_providers_async(
name=provider_name,
provider_type=provider_type,
actor=actor,
)
# Filter to only BYOK providers (base providers are already in self._enabled_providers)
persisted_byok_providers = [p.cast_to_subtype() for p in persisted_providers if p.provider_category == ProviderCategory.byok]
providers.extend(persisted_byok_providers)
providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
providers.extend(providers_from_db)
if provider_name is not None:
providers = [p for p in providers if p.name == provider_name]
@@ -1216,19 +1170,32 @@ class SyncServer(object):
max_reasoning_tokens: Optional[int] = None,
enable_reasoner: Optional[bool] = None,
) -> LLMConfig:
# Use provider_manager to get LLMConfig from handle
try:
llm_config = await self.provider_manager.get_llm_config_from_handle(
handle=handle,
actor=actor,
)
except Exception as e:
# Convert to HandleNotFoundError for backwards compatibility
from letta.orm.errors import NoResultFound
provider_name, model_name = handle.split("/", 1)
provider = await self.get_provider_from_name_async(provider_name, actor)
if isinstance(e, NoResultFound):
raise HandleNotFoundError(handle, [])
raise
all_llm_configs = await provider.list_llm_models_async()
llm_configs = [config for config in all_llm_configs if config.handle == handle]
if not llm_configs:
llm_configs = [config for config in all_llm_configs if config.model == model_name]
if not llm_configs:
available_handles = [config.handle for config in all_llm_configs]
raise HandleNotFoundError(handle, available_handles)
except ValueError as e:
llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle]
if not llm_configs:
llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name]
if not llm_configs:
raise e
if len(llm_configs) == 1:
llm_config = llm_configs[0]
elif len(llm_configs) > 1:
raise LettaInvalidArgumentError(
f"Multiple LLM models with name {model_name} supported by {provider_name}", argument_name="model_name"
)
else:
llm_config = llm_configs[0]
if context_window_limit is not None:
if context_window_limit > llm_config.context_window:
@@ -1260,22 +1227,33 @@ class SyncServer(object):
async def get_embedding_config_from_handle_async(
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
) -> EmbeddingConfig:
# Use provider_manager to get EmbeddingConfig from handle
try:
embedding_config = await self.provider_manager.get_embedding_config_from_handle(
handle=handle,
actor=actor,
provider_name, model_name = handle.split("/", 1)
provider = await self.get_provider_from_name_async(provider_name, actor)
all_embedding_configs = await provider.list_embedding_models_async()
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
if not embedding_configs:
raise LettaInvalidArgumentError(
f"Embedding model {model_name} is not supported by {provider_name}", argument_name="model_name"
)
except LettaInvalidArgumentError as e:
# search local configs
embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle]
if not embedding_configs:
raise e
if len(embedding_configs) == 1:
embedding_config = embedding_configs[0]
elif len(embedding_configs) > 1:
raise LettaInvalidArgumentError(
f"Multiple embedding models with name {model_name} supported by {provider_name}", argument_name="model_name"
)
except Exception as e:
# Convert to LettaInvalidArgumentError for backwards compatibility
from letta.orm.errors import NoResultFound
else:
embedding_config = embedding_configs[0]
if isinstance(e, NoResultFound):
raise LettaInvalidArgumentError(f"Embedding model {handle} not found", argument_name="handle")
raise
# Override chunk size if provided
embedding_config.embedding_chunk_size = embedding_chunk_size
if embedding_chunk_size:
embedding_config.embedding_chunk_size = embedding_chunk_size
return embedding_config
@@ -1294,6 +1272,46 @@ class SyncServer(object):
return provider
def get_local_llm_configs(self):
llm_models = []
# NOTE: deprecated
# try:
# llm_configs_dir = os.path.expanduser("~/.letta/llm_configs")
# if os.path.exists(llm_configs_dir):
# for filename in os.listdir(llm_configs_dir):
# if filename.endswith(".json"):
# filepath = os.path.join(llm_configs_dir, filename)
# try:
# with open(filepath, "r") as f:
# config_data = json.load(f)
# llm_config = LLMConfig(**config_data)
# llm_models.append(llm_config)
# except (json.JSONDecodeError, ValueError) as e:
# logger.warning(f"Error parsing LLM config file {filename}: {e}")
# except Exception as e:
# logger.warning(f"Error reading LLM configs directory: {e}")
return llm_models
def get_local_embedding_configs(self):
embedding_models = []
# NOTE: deprecated
# try:
# embedding_configs_dir = os.path.expanduser("~/.letta/embedding_configs")
# if os.path.exists(embedding_configs_dir):
# for filename in os.listdir(embedding_configs_dir):
# if filename.endswith(".json"):
# filepath = os.path.join(embedding_configs_dir, filename)
# try:
# with open(filepath, "r") as f:
# config_data = json.load(f)
# embedding_config = EmbeddingConfig(**config_data)
# embedding_models.append(embedding_config)
# except (json.JSONDecodeError, ValueError) as e:
# logger.warning(f"Error parsing embedding config file {filename}: {e}")
# except Exception as e:
# logger.warning(f"Error reading embedding configs directory: {e}")
return embedding_models
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
"""Add a new LLM model"""

View File

@@ -154,7 +154,7 @@ class ProviderManager:
@trace_method
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
"""Delete a provider and its associated models."""
"""Delete a provider."""
async with db_registry.async_session() as session:
# Clear api key field
existing_provider = await ProviderModel.read_async(
@@ -163,15 +163,6 @@ class ProviderManager:
existing_provider.api_key = None
await existing_provider.update_async(session, actor=actor)
# Soft delete all models associated with this provider
provider_models = await ProviderModelORM.list_async(
db_session=session,
provider_id=provider_id,
check_is_deleted=True,
)
for model in provider_models:
await model.delete_async(session, actor=actor)
# Soft delete in provider table
await existing_provider.delete_async(session, actor=actor)
@@ -640,11 +631,11 @@ class ProviderManager:
await model.create_async(session)
logger.info(f" ✓ Successfully created LLM model {llm_config.handle} with ID {model.id}")
except Exception as e:
logger.info(f" ✗ Failed to create LLM model {llm_config.handle}: {e}")
logger.error(f" ✗ Failed to create LLM model {llm_config.handle}: {e}")
# Log the full error details
import traceback
logger.info(f" Full traceback: {traceback.format_exc()}")
logger.error(f" Full traceback: {traceback.format_exc()}")
# Roll back the session to clear the failed transaction
await session.rollback()
else:

View File

@@ -499,436 +499,3 @@ async def test_byok_provider_auto_syncs_models(provider_manager, default_user, m
llm_config = await provider_manager.get_llm_config_from_handle(handle="my-openai-key/gpt-4o", actor=default_user)
assert llm_config.model == "gpt-4o"
assert llm_config.provider_name == "my-openai-key"
# ======================================================================================================================
# Server Startup Provider Sync Tests
# ======================================================================================================================
@pytest.mark.asyncio
async def test_server_startup_syncs_base_providers(default_user, default_organization, monkeypatch):
"""Test that server startup properly syncs base provider models from environment.
This test simulates the server startup process and verifies that:
1. Base providers from environment variables are synced to database
2. Provider models are fetched from mocked API endpoints
3. Models are properly persisted to the database with correct metadata
4. Models can be retrieved using handles
"""
from unittest.mock import AsyncMock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers import AnthropicProvider, OpenAIProvider
from letta.server.server import SyncServer
# Mock OpenAI API responses
mock_openai_models = {
"data": [
{
"id": "gpt-4",
"object": "model",
"created": 1687882411,
"owned_by": "openai",
"max_model_len": 8192,
},
{
"id": "gpt-4-turbo",
"object": "model",
"created": 1712361441,
"owned_by": "system",
"max_model_len": 128000,
},
{
"id": "text-embedding-ada-002",
"object": "model",
"created": 1671217299,
"owned_by": "openai-internal",
},
{
"id": "gpt-4-vision", # Should be filtered out by OpenAI provider logic (has disallowed keyword)
"object": "model",
"created": 1698959748,
"owned_by": "system",
"max_model_len": 8192,
},
]
}
# Mock Anthropic API responses
mock_anthropic_models = {
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"type": "model",
"display_name": "Claude 3.5 Sonnet",
"created_at": "2024-10-22T00:00:00Z",
},
{
"id": "claude-3-opus-20240229",
"type": "model",
"display_name": "Claude 3 Opus",
"created_at": "2024-02-29T00:00:00Z",
},
]
}
# Mock the API calls for OpenAI
async def mock_openai_get_model_list_async(*args, **kwargs):
return mock_openai_models
# Mock Anthropic models.list() response
from unittest.mock import MagicMock
mock_anthropic_response = MagicMock()
mock_anthropic_response.model_dump.return_value = mock_anthropic_models
# Mock the Anthropic AsyncAnthropic client
class MockAnthropicModels:
async def list(self):
return mock_anthropic_response
class MockAsyncAnthropic:
def __init__(self, *args, **kwargs):
self.models = MockAnthropicModels()
# Patch the actual API calling functions
monkeypatch.setattr(
"letta.llm_api.openai.openai_get_model_list_async",
mock_openai_get_model_list_async,
)
monkeypatch.setattr(
"anthropic.AsyncAnthropic",
MockAsyncAnthropic,
)
# Clear ALL provider-related env vars first to ensure clean state
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
monkeypatch.delenv("AZURE_API_KEY", raising=False)
monkeypatch.delenv("GROQ_API_KEY", raising=False)
monkeypatch.delenv("TOGETHER_API_KEY", raising=False)
monkeypatch.delenv("VLLM_API_BASE", raising=False)
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
monkeypatch.delenv("LMSTUDIO_BASE_URL", raising=False)
monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False)
monkeypatch.delenv("XAI_API_KEY", raising=False)
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
# Set environment variables to enable only OpenAI and Anthropic
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key-12345")
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key-67890")
# Reload model_settings to pick up new env vars
from letta.settings import model_settings
monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key-12345")
monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key-67890")
monkeypatch.setattr(model_settings, "gemini_api_key", None)
monkeypatch.setattr(model_settings, "google_cloud_project", None)
monkeypatch.setattr(model_settings, "google_cloud_location", None)
monkeypatch.setattr(model_settings, "azure_api_key", None)
monkeypatch.setattr(model_settings, "groq_api_key", None)
monkeypatch.setattr(model_settings, "together_api_key", None)
monkeypatch.setattr(model_settings, "vllm_api_base", None)
monkeypatch.setattr(model_settings, "aws_access_key_id", None)
monkeypatch.setattr(model_settings, "aws_secret_access_key", None)
monkeypatch.setattr(model_settings, "lmstudio_base_url", None)
monkeypatch.setattr(model_settings, "deepseek_api_key", None)
monkeypatch.setattr(model_settings, "xai_api_key", None)
monkeypatch.setattr(model_settings, "openrouter_api_key", None)
# Create server instance (this will load enabled providers from environment)
server = SyncServer(init_with_default_org_and_user=False)
# Manually set up the default user/org (since we disabled auto-init)
server.default_user = default_user
server.default_org = default_organization
# Verify enabled providers were loaded
assert len(server._enabled_providers) == 3 # Exactly: letta, openai, anthropic
enabled_provider_names = [p.name for p in server._enabled_providers]
assert "letta" in enabled_provider_names
assert "openai" in enabled_provider_names
assert "anthropic" in enabled_provider_names
# First, sync base providers to database (this is what init_async does)
await server.provider_manager.sync_base_providers(
base_providers=server._enabled_providers,
actor=default_user,
)
# Now call the actual _sync_provider_models_async method
# This simulates what happens during server startup
await server._sync_provider_models_async()
# Verify OpenAI models were synced
openai_providers = await server.provider_manager.list_providers_async(
name="openai",
actor=default_user,
)
assert len(openai_providers) == 1, "OpenAI provider should exist"
openai_provider = openai_providers[0]
# Check OpenAI LLM models
openai_llm_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=openai_provider.id,
model_type="llm",
)
# Should have gpt-4 and gpt-4-turbo (gpt-4-vision filtered out due to "vision" keyword)
assert len(openai_llm_models) >= 2, f"Expected at least 2 OpenAI LLM models, got {len(openai_llm_models)}"
openai_model_names = [m.name for m in openai_llm_models]
assert "gpt-4" in openai_model_names
assert "gpt-4-turbo" in openai_model_names
# Check OpenAI embedding models
openai_embedding_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=openai_provider.id,
model_type="embedding",
)
assert len(openai_embedding_models) >= 1, "Expected at least 1 OpenAI embedding model"
embedding_model_names = [m.name for m in openai_embedding_models]
assert "text-embedding-ada-002" in embedding_model_names
# Verify model metadata is correct
gpt4_models = [m for m in openai_llm_models if m.name == "gpt-4"]
assert len(gpt4_models) > 0, "gpt-4 model should exist"
gpt4_model = gpt4_models[0]
assert gpt4_model.handle == "openai/gpt-4"
assert gpt4_model.model_endpoint_type == "openai"
assert gpt4_model.max_context_window == 8192
assert gpt4_model.enabled is True
# Verify Anthropic models were synced
anthropic_providers = await server.provider_manager.list_providers_async(
name="anthropic",
actor=default_user,
)
assert len(anthropic_providers) == 1, "Anthropic provider should exist"
anthropic_provider = anthropic_providers[0]
anthropic_llm_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=anthropic_provider.id,
model_type="llm",
)
# Should have Claude models
assert len(anthropic_llm_models) >= 2, f"Expected at least 2 Anthropic models, got {len(anthropic_llm_models)}"
anthropic_model_names = [m.name for m in anthropic_llm_models]
assert "claude-3-5-sonnet-20241022" in anthropic_model_names
assert "claude-3-opus-20240229" in anthropic_model_names
# Test that we can retrieve LLMConfig from handle
llm_config = await server.provider_manager.get_llm_config_from_handle(
handle="openai/gpt-4",
actor=default_user,
)
assert llm_config.model == "gpt-4"
assert llm_config.handle == "openai/gpt-4"
assert llm_config.provider_name == "openai"
assert llm_config.context_window == 8192
# Test that we can retrieve EmbeddingConfig from handle
embedding_config = await server.provider_manager.get_embedding_config_from_handle(
handle="openai/text-embedding-ada-002",
actor=default_user,
)
assert embedding_config.embedding_model == "text-embedding-ada-002"
assert embedding_config.handle == "openai/text-embedding-ada-002"
assert embedding_config.embedding_dim == 1536
@pytest.mark.asyncio
async def test_server_startup_handles_disabled_providers(default_user, default_organization, monkeypatch):
"""Test that server startup properly handles providers that are no longer enabled.
This test verifies that:
1. Base providers that are no longer enabled (env vars removed) are deleted
2. BYOK providers that are no longer enabled are NOT deleted (user-created)
3. The sync process handles providers gracefully when API calls fail
"""
from letta.schemas.providers import OpenAIProvider, ProviderCreate
from letta.server.server import SyncServer
# First, manually create providers in the database
provider_manager = ProviderManager()
# Create a base OpenAI provider (simulating it was synced before)
base_openai_create = ProviderCreate(
name="openai",
provider_type=ProviderType.openai,
api_key="sk-old-key",
base_url="https://api.openai.com/v1",
)
base_openai = await provider_manager.create_provider_async(
base_openai_create,
actor=default_user,
is_byok=False, # This is a base provider
)
# Create a BYOK provider (user-created)
byok_provider_create = ProviderCreate(
name="my-custom-openai",
provider_type=ProviderType.openai,
api_key="sk-my-key",
base_url="https://api.openai.com/v1",
)
byok_provider = await provider_manager.create_provider_async(
byok_provider_create,
actor=default_user,
is_byok=True,
)
assert byok_provider.provider_category == ProviderCategory.byok
# Now create server with NO environment variables set (all base providers disabled)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
from letta.settings import model_settings
monkeypatch.setattr(model_settings, "openai_api_key", None)
monkeypatch.setattr(model_settings, "anthropic_api_key", None)
# Create server instance
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.default_org = default_organization
# Verify only letta provider is enabled (no openai)
enabled_names = [p.name for p in server._enabled_providers]
assert "letta" in enabled_names
assert "openai" not in enabled_names
# Sync base providers (should not include openai anymore)
await server.provider_manager.sync_base_providers(
base_providers=server._enabled_providers,
actor=default_user,
)
# Call _sync_provider_models_async
await server._sync_provider_models_async()
# Verify base OpenAI provider was deleted (no longer enabled)
try:
await server.provider_manager.get_provider_async(base_openai.id, actor=default_user)
assert False, "Base OpenAI provider should have been deleted"
except Exception:
# Expected - provider should not exist
pass
# Verify BYOK provider still exists (should NOT be deleted)
byok_still_exists = await server.provider_manager.get_provider_async(
byok_provider.id,
actor=default_user,
)
assert byok_still_exists is not None
assert byok_still_exists.name == "my-custom-openai"
assert byok_still_exists.provider_category == ProviderCategory.byok
@pytest.mark.asyncio
async def test_server_startup_handles_api_errors_gracefully(default_user, default_organization, monkeypatch):
"""Test that server startup handles API errors gracefully without crashing.
This test verifies that:
1. If a provider's API call fails during sync, it logs an error but continues
2. Other providers can still sync successfully
3. The server startup completes without crashing
"""
from letta.schemas.providers import AnthropicProvider, OpenAIProvider
from letta.server.server import SyncServer
# Mock OpenAI to fail
async def mock_openai_fail(*args, **kwargs):
raise Exception("OpenAI API is down")
# Mock Anthropic to succeed
from unittest.mock import MagicMock
mock_anthropic_response = MagicMock()
mock_anthropic_response.model_dump.return_value = {
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"type": "model",
"display_name": "Claude 3.5 Sonnet",
"created_at": "2024-10-22T00:00:00Z",
}
]
}
class MockAnthropicModels:
async def list(self):
return mock_anthropic_response
class MockAsyncAnthropic:
def __init__(self, *args, **kwargs):
self.models = MockAnthropicModels()
monkeypatch.setattr(
"letta.llm_api.openai.openai_get_model_list_async",
mock_openai_fail,
)
monkeypatch.setattr(
"anthropic.AsyncAnthropic",
MockAsyncAnthropic,
)
# Set environment variables
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key")
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key")
from letta.settings import model_settings
monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key")
monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key")
# Create server
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.default_org = default_organization
# Sync base providers
await server.provider_manager.sync_base_providers(
base_providers=server._enabled_providers,
actor=default_user,
)
# This should NOT crash even though OpenAI fails
await server._sync_provider_models_async()
# Verify Anthropic still synced successfully
anthropic_providers = await server.provider_manager.list_providers_async(
name="anthropic",
actor=default_user,
)
assert len(anthropic_providers) == 1
anthropic_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=anthropic_providers[0].id,
model_type="llm",
)
assert len(anthropic_models) >= 1, "Anthropic models should have synced despite OpenAI failure"
# OpenAI should have no models (sync failed)
openai_providers = await server.provider_manager.list_providers_async(
name="openai",
actor=default_user,
)
if len(openai_providers) > 0:
openai_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=openai_providers[0].id,
)
# Models might exist from previous runs, but the sync attempt should have been logged as failed
# The key is that the server didn't crash

View File

@@ -32,7 +32,6 @@ from letta.config import LettaConfig
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.server.server import SyncServer
from tests.helpers.utils import upload_file_and_wait
from tests.utils import wait_for_server
# Constants
SERVER_PORT = 8283
@@ -107,7 +106,7 @@ def client() -> LettaSDKClient:
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
wait_for_server(server_url, timeout=60)
time.sleep(5)
print("Running client tests with server:", server_url)
client = LettaSDKClient(base_url=server_url)

View File

@@ -1740,110 +1740,3 @@ async def test_handle_uniqueness_per_org(default_user, provider_manager):
assert model is not None
assert model.provider_id == provider_1.id # Still original provider
assert model.max_context_window == 8192 # Still original
@pytest.mark.asyncio
async def test_delete_provider_cascades_to_models(default_user, provider_manager, monkeypatch):
"""Test that deleting a provider also soft-deletes its associated models."""
test_id = generate_test_id()
# Mock _sync_default_models_for_provider to avoid external API calls
async def mock_sync(provider, actor):
pass # Don't actually sync - we'll manually create models below
monkeypatch.setattr(provider_manager, "_sync_default_models_for_provider", mock_sync)
# 1. Create a BYOK provider (org-scoped, so the actor can delete it)
provider_create = ProviderCreate(
name=f"test-cascade-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-test-key",
)
provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=True)
# 2. Manually sync models to the provider
llm_models = [
LLMConfig(
model=f"gpt-4o-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-{test_id}/gpt-4o",
provider_name=provider.name,
provider_category=ProviderCategory.byok,
),
LLMConfig(
model=f"gpt-4o-mini-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=16384,
handle=f"test-{test_id}/gpt-4o-mini",
provider_name=provider.name,
provider_category=ProviderCategory.byok,
),
]
embedding_models = [
EmbeddingConfig(
embedding_model=f"text-embedding-3-small-{test_id}",
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=300,
handle=f"test-{test_id}/text-embedding-3-small",
),
]
await provider_manager.sync_provider_models_async(
provider=provider,
llm_models=llm_models,
embedding_models=embedding_models,
organization_id=default_user.organization_id, # Org-scoped for BYOK provider
)
# 3. Verify models exist before deletion
llm_models_before = await provider_manager.list_models_async(
actor=default_user,
model_type="llm",
provider_id=provider.id,
)
embedding_models_before = await provider_manager.list_models_async(
actor=default_user,
model_type="embedding",
provider_id=provider.id,
)
llm_handles_before = {m.handle for m in llm_models_before}
embedding_handles_before = {m.handle for m in embedding_models_before}
assert f"test-{test_id}/gpt-4o" in llm_handles_before
assert f"test-{test_id}/gpt-4o-mini" in llm_handles_before
assert f"test-{test_id}/text-embedding-3-small" in embedding_handles_before
# 4. Delete the provider
await provider_manager.delete_provider_by_id_async(provider.id, actor=default_user)
# 5. Verify models are soft-deleted (no longer returned in list)
all_llm_models_after = await provider_manager.list_models_async(
actor=default_user,
model_type="llm",
)
all_embedding_models_after = await provider_manager.list_models_async(
actor=default_user,
model_type="embedding",
)
all_llm_handles_after = {m.handle for m in all_llm_models_after}
all_embedding_handles_after = {m.handle for m in all_embedding_models_after}
# All models from the deleted provider should be gone
assert f"test-{test_id}/gpt-4o" not in all_llm_handles_after
assert f"test-{test_id}/gpt-4o-mini" not in all_llm_handles_after
assert f"test-{test_id}/text-embedding-3-small" not in all_embedding_handles_after
# 6. Verify provider is also deleted
providers_after = await provider_manager.list_providers_async(
actor=default_user,
name=f"test-cascade-{test_id}",
)
assert len(providers_after) == 0