fix: fix providers and models persistence (#8302)
This commit is contained in:
@@ -8,6 +8,26 @@ LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir")
|
||||
LETTA_MODEL_ENDPOINT = "https://inference.letta.com/v1/"
|
||||
DEFAULT_TIMEZONE = "UTC"
|
||||
|
||||
# Provider ordering for model listing (matches original _enabled_providers list order)
|
||||
PROVIDER_ORDER = {
|
||||
"letta": 0,
|
||||
"openai": 1,
|
||||
"anthropic": 2,
|
||||
"ollama": 3,
|
||||
"google_ai": 4,
|
||||
"google_vertex": 5,
|
||||
"azure": 6,
|
||||
"groq": 7,
|
||||
"together": 8,
|
||||
"vllm": 9,
|
||||
"bedrock": 10,
|
||||
"deepseek": 11,
|
||||
"xai": 12,
|
||||
"lmstudio": 13,
|
||||
"zai": 14,
|
||||
"openrouter": 15, # Note: OpenRouter uses OpenRouterProvider, not a ProviderType enum
|
||||
}
|
||||
|
||||
ADMIN_PREFIX = "/v1/admin"
|
||||
API_PREFIX = "/v1"
|
||||
OLLAMA_API_PREFIX = "/v1"
|
||||
|
||||
@@ -235,16 +235,17 @@ class LLMClientBase:
|
||||
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Returns the override key for the given llm config.
|
||||
For both base and BYOK providers, fetch the API key from the database.
|
||||
Only fetches API key from database for BYOK providers.
|
||||
Base providers use environment variables directly.
|
||||
"""
|
||||
api_key = None
|
||||
# Fetch API key from database for both base and BYOK providers
|
||||
# This ensures that base providers (from environment) also have their keys persisted and accessible
|
||||
if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]:
|
||||
# Only fetch API key from database for BYOK providers
|
||||
# Base providers should always use environment variables
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
|
||||
# If we got an empty string from the database (e.g., Letta provider), treat it as None
|
||||
# If we got an empty string from the database, treat it as None
|
||||
# so the client can fall back to environment variables or default behavior
|
||||
if api_key == "":
|
||||
api_key = None
|
||||
@@ -254,16 +255,17 @@ class LLMClientBase:
|
||||
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Returns the override key for the given llm config.
|
||||
For both base and BYOK providers, fetch the API key from the database.
|
||||
Only fetches API key from database for BYOK providers.
|
||||
Base providers use environment variables directly.
|
||||
"""
|
||||
api_key = None
|
||||
# Fetch API key from database for both base and BYOK providers
|
||||
# This ensures that base providers (from environment) also have their keys persisted and accessible
|
||||
if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]:
|
||||
# Only fetch API key from database for BYOK providers
|
||||
# Base providers should always use environment variables
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
|
||||
# If we got an empty string from the database (e.g., Letta provider), treat it as None
|
||||
# If we got an empty string from the database, treat it as None
|
||||
# so the client can fall back to environment variables or default behavior
|
||||
if api_key == "":
|
||||
api_key = None
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
from fastapi import APIRouter, Body, Depends, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.providers import Provider, ProviderBase, ProviderCheck, ProviderCreate, ProviderUpdate
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.validators import ProviderId
|
||||
@@ -39,7 +39,14 @@ async def list_providers(
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
providers = await server.provider_manager.list_providers_async(
|
||||
before=before, after=after, limit=limit, actor=actor, name=name, provider_type=provider_type, ascending=(order == "asc")
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
name=name,
|
||||
provider_type=provider_type,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
ascending=(order == "asc"),
|
||||
)
|
||||
return providers
|
||||
|
||||
|
||||
@@ -1097,6 +1097,18 @@ class SyncServer(object):
|
||||
passage_count, document_count = await load_data(connector, source, self.passage_manager, self.file_manager, actor=actor)
|
||||
return passage_count, document_count
|
||||
|
||||
def _get_provider_sort_key(self, model: LLMConfig) -> Tuple[int, str, str]:
|
||||
"""Get sort key for a model: (provider_priority, provider_name, model_name)"""
|
||||
provider_priority = constants.PROVIDER_ORDER.get(model.provider_name, 999)
|
||||
return (provider_priority, model.provider_name or "", model.model or "")
|
||||
|
||||
def _get_embedding_provider_sort_key(self, model: EmbeddingConfig) -> Tuple[int, str, str]:
|
||||
"""Get sort key for an embedding model: (provider_priority, provider_name, model_name)"""
|
||||
# Extract provider name from handle (format: "provider_name/model_name")
|
||||
provider_name = model.handle.split("/")[0] if model.handle and "/" in model.handle else ""
|
||||
provider_priority = constants.PROVIDER_ORDER.get(provider_name, 999)
|
||||
return (provider_priority, provider_name, model.embedding_model or "")
|
||||
|
||||
@trace_method
|
||||
async def list_llm_models_async(
|
||||
self,
|
||||
@@ -1105,86 +1117,122 @@ class SyncServer(object):
|
||||
provider_name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
) -> List[LLMConfig]:
|
||||
"""List available LLM models from database cache"""
|
||||
# Get provider IDs if filtering by provider
|
||||
provider_ids = None
|
||||
if provider_name or provider_type:
|
||||
providers = await self.get_enabled_providers_async(
|
||||
provider_category=provider_category,
|
||||
provider_name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
provider_ids = [p.id for p in providers]
|
||||
|
||||
# If filtering was requested but no providers matched, return empty list
|
||||
if not provider_ids:
|
||||
return []
|
||||
|
||||
# Get models from database
|
||||
provider_models = await self.provider_manager.list_models_async(
|
||||
actor=actor,
|
||||
model_type="llm",
|
||||
provider_id=provider_ids[0] if provider_ids and len(provider_ids) == 1 else None,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
# Build LLMConfig objects from cached data
|
||||
# Cache providers to avoid N+1 queries
|
||||
provider_cache: Dict[str, Provider] = {}
|
||||
"""List available LLM models - base from DB, BYOK from provider endpoints"""
|
||||
llm_models = []
|
||||
for model in provider_models:
|
||||
# Skip if filtering by provider and model doesn't match
|
||||
if provider_ids and model.provider_id not in provider_ids:
|
||||
continue
|
||||
|
||||
# Get provider details (with caching to avoid N+1 queries)
|
||||
if model.provider_id not in provider_cache:
|
||||
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
|
||||
provider = provider_cache[model.provider_id]
|
||||
# Determine which categories to include
|
||||
include_base = not provider_category or ProviderCategory.base in provider_category
|
||||
include_byok = not provider_category or ProviderCategory.byok in provider_category
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=model.name,
|
||||
model_endpoint_type=model.model_endpoint_type,
|
||||
model_endpoint=provider.base_url or model.model_endpoint_type,
|
||||
context_window=model.max_context_window or 16384,
|
||||
handle=model.handle,
|
||||
provider_name=provider.name,
|
||||
provider_category=provider.provider_category,
|
||||
# Get base provider models from database
|
||||
if include_base:
|
||||
provider_models = await self.provider_manager.list_models_async(
|
||||
actor=actor,
|
||||
model_type="llm",
|
||||
enabled=True,
|
||||
)
|
||||
llm_models.append(llm_config)
|
||||
|
||||
# Build LLMConfig objects from database
|
||||
provider_cache: Dict[str, Provider] = {}
|
||||
for model in provider_models:
|
||||
# Get provider details (with caching to avoid N+1 queries)
|
||||
if model.provider_id not in provider_cache:
|
||||
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
|
||||
provider = provider_cache[model.provider_id]
|
||||
|
||||
# Skip non-base providers (they're handled separately)
|
||||
if provider.provider_category != ProviderCategory.base:
|
||||
continue
|
||||
|
||||
# Apply provider_name/provider_type filters if specified
|
||||
if provider_name and provider.name != provider_name:
|
||||
continue
|
||||
if provider_type and provider.provider_type != provider_type:
|
||||
continue
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=model.name,
|
||||
model_endpoint_type=model.model_endpoint_type,
|
||||
model_endpoint=provider.base_url or model.model_endpoint_type,
|
||||
context_window=model.max_context_window or 16384,
|
||||
handle=model.handle,
|
||||
provider_name=provider.name,
|
||||
provider_category=provider.provider_category,
|
||||
)
|
||||
llm_models.append(llm_config)
|
||||
|
||||
# Get BYOK provider models by hitting provider endpoints directly
|
||||
if include_byok:
|
||||
byok_providers = await self.provider_manager.list_providers_async(
|
||||
actor=actor,
|
||||
name=provider_name,
|
||||
provider_type=provider_type,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
)
|
||||
|
||||
for provider in byok_providers:
|
||||
try:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
models = await typed_provider.list_llm_models_async()
|
||||
llm_models.extend(models)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch models from BYOK provider {provider.name}: {e}")
|
||||
|
||||
# Sort by provider order (matching old _enabled_providers order), then by model name
|
||||
llm_models.sort(key=self._get_provider_sort_key)
|
||||
|
||||
return llm_models
|
||||
|
||||
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
|
||||
"""List available embedding models from database cache"""
|
||||
# Get models from database
|
||||
"""List available embedding models - base from DB, BYOK from provider endpoints"""
|
||||
embedding_models = []
|
||||
|
||||
# Get base provider models from database
|
||||
provider_models = await self.provider_manager.list_models_async(
|
||||
actor=actor,
|
||||
model_type="embedding",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
# Build EmbeddingConfig objects from cached data
|
||||
# Cache providers to avoid N+1 queries
|
||||
# Build EmbeddingConfig objects from database (base providers only)
|
||||
provider_cache: Dict[str, Provider] = {}
|
||||
embedding_models = []
|
||||
for model in provider_models:
|
||||
# Get provider details (with caching to avoid N+1 queries)
|
||||
if model.provider_id not in provider_cache:
|
||||
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
|
||||
provider = provider_cache[model.provider_id]
|
||||
|
||||
# Skip non-base providers (they're handled separately)
|
||||
if provider.provider_category != ProviderCategory.base:
|
||||
continue
|
||||
|
||||
embedding_config = EmbeddingConfig(
|
||||
embedding_model=model.name,
|
||||
embedding_endpoint_type=model.model_endpoint_type,
|
||||
embedding_endpoint=provider.base_url or model.model_endpoint_type,
|
||||
embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default
|
||||
embedding_dim=model.embedding_dim or 1536,
|
||||
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=model.handle,
|
||||
)
|
||||
embedding_models.append(embedding_config)
|
||||
|
||||
# Get BYOK provider models by hitting provider endpoints directly
|
||||
byok_providers = await self.provider_manager.list_providers_async(
|
||||
actor=actor,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
)
|
||||
|
||||
for provider in byok_providers:
|
||||
try:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
models = await typed_provider.list_embedding_models_async()
|
||||
embedding_models.extend(models)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider.name}: {e}")
|
||||
|
||||
# Sort by provider order (matching old _enabled_providers order), then by model name
|
||||
embedding_models.sort(key=self._get_embedding_provider_sort_key)
|
||||
|
||||
return embedding_models
|
||||
|
||||
async def get_enabled_providers_async(
|
||||
|
||||
@@ -240,6 +240,7 @@ class ProviderManager:
|
||||
actor: PydanticUser,
|
||||
name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
provider_category: Optional[List[ProviderCategory]] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
@@ -280,7 +281,14 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
# Combine both lists
|
||||
all_providers = org_providers + global_providers
|
||||
all_providers = []
|
||||
if not provider_category:
|
||||
all_providers = org_providers + global_providers
|
||||
else:
|
||||
if ProviderCategory.byok in provider_category:
|
||||
all_providers += org_providers
|
||||
if ProviderCategory.base in provider_category:
|
||||
all_providers += global_providers
|
||||
|
||||
# Remove deprecated api_key and access_key fields from the response
|
||||
for provider in all_providers:
|
||||
@@ -575,13 +583,14 @@ class ProviderManager:
|
||||
continue
|
||||
|
||||
# Convert Provider to ProviderCreate
|
||||
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
||||
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
||||
# NOTE: Do NOT store API keys for base providers in the database.
|
||||
# Base providers should always use environment variables for API keys.
|
||||
# This ensures keys stay in sync with env vars and aren't duplicated in DB.
|
||||
provider_create = ProviderCreate(
|
||||
name=provider.name,
|
||||
provider_type=provider.provider_type,
|
||||
api_key=api_key or "", # ProviderCreate requires api_key, use empty string if None
|
||||
access_key=access_key,
|
||||
api_key="", # Base providers use env vars, not DB-stored keys
|
||||
access_key=None,
|
||||
region=provider.region,
|
||||
base_url=provider.base_url,
|
||||
api_version=provider.api_version,
|
||||
|
||||
@@ -2089,3 +2089,341 @@ async def test_get_enabled_providers_async_queries_database(default_user, provid
|
||||
|
||||
assert f"test-base-provider-{test_id}" in openai_names
|
||||
assert f"test-byok-provider-{test_id}" not in openai_names # This is anthropic type
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BYOK Provider and Model Listing Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_providers_filters_by_category(default_user, provider_manager):
|
||||
"""Test that list_providers_async correctly filters by provider_category."""
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create a base provider
|
||||
base_provider_create = ProviderCreate(
|
||||
name=f"test-base-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-base-key",
|
||||
)
|
||||
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
|
||||
|
||||
# Create a BYOK provider
|
||||
byok_provider_create = ProviderCreate(
|
||||
name=f"test-byok-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-byok-key",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Verify base provider has correct category
|
||||
assert base_provider.provider_category == ProviderCategory.base
|
||||
|
||||
# Verify BYOK provider has correct category
|
||||
assert byok_provider.provider_category == ProviderCategory.byok
|
||||
|
||||
# List only BYOK providers
|
||||
byok_providers = await provider_manager.list_providers_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
)
|
||||
byok_names = [p.name for p in byok_providers]
|
||||
|
||||
assert f"test-byok-{test_id}" in byok_names
|
||||
assert f"test-base-{test_id}" not in byok_names
|
||||
|
||||
# List only base providers
|
||||
base_providers = await provider_manager.list_providers_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.base],
|
||||
)
|
||||
base_names = [p.name for p in base_providers]
|
||||
|
||||
assert f"test-base-{test_id}" in base_names
|
||||
assert f"test-byok-{test_id}" not in base_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_provider_api_key_not_stored_in_db(default_user, provider_manager):
|
||||
"""Test that sync_base_providers does NOT store API keys for base providers."""
|
||||
# Create base providers with API keys
|
||||
base_providers = [
|
||||
OpenAIProvider(name="test-openai-no-key", api_key="sk-should-not-be-stored"),
|
||||
]
|
||||
|
||||
# Sync to database
|
||||
await provider_manager.sync_base_providers(base_providers=base_providers, actor=default_user)
|
||||
|
||||
# Retrieve the provider from database
|
||||
providers = await provider_manager.list_providers_async(name="test-openai-no-key", actor=default_user)
|
||||
assert len(providers) == 1
|
||||
|
||||
provider = providers[0]
|
||||
assert provider.provider_category == ProviderCategory.base
|
||||
|
||||
# The API key should be empty (not stored) for base providers
|
||||
if provider.api_key_enc:
|
||||
api_key = await provider.api_key_enc.get_plaintext_async()
|
||||
assert api_key == "" or api_key is None, "Base provider API key should not be stored in database"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byok_provider_api_key_stored_in_db(default_user, provider_manager):
|
||||
"""Test that BYOK providers DO have their API keys stored in the database."""
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create a BYOK provider with API key
|
||||
byok_provider_create = ProviderCreate(
|
||||
name=f"test-byok-with-key-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-byok-should-be-stored",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Retrieve the provider from database
|
||||
providers = await provider_manager.list_providers_async(name=f"test-byok-with-key-{test_id}", actor=default_user)
|
||||
assert len(providers) == 1
|
||||
|
||||
provider = providers[0]
|
||||
assert provider.provider_category == ProviderCategory.byok
|
||||
|
||||
# The API key SHOULD be stored for BYOK providers
|
||||
assert provider.api_key_enc is not None
|
||||
api_key = await provider.api_key_enc.get_plaintext_async()
|
||||
assert api_key == "sk-byok-should-be-stored", "BYOK provider API key should be stored in database"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_list_llm_models_base_from_db(default_user, provider_manager):
|
||||
"""Test that server.list_llm_models_async fetches base models from database."""
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create base provider and models (these ARE stored in DB)
|
||||
base_provider_create = ProviderCreate(
|
||||
name=f"test-base-llm-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-base-key",
|
||||
)
|
||||
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
|
||||
|
||||
base_llm_model = LLMConfig(
|
||||
model=f"base-gpt-4o-{test_id}",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle=f"test-base-llm-{test_id}/gpt-4o",
|
||||
provider_name=base_provider.name,
|
||||
provider_category=ProviderCategory.base,
|
||||
)
|
||||
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=base_provider,
|
||||
llm_models=[base_llm_model],
|
||||
embedding_models=[],
|
||||
organization_id=None,
|
||||
)
|
||||
|
||||
# Create server and list models
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.provider_manager = provider_manager
|
||||
server._enabled_providers = [] # Clear to test database-backed listing
|
||||
|
||||
# List all models - base models come from DB
|
||||
all_models = await server.list_llm_models_async(actor=default_user)
|
||||
all_handles = [m.handle for m in all_models]
|
||||
|
||||
assert f"test-base-llm-{test_id}/gpt-4o" in all_handles, "Base model should be in list"
|
||||
|
||||
# List only base models
|
||||
base_models = await server.list_llm_models_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.base],
|
||||
)
|
||||
base_handles = [m.handle for m in base_models]
|
||||
|
||||
assert f"test-base-llm-{test_id}/gpt-4o" in base_handles
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_list_llm_models_byok_from_provider_api(default_user, provider_manager):
|
||||
"""Test that server.list_llm_models_async fetches BYOK models from provider API, not DB.
|
||||
|
||||
Note: BYOK models are fetched by calling the provider's list_llm_models_async() method,
|
||||
which hits the actual provider API. This test uses mocking to verify that flow.
|
||||
"""
|
||||
from letta.schemas.providers import Provider
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create a BYOK provider (but don't sync models to DB - they come from API)
|
||||
byok_provider_create = ProviderCreate(
|
||||
name=f"test-byok-llm-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-byok-key",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Create server
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.provider_manager = provider_manager
|
||||
server._enabled_providers = []
|
||||
|
||||
# Mock the BYOK provider's list_llm_models_async to return test models
|
||||
mock_byok_models = [
|
||||
LLMConfig(
|
||||
model=f"byok-gpt-4o-{test_id}",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://custom.openai.com/v1",
|
||||
context_window=64000,
|
||||
handle=f"test-byok-llm-{test_id}/gpt-4o",
|
||||
provider_name=byok_provider.name,
|
||||
provider_category=ProviderCategory.byok,
|
||||
)
|
||||
]
|
||||
|
||||
# Create a mock typed provider that returns our test models
|
||||
mock_typed_provider = MagicMock()
|
||||
mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_byok_models)
|
||||
|
||||
# Patch cast_to_subtype on the Provider class to return our mock
|
||||
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
|
||||
# List BYOK models - should call provider API via cast_to_subtype().list_llm_models_async()
|
||||
byok_models = await server.list_llm_models_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.byok],
|
||||
)
|
||||
|
||||
# Verify the mock was called (proving we hit provider API, not DB)
|
||||
mock_typed_provider.list_llm_models_async.assert_called()
|
||||
|
||||
# Verify we got the mocked models back
|
||||
byok_handles = [m.handle for m in byok_models]
|
||||
assert f"test-byok-llm-{test_id}/gpt-4o" in byok_handles
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_list_embedding_models_base_from_db(default_user, provider_manager):
|
||||
"""Test that server.list_embedding_models_async fetches base models from database.
|
||||
|
||||
Note: Similar to LLM models, base embedding models are stored in DB while BYOK
|
||||
embedding models would be fetched from provider API.
|
||||
"""
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create base provider and embedding models (these ARE stored in DB)
|
||||
base_provider_create = ProviderCreate(
|
||||
name=f"test-base-embed-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-base-key",
|
||||
)
|
||||
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
|
||||
|
||||
base_embedding_model = EmbeddingConfig(
|
||||
embedding_model=f"base-text-embedding-{test_id}",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
handle=f"test-base-embed-{test_id}/text-embedding-3-small",
|
||||
)
|
||||
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=base_provider,
|
||||
llm_models=[],
|
||||
embedding_models=[base_embedding_model],
|
||||
organization_id=None,
|
||||
)
|
||||
|
||||
# Create server and list models
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.provider_manager = provider_manager
|
||||
server._enabled_providers = []
|
||||
|
||||
# List all embedding models - base models come from DB
|
||||
all_models = await server.list_embedding_models_async(actor=default_user)
|
||||
all_handles = [m.handle for m in all_models]
|
||||
|
||||
assert f"test-base-embed-{test_id}/text-embedding-3-small" in all_handles
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_ordering_matches_constants(default_user, provider_manager):
|
||||
"""Test that provider ordering in model listing matches PROVIDER_ORDER in constants."""
|
||||
from letta.constants import PROVIDER_ORDER
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Create providers with different names that should have different ordering
|
||||
providers_to_create = [
|
||||
("zai", ProviderType.zai, 14), # Lower priority
|
||||
("openai", ProviderType.openai, 1), # Higher priority
|
||||
("anthropic", ProviderType.anthropic, 2), # Medium priority
|
||||
]
|
||||
|
||||
created_providers = []
|
||||
for name_suffix, provider_type, expected_order in providers_to_create:
|
||||
provider_create = ProviderCreate(
|
||||
name=f"{name_suffix}", # Use actual provider name for ordering
|
||||
provider_type=provider_type,
|
||||
api_key=f"sk-{name_suffix}-key",
|
||||
)
|
||||
# Check if provider already exists
|
||||
existing = await provider_manager.list_providers_async(name=name_suffix, actor=default_user)
|
||||
if not existing:
|
||||
provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=False)
|
||||
created_providers.append((provider, expected_order))
|
||||
|
||||
# Create a model for this provider
|
||||
llm_model = LLMConfig(
|
||||
model=f"test-model-{test_id}",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.example.com/v1",
|
||||
context_window=8192,
|
||||
handle=f"{name_suffix}/test-model-{test_id}",
|
||||
provider_name=provider.name,
|
||||
provider_category=ProviderCategory.base,
|
||||
)
|
||||
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=provider,
|
||||
llm_models=[llm_model],
|
||||
embedding_models=[],
|
||||
organization_id=None,
|
||||
)
|
||||
|
||||
# Verify PROVIDER_ORDER has expected values
|
||||
assert PROVIDER_ORDER.get("openai") == 1
|
||||
assert PROVIDER_ORDER.get("anthropic") == 2
|
||||
assert PROVIDER_ORDER.get("zai") == 14
|
||||
|
||||
# Create server and verify ordering
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.provider_manager = provider_manager
|
||||
server._enabled_providers = []
|
||||
|
||||
# List models and check ordering
|
||||
all_models = await server.list_llm_models_async(actor=default_user)
|
||||
|
||||
# Filter to only our test models
|
||||
test_models = [m for m in all_models if f"test-model-{test_id}" in m.handle]
|
||||
|
||||
if len(test_models) >= 2:
|
||||
# Verify models are sorted by provider order
|
||||
provider_names_in_order = [m.provider_name for m in test_models]
|
||||
|
||||
# Get the indices in PROVIDER_ORDER
|
||||
indices = [PROVIDER_ORDER.get(name, 999) for name in provider_names_in_order]
|
||||
|
||||
# Verify the list is sorted by provider order
|
||||
assert indices == sorted(indices), f"Models should be sorted by PROVIDER_ORDER, got: {provider_names_in_order}"
|
||||
|
||||
Reference in New Issue
Block a user