fix: fix providers and models persistence (#8302)
This commit is contained in:
@@ -2089,3 +2089,341 @@ async def test_get_enabled_providers_async_queries_database(default_user, provid
|
||||
|
||||
assert f"test-base-provider-{test_id}" in openai_names
|
||||
assert f"test-byok-provider-{test_id}" not in openai_names # This is anthropic type
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BYOK Provider and Model Listing Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_providers_filters_by_category(default_user, provider_manager):
|
||||
"""Test that list_providers_async correctly filters by provider_category."""
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create a base provider
|
||||
base_provider_create = ProviderCreate(
|
||||
name=f"test-base-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-base-key",
|
||||
)
|
||||
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
|
||||
|
||||
# Create a BYOK provider
|
||||
byok_provider_create = ProviderCreate(
|
||||
name=f"test-byok-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-byok-key",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Verify base provider has correct category
|
||||
assert base_provider.provider_category == ProviderCategory.base
|
||||
|
||||
# Verify BYOK provider has correct category
|
||||
assert byok_provider.provider_category == ProviderCategory.byok
|
||||
|
||||
# List only BYOK providers
|
||||
byok_providers = await provider_manager.list_providers_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
)
|
||||
byok_names = [p.name for p in byok_providers]
|
||||
|
||||
assert f"test-byok-{test_id}" in byok_names
|
||||
assert f"test-base-{test_id}" not in byok_names
|
||||
|
||||
# List only base providers
|
||||
base_providers = await provider_manager.list_providers_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.base],
|
||||
)
|
||||
base_names = [p.name for p in base_providers]
|
||||
|
||||
assert f"test-base-{test_id}" in base_names
|
||||
assert f"test-byok-{test_id}" not in base_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_provider_api_key_not_stored_in_db(default_user, provider_manager):
|
||||
"""Test that sync_base_providers does NOT store API keys for base providers."""
|
||||
# Create base providers with API keys
|
||||
base_providers = [
|
||||
OpenAIProvider(name="test-openai-no-key", api_key="sk-should-not-be-stored"),
|
||||
]
|
||||
|
||||
# Sync to database
|
||||
await provider_manager.sync_base_providers(base_providers=base_providers, actor=default_user)
|
||||
|
||||
# Retrieve the provider from database
|
||||
providers = await provider_manager.list_providers_async(name="test-openai-no-key", actor=default_user)
|
||||
assert len(providers) == 1
|
||||
|
||||
provider = providers[0]
|
||||
assert provider.provider_category == ProviderCategory.base
|
||||
|
||||
# The API key should be empty (not stored) for base providers
|
||||
if provider.api_key_enc:
|
||||
api_key = await provider.api_key_enc.get_plaintext_async()
|
||||
assert api_key == "" or api_key is None, "Base provider API key should not be stored in database"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byok_provider_api_key_stored_in_db(default_user, provider_manager):
|
||||
"""Test that BYOK providers DO have their API keys stored in the database."""
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create a BYOK provider with API key
|
||||
byok_provider_create = ProviderCreate(
|
||||
name=f"test-byok-with-key-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-byok-should-be-stored",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Retrieve the provider from database
|
||||
providers = await provider_manager.list_providers_async(name=f"test-byok-with-key-{test_id}", actor=default_user)
|
||||
assert len(providers) == 1
|
||||
|
||||
provider = providers[0]
|
||||
assert provider.provider_category == ProviderCategory.byok
|
||||
|
||||
# The API key SHOULD be stored for BYOK providers
|
||||
assert provider.api_key_enc is not None
|
||||
api_key = await provider.api_key_enc.get_plaintext_async()
|
||||
assert api_key == "sk-byok-should-be-stored", "BYOK provider API key should be stored in database"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_list_llm_models_base_from_db(default_user, provider_manager):
|
||||
"""Test that server.list_llm_models_async fetches base models from database."""
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create base provider and models (these ARE stored in DB)
|
||||
base_provider_create = ProviderCreate(
|
||||
name=f"test-base-llm-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-base-key",
|
||||
)
|
||||
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
|
||||
|
||||
base_llm_model = LLMConfig(
|
||||
model=f"base-gpt-4o-{test_id}",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle=f"test-base-llm-{test_id}/gpt-4o",
|
||||
provider_name=base_provider.name,
|
||||
provider_category=ProviderCategory.base,
|
||||
)
|
||||
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=base_provider,
|
||||
llm_models=[base_llm_model],
|
||||
embedding_models=[],
|
||||
organization_id=None,
|
||||
)
|
||||
|
||||
# Create server and list models
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.provider_manager = provider_manager
|
||||
server._enabled_providers = [] # Clear to test database-backed listing
|
||||
|
||||
# List all models - base models come from DB
|
||||
all_models = await server.list_llm_models_async(actor=default_user)
|
||||
all_handles = [m.handle for m in all_models]
|
||||
|
||||
assert f"test-base-llm-{test_id}/gpt-4o" in all_handles, "Base model should be in list"
|
||||
|
||||
# List only base models
|
||||
base_models = await server.list_llm_models_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.base],
|
||||
)
|
||||
base_handles = [m.handle for m in base_models]
|
||||
|
||||
assert f"test-base-llm-{test_id}/gpt-4o" in base_handles
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_list_llm_models_byok_from_provider_api(default_user, provider_manager):
|
||||
"""Test that server.list_llm_models_async fetches BYOK models from provider API, not DB.
|
||||
|
||||
Note: BYOK models are fetched by calling the provider's list_llm_models_async() method,
|
||||
which hits the actual provider API. This test uses mocking to verify that flow.
|
||||
"""
|
||||
from letta.schemas.providers import Provider
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create a BYOK provider (but don't sync models to DB - they come from API)
|
||||
byok_provider_create = ProviderCreate(
|
||||
name=f"test-byok-llm-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-byok-key",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Create server
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.provider_manager = provider_manager
|
||||
server._enabled_providers = []
|
||||
|
||||
# Mock the BYOK provider's list_llm_models_async to return test models
|
||||
mock_byok_models = [
|
||||
LLMConfig(
|
||||
model=f"byok-gpt-4o-{test_id}",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://custom.openai.com/v1",
|
||||
context_window=64000,
|
||||
handle=f"test-byok-llm-{test_id}/gpt-4o",
|
||||
provider_name=byok_provider.name,
|
||||
provider_category=ProviderCategory.byok,
|
||||
)
|
||||
]
|
||||
|
||||
# Create a mock typed provider that returns our test models
|
||||
mock_typed_provider = MagicMock()
|
||||
mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_byok_models)
|
||||
|
||||
# Patch cast_to_subtype on the Provider class to return our mock
|
||||
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
|
||||
# List BYOK models - should call provider API via cast_to_subtype().list_llm_models_async()
|
||||
byok_models = await server.list_llm_models_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
)
|
||||
|
||||
# Verify the mock was called (proving we hit provider API, not DB)
|
||||
mock_typed_provider.list_llm_models_async.assert_called()
|
||||
|
||||
# Verify we got the mocked models back
|
||||
byok_handles = [m.handle for m in byok_models]
|
||||
assert f"test-byok-llm-{test_id}/gpt-4o" in byok_handles
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_list_embedding_models_base_from_db(default_user, provider_manager):
|
||||
"""Test that server.list_embedding_models_async fetches base models from database.
|
||||
|
||||
Note: Similar to LLM models, base embedding models are stored in DB while BYOK
|
||||
embedding models would be fetched from provider API.
|
||||
"""
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create base provider and embedding models (these ARE stored in DB)
|
||||
base_provider_create = ProviderCreate(
|
||||
name=f"test-base-embed-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-base-key",
|
||||
)
|
||||
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
|
||||
|
||||
base_embedding_model = EmbeddingConfig(
|
||||
embedding_model=f"base-text-embedding-{test_id}",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
handle=f"test-base-embed-{test_id}/text-embedding-3-small",
|
||||
)
|
||||
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=base_provider,
|
||||
llm_models=[],
|
||||
embedding_models=[base_embedding_model],
|
||||
organization_id=None,
|
||||
)
|
||||
|
||||
# Create server and list models
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.provider_manager = provider_manager
|
||||
server._enabled_providers = []
|
||||
|
||||
# List all embedding models - base models come from DB
|
||||
all_models = await server.list_embedding_models_async(actor=default_user)
|
||||
all_handles = [m.handle for m in all_models]
|
||||
|
||||
assert f"test-base-embed-{test_id}/text-embedding-3-small" in all_handles
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_ordering_matches_constants(default_user, provider_manager):
|
||||
"""Test that provider ordering in model listing matches PROVIDER_ORDER in constants."""
|
||||
from letta.constants import PROVIDER_ORDER
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create providers with different names that should have different ordering
|
||||
providers_to_create = [
|
||||
("zai", ProviderType.zai, 14), # Lower priority
|
||||
("openai", ProviderType.openai, 1), # Higher priority
|
||||
("anthropic", ProviderType.anthropic, 2), # Medium priority
|
||||
]
|
||||
|
||||
created_providers = []
|
||||
for name_suffix, provider_type, expected_order in providers_to_create:
|
||||
provider_create = ProviderCreate(
|
||||
name=f"{name_suffix}", # Use actual provider name for ordering
|
||||
provider_type=provider_type,
|
||||
api_key=f"sk-{name_suffix}-key",
|
||||
)
|
||||
# Check if provider already exists
|
||||
existing = await provider_manager.list_providers_async(name=name_suffix, actor=default_user)
|
||||
if not existing:
|
||||
provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=False)
|
||||
created_providers.append((provider, expected_order))
|
||||
|
||||
# Create a model for this provider
|
||||
llm_model = LLMConfig(
|
||||
model=f"test-model-{test_id}",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.example.com/v1",
|
||||
context_window=8192,
|
||||
handle=f"{name_suffix}/test-model-{test_id}",
|
||||
provider_name=provider.name,
|
||||
provider_category=ProviderCategory.base,
|
||||
)
|
||||
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=provider,
|
||||
llm_models=[llm_model],
|
||||
embedding_models=[],
|
||||
organization_id=None,
|
||||
)
|
||||
|
||||
# Verify PROVIDER_ORDER has expected values
|
||||
assert PROVIDER_ORDER.get("openai") == 1
|
||||
assert PROVIDER_ORDER.get("anthropic") == 2
|
||||
assert PROVIDER_ORDER.get("zai") == 14
|
||||
|
||||
# Create server and verify ordering
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.provider_manager = provider_manager
|
||||
server._enabled_providers = []
|
||||
|
||||
# List models and check ordering
|
||||
all_models = await server.list_llm_models_async(actor=default_user)
|
||||
|
||||
# Filter to only our test models
|
||||
test_models = [m for m in all_models if f"test-model-{test_id}" in m.handle]
|
||||
|
||||
if len(test_models) >= 2:
|
||||
# Verify models are sorted by provider order
|
||||
provider_names_in_order = [m.provider_name for m in test_models]
|
||||
|
||||
# Get the indices in PROVIDER_ORDER
|
||||
indices = [PROVIDER_ORDER.get(name, 999) for name in provider_names_in_order]
|
||||
|
||||
# Verify the list is sorted by provider order
|
||||
assert indices == sorted(indices), f"Models should be sorted by PROVIDER_ORDER, got: {provider_names_in_order}"
|
||||
|
||||
Reference in New Issue
Block a user