feat: enable provider models persistence (#6193)

* Revert "fix test"

This reverts commit 5126815f23cefb4edad3e3bf9e7083209dcc7bf1.

* fix server and better test

* test fix, get api key for base and byok?

* set letta default endpoint

* try to fix timeout for test

* fix for letta api key

* Delete apps/core/tests/sdk_v1/conftest.py

* Update utils.py

* clean up a few issues

* fix filterning on list_llm_models

* soft delete models with provider

* add one more test

* fix ci

* add timeout

* band aid for letta embedding provider

* info instead of error logs when creating models
This commit is contained in:
Ari Webb
2025-12-09 14:33:06 -08:00
committed by Caren Thomas
parent b4af037c19
commit 848a73125c
8 changed files with 754 additions and 205 deletions

View File

@@ -499,3 +499,436 @@ async def test_byok_provider_auto_syncs_models(provider_manager, default_user, m
llm_config = await provider_manager.get_llm_config_from_handle(handle="my-openai-key/gpt-4o", actor=default_user)
assert llm_config.model == "gpt-4o"
assert llm_config.provider_name == "my-openai-key"
# ======================================================================================================================
# Server Startup Provider Sync Tests
# ======================================================================================================================
@pytest.mark.asyncio
async def test_server_startup_syncs_base_providers(default_user, default_organization, monkeypatch):
"""Test that server startup properly syncs base provider models from environment.
This test simulates the server startup process and verifies that:
1. Base providers from environment variables are synced to database
2. Provider models are fetched from mocked API endpoints
3. Models are properly persisted to the database with correct metadata
4. Models can be retrieved using handles
"""
from unittest.mock import AsyncMock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers import AnthropicProvider, OpenAIProvider
from letta.server.server import SyncServer
# Mock OpenAI API responses
mock_openai_models = {
"data": [
{
"id": "gpt-4",
"object": "model",
"created": 1687882411,
"owned_by": "openai",
"max_model_len": 8192,
},
{
"id": "gpt-4-turbo",
"object": "model",
"created": 1712361441,
"owned_by": "system",
"max_model_len": 128000,
},
{
"id": "text-embedding-ada-002",
"object": "model",
"created": 1671217299,
"owned_by": "openai-internal",
},
{
"id": "gpt-4-vision", # Should be filtered out by OpenAI provider logic (has disallowed keyword)
"object": "model",
"created": 1698959748,
"owned_by": "system",
"max_model_len": 8192,
},
]
}
# Mock Anthropic API responses
mock_anthropic_models = {
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"type": "model",
"display_name": "Claude 3.5 Sonnet",
"created_at": "2024-10-22T00:00:00Z",
},
{
"id": "claude-3-opus-20240229",
"type": "model",
"display_name": "Claude 3 Opus",
"created_at": "2024-02-29T00:00:00Z",
},
]
}
# Mock the API calls for OpenAI
async def mock_openai_get_model_list_async(*args, **kwargs):
return mock_openai_models
# Mock Anthropic models.list() response
from unittest.mock import MagicMock
mock_anthropic_response = MagicMock()
mock_anthropic_response.model_dump.return_value = mock_anthropic_models
# Mock the Anthropic AsyncAnthropic client
class MockAnthropicModels:
async def list(self):
return mock_anthropic_response
class MockAsyncAnthropic:
def __init__(self, *args, **kwargs):
self.models = MockAnthropicModels()
# Patch the actual API calling functions
monkeypatch.setattr(
"letta.llm_api.openai.openai_get_model_list_async",
mock_openai_get_model_list_async,
)
monkeypatch.setattr(
"anthropic.AsyncAnthropic",
MockAsyncAnthropic,
)
# Clear ALL provider-related env vars first to ensure clean state
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
monkeypatch.delenv("AZURE_API_KEY", raising=False)
monkeypatch.delenv("GROQ_API_KEY", raising=False)
monkeypatch.delenv("TOGETHER_API_KEY", raising=False)
monkeypatch.delenv("VLLM_API_BASE", raising=False)
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
monkeypatch.delenv("LMSTUDIO_BASE_URL", raising=False)
monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False)
monkeypatch.delenv("XAI_API_KEY", raising=False)
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
# Set environment variables to enable only OpenAI and Anthropic
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key-12345")
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key-67890")
# Reload model_settings to pick up new env vars
from letta.settings import model_settings
monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key-12345")
monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key-67890")
monkeypatch.setattr(model_settings, "gemini_api_key", None)
monkeypatch.setattr(model_settings, "google_cloud_project", None)
monkeypatch.setattr(model_settings, "google_cloud_location", None)
monkeypatch.setattr(model_settings, "azure_api_key", None)
monkeypatch.setattr(model_settings, "groq_api_key", None)
monkeypatch.setattr(model_settings, "together_api_key", None)
monkeypatch.setattr(model_settings, "vllm_api_base", None)
monkeypatch.setattr(model_settings, "aws_access_key_id", None)
monkeypatch.setattr(model_settings, "aws_secret_access_key", None)
monkeypatch.setattr(model_settings, "lmstudio_base_url", None)
monkeypatch.setattr(model_settings, "deepseek_api_key", None)
monkeypatch.setattr(model_settings, "xai_api_key", None)
monkeypatch.setattr(model_settings, "openrouter_api_key", None)
# Create server instance (this will load enabled providers from environment)
server = SyncServer(init_with_default_org_and_user=False)
# Manually set up the default user/org (since we disabled auto-init)
server.default_user = default_user
server.default_org = default_organization
# Verify enabled providers were loaded
assert len(server._enabled_providers) == 3 # Exactly: letta, openai, anthropic
enabled_provider_names = [p.name for p in server._enabled_providers]
assert "letta" in enabled_provider_names
assert "openai" in enabled_provider_names
assert "anthropic" in enabled_provider_names
# First, sync base providers to database (this is what init_async does)
await server.provider_manager.sync_base_providers(
base_providers=server._enabled_providers,
actor=default_user,
)
# Now call the actual _sync_provider_models_async method
# This simulates what happens during server startup
await server._sync_provider_models_async()
# Verify OpenAI models were synced
openai_providers = await server.provider_manager.list_providers_async(
name="openai",
actor=default_user,
)
assert len(openai_providers) == 1, "OpenAI provider should exist"
openai_provider = openai_providers[0]
# Check OpenAI LLM models
openai_llm_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=openai_provider.id,
model_type="llm",
)
# Should have gpt-4 and gpt-4-turbo (gpt-4-vision filtered out due to "vision" keyword)
assert len(openai_llm_models) >= 2, f"Expected at least 2 OpenAI LLM models, got {len(openai_llm_models)}"
openai_model_names = [m.name for m in openai_llm_models]
assert "gpt-4" in openai_model_names
assert "gpt-4-turbo" in openai_model_names
# Check OpenAI embedding models
openai_embedding_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=openai_provider.id,
model_type="embedding",
)
assert len(openai_embedding_models) >= 1, "Expected at least 1 OpenAI embedding model"
embedding_model_names = [m.name for m in openai_embedding_models]
assert "text-embedding-ada-002" in embedding_model_names
# Verify model metadata is correct
gpt4_models = [m for m in openai_llm_models if m.name == "gpt-4"]
assert len(gpt4_models) > 0, "gpt-4 model should exist"
gpt4_model = gpt4_models[0]
assert gpt4_model.handle == "openai/gpt-4"
assert gpt4_model.model_endpoint_type == "openai"
assert gpt4_model.max_context_window == 8192
assert gpt4_model.enabled is True
# Verify Anthropic models were synced
anthropic_providers = await server.provider_manager.list_providers_async(
name="anthropic",
actor=default_user,
)
assert len(anthropic_providers) == 1, "Anthropic provider should exist"
anthropic_provider = anthropic_providers[0]
anthropic_llm_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=anthropic_provider.id,
model_type="llm",
)
# Should have Claude models
assert len(anthropic_llm_models) >= 2, f"Expected at least 2 Anthropic models, got {len(anthropic_llm_models)}"
anthropic_model_names = [m.name for m in anthropic_llm_models]
assert "claude-3-5-sonnet-20241022" in anthropic_model_names
assert "claude-3-opus-20240229" in anthropic_model_names
# Test that we can retrieve LLMConfig from handle
llm_config = await server.provider_manager.get_llm_config_from_handle(
handle="openai/gpt-4",
actor=default_user,
)
assert llm_config.model == "gpt-4"
assert llm_config.handle == "openai/gpt-4"
assert llm_config.provider_name == "openai"
assert llm_config.context_window == 8192
# Test that we can retrieve EmbeddingConfig from handle
embedding_config = await server.provider_manager.get_embedding_config_from_handle(
handle="openai/text-embedding-ada-002",
actor=default_user,
)
assert embedding_config.embedding_model == "text-embedding-ada-002"
assert embedding_config.handle == "openai/text-embedding-ada-002"
assert embedding_config.embedding_dim == 1536
@pytest.mark.asyncio
async def test_server_startup_handles_disabled_providers(default_user, default_organization, monkeypatch):
"""Test that server startup properly handles providers that are no longer enabled.
This test verifies that:
1. Base providers that are no longer enabled (env vars removed) are deleted
2. BYOK providers that are no longer enabled are NOT deleted (user-created)
3. The sync process handles providers gracefully when API calls fail
"""
from letta.schemas.providers import OpenAIProvider, ProviderCreate
from letta.server.server import SyncServer
# First, manually create providers in the database
provider_manager = ProviderManager()
# Create a base OpenAI provider (simulating it was synced before)
base_openai_create = ProviderCreate(
name="openai",
provider_type=ProviderType.openai,
api_key="sk-old-key",
base_url="https://api.openai.com/v1",
)
base_openai = await provider_manager.create_provider_async(
base_openai_create,
actor=default_user,
is_byok=False, # This is a base provider
)
# Create a BYOK provider (user-created)
byok_provider_create = ProviderCreate(
name="my-custom-openai",
provider_type=ProviderType.openai,
api_key="sk-my-key",
base_url="https://api.openai.com/v1",
)
byok_provider = await provider_manager.create_provider_async(
byok_provider_create,
actor=default_user,
is_byok=True,
)
assert byok_provider.provider_category == ProviderCategory.byok
# Now create server with NO environment variables set (all base providers disabled)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
from letta.settings import model_settings
monkeypatch.setattr(model_settings, "openai_api_key", None)
monkeypatch.setattr(model_settings, "anthropic_api_key", None)
# Create server instance
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.default_org = default_organization
# Verify only letta provider is enabled (no openai)
enabled_names = [p.name for p in server._enabled_providers]
assert "letta" in enabled_names
assert "openai" not in enabled_names
# Sync base providers (should not include openai anymore)
await server.provider_manager.sync_base_providers(
base_providers=server._enabled_providers,
actor=default_user,
)
# Call _sync_provider_models_async
await server._sync_provider_models_async()
# Verify base OpenAI provider was deleted (no longer enabled)
try:
await server.provider_manager.get_provider_async(base_openai.id, actor=default_user)
assert False, "Base OpenAI provider should have been deleted"
except Exception:
# Expected - provider should not exist
pass
# Verify BYOK provider still exists (should NOT be deleted)
byok_still_exists = await server.provider_manager.get_provider_async(
byok_provider.id,
actor=default_user,
)
assert byok_still_exists is not None
assert byok_still_exists.name == "my-custom-openai"
assert byok_still_exists.provider_category == ProviderCategory.byok
@pytest.mark.asyncio
async def test_server_startup_handles_api_errors_gracefully(default_user, default_organization, monkeypatch):
"""Test that server startup handles API errors gracefully without crashing.
This test verifies that:
1. If a provider's API call fails during sync, it logs an error but continues
2. Other providers can still sync successfully
3. The server startup completes without crashing
"""
from letta.schemas.providers import AnthropicProvider, OpenAIProvider
from letta.server.server import SyncServer
# Mock OpenAI to fail
async def mock_openai_fail(*args, **kwargs):
raise Exception("OpenAI API is down")
# Mock Anthropic to succeed
from unittest.mock import MagicMock
mock_anthropic_response = MagicMock()
mock_anthropic_response.model_dump.return_value = {
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"type": "model",
"display_name": "Claude 3.5 Sonnet",
"created_at": "2024-10-22T00:00:00Z",
}
]
}
class MockAnthropicModels:
async def list(self):
return mock_anthropic_response
class MockAsyncAnthropic:
def __init__(self, *args, **kwargs):
self.models = MockAnthropicModels()
monkeypatch.setattr(
"letta.llm_api.openai.openai_get_model_list_async",
mock_openai_fail,
)
monkeypatch.setattr(
"anthropic.AsyncAnthropic",
MockAsyncAnthropic,
)
# Set environment variables
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key")
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key")
from letta.settings import model_settings
monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key")
monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key")
# Create server
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.default_org = default_organization
# Sync base providers
await server.provider_manager.sync_base_providers(
base_providers=server._enabled_providers,
actor=default_user,
)
# This should NOT crash even though OpenAI fails
await server._sync_provider_models_async()
# Verify Anthropic still synced successfully
anthropic_providers = await server.provider_manager.list_providers_async(
name="anthropic",
actor=default_user,
)
assert len(anthropic_providers) == 1
anthropic_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=anthropic_providers[0].id,
model_type="llm",
)
assert len(anthropic_models) >= 1, "Anthropic models should have synced despite OpenAI failure"
# OpenAI should have no models (sync failed)
openai_providers = await server.provider_manager.list_providers_async(
name="openai",
actor=default_user,
)
if len(openai_providers) > 0:
openai_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=openai_providers[0].id,
)
# Models might exist from previous runs, but the sync attempt should have been logged as failed
# The key is that the server didn't crash

