diff --git a/letta/constants.py b/letta/constants.py index 882f9ce3..48e160a7 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -8,6 +8,26 @@ LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir") LETTA_MODEL_ENDPOINT = "https://inference.letta.com/v1/" DEFAULT_TIMEZONE = "UTC" +# Provider ordering for model listing (matches original _enabled_providers list order) +PROVIDER_ORDER = { + "letta": 0, + "openai": 1, + "anthropic": 2, + "ollama": 3, + "google_ai": 4, + "google_vertex": 5, + "azure": 6, + "groq": 7, + "together": 8, + "vllm": 9, + "bedrock": 10, + "deepseek": 11, + "xai": 12, + "lmstudio": 13, + "zai": 14, + "openrouter": 15, # Note: OpenRouter uses OpenRouterProvider, not a ProviderType enum +} + ADMIN_PREFIX = "/v1/admin" API_PREFIX = "/v1" OLLAMA_API_PREFIX = "/v1" diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 599969f3..00f0815e 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -235,16 +235,17 @@ class LLMClientBase: def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]: """ Returns the override key for the given llm config. - For both base and BYOK providers, fetch the API key from the database. + Only fetches API key from database for BYOK providers. + Base providers use environment variables directly. """ api_key = None - # Fetch API key from database for both base and BYOK providers - # This ensures that base providers (from environment) also have their keys persisted and accessible - if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]: + # Only fetch API key from database for BYOK providers + # Base providers should always use environment variables + if llm_config.provider_category == ProviderCategory.byok: from letta.services.provider_manager import ProviderManager api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor) - # If we got an empty string from the database (e.g., Letta provider), treat it as None + # If we got an empty string from the database, treat it as None # so the client can fall back to environment variables or default behavior if api_key == "": api_key = None @@ -254,16 +255,17 @@ class LLMClientBase: async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]: """ Returns the override key for the given llm config. - For both base and BYOK providers, fetch the API key from the database. + Only fetches API key from database for BYOK providers. + Base providers use environment variables directly. """ api_key = None - # Fetch API key from database for both base and BYOK providers - # This ensures that base providers (from environment) also have their keys persisted and accessible - if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]: + # Only fetch API key from database for BYOK providers + # Base providers should always use environment variables + if llm_config.provider_category == ProviderCategory.byok: from letta.services.provider_manager import ProviderManager api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor) - # If we got an empty string from the database (e.g., Letta provider), treat it as None + # If we got an empty string from the database, treat it as None # so the client can fall back to environment variables or default behavior if api_key == "": api_key = None diff --git a/letta/server/rest_api/routers/v1/providers.py b/letta/server/rest_api/routers/v1/providers.py index 0e131462..5d5135f9 100644 --- a/letta/server/rest_api/routers/v1/providers.py +++ b/letta/server/rest_api/routers/v1/providers.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, List, Literal, Optional from fastapi import APIRouter, Body, Depends, Query, status from fastapi.responses import JSONResponse -from letta.schemas.enums import ProviderType +from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.providers import Provider, ProviderBase, ProviderCheck, ProviderCreate, ProviderUpdate from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server from letta.validators import ProviderId @@ -39,7 +39,14 @@ async def list_providers( """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) providers = await server.provider_manager.list_providers_async( - before=before, after=after, limit=limit, actor=actor, name=name, provider_type=provider_type, ascending=(order == "asc") + before=before, + after=after, + limit=limit, + actor=actor, + name=name, + provider_type=provider_type, + provider_category=[ProviderCategory.byok], + ascending=(order == "asc"), ) return providers diff --git a/letta/server/server.py b/letta/server/server.py index 7b79e288..074888d6 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1097,6 +1097,18 @@ class SyncServer(object): passage_count, document_count = await load_data(connector, source, self.passage_manager, self.file_manager, actor=actor) return passage_count, document_count + def _get_provider_sort_key(self, model: LLMConfig) -> Tuple[int, str, str]: + """Get sort key for a model: (provider_priority, provider_name, model_name)""" + provider_priority = constants.PROVIDER_ORDER.get(model.provider_name, 999) + return (provider_priority, model.provider_name or "", model.model or "") + + def _get_embedding_provider_sort_key(self, model: EmbeddingConfig) -> Tuple[int, str, str]: + """Get sort key for an embedding model: (provider_priority, provider_name, model_name)""" + # Extract provider name from handle (format: "provider_name/model_name") + provider_name = model.handle.split("/")[0] if model.handle and "/" in model.handle else "" + provider_priority = constants.PROVIDER_ORDER.get(provider_name, 999) + return (provider_priority, provider_name, model.embedding_model or "") + @trace_method async def list_llm_models_async( self, @@ -1105,86 +1117,122 @@ class SyncServer(object): provider_name: Optional[str] = None, provider_type: Optional[ProviderType] = None, ) -> List[LLMConfig]: - """List available LLM models from database cache""" - # Get provider IDs if filtering by provider - provider_ids = None - if provider_name or provider_type: - providers = await self.get_enabled_providers_async( - provider_category=provider_category, - provider_name=provider_name, - provider_type=provider_type, - actor=actor, - ) - provider_ids = [p.id for p in providers] - - # If filtering was requested but no providers matched, return empty list - if not provider_ids: - return [] - - # Get models from database - provider_models = await self.provider_manager.list_models_async( - actor=actor, - model_type="llm", - provider_id=provider_ids[0] if provider_ids and len(provider_ids) == 1 else None, - enabled=True, - ) - - # Build LLMConfig objects from cached data - # Cache providers to avoid N+1 queries - provider_cache: Dict[str, Provider] = {} + """List available LLM models - base from DB, BYOK from provider endpoints""" llm_models = [] - for model in provider_models: - # Skip if filtering by provider and model doesn't match - if provider_ids and model.provider_id not in provider_ids: - continue - # Get provider details (with caching to avoid N+1 queries) - if model.provider_id not in provider_cache: - provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor) - provider = provider_cache[model.provider_id] + # Determine which categories to include + include_base = not provider_category or ProviderCategory.base in provider_category + include_byok = not provider_category or ProviderCategory.byok in provider_category - llm_config = LLMConfig( - model=model.name, - model_endpoint_type=model.model_endpoint_type, - model_endpoint=provider.base_url or model.model_endpoint_type, - context_window=model.max_context_window or 16384, - handle=model.handle, - provider_name=provider.name, - provider_category=provider.provider_category, + # Get base provider models from database + if include_base: + provider_models = await self.provider_manager.list_models_async( + actor=actor, + model_type="llm", + enabled=True, ) - llm_models.append(llm_config) + + # Build LLMConfig objects from database + provider_cache: Dict[str, Provider] = {} + for model in provider_models: + # Get provider details (with caching to avoid N+1 queries) + if model.provider_id not in provider_cache: + provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor) + provider = provider_cache[model.provider_id] + + # Skip non-base providers (they're handled separately) + if provider.provider_category != ProviderCategory.base: + continue + + # Apply provider_name/provider_type filters if specified + if provider_name and provider.name != provider_name: + continue + if provider_type and provider.provider_type != provider_type: + continue + + llm_config = LLMConfig( + model=model.name, + model_endpoint_type=model.model_endpoint_type, + model_endpoint=provider.base_url or model.model_endpoint_type, + context_window=model.max_context_window or 16384, + handle=model.handle, + provider_name=provider.name, + provider_category=provider.provider_category, + ) + llm_models.append(llm_config) + + # Get BYOK provider models by hitting provider endpoints directly + if include_byok: + byok_providers = await self.provider_manager.list_providers_async( + actor=actor, + name=provider_name, + provider_type=provider_type, + provider_category=[ProviderCategory.byok], + ) + + for provider in byok_providers: + try: + typed_provider = provider.cast_to_subtype() + models = await typed_provider.list_llm_models_async() + llm_models.extend(models) + except Exception as e: + logger.warning(f"Failed to fetch models from BYOK provider {provider.name}: {e}") + + # Sort by provider order (matching old _enabled_providers order), then by model name + llm_models.sort(key=self._get_provider_sort_key) return llm_models async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]: - """List available embedding models from database cache""" - # Get models from database + """List available embedding models - base from DB, BYOK from provider endpoints""" + embedding_models = [] + + # Get base provider models from database provider_models = await self.provider_manager.list_models_async( actor=actor, model_type="embedding", enabled=True, ) - # Build EmbeddingConfig objects from cached data - # Cache providers to avoid N+1 queries + # Build EmbeddingConfig objects from database (base providers only) provider_cache: Dict[str, Provider] = {} - embedding_models = [] for model in provider_models: # Get provider details (with caching to avoid N+1 queries) if model.provider_id not in provider_cache: provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor) provider = provider_cache[model.provider_id] + # Skip non-base providers (they're handled separately) + if provider.provider_category != ProviderCategory.base: + continue + embedding_config = EmbeddingConfig( embedding_model=model.name, embedding_endpoint_type=model.model_endpoint_type, embedding_endpoint=provider.base_url or model.model_endpoint_type, - embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default + embedding_dim=model.embedding_dim or 1536, embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE, handle=model.handle, ) embedding_models.append(embedding_config) + # Get BYOK provider models by hitting provider endpoints directly + byok_providers = await self.provider_manager.list_providers_async( + actor=actor, + provider_category=[ProviderCategory.byok], + ) + + for provider in byok_providers: + try: + typed_provider = provider.cast_to_subtype() + models = await typed_provider.list_embedding_models_async() + embedding_models.extend(models) + except Exception as e: + logger.warning(f"Failed to fetch embedding models from BYOK provider {provider.name}: {e}") + + # Sort by provider order (matching old _enabled_providers order), then by model name + embedding_models.sort(key=self._get_embedding_provider_sort_key) + return embedding_models async def get_enabled_providers_async( diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 7cf8a2aa..b69b6364 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -240,6 +240,7 @@ class ProviderManager: actor: PydanticUser, name: Optional[str] = None, provider_type: Optional[ProviderType] = None, + provider_category: Optional[List[ProviderCategory]] = None, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 50, @@ -280,7 +281,14 @@ class ProviderManager: ) # Combine both lists - all_providers = org_providers + global_providers + all_providers = [] + if not provider_category: + all_providers = org_providers + global_providers + else: + if ProviderCategory.byok in provider_category: + all_providers += org_providers + if ProviderCategory.base in provider_category: + all_providers += global_providers # Remove deprecated api_key and access_key fields from the response for provider in all_providers: @@ -575,13 +583,14 @@ class ProviderManager: continue # Convert Provider to ProviderCreate - api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None - access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None + # NOTE: Do NOT store API keys for base providers in the database. + # Base providers should always use environment variables for API keys. + # This ensures keys stay in sync with env vars and aren't duplicated in DB. provider_create = ProviderCreate( name=provider.name, provider_type=provider.provider_type, - api_key=api_key or "", # ProviderCreate requires api_key, use empty string if None - access_key=access_key, + api_key="", # Base providers use env vars, not DB-stored keys + access_key=None, region=provider.region, base_url=provider.base_url, api_version=provider.api_version, diff --git a/tests/test_server_providers.py b/tests/test_server_providers.py index 7ae9580e..3bd8de9f 100644 --- a/tests/test_server_providers.py +++ b/tests/test_server_providers.py @@ -2089,3 +2089,341 @@ async def test_get_enabled_providers_async_queries_database(default_user, provid assert f"test-base-provider-{test_id}" in openai_names assert f"test-byok-provider-{test_id}" not in openai_names # This is anthropic type + + +# ============================================================================= +# BYOK Provider and Model Listing Integration Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_list_providers_filters_by_category(default_user, provider_manager): + """Test that list_providers_async correctly filters by provider_category.""" + test_id = generate_test_id() + + # Create a base provider + base_provider_create = ProviderCreate( + name=f"test-base-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-base-key", + ) + base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False) + + # Create a BYOK provider + byok_provider_create = ProviderCreate( + name=f"test-byok-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-byok-key", + ) + byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True) + + # Verify base provider has correct category + assert base_provider.provider_category == ProviderCategory.base + + # Verify BYOK provider has correct category + assert byok_provider.provider_category == ProviderCategory.byok + + # List only BYOK providers + byok_providers = await provider_manager.list_providers_async( + actor=default_user, + provider_category=[ProviderCategory.byok], + ) + byok_names = [p.name for p in byok_providers] + + assert f"test-byok-{test_id}" in byok_names + assert f"test-base-{test_id}" not in byok_names + + # List only base providers + base_providers = await provider_manager.list_providers_async( + actor=default_user, + provider_category=[ProviderCategory.base], + ) + base_names = [p.name for p in base_providers] + + assert f"test-base-{test_id}" in base_names + assert f"test-byok-{test_id}" not in base_names + + +@pytest.mark.asyncio +async def test_base_provider_api_key_not_stored_in_db(default_user, provider_manager): + """Test that sync_base_providers does NOT store API keys for base providers.""" + # Create base providers with API keys + base_providers = [ + OpenAIProvider(name="test-openai-no-key", api_key="sk-should-not-be-stored"), + ] + + # Sync to database + await provider_manager.sync_base_providers(base_providers=base_providers, actor=default_user) + + # Retrieve the provider from database + providers = await provider_manager.list_providers_async(name="test-openai-no-key", actor=default_user) + assert len(providers) == 1 + + provider = providers[0] + assert provider.provider_category == ProviderCategory.base + + # The API key should be empty (not stored) for base providers + if provider.api_key_enc: + api_key = await provider.api_key_enc.get_plaintext_async() + assert api_key == "" or api_key is None, "Base provider API key should not be stored in database" + + +@pytest.mark.asyncio +async def test_byok_provider_api_key_stored_in_db(default_user, provider_manager): + """Test that BYOK providers DO have their API keys stored in the database.""" + test_id = generate_test_id() + + # Create a BYOK provider with API key + byok_provider_create = ProviderCreate( + name=f"test-byok-with-key-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-byok-should-be-stored", + ) + byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True) + + # Retrieve the provider from database + providers = await provider_manager.list_providers_async(name=f"test-byok-with-key-{test_id}", actor=default_user) + assert len(providers) == 1 + + provider = providers[0] + assert provider.provider_category == ProviderCategory.byok + + # The API key SHOULD be stored for BYOK providers + assert provider.api_key_enc is not None + api_key = await provider.api_key_enc.get_plaintext_async() + assert api_key == "sk-byok-should-be-stored", "BYOK provider API key should be stored in database" + + +@pytest.mark.asyncio +async def test_server_list_llm_models_base_from_db(default_user, provider_manager): + """Test that server.list_llm_models_async fetches base models from database.""" + from letta.server.server import SyncServer + + test_id = generate_test_id() + + # Create base provider and models (these ARE stored in DB) + base_provider_create = ProviderCreate( + name=f"test-base-llm-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-base-key", + ) + base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False) + + base_llm_model = LLMConfig( + model=f"base-gpt-4o-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=f"test-base-llm-{test_id}/gpt-4o", + provider_name=base_provider.name, + provider_category=ProviderCategory.base, + ) + + await provider_manager.sync_provider_models_async( + provider=base_provider, + llm_models=[base_llm_model], + embedding_models=[], + organization_id=None, + ) + + # Create server and list models + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.provider_manager = provider_manager + server._enabled_providers = [] # Clear to test database-backed listing + + # List all models - base models come from DB + all_models = await server.list_llm_models_async(actor=default_user) + all_handles = [m.handle for m in all_models] + + assert f"test-base-llm-{test_id}/gpt-4o" in all_handles, "Base model should be in list" + + # List only base models + base_models = await server.list_llm_models_async( + actor=default_user, + provider_category=[ProviderCategory.base], + ) + base_handles = [m.handle for m in base_models] + + assert f"test-base-llm-{test_id}/gpt-4o" in base_handles + + +@pytest.mark.asyncio +async def test_server_list_llm_models_byok_from_provider_api(default_user, provider_manager): + """Test that server.list_llm_models_async fetches BYOK models from provider API, not DB. + + Note: BYOK models are fetched by calling the provider's list_llm_models_async() method, + which hits the actual provider API. This test uses mocking to verify that flow. + """ + from letta.schemas.providers import Provider + from letta.server.server import SyncServer + + test_id = generate_test_id() + + # Create a BYOK provider (but don't sync models to DB - they come from API) + byok_provider_create = ProviderCreate( + name=f"test-byok-llm-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-byok-key", + ) + byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True) + + # Create server + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.provider_manager = provider_manager + server._enabled_providers = [] + + # Mock the BYOK provider's list_llm_models_async to return test models + mock_byok_models = [ + LLMConfig( + model=f"byok-gpt-4o-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://custom.openai.com/v1", + context_window=64000, + handle=f"test-byok-llm-{test_id}/gpt-4o", + provider_name=byok_provider.name, + provider_category=ProviderCategory.byok, + ) + ] + + # Create a mock typed provider that returns our test models + mock_typed_provider = MagicMock() + mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_byok_models) + + # Patch cast_to_subtype on the Provider class to return our mock + with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider): + # List BYOK models - should call provider API via cast_to_subtype().list_llm_models_async() + byok_models = await server.list_llm_models_async( + actor=default_user, + provider_category=[ProviderCategory.byok], + ) + + # Verify the mock was called (proving we hit provider API, not DB) + mock_typed_provider.list_llm_models_async.assert_called() + + # Verify we got the mocked models back + byok_handles = [m.handle for m in byok_models] + assert f"test-byok-llm-{test_id}/gpt-4o" in byok_handles + + +@pytest.mark.asyncio +async def test_server_list_embedding_models_base_from_db(default_user, provider_manager): + """Test that server.list_embedding_models_async fetches base models from database. + + Note: Similar to LLM models, base embedding models are stored in DB while BYOK + embedding models would be fetched from provider API. + """ + from letta.server.server import SyncServer + + test_id = generate_test_id() + + # Create base provider and embedding models (these ARE stored in DB) + base_provider_create = ProviderCreate( + name=f"test-base-embed-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-base-key", + ) + base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False) + + base_embedding_model = EmbeddingConfig( + embedding_model=f"base-text-embedding-{test_id}", + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + embedding_chunk_size=300, + handle=f"test-base-embed-{test_id}/text-embedding-3-small", + ) + + await provider_manager.sync_provider_models_async( + provider=base_provider, + llm_models=[], + embedding_models=[base_embedding_model], + organization_id=None, + ) + + # Create server and list models + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.provider_manager = provider_manager + server._enabled_providers = [] + + # List all embedding models - base models come from DB + all_models = await server.list_embedding_models_async(actor=default_user) + all_handles = [m.handle for m in all_models] + + assert f"test-base-embed-{test_id}/text-embedding-3-small" in all_handles + + +@pytest.mark.asyncio +async def test_provider_ordering_matches_constants(default_user, provider_manager): + """Test that provider ordering in model listing matches PROVIDER_ORDER in constants.""" + from letta.constants import PROVIDER_ORDER + from letta.server.server import SyncServer + + test_id = generate_test_id() + + # Create providers with different names that should have different ordering + providers_to_create = [ + ("zai", ProviderType.zai, 14), # Lower priority + ("openai", ProviderType.openai, 1), # Higher priority + ("anthropic", ProviderType.anthropic, 2), # Medium priority + ] + + created_providers = [] + for name_suffix, provider_type, expected_order in providers_to_create: + provider_create = ProviderCreate( + name=f"{name_suffix}", # Use actual provider name for ordering + provider_type=provider_type, + api_key=f"sk-{name_suffix}-key", + ) + # Check if provider already exists + existing = await provider_manager.list_providers_async(name=name_suffix, actor=default_user) + if not existing: + provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=False) + created_providers.append((provider, expected_order)) + + # Create a model for this provider + llm_model = LLMConfig( + model=f"test-model-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://api.example.com/v1", + context_window=8192, + handle=f"{name_suffix}/test-model-{test_id}", + provider_name=provider.name, + provider_category=ProviderCategory.base, + ) + + await provider_manager.sync_provider_models_async( + provider=provider, + llm_models=[llm_model], + embedding_models=[], + organization_id=None, + ) + + # Verify PROVIDER_ORDER has expected values + assert PROVIDER_ORDER.get("openai") == 1 + assert PROVIDER_ORDER.get("anthropic") == 2 + assert PROVIDER_ORDER.get("zai") == 14 + + # Create server and verify ordering + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.provider_manager = provider_manager + server._enabled_providers = [] + + # List models and check ordering + all_models = await server.list_llm_models_async(actor=default_user) + + # Filter to only our test models + test_models = [m for m in all_models if f"test-model-{test_id}" in m.handle] + + if len(test_models) >= 2: + # Verify models are sorted by provider order + provider_names_in_order = [m.provider_name for m in test_models] + + # Get the indices in PROVIDER_ORDER + indices = [PROVIDER_ORDER.get(name, 999) for name in provider_names_in_order] + + # Verify the list is sorted by provider order + assert indices == sorted(indices), f"Models should be sorted by PROVIDER_ORDER, got: {provider_names_in_order}"