diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 66c04c44..599969f3 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -235,24 +235,38 @@ class LLMClientBase: def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]: """ Returns the override key for the given llm config. + For both base and BYOK providers, fetch the API key from the database. """ api_key = None - if llm_config.provider_category == ProviderCategory.byok: + # Fetch API key from database for both base and BYOK providers + # This ensures that base providers (from environment) also have their keys persisted and accessible + if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]: from letta.services.provider_manager import ProviderManager api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor) + # If we got an empty string from the database (e.g., Letta provider), treat it as None + # so the client can fall back to environment variables or default behavior + if api_key == "": + api_key = None return api_key, None, None async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]: """ Returns the override key for the given llm config. + For both base and BYOK providers, fetch the API key from the database. """ api_key = None - if llm_config.provider_category == ProviderCategory.byok: + # Fetch API key from database for both base and BYOK providers + # This ensures that base providers (from environment) also have their keys persisted and accessible + if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]: from letta.services.provider_manager import ProviderManager api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor) + # If we got an empty string from the database (e.g., Letta provider), treat it as None + # so the client can fall back to environment variables or default behavior + if api_key == "": + api_key = None return api_key, None, None diff --git a/letta/schemas/providers/letta.py b/letta/schemas/providers/letta.py index 69eb0875..d843f1ba 100644 --- a/letta/schemas/providers/letta.py +++ b/letta/schemas/providers/letta.py @@ -8,10 +8,13 @@ from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.llm_config import LLMConfig from letta.schemas.providers.base import Provider +LETTA_EMBEDDING_ENDPOINT = "https://embeddings.letta.com/" + class LettaProvider(Provider): provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = Field(LETTA_EMBEDDING_ENDPOINT, description="Base URL for the Letta API (used for embeddings).") async def list_llm_models_async(self) -> list[LLMConfig]: return [ @@ -32,7 +35,7 @@ class LettaProvider(Provider): EmbeddingConfig( embedding_model="letta-free", # NOTE: renamed embedding_endpoint_type="openai", - embedding_endpoint="https://embeddings.letta.com/", + embedding_endpoint=self.base_url, embedding_dim=1536, embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, handle=self.get_handle("letta-free", is_embedding=True), diff --git a/letta/server/server.py b/letta/server/server.py index b581d8a2..4f606035 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -116,7 +116,7 @@ from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import DatabaseChoice, model_settings, settings, tool_settings from letta.streaming_interface import AgentChunkStreamingInterface -from letta.utils import get_friendly_error_msg, get_persona_text, make_key, safe_create_task +from letta.utils import get_friendly_error_msg, get_persona_text, safe_create_task config = LettaConfig.load() logger = get_logger(__name__) @@ -210,12 +210,10 @@ class SyncServer(object): """Initialize the MCP clients (there may be multiple)""" self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {} - # TODO: Remove these in memory caches - self._llm_config_cache = {} - self._embedding_config_cache = {} - # collect providers (always has Letta as a default) - self._enabled_providers: List[Provider] = [LettaProvider(name="letta")] + from letta.constants import LETTA_MODEL_ENDPOINT + + self._enabled_providers: List[Provider] = [LettaProvider(name="letta", base_url=LETTA_MODEL_ENDPOINT)] if model_settings.openai_api_key: self._enabled_providers.append( OpenAIProvider( @@ -347,6 +345,12 @@ class SyncServer(object): print(f"Default user: {self.default_user} and org: {self.default_org}") await self.tool_manager.upsert_base_tools_async(actor=self.default_user) + # Sync environment-based providers to database (idempotent, safe for multi-pod startup) + await self.provider_manager.sync_base_providers(base_providers=self._enabled_providers, actor=self.default_user) + + # Sync provider models to database + await self._sync_provider_models_async() + # For OSS users, create a local sandbox config oss_default_user = await self.user_manager.get_default_actor_async() use_venv = False if not tool_settings.tool_exec_venv_name else True @@ -383,6 +387,65 @@ class SyncServer(object): force_recreate=True, ) + def _get_enabled_provider(self, provider_name: str) -> Optional[Provider]: + """Find and return an enabled provider by name. + + Args: + provider_name: The name of the provider to find + + Returns: + The matching enabled provider, or None if not found + """ + for provider in self._enabled_providers: + if provider.name == provider_name: + return provider + return None + + async def _sync_provider_models_async(self): + """Sync all provider models to database at startup.""" + logger.info("Syncing provider models to database") + + # Get persisted providers from database (they now have IDs) + persisted_providers = await self.provider_manager.list_providers_async(actor=self.default_user) + + for persisted_provider in persisted_providers: + try: + # Find the matching enabled provider instance to call list_models on + enabled_provider = self._get_enabled_provider(persisted_provider.name) + + if not enabled_provider: + # Only delete base providers that are no longer enabled + # BYOK providers are user-created and should not be automatically deleted + if persisted_provider.provider_category == ProviderCategory.base: + logger.info(f"Base provider {persisted_provider.name} is no longer enabled, deleting from database") + try: + await self.provider_manager.delete_provider_by_id_async( + provider_id=persisted_provider.id, actor=self.default_user + ) + except NoResultFound: + # Provider was already deleted (race condition in multi-pod startup) + logger.debug(f"Provider {persisted_provider.name} was already deleted, skipping") + else: + logger.debug(f"No enabled provider for BYOK provider {persisted_provider.name}, skipping model sync") + continue + + # Fetch models from provider + llm_models = await enabled_provider.list_llm_models_async() + embedding_models = await enabled_provider.list_embedding_models_async() + + # Save to database with the persisted provider (which has an ID) + await self.provider_manager.sync_provider_models_async( + provider=persisted_provider, + llm_models=llm_models, + embedding_models=embedding_models, + organization_id=None, # Global models + ) + logger.info( + f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}" + ) + except Exception as e: + logger.error(f"Failed to sync models for provider {persisted_provider.name}: {e}", exc_info=True) + async def init_mcp_clients(self): # TODO: remove this mcp_server_configs = self.get_mcp_servers() @@ -410,39 +473,6 @@ class SyncServer(object): logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}") logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}") - @trace_method - def get_cached_llm_config(self, actor: User, **kwargs): - key = make_key(**kwargs) - if key not in self._llm_config_cache: - self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs) - logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries") - return self._llm_config_cache[key] - - @trace_method - async def get_cached_llm_config_async(self, actor: User, **kwargs): - key = make_key(**kwargs) - if key not in self._llm_config_cache: - self._llm_config_cache[key] = await self.get_llm_config_from_handle_async(actor=actor, **kwargs) - logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries") - return self._llm_config_cache[key] - - @trace_method - def get_cached_embedding_config(self, actor: User, **kwargs): - key = make_key(**kwargs) - if key not in self._embedding_config_cache: - self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs) - logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries") - return self._embedding_config_cache[key] - - # @async_redis_cache(key_func=lambda (actor, **kwargs): actor.id + hash(kwargs)) - @trace_method - async def get_cached_embedding_config_async(self, actor: User, **kwargs): - key = make_key(**kwargs) - if key not in self._embedding_config_cache: - self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs) - logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries") - return self._embedding_config_cache[key] - @trace_method async def create_agent_async( self, @@ -476,10 +506,9 @@ class SyncServer(object): "max_reasoning_tokens": request.max_reasoning_tokens, "enable_reasoner": request.enable_reasoner, } - config_params.update(additional_config_params) - log_event(name="start get_cached_llm_config", attributes=config_params) - request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params) - log_event(name="end get_cached_llm_config", attributes=config_params) + log_event(name="start get_llm_config_from_handle", attributes=config_params) + request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params) + log_event(name="end get_llm_config_from_handle", attributes=config_params) if request.model and isinstance(request.model, str): assert request.llm_config.handle == request.model, ( f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}" @@ -507,9 +536,9 @@ class SyncServer(object): "handle": request.embedding, "embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE, } - log_event(name="start get_cached_embedding_config", attributes=embedding_config_params) - request.embedding_config = await self.get_cached_embedding_config_async(actor=actor, **embedding_config_params) - log_event(name="end get_cached_embedding_config", attributes=embedding_config_params) + log_event(name="start get_embedding_config_from_handle", attributes=embedding_config_params) + request.embedding_config = await self.get_embedding_config_from_handle_async(actor=actor, **embedding_config_params) + log_event(name="end get_embedding_config_from_handle", attributes=embedding_config_params) log_event(name="start create_agent db") main_agent = await self.agent_manager.create_agent_async( @@ -558,9 +587,9 @@ class SyncServer(object): "context_window_limit": request.context_window_limit, "max_tokens": request.max_tokens, } - log_event(name="start get_cached_llm_config", attributes=config_params) - request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params) - log_event(name="end get_cached_llm_config", attributes=config_params) + log_event(name="start get_llm_config_from_handle", attributes=config_params) + request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params) + log_event(name="end get_llm_config_from_handle", attributes=config_params) # update with model_settings if request.model_settings is not None: @@ -1070,73 +1099,85 @@ class SyncServer(object): provider_name: Optional[str] = None, provider_type: Optional[ProviderType] = None, ) -> List[LLMConfig]: - """Asynchronously list available models with maximum concurrency""" - import asyncio + """List available LLM models from database cache""" + # Get provider IDs if filtering by provider + provider_ids = None + if provider_name or provider_type: + providers = await self.get_enabled_providers_async( + provider_category=provider_category, + provider_name=provider_name, + provider_type=provider_type, + actor=actor, + ) + provider_ids = [p.id for p in providers] - providers = await self.get_enabled_providers_async( - provider_category=provider_category, - provider_name=provider_name, - provider_type=provider_type, + # If filtering was requested but no providers matched, return empty list + if not provider_ids: + return [] + + # Get models from database + provider_models = await self.provider_manager.list_models_async( actor=actor, + model_type="llm", + provider_id=provider_ids[0] if provider_ids and len(provider_ids) == 1 else None, + enabled=True, ) - async def get_provider_models(provider: Provider) -> list[LLMConfig]: - try: - async with asyncio.timeout(constants.GET_PROVIDERS_TIMEOUT_SECONDS): - return await provider.list_llm_models_async() - except asyncio.TimeoutError: - logger.warning(f"Timeout while listing LLM models for provider {provider}") - return [] - except Exception as e: - logger.exception(f"Error while listing LLM models for provider {provider}: {e}") - return [] - - # Execute all provider model listing tasks concurrently - provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers]) - - # Flatten the results + # Build LLMConfig objects from cached data + # Cache providers to avoid N+1 queries + provider_cache: Dict[str, Provider] = {} llm_models = [] - for models in provider_results: - llm_models.extend(models) + for model in provider_models: + # Skip if filtering by provider and model doesn't match + if provider_ids and model.provider_id not in provider_ids: + continue - # Get local configs - if this is potentially slow, consider making it async too - local_configs = self.get_local_llm_configs() - llm_models.extend(local_configs) + # Get provider details (with caching to avoid N+1 queries) + if model.provider_id not in provider_cache: + provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor) + provider = provider_cache[model.provider_id] - # dedupe by handle for uniqueness - # Seems like this is required from the tests? - seen_handles = set() - unique_models = [] - for model in llm_models: - if model.handle not in seen_handles: - seen_handles.add(model.handle) - unique_models.append(model) + llm_config = LLMConfig( + model=model.name, + model_endpoint_type=model.model_endpoint_type, + model_endpoint=provider.base_url or model.model_endpoint_type, + context_window=model.max_context_window or 16384, + handle=model.handle, + provider_name=provider.name, + provider_category=provider.provider_category, + ) + llm_models.append(llm_config) - return unique_models + return llm_models async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]: - """Asynchronously list available embedding models with maximum concurrency""" - import asyncio + """List available embedding models from database cache""" + # Get models from database + provider_models = await self.provider_manager.list_models_async( + actor=actor, + model_type="embedding", + enabled=True, + ) - # Get all eligible providers first - providers = await self.get_enabled_providers_async(actor=actor) - - # Fetch embedding models from each provider concurrently - async def get_provider_embedding_models(provider): - try: - # All providers now have list_embedding_models_async - return await provider.list_embedding_models_async() - except Exception as e: - logger.exception(f"An error occurred while listing embedding models for provider {provider}: {e}") - return [] - - # Execute all provider model listing tasks concurrently - provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers]) - - # Flatten the results + # Build EmbeddingConfig objects from cached data + # Cache providers to avoid N+1 queries + provider_cache: Dict[str, Provider] = {} embedding_models = [] - for models in provider_results: - embedding_models.extend(models) + for model in provider_models: + # Get provider details (with caching to avoid N+1 queries) + if model.provider_id not in provider_cache: + provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor) + provider = provider_cache[model.provider_id] + + embedding_config = EmbeddingConfig( + embedding_model=model.name, + embedding_endpoint_type=model.model_endpoint_type, + embedding_endpoint=provider.base_url or model.model_endpoint_type, + embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default + embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=model.handle, + ) + embedding_models.append(embedding_config) return embedding_models @@ -1147,25 +1188,17 @@ class SyncServer(object): provider_name: Optional[str] = None, provider_type: Optional[ProviderType] = None, ) -> List[Provider]: - providers = [] - if not provider_category or ProviderCategory.base in provider_category: - providers_from_env = [p for p in self._enabled_providers] - providers.extend(providers_from_env) + # Query all persisted providers from database + persisted_providers = await self.provider_manager.list_providers_async( + name=provider_name, + provider_type=provider_type, + actor=actor, + ) + providers = [p.cast_to_subtype() for p in persisted_providers] - if not provider_category or ProviderCategory.byok in provider_category: - providers_from_db = await self.provider_manager.list_providers_async( - name=provider_name, - provider_type=provider_type, - actor=actor, - ) - providers_from_db = [p.cast_to_subtype() for p in providers_from_db if p.provider_category == ProviderCategory.byok] - providers.extend(providers_from_db) - - if provider_name is not None: - providers = [p for p in providers if p.name == provider_name] - - if provider_type is not None: - providers = [p for p in providers if p.provider_type == provider_type] + # Filter by category if specified + if provider_category: + providers = [p for p in providers if p.provider_category in provider_category] return providers @@ -1179,32 +1212,19 @@ class SyncServer(object): max_reasoning_tokens: Optional[int] = None, enable_reasoner: Optional[bool] = None, ) -> LLMConfig: + # Use provider_manager to get LLMConfig from handle try: - provider_name, model_name = handle.split("/", 1) - provider = await self.get_provider_from_name_async(provider_name, actor) - - all_llm_configs = await provider.list_llm_models_async() - llm_configs = [config for config in all_llm_configs if config.handle == handle] - if not llm_configs: - llm_configs = [config for config in all_llm_configs if config.model == model_name] - if not llm_configs: - available_handles = [config.handle for config in all_llm_configs] - raise HandleNotFoundError(handle, available_handles) - except ValueError as e: - llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle] - if not llm_configs: - llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name] - if not llm_configs: - raise e - - if len(llm_configs) == 1: - llm_config = llm_configs[0] - elif len(llm_configs) > 1: - raise LettaInvalidArgumentError( - f"Multiple LLM models with name {model_name} supported by {provider_name}", argument_name="model_name" + llm_config = await self.provider_manager.get_llm_config_from_handle( + handle=handle, + actor=actor, ) - else: - llm_config = llm_configs[0] + except Exception as e: + # Convert to HandleNotFoundError for backwards compatibility + from letta.orm.errors import NoResultFound + + if isinstance(e, NoResultFound): + raise HandleNotFoundError(handle, []) + raise if context_window_limit is not None: if context_window_limit > llm_config.context_window: @@ -1236,33 +1256,22 @@ class SyncServer(object): async def get_embedding_config_from_handle_async( self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE ) -> EmbeddingConfig: + # Use provider_manager to get EmbeddingConfig from handle try: - provider_name, model_name = handle.split("/", 1) - provider = await self.get_provider_from_name_async(provider_name, actor) - - all_embedding_configs = await provider.list_embedding_models_async() - embedding_configs = [config for config in all_embedding_configs if config.handle == handle] - if not embedding_configs: - raise LettaInvalidArgumentError( - f"Embedding model {model_name} is not supported by {provider_name}", argument_name="model_name" - ) - except LettaInvalidArgumentError as e: - # search local configs - embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle] - if not embedding_configs: - raise e - - if len(embedding_configs) == 1: - embedding_config = embedding_configs[0] - elif len(embedding_configs) > 1: - raise LettaInvalidArgumentError( - f"Multiple embedding models with name {model_name} supported by {provider_name}", argument_name="model_name" + embedding_config = await self.provider_manager.get_embedding_config_from_handle( + handle=handle, + actor=actor, ) - else: - embedding_config = embedding_configs[0] + except Exception as e: + # Convert to LettaInvalidArgumentError for backwards compatibility + from letta.orm.errors import NoResultFound - if embedding_chunk_size: - embedding_config.embedding_chunk_size = embedding_chunk_size + if isinstance(e, NoResultFound): + raise LettaInvalidArgumentError(f"Embedding model {handle} not found", argument_name="handle") + raise + + # Override chunk size if provided + embedding_config.embedding_chunk_size = embedding_chunk_size return embedding_config @@ -1271,7 +1280,7 @@ class SyncServer(object): providers = [provider for provider in all_providers if provider.name == provider_name] if not providers: raise LettaInvalidArgumentError( - f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in self._enabled_providers])})", + f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in all_providers])})", argument_name="provider_name", ) elif len(providers) > 1: @@ -1282,46 +1291,6 @@ class SyncServer(object): return provider - def get_local_llm_configs(self): - llm_models = [] - # NOTE: deprecated - # try: - # llm_configs_dir = os.path.expanduser("~/.letta/llm_configs") - # if os.path.exists(llm_configs_dir): - # for filename in os.listdir(llm_configs_dir): - # if filename.endswith(".json"): - # filepath = os.path.join(llm_configs_dir, filename) - # try: - # with open(filepath, "r") as f: - # config_data = json.load(f) - # llm_config = LLMConfig(**config_data) - # llm_models.append(llm_config) - # except (json.JSONDecodeError, ValueError) as e: - # logger.warning(f"Error parsing LLM config file {filename}: {e}") - # except Exception as e: - # logger.warning(f"Error reading LLM configs directory: {e}") - return llm_models - - def get_local_embedding_configs(self): - embedding_models = [] - # NOTE: deprecated - # try: - # embedding_configs_dir = os.path.expanduser("~/.letta/embedding_configs") - # if os.path.exists(embedding_configs_dir): - # for filename in os.listdir(embedding_configs_dir): - # if filename.endswith(".json"): - # filepath = os.path.join(embedding_configs_dir, filename) - # try: - # with open(filepath, "r") as f: - # config_data = json.load(f) - # embedding_config = EmbeddingConfig(**config_data) - # embedding_models.append(embedding_config) - # except (json.JSONDecodeError, ValueError) as e: - # logger.warning(f"Error parsing embedding config file {filename}: {e}") - # except Exception as e: - # logger.warning(f"Error reading embedding configs directory: {e}") - return embedding_models - def add_llm_model(self, request: LLMConfig) -> LLMConfig: """Add a new LLM model""" diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index aa0af1cd..7cf8a2aa 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -201,7 +201,7 @@ class ProviderManager: @raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER) @trace_method async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser): - """Delete a provider.""" + """Delete a provider and its associated models.""" async with db_registry.async_session() as session: # Clear api key field existing_provider = await ProviderModel.read_async( @@ -218,6 +218,15 @@ class ProviderManager: await existing_provider.update_async(session, actor=actor) + # Soft delete all models associated with this provider + provider_models = await ProviderModelORM.list_async( + db_session=session, + provider_id=provider_id, + check_is_deleted=True, + ) + for model in provider_models: + await model.delete_async(session, actor=actor) + # Soft delete in provider table await existing_provider.delete_async(session, actor=actor) @@ -701,11 +710,11 @@ class ProviderManager: await model.create_async(session) logger.info(f" ✓ Successfully created LLM model {llm_config.handle} with ID {model.id}") except Exception as e: - logger.error(f" ✗ Failed to create LLM model {llm_config.handle}: {e}") + logger.info(f" ✗ Failed to create LLM model {llm_config.handle}: {e}") # Log the full error details import traceback - logger.error(f" Full traceback: {traceback.format_exc()}") + logger.info(f" Full traceback: {traceback.format_exc()}") # Roll back the session to clear the failed transaction await session.rollback() else: @@ -899,8 +908,12 @@ class ProviderManager: if not model: raise NoResultFound(f"LLM model not found with handle='{handle}'") - # Get the provider for this model + # Get the provider for this model and cast to subtype to access provider-specific methods provider = await self.get_provider_async(provider_id=model.provider_id, actor=actor) + typed_provider = provider.cast_to_subtype() + + # Get the default max_output_tokens from the provider (provider-specific logic) + max_tokens = typed_provider.get_default_max_output_tokens(model.name) # Construct the LLMConfig from the model and provider data llm_config = LLMConfig( @@ -911,6 +924,7 @@ class ProviderManager: handle=model.handle, provider_name=provider.name, provider_category=provider.provider_category, + max_tokens=max_tokens, ) return llm_config diff --git a/tests/managers/test_agent_manager.py b/tests/managers/test_agent_manager.py index 9c3f5c51..44273221 100644 --- a/tests/managers/test_agent_manager.py +++ b/tests/managers/test_agent_manager.py @@ -258,14 +258,14 @@ async def test_create_agent_with_model_handle_uses_correct_llm_config(server: Sy """When CreateAgent.model is provided, ensure the correct handle is used to resolve llm_config. This verifies that the model handle passed by the client is forwarded into - SyncServer.get_cached_llm_config_async and that the resulting AgentState + SyncServer.get_llm_config_from_handle_async and that the resulting AgentState carries an llm_config with the same handle. """ # Track the arguments used to resolve the LLM config captured_kwargs: dict = {} - async def fake_get_cached_llm_config_async(self, actor, **kwargs): # type: ignore[override] + async def fake_get_llm_config_from_handle_async(self, actor, **kwargs): # type: ignore[override] from letta.schemas.llm_config import LLMConfig as PydanticLLMConfig captured_kwargs.update(kwargs) @@ -282,8 +282,8 @@ async def test_create_agent_with_model_handle_uses_correct_llm_config(server: Sy model_handle = "openai/gpt-4o-mini" - # Patch SyncServer.get_cached_llm_config_async so we don't depend on provider DB state - with patch.object(SyncServer, "get_cached_llm_config_async", new=fake_get_cached_llm_config_async): + # Patch SyncServer.get_llm_config_from_handle_async so we don't depend on provider DB state + with patch.object(SyncServer, "get_llm_config_from_handle_async", new=fake_get_llm_config_from_handle_async): created_agent = await server.create_agent_async( request=CreateAgent( name="agent_with_model_handle", diff --git a/tests/managers/test_provider_manager.py b/tests/managers/test_provider_manager.py index fec435d7..70c0f418 100644 --- a/tests/managers/test_provider_manager.py +++ b/tests/managers/test_provider_manager.py @@ -487,87 +487,435 @@ async def test_byok_provider_auto_syncs_models(provider_manager, default_user, m # ====================================================================================================================== -# No Encryption Key Tests +# Server Startup Provider Sync Tests # ====================================================================================================================== -@pytest.fixture -def no_encryption_key(): - """Fixture to ensure NO encryption key is set for tests.""" - original_key = settings.encryption_key - settings.encryption_key = None - yield None - settings.encryption_key = original_key +@pytest.mark.asyncio +async def test_server_startup_syncs_base_providers(default_user, default_organization, monkeypatch): + """Test that server startup properly syncs base provider models from environment. + + This test simulates the server startup process and verifies that: + 1. Base providers from environment variables are synced to database + 2. Provider models are fetched from mocked API endpoints + 3. Models are properly persisted to the database with correct metadata + 4. Models can be retrieved using handles + """ + from unittest.mock import AsyncMock + + from letta.schemas.embedding_config import EmbeddingConfig + from letta.schemas.llm_config import LLMConfig + from letta.schemas.providers import AnthropicProvider, OpenAIProvider + from letta.server.server import SyncServer + + # Mock OpenAI API responses + mock_openai_models = { + "data": [ + { + "id": "gpt-4", + "object": "model", + "created": 1687882411, + "owned_by": "openai", + "max_model_len": 8192, + }, + { + "id": "gpt-4-turbo", + "object": "model", + "created": 1712361441, + "owned_by": "system", + "max_model_len": 128000, + }, + { + "id": "text-embedding-ada-002", + "object": "model", + "created": 1671217299, + "owned_by": "openai-internal", + }, + { + "id": "gpt-4-vision", # Should be filtered out by OpenAI provider logic (has disallowed keyword) + "object": "model", + "created": 1698959748, + "owned_by": "system", + "max_model_len": 8192, + }, + ] + } + + # Mock Anthropic API responses + mock_anthropic_models = { + "data": [ + { + "id": "claude-3-5-sonnet-20241022", + "type": "model", + "display_name": "Claude 3.5 Sonnet", + "created_at": "2024-10-22T00:00:00Z", + }, + { + "id": "claude-3-opus-20240229", + "type": "model", + "display_name": "Claude 3 Opus", + "created_at": "2024-02-29T00:00:00Z", + }, + ] + } + + # Mock the API calls for OpenAI + async def mock_openai_get_model_list_async(*args, **kwargs): + return mock_openai_models + + # Mock Anthropic models.list() response + from unittest.mock import MagicMock + + mock_anthropic_response = MagicMock() + mock_anthropic_response.model_dump.return_value = mock_anthropic_models + + # Mock the Anthropic AsyncAnthropic client + class MockAnthropicModels: + async def list(self): + return mock_anthropic_response + + class MockAsyncAnthropic: + def __init__(self, *args, **kwargs): + self.models = MockAnthropicModels() + + # Patch the actual API calling functions + monkeypatch.setattr( + "letta.llm_api.openai.openai_get_model_list_async", + mock_openai_get_model_list_async, + ) + monkeypatch.setattr( + "anthropic.AsyncAnthropic", + MockAsyncAnthropic, + ) + + # Clear ALL provider-related env vars first to ensure clean state + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + monkeypatch.delenv("GEMINI_API_KEY", raising=False) + monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False) + monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False) + monkeypatch.delenv("AZURE_API_KEY", raising=False) + monkeypatch.delenv("GROQ_API_KEY", raising=False) + monkeypatch.delenv("TOGETHER_API_KEY", raising=False) + monkeypatch.delenv("VLLM_API_BASE", raising=False) + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + monkeypatch.delenv("LMSTUDIO_BASE_URL", raising=False) + monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False) + monkeypatch.delenv("XAI_API_KEY", raising=False) + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + monkeypatch.delenv("ZAI_API_KEY", raising=False) + + # Set environment variables to enable only OpenAI and Anthropic + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key-12345") + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key-67890") + + # Reload model_settings to pick up new env vars + from letta.settings import model_settings + + monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key-12345") + monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key-67890") + monkeypatch.setattr(model_settings, "gemini_api_key", None) + monkeypatch.setattr(model_settings, "google_cloud_project", None) + monkeypatch.setattr(model_settings, "google_cloud_location", None) + monkeypatch.setattr(model_settings, "azure_api_key", None) + monkeypatch.setattr(model_settings, "groq_api_key", None) + monkeypatch.setattr(model_settings, "together_api_key", None) + monkeypatch.setattr(model_settings, "vllm_api_base", None) + monkeypatch.setattr(model_settings, "aws_access_key_id", None) + monkeypatch.setattr(model_settings, "aws_secret_access_key", None) + monkeypatch.setattr(model_settings, "lmstudio_base_url", None) + monkeypatch.setattr(model_settings, "deepseek_api_key", None) + monkeypatch.setattr(model_settings, "xai_api_key", None) + monkeypatch.setattr(model_settings, "openrouter_api_key", None) + monkeypatch.setattr(model_settings, "zai_api_key", None) + + # Create server instance (this will load enabled providers from environment) + server = SyncServer(init_with_default_org_and_user=False) + + # Manually set up the default user/org (since we disabled auto-init) + server.default_user = default_user + server.default_org = default_organization + + # Verify enabled providers were loaded + assert len(server._enabled_providers) == 3 # Exactly: letta, openai, anthropic + enabled_provider_names = [p.name for p in server._enabled_providers] + assert "letta" in enabled_provider_names + assert "openai" in enabled_provider_names + assert "anthropic" in enabled_provider_names + + # First, sync base providers to database (this is what init_async does) + await server.provider_manager.sync_base_providers( + base_providers=server._enabled_providers, + actor=default_user, + ) + + # Now call the actual _sync_provider_models_async method + # This simulates what happens during server startup + await server._sync_provider_models_async() + + # Verify OpenAI models were synced + openai_providers = await server.provider_manager.list_providers_async( + name="openai", + actor=default_user, + ) + assert len(openai_providers) == 1, "OpenAI provider should exist" + openai_provider = openai_providers[0] + + # Check OpenAI LLM models + openai_llm_models = await server.provider_manager.list_models_async( + actor=default_user, + provider_id=openai_provider.id, + model_type="llm", + ) + + # Should have gpt-4 and gpt-4-turbo (gpt-4-vision filtered out due to "vision" keyword) + assert len(openai_llm_models) >= 2, f"Expected at least 2 OpenAI LLM models, got {len(openai_llm_models)}" + openai_model_names = [m.name for m in openai_llm_models] + assert "gpt-4" in openai_model_names + assert "gpt-4-turbo" in openai_model_names + + # Check OpenAI embedding models + openai_embedding_models = await server.provider_manager.list_models_async( + actor=default_user, + provider_id=openai_provider.id, + model_type="embedding", + ) + assert len(openai_embedding_models) >= 1, "Expected at least 1 OpenAI embedding model" + embedding_model_names = [m.name for m in openai_embedding_models] + assert "text-embedding-ada-002" in embedding_model_names + + # Verify model metadata is correct + gpt4_models = [m for m in openai_llm_models if m.name == "gpt-4"] + assert len(gpt4_models) > 0, "gpt-4 model should exist" + gpt4_model = gpt4_models[0] + assert gpt4_model.handle == "openai/gpt-4" + assert gpt4_model.model_endpoint_type == "openai" + assert gpt4_model.max_context_window == 8192 + assert gpt4_model.enabled is True + + # Verify Anthropic models were synced + anthropic_providers = await server.provider_manager.list_providers_async( + name="anthropic", + actor=default_user, + ) + assert len(anthropic_providers) == 1, "Anthropic provider should exist" + anthropic_provider = anthropic_providers[0] + + anthropic_llm_models = await server.provider_manager.list_models_async( + actor=default_user, + provider_id=anthropic_provider.id, + model_type="llm", + ) + + # Should have Claude models + assert len(anthropic_llm_models) >= 2, f"Expected at least 2 Anthropic models, got {len(anthropic_llm_models)}" + anthropic_model_names = [m.name for m in anthropic_llm_models] + assert "claude-3-5-sonnet-20241022" in anthropic_model_names + assert "claude-3-opus-20240229" in anthropic_model_names + + # Test that we can retrieve LLMConfig from handle + llm_config = await server.provider_manager.get_llm_config_from_handle( + handle="openai/gpt-4", + actor=default_user, + ) + assert llm_config.model == "gpt-4" + assert llm_config.handle == "openai/gpt-4" + assert llm_config.provider_name == "openai" + assert llm_config.context_window == 8192 + + # Test that we can retrieve EmbeddingConfig from handle + embedding_config = await server.provider_manager.get_embedding_config_from_handle( + handle="openai/text-embedding-ada-002", + actor=default_user, + ) + assert embedding_config.embedding_model == "text-embedding-ada-002" + assert embedding_config.handle == "openai/text-embedding-ada-002" + assert embedding_config.embedding_dim == 1536 @pytest.mark.asyncio -async def test_provider_works_without_encryption_key(provider_manager, default_user, no_encryption_key): - """Test that providers can be created and read when no encryption key is configured. +async def test_server_startup_handles_disabled_providers(default_user, default_organization, monkeypatch): + """Test that server startup properly handles providers that are no longer enabled. - When LETTA_ENCRYPTION_KEY is not set, the Secret class should store values as - plaintext in the _enc column and successfully retrieve them. + This test verifies that: + 1. Base providers that are no longer enabled (env vars removed) are deleted + 2. BYOK providers that are no longer enabled are NOT deleted (user-created) + 3. The sync process handles providers gracefully when API calls fail """ - # Create a provider without encryption key configured - provider_create = ProviderCreate( - name="test-no-encryption-provider", + from letta.schemas.providers import OpenAIProvider, ProviderCreate + from letta.server.server import SyncServer + + # First, manually create providers in the database + provider_manager = ProviderManager() + + # Create a base OpenAI provider (simulating it was synced before) + base_openai_create = ProviderCreate( + name="openai", provider_type=ProviderType.openai, - api_key="sk-plaintext-key-12345", + api_key="sk-old-key", base_url="https://api.openai.com/v1", ) + base_openai = await provider_manager.create_provider_async( + base_openai_create, + actor=default_user, + is_byok=False, # This is a base provider + ) - # Create provider - should work even without encryption - created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user) + # Create a BYOK provider (user-created) + byok_provider_create = ProviderCreate( + name="my-custom-openai", + provider_type=ProviderType.openai, + api_key="sk-my-key", + base_url="https://api.openai.com/v1", + ) + byok_provider = await provider_manager.create_provider_async( + byok_provider_create, + actor=default_user, + is_byok=True, + ) + assert byok_provider.provider_category == ProviderCategory.byok - # Verify provider was created - assert created_provider is not None - assert created_provider.name == "test-no-encryption-provider" + # Now create server with NO environment variables set (all base providers disabled) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - # Verify api_key can be retrieved (stored as plaintext in _enc column) - assert created_provider.api_key_enc.get_plaintext() == "sk-plaintext-key-12345" + from letta.settings import model_settings - # Read the provider back from database - retrieved_provider = await provider_manager.get_provider_async(created_provider.id, actor=default_user) + monkeypatch.setattr(model_settings, "openai_api_key", None) + monkeypatch.setattr(model_settings, "anthropic_api_key", None) - # Verify round-trip works - assert retrieved_provider.api_key_enc.get_plaintext() == "sk-plaintext-key-12345" + # Create server instance + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.default_org = default_organization - # Verify the value in _enc column is actually plaintext (not encrypted) - async with db_registry.async_session() as session: - provider_orm = await ProviderModel.read_async( - db_session=session, - identifier=created_provider.id, - actor=default_user, - ) + # Verify only letta provider is enabled (no openai) + enabled_names = [p.name for p in server._enabled_providers] + assert "letta" in enabled_names + assert "openai" not in enabled_names - # The value should be stored as plaintext since no encryption key was available - assert provider_orm.api_key_enc is not None - # When no encryption key is set, the plaintext is stored directly - # so from_encrypted + get_plaintext should return the original value - assert Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() == "sk-plaintext-key-12345" + # Sync base providers (should not include openai anymore) + await server.provider_manager.sync_base_providers( + base_providers=server._enabled_providers, + actor=default_user, + ) + + # Call _sync_provider_models_async + await server._sync_provider_models_async() + + # Verify base OpenAI provider was deleted (no longer enabled) + try: + await server.provider_manager.get_provider_async(base_openai.id, actor=default_user) + assert False, "Base OpenAI provider should have been deleted" + except Exception: + # Expected - provider should not exist + pass + + # Verify BYOK provider still exists (should NOT be deleted) + byok_still_exists = await server.provider_manager.get_provider_async( + byok_provider.id, + actor=default_user, + ) + assert byok_still_exists is not None + assert byok_still_exists.name == "my-custom-openai" + assert byok_still_exists.provider_category == ProviderCategory.byok @pytest.mark.asyncio -async def test_provider_update_works_without_encryption_key(provider_manager, default_user, no_encryption_key): - """Test that provider updates work when no encryption key is configured.""" - # Create initial provider - provider_create = ProviderCreate( - name="test-no-enc-update-provider", - provider_type=ProviderType.anthropic, - api_key="sk-ant-initial-key", +async def test_server_startup_handles_api_errors_gracefully(default_user, default_organization, monkeypatch): + """Test that server startup handles API errors gracefully without crashing. + + This test verifies that: + 1. If a provider's API call fails during sync, it logs an error but continues + 2. Other providers can still sync successfully + 3. The server startup completes without crashing + """ + from letta.schemas.providers import AnthropicProvider, OpenAIProvider + from letta.server.server import SyncServer + + # Mock OpenAI to fail + async def mock_openai_fail(*args, **kwargs): + raise Exception("OpenAI API is down") + + # Mock Anthropic to succeed + from unittest.mock import MagicMock + + mock_anthropic_response = MagicMock() + mock_anthropic_response.model_dump.return_value = { + "data": [ + { + "id": "claude-3-5-sonnet-20241022", + "type": "model", + "display_name": "Claude 3.5 Sonnet", + "created_at": "2024-10-22T00:00:00Z", + } + ] + } + + class MockAnthropicModels: + async def list(self): + return mock_anthropic_response + + class MockAsyncAnthropic: + def __init__(self, *args, **kwargs): + self.models = MockAnthropicModels() + + monkeypatch.setattr( + "letta.llm_api.openai.openai_get_model_list_async", + mock_openai_fail, + ) + monkeypatch.setattr( + "anthropic.AsyncAnthropic", + MockAsyncAnthropic, ) - created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user) + # Set environment variables + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key") + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key") - # Update the api_key - provider_update = ProviderUpdate( - api_key="sk-ant-updated-key", + from letta.settings import model_settings + + monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key") + monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key") + + # Create server + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.default_org = default_organization + + # Sync base providers + await server.provider_manager.sync_base_providers( + base_providers=server._enabled_providers, + actor=default_user, ) - updated_provider = await provider_manager.update_provider_async(created_provider.id, provider_update, actor=default_user) + # This should NOT crash even though OpenAI fails + await server._sync_provider_models_async() - # Verify the updated key is accessible - assert updated_provider.api_key_enc.get_plaintext() == "sk-ant-updated-key" + # Verify Anthropic still synced successfully + anthropic_providers = await server.provider_manager.list_providers_async( + name="anthropic", + actor=default_user, + ) + assert len(anthropic_providers) == 1 - # Verify via database read - retrieved_provider = await provider_manager.get_provider_async(created_provider.id, actor=default_user) - assert retrieved_provider.api_key_enc.get_plaintext() == "sk-ant-updated-key" + anthropic_models = await server.provider_manager.list_models_async( + actor=default_user, + provider_id=anthropic_providers[0].id, + model_type="llm", + ) + assert len(anthropic_models) >= 1, "Anthropic models should have synced despite OpenAI failure" + + # OpenAI should have no models (sync failed) + openai_providers = await server.provider_manager.list_providers_async( + name="openai", + actor=default_user, + ) + if len(openai_providers) > 0: + openai_models = await server.provider_manager.list_models_async( + actor=default_user, + provider_id=openai_providers[0].id, + ) + # Models might exist from previous runs, but the sync attempt should have been logged as failed + # The key is that the server didn't crash diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index dd4d0571..91e340b8 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -32,6 +32,7 @@ from letta.config import LettaConfig from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.server.server import SyncServer from tests.helpers.utils import upload_file_and_wait +from tests.utils import wait_for_server # Constants SERVER_PORT = 8283 @@ -106,7 +107,7 @@ def client() -> LettaSDKClient: print("Starting server thread") thread = threading.Thread(target=run_server, daemon=True) thread.start() - time.sleep(5) + wait_for_server(server_url, timeout=60) print("Running client tests with server:", server_url) client = LettaSDKClient(base_url=server_url) diff --git a/tests/test_server_providers.py b/tests/test_server_providers.py index 2306c22f..7ae9580e 100644 --- a/tests/test_server_providers.py +++ b/tests/test_server_providers.py @@ -1740,3 +1740,352 @@ async def test_handle_uniqueness_per_org(default_user, provider_manager): assert model is not None assert model.provider_id == provider_1.id # Still original provider assert model.max_context_window == 8192 # Still original + + +@pytest.mark.asyncio +async def test_delete_provider_cascades_to_models(default_user, provider_manager, monkeypatch): + """Test that deleting a provider also soft-deletes its associated models.""" + test_id = generate_test_id() + + # Mock _sync_default_models_for_provider to avoid external API calls + async def mock_sync(provider, actor): + pass # Don't actually sync - we'll manually create models below + + monkeypatch.setattr(provider_manager, "_sync_default_models_for_provider", mock_sync) + + # 1. Create a BYOK provider (org-scoped, so the actor can delete it) + provider_create = ProviderCreate( + name=f"test-cascade-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-test-key", + ) + provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=True) + + # 2. Manually sync models to the provider + llm_models = [ + LLMConfig( + model=f"gpt-4o-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=f"test-{test_id}/gpt-4o", + provider_name=provider.name, + provider_category=ProviderCategory.byok, + ), + LLMConfig( + model=f"gpt-4o-mini-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=16384, + handle=f"test-{test_id}/gpt-4o-mini", + provider_name=provider.name, + provider_category=ProviderCategory.byok, + ), + ] + + embedding_models = [ + EmbeddingConfig( + embedding_model=f"text-embedding-3-small-{test_id}", + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + embedding_chunk_size=300, + handle=f"test-{test_id}/text-embedding-3-small", + ), + ] + + await provider_manager.sync_provider_models_async( + provider=provider, + llm_models=llm_models, + embedding_models=embedding_models, + organization_id=default_user.organization_id, # Org-scoped for BYOK provider + ) + + # 3. Verify models exist before deletion + llm_models_before = await provider_manager.list_models_async( + actor=default_user, + model_type="llm", + provider_id=provider.id, + ) + embedding_models_before = await provider_manager.list_models_async( + actor=default_user, + model_type="embedding", + provider_id=provider.id, + ) + + llm_handles_before = {m.handle for m in llm_models_before} + embedding_handles_before = {m.handle for m in embedding_models_before} + + assert f"test-{test_id}/gpt-4o" in llm_handles_before + assert f"test-{test_id}/gpt-4o-mini" in llm_handles_before + assert f"test-{test_id}/text-embedding-3-small" in embedding_handles_before + + # 4. Delete the provider + await provider_manager.delete_provider_by_id_async(provider.id, actor=default_user) + + # 5. Verify models are soft-deleted (no longer returned in list) + all_llm_models_after = await provider_manager.list_models_async( + actor=default_user, + model_type="llm", + ) + all_embedding_models_after = await provider_manager.list_models_async( + actor=default_user, + model_type="embedding", + ) + + all_llm_handles_after = {m.handle for m in all_llm_models_after} + all_embedding_handles_after = {m.handle for m in all_embedding_models_after} + + # All models from the deleted provider should be gone + assert f"test-{test_id}/gpt-4o" not in all_llm_handles_after + assert f"test-{test_id}/gpt-4o-mini" not in all_llm_handles_after + assert f"test-{test_id}/text-embedding-3-small" not in all_embedding_handles_after + + # 6. Verify provider is also deleted + providers_after = await provider_manager.list_providers_async( + actor=default_user, + name=f"test-cascade-{test_id}", + ) + assert len(providers_after) == 0 + + +@pytest.mark.asyncio +async def test_get_llm_config_from_handle_includes_max_tokens(default_user, provider_manager): + """Test that get_llm_config_from_handle includes max_tokens from provider's get_default_max_output_tokens. + + This test verifies that: + 1. The max_tokens field is populated when retrieving LLMConfig from a handle + 2. The max_tokens value comes from the provider's get_default_max_output_tokens method + 3. Different providers return different default max_tokens values (e.g., OpenAI returns 16384) + """ + test_id = generate_test_id() + + # Create an OpenAI provider + provider_create = ProviderCreate( + name=f"test-openai-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-test-key", + base_url="https://api.openai.com/v1", + ) + provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=False) + + # Sync a model with the provider + llm_models = [ + LLMConfig( + model=f"gpt-4o-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=f"test-{test_id}/gpt-4o", + provider_name=provider.name, + provider_category=ProviderCategory.base, + ), + ] + + await provider_manager.sync_provider_models_async( + provider=provider, + llm_models=llm_models, + embedding_models=[], + organization_id=None, # Global model + ) + + # Get LLMConfig from handle + llm_config = await provider_manager.get_llm_config_from_handle( + handle=f"test-{test_id}/gpt-4o", + actor=default_user, + ) + + # Verify max_tokens is set and comes from OpenAI provider's default (16384 for non-o1/o3 models) + assert llm_config.max_tokens is not None, "max_tokens should be set" + assert llm_config.max_tokens == 16384, f"Expected max_tokens=16384 for OpenAI gpt-4o, got {llm_config.max_tokens}" + + # Test with a gpt-5 model (should have 16384) + llm_models_gpt5 = [ + LLMConfig( + model=f"gpt-5-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=200000, + handle=f"test-{test_id}/gpt-5", + provider_name=provider.name, + provider_category=ProviderCategory.base, + ), + ] + + await provider_manager.sync_provider_models_async( + provider=provider, + llm_models=llm_models_gpt5, + embedding_models=[], + organization_id=None, + ) + + llm_config_gpt5 = await provider_manager.get_llm_config_from_handle( + handle=f"test-{test_id}/gpt-5", + actor=default_user, + ) + + # gpt-5 models also have 16384 max_tokens + assert llm_config_gpt5.max_tokens == 16384, f"Expected max_tokens=16384 for gpt-5, got {llm_config_gpt5.max_tokens}" + + +@pytest.mark.asyncio +async def test_server_list_llm_models_async_reads_from_database(default_user, provider_manager): + """Test that the server's list_llm_models_async reads models from database, not in-memory. + + This test verifies that: + 1. Models synced to the database are returned by list_llm_models_async + 2. The LLMConfig objects are correctly constructed from database-cached models + 3. Provider filtering works correctly when reading from database + """ + from letta.server.server import SyncServer + + test_id = generate_test_id() + + # Create a provider in the database + provider_create = ProviderCreate( + name=f"test-db-provider-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-test-key", + base_url="https://custom.openai.com/v1", + ) + provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=False) + + # Sync models to database + llm_models = [ + LLMConfig( + model=f"custom-model-1-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://custom.openai.com/v1", + context_window=32000, + handle=f"test-{test_id}/custom-model-1", + provider_name=provider.name, + provider_category=ProviderCategory.base, + ), + LLMConfig( + model=f"custom-model-2-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://custom.openai.com/v1", + context_window=64000, + handle=f"test-{test_id}/custom-model-2", + provider_name=provider.name, + provider_category=ProviderCategory.base, + ), + ] + + await provider_manager.sync_provider_models_async( + provider=provider, + llm_models=llm_models, + embedding_models=[], + organization_id=None, + ) + + # Create server instance + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.provider_manager = provider_manager + + # List LLM models via server + models = await server.list_llm_models_async( + actor=default_user, + provider_name=f"test-db-provider-{test_id}", + ) + + # Verify models were read from database + handles = {m.handle for m in models} + assert f"test-{test_id}/custom-model-1" in handles, "custom-model-1 should be in database" + assert f"test-{test_id}/custom-model-2" in handles, "custom-model-2 should be in database" + + # Verify LLMConfig properties are correctly populated from database + model_1 = next(m for m in models if m.handle == f"test-{test_id}/custom-model-1") + assert model_1.model == f"custom-model-1-{test_id}" + assert model_1.context_window == 32000 + assert model_1.model_endpoint == "https://custom.openai.com/v1" + assert model_1.provider_name == f"test-db-provider-{test_id}" + + model_2 = next(m for m in models if m.handle == f"test-{test_id}/custom-model-2") + assert model_2.model == f"custom-model-2-{test_id}" + assert model_2.context_window == 64000 + + +@pytest.mark.asyncio +async def test_get_enabled_providers_async_queries_database(default_user, provider_manager): + """Test that get_enabled_providers_async queries providers from database, not in-memory list. + + This test verifies that: + 1. Providers created in the database are returned by get_enabled_providers_async + 2. The method queries the database, not an in-memory _enabled_providers list + 3. Provider filtering by category works correctly from database + """ + from letta.server.server import SyncServer + + test_id = generate_test_id() + + # Create providers in the database + base_provider_create = ProviderCreate( + name=f"test-base-provider-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-test-key", + base_url="https://api.openai.com/v1", + ) + base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False) + + byok_provider_create = ProviderCreate( + name=f"test-byok-provider-{test_id}", + provider_type=ProviderType.anthropic, + api_key="sk-test-byok-key", + ) + byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True) + + # Create server instance - importantly, don't set _enabled_providers + # This ensures we're testing database queries, not in-memory list + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.provider_manager = provider_manager + # Clear in-memory providers to prove we're querying database + server._enabled_providers = [] + + # Get all providers - should query database + all_providers = await server.get_enabled_providers_async(actor=default_user) + provider_names = [p.name for p in all_providers] + + assert f"test-base-provider-{test_id}" in provider_names, "Base provider should be in database" + assert f"test-byok-provider-{test_id}" in provider_names, "BYOK provider should be in database" + + # Filter by provider category + base_only = await server.get_enabled_providers_async( + actor=default_user, + provider_category=[ProviderCategory.base], + ) + base_only_names = [p.name for p in base_only] + + assert f"test-base-provider-{test_id}" in base_only_names, "Base provider should be in base-only list" + assert f"test-byok-provider-{test_id}" not in base_only_names, "BYOK provider should NOT be in base-only list" + + byok_only = await server.get_enabled_providers_async( + actor=default_user, + provider_category=[ProviderCategory.byok], + ) + byok_only_names = [p.name for p in byok_only] + + assert f"test-byok-provider-{test_id}" in byok_only_names, "BYOK provider should be in byok-only list" + assert f"test-base-provider-{test_id}" not in byok_only_names, "Base provider should NOT be in byok-only list" + + # Filter by provider name + specific_provider = await server.get_enabled_providers_async( + actor=default_user, + provider_name=f"test-base-provider-{test_id}", + ) + + assert len(specific_provider) == 1 + assert specific_provider[0].name == f"test-base-provider-{test_id}" + assert specific_provider[0].provider_type == ProviderType.openai + + # Filter by provider type + openai_providers = await server.get_enabled_providers_async( + actor=default_user, + provider_type=ProviderType.openai, + ) + openai_names = [p.name for p in openai_providers] + + assert f"test-base-provider-{test_id}" in openai_names + assert f"test-byok-provider-{test_id}" not in openai_names # This is anthropic type