Files
letta-server/tests/managers/test_provider_manager.py
Kian Jones 25d54dd896 chore: enable F821, F401, W293 (#9503)
* auto fixes

* auto fix pt2 and transitive deps and undefined var checking locals()

* manual fixes (ignored or letta-code fixed)

* fix circular import
2026-02-24 10:55:08 -08:00

917 lines
36 KiB
Python

"""Tests for ProviderManager encryption/decryption logic."""
import pytest
from letta.orm.provider import Provider as ProviderModel
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.providers import ProviderCreate, ProviderUpdate
from letta.schemas.secret import Secret
from letta.server.db import db_registry
from letta.services.organization_manager import OrganizationManager
from letta.services.provider_manager import ProviderManager
from letta.services.user_manager import UserManager
from letta.settings import settings
@pytest.fixture
async def default_organization():
"""Fixture to create and return the default organization."""
manager = OrganizationManager()
org = await manager.create_default_organization_async()
yield org
@pytest.fixture
async def default_user(default_organization):
"""Fixture to create and return the default user within the default organization."""
manager = UserManager()
user = await manager.create_default_actor_async(org_id=default_organization.id)
yield user
@pytest.fixture
async def provider_manager():
"""Fixture to create and return a ProviderManager instance."""
return ProviderManager()
@pytest.fixture
def encryption_key():
"""Fixture to ensure encryption key is set for tests."""
original_key = settings.encryption_key
# Set a test encryption key if not already set
if not settings.encryption_key:
settings.encryption_key = "test-encryption-key-32-bytes!!"
yield settings.encryption_key
# Restore original
settings.encryption_key = original_key
# ======================================================================================================================
# Provider Encryption Tests
# ======================================================================================================================
@pytest.mark.asyncio
async def test_provider_create_encrypts_api_key(provider_manager, default_user, encryption_key):
"""Test that creating a provider encrypts the api_key and stores it in api_key_enc."""
# Create a provider with plaintext api_key
provider_create = ProviderCreate(
name="test-openai-provider",
provider_type=ProviderType.openai,
api_key="sk-test-plaintext-api-key-12345",
base_url="https://api.openai.com/v1",
)
# Create provider through manager
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
# Verify provider was created
assert created_provider is not None
assert created_provider.name == "test-openai-provider"
assert created_provider.provider_type == ProviderType.openai
# Verify encrypted api_key can be decrypted
assert created_provider.api_key_enc.get_plaintext() == "sk-test-plaintext-api-key-12345"
# Read directly from database to verify encryption
async with db_registry.async_session() as session:
provider_orm = await ProviderModel.read_async(
db_session=session,
identifier=created_provider.id,
actor=default_user,
)
# Verify encrypted column is populated and decrypts correctly
assert provider_orm.api_key_enc is not None
decrypted = Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext()
assert decrypted == "sk-test-plaintext-api-key-12345"
@pytest.mark.asyncio
async def test_provider_read_decrypts_api_key(provider_manager, default_user, encryption_key):
"""Test that reading a provider decrypts the api_key from api_key_enc."""
# Create a provider
provider_create = ProviderCreate(
name="test-anthropic-provider",
provider_type=ProviderType.anthropic,
api_key="sk-ant-test-key-67890",
)
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
provider_id = created_provider.id
# Read the provider back
retrieved_provider = await provider_manager.get_provider_async(provider_id, actor=default_user)
# Verify the api_key is decrypted correctly via api_key_enc
decrypted_key = retrieved_provider.api_key_enc.get_plaintext()
assert decrypted_key == "sk-ant-test-key-67890"
@pytest.mark.asyncio
async def test_provider_update_encrypts_new_api_key(provider_manager, default_user, encryption_key):
"""Test that updating a provider's api_key encrypts the new value."""
# Create initial provider
provider_create = ProviderCreate(
name="test-groq-provider",
provider_type=ProviderType.groq,
api_key="gsk-initial-key-123",
)
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
provider_id = created_provider.id
# Update the api_key
provider_update = ProviderUpdate(
api_key="gsk-updated-key-456",
)
updated_provider = await provider_manager.update_provider_async(provider_id, provider_update, actor=default_user)
# Verify the updated key is accessible via the encrypted field
assert updated_provider.api_key_enc.get_plaintext() == "gsk-updated-key-456"
# Read from DB to verify new encrypted value
async with db_registry.async_session() as session:
provider_orm = await ProviderModel.read_async(
db_session=session,
identifier=provider_id,
actor=default_user,
)
assert provider_orm.api_key_enc is not None
decrypted = Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext()
assert decrypted == "gsk-updated-key-456"
@pytest.mark.asyncio
async def test_bedrock_credentials_encryption(provider_manager, default_user, encryption_key):
"""Test that Bedrock provider encrypts both access_key and api_key (secret_key)."""
# Create Bedrock provider with both keys
provider_create = ProviderCreate(
name="test-bedrock-provider",
provider_type=ProviderType.bedrock,
api_key="secret-access-key-xyz", # This is the secret key
access_key="access-key-id-abc", # This is the access key ID
region="us-east-1",
)
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
# Verify both keys are accessible via encrypted fields
assert created_provider.api_key_enc.get_plaintext() == "secret-access-key-xyz"
assert created_provider.access_key_enc.get_plaintext() == "access-key-id-abc"
# Read from DB to verify both are encrypted
async with db_registry.async_session() as session:
provider_orm = await ProviderModel.read_async(
db_session=session,
identifier=created_provider.id,
actor=default_user,
)
# Verify both encrypted columns are populated
assert provider_orm.api_key_enc is not None
assert provider_orm.access_key_enc is not None
# Verify encrypted values are different from plaintext
assert Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() == "secret-access-key-xyz"
assert Secret.from_encrypted(provider_orm.access_key_enc).get_plaintext() == "access-key-id-abc"
# Test the manager method for getting Bedrock credentials
access_key, secret_key, region = await provider_manager.get_bedrock_credentials_async("test-bedrock-provider", actor=default_user)
assert access_key == "access-key-id-abc"
assert secret_key == "secret-access-key-xyz"
assert region == "us-east-1"
@pytest.mark.asyncio
async def test_provider_secret_not_exposed_in_logs(provider_manager, default_user, encryption_key):
"""Test that Secret objects don't expose plaintext in string representations."""
# Create a provider
provider_create = ProviderCreate(
name="test-secret-provider",
provider_type=ProviderType.openai,
api_key="sk-very-secret-key-do-not-log",
)
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
# Get the Secret object
api_key_secret = created_provider.api_key_enc
# Verify string representation doesn't expose the key
secret_str = str(api_key_secret)
secret_repr = repr(api_key_secret)
assert "sk-very-secret-key-do-not-log" not in secret_str
assert "sk-very-secret-key-do-not-log" not in secret_repr
assert "****" in secret_str or "Secret" in secret_str
assert "****" in secret_repr or "Secret" in secret_repr
@pytest.mark.asyncio
async def test_provider_pydantic_to_orm_serialization(provider_manager, default_user, encryption_key):
"""Test the full Pydantic → ORM → Pydantic round-trip maintains data integrity."""
# Create a provider through the normal flow
provider_create = ProviderCreate(
name="test-roundtrip-provider",
provider_type=ProviderType.openai,
api_key="sk-roundtrip-test-key-999",
base_url="https://api.openai.com/v1",
)
# Step 1: Create provider (Pydantic → ORM)
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
original_api_key = created_provider.api_key_enc.get_plaintext()
# Step 2: Read provider back (ORM → Pydantic)
retrieved_provider = await provider_manager.get_provider_async(created_provider.id, actor=default_user)
# Verify data integrity
assert retrieved_provider.api_key_enc.get_plaintext() == original_api_key
assert retrieved_provider.name == "test-roundtrip-provider"
assert retrieved_provider.provider_type == ProviderType.openai
assert retrieved_provider.base_url == "https://api.openai.com/v1"
# Verify Secret object works correctly
api_key_secret = retrieved_provider.api_key_enc
assert api_key_secret.get_plaintext() == original_api_key
# Step 3: Convert to ORM again (should preserve encrypted field)
orm_data = retrieved_provider.model_dump(to_orm=True)
# Verify encrypted field is in the ORM data
assert "api_key_enc" in orm_data
assert orm_data["api_key_enc"] is not None
assert Secret.from_encrypted(orm_data["api_key_enc"]).get_plaintext() == original_api_key
@pytest.mark.asyncio
async def test_provider_with_none_api_key(provider_manager, default_user, encryption_key):
"""Test that providers can be created with None api_key (some providers may not need it)."""
# Create a provider without an api_key
provider_create = ProviderCreate(
name="test-no-key-provider",
provider_type=ProviderType.ollama,
api_key="", # Empty string
base_url="http://localhost:11434",
)
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
# Verify provider was created
assert created_provider is not None
assert created_provider.name == "test-no-key-provider"
# Read from DB
async with db_registry.async_session() as session:
provider_orm = await ProviderModel.read_async(
db_session=session,
identifier=created_provider.id,
actor=default_user,
)
# api_key_enc should handle empty string appropriately
assert provider_orm.api_key_enc is not None
assert Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() == ""
@pytest.mark.asyncio
async def test_list_providers_decrypts_all(provider_manager, default_user, encryption_key):
"""Test that listing multiple providers decrypts all their api_keys correctly."""
# Create multiple providers
providers_to_create = [
ProviderCreate(name=f"test-provider-{i}", provider_type=ProviderType.openai, api_key=f"sk-key-{i}") for i in range(3)
]
created_ids = []
for provider_create in providers_to_create:
provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
created_ids.append(provider.id)
# List all providers
all_providers = await provider_manager.list_providers_async(actor=default_user)
# Filter to our test providers
test_providers = [p for p in all_providers if p.id in created_ids]
# Verify all are decrypted correctly
assert len(test_providers) == 3
for i, provider in enumerate(sorted(test_providers, key=lambda p: p.name)):
secret = provider.api_key_enc
assert secret.get_plaintext() == f"sk-key-{i}"
# ======================================================================================================================
# Handle to Config Conversion Tests
# ======================================================================================================================
@pytest.mark.asyncio
async def test_handle_to_llm_config_conversion(provider_manager, default_user):
"""Test that handle to LLMConfig conversion works correctly with database lookup."""
from letta.orm.errors import NoResultFound
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
# Create a test provider
provider_create = ProviderCreate(
name="test-handle-provider", provider_type=ProviderType.openai, api_key="sk-test-handle-key", base_url="https://api.openai.com/v1"
)
provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
# Sync some test models
llm_models = [
LLMConfig(
model="gpt-4",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=8192,
handle="test-handle-provider/gpt-4",
provider_name=provider.name,
provider_category=ProviderCategory.base,
),
LLMConfig(
model="gpt-3.5-turbo",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=4096,
handle="test-handle-provider/gpt-3.5-turbo",
provider_name=provider.name,
provider_category=ProviderCategory.base,
),
]
embedding_models = [
EmbeddingConfig(
embedding_model="text-embedding-ada-002",
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=300,
handle="test-handle-provider/text-embedding-ada-002",
)
]
await provider_manager.sync_provider_models_async(
provider=provider, llm_models=llm_models, embedding_models=embedding_models, organization_id=default_user.organization_id
)
# Test LLM config from handle
llm_config = await provider_manager.get_llm_config_from_handle(handle="test-handle-provider/gpt-4", actor=default_user)
# Verify the returned config
assert llm_config.model == "gpt-4"
assert llm_config.handle == "test-handle-provider/gpt-4"
assert llm_config.context_window == 8192
assert llm_config.model_endpoint == "https://api.openai.com/v1"
assert llm_config.provider_name == "test-handle-provider"
# Test embedding config from handle
embedding_config = await provider_manager.get_embedding_config_from_handle(
handle="test-handle-provider/text-embedding-ada-002", actor=default_user
)
# Verify the returned config
assert embedding_config.embedding_model == "text-embedding-ada-002"
assert embedding_config.handle == "test-handle-provider/text-embedding-ada-002"
assert embedding_config.embedding_dim == 1536
assert embedding_config.embedding_chunk_size == 300
assert embedding_config.embedding_endpoint == "https://api.openai.com/v1"
# Test context window limit override would be done at server level
# The provider_manager method doesn't support context_window_limit directly
# Test error handling for non-existent handle
with pytest.raises(NoResultFound):
await provider_manager.get_llm_config_from_handle(handle="nonexistent/model", actor=default_user)
@pytest.mark.asyncio
async def test_byok_provider_auto_syncs_models(provider_manager, default_user, monkeypatch):
"""Test that creating a BYOK provider attempts to sync its models."""
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
# Mock the list_llm_models_async method
async def mock_list_llm():
return [
LLMConfig(
model="gpt-4o",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle="openai/gpt-4o",
provider_name="openai",
provider_category=ProviderCategory.base,
),
LLMConfig(
model="gpt-4",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=8192,
handle="openai/gpt-4",
provider_name="openai",
provider_category=ProviderCategory.base,
),
]
# Mock the list_embedding_models_async method
async def mock_list_embedding():
return [
EmbeddingConfig(
embedding_model="text-embedding-ada-002",
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=300,
handle="openai/text-embedding-ada-002",
)
]
# Mock the _sync_default_models_for_provider method directly
async def mock_sync(provider, actor):
# Get mock models and update them for this provider
llm_models = await mock_list_llm()
embedding_models = await mock_list_embedding()
# Update models to match the BYOK provider
for model in llm_models:
model.provider_name = provider.name
model.handle = f"{provider.name}/{model.model}"
model.provider_category = provider.provider_category
for model in embedding_models:
model.handle = f"{provider.name}/{model.embedding_model}"
# Call sync_provider_models_async with mock data
await provider_manager.sync_provider_models_async(
provider=provider, llm_models=llm_models, embedding_models=embedding_models, organization_id=actor.organization_id
)
monkeypatch.setattr(provider_manager, "_sync_default_models_for_provider", mock_sync)
# Create a BYOK OpenAI provider (simulates UI "Add API Key" flow)
provider_create = ProviderCreate(name="my-openai-key", provider_type=ProviderType.openai, api_key="sk-my-personal-key-123")
# Create the BYOK provider (is_byok=True is the default)
provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=True)
# Verify provider was created
assert provider.name == "my-openai-key"
assert provider.provider_type == ProviderType.openai
# List models for this provider - they should have been auto-synced
models = await provider_manager.list_models_async(actor=default_user, provider_id=provider.id)
# Should have both LLM and embedding models
llm_models = [m for m in models if m.model_type == "llm"]
embedding_models = [m for m in models if m.model_type == "embedding"]
assert len(llm_models) > 0, "No LLM models were synced"
assert len(embedding_models) > 0, "No embedding models were synced"
# Verify handles are correctly formatted with BYOK provider name
for model in models:
assert model.handle.startswith(f"{provider.name}/")
# Test that we can get LLM config from handle
llm_config = await provider_manager.get_llm_config_from_handle(handle="my-openai-key/gpt-4o", actor=default_user)
assert llm_config.model == "gpt-4o"
assert llm_config.provider_name == "my-openai-key"
# ======================================================================================================================
# Server Startup Provider Sync Tests
# ======================================================================================================================
@pytest.mark.asyncio
async def test_server_startup_syncs_base_providers(default_user, default_organization, monkeypatch):
"""Test that server startup properly syncs base provider models from environment.
This test simulates the server startup process and verifies that:
1. Base providers from environment variables are synced to database
2. Provider models are fetched from mocked API endpoints
3. Models are properly persisted to the database with correct metadata
4. Models can be retrieved using handles
"""
from letta.server.server import SyncServer
# Mock OpenAI API responses
mock_openai_models = {
"data": [
{
"id": "gpt-4",
"object": "model",
"created": 1687882411,
"owned_by": "openai",
"max_model_len": 8192,
},
{
"id": "gpt-4-turbo",
"object": "model",
"created": 1712361441,
"owned_by": "system",
"max_model_len": 128000,
},
{
"id": "text-embedding-ada-002",
"object": "model",
"created": 1671217299,
"owned_by": "openai-internal",
},
{
"id": "gpt-4-vision", # Should be filtered out by OpenAI provider logic (has disallowed keyword)
"object": "model",
"created": 1698959748,
"owned_by": "system",
"max_model_len": 8192,
},
]
}
# Mock Anthropic API responses
mock_anthropic_models = {
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"type": "model",
"display_name": "Claude 3.5 Sonnet",
"created_at": "2024-10-22T00:00:00Z",
},
{
"id": "claude-3-opus-20240229",
"type": "model",
"display_name": "Claude 3 Opus",
"created_at": "2024-02-29T00:00:00Z",
},
]
}
# Mock the API calls for OpenAI
async def mock_openai_get_model_list_async(*args, **kwargs):
return mock_openai_models
# Mock Anthropic models.list() response
from unittest.mock import MagicMock
mock_anthropic_response = MagicMock()
mock_anthropic_response.model_dump.return_value = mock_anthropic_models
# Mock the Anthropic AsyncAnthropic client
class MockAnthropicModels:
async def list(self):
return mock_anthropic_response
class MockAsyncAnthropic:
def __init__(self, *args, **kwargs):
self.models = MockAnthropicModels()
# Patch the actual API calling functions
monkeypatch.setattr(
"letta.llm_api.openai.openai_get_model_list_async",
mock_openai_get_model_list_async,
)
monkeypatch.setattr(
"anthropic.AsyncAnthropic",
MockAsyncAnthropic,
)
# Clear ALL provider-related env vars first to ensure clean state
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
monkeypatch.delenv("AZURE_API_KEY", raising=False)
monkeypatch.delenv("GROQ_API_KEY", raising=False)
monkeypatch.delenv("TOGETHER_API_KEY", raising=False)
monkeypatch.delenv("VLLM_API_BASE", raising=False)
monkeypatch.delenv("SGLANG_API_BASE", raising=False)
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
monkeypatch.delenv("LMSTUDIO_BASE_URL", raising=False)
monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False)
monkeypatch.delenv("XAI_API_KEY", raising=False)
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.delenv("ZAI_API_KEY", raising=False)
# Set environment variables to enable only OpenAI and Anthropic
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key-12345")
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key-67890")
# Reload model_settings to pick up new env vars
from letta.settings import model_settings
monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key-12345")
monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key-67890")
monkeypatch.setattr(model_settings, "gemini_api_key", None)
monkeypatch.setattr(model_settings, "google_cloud_project", None)
monkeypatch.setattr(model_settings, "google_cloud_location", None)
monkeypatch.setattr(model_settings, "azure_api_key", None)
monkeypatch.setattr(model_settings, "groq_api_key", None)
monkeypatch.setattr(model_settings, "together_api_key", None)
monkeypatch.setattr(model_settings, "vllm_api_base", None)
monkeypatch.setattr(model_settings, "sglang_api_base", None)
monkeypatch.setattr(model_settings, "aws_access_key_id", None)
monkeypatch.setattr(model_settings, "aws_secret_access_key", None)
monkeypatch.setattr(model_settings, "lmstudio_base_url", None)
monkeypatch.setattr(model_settings, "deepseek_api_key", None)
monkeypatch.setattr(model_settings, "xai_api_key", None)
monkeypatch.setattr(model_settings, "openrouter_api_key", None)
monkeypatch.setattr(model_settings, "zai_api_key", None)
# Create server instance (this will load enabled providers from environment)
server = SyncServer(init_with_default_org_and_user=False)
# Manually set up the default user/org (since we disabled auto-init)
server.default_user = default_user
server.default_org = default_organization
# Verify enabled providers were loaded
assert len(server._enabled_providers) == 3 # Exactly: letta, openai, anthropic
enabled_provider_names = [p.name for p in server._enabled_providers]
assert "letta" in enabled_provider_names
assert "openai" in enabled_provider_names
assert "anthropic" in enabled_provider_names
# First, sync base providers to database (this is what init_async does)
await server.provider_manager.sync_base_providers(
base_providers=server._enabled_providers,
actor=default_user,
)
# Now call the actual _sync_provider_models_async method
# This simulates what happens during server startup
await server._sync_provider_models_async()
# Verify OpenAI models were synced
openai_providers = await server.provider_manager.list_providers_async(
name="openai",
actor=default_user,
)
assert len(openai_providers) == 1, "OpenAI provider should exist"
openai_provider = openai_providers[0]
# Check OpenAI LLM models
openai_llm_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=openai_provider.id,
model_type="llm",
)
# Should have gpt-4 and gpt-4-turbo (gpt-4-vision filtered out due to "vision" keyword)
assert len(openai_llm_models) >= 2, f"Expected at least 2 OpenAI LLM models, got {len(openai_llm_models)}"
openai_model_names = [m.name for m in openai_llm_models]
assert "gpt-4" in openai_model_names
assert "gpt-4-turbo" in openai_model_names
# Check OpenAI embedding models
openai_embedding_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=openai_provider.id,
model_type="embedding",
)
assert len(openai_embedding_models) >= 1, "Expected at least 1 OpenAI embedding model"
embedding_model_names = [m.name for m in openai_embedding_models]
assert "text-embedding-ada-002" in embedding_model_names
# Verify model metadata is correct
gpt4_models = [m for m in openai_llm_models if m.name == "gpt-4"]
assert len(gpt4_models) > 0, "gpt-4 model should exist"
gpt4_model = gpt4_models[0]
assert gpt4_model.handle == "openai/gpt-4"
assert gpt4_model.model_endpoint_type == "openai"
assert gpt4_model.max_context_window == 8192
assert gpt4_model.enabled is True
# Verify Anthropic models were synced
anthropic_providers = await server.provider_manager.list_providers_async(
name="anthropic",
actor=default_user,
)
assert len(anthropic_providers) == 1, "Anthropic provider should exist"
anthropic_provider = anthropic_providers[0]
anthropic_llm_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=anthropic_provider.id,
model_type="llm",
)
# Should have Claude models
assert len(anthropic_llm_models) >= 2, f"Expected at least 2 Anthropic models, got {len(anthropic_llm_models)}"
anthropic_model_names = [m.name for m in anthropic_llm_models]
assert "claude-3-5-sonnet-20241022" in anthropic_model_names
assert "claude-3-opus-20240229" in anthropic_model_names
# Test that we can retrieve LLMConfig from handle
llm_config = await server.provider_manager.get_llm_config_from_handle(
handle="openai/gpt-4",
actor=default_user,
)
assert llm_config.model == "gpt-4"
assert llm_config.handle == "openai/gpt-4"
assert llm_config.provider_name == "openai"
assert llm_config.context_window == 8192
# Test that we can retrieve EmbeddingConfig from handle
embedding_config = await server.provider_manager.get_embedding_config_from_handle(
handle="openai/text-embedding-ada-002",
actor=default_user,
)
assert embedding_config.embedding_model == "text-embedding-ada-002"
assert embedding_config.handle == "openai/text-embedding-ada-002"
assert embedding_config.embedding_dim == 1536
@pytest.mark.asyncio
async def test_server_startup_handles_disabled_providers(default_user, default_organization, monkeypatch):
"""Test that server startup properly handles providers that are no longer enabled.
This test verifies that:
1. Base providers that are no longer enabled (env vars removed) are deleted
2. BYOK providers that are no longer enabled are NOT deleted (user-created)
3. The sync process handles providers gracefully when API calls fail
"""
from letta.schemas.providers import ProviderCreate
from letta.server.server import SyncServer
# First, manually create providers in the database
provider_manager = ProviderManager()
# Create a base OpenAI provider (simulating it was synced before)
base_openai_create = ProviderCreate(
name="openai",
provider_type=ProviderType.openai,
api_key="sk-old-key",
base_url="https://api.openai.com/v1",
)
base_openai = await provider_manager.create_provider_async(
base_openai_create,
actor=default_user,
is_byok=False, # This is a base provider
)
# Create a BYOK provider (user-created)
byok_provider_create = ProviderCreate(
name="my-custom-openai",
provider_type=ProviderType.openai,
api_key="sk-my-key",
base_url="https://api.openai.com/v1",
)
byok_provider = await provider_manager.create_provider_async(
byok_provider_create,
actor=default_user,
is_byok=True,
)
assert byok_provider.provider_category == ProviderCategory.byok
# Now create server with NO environment variables set (all base providers disabled)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
from letta.settings import model_settings
monkeypatch.setattr(model_settings, "openai_api_key", None)
monkeypatch.setattr(model_settings, "anthropic_api_key", None)
# Create server instance
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.default_org = default_organization
# Verify only letta provider is enabled (no openai)
enabled_names = [p.name for p in server._enabled_providers]
assert "letta" in enabled_names
assert "openai" not in enabled_names
# Sync base providers (should not include openai anymore)
await server.provider_manager.sync_base_providers(
base_providers=server._enabled_providers,
actor=default_user,
)
# Call _sync_provider_models_async
await server._sync_provider_models_async()
# Verify base OpenAI provider was deleted (no longer enabled)
try:
await server.provider_manager.get_provider_async(base_openai.id, actor=default_user)
assert False, "Base OpenAI provider should have been deleted"
except Exception:
# Expected - provider should not exist
pass
# Verify BYOK provider still exists (should NOT be deleted)
byok_still_exists = await server.provider_manager.get_provider_async(
byok_provider.id,
actor=default_user,
)
assert byok_still_exists is not None
assert byok_still_exists.name == "my-custom-openai"
assert byok_still_exists.provider_category == ProviderCategory.byok
@pytest.mark.asyncio
async def test_server_startup_handles_api_errors_gracefully(default_user, default_organization, monkeypatch):
"""Test that server startup handles API errors gracefully without crashing.
This test verifies that:
1. If a provider's API call fails during sync, it logs an error but continues
2. Other providers can still sync successfully
3. The server startup completes without crashing
"""
from letta.server.server import SyncServer
# Mock OpenAI to fail
async def mock_openai_fail(*args, **kwargs):
raise Exception("OpenAI API is down")
# Mock Anthropic to succeed
from unittest.mock import MagicMock
mock_anthropic_response = MagicMock()
mock_anthropic_response.model_dump.return_value = {
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"type": "model",
"display_name": "Claude 3.5 Sonnet",
"created_at": "2024-10-22T00:00:00Z",
}
]
}
class MockAnthropicModels:
async def list(self):
return mock_anthropic_response
class MockAsyncAnthropic:
def __init__(self, *args, **kwargs):
self.models = MockAnthropicModels()
monkeypatch.setattr(
"letta.llm_api.openai.openai_get_model_list_async",
mock_openai_fail,
)
monkeypatch.setattr(
"anthropic.AsyncAnthropic",
MockAsyncAnthropic,
)
# Set environment variables
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key")
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key")
from letta.settings import model_settings
monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key")
monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key")
# Create server
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.default_org = default_organization
# Sync base providers
await server.provider_manager.sync_base_providers(
base_providers=server._enabled_providers,
actor=default_user,
)
# This should NOT crash even though OpenAI fails
await server._sync_provider_models_async()
# Verify Anthropic still synced successfully
anthropic_providers = await server.provider_manager.list_providers_async(
name="anthropic",
actor=default_user,
)
assert len(anthropic_providers) == 1
anthropic_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=anthropic_providers[0].id,
model_type="llm",
)
assert len(anthropic_models) >= 1, "Anthropic models should have synced despite OpenAI failure"
# OpenAI should have no models (sync failed)
openai_providers = await server.provider_manager.list_providers_async(
name="openai",
actor=default_user,
)
if len(openai_providers) > 0:
openai_models = await server.provider_manager.list_models_async(
actor=default_user,
provider_id=openai_providers[0].id,
)
# Models might exist from previous runs, but the sync attempt should have been logged as failed
# The key is that the server didn't crash