View File

@@ -32,6 +32,7 @@ from letta.config import LettaConfig
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.server.server import SyncServer
from tests.helpers.utils import upload_file_and_wait
from tests.utils import wait_for_server
# Constants
SERVER_PORT = 8283
@@ -106,7 +107,7 @@ def client() -> LettaSDKClient:
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
wait_for_server(server_url, timeout=60)
print("Running client tests with server:", server_url)
client = LettaSDKClient(base_url=server_url)

View File

@@ -1740,3 +1740,110 @@ async def test_handle_uniqueness_per_org(default_user, provider_manager):
assert model is not None
assert model.provider_id == provider_1.id # Still original provider
assert model.max_context_window == 8192 # Still original
@pytest.mark.asyncio
async def test_delete_provider_cascades_to_models(default_user, provider_manager, monkeypatch):
"""Test that deleting a provider also soft-deletes its associated models."""
test_id = generate_test_id()
# Mock _sync_default_models_for_provider to avoid external API calls
async def mock_sync(provider, actor):
pass # Don't actually sync - we'll manually create models below
monkeypatch.setattr(provider_manager, "_sync_default_models_for_provider", mock_sync)
# 1. Create a BYOK provider (org-scoped, so the actor can delete it)
provider_create = ProviderCreate(
name=f"test-cascade-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-test-key",
)
provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=True)
# 2. Manually sync models to the provider
llm_models = [
LLMConfig(
model=f"gpt-4o-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-{test_id}/gpt-4o",
provider_name=provider.name,
provider_category=ProviderCategory.byok,
),
LLMConfig(
model=f"gpt-4o-mini-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=16384,
handle=f"test-{test_id}/gpt-4o-mini",
provider_name=provider.name,
provider_category=ProviderCategory.byok,
),
]
embedding_models = [
EmbeddingConfig(
embedding_model=f"text-embedding-3-small-{test_id}",
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=300,
handle=f"test-{test_id}/text-embedding-3-small",
),
]
await provider_manager.sync_provider_models_async(
provider=provider,
llm_models=llm_models,
embedding_models=embedding_models,
organization_id=default_user.organization_id, # Org-scoped for BYOK provider
)
# 3. Verify models exist before deletion
llm_models_before = await provider_manager.list_models_async(
actor=default_user,
model_type="llm",
provider_id=provider.id,
)
embedding_models_before = await provider_manager.list_models_async(
actor=default_user,
model_type="embedding",
provider_id=provider.id,
)
llm_handles_before = {m.handle for m in llm_models_before}
embedding_handles_before = {m.handle for m in embedding_models_before}
assert f"test-{test_id}/gpt-4o" in llm_handles_before
assert f"test-{test_id}/gpt-4o-mini" in llm_handles_before
assert f"test-{test_id}/text-embedding-3-small" in embedding_handles_before
# 4. Delete the provider
await provider_manager.delete_provider_by_id_async(provider.id, actor=default_user)
# 5. Verify models are soft-deleted (no longer returned in list)
all_llm_models_after = await provider_manager.list_models_async(
actor=default_user,
model_type="llm",
)
all_embedding_models_after = await provider_manager.list_models_async(
actor=default_user,
model_type="embedding",
)
all_llm_handles_after = {m.handle for m in all_llm_models_after}
all_embedding_handles_after = {m.handle for m in all_embedding_models_after}
# All models from the deleted provider should be gone
assert f"test-{test_id}/gpt-4o" not in all_llm_handles_after
assert f"test-{test_id}/gpt-4o-mini" not in all_llm_handles_after
assert f"test-{test_id}/text-embedding-3-small" not in all_embedding_handles_after
# 6. Verify provider is also deleted
providers_after = await provider_manager.list_providers_async(
actor=default_user,
name=f"test-cascade-{test_id}",
)
assert len(providers_after) == 0