fix: fix providers and models persistence (#8302)

This commit is contained in:
Ari Webb
2026-01-05 18:05:44 -08:00
committed by Caren Thomas
parent e56c5c5b49
commit 02f3e3f3b9
6 changed files with 491 additions and 67 deletions

View File

@@ -2089,3 +2089,341 @@ async def test_get_enabled_providers_async_queries_database(default_user, provid
assert f"test-base-provider-{test_id}" in openai_names
assert f"test-byok-provider-{test_id}" not in openai_names # This is anthropic type
# =============================================================================
# BYOK Provider and Model Listing Integration Tests
# =============================================================================
@pytest.mark.asyncio
async def test_list_providers_filters_by_category(default_user, provider_manager):
"""Test that list_providers_async correctly filters by provider_category."""
test_id = generate_test_id()
# Create a base provider
base_provider_create = ProviderCreate(
name=f"test-base-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-base-key",
)
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
# Create a BYOK provider
byok_provider_create = ProviderCreate(
name=f"test-byok-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-byok-key",
)
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
# Verify base provider has correct category
assert base_provider.provider_category == ProviderCategory.base
# Verify BYOK provider has correct category
assert byok_provider.provider_category == ProviderCategory.byok
# List only BYOK providers
byok_providers = await provider_manager.list_providers_async(
actor=default_user,
provider_category=[ProviderCategory.byok],
)
byok_names = [p.name for p in byok_providers]
assert f"test-byok-{test_id}" in byok_names
assert f"test-base-{test_id}" not in byok_names
# List only base providers
base_providers = await provider_manager.list_providers_async(
actor=default_user,
provider_category=[ProviderCategory.base],
)
base_names = [p.name for p in base_providers]
assert f"test-base-{test_id}" in base_names
assert f"test-byok-{test_id}" not in base_names
@pytest.mark.asyncio
async def test_base_provider_api_key_not_stored_in_db(default_user, provider_manager):
"""Test that sync_base_providers does NOT store API keys for base providers."""
# Create base providers with API keys
base_providers = [
OpenAIProvider(name="test-openai-no-key", api_key="sk-should-not-be-stored"),
]
# Sync to database
await provider_manager.sync_base_providers(base_providers=base_providers, actor=default_user)
# Retrieve the provider from database
providers = await provider_manager.list_providers_async(name="test-openai-no-key", actor=default_user)
assert len(providers) == 1
provider = providers[0]
assert provider.provider_category == ProviderCategory.base
# The API key should be empty (not stored) for base providers
if provider.api_key_enc:
api_key = await provider.api_key_enc.get_plaintext_async()
assert api_key == "" or api_key is None, "Base provider API key should not be stored in database"
@pytest.mark.asyncio
async def test_byok_provider_api_key_stored_in_db(default_user, provider_manager):
"""Test that BYOK providers DO have their API keys stored in the database."""
test_id = generate_test_id()
# Create a BYOK provider with API key
byok_provider_create = ProviderCreate(
name=f"test-byok-with-key-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-byok-should-be-stored",
)
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
# Retrieve the provider from database
providers = await provider_manager.list_providers_async(name=f"test-byok-with-key-{test_id}", actor=default_user)
assert len(providers) == 1
provider = providers[0]
assert provider.provider_category == ProviderCategory.byok
# The API key SHOULD be stored for BYOK providers
assert provider.api_key_enc is not None
api_key = await provider.api_key_enc.get_plaintext_async()
assert api_key == "sk-byok-should-be-stored", "BYOK provider API key should be stored in database"
@pytest.mark.asyncio
async def test_server_list_llm_models_base_from_db(default_user, provider_manager):
"""Test that server.list_llm_models_async fetches base models from database."""
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create base provider and models (these ARE stored in DB)
base_provider_create = ProviderCreate(
name=f"test-base-llm-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-base-key",
)
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
base_llm_model = LLMConfig(
model=f"base-gpt-4o-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-base-llm-{test_id}/gpt-4o",
provider_name=base_provider.name,
provider_category=ProviderCategory.base,
)
await provider_manager.sync_provider_models_async(
provider=base_provider,
llm_models=[base_llm_model],
embedding_models=[],
organization_id=None,
)
# Create server and list models
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
server._enabled_providers = [] # Clear to test database-backed listing
# List all models - base models come from DB
all_models = await server.list_llm_models_async(actor=default_user)
all_handles = [m.handle for m in all_models]
assert f"test-base-llm-{test_id}/gpt-4o" in all_handles, "Base model should be in list"
# List only base models
base_models = await server.list_llm_models_async(
actor=default_user,
provider_category=[ProviderCategory.base],
)
base_handles = [m.handle for m in base_models]
assert f"test-base-llm-{test_id}/gpt-4o" in base_handles
@pytest.mark.asyncio
async def test_server_list_llm_models_byok_from_provider_api(default_user, provider_manager):
"""Test that server.list_llm_models_async fetches BYOK models from provider API, not DB.
Note: BYOK models are fetched by calling the provider's list_llm_models_async() method,
which hits the actual provider API. This test uses mocking to verify that flow.
"""
from letta.schemas.providers import Provider
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create a BYOK provider (but don't sync models to DB - they come from API)
byok_provider_create = ProviderCreate(
name=f"test-byok-llm-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-byok-key",
)
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
# Create server
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
server._enabled_providers = []
# Mock the BYOK provider's list_llm_models_async to return test models
mock_byok_models = [
LLMConfig(
model=f"byok-gpt-4o-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://custom.openai.com/v1",
context_window=64000,
handle=f"test-byok-llm-{test_id}/gpt-4o",
provider_name=byok_provider.name,
provider_category=ProviderCategory.byok,
)
]
# Create a mock typed provider that returns our test models
mock_typed_provider = MagicMock()
mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_byok_models)
# Patch cast_to_subtype on the Provider class to return our mock
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
# List BYOK models - should call provider API via cast_to_subtype().list_llm_models_async()
byok_models = await server.list_llm_models_async(
actor=default_user,
provider_category=[ProviderCategory.byok],
)
# Verify the mock was called (proving we hit provider API, not DB)
mock_typed_provider.list_llm_models_async.assert_called()
# Verify we got the mocked models back
byok_handles = [m.handle for m in byok_models]
assert f"test-byok-llm-{test_id}/gpt-4o" in byok_handles
@pytest.mark.asyncio
async def test_server_list_embedding_models_base_from_db(default_user, provider_manager):
"""Test that server.list_embedding_models_async fetches base models from database.
Note: Similar to LLM models, base embedding models are stored in DB while BYOK
embedding models would be fetched from provider API.
"""
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create base provider and embedding models (these ARE stored in DB)
base_provider_create = ProviderCreate(
name=f"test-base-embed-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-base-key",
)
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
base_embedding_model = EmbeddingConfig(
embedding_model=f"base-text-embedding-{test_id}",
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=300,
handle=f"test-base-embed-{test_id}/text-embedding-3-small",
)
await provider_manager.sync_provider_models_async(
provider=base_provider,
llm_models=[],
embedding_models=[base_embedding_model],
organization_id=None,
)
# Create server and list models
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
server._enabled_providers = []
# List all embedding models - base models come from DB
all_models = await server.list_embedding_models_async(actor=default_user)
all_handles = [m.handle for m in all_models]
assert f"test-base-embed-{test_id}/text-embedding-3-small" in all_handles
@pytest.mark.asyncio
async def test_provider_ordering_matches_constants(default_user, provider_manager):
"""Test that provider ordering in model listing matches PROVIDER_ORDER in constants."""
from letta.constants import PROVIDER_ORDER
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create providers with different names that should have different ordering
providers_to_create = [
("zai", ProviderType.zai, 14), # Lower priority
("openai", ProviderType.openai, 1), # Higher priority
("anthropic", ProviderType.anthropic, 2), # Medium priority
]
created_providers = []
for name_suffix, provider_type, expected_order in providers_to_create:
provider_create = ProviderCreate(
name=f"{name_suffix}", # Use actual provider name for ordering
provider_type=provider_type,
api_key=f"sk-{name_suffix}-key",
)
# Check if provider already exists
existing = await provider_manager.list_providers_async(name=name_suffix, actor=default_user)
if not existing:
provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=False)
created_providers.append((provider, expected_order))
# Create a model for this provider
llm_model = LLMConfig(
model=f"test-model-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.example.com/v1",
context_window=8192,
handle=f"{name_suffix}/test-model-{test_id}",
provider_name=provider.name,
provider_category=ProviderCategory.base,
)
await provider_manager.sync_provider_models_async(
provider=provider,
llm_models=[llm_model],
embedding_models=[],
organization_id=None,
)
# Verify PROVIDER_ORDER has expected values
assert PROVIDER_ORDER.get("openai") == 1
assert PROVIDER_ORDER.get("anthropic") == 2
assert PROVIDER_ORDER.get("zai") == 14
# Create server and verify ordering
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
server._enabled_providers = []
# List models and check ordering
all_models = await server.list_llm_models_async(actor=default_user)
# Filter to only our test models
test_models = [m for m in all_models if f"test-model-{test_id}" in m.handle]
if len(test_models) >= 2:
# Verify models are sorted by provider order
provider_names_in_order = [m.provider_name for m in test_models]
# Get the indices in PROVIDER_ORDER
indices = [PROVIDER_ORDER.get(name, 999) for name in provider_names_in_order]
# Verify the list is sorted by provider order
assert indices == sorted(indices), f"Models should be sorted by PROVIDER_ORDER, got: {provider_names_in_order}"