From 848a73125cd9716bc307ddf706285e103572f208 Mon Sep 17 00:00:00 2001 From: Ari Webb Date: Tue, 9 Dec 2025 14:33:06 -0800 Subject: [PATCH] feat: enable provider models persistence (#6193) * Revert "fix test" This reverts commit 5126815f23cefb4edad3e3bf9e7083209dcc7bf1. * fix server and better test * test fix, get api key for base and byok? * set letta default endpoint * try to fix timeout for test * fix for letta api key * Delete apps/core/tests/sdk_v1/conftest.py * Update utils.py * clean up a few issues * fix filterning on list_llm_models * soft delete models with provider * add one more test * fix ci * add timeout * band aid for letta embedding provider * info instead of error logs when creating models --- letta/llm_api/llm_client_base.py | 18 +- letta/schemas/providers/letta.py | 5 +- letta/schemas/secret.py | 2 +- letta/server/server.py | 376 ++++++++++---------- letta/services/provider_manager.py | 15 +- tests/managers/test_provider_manager.py | 433 ++++++++++++++++++++++++ tests/test_sdk_client.py | 3 +- tests/test_server_providers.py | 107 ++++++ 8 files changed, 754 insertions(+), 205 deletions(-) diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index b25df76a..7c43ce8e 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -220,24 +220,38 @@ class LLMClientBase: def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]: """ Returns the override key for the given llm config. + For both base and BYOK providers, fetch the API key from the database. """ api_key = None - if llm_config.provider_category == ProviderCategory.byok: + # Fetch API key from database for both base and BYOK providers + # This ensures that base providers (from environment) also have their keys persisted and accessible + if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]: from letta.services.provider_manager import ProviderManager api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor) + # If we got an empty string from the database (e.g., Letta provider), treat it as None + # so the client can fall back to environment variables or default behavior + if api_key == "": + api_key = None return api_key, None, None async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]: """ Returns the override key for the given llm config. + For both base and BYOK providers, fetch the API key from the database. """ api_key = None - if llm_config.provider_category == ProviderCategory.byok: + # Fetch API key from database for both base and BYOK providers + # This ensures that base providers (from environment) also have their keys persisted and accessible + if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]: from letta.services.provider_manager import ProviderManager api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor) + # If we got an empty string from the database (e.g., Letta provider), treat it as None + # so the client can fall back to environment variables or default behavior + if api_key == "": + api_key = None return api_key, None, None diff --git a/letta/schemas/providers/letta.py b/letta/schemas/providers/letta.py index 34151fac..4b223ada 100644 --- a/letta/schemas/providers/letta.py +++ b/letta/schemas/providers/letta.py @@ -8,10 +8,13 @@ from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.llm_config import LLMConfig from letta.schemas.providers.base import Provider +LETTA_EMBEDDING_ENDPOINT = "https://embeddings.letta.com/" + class LettaProvider(Provider): provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = Field(LETTA_EMBEDDING_ENDPOINT, description="Base URL for the Letta API (used for embeddings).") async def list_llm_models_async(self) -> list[LLMConfig]: return [ @@ -31,7 +34,7 @@ class LettaProvider(Provider): EmbeddingConfig( embedding_model="letta-free", # NOTE: renamed embedding_endpoint_type="openai", - embedding_endpoint="https://embeddings.letta.com/", + embedding_endpoint=self.base_url, embedding_dim=1536, embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, handle=self.get_handle("letta-free", is_embedding=True), diff --git a/letta/schemas/secret.py b/letta/schemas/secret.py index 2790c4aa..9471ff03 100644 --- a/letta/schemas/secret.py +++ b/letta/schemas/secret.py @@ -113,7 +113,7 @@ class Secret(BaseModel): "MIGRATION_NEEDED: Reading from plaintext column instead of encrypted column. " "This indicates data that hasn't been migrated to the _enc column yet. " "Please run migrate data to _enc columns as plaintext columns will be deprecated.", - stack_info=True, + # stack_info=True, ) return cls.from_plaintext(plaintext_value) return cls.from_plaintext(None) diff --git a/letta/server/server.py b/letta/server/server.py index 20cd7f85..59267582 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -109,7 +109,7 @@ from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import DatabaseChoice, model_settings, settings, tool_settings from letta.streaming_interface import AgentChunkStreamingInterface -from letta.utils import get_friendly_error_msg, get_persona_text, make_key, safe_create_task +from letta.utils import get_friendly_error_msg, get_persona_text, safe_create_task config = LettaConfig.load() logger = get_logger(__name__) @@ -203,12 +203,10 @@ class SyncServer(object): """Initialize the MCP clients (there may be multiple)""" self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {} - # TODO: Remove these in memory caches - self._llm_config_cache = {} - self._embedding_config_cache = {} - # collect providers (always has Letta as a default) - self._enabled_providers: List[Provider] = [LettaProvider(name="letta")] + from letta.constants import LETTA_MODEL_ENDPOINT + + self._enabled_providers: List[Provider] = [LettaProvider(name="letta", base_url=LETTA_MODEL_ENDPOINT)] if model_settings.openai_api_key: self._enabled_providers.append( OpenAIProvider( @@ -342,6 +340,12 @@ class SyncServer(object): print(f"Default user: {self.default_user} and org: {self.default_org}") await self.tool_manager.upsert_base_tools_async(actor=self.default_user) + # Sync environment-based providers to database (idempotent, safe for multi-pod startup) + await self.provider_manager.sync_base_providers(base_providers=self._enabled_providers, actor=self.default_user) + + # Sync provider models to database + await self._sync_provider_models_async() + # For OSS users, create a local sandbox config oss_default_user = await self.user_manager.get_default_actor_async() use_venv = False if not tool_settings.tool_exec_venv_name else True @@ -378,6 +382,65 @@ class SyncServer(object): force_recreate=True, ) + def _get_enabled_provider(self, provider_name: str) -> Optional[Provider]: + """Find and return an enabled provider by name. + + Args: + provider_name: The name of the provider to find + + Returns: + The matching enabled provider, or None if not found + """ + for provider in self._enabled_providers: + if provider.name == provider_name: + return provider + return None + + async def _sync_provider_models_async(self): + """Sync all provider models to database at startup.""" + logger.info("Syncing provider models to database") + + # Get persisted providers from database (they now have IDs) + persisted_providers = await self.provider_manager.list_providers_async(actor=self.default_user) + + for persisted_provider in persisted_providers: + try: + # Find the matching enabled provider instance to call list_models on + enabled_provider = self._get_enabled_provider(persisted_provider.name) + + if not enabled_provider: + # Only delete base providers that are no longer enabled + # BYOK providers are user-created and should not be automatically deleted + if persisted_provider.provider_category == ProviderCategory.base: + logger.info(f"Base provider {persisted_provider.name} is no longer enabled, deleting from database") + try: + await self.provider_manager.delete_provider_by_id_async( + provider_id=persisted_provider.id, actor=self.default_user + ) + except NoResultFound: + # Provider was already deleted (race condition in multi-pod startup) + logger.debug(f"Provider {persisted_provider.name} was already deleted, skipping") + else: + logger.debug(f"No enabled provider for BYOK provider {persisted_provider.name}, skipping model sync") + continue + + # Fetch models from provider + llm_models = await enabled_provider.list_llm_models_async() + embedding_models = await enabled_provider.list_embedding_models_async() + + # Save to database with the persisted provider (which has an ID) + await self.provider_manager.sync_provider_models_async( + provider=persisted_provider, + llm_models=llm_models, + embedding_models=embedding_models, + organization_id=None, # Global models + ) + logger.info( + f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}" + ) + except Exception as e: + logger.error(f"Failed to sync models for provider {persisted_provider.name}: {e}", exc_info=True) + async def init_mcp_clients(self): # TODO: remove this mcp_server_configs = self.get_mcp_servers() @@ -405,39 +468,6 @@ class SyncServer(object): logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}") logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}") - @trace_method - def get_cached_llm_config(self, actor: User, **kwargs): - key = make_key(**kwargs) - if key not in self._llm_config_cache: - self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs) - logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries") - return self._llm_config_cache[key] - - @trace_method - async def get_cached_llm_config_async(self, actor: User, **kwargs): - key = make_key(**kwargs) - if key not in self._llm_config_cache: - self._llm_config_cache[key] = await self.get_llm_config_from_handle_async(actor=actor, **kwargs) - logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries") - return self._llm_config_cache[key] - - @trace_method - def get_cached_embedding_config(self, actor: User, **kwargs): - key = make_key(**kwargs) - if key not in self._embedding_config_cache: - self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs) - logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries") - return self._embedding_config_cache[key] - - # @async_redis_cache(key_func=lambda (actor, **kwargs): actor.id + hash(kwargs)) - @trace_method - async def get_cached_embedding_config_async(self, actor: User, **kwargs): - key = make_key(**kwargs) - if key not in self._embedding_config_cache: - self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs) - logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries") - return self._embedding_config_cache[key] - @trace_method async def create_agent_async( self, @@ -471,10 +501,9 @@ class SyncServer(object): "max_reasoning_tokens": request.max_reasoning_tokens, "enable_reasoner": request.enable_reasoner, } - config_params.update(additional_config_params) - log_event(name="start get_cached_llm_config", attributes=config_params) - request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params) - log_event(name="end get_cached_llm_config", attributes=config_params) + log_event(name="start get_llm_config_from_handle", attributes=config_params) + request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params) + log_event(name="end get_llm_config_from_handle", attributes=config_params) if request.model and isinstance(request.model, str): assert request.llm_config.handle == request.model, ( f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}" @@ -504,9 +533,9 @@ class SyncServer(object): "handle": request.embedding, "embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE, } - log_event(name="start get_cached_embedding_config", attributes=embedding_config_params) - request.embedding_config = await self.get_cached_embedding_config_async(actor=actor, **embedding_config_params) - log_event(name="end get_cached_embedding_config", attributes=embedding_config_params) + log_event(name="start get_embedding_config_from_handle", attributes=embedding_config_params) + request.embedding_config = await self.get_embedding_config_from_handle_async(actor=actor, **embedding_config_params) + log_event(name="end get_embedding_config_from_handle", attributes=embedding_config_params) log_event(name="start create_agent db") main_agent = await self.agent_manager.create_agent_async( @@ -555,9 +584,9 @@ class SyncServer(object): "context_window_limit": request.context_window_limit, "max_tokens": request.max_tokens, } - log_event(name="start get_cached_llm_config", attributes=config_params) - request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params) - log_event(name="end get_cached_llm_config", attributes=config_params) + log_event(name="start get_llm_config_from_handle", attributes=config_params) + request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params) + log_event(name="end get_llm_config_from_handle", attributes=config_params) # update with model_settings if request.model_settings is not None: @@ -1061,73 +1090,85 @@ class SyncServer(object): provider_name: Optional[str] = None, provider_type: Optional[ProviderType] = None, ) -> List[LLMConfig]: - """Asynchronously list available models with maximum concurrency""" - import asyncio + """List available LLM models from database cache""" + # Get provider IDs if filtering by provider + provider_ids = None + if provider_name or provider_type: + providers = await self.get_enabled_providers_async( + provider_category=provider_category, + provider_name=provider_name, + provider_type=provider_type, + actor=actor, + ) + provider_ids = [p.id for p in providers] - providers = await self.get_enabled_providers_async( - provider_category=provider_category, - provider_name=provider_name, - provider_type=provider_type, + # If filtering was requested but no providers matched, return empty list + if not provider_ids: + return [] + + # Get models from database + provider_models = await self.provider_manager.list_models_async( actor=actor, + model_type="llm", + provider_id=provider_ids[0] if provider_ids and len(provider_ids) == 1 else None, + enabled=True, ) - async def get_provider_models(provider: Provider) -> list[LLMConfig]: - try: - async with asyncio.timeout(constants.GET_PROVIDERS_TIMEOUT_SECONDS): - return await provider.list_llm_models_async() - except asyncio.TimeoutError: - logger.warning(f"Timeout while listing LLM models for provider {provider}") - return [] - except Exception as e: - logger.exception(f"Error while listing LLM models for provider {provider}: {e}") - return [] - - # Execute all provider model listing tasks concurrently - provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers]) - - # Flatten the results + # Build LLMConfig objects from cached data + # Cache providers to avoid N+1 queries + provider_cache: Dict[str, Provider] = {} llm_models = [] - for models in provider_results: - llm_models.extend(models) + for model in provider_models: + # Skip if filtering by provider and model doesn't match + if provider_ids and model.provider_id not in provider_ids: + continue - # Get local configs - if this is potentially slow, consider making it async too - local_configs = self.get_local_llm_configs() - llm_models.extend(local_configs) + # Get provider details (with caching to avoid N+1 queries) + if model.provider_id not in provider_cache: + provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor) + provider = provider_cache[model.provider_id] - # dedupe by handle for uniqueness - # Seems like this is required from the tests? - seen_handles = set() - unique_models = [] - for model in llm_models: - if model.handle not in seen_handles: - seen_handles.add(model.handle) - unique_models.append(model) + llm_config = LLMConfig( + model=model.name, + model_endpoint_type=model.model_endpoint_type, + model_endpoint=provider.base_url or model.model_endpoint_type, + context_window=model.max_context_window or 16384, + handle=model.handle, + provider_name=provider.name, + provider_category=provider.provider_category, + ) + llm_models.append(llm_config) - return unique_models + return llm_models async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]: - """Asynchronously list available embedding models with maximum concurrency""" - import asyncio + """List available embedding models from database cache""" + # Get models from database + provider_models = await self.provider_manager.list_models_async( + actor=actor, + model_type="embedding", + enabled=True, + ) - # Get all eligible providers first - providers = await self.get_enabled_providers_async(actor=actor) - - # Fetch embedding models from each provider concurrently - async def get_provider_embedding_models(provider): - try: - # All providers now have list_embedding_models_async - return await provider.list_embedding_models_async() - except Exception as e: - logger.exception(f"An error occurred while listing embedding models for provider {provider}: {e}") - return [] - - # Execute all provider model listing tasks concurrently - provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers]) - - # Flatten the results + # Build EmbeddingConfig objects from cached data + # Cache providers to avoid N+1 queries + provider_cache: Dict[str, Provider] = {} embedding_models = [] - for models in provider_results: - embedding_models.extend(models) + for model in provider_models: + # Get provider details (with caching to avoid N+1 queries) + if model.provider_id not in provider_cache: + provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor) + provider = provider_cache[model.provider_id] + + embedding_config = EmbeddingConfig( + embedding_model=model.name, + embedding_endpoint_type=model.model_endpoint_type, + embedding_endpoint=provider.base_url or model.model_endpoint_type, + embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default + embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=model.handle, + ) + embedding_models.append(embedding_config) return embedding_models @@ -1140,17 +1181,22 @@ class SyncServer(object): ) -> List[Provider]: providers = [] if not provider_category or ProviderCategory.base in provider_category: - providers_from_env = [p for p in self._enabled_providers] - providers.extend(providers_from_env) + # Add enabled providers (base providers from environment) + enabled_providers = [p for p in self._enabled_providers] + providers.extend(enabled_providers) if not provider_category or ProviderCategory.byok in provider_category: - providers_from_db = await self.provider_manager.list_providers_async( + # Add persisted BYOK providers from database + # Note: list_providers_async returns both org-specific and global providers, + # so we filter to only include BYOK providers to avoid duplicating base providers + persisted_providers = await self.provider_manager.list_providers_async( name=provider_name, provider_type=provider_type, actor=actor, ) - providers_from_db = [p.cast_to_subtype() for p in providers_from_db] - providers.extend(providers_from_db) + # Filter to only BYOK providers (base providers are already in self._enabled_providers) + persisted_byok_providers = [p.cast_to_subtype() for p in persisted_providers if p.provider_category == ProviderCategory.byok] + providers.extend(persisted_byok_providers) if provider_name is not None: providers = [p for p in providers if p.name == provider_name] @@ -1170,32 +1216,19 @@ class SyncServer(object): max_reasoning_tokens: Optional[int] = None, enable_reasoner: Optional[bool] = None, ) -> LLMConfig: + # Use provider_manager to get LLMConfig from handle try: - provider_name, model_name = handle.split("/", 1) - provider = await self.get_provider_from_name_async(provider_name, actor) - - all_llm_configs = await provider.list_llm_models_async() - llm_configs = [config for config in all_llm_configs if config.handle == handle] - if not llm_configs: - llm_configs = [config for config in all_llm_configs if config.model == model_name] - if not llm_configs: - available_handles = [config.handle for config in all_llm_configs] - raise HandleNotFoundError(handle, available_handles) - except ValueError as e: - llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle] - if not llm_configs: - llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name] - if not llm_configs: - raise e - - if len(llm_configs) == 1: - llm_config = llm_configs[0] - elif len(llm_configs) > 1: - raise LettaInvalidArgumentError( - f"Multiple LLM models with name {model_name} supported by {provider_name}", argument_name="model_name" + llm_config = await self.provider_manager.get_llm_config_from_handle( + handle=handle, + actor=actor, ) - else: - llm_config = llm_configs[0] + except Exception as e: + # Convert to HandleNotFoundError for backwards compatibility + from letta.orm.errors import NoResultFound + + if isinstance(e, NoResultFound): + raise HandleNotFoundError(handle, []) + raise if context_window_limit is not None: if context_window_limit > llm_config.context_window: @@ -1227,33 +1260,22 @@ class SyncServer(object): async def get_embedding_config_from_handle_async( self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE ) -> EmbeddingConfig: + # Use provider_manager to get EmbeddingConfig from handle try: - provider_name, model_name = handle.split("/", 1) - provider = await self.get_provider_from_name_async(provider_name, actor) - - all_embedding_configs = await provider.list_embedding_models_async() - embedding_configs = [config for config in all_embedding_configs if config.handle == handle] - if not embedding_configs: - raise LettaInvalidArgumentError( - f"Embedding model {model_name} is not supported by {provider_name}", argument_name="model_name" - ) - except LettaInvalidArgumentError as e: - # search local configs - embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle] - if not embedding_configs: - raise e - - if len(embedding_configs) == 1: - embedding_config = embedding_configs[0] - elif len(embedding_configs) > 1: - raise LettaInvalidArgumentError( - f"Multiple embedding models with name {model_name} supported by {provider_name}", argument_name="model_name" + embedding_config = await self.provider_manager.get_embedding_config_from_handle( + handle=handle, + actor=actor, ) - else: - embedding_config = embedding_configs[0] + except Exception as e: + # Convert to LettaInvalidArgumentError for backwards compatibility + from letta.orm.errors import NoResultFound - if embedding_chunk_size: - embedding_config.embedding_chunk_size = embedding_chunk_size + if isinstance(e, NoResultFound): + raise LettaInvalidArgumentError(f"Embedding model {handle} not found", argument_name="handle") + raise + + # Override chunk size if provided + embedding_config.embedding_chunk_size = embedding_chunk_size return embedding_config @@ -1272,46 +1294,6 @@ class SyncServer(object): return provider - def get_local_llm_configs(self): - llm_models = [] - # NOTE: deprecated - # try: - # llm_configs_dir = os.path.expanduser("~/.letta/llm_configs") - # if os.path.exists(llm_configs_dir): - # for filename in os.listdir(llm_configs_dir): - # if filename.endswith(".json"): - # filepath = os.path.join(llm_configs_dir, filename) - # try: - # with open(filepath, "r") as f: - # config_data = json.load(f) - # llm_config = LLMConfig(**config_data) - # llm_models.append(llm_config) - # except (json.JSONDecodeError, ValueError) as e: - # logger.warning(f"Error parsing LLM config file {filename}: {e}") - # except Exception as e: - # logger.warning(f"Error reading LLM configs directory: {e}") - return llm_models - - def get_local_embedding_configs(self): - embedding_models = [] - # NOTE: deprecated - # try: - # embedding_configs_dir = os.path.expanduser("~/.letta/embedding_configs") - # if os.path.exists(embedding_configs_dir): - # for filename in os.listdir(embedding_configs_dir): - # if filename.endswith(".json"): - # filepath = os.path.join(embedding_configs_dir, filename) - # try: - # with open(filepath, "r") as f: - # config_data = json.load(f) - # embedding_config = EmbeddingConfig(**config_data) - # embedding_models.append(embedding_config) - # except (json.JSONDecodeError, ValueError) as e: - # logger.warning(f"Error parsing embedding config file {filename}: {e}") - # except Exception as e: - # logger.warning(f"Error reading embedding configs directory: {e}") - return embedding_models - def add_llm_model(self, request: LLMConfig) -> LLMConfig: """Add a new LLM model""" diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index f99f7944..8d327e4c 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -154,7 +154,7 @@ class ProviderManager: @trace_method @raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER) async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser): - """Delete a provider.""" + """Delete a provider and its associated models.""" async with db_registry.async_session() as session: # Clear api key field existing_provider = await ProviderModel.read_async( @@ -163,6 +163,15 @@ class ProviderManager: existing_provider.api_key = None await existing_provider.update_async(session, actor=actor) + # Soft delete all models associated with this provider + provider_models = await ProviderModelORM.list_async( + db_session=session, + provider_id=provider_id, + check_is_deleted=True, + ) + for model in provider_models: + await model.delete_async(session, actor=actor) + # Soft delete in provider table await existing_provider.delete_async(session, actor=actor) @@ -631,11 +640,11 @@ class ProviderManager: await model.create_async(session) logger.info(f" ✓ Successfully created LLM model {llm_config.handle} with ID {model.id}") except Exception as e: - logger.error(f" ✗ Failed to create LLM model {llm_config.handle}: {e}") + logger.info(f" ✗ Failed to create LLM model {llm_config.handle}: {e}") # Log the full error details import traceback - logger.error(f" Full traceback: {traceback.format_exc()}") + logger.info(f" Full traceback: {traceback.format_exc()}") # Roll back the session to clear the failed transaction await session.rollback() else: diff --git a/tests/managers/test_provider_manager.py b/tests/managers/test_provider_manager.py index 61e2597f..dd0bda6b 100644 --- a/tests/managers/test_provider_manager.py +++ b/tests/managers/test_provider_manager.py @@ -499,3 +499,436 @@ async def test_byok_provider_auto_syncs_models(provider_manager, default_user, m llm_config = await provider_manager.get_llm_config_from_handle(handle="my-openai-key/gpt-4o", actor=default_user) assert llm_config.model == "gpt-4o" assert llm_config.provider_name == "my-openai-key" + + +# ====================================================================================================================== +# Server Startup Provider Sync Tests +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_server_startup_syncs_base_providers(default_user, default_organization, monkeypatch): + """Test that server startup properly syncs base provider models from environment. + + This test simulates the server startup process and verifies that: + 1. Base providers from environment variables are synced to database + 2. Provider models are fetched from mocked API endpoints + 3. Models are properly persisted to the database with correct metadata + 4. Models can be retrieved using handles + """ + from unittest.mock import AsyncMock + + from letta.schemas.embedding_config import EmbeddingConfig + from letta.schemas.llm_config import LLMConfig + from letta.schemas.providers import AnthropicProvider, OpenAIProvider + from letta.server.server import SyncServer + + # Mock OpenAI API responses + mock_openai_models = { + "data": [ + { + "id": "gpt-4", + "object": "model", + "created": 1687882411, + "owned_by": "openai", + "max_model_len": 8192, + }, + { + "id": "gpt-4-turbo", + "object": "model", + "created": 1712361441, + "owned_by": "system", + "max_model_len": 128000, + }, + { + "id": "text-embedding-ada-002", + "object": "model", + "created": 1671217299, + "owned_by": "openai-internal", + }, + { + "id": "gpt-4-vision", # Should be filtered out by OpenAI provider logic (has disallowed keyword) + "object": "model", + "created": 1698959748, + "owned_by": "system", + "max_model_len": 8192, + }, + ] + } + + # Mock Anthropic API responses + mock_anthropic_models = { + "data": [ + { + "id": "claude-3-5-sonnet-20241022", + "type": "model", + "display_name": "Claude 3.5 Sonnet", + "created_at": "2024-10-22T00:00:00Z", + }, + { + "id": "claude-3-opus-20240229", + "type": "model", + "display_name": "Claude 3 Opus", + "created_at": "2024-02-29T00:00:00Z", + }, + ] + } + + # Mock the API calls for OpenAI + async def mock_openai_get_model_list_async(*args, **kwargs): + return mock_openai_models + + # Mock Anthropic models.list() response + from unittest.mock import MagicMock + + mock_anthropic_response = MagicMock() + mock_anthropic_response.model_dump.return_value = mock_anthropic_models + + # Mock the Anthropic AsyncAnthropic client + class MockAnthropicModels: + async def list(self): + return mock_anthropic_response + + class MockAsyncAnthropic: + def __init__(self, *args, **kwargs): + self.models = MockAnthropicModels() + + # Patch the actual API calling functions + monkeypatch.setattr( + "letta.llm_api.openai.openai_get_model_list_async", + mock_openai_get_model_list_async, + ) + monkeypatch.setattr( + "anthropic.AsyncAnthropic", + MockAsyncAnthropic, + ) + + # Clear ALL provider-related env vars first to ensure clean state + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + monkeypatch.delenv("GEMINI_API_KEY", raising=False) + monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False) + monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False) + monkeypatch.delenv("AZURE_API_KEY", raising=False) + monkeypatch.delenv("GROQ_API_KEY", raising=False) + monkeypatch.delenv("TOGETHER_API_KEY", raising=False) + monkeypatch.delenv("VLLM_API_BASE", raising=False) + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + monkeypatch.delenv("LMSTUDIO_BASE_URL", raising=False) + monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False) + monkeypatch.delenv("XAI_API_KEY", raising=False) + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + + # Set environment variables to enable only OpenAI and Anthropic + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key-12345") + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key-67890") + + # Reload model_settings to pick up new env vars + from letta.settings import model_settings + + monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key-12345") + monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key-67890") + monkeypatch.setattr(model_settings, "gemini_api_key", None) + monkeypatch.setattr(model_settings, "google_cloud_project", None) + monkeypatch.setattr(model_settings, "google_cloud_location", None) + monkeypatch.setattr(model_settings, "azure_api_key", None) + monkeypatch.setattr(model_settings, "groq_api_key", None) + monkeypatch.setattr(model_settings, "together_api_key", None) + monkeypatch.setattr(model_settings, "vllm_api_base", None) + monkeypatch.setattr(model_settings, "aws_access_key_id", None) + monkeypatch.setattr(model_settings, "aws_secret_access_key", None) + monkeypatch.setattr(model_settings, "lmstudio_base_url", None) + monkeypatch.setattr(model_settings, "deepseek_api_key", None) + monkeypatch.setattr(model_settings, "xai_api_key", None) + monkeypatch.setattr(model_settings, "openrouter_api_key", None) + + # Create server instance (this will load enabled providers from environment) + server = SyncServer(init_with_default_org_and_user=False) + + # Manually set up the default user/org (since we disabled auto-init) + server.default_user = default_user + server.default_org = default_organization + + # Verify enabled providers were loaded + assert len(server._enabled_providers) == 3 # Exactly: letta, openai, anthropic + enabled_provider_names = [p.name for p in server._enabled_providers] + assert "letta" in enabled_provider_names + assert "openai" in enabled_provider_names + assert "anthropic" in enabled_provider_names + + # First, sync base providers to database (this is what init_async does) + await server.provider_manager.sync_base_providers( + base_providers=server._enabled_providers, + actor=default_user, + ) + + # Now call the actual _sync_provider_models_async method + # This simulates what happens during server startup + await server._sync_provider_models_async() + + # Verify OpenAI models were synced + openai_providers = await server.provider_manager.list_providers_async( + name="openai", + actor=default_user, + ) + assert len(openai_providers) == 1, "OpenAI provider should exist" + openai_provider = openai_providers[0] + + # Check OpenAI LLM models + openai_llm_models = await server.provider_manager.list_models_async( + actor=default_user, + provider_id=openai_provider.id, + model_type="llm", + ) + + # Should have gpt-4 and gpt-4-turbo (gpt-4-vision filtered out due to "vision" keyword) + assert len(openai_llm_models) >= 2, f"Expected at least 2 OpenAI LLM models, got {len(openai_llm_models)}" + openai_model_names = [m.name for m in openai_llm_models] + assert "gpt-4" in openai_model_names + assert "gpt-4-turbo" in openai_model_names + + # Check OpenAI embedding models + openai_embedding_models = await server.provider_manager.list_models_async( + actor=default_user, + provider_id=openai_provider.id, + model_type="embedding", + ) + assert len(openai_embedding_models) >= 1, "Expected at least 1 OpenAI embedding model" + embedding_model_names = [m.name for m in openai_embedding_models] + assert "text-embedding-ada-002" in embedding_model_names + + # Verify model metadata is correct + gpt4_models = [m for m in openai_llm_models if m.name == "gpt-4"] + assert len(gpt4_models) > 0, "gpt-4 model should exist" + gpt4_model = gpt4_models[0] + assert gpt4_model.handle == "openai/gpt-4" + assert gpt4_model.model_endpoint_type == "openai" + assert gpt4_model.max_context_window == 8192 + assert gpt4_model.enabled is True + + # Verify Anthropic models were synced + anthropic_providers = await server.provider_manager.list_providers_async( + name="anthropic", + actor=default_user, + ) + assert len(anthropic_providers) == 1, "Anthropic provider should exist" + anthropic_provider = anthropic_providers[0] + + anthropic_llm_models = await server.provider_manager.list_models_async( + actor=default_user, + provider_id=anthropic_provider.id, + model_type="llm", + ) + + # Should have Claude models + assert len(anthropic_llm_models) >= 2, f"Expected at least 2 Anthropic models, got {len(anthropic_llm_models)}" + anthropic_model_names = [m.name for m in anthropic_llm_models] + assert "claude-3-5-sonnet-20241022" in anthropic_model_names + assert "claude-3-opus-20240229" in anthropic_model_names + + # Test that we can retrieve LLMConfig from handle + llm_config = await server.provider_manager.get_llm_config_from_handle( + handle="openai/gpt-4", + actor=default_user, + ) + assert llm_config.model == "gpt-4" + assert llm_config.handle == "openai/gpt-4" + assert llm_config.provider_name == "openai" + assert llm_config.context_window == 8192 + + # Test that we can retrieve EmbeddingConfig from handle + embedding_config = await server.provider_manager.get_embedding_config_from_handle( + handle="openai/text-embedding-ada-002", + actor=default_user, + ) + assert embedding_config.embedding_model == "text-embedding-ada-002" + assert embedding_config.handle == "openai/text-embedding-ada-002" + assert embedding_config.embedding_dim == 1536 + + +@pytest.mark.asyncio +async def test_server_startup_handles_disabled_providers(default_user, default_organization, monkeypatch): + """Test that server startup properly handles providers that are no longer enabled. + + This test verifies that: + 1. Base providers that are no longer enabled (env vars removed) are deleted + 2. BYOK providers that are no longer enabled are NOT deleted (user-created) + 3. The sync process handles providers gracefully when API calls fail + """ + from letta.schemas.providers import OpenAIProvider, ProviderCreate + from letta.server.server import SyncServer + + # First, manually create providers in the database + provider_manager = ProviderManager() + + # Create a base OpenAI provider (simulating it was synced before) + base_openai_create = ProviderCreate( + name="openai", + provider_type=ProviderType.openai, + api_key="sk-old-key", + base_url="https://api.openai.com/v1", + ) + base_openai = await provider_manager.create_provider_async( + base_openai_create, + actor=default_user, + is_byok=False, # This is a base provider + ) + + # Create a BYOK provider (user-created) + byok_provider_create = ProviderCreate( + name="my-custom-openai", + provider_type=ProviderType.openai, + api_key="sk-my-key", + base_url="https://api.openai.com/v1", + ) + byok_provider = await provider_manager.create_provider_async( + byok_provider_create, + actor=default_user, + is_byok=True, + ) + assert byok_provider.provider_category == ProviderCategory.byok + + # Now create server with NO environment variables set (all base providers disabled) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + + from letta.settings import model_settings + + monkeypatch.setattr(model_settings, "openai_api_key", None) + monkeypatch.setattr(model_settings, "anthropic_api_key", None) + + # Create server instance + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.default_org = default_organization + + # Verify only letta provider is enabled (no openai) + enabled_names = [p.name for p in server._enabled_providers] + assert "letta" in enabled_names + assert "openai" not in enabled_names + + # Sync base providers (should not include openai anymore) + await server.provider_manager.sync_base_providers( + base_providers=server._enabled_providers, + actor=default_user, + ) + + # Call _sync_provider_models_async + await server._sync_provider_models_async() + + # Verify base OpenAI provider was deleted (no longer enabled) + try: + await server.provider_manager.get_provider_async(base_openai.id, actor=default_user) + assert False, "Base OpenAI provider should have been deleted" + except Exception: + # Expected - provider should not exist + pass + + # Verify BYOK provider still exists (should NOT be deleted) + byok_still_exists = await server.provider_manager.get_provider_async( + byok_provider.id, + actor=default_user, + ) + assert byok_still_exists is not None + assert byok_still_exists.name == "my-custom-openai" + assert byok_still_exists.provider_category == ProviderCategory.byok + + +@pytest.mark.asyncio +async def test_server_startup_handles_api_errors_gracefully(default_user, default_organization, monkeypatch): + """Test that server startup handles API errors gracefully without crashing. + + This test verifies that: + 1. If a provider's API call fails during sync, it logs an error but continues + 2. Other providers can still sync successfully + 3. The server startup completes without crashing + """ + from letta.schemas.providers import AnthropicProvider, OpenAIProvider + from letta.server.server import SyncServer + + # Mock OpenAI to fail + async def mock_openai_fail(*args, **kwargs): + raise Exception("OpenAI API is down") + + # Mock Anthropic to succeed + from unittest.mock import MagicMock + + mock_anthropic_response = MagicMock() + mock_anthropic_response.model_dump.return_value = { + "data": [ + { + "id": "claude-3-5-sonnet-20241022", + "type": "model", + "display_name": "Claude 3.5 Sonnet", + "created_at": "2024-10-22T00:00:00Z", + } + ] + } + + class MockAnthropicModels: + async def list(self): + return mock_anthropic_response + + class MockAsyncAnthropic: + def __init__(self, *args, **kwargs): + self.models = MockAnthropicModels() + + monkeypatch.setattr( + "letta.llm_api.openai.openai_get_model_list_async", + mock_openai_fail, + ) + monkeypatch.setattr( + "anthropic.AsyncAnthropic", + MockAsyncAnthropic, + ) + + # Set environment variables + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key") + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key") + + from letta.settings import model_settings + + monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key") + monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key") + + # Create server + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.default_org = default_organization + + # Sync base providers + await server.provider_manager.sync_base_providers( + base_providers=server._enabled_providers, + actor=default_user, + ) + + # This should NOT crash even though OpenAI fails + await server._sync_provider_models_async() + + # Verify Anthropic still synced successfully + anthropic_providers = await server.provider_manager.list_providers_async( + name="anthropic", + actor=default_user, + ) + assert len(anthropic_providers) == 1 + + anthropic_models = await server.provider_manager.list_models_async( + actor=default_user, + provider_id=anthropic_providers[0].id, + model_type="llm", + ) + assert len(anthropic_models) >= 1, "Anthropic models should have synced despite OpenAI failure" + + # OpenAI should have no models (sync failed) + openai_providers = await server.provider_manager.list_providers_async( + name="openai", + actor=default_user, + ) + if len(openai_providers) > 0: + openai_models = await server.provider_manager.list_models_async( + actor=default_user, + provider_id=openai_providers[0].id, + ) + # Models might exist from previous runs, but the sync attempt should have been logged as failed + # The key is that the server didn't crash diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index cfc05dd3..0272e0e3 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -32,6 +32,7 @@ from letta.config import LettaConfig from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.server.server import SyncServer from tests.helpers.utils import upload_file_and_wait +from tests.utils import wait_for_server # Constants SERVER_PORT = 8283 @@ -106,7 +107,7 @@ def client() -> LettaSDKClient: print("Starting server thread") thread = threading.Thread(target=run_server, daemon=True) thread.start() - time.sleep(5) + wait_for_server(server_url, timeout=60) print("Running client tests with server:", server_url) client = LettaSDKClient(base_url=server_url) diff --git a/tests/test_server_providers.py b/tests/test_server_providers.py index 2306c22f..67494bc5 100644 --- a/tests/test_server_providers.py +++ b/tests/test_server_providers.py @@ -1740,3 +1740,110 @@ async def test_handle_uniqueness_per_org(default_user, provider_manager): assert model is not None assert model.provider_id == provider_1.id # Still original provider assert model.max_context_window == 8192 # Still original + + +@pytest.mark.asyncio +async def test_delete_provider_cascades_to_models(default_user, provider_manager, monkeypatch): + """Test that deleting a provider also soft-deletes its associated models.""" + test_id = generate_test_id() + + # Mock _sync_default_models_for_provider to avoid external API calls + async def mock_sync(provider, actor): + pass # Don't actually sync - we'll manually create models below + + monkeypatch.setattr(provider_manager, "_sync_default_models_for_provider", mock_sync) + + # 1. Create a BYOK provider (org-scoped, so the actor can delete it) + provider_create = ProviderCreate( + name=f"test-cascade-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-test-key", + ) + provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=True) + + # 2. Manually sync models to the provider + llm_models = [ + LLMConfig( + model=f"gpt-4o-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=f"test-{test_id}/gpt-4o", + provider_name=provider.name, + provider_category=ProviderCategory.byok, + ), + LLMConfig( + model=f"gpt-4o-mini-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=16384, + handle=f"test-{test_id}/gpt-4o-mini", + provider_name=provider.name, + provider_category=ProviderCategory.byok, + ), + ] + + embedding_models = [ + EmbeddingConfig( + embedding_model=f"text-embedding-3-small-{test_id}", + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + embedding_chunk_size=300, + handle=f"test-{test_id}/text-embedding-3-small", + ), + ] + + await provider_manager.sync_provider_models_async( + provider=provider, + llm_models=llm_models, + embedding_models=embedding_models, + organization_id=default_user.organization_id, # Org-scoped for BYOK provider + ) + + # 3. Verify models exist before deletion + llm_models_before = await provider_manager.list_models_async( + actor=default_user, + model_type="llm", + provider_id=provider.id, + ) + embedding_models_before = await provider_manager.list_models_async( + actor=default_user, + model_type="embedding", + provider_id=provider.id, + ) + + llm_handles_before = {m.handle for m in llm_models_before} + embedding_handles_before = {m.handle for m in embedding_models_before} + + assert f"test-{test_id}/gpt-4o" in llm_handles_before + assert f"test-{test_id}/gpt-4o-mini" in llm_handles_before + assert f"test-{test_id}/text-embedding-3-small" in embedding_handles_before + + # 4. Delete the provider + await provider_manager.delete_provider_by_id_async(provider.id, actor=default_user) + + # 5. Verify models are soft-deleted (no longer returned in list) + all_llm_models_after = await provider_manager.list_models_async( + actor=default_user, + model_type="llm", + ) + all_embedding_models_after = await provider_manager.list_models_async( + actor=default_user, + model_type="embedding", + ) + + all_llm_handles_after = {m.handle for m in all_llm_models_after} + all_embedding_handles_after = {m.handle for m in all_embedding_models_after} + + # All models from the deleted provider should be gone + assert f"test-{test_id}/gpt-4o" not in all_llm_handles_after + assert f"test-{test_id}/gpt-4o-mini" not in all_llm_handles_after + assert f"test-{test_id}/text-embedding-3-small" not in all_embedding_handles_after + + # 6. Verify provider is also deleted + providers_after = await provider_manager.list_providers_async( + actor=default_user, + name=f"test-cascade-{test_id}", + ) + assert len(providers_after) == 0