fix: fix providers and models persistence (#8302)

This commit is contained in:
Ari Webb
2026-01-05 18:05:44 -08:00
committed by Caren Thomas
parent e56c5c5b49
commit 02f3e3f3b9
6 changed files with 491 additions and 67 deletions

View File

@@ -8,6 +8,26 @@ LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir")
LETTA_MODEL_ENDPOINT = "https://inference.letta.com/v1/"
DEFAULT_TIMEZONE = "UTC"
# Provider ordering for model listing (matches original _enabled_providers list order)
PROVIDER_ORDER = {
"letta": 0,
"openai": 1,
"anthropic": 2,
"ollama": 3,
"google_ai": 4,
"google_vertex": 5,
"azure": 6,
"groq": 7,
"together": 8,
"vllm": 9,
"bedrock": 10,
"deepseek": 11,
"xai": 12,
"lmstudio": 13,
"zai": 14,
"openrouter": 15, # Note: OpenRouter uses OpenRouterProvider, not a ProviderType enum
}
ADMIN_PREFIX = "/v1/admin"
API_PREFIX = "/v1"
OLLAMA_API_PREFIX = "/v1"

View File

@@ -235,16 +235,17 @@ class LLMClientBase:
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Returns the override key for the given llm config.
For both base and BYOK providers, fetch the API key from the database.
Only fetches API key from database for BYOK providers.
Base providers use environment variables directly.
"""
api_key = None
# Fetch API key from database for both base and BYOK providers
# This ensures that base providers (from environment) also have their keys persisted and accessible
if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]:
# Only fetch API key from database for BYOK providers
# Base providers should always use environment variables
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
# If we got an empty string from the database (e.g., Letta provider), treat it as None
# If we got an empty string from the database, treat it as None
# so the client can fall back to environment variables or default behavior
if api_key == "":
api_key = None
@@ -254,16 +255,17 @@ class LLMClientBase:
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Returns the override key for the given llm config.
For both base and BYOK providers, fetch the API key from the database.
Only fetches API key from database for BYOK providers.
Base providers use environment variables directly.
"""
api_key = None
# Fetch API key from database for both base and BYOK providers
# This ensures that base providers (from environment) also have their keys persisted and accessible
if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]:
# Only fetch API key from database for BYOK providers
# Base providers should always use environment variables
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
# If we got an empty string from the database (e.g., Letta provider), treat it as None
# If we got an empty string from the database, treat it as None
# so the client can fall back to environment variables or default behavior
if api_key == "":
api_key = None

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, List, Literal, Optional
from fastapi import APIRouter, Body, Depends, Query, status
from fastapi.responses import JSONResponse
from letta.schemas.enums import ProviderType
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.providers import Provider, ProviderBase, ProviderCheck, ProviderCreate, ProviderUpdate
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
from letta.validators import ProviderId
@@ -39,7 +39,14 @@ async def list_providers(
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
providers = await server.provider_manager.list_providers_async(
before=before, after=after, limit=limit, actor=actor, name=name, provider_type=provider_type, ascending=(order == "asc")
before=before,
after=after,
limit=limit,
actor=actor,
name=name,
provider_type=provider_type,
provider_category=[ProviderCategory.byok],
ascending=(order == "asc"),
)
return providers

View File

@@ -1097,6 +1097,18 @@ class SyncServer(object):
passage_count, document_count = await load_data(connector, source, self.passage_manager, self.file_manager, actor=actor)
return passage_count, document_count
def _get_provider_sort_key(self, model: LLMConfig) -> Tuple[int, str, str]:
"""Get sort key for a model: (provider_priority, provider_name, model_name)"""
provider_priority = constants.PROVIDER_ORDER.get(model.provider_name, 999)
return (provider_priority, model.provider_name or "", model.model or "")
def _get_embedding_provider_sort_key(self, model: EmbeddingConfig) -> Tuple[int, str, str]:
"""Get sort key for an embedding model: (provider_priority, provider_name, model_name)"""
# Extract provider name from handle (format: "provider_name/model_name")
provider_name = model.handle.split("/")[0] if model.handle and "/" in model.handle else ""
provider_priority = constants.PROVIDER_ORDER.get(provider_name, 999)
return (provider_priority, provider_name, model.embedding_model or "")
@trace_method
async def list_llm_models_async(
self,
@@ -1105,86 +1117,122 @@ class SyncServer(object):
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[LLMConfig]:
"""List available LLM models from database cache"""
# Get provider IDs if filtering by provider
provider_ids = None
if provider_name or provider_type:
providers = await self.get_enabled_providers_async(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
)
provider_ids = [p.id for p in providers]
# If filtering was requested but no providers matched, return empty list
if not provider_ids:
return []
# Get models from database
provider_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="llm",
provider_id=provider_ids[0] if provider_ids and len(provider_ids) == 1 else None,
enabled=True,
)
# Build LLMConfig objects from cached data
# Cache providers to avoid N+1 queries
provider_cache: Dict[str, Provider] = {}
"""List available LLM models - base from DB, BYOK from provider endpoints"""
llm_models = []
for model in provider_models:
# Skip if filtering by provider and model doesn't match
if provider_ids and model.provider_id not in provider_ids:
continue
# Get provider details (with caching to avoid N+1 queries)
if model.provider_id not in provider_cache:
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
provider = provider_cache[model.provider_id]
# Determine which categories to include
include_base = not provider_category or ProviderCategory.base in provider_category
include_byok = not provider_category or ProviderCategory.byok in provider_category
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=provider.base_url or model.model_endpoint_type,
context_window=model.max_context_window or 16384,
handle=model.handle,
provider_name=provider.name,
provider_category=provider.provider_category,
# Get base provider models from database
if include_base:
provider_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="llm",
enabled=True,
)
llm_models.append(llm_config)
# Build LLMConfig objects from database
provider_cache: Dict[str, Provider] = {}
for model in provider_models:
# Get provider details (with caching to avoid N+1 queries)
if model.provider_id not in provider_cache:
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
provider = provider_cache[model.provider_id]
# Skip non-base providers (they're handled separately)
if provider.provider_category != ProviderCategory.base:
continue
# Apply provider_name/provider_type filters if specified
if provider_name and provider.name != provider_name:
continue
if provider_type and provider.provider_type != provider_type:
continue
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=provider.base_url or model.model_endpoint_type,
context_window=model.max_context_window or 16384,
handle=model.handle,
provider_name=provider.name,
provider_category=provider.provider_category,
)
llm_models.append(llm_config)
# Get BYOK provider models by hitting provider endpoints directly
if include_byok:
byok_providers = await self.provider_manager.list_providers_async(
actor=actor,
name=provider_name,
provider_type=provider_type,
provider_category=[ProviderCategory.byok],
)
for provider in byok_providers:
try:
typed_provider = provider.cast_to_subtype()
models = await typed_provider.list_llm_models_async()
llm_models.extend(models)
except Exception as e:
logger.warning(f"Failed to fetch models from BYOK provider {provider.name}: {e}")
# Sort by provider order (matching old _enabled_providers order), then by model name
llm_models.sort(key=self._get_provider_sort_key)
return llm_models
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
"""List available embedding models from database cache"""
# Get models from database
"""List available embedding models - base from DB, BYOK from provider endpoints"""
embedding_models = []
# Get base provider models from database
provider_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="embedding",
enabled=True,
)
# Build EmbeddingConfig objects from cached data
# Cache providers to avoid N+1 queries
# Build EmbeddingConfig objects from database (base providers only)
provider_cache: Dict[str, Provider] = {}
embedding_models = []
for model in provider_models:
# Get provider details (with caching to avoid N+1 queries)
if model.provider_id not in provider_cache:
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
provider = provider_cache[model.provider_id]
# Skip non-base providers (they're handled separately)
if provider.provider_category != ProviderCategory.base:
continue
embedding_config = EmbeddingConfig(
embedding_model=model.name,
embedding_endpoint_type=model.model_endpoint_type,
embedding_endpoint=provider.base_url or model.model_endpoint_type,
embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default
embedding_dim=model.embedding_dim or 1536,
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=model.handle,
)
embedding_models.append(embedding_config)
# Get BYOK provider models by hitting provider endpoints directly
byok_providers = await self.provider_manager.list_providers_async(
actor=actor,
provider_category=[ProviderCategory.byok],
)
for provider in byok_providers:
try:
typed_provider = provider.cast_to_subtype()
models = await typed_provider.list_embedding_models_async()
embedding_models.extend(models)
except Exception as e:
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider.name}: {e}")
# Sort by provider order (matching old _enabled_providers order), then by model name
embedding_models.sort(key=self._get_embedding_provider_sort_key)
return embedding_models
async def get_enabled_providers_async(

View File

@@ -240,6 +240,7 @@ class ProviderManager:
actor: PydanticUser,
name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
provider_category: Optional[List[ProviderCategory]] = None,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
@@ -280,7 +281,14 @@ class ProviderManager:
)
# Combine both lists
all_providers = org_providers + global_providers
all_providers = []
if not provider_category:
all_providers = org_providers + global_providers
else:
if ProviderCategory.byok in provider_category:
all_providers += org_providers
if ProviderCategory.base in provider_category:
all_providers += global_providers
# Remove deprecated api_key and access_key fields from the response
for provider in all_providers:
@@ -575,13 +583,14 @@ class ProviderManager:
continue
# Convert Provider to ProviderCreate
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
# NOTE: Do NOT store API keys for base providers in the database.
# Base providers should always use environment variables for API keys.
# This ensures keys stay in sync with env vars and aren't duplicated in DB.
provider_create = ProviderCreate(
name=provider.name,
provider_type=provider.provider_type,
api_key=api_key or "", # ProviderCreate requires api_key, use empty string if None
access_key=access_key,
api_key="", # Base providers use env vars, not DB-stored keys
access_key=None,
region=provider.region,
base_url=provider.base_url,
api_version=provider.api_version,

View File

@@ -2089,3 +2089,341 @@ async def test_get_enabled_providers_async_queries_database(default_user, provid
assert f"test-base-provider-{test_id}" in openai_names
assert f"test-byok-provider-{test_id}" not in openai_names # This is anthropic type
# =============================================================================
# BYOK Provider and Model Listing Integration Tests
# =============================================================================
@pytest.mark.asyncio
async def test_list_providers_filters_by_category(default_user, provider_manager):
"""Test that list_providers_async correctly filters by provider_category."""
test_id = generate_test_id()
# Create a base provider
base_provider_create = ProviderCreate(
name=f"test-base-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-base-key",
)
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
# Create a BYOK provider
byok_provider_create = ProviderCreate(
name=f"test-byok-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-byok-key",
)
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
# Verify base provider has correct category
assert base_provider.provider_category == ProviderCategory.base
# Verify BYOK provider has correct category
assert byok_provider.provider_category == ProviderCategory.byok
# List only BYOK providers
byok_providers = await provider_manager.list_providers_async(
actor=default_user,
provider_category=[ProviderCategory.byok],
)
byok_names = [p.name for p in byok_providers]
assert f"test-byok-{test_id}" in byok_names
assert f"test-base-{test_id}" not in byok_names
# List only base providers
base_providers = await provider_manager.list_providers_async(
actor=default_user,
provider_category=[ProviderCategory.base],
)
base_names = [p.name for p in base_providers]
assert f"test-base-{test_id}" in base_names
assert f"test-byok-{test_id}" not in base_names
@pytest.mark.asyncio
async def test_base_provider_api_key_not_stored_in_db(default_user, provider_manager):
"""Test that sync_base_providers does NOT store API keys for base providers."""
# Create base providers with API keys
base_providers = [
OpenAIProvider(name="test-openai-no-key", api_key="sk-should-not-be-stored"),
]
# Sync to database
await provider_manager.sync_base_providers(base_providers=base_providers, actor=default_user)
# Retrieve the provider from database
providers = await provider_manager.list_providers_async(name="test-openai-no-key", actor=default_user)
assert len(providers) == 1
provider = providers[0]
assert provider.provider_category == ProviderCategory.base
# The API key should be empty (not stored) for base providers
if provider.api_key_enc:
api_key = await provider.api_key_enc.get_plaintext_async()
assert api_key == "" or api_key is None, "Base provider API key should not be stored in database"
@pytest.mark.asyncio
async def test_byok_provider_api_key_stored_in_db(default_user, provider_manager):
"""Test that BYOK providers DO have their API keys stored in the database."""
test_id = generate_test_id()
# Create a BYOK provider with API key
byok_provider_create = ProviderCreate(
name=f"test-byok-with-key-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-byok-should-be-stored",
)
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
# Retrieve the provider from database
providers = await provider_manager.list_providers_async(name=f"test-byok-with-key-{test_id}", actor=default_user)
assert len(providers) == 1
provider = providers[0]
assert provider.provider_category == ProviderCategory.byok
# The API key SHOULD be stored for BYOK providers
assert provider.api_key_enc is not None
api_key = await provider.api_key_enc.get_plaintext_async()
assert api_key == "sk-byok-should-be-stored", "BYOK provider API key should be stored in database"
@pytest.mark.asyncio
async def test_server_list_llm_models_base_from_db(default_user, provider_manager):
"""Test that server.list_llm_models_async fetches base models from database."""
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create base provider and models (these ARE stored in DB)
base_provider_create = ProviderCreate(
name=f"test-base-llm-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-base-key",
)
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
base_llm_model = LLMConfig(
model=f"base-gpt-4o-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-base-llm-{test_id}/gpt-4o",
provider_name=base_provider.name,
provider_category=ProviderCategory.base,
)
await provider_manager.sync_provider_models_async(
provider=base_provider,
llm_models=[base_llm_model],
embedding_models=[],
organization_id=None,
)
# Create server and list models
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
server._enabled_providers = [] # Clear to test database-backed listing
# List all models - base models come from DB
all_models = await server.list_llm_models_async(actor=default_user)
all_handles = [m.handle for m in all_models]
assert f"test-base-llm-{test_id}/gpt-4o" in all_handles, "Base model should be in list"
# List only base models
base_models = await server.list_llm_models_async(
actor=default_user,
provider_category=[ProviderCategory.base],
)
base_handles = [m.handle for m in base_models]
assert f"test-base-llm-{test_id}/gpt-4o" in base_handles
@pytest.mark.asyncio
async def test_server_list_llm_models_byok_from_provider_api(default_user, provider_manager):
"""Test that server.list_llm_models_async fetches BYOK models from provider API, not DB.
Note: BYOK models are fetched by calling the provider's list_llm_models_async() method,
which hits the actual provider API. This test uses mocking to verify that flow.
"""
from letta.schemas.providers import Provider
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create a BYOK provider (but don't sync models to DB - they come from API)
byok_provider_create = ProviderCreate(
name=f"test-byok-llm-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-byok-key",
)
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
# Create server
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
server._enabled_providers = []
# Mock the BYOK provider's list_llm_models_async to return test models
mock_byok_models = [
LLMConfig(
model=f"byok-gpt-4o-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://custom.openai.com/v1",
context_window=64000,
handle=f"test-byok-llm-{test_id}/gpt-4o",
provider_name=byok_provider.name,
provider_category=ProviderCategory.byok,
)
]
# Create a mock typed provider that returns our test models
mock_typed_provider = MagicMock()
mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_byok_models)
# Patch cast_to_subtype on the Provider class to return our mock
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
# List BYOK models - should call provider API via cast_to_subtype().list_llm_models_async()
byok_models = await server.list_llm_models_async(
actor=default_user,
provider_category=[ProviderCategory.byok],
)
# Verify the mock was called (proving we hit provider API, not DB)
mock_typed_provider.list_llm_models_async.assert_called()
# Verify we got the mocked models back
byok_handles = [m.handle for m in byok_models]
assert f"test-byok-llm-{test_id}/gpt-4o" in byok_handles
@pytest.mark.asyncio
async def test_server_list_embedding_models_base_from_db(default_user, provider_manager):
"""Test that server.list_embedding_models_async fetches base models from database.
Note: Similar to LLM models, base embedding models are stored in DB while BYOK
embedding models would be fetched from provider API.
"""
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create base provider and embedding models (these ARE stored in DB)
base_provider_create = ProviderCreate(
name=f"test-base-embed-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-base-key",
)
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
base_embedding_model = EmbeddingConfig(
embedding_model=f"base-text-embedding-{test_id}",
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=300,
handle=f"test-base-embed-{test_id}/text-embedding-3-small",
)
await provider_manager.sync_provider_models_async(
provider=base_provider,
llm_models=[],
embedding_models=[base_embedding_model],
organization_id=None,
)
# Create server and list models
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
server._enabled_providers = []
# List all embedding models - base models come from DB
all_models = await server.list_embedding_models_async(actor=default_user)
all_handles = [m.handle for m in all_models]
assert f"test-base-embed-{test_id}/text-embedding-3-small" in all_handles
@pytest.mark.asyncio
async def test_provider_ordering_matches_constants(default_user, provider_manager):
"""Test that provider ordering in model listing matches PROVIDER_ORDER in constants."""
from letta.constants import PROVIDER_ORDER
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create providers with different names that should have different ordering
providers_to_create = [
("zai", ProviderType.zai, 14), # Lower priority
("openai", ProviderType.openai, 1), # Higher priority
("anthropic", ProviderType.anthropic, 2), # Medium priority
]
created_providers = []
for name_suffix, provider_type, expected_order in providers_to_create:
provider_create = ProviderCreate(
name=f"{name_suffix}", # Use actual provider name for ordering
provider_type=provider_type,
api_key=f"sk-{name_suffix}-key",
)
# Check if provider already exists
existing = await provider_manager.list_providers_async(name=name_suffix, actor=default_user)
if not existing:
provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=False)
created_providers.append((provider, expected_order))
# Create a model for this provider
llm_model = LLMConfig(
model=f"test-model-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.example.com/v1",
context_window=8192,
handle=f"{name_suffix}/test-model-{test_id}",
provider_name=provider.name,
provider_category=ProviderCategory.base,
)
await provider_manager.sync_provider_models_async(
provider=provider,
llm_models=[llm_model],
embedding_models=[],
organization_id=None,
)
# Verify PROVIDER_ORDER has expected values
assert PROVIDER_ORDER.get("openai") == 1
assert PROVIDER_ORDER.get("anthropic") == 2
assert PROVIDER_ORDER.get("zai") == 14
# Create server and verify ordering
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
server._enabled_providers = []
# List models and check ordering
all_models = await server.list_llm_models_async(actor=default_user)
# Filter to only our test models
test_models = [m for m in all_models if f"test-model-{test_id}" in m.handle]
if len(test_models) >= 2:
# Verify models are sorted by provider order
provider_names_in_order = [m.provider_name for m in test_models]
# Get the indices in PROVIDER_ORDER
indices = [PROVIDER_ORDER.get(name, 999) for name in provider_names_in_order]
# Verify the list is sorted by provider order
assert indices == sorted(indices), f"Models should be sorted by PROVIDER_ORDER, got: {provider_names_in_order}"