From 8440e319e28107066bd0911eee32be4ec4dd7228 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 9 Dec 2025 16:46:26 -0800 Subject: [PATCH] Revert "feat: enable provider models persistence" (#6590) Revert "feat: enable provider models persistence (#6193)" This reverts commit 9682aff32640a6ee8cf71a6f18c9fa7cda25c40e. --- letta/llm_api/llm_client_base.py | 18 +- letta/schemas/providers/letta.py | 5 +- letta/schemas/secret.py | 2 +- letta/server/server.py | 376 ++++++++++---------- letta/services/provider_manager.py | 15 +- tests/managers/test_provider_manager.py | 433 ------------------------ tests/test_sdk_client.py | 3 +- tests/test_server_providers.py | 107 ------ 8 files changed, 205 insertions(+), 754 deletions(-) diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 7c43ce8e..b25df76a 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -220,38 +220,24 @@ class LLMClientBase: def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]: """ Returns the override key for the given llm config. - For both base and BYOK providers, fetch the API key from the database. """ api_key = None - # Fetch API key from database for both base and BYOK providers - # This ensures that base providers (from environment) also have their keys persisted and accessible - if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]: + if llm_config.provider_category == ProviderCategory.byok: from letta.services.provider_manager import ProviderManager api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor) - # If we got an empty string from the database (e.g., Letta provider), treat it as None - # so the client can fall back to environment variables or default behavior - if api_key == "": - api_key = None return api_key, None, None async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]: """ Returns the override key for the given llm config. - For both base and BYOK providers, fetch the API key from the database. """ api_key = None - # Fetch API key from database for both base and BYOK providers - # This ensures that base providers (from environment) also have their keys persisted and accessible - if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]: + if llm_config.provider_category == ProviderCategory.byok: from letta.services.provider_manager import ProviderManager api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor) - # If we got an empty string from the database (e.g., Letta provider), treat it as None - # so the client can fall back to environment variables or default behavior - if api_key == "": - api_key = None return api_key, None, None diff --git a/letta/schemas/providers/letta.py b/letta/schemas/providers/letta.py index 4b223ada..34151fac 100644 --- a/letta/schemas/providers/letta.py +++ b/letta/schemas/providers/letta.py @@ -8,13 +8,10 @@ from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.llm_config import LLMConfig from letta.schemas.providers.base import Provider -LETTA_EMBEDDING_ENDPOINT = "https://embeddings.letta.com/" - class LettaProvider(Provider): provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - base_url: str = Field(LETTA_EMBEDDING_ENDPOINT, description="Base URL for the Letta API (used for embeddings).") async def list_llm_models_async(self) -> list[LLMConfig]: return [ @@ -34,7 +31,7 @@ class LettaProvider(Provider): EmbeddingConfig( embedding_model="letta-free", # NOTE: renamed embedding_endpoint_type="openai", - embedding_endpoint=self.base_url, + embedding_endpoint="https://embeddings.letta.com/", embedding_dim=1536, embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, handle=self.get_handle("letta-free", is_embedding=True), diff --git a/letta/schemas/secret.py b/letta/schemas/secret.py index 9471ff03..2790c4aa 100644 --- a/letta/schemas/secret.py +++ b/letta/schemas/secret.py @@ -113,7 +113,7 @@ class Secret(BaseModel): "MIGRATION_NEEDED: Reading from plaintext column instead of encrypted column. " "This indicates data that hasn't been migrated to the _enc column yet. " "Please run migrate data to _enc columns as plaintext columns will be deprecated.", - # stack_info=True, + stack_info=True, ) return cls.from_plaintext(plaintext_value) return cls.from_plaintext(None) diff --git a/letta/server/server.py b/letta/server/server.py index 59267582..20cd7f85 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -109,7 +109,7 @@ from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import DatabaseChoice, model_settings, settings, tool_settings from letta.streaming_interface import AgentChunkStreamingInterface -from letta.utils import get_friendly_error_msg, get_persona_text, safe_create_task +from letta.utils import get_friendly_error_msg, get_persona_text, make_key, safe_create_task config = LettaConfig.load() logger = get_logger(__name__) @@ -203,10 +203,12 @@ class SyncServer(object): """Initialize the MCP clients (there may be multiple)""" self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {} - # collect providers (always has Letta as a default) - from letta.constants import LETTA_MODEL_ENDPOINT + # TODO: Remove these in memory caches + self._llm_config_cache = {} + self._embedding_config_cache = {} - self._enabled_providers: List[Provider] = [LettaProvider(name="letta", base_url=LETTA_MODEL_ENDPOINT)] + # collect providers (always has Letta as a default) + self._enabled_providers: List[Provider] = [LettaProvider(name="letta")] if model_settings.openai_api_key: self._enabled_providers.append( OpenAIProvider( @@ -340,12 +342,6 @@ class SyncServer(object): print(f"Default user: {self.default_user} and org: {self.default_org}") await self.tool_manager.upsert_base_tools_async(actor=self.default_user) - # Sync environment-based providers to database (idempotent, safe for multi-pod startup) - await self.provider_manager.sync_base_providers(base_providers=self._enabled_providers, actor=self.default_user) - - # Sync provider models to database - await self._sync_provider_models_async() - # For OSS users, create a local sandbox config oss_default_user = await self.user_manager.get_default_actor_async() use_venv = False if not tool_settings.tool_exec_venv_name else True @@ -382,65 +378,6 @@ class SyncServer(object): force_recreate=True, ) - def _get_enabled_provider(self, provider_name: str) -> Optional[Provider]: - """Find and return an enabled provider by name. - - Args: - provider_name: The name of the provider to find - - Returns: - The matching enabled provider, or None if not found - """ - for provider in self._enabled_providers: - if provider.name == provider_name: - return provider - return None - - async def _sync_provider_models_async(self): - """Sync all provider models to database at startup.""" - logger.info("Syncing provider models to database") - - # Get persisted providers from database (they now have IDs) - persisted_providers = await self.provider_manager.list_providers_async(actor=self.default_user) - - for persisted_provider in persisted_providers: - try: - # Find the matching enabled provider instance to call list_models on - enabled_provider = self._get_enabled_provider(persisted_provider.name) - - if not enabled_provider: - # Only delete base providers that are no longer enabled - # BYOK providers are user-created and should not be automatically deleted - if persisted_provider.provider_category == ProviderCategory.base: - logger.info(f"Base provider {persisted_provider.name} is no longer enabled, deleting from database") - try: - await self.provider_manager.delete_provider_by_id_async( - provider_id=persisted_provider.id, actor=self.default_user - ) - except NoResultFound: - # Provider was already deleted (race condition in multi-pod startup) - logger.debug(f"Provider {persisted_provider.name} was already deleted, skipping") - else: - logger.debug(f"No enabled provider for BYOK provider {persisted_provider.name}, skipping model sync") - continue - - # Fetch models from provider - llm_models = await enabled_provider.list_llm_models_async() - embedding_models = await enabled_provider.list_embedding_models_async() - - # Save to database with the persisted provider (which has an ID) - await self.provider_manager.sync_provider_models_async( - provider=persisted_provider, - llm_models=llm_models, - embedding_models=embedding_models, - organization_id=None, # Global models - ) - logger.info( - f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}" - ) - except Exception as e: - logger.error(f"Failed to sync models for provider {persisted_provider.name}: {e}", exc_info=True) - async def init_mcp_clients(self): # TODO: remove this mcp_server_configs = self.get_mcp_servers() @@ -468,6 +405,39 @@ class SyncServer(object): logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}") logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}") + @trace_method + def get_cached_llm_config(self, actor: User, **kwargs): + key = make_key(**kwargs) + if key not in self._llm_config_cache: + self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs) + logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries") + return self._llm_config_cache[key] + + @trace_method + async def get_cached_llm_config_async(self, actor: User, **kwargs): + key = make_key(**kwargs) + if key not in self._llm_config_cache: + self._llm_config_cache[key] = await self.get_llm_config_from_handle_async(actor=actor, **kwargs) + logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries") + return self._llm_config_cache[key] + + @trace_method + def get_cached_embedding_config(self, actor: User, **kwargs): + key = make_key(**kwargs) + if key not in self._embedding_config_cache: + self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs) + logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries") + return self._embedding_config_cache[key] + + # @async_redis_cache(key_func=lambda (actor, **kwargs): actor.id + hash(kwargs)) + @trace_method + async def get_cached_embedding_config_async(self, actor: User, **kwargs): + key = make_key(**kwargs) + if key not in self._embedding_config_cache: + self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs) + logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries") + return self._embedding_config_cache[key] + @trace_method async def create_agent_async( self, @@ -501,9 +471,10 @@ class SyncServer(object): "max_reasoning_tokens": request.max_reasoning_tokens, "enable_reasoner": request.enable_reasoner, } - log_event(name="start get_llm_config_from_handle", attributes=config_params) - request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params) - log_event(name="end get_llm_config_from_handle", attributes=config_params) + config_params.update(additional_config_params) + log_event(name="start get_cached_llm_config", attributes=config_params) + request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params) + log_event(name="end get_cached_llm_config", attributes=config_params) if request.model and isinstance(request.model, str): assert request.llm_config.handle == request.model, ( f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}" @@ -533,9 +504,9 @@ class SyncServer(object): "handle": request.embedding, "embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE, } - log_event(name="start get_embedding_config_from_handle", attributes=embedding_config_params) - request.embedding_config = await self.get_embedding_config_from_handle_async(actor=actor, **embedding_config_params) - log_event(name="end get_embedding_config_from_handle", attributes=embedding_config_params) + log_event(name="start get_cached_embedding_config", attributes=embedding_config_params) + request.embedding_config = await self.get_cached_embedding_config_async(actor=actor, **embedding_config_params) + log_event(name="end get_cached_embedding_config", attributes=embedding_config_params) log_event(name="start create_agent db") main_agent = await self.agent_manager.create_agent_async( @@ -584,9 +555,9 @@ class SyncServer(object): "context_window_limit": request.context_window_limit, "max_tokens": request.max_tokens, } - log_event(name="start get_llm_config_from_handle", attributes=config_params) - request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params) - log_event(name="end get_llm_config_from_handle", attributes=config_params) + log_event(name="start get_cached_llm_config", attributes=config_params) + request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params) + log_event(name="end get_cached_llm_config", attributes=config_params) # update with model_settings if request.model_settings is not None: @@ -1090,85 +1061,73 @@ class SyncServer(object): provider_name: Optional[str] = None, provider_type: Optional[ProviderType] = None, ) -> List[LLMConfig]: - """List available LLM models from database cache""" - # Get provider IDs if filtering by provider - provider_ids = None - if provider_name or provider_type: - providers = await self.get_enabled_providers_async( - provider_category=provider_category, - provider_name=provider_name, - provider_type=provider_type, - actor=actor, - ) - provider_ids = [p.id for p in providers] + """Asynchronously list available models with maximum concurrency""" + import asyncio - # If filtering was requested but no providers matched, return empty list - if not provider_ids: + providers = await self.get_enabled_providers_async( + provider_category=provider_category, + provider_name=provider_name, + provider_type=provider_type, + actor=actor, + ) + + async def get_provider_models(provider: Provider) -> list[LLMConfig]: + try: + async with asyncio.timeout(constants.GET_PROVIDERS_TIMEOUT_SECONDS): + return await provider.list_llm_models_async() + except asyncio.TimeoutError: + logger.warning(f"Timeout while listing LLM models for provider {provider}") + return [] + except Exception as e: + logger.exception(f"Error while listing LLM models for provider {provider}: {e}") return [] - # Get models from database - provider_models = await self.provider_manager.list_models_async( - actor=actor, - model_type="llm", - provider_id=provider_ids[0] if provider_ids and len(provider_ids) == 1 else None, - enabled=True, - ) + # Execute all provider model listing tasks concurrently + provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers]) - # Build LLMConfig objects from cached data - # Cache providers to avoid N+1 queries - provider_cache: Dict[str, Provider] = {} + # Flatten the results llm_models = [] - for model in provider_models: - # Skip if filtering by provider and model doesn't match - if provider_ids and model.provider_id not in provider_ids: - continue + for models in provider_results: + llm_models.extend(models) - # Get provider details (with caching to avoid N+1 queries) - if model.provider_id not in provider_cache: - provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor) - provider = provider_cache[model.provider_id] + # Get local configs - if this is potentially slow, consider making it async too + local_configs = self.get_local_llm_configs() + llm_models.extend(local_configs) - llm_config = LLMConfig( - model=model.name, - model_endpoint_type=model.model_endpoint_type, - model_endpoint=provider.base_url or model.model_endpoint_type, - context_window=model.max_context_window or 16384, - handle=model.handle, - provider_name=provider.name, - provider_category=provider.provider_category, - ) - llm_models.append(llm_config) + # dedupe by handle for uniqueness + # Seems like this is required from the tests? + seen_handles = set() + unique_models = [] + for model in llm_models: + if model.handle not in seen_handles: + seen_handles.add(model.handle) + unique_models.append(model) - return llm_models + return unique_models async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]: - """List available embedding models from database cache""" - # Get models from database - provider_models = await self.provider_manager.list_models_async( - actor=actor, - model_type="embedding", - enabled=True, - ) + """Asynchronously list available embedding models with maximum concurrency""" + import asyncio - # Build EmbeddingConfig objects from cached data - # Cache providers to avoid N+1 queries - provider_cache: Dict[str, Provider] = {} + # Get all eligible providers first + providers = await self.get_enabled_providers_async(actor=actor) + + # Fetch embedding models from each provider concurrently + async def get_provider_embedding_models(provider): + try: + # All providers now have list_embedding_models_async + return await provider.list_embedding_models_async() + except Exception as e: + logger.exception(f"An error occurred while listing embedding models for provider {provider}: {e}") + return [] + + # Execute all provider model listing tasks concurrently + provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers]) + + # Flatten the results embedding_models = [] - for model in provider_models: - # Get provider details (with caching to avoid N+1 queries) - if model.provider_id not in provider_cache: - provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor) - provider = provider_cache[model.provider_id] - - embedding_config = EmbeddingConfig( - embedding_model=model.name, - embedding_endpoint_type=model.model_endpoint_type, - embedding_endpoint=provider.base_url or model.model_endpoint_type, - embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default - embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE, - handle=model.handle, - ) - embedding_models.append(embedding_config) + for models in provider_results: + embedding_models.extend(models) return embedding_models @@ -1181,22 +1140,17 @@ class SyncServer(object): ) -> List[Provider]: providers = [] if not provider_category or ProviderCategory.base in provider_category: - # Add enabled providers (base providers from environment) - enabled_providers = [p for p in self._enabled_providers] - providers.extend(enabled_providers) + providers_from_env = [p for p in self._enabled_providers] + providers.extend(providers_from_env) if not provider_category or ProviderCategory.byok in provider_category: - # Add persisted BYOK providers from database - # Note: list_providers_async returns both org-specific and global providers, - # so we filter to only include BYOK providers to avoid duplicating base providers - persisted_providers = await self.provider_manager.list_providers_async( + providers_from_db = await self.provider_manager.list_providers_async( name=provider_name, provider_type=provider_type, actor=actor, ) - # Filter to only BYOK providers (base providers are already in self._enabled_providers) - persisted_byok_providers = [p.cast_to_subtype() for p in persisted_providers if p.provider_category == ProviderCategory.byok] - providers.extend(persisted_byok_providers) + providers_from_db = [p.cast_to_subtype() for p in providers_from_db] + providers.extend(providers_from_db) if provider_name is not None: providers = [p for p in providers if p.name == provider_name] @@ -1216,19 +1170,32 @@ class SyncServer(object): max_reasoning_tokens: Optional[int] = None, enable_reasoner: Optional[bool] = None, ) -> LLMConfig: - # Use provider_manager to get LLMConfig from handle try: - llm_config = await self.provider_manager.get_llm_config_from_handle( - handle=handle, - actor=actor, - ) - except Exception as e: - # Convert to HandleNotFoundError for backwards compatibility - from letta.orm.errors import NoResultFound + provider_name, model_name = handle.split("/", 1) + provider = await self.get_provider_from_name_async(provider_name, actor) - if isinstance(e, NoResultFound): - raise HandleNotFoundError(handle, []) - raise + all_llm_configs = await provider.list_llm_models_async() + llm_configs = [config for config in all_llm_configs if config.handle == handle] + if not llm_configs: + llm_configs = [config for config in all_llm_configs if config.model == model_name] + if not llm_configs: + available_handles = [config.handle for config in all_llm_configs] + raise HandleNotFoundError(handle, available_handles) + except ValueError as e: + llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle] + if not llm_configs: + llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name] + if not llm_configs: + raise e + + if len(llm_configs) == 1: + llm_config = llm_configs[0] + elif len(llm_configs) > 1: + raise LettaInvalidArgumentError( + f"Multiple LLM models with name {model_name} supported by {provider_name}", argument_name="model_name" + ) + else: + llm_config = llm_configs[0] if context_window_limit is not None: if context_window_limit > llm_config.context_window: @@ -1260,22 +1227,33 @@ class SyncServer(object): async def get_embedding_config_from_handle_async( self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE ) -> EmbeddingConfig: - # Use provider_manager to get EmbeddingConfig from handle try: - embedding_config = await self.provider_manager.get_embedding_config_from_handle( - handle=handle, - actor=actor, + provider_name, model_name = handle.split("/", 1) + provider = await self.get_provider_from_name_async(provider_name, actor) + + all_embedding_configs = await provider.list_embedding_models_async() + embedding_configs = [config for config in all_embedding_configs if config.handle == handle] + if not embedding_configs: + raise LettaInvalidArgumentError( + f"Embedding model {model_name} is not supported by {provider_name}", argument_name="model_name" + ) + except LettaInvalidArgumentError as e: + # search local configs + embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle] + if not embedding_configs: + raise e + + if len(embedding_configs) == 1: + embedding_config = embedding_configs[0] + elif len(embedding_configs) > 1: + raise LettaInvalidArgumentError( + f"Multiple embedding models with name {model_name} supported by {provider_name}", argument_name="model_name" ) - except Exception as e: - # Convert to LettaInvalidArgumentError for backwards compatibility - from letta.orm.errors import NoResultFound + else: + embedding_config = embedding_configs[0] - if isinstance(e, NoResultFound): - raise LettaInvalidArgumentError(f"Embedding model {handle} not found", argument_name="handle") - raise - - # Override chunk size if provided - embedding_config.embedding_chunk_size = embedding_chunk_size + if embedding_chunk_size: + embedding_config.embedding_chunk_size = embedding_chunk_size return embedding_config @@ -1294,6 +1272,46 @@ class SyncServer(object): return provider + def get_local_llm_configs(self): + llm_models = [] + # NOTE: deprecated + # try: + # llm_configs_dir = os.path.expanduser("~/.letta/llm_configs") + # if os.path.exists(llm_configs_dir): + # for filename in os.listdir(llm_configs_dir): + # if filename.endswith(".json"): + # filepath = os.path.join(llm_configs_dir, filename) + # try: + # with open(filepath, "r") as f: + # config_data = json.load(f) + # llm_config = LLMConfig(**config_data) + # llm_models.append(llm_config) + # except (json.JSONDecodeError, ValueError) as e: + # logger.warning(f"Error parsing LLM config file {filename}: {e}") + # except Exception as e: + # logger.warning(f"Error reading LLM configs directory: {e}") + return llm_models + + def get_local_embedding_configs(self): + embedding_models = [] + # NOTE: deprecated + # try: + # embedding_configs_dir = os.path.expanduser("~/.letta/embedding_configs") + # if os.path.exists(embedding_configs_dir): + # for filename in os.listdir(embedding_configs_dir): + # if filename.endswith(".json"): + # filepath = os.path.join(embedding_configs_dir, filename) + # try: + # with open(filepath, "r") as f: + # config_data = json.load(f) + # embedding_config = EmbeddingConfig(**config_data) + # embedding_models.append(embedding_config) + # except (json.JSONDecodeError, ValueError) as e: + # logger.warning(f"Error parsing embedding config file {filename}: {e}") + # except Exception as e: + # logger.warning(f"Error reading embedding configs directory: {e}") + return embedding_models + def add_llm_model(self, request: LLMConfig) -> LLMConfig: """Add a new LLM model""" diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 8d327e4c..f99f7944 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -154,7 +154,7 @@ class ProviderManager: @trace_method @raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER) async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser): - """Delete a provider and its associated models.""" + """Delete a provider.""" async with db_registry.async_session() as session: # Clear api key field existing_provider = await ProviderModel.read_async( @@ -163,15 +163,6 @@ class ProviderManager: existing_provider.api_key = None await existing_provider.update_async(session, actor=actor) - # Soft delete all models associated with this provider - provider_models = await ProviderModelORM.list_async( - db_session=session, - provider_id=provider_id, - check_is_deleted=True, - ) - for model in provider_models: - await model.delete_async(session, actor=actor) - # Soft delete in provider table await existing_provider.delete_async(session, actor=actor) @@ -640,11 +631,11 @@ class ProviderManager: await model.create_async(session) logger.info(f" ✓ Successfully created LLM model {llm_config.handle} with ID {model.id}") except Exception as e: - logger.info(f" ✗ Failed to create LLM model {llm_config.handle}: {e}") + logger.error(f" ✗ Failed to create LLM model {llm_config.handle}: {e}") # Log the full error details import traceback - logger.info(f" Full traceback: {traceback.format_exc()}") + logger.error(f" Full traceback: {traceback.format_exc()}") # Roll back the session to clear the failed transaction await session.rollback() else: diff --git a/tests/managers/test_provider_manager.py b/tests/managers/test_provider_manager.py index dd0bda6b..61e2597f 100644 --- a/tests/managers/test_provider_manager.py +++ b/tests/managers/test_provider_manager.py @@ -499,436 +499,3 @@ async def test_byok_provider_auto_syncs_models(provider_manager, default_user, m llm_config = await provider_manager.get_llm_config_from_handle(handle="my-openai-key/gpt-4o", actor=default_user) assert llm_config.model == "gpt-4o" assert llm_config.provider_name == "my-openai-key" - - -# ====================================================================================================================== -# Server Startup Provider Sync Tests -# ====================================================================================================================== - - -@pytest.mark.asyncio -async def test_server_startup_syncs_base_providers(default_user, default_organization, monkeypatch): - """Test that server startup properly syncs base provider models from environment. - - This test simulates the server startup process and verifies that: - 1. Base providers from environment variables are synced to database - 2. Provider models are fetched from mocked API endpoints - 3. Models are properly persisted to the database with correct metadata - 4. Models can be retrieved using handles - """ - from unittest.mock import AsyncMock - - from letta.schemas.embedding_config import EmbeddingConfig - from letta.schemas.llm_config import LLMConfig - from letta.schemas.providers import AnthropicProvider, OpenAIProvider - from letta.server.server import SyncServer - - # Mock OpenAI API responses - mock_openai_models = { - "data": [ - { - "id": "gpt-4", - "object": "model", - "created": 1687882411, - "owned_by": "openai", - "max_model_len": 8192, - }, - { - "id": "gpt-4-turbo", - "object": "model", - "created": 1712361441, - "owned_by": "system", - "max_model_len": 128000, - }, - { - "id": "text-embedding-ada-002", - "object": "model", - "created": 1671217299, - "owned_by": "openai-internal", - }, - { - "id": "gpt-4-vision", # Should be filtered out by OpenAI provider logic (has disallowed keyword) - "object": "model", - "created": 1698959748, - "owned_by": "system", - "max_model_len": 8192, - }, - ] - } - - # Mock Anthropic API responses - mock_anthropic_models = { - "data": [ - { - "id": "claude-3-5-sonnet-20241022", - "type": "model", - "display_name": "Claude 3.5 Sonnet", - "created_at": "2024-10-22T00:00:00Z", - }, - { - "id": "claude-3-opus-20240229", - "type": "model", - "display_name": "Claude 3 Opus", - "created_at": "2024-02-29T00:00:00Z", - }, - ] - } - - # Mock the API calls for OpenAI - async def mock_openai_get_model_list_async(*args, **kwargs): - return mock_openai_models - - # Mock Anthropic models.list() response - from unittest.mock import MagicMock - - mock_anthropic_response = MagicMock() - mock_anthropic_response.model_dump.return_value = mock_anthropic_models - - # Mock the Anthropic AsyncAnthropic client - class MockAnthropicModels: - async def list(self): - return mock_anthropic_response - - class MockAsyncAnthropic: - def __init__(self, *args, **kwargs): - self.models = MockAnthropicModels() - - # Patch the actual API calling functions - monkeypatch.setattr( - "letta.llm_api.openai.openai_get_model_list_async", - mock_openai_get_model_list_async, - ) - monkeypatch.setattr( - "anthropic.AsyncAnthropic", - MockAsyncAnthropic, - ) - - # Clear ALL provider-related env vars first to ensure clean state - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - monkeypatch.delenv("GEMINI_API_KEY", raising=False) - monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False) - monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False) - monkeypatch.delenv("AZURE_API_KEY", raising=False) - monkeypatch.delenv("GROQ_API_KEY", raising=False) - monkeypatch.delenv("TOGETHER_API_KEY", raising=False) - monkeypatch.delenv("VLLM_API_BASE", raising=False) - monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) - monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) - monkeypatch.delenv("LMSTUDIO_BASE_URL", raising=False) - monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False) - monkeypatch.delenv("XAI_API_KEY", raising=False) - monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) - - # Set environment variables to enable only OpenAI and Anthropic - monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key-12345") - monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key-67890") - - # Reload model_settings to pick up new env vars - from letta.settings import model_settings - - monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key-12345") - monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key-67890") - monkeypatch.setattr(model_settings, "gemini_api_key", None) - monkeypatch.setattr(model_settings, "google_cloud_project", None) - monkeypatch.setattr(model_settings, "google_cloud_location", None) - monkeypatch.setattr(model_settings, "azure_api_key", None) - monkeypatch.setattr(model_settings, "groq_api_key", None) - monkeypatch.setattr(model_settings, "together_api_key", None) - monkeypatch.setattr(model_settings, "vllm_api_base", None) - monkeypatch.setattr(model_settings, "aws_access_key_id", None) - monkeypatch.setattr(model_settings, "aws_secret_access_key", None) - monkeypatch.setattr(model_settings, "lmstudio_base_url", None) - monkeypatch.setattr(model_settings, "deepseek_api_key", None) - monkeypatch.setattr(model_settings, "xai_api_key", None) - monkeypatch.setattr(model_settings, "openrouter_api_key", None) - - # Create server instance (this will load enabled providers from environment) - server = SyncServer(init_with_default_org_and_user=False) - - # Manually set up the default user/org (since we disabled auto-init) - server.default_user = default_user - server.default_org = default_organization - - # Verify enabled providers were loaded - assert len(server._enabled_providers) == 3 # Exactly: letta, openai, anthropic - enabled_provider_names = [p.name for p in server._enabled_providers] - assert "letta" in enabled_provider_names - assert "openai" in enabled_provider_names - assert "anthropic" in enabled_provider_names - - # First, sync base providers to database (this is what init_async does) - await server.provider_manager.sync_base_providers( - base_providers=server._enabled_providers, - actor=default_user, - ) - - # Now call the actual _sync_provider_models_async method - # This simulates what happens during server startup - await server._sync_provider_models_async() - - # Verify OpenAI models were synced - openai_providers = await server.provider_manager.list_providers_async( - name="openai", - actor=default_user, - ) - assert len(openai_providers) == 1, "OpenAI provider should exist" - openai_provider = openai_providers[0] - - # Check OpenAI LLM models - openai_llm_models = await server.provider_manager.list_models_async( - actor=default_user, - provider_id=openai_provider.id, - model_type="llm", - ) - - # Should have gpt-4 and gpt-4-turbo (gpt-4-vision filtered out due to "vision" keyword) - assert len(openai_llm_models) >= 2, f"Expected at least 2 OpenAI LLM models, got {len(openai_llm_models)}" - openai_model_names = [m.name for m in openai_llm_models] - assert "gpt-4" in openai_model_names - assert "gpt-4-turbo" in openai_model_names - - # Check OpenAI embedding models - openai_embedding_models = await server.provider_manager.list_models_async( - actor=default_user, - provider_id=openai_provider.id, - model_type="embedding", - ) - assert len(openai_embedding_models) >= 1, "Expected at least 1 OpenAI embedding model" - embedding_model_names = [m.name for m in openai_embedding_models] - assert "text-embedding-ada-002" in embedding_model_names - - # Verify model metadata is correct - gpt4_models = [m for m in openai_llm_models if m.name == "gpt-4"] - assert len(gpt4_models) > 0, "gpt-4 model should exist" - gpt4_model = gpt4_models[0] - assert gpt4_model.handle == "openai/gpt-4" - assert gpt4_model.model_endpoint_type == "openai" - assert gpt4_model.max_context_window == 8192 - assert gpt4_model.enabled is True - - # Verify Anthropic models were synced - anthropic_providers = await server.provider_manager.list_providers_async( - name="anthropic", - actor=default_user, - ) - assert len(anthropic_providers) == 1, "Anthropic provider should exist" - anthropic_provider = anthropic_providers[0] - - anthropic_llm_models = await server.provider_manager.list_models_async( - actor=default_user, - provider_id=anthropic_provider.id, - model_type="llm", - ) - - # Should have Claude models - assert len(anthropic_llm_models) >= 2, f"Expected at least 2 Anthropic models, got {len(anthropic_llm_models)}" - anthropic_model_names = [m.name for m in anthropic_llm_models] - assert "claude-3-5-sonnet-20241022" in anthropic_model_names - assert "claude-3-opus-20240229" in anthropic_model_names - - # Test that we can retrieve LLMConfig from handle - llm_config = await server.provider_manager.get_llm_config_from_handle( - handle="openai/gpt-4", - actor=default_user, - ) - assert llm_config.model == "gpt-4" - assert llm_config.handle == "openai/gpt-4" - assert llm_config.provider_name == "openai" - assert llm_config.context_window == 8192 - - # Test that we can retrieve EmbeddingConfig from handle - embedding_config = await server.provider_manager.get_embedding_config_from_handle( - handle="openai/text-embedding-ada-002", - actor=default_user, - ) - assert embedding_config.embedding_model == "text-embedding-ada-002" - assert embedding_config.handle == "openai/text-embedding-ada-002" - assert embedding_config.embedding_dim == 1536 - - -@pytest.mark.asyncio -async def test_server_startup_handles_disabled_providers(default_user, default_organization, monkeypatch): - """Test that server startup properly handles providers that are no longer enabled. - - This test verifies that: - 1. Base providers that are no longer enabled (env vars removed) are deleted - 2. BYOK providers that are no longer enabled are NOT deleted (user-created) - 3. The sync process handles providers gracefully when API calls fail - """ - from letta.schemas.providers import OpenAIProvider, ProviderCreate - from letta.server.server import SyncServer - - # First, manually create providers in the database - provider_manager = ProviderManager() - - # Create a base OpenAI provider (simulating it was synced before) - base_openai_create = ProviderCreate( - name="openai", - provider_type=ProviderType.openai, - api_key="sk-old-key", - base_url="https://api.openai.com/v1", - ) - base_openai = await provider_manager.create_provider_async( - base_openai_create, - actor=default_user, - is_byok=False, # This is a base provider - ) - - # Create a BYOK provider (user-created) - byok_provider_create = ProviderCreate( - name="my-custom-openai", - provider_type=ProviderType.openai, - api_key="sk-my-key", - base_url="https://api.openai.com/v1", - ) - byok_provider = await provider_manager.create_provider_async( - byok_provider_create, - actor=default_user, - is_byok=True, - ) - assert byok_provider.provider_category == ProviderCategory.byok - - # Now create server with NO environment variables set (all base providers disabled) - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - - from letta.settings import model_settings - - monkeypatch.setattr(model_settings, "openai_api_key", None) - monkeypatch.setattr(model_settings, "anthropic_api_key", None) - - # Create server instance - server = SyncServer(init_with_default_org_and_user=False) - server.default_user = default_user - server.default_org = default_organization - - # Verify only letta provider is enabled (no openai) - enabled_names = [p.name for p in server._enabled_providers] - assert "letta" in enabled_names - assert "openai" not in enabled_names - - # Sync base providers (should not include openai anymore) - await server.provider_manager.sync_base_providers( - base_providers=server._enabled_providers, - actor=default_user, - ) - - # Call _sync_provider_models_async - await server._sync_provider_models_async() - - # Verify base OpenAI provider was deleted (no longer enabled) - try: - await server.provider_manager.get_provider_async(base_openai.id, actor=default_user) - assert False, "Base OpenAI provider should have been deleted" - except Exception: - # Expected - provider should not exist - pass - - # Verify BYOK provider still exists (should NOT be deleted) - byok_still_exists = await server.provider_manager.get_provider_async( - byok_provider.id, - actor=default_user, - ) - assert byok_still_exists is not None - assert byok_still_exists.name == "my-custom-openai" - assert byok_still_exists.provider_category == ProviderCategory.byok - - -@pytest.mark.asyncio -async def test_server_startup_handles_api_errors_gracefully(default_user, default_organization, monkeypatch): - """Test that server startup handles API errors gracefully without crashing. - - This test verifies that: - 1. If a provider's API call fails during sync, it logs an error but continues - 2. Other providers can still sync successfully - 3. The server startup completes without crashing - """ - from letta.schemas.providers import AnthropicProvider, OpenAIProvider - from letta.server.server import SyncServer - - # Mock OpenAI to fail - async def mock_openai_fail(*args, **kwargs): - raise Exception("OpenAI API is down") - - # Mock Anthropic to succeed - from unittest.mock import MagicMock - - mock_anthropic_response = MagicMock() - mock_anthropic_response.model_dump.return_value = { - "data": [ - { - "id": "claude-3-5-sonnet-20241022", - "type": "model", - "display_name": "Claude 3.5 Sonnet", - "created_at": "2024-10-22T00:00:00Z", - } - ] - } - - class MockAnthropicModels: - async def list(self): - return mock_anthropic_response - - class MockAsyncAnthropic: - def __init__(self, *args, **kwargs): - self.models = MockAnthropicModels() - - monkeypatch.setattr( - "letta.llm_api.openai.openai_get_model_list_async", - mock_openai_fail, - ) - monkeypatch.setattr( - "anthropic.AsyncAnthropic", - MockAsyncAnthropic, - ) - - # Set environment variables - monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key") - monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key") - - from letta.settings import model_settings - - monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key") - monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key") - - # Create server - server = SyncServer(init_with_default_org_and_user=False) - server.default_user = default_user - server.default_org = default_organization - - # Sync base providers - await server.provider_manager.sync_base_providers( - base_providers=server._enabled_providers, - actor=default_user, - ) - - # This should NOT crash even though OpenAI fails - await server._sync_provider_models_async() - - # Verify Anthropic still synced successfully - anthropic_providers = await server.provider_manager.list_providers_async( - name="anthropic", - actor=default_user, - ) - assert len(anthropic_providers) == 1 - - anthropic_models = await server.provider_manager.list_models_async( - actor=default_user, - provider_id=anthropic_providers[0].id, - model_type="llm", - ) - assert len(anthropic_models) >= 1, "Anthropic models should have synced despite OpenAI failure" - - # OpenAI should have no models (sync failed) - openai_providers = await server.provider_manager.list_providers_async( - name="openai", - actor=default_user, - ) - if len(openai_providers) > 0: - openai_models = await server.provider_manager.list_models_async( - actor=default_user, - provider_id=openai_providers[0].id, - ) - # Models might exist from previous runs, but the sync attempt should have been logged as failed - # The key is that the server didn't crash diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 0272e0e3..cfc05dd3 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -32,7 +32,6 @@ from letta.config import LettaConfig from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.server.server import SyncServer from tests.helpers.utils import upload_file_and_wait -from tests.utils import wait_for_server # Constants SERVER_PORT = 8283 @@ -107,7 +106,7 @@ def client() -> LettaSDKClient: print("Starting server thread") thread = threading.Thread(target=run_server, daemon=True) thread.start() - wait_for_server(server_url, timeout=60) + time.sleep(5) print("Running client tests with server:", server_url) client = LettaSDKClient(base_url=server_url) diff --git a/tests/test_server_providers.py b/tests/test_server_providers.py index 67494bc5..2306c22f 100644 --- a/tests/test_server_providers.py +++ b/tests/test_server_providers.py @@ -1740,110 +1740,3 @@ async def test_handle_uniqueness_per_org(default_user, provider_manager): assert model is not None assert model.provider_id == provider_1.id # Still original provider assert model.max_context_window == 8192 # Still original - - -@pytest.mark.asyncio -async def test_delete_provider_cascades_to_models(default_user, provider_manager, monkeypatch): - """Test that deleting a provider also soft-deletes its associated models.""" - test_id = generate_test_id() - - # Mock _sync_default_models_for_provider to avoid external API calls - async def mock_sync(provider, actor): - pass # Don't actually sync - we'll manually create models below - - monkeypatch.setattr(provider_manager, "_sync_default_models_for_provider", mock_sync) - - # 1. Create a BYOK provider (org-scoped, so the actor can delete it) - provider_create = ProviderCreate( - name=f"test-cascade-{test_id}", - provider_type=ProviderType.openai, - api_key="sk-test-key", - ) - provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=True) - - # 2. Manually sync models to the provider - llm_models = [ - LLMConfig( - model=f"gpt-4o-{test_id}", - model_endpoint_type="openai", - model_endpoint="https://api.openai.com/v1", - context_window=128000, - handle=f"test-{test_id}/gpt-4o", - provider_name=provider.name, - provider_category=ProviderCategory.byok, - ), - LLMConfig( - model=f"gpt-4o-mini-{test_id}", - model_endpoint_type="openai", - model_endpoint="https://api.openai.com/v1", - context_window=16384, - handle=f"test-{test_id}/gpt-4o-mini", - provider_name=provider.name, - provider_category=ProviderCategory.byok, - ), - ] - - embedding_models = [ - EmbeddingConfig( - embedding_model=f"text-embedding-3-small-{test_id}", - embedding_endpoint_type="openai", - embedding_endpoint="https://api.openai.com/v1", - embedding_dim=1536, - embedding_chunk_size=300, - handle=f"test-{test_id}/text-embedding-3-small", - ), - ] - - await provider_manager.sync_provider_models_async( - provider=provider, - llm_models=llm_models, - embedding_models=embedding_models, - organization_id=default_user.organization_id, # Org-scoped for BYOK provider - ) - - # 3. Verify models exist before deletion - llm_models_before = await provider_manager.list_models_async( - actor=default_user, - model_type="llm", - provider_id=provider.id, - ) - embedding_models_before = await provider_manager.list_models_async( - actor=default_user, - model_type="embedding", - provider_id=provider.id, - ) - - llm_handles_before = {m.handle for m in llm_models_before} - embedding_handles_before = {m.handle for m in embedding_models_before} - - assert f"test-{test_id}/gpt-4o" in llm_handles_before - assert f"test-{test_id}/gpt-4o-mini" in llm_handles_before - assert f"test-{test_id}/text-embedding-3-small" in embedding_handles_before - - # 4. Delete the provider - await provider_manager.delete_provider_by_id_async(provider.id, actor=default_user) - - # 5. Verify models are soft-deleted (no longer returned in list) - all_llm_models_after = await provider_manager.list_models_async( - actor=default_user, - model_type="llm", - ) - all_embedding_models_after = await provider_manager.list_models_async( - actor=default_user, - model_type="embedding", - ) - - all_llm_handles_after = {m.handle for m in all_llm_models_after} - all_embedding_handles_after = {m.handle for m in all_embedding_models_after} - - # All models from the deleted provider should be gone - assert f"test-{test_id}/gpt-4o" not in all_llm_handles_after - assert f"test-{test_id}/gpt-4o-mini" not in all_llm_handles_after - assert f"test-{test_id}/text-embedding-3-small" not in all_embedding_handles_after - - # 6. Verify provider is also deleted - providers_after = await provider_manager.list_providers_async( - actor=default_user, - name=f"test-cascade-{test_id}", - ) - assert len(providers_after) == 0