Revert "feat: enable provider models persistence" (#6590)
Revert "feat: enable provider models persistence (#6193)" This reverts commit 9682aff32640a6ee8cf71a6f18c9fa7cda25c40e.
This commit is contained in:
committed by
Caren Thomas
parent
bbd52e291c
commit
8440e319e2
@@ -220,38 +220,24 @@ class LLMClientBase:
|
||||
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Returns the override key for the given llm config.
|
||||
For both base and BYOK providers, fetch the API key from the database.
|
||||
"""
|
||||
api_key = None
|
||||
# Fetch API key from database for both base and BYOK providers
|
||||
# This ensures that base providers (from environment) also have their keys persisted and accessible
|
||||
if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]:
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
|
||||
# If we got an empty string from the database (e.g., Letta provider), treat it as None
|
||||
# so the client can fall back to environment variables or default behavior
|
||||
if api_key == "":
|
||||
api_key = None
|
||||
|
||||
return api_key, None, None
|
||||
|
||||
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Returns the override key for the given llm config.
|
||||
For both base and BYOK providers, fetch the API key from the database.
|
||||
"""
|
||||
api_key = None
|
||||
# Fetch API key from database for both base and BYOK providers
|
||||
# This ensures that base providers (from environment) also have their keys persisted and accessible
|
||||
if llm_config.provider_category in [ProviderCategory.byok, ProviderCategory.base]:
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
|
||||
# If we got an empty string from the database (e.g., Letta provider), treat it as None
|
||||
# so the client can fall back to environment variables or default behavior
|
||||
if api_key == "":
|
||||
api_key = None
|
||||
|
||||
return api_key, None, None
|
||||
|
||||
|
||||
@@ -8,13 +8,10 @@ from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers.base import Provider
|
||||
|
||||
LETTA_EMBEDDING_ENDPOINT = "https://embeddings.letta.com/"
|
||||
|
||||
|
||||
class LettaProvider(Provider):
|
||||
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
base_url: str = Field(LETTA_EMBEDDING_ENDPOINT, description="Base URL for the Letta API (used for embeddings).")
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
return [
|
||||
@@ -34,7 +31,7 @@ class LettaProvider(Provider):
|
||||
EmbeddingConfig(
|
||||
embedding_model="letta-free", # NOTE: renamed
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_endpoint="https://embeddings.letta.com/",
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=self.get_handle("letta-free", is_embedding=True),
|
||||
|
||||
@@ -113,7 +113,7 @@ class Secret(BaseModel):
|
||||
"MIGRATION_NEEDED: Reading from plaintext column instead of encrypted column. "
|
||||
"This indicates data that hasn't been migrated to the _enc column yet. "
|
||||
"Please run migrate data to _enc columns as plaintext columns will be deprecated.",
|
||||
# stack_info=True,
|
||||
stack_info=True,
|
||||
)
|
||||
return cls.from_plaintext(plaintext_value)
|
||||
return cls.from_plaintext(None)
|
||||
|
||||
@@ -109,7 +109,7 @@ from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import DatabaseChoice, model_settings, settings, tool_settings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||
from letta.utils import get_friendly_error_msg, get_persona_text, safe_create_task
|
||||
from letta.utils import get_friendly_error_msg, get_persona_text, make_key, safe_create_task
|
||||
|
||||
config = LettaConfig.load()
|
||||
logger = get_logger(__name__)
|
||||
@@ -203,10 +203,12 @@ class SyncServer(object):
|
||||
"""Initialize the MCP clients (there may be multiple)"""
|
||||
self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {}
|
||||
|
||||
# collect providers (always has Letta as a default)
|
||||
from letta.constants import LETTA_MODEL_ENDPOINT
|
||||
# TODO: Remove these in memory caches
|
||||
self._llm_config_cache = {}
|
||||
self._embedding_config_cache = {}
|
||||
|
||||
self._enabled_providers: List[Provider] = [LettaProvider(name="letta", base_url=LETTA_MODEL_ENDPOINT)]
|
||||
# collect providers (always has Letta as a default)
|
||||
self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
|
||||
if model_settings.openai_api_key:
|
||||
self._enabled_providers.append(
|
||||
OpenAIProvider(
|
||||
@@ -340,12 +342,6 @@ class SyncServer(object):
|
||||
print(f"Default user: {self.default_user} and org: {self.default_org}")
|
||||
await self.tool_manager.upsert_base_tools_async(actor=self.default_user)
|
||||
|
||||
# Sync environment-based providers to database (idempotent, safe for multi-pod startup)
|
||||
await self.provider_manager.sync_base_providers(base_providers=self._enabled_providers, actor=self.default_user)
|
||||
|
||||
# Sync provider models to database
|
||||
await self._sync_provider_models_async()
|
||||
|
||||
# For OSS users, create a local sandbox config
|
||||
oss_default_user = await self.user_manager.get_default_actor_async()
|
||||
use_venv = False if not tool_settings.tool_exec_venv_name else True
|
||||
@@ -382,65 +378,6 @@ class SyncServer(object):
|
||||
force_recreate=True,
|
||||
)
|
||||
|
||||
def _get_enabled_provider(self, provider_name: str) -> Optional[Provider]:
|
||||
"""Find and return an enabled provider by name.
|
||||
|
||||
Args:
|
||||
provider_name: The name of the provider to find
|
||||
|
||||
Returns:
|
||||
The matching enabled provider, or None if not found
|
||||
"""
|
||||
for provider in self._enabled_providers:
|
||||
if provider.name == provider_name:
|
||||
return provider
|
||||
return None
|
||||
|
||||
async def _sync_provider_models_async(self):
|
||||
"""Sync all provider models to database at startup."""
|
||||
logger.info("Syncing provider models to database")
|
||||
|
||||
# Get persisted providers from database (they now have IDs)
|
||||
persisted_providers = await self.provider_manager.list_providers_async(actor=self.default_user)
|
||||
|
||||
for persisted_provider in persisted_providers:
|
||||
try:
|
||||
# Find the matching enabled provider instance to call list_models on
|
||||
enabled_provider = self._get_enabled_provider(persisted_provider.name)
|
||||
|
||||
if not enabled_provider:
|
||||
# Only delete base providers that are no longer enabled
|
||||
# BYOK providers are user-created and should not be automatically deleted
|
||||
if persisted_provider.provider_category == ProviderCategory.base:
|
||||
logger.info(f"Base provider {persisted_provider.name} is no longer enabled, deleting from database")
|
||||
try:
|
||||
await self.provider_manager.delete_provider_by_id_async(
|
||||
provider_id=persisted_provider.id, actor=self.default_user
|
||||
)
|
||||
except NoResultFound:
|
||||
# Provider was already deleted (race condition in multi-pod startup)
|
||||
logger.debug(f"Provider {persisted_provider.name} was already deleted, skipping")
|
||||
else:
|
||||
logger.debug(f"No enabled provider for BYOK provider {persisted_provider.name}, skipping model sync")
|
||||
continue
|
||||
|
||||
# Fetch models from provider
|
||||
llm_models = await enabled_provider.list_llm_models_async()
|
||||
embedding_models = await enabled_provider.list_embedding_models_async()
|
||||
|
||||
# Save to database with the persisted provider (which has an ID)
|
||||
await self.provider_manager.sync_provider_models_async(
|
||||
provider=persisted_provider,
|
||||
llm_models=llm_models,
|
||||
embedding_models=embedding_models,
|
||||
organization_id=None, # Global models
|
||||
)
|
||||
logger.info(
|
||||
f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync models for provider {persisted_provider.name}: {e}", exc_info=True)
|
||||
|
||||
async def init_mcp_clients(self):
|
||||
# TODO: remove this
|
||||
mcp_server_configs = self.get_mcp_servers()
|
||||
@@ -468,6 +405,39 @@ class SyncServer(object):
|
||||
logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}")
|
||||
logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}")
|
||||
|
||||
@trace_method
|
||||
def get_cached_llm_config(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._llm_config_cache:
|
||||
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
|
||||
logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries")
|
||||
return self._llm_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
async def get_cached_llm_config_async(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._llm_config_cache:
|
||||
self._llm_config_cache[key] = await self.get_llm_config_from_handle_async(actor=actor, **kwargs)
|
||||
logger.info(f"LLM config cache size: {len(self._llm_config_cache)} entries")
|
||||
return self._llm_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
def get_cached_embedding_config(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._embedding_config_cache:
|
||||
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
|
||||
logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries")
|
||||
return self._embedding_config_cache[key]
|
||||
|
||||
# @async_redis_cache(key_func=lambda (actor, **kwargs): actor.id + hash(kwargs))
|
||||
@trace_method
|
||||
async def get_cached_embedding_config_async(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._embedding_config_cache:
|
||||
self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs)
|
||||
logger.info(f"Embedding config cache size: {len(self._embedding_config_cache)} entries")
|
||||
return self._embedding_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
async def create_agent_async(
|
||||
self,
|
||||
@@ -501,9 +471,10 @@ class SyncServer(object):
|
||||
"max_reasoning_tokens": request.max_reasoning_tokens,
|
||||
"enable_reasoner": request.enable_reasoner,
|
||||
}
|
||||
log_event(name="start get_llm_config_from_handle", attributes=config_params)
|
||||
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
|
||||
log_event(name="end get_llm_config_from_handle", attributes=config_params)
|
||||
config_params.update(additional_config_params)
|
||||
log_event(name="start get_cached_llm_config", attributes=config_params)
|
||||
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
|
||||
log_event(name="end get_cached_llm_config", attributes=config_params)
|
||||
if request.model and isinstance(request.model, str):
|
||||
assert request.llm_config.handle == request.model, (
|
||||
f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}"
|
||||
@@ -533,9 +504,9 @@ class SyncServer(object):
|
||||
"handle": request.embedding,
|
||||
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
}
|
||||
log_event(name="start get_embedding_config_from_handle", attributes=embedding_config_params)
|
||||
request.embedding_config = await self.get_embedding_config_from_handle_async(actor=actor, **embedding_config_params)
|
||||
log_event(name="end get_embedding_config_from_handle", attributes=embedding_config_params)
|
||||
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
|
||||
request.embedding_config = await self.get_cached_embedding_config_async(actor=actor, **embedding_config_params)
|
||||
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
|
||||
|
||||
log_event(name="start create_agent db")
|
||||
main_agent = await self.agent_manager.create_agent_async(
|
||||
@@ -584,9 +555,9 @@ class SyncServer(object):
|
||||
"context_window_limit": request.context_window_limit,
|
||||
"max_tokens": request.max_tokens,
|
||||
}
|
||||
log_event(name="start get_llm_config_from_handle", attributes=config_params)
|
||||
request.llm_config = await self.get_llm_config_from_handle_async(actor=actor, **config_params)
|
||||
log_event(name="end get_llm_config_from_handle", attributes=config_params)
|
||||
log_event(name="start get_cached_llm_config", attributes=config_params)
|
||||
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
|
||||
log_event(name="end get_cached_llm_config", attributes=config_params)
|
||||
|
||||
# update with model_settings
|
||||
if request.model_settings is not None:
|
||||
@@ -1090,85 +1061,73 @@ class SyncServer(object):
|
||||
provider_name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
) -> List[LLMConfig]:
|
||||
"""List available LLM models from database cache"""
|
||||
# Get provider IDs if filtering by provider
|
||||
provider_ids = None
|
||||
if provider_name or provider_type:
|
||||
providers = await self.get_enabled_providers_async(
|
||||
provider_category=provider_category,
|
||||
provider_name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
provider_ids = [p.id for p in providers]
|
||||
"""Asynchronously list available models with maximum concurrency"""
|
||||
import asyncio
|
||||
|
||||
# If filtering was requested but no providers matched, return empty list
|
||||
if not provider_ids:
|
||||
providers = await self.get_enabled_providers_async(
|
||||
provider_category=provider_category,
|
||||
provider_name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
async def get_provider_models(provider: Provider) -> list[LLMConfig]:
|
||||
try:
|
||||
async with asyncio.timeout(constants.GET_PROVIDERS_TIMEOUT_SECONDS):
|
||||
return await provider.list_llm_models_async()
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout while listing LLM models for provider {provider}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while listing LLM models for provider {provider}: {e}")
|
||||
return []
|
||||
|
||||
# Get models from database
|
||||
provider_models = await self.provider_manager.list_models_async(
|
||||
actor=actor,
|
||||
model_type="llm",
|
||||
provider_id=provider_ids[0] if provider_ids and len(provider_ids) == 1 else None,
|
||||
enabled=True,
|
||||
)
|
||||
# Execute all provider model listing tasks concurrently
|
||||
provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers])
|
||||
|
||||
# Build LLMConfig objects from cached data
|
||||
# Cache providers to avoid N+1 queries
|
||||
provider_cache: Dict[str, Provider] = {}
|
||||
# Flatten the results
|
||||
llm_models = []
|
||||
for model in provider_models:
|
||||
# Skip if filtering by provider and model doesn't match
|
||||
if provider_ids and model.provider_id not in provider_ids:
|
||||
continue
|
||||
for models in provider_results:
|
||||
llm_models.extend(models)
|
||||
|
||||
# Get provider details (with caching to avoid N+1 queries)
|
||||
if model.provider_id not in provider_cache:
|
||||
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
|
||||
provider = provider_cache[model.provider_id]
|
||||
# Get local configs - if this is potentially slow, consider making it async too
|
||||
local_configs = self.get_local_llm_configs()
|
||||
llm_models.extend(local_configs)
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=model.name,
|
||||
model_endpoint_type=model.model_endpoint_type,
|
||||
model_endpoint=provider.base_url or model.model_endpoint_type,
|
||||
context_window=model.max_context_window or 16384,
|
||||
handle=model.handle,
|
||||
provider_name=provider.name,
|
||||
provider_category=provider.provider_category,
|
||||
)
|
||||
llm_models.append(llm_config)
|
||||
# dedupe by handle for uniqueness
|
||||
# Seems like this is required from the tests?
|
||||
seen_handles = set()
|
||||
unique_models = []
|
||||
for model in llm_models:
|
||||
if model.handle not in seen_handles:
|
||||
seen_handles.add(model.handle)
|
||||
unique_models.append(model)
|
||||
|
||||
return llm_models
|
||||
return unique_models
|
||||
|
||||
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
|
||||
"""List available embedding models from database cache"""
|
||||
# Get models from database
|
||||
provider_models = await self.provider_manager.list_models_async(
|
||||
actor=actor,
|
||||
model_type="embedding",
|
||||
enabled=True,
|
||||
)
|
||||
"""Asynchronously list available embedding models with maximum concurrency"""
|
||||
import asyncio
|
||||
|
||||
# Build EmbeddingConfig objects from cached data
|
||||
# Cache providers to avoid N+1 queries
|
||||
provider_cache: Dict[str, Provider] = {}
|
||||
# Get all eligible providers first
|
||||
providers = await self.get_enabled_providers_async(actor=actor)
|
||||
|
||||
# Fetch embedding models from each provider concurrently
|
||||
async def get_provider_embedding_models(provider):
|
||||
try:
|
||||
# All providers now have list_embedding_models_async
|
||||
return await provider.list_embedding_models_async()
|
||||
except Exception as e:
|
||||
logger.exception(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return []
|
||||
|
||||
# Execute all provider model listing tasks concurrently
|
||||
provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers])
|
||||
|
||||
# Flatten the results
|
||||
embedding_models = []
|
||||
for model in provider_models:
|
||||
# Get provider details (with caching to avoid N+1 queries)
|
||||
if model.provider_id not in provider_cache:
|
||||
provider_cache[model.provider_id] = await self.provider_manager.get_provider_async(model.provider_id, actor)
|
||||
provider = provider_cache[model.provider_id]
|
||||
|
||||
embedding_config = EmbeddingConfig(
|
||||
embedding_model=model.name,
|
||||
embedding_endpoint_type=model.model_endpoint_type,
|
||||
embedding_endpoint=provider.base_url or model.model_endpoint_type,
|
||||
embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default
|
||||
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=model.handle,
|
||||
)
|
||||
embedding_models.append(embedding_config)
|
||||
for models in provider_results:
|
||||
embedding_models.extend(models)
|
||||
|
||||
return embedding_models
|
||||
|
||||
@@ -1181,22 +1140,17 @@ class SyncServer(object):
|
||||
) -> List[Provider]:
|
||||
providers = []
|
||||
if not provider_category or ProviderCategory.base in provider_category:
|
||||
# Add enabled providers (base providers from environment)
|
||||
enabled_providers = [p for p in self._enabled_providers]
|
||||
providers.extend(enabled_providers)
|
||||
providers_from_env = [p for p in self._enabled_providers]
|
||||
providers.extend(providers_from_env)
|
||||
|
||||
if not provider_category or ProviderCategory.byok in provider_category:
|
||||
# Add persisted BYOK providers from database
|
||||
# Note: list_providers_async returns both org-specific and global providers,
|
||||
# so we filter to only include BYOK providers to avoid duplicating base providers
|
||||
persisted_providers = await self.provider_manager.list_providers_async(
|
||||
providers_from_db = await self.provider_manager.list_providers_async(
|
||||
name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
# Filter to only BYOK providers (base providers are already in self._enabled_providers)
|
||||
persisted_byok_providers = [p.cast_to_subtype() for p in persisted_providers if p.provider_category == ProviderCategory.byok]
|
||||
providers.extend(persisted_byok_providers)
|
||||
providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
|
||||
providers.extend(providers_from_db)
|
||||
|
||||
if provider_name is not None:
|
||||
providers = [p for p in providers if p.name == provider_name]
|
||||
@@ -1216,19 +1170,32 @@ class SyncServer(object):
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
enable_reasoner: Optional[bool] = None,
|
||||
) -> LLMConfig:
|
||||
# Use provider_manager to get LLMConfig from handle
|
||||
try:
|
||||
llm_config = await self.provider_manager.get_llm_config_from_handle(
|
||||
handle=handle,
|
||||
actor=actor,
|
||||
)
|
||||
except Exception as e:
|
||||
# Convert to HandleNotFoundError for backwards compatibility
|
||||
from letta.orm.errors import NoResultFound
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = await self.get_provider_from_name_async(provider_name, actor)
|
||||
|
||||
if isinstance(e, NoResultFound):
|
||||
raise HandleNotFoundError(handle, [])
|
||||
raise
|
||||
all_llm_configs = await provider.list_llm_models_async()
|
||||
llm_configs = [config for config in all_llm_configs if config.handle == handle]
|
||||
if not llm_configs:
|
||||
llm_configs = [config for config in all_llm_configs if config.model == model_name]
|
||||
if not llm_configs:
|
||||
available_handles = [config.handle for config in all_llm_configs]
|
||||
raise HandleNotFoundError(handle, available_handles)
|
||||
except ValueError as e:
|
||||
llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle]
|
||||
if not llm_configs:
|
||||
llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name]
|
||||
if not llm_configs:
|
||||
raise e
|
||||
|
||||
if len(llm_configs) == 1:
|
||||
llm_config = llm_configs[0]
|
||||
elif len(llm_configs) > 1:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Multiple LLM models with name {model_name} supported by {provider_name}", argument_name="model_name"
|
||||
)
|
||||
else:
|
||||
llm_config = llm_configs[0]
|
||||
|
||||
if context_window_limit is not None:
|
||||
if context_window_limit > llm_config.context_window:
|
||||
@@ -1260,22 +1227,33 @@ class SyncServer(object):
|
||||
async def get_embedding_config_from_handle_async(
|
||||
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
) -> EmbeddingConfig:
|
||||
# Use provider_manager to get EmbeddingConfig from handle
|
||||
try:
|
||||
embedding_config = await self.provider_manager.get_embedding_config_from_handle(
|
||||
handle=handle,
|
||||
actor=actor,
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = await self.get_provider_from_name_async(provider_name, actor)
|
||||
|
||||
all_embedding_configs = await provider.list_embedding_models_async()
|
||||
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
|
||||
if not embedding_configs:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Embedding model {model_name} is not supported by {provider_name}", argument_name="model_name"
|
||||
)
|
||||
except LettaInvalidArgumentError as e:
|
||||
# search local configs
|
||||
embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle]
|
||||
if not embedding_configs:
|
||||
raise e
|
||||
|
||||
if len(embedding_configs) == 1:
|
||||
embedding_config = embedding_configs[0]
|
||||
elif len(embedding_configs) > 1:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Multiple embedding models with name {model_name} supported by {provider_name}", argument_name="model_name"
|
||||
)
|
||||
except Exception as e:
|
||||
# Convert to LettaInvalidArgumentError for backwards compatibility
|
||||
from letta.orm.errors import NoResultFound
|
||||
else:
|
||||
embedding_config = embedding_configs[0]
|
||||
|
||||
if isinstance(e, NoResultFound):
|
||||
raise LettaInvalidArgumentError(f"Embedding model {handle} not found", argument_name="handle")
|
||||
raise
|
||||
|
||||
# Override chunk size if provided
|
||||
embedding_config.embedding_chunk_size = embedding_chunk_size
|
||||
if embedding_chunk_size:
|
||||
embedding_config.embedding_chunk_size = embedding_chunk_size
|
||||
|
||||
return embedding_config
|
||||
|
||||
@@ -1294,6 +1272,46 @@ class SyncServer(object):
|
||||
|
||||
return provider
|
||||
|
||||
def get_local_llm_configs(self):
|
||||
llm_models = []
|
||||
# NOTE: deprecated
|
||||
# try:
|
||||
# llm_configs_dir = os.path.expanduser("~/.letta/llm_configs")
|
||||
# if os.path.exists(llm_configs_dir):
|
||||
# for filename in os.listdir(llm_configs_dir):
|
||||
# if filename.endswith(".json"):
|
||||
# filepath = os.path.join(llm_configs_dir, filename)
|
||||
# try:
|
||||
# with open(filepath, "r") as f:
|
||||
# config_data = json.load(f)
|
||||
# llm_config = LLMConfig(**config_data)
|
||||
# llm_models.append(llm_config)
|
||||
# except (json.JSONDecodeError, ValueError) as e:
|
||||
# logger.warning(f"Error parsing LLM config file {filename}: {e}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error reading LLM configs directory: {e}")
|
||||
return llm_models
|
||||
|
||||
def get_local_embedding_configs(self):
|
||||
embedding_models = []
|
||||
# NOTE: deprecated
|
||||
# try:
|
||||
# embedding_configs_dir = os.path.expanduser("~/.letta/embedding_configs")
|
||||
# if os.path.exists(embedding_configs_dir):
|
||||
# for filename in os.listdir(embedding_configs_dir):
|
||||
# if filename.endswith(".json"):
|
||||
# filepath = os.path.join(embedding_configs_dir, filename)
|
||||
# try:
|
||||
# with open(filepath, "r") as f:
|
||||
# config_data = json.load(f)
|
||||
# embedding_config = EmbeddingConfig(**config_data)
|
||||
# embedding_models.append(embedding_config)
|
||||
# except (json.JSONDecodeError, ValueError) as e:
|
||||
# logger.warning(f"Error parsing embedding config file {filename}: {e}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error reading embedding configs directory: {e}")
|
||||
return embedding_models
|
||||
|
||||
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
|
||||
"""Add a new LLM model"""
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ class ProviderManager:
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
|
||||
"""Delete a provider and its associated models."""
|
||||
"""Delete a provider."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Clear api key field
|
||||
existing_provider = await ProviderModel.read_async(
|
||||
@@ -163,15 +163,6 @@ class ProviderManager:
|
||||
existing_provider.api_key = None
|
||||
await existing_provider.update_async(session, actor=actor)
|
||||
|
||||
# Soft delete all models associated with this provider
|
||||
provider_models = await ProviderModelORM.list_async(
|
||||
db_session=session,
|
||||
provider_id=provider_id,
|
||||
check_is_deleted=True,
|
||||
)
|
||||
for model in provider_models:
|
||||
await model.delete_async(session, actor=actor)
|
||||
|
||||
# Soft delete in provider table
|
||||
await existing_provider.delete_async(session, actor=actor)
|
||||
|
||||
@@ -640,11 +631,11 @@ class ProviderManager:
|
||||
await model.create_async(session)
|
||||
logger.info(f" ✓ Successfully created LLM model {llm_config.handle} with ID {model.id}")
|
||||
except Exception as e:
|
||||
logger.info(f" ✗ Failed to create LLM model {llm_config.handle}: {e}")
|
||||
logger.error(f" ✗ Failed to create LLM model {llm_config.handle}: {e}")
|
||||
# Log the full error details
|
||||
import traceback
|
||||
|
||||
logger.info(f" Full traceback: {traceback.format_exc()}")
|
||||
logger.error(f" Full traceback: {traceback.format_exc()}")
|
||||
# Roll back the session to clear the failed transaction
|
||||
await session.rollback()
|
||||
else:
|
||||
|
||||
@@ -499,436 +499,3 @@ async def test_byok_provider_auto_syncs_models(provider_manager, default_user, m
|
||||
llm_config = await provider_manager.get_llm_config_from_handle(handle="my-openai-key/gpt-4o", actor=default_user)
|
||||
assert llm_config.model == "gpt-4o"
|
||||
assert llm_config.provider_name == "my-openai-key"
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Server Startup Provider Sync Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_startup_syncs_base_providers(default_user, default_organization, monkeypatch):
|
||||
"""Test that server startup properly syncs base provider models from environment.
|
||||
|
||||
This test simulates the server startup process and verifies that:
|
||||
1. Base providers from environment variables are synced to database
|
||||
2. Provider models are fetched from mocked API endpoints
|
||||
3. Models are properly persisted to the database with correct metadata
|
||||
4. Models can be retrieved using handles
|
||||
"""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers import AnthropicProvider, OpenAIProvider
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
# Mock OpenAI API responses
|
||||
mock_openai_models = {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"object": "model",
|
||||
"created": 1687882411,
|
||||
"owned_by": "openai",
|
||||
"max_model_len": 8192,
|
||||
},
|
||||
{
|
||||
"id": "gpt-4-turbo",
|
||||
"object": "model",
|
||||
"created": 1712361441,
|
||||
"owned_by": "system",
|
||||
"max_model_len": 128000,
|
||||
},
|
||||
{
|
||||
"id": "text-embedding-ada-002",
|
||||
"object": "model",
|
||||
"created": 1671217299,
|
||||
"owned_by": "openai-internal",
|
||||
},
|
||||
{
|
||||
"id": "gpt-4-vision", # Should be filtered out by OpenAI provider logic (has disallowed keyword)
|
||||
"object": "model",
|
||||
"created": 1698959748,
|
||||
"owned_by": "system",
|
||||
"max_model_len": 8192,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
# Mock Anthropic API responses
|
||||
mock_anthropic_models = {
|
||||
"data": [
|
||||
{
|
||||
"id": "claude-3-5-sonnet-20241022",
|
||||
"type": "model",
|
||||
"display_name": "Claude 3.5 Sonnet",
|
||||
"created_at": "2024-10-22T00:00:00Z",
|
||||
},
|
||||
{
|
||||
"id": "claude-3-opus-20240229",
|
||||
"type": "model",
|
||||
"display_name": "Claude 3 Opus",
|
||||
"created_at": "2024-02-29T00:00:00Z",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
# Mock the API calls for OpenAI
|
||||
async def mock_openai_get_model_list_async(*args, **kwargs):
|
||||
return mock_openai_models
|
||||
|
||||
# Mock Anthropic models.list() response
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_anthropic_response = MagicMock()
|
||||
mock_anthropic_response.model_dump.return_value = mock_anthropic_models
|
||||
|
||||
# Mock the Anthropic AsyncAnthropic client
|
||||
class MockAnthropicModels:
|
||||
async def list(self):
|
||||
return mock_anthropic_response
|
||||
|
||||
class MockAsyncAnthropic:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.models = MockAnthropicModels()
|
||||
|
||||
# Patch the actual API calling functions
|
||||
monkeypatch.setattr(
|
||||
"letta.llm_api.openai.openai_get_model_list_async",
|
||||
mock_openai_get_model_list_async,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"anthropic.AsyncAnthropic",
|
||||
MockAsyncAnthropic,
|
||||
)
|
||||
|
||||
# Clear ALL provider-related env vars first to ensure clean state
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
||||
monkeypatch.delenv("AZURE_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.delenv("TOGETHER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("VLLM_API_BASE", raising=False)
|
||||
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
|
||||
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
|
||||
monkeypatch.delenv("LMSTUDIO_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False)
|
||||
monkeypatch.delenv("XAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
|
||||
# Set environment variables to enable only OpenAI and Anthropic
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key-12345")
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key-67890")
|
||||
|
||||
# Reload model_settings to pick up new env vars
|
||||
from letta.settings import model_settings
|
||||
|
||||
monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key-12345")
|
||||
monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key-67890")
|
||||
monkeypatch.setattr(model_settings, "gemini_api_key", None)
|
||||
monkeypatch.setattr(model_settings, "google_cloud_project", None)
|
||||
monkeypatch.setattr(model_settings, "google_cloud_location", None)
|
||||
monkeypatch.setattr(model_settings, "azure_api_key", None)
|
||||
monkeypatch.setattr(model_settings, "groq_api_key", None)
|
||||
monkeypatch.setattr(model_settings, "together_api_key", None)
|
||||
monkeypatch.setattr(model_settings, "vllm_api_base", None)
|
||||
monkeypatch.setattr(model_settings, "aws_access_key_id", None)
|
||||
monkeypatch.setattr(model_settings, "aws_secret_access_key", None)
|
||||
monkeypatch.setattr(model_settings, "lmstudio_base_url", None)
|
||||
monkeypatch.setattr(model_settings, "deepseek_api_key", None)
|
||||
monkeypatch.setattr(model_settings, "xai_api_key", None)
|
||||
monkeypatch.setattr(model_settings, "openrouter_api_key", None)
|
||||
|
||||
# Create server instance (this will load enabled providers from environment)
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
|
||||
# Manually set up the default user/org (since we disabled auto-init)
|
||||
server.default_user = default_user
|
||||
server.default_org = default_organization
|
||||
|
||||
# Verify enabled providers were loaded
|
||||
assert len(server._enabled_providers) == 3 # Exactly: letta, openai, anthropic
|
||||
enabled_provider_names = [p.name for p in server._enabled_providers]
|
||||
assert "letta" in enabled_provider_names
|
||||
assert "openai" in enabled_provider_names
|
||||
assert "anthropic" in enabled_provider_names
|
||||
|
||||
# First, sync base providers to database (this is what init_async does)
|
||||
await server.provider_manager.sync_base_providers(
|
||||
base_providers=server._enabled_providers,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Now call the actual _sync_provider_models_async method
|
||||
# This simulates what happens during server startup
|
||||
await server._sync_provider_models_async()
|
||||
|
||||
# Verify OpenAI models were synced
|
||||
openai_providers = await server.provider_manager.list_providers_async(
|
||||
name="openai",
|
||||
actor=default_user,
|
||||
)
|
||||
assert len(openai_providers) == 1, "OpenAI provider should exist"
|
||||
openai_provider = openai_providers[0]
|
||||
|
||||
# Check OpenAI LLM models
|
||||
openai_llm_models = await server.provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
provider_id=openai_provider.id,
|
||||
model_type="llm",
|
||||
)
|
||||
|
||||
# Should have gpt-4 and gpt-4-turbo (gpt-4-vision filtered out due to "vision" keyword)
|
||||
assert len(openai_llm_models) >= 2, f"Expected at least 2 OpenAI LLM models, got {len(openai_llm_models)}"
|
||||
openai_model_names = [m.name for m in openai_llm_models]
|
||||
assert "gpt-4" in openai_model_names
|
||||
assert "gpt-4-turbo" in openai_model_names
|
||||
|
||||
# Check OpenAI embedding models
|
||||
openai_embedding_models = await server.provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
provider_id=openai_provider.id,
|
||||
model_type="embedding",
|
||||
)
|
||||
assert len(openai_embedding_models) >= 1, "Expected at least 1 OpenAI embedding model"
|
||||
embedding_model_names = [m.name for m in openai_embedding_models]
|
||||
assert "text-embedding-ada-002" in embedding_model_names
|
||||
|
||||
# Verify model metadata is correct
|
||||
gpt4_models = [m for m in openai_llm_models if m.name == "gpt-4"]
|
||||
assert len(gpt4_models) > 0, "gpt-4 model should exist"
|
||||
gpt4_model = gpt4_models[0]
|
||||
assert gpt4_model.handle == "openai/gpt-4"
|
||||
assert gpt4_model.model_endpoint_type == "openai"
|
||||
assert gpt4_model.max_context_window == 8192
|
||||
assert gpt4_model.enabled is True
|
||||
|
||||
# Verify Anthropic models were synced
|
||||
anthropic_providers = await server.provider_manager.list_providers_async(
|
||||
name="anthropic",
|
||||
actor=default_user,
|
||||
)
|
||||
assert len(anthropic_providers) == 1, "Anthropic provider should exist"
|
||||
anthropic_provider = anthropic_providers[0]
|
||||
|
||||
anthropic_llm_models = await server.provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
provider_id=anthropic_provider.id,
|
||||
model_type="llm",
|
||||
)
|
||||
|
||||
# Should have Claude models
|
||||
assert len(anthropic_llm_models) >= 2, f"Expected at least 2 Anthropic models, got {len(anthropic_llm_models)}"
|
||||
anthropic_model_names = [m.name for m in anthropic_llm_models]
|
||||
assert "claude-3-5-sonnet-20241022" in anthropic_model_names
|
||||
assert "claude-3-opus-20240229" in anthropic_model_names
|
||||
|
||||
# Test that we can retrieve LLMConfig from handle
|
||||
llm_config = await server.provider_manager.get_llm_config_from_handle(
|
||||
handle="openai/gpt-4",
|
||||
actor=default_user,
|
||||
)
|
||||
assert llm_config.model == "gpt-4"
|
||||
assert llm_config.handle == "openai/gpt-4"
|
||||
assert llm_config.provider_name == "openai"
|
||||
assert llm_config.context_window == 8192
|
||||
|
||||
# Test that we can retrieve EmbeddingConfig from handle
|
||||
embedding_config = await server.provider_manager.get_embedding_config_from_handle(
|
||||
handle="openai/text-embedding-ada-002",
|
||||
actor=default_user,
|
||||
)
|
||||
assert embedding_config.embedding_model == "text-embedding-ada-002"
|
||||
assert embedding_config.handle == "openai/text-embedding-ada-002"
|
||||
assert embedding_config.embedding_dim == 1536
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_startup_handles_disabled_providers(default_user, default_organization, monkeypatch):
|
||||
"""Test that server startup properly handles providers that are no longer enabled.
|
||||
|
||||
This test verifies that:
|
||||
1. Base providers that are no longer enabled (env vars removed) are deleted
|
||||
2. BYOK providers that are no longer enabled are NOT deleted (user-created)
|
||||
3. The sync process handles providers gracefully when API calls fail
|
||||
"""
|
||||
from letta.schemas.providers import OpenAIProvider, ProviderCreate
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
# First, manually create providers in the database
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Create a base OpenAI provider (simulating it was synced before)
|
||||
base_openai_create = ProviderCreate(
|
||||
name="openai",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-old-key",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
base_openai = await provider_manager.create_provider_async(
|
||||
base_openai_create,
|
||||
actor=default_user,
|
||||
is_byok=False, # This is a base provider
|
||||
)
|
||||
|
||||
# Create a BYOK provider (user-created)
|
||||
byok_provider_create = ProviderCreate(
|
||||
name="my-custom-openai",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-my-key",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
byok_provider = await provider_manager.create_provider_async(
|
||||
byok_provider_create,
|
||||
actor=default_user,
|
||||
is_byok=True,
|
||||
)
|
||||
assert byok_provider.provider_category == ProviderCategory.byok
|
||||
|
||||
# Now create server with NO environment variables set (all base providers disabled)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
|
||||
from letta.settings import model_settings
|
||||
|
||||
monkeypatch.setattr(model_settings, "openai_api_key", None)
|
||||
monkeypatch.setattr(model_settings, "anthropic_api_key", None)
|
||||
|
||||
# Create server instance
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.default_org = default_organization
|
||||
|
||||
# Verify only letta provider is enabled (no openai)
|
||||
enabled_names = [p.name for p in server._enabled_providers]
|
||||
assert "letta" in enabled_names
|
||||
assert "openai" not in enabled_names
|
||||
|
||||
# Sync base providers (should not include openai anymore)
|
||||
await server.provider_manager.sync_base_providers(
|
||||
base_providers=server._enabled_providers,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Call _sync_provider_models_async
|
||||
await server._sync_provider_models_async()
|
||||
|
||||
# Verify base OpenAI provider was deleted (no longer enabled)
|
||||
try:
|
||||
await server.provider_manager.get_provider_async(base_openai.id, actor=default_user)
|
||||
assert False, "Base OpenAI provider should have been deleted"
|
||||
except Exception:
|
||||
# Expected - provider should not exist
|
||||
pass
|
||||
|
||||
# Verify BYOK provider still exists (should NOT be deleted)
|
||||
byok_still_exists = await server.provider_manager.get_provider_async(
|
||||
byok_provider.id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert byok_still_exists is not None
|
||||
assert byok_still_exists.name == "my-custom-openai"
|
||||
assert byok_still_exists.provider_category == ProviderCategory.byok
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_startup_handles_api_errors_gracefully(default_user, default_organization, monkeypatch):
|
||||
"""Test that server startup handles API errors gracefully without crashing.
|
||||
|
||||
This test verifies that:
|
||||
1. If a provider's API call fails during sync, it logs an error but continues
|
||||
2. Other providers can still sync successfully
|
||||
3. The server startup completes without crashing
|
||||
"""
|
||||
from letta.schemas.providers import AnthropicProvider, OpenAIProvider
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
# Mock OpenAI to fail
|
||||
async def mock_openai_fail(*args, **kwargs):
|
||||
raise Exception("OpenAI API is down")
|
||||
|
||||
# Mock Anthropic to succeed
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_anthropic_response = MagicMock()
|
||||
mock_anthropic_response.model_dump.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": "claude-3-5-sonnet-20241022",
|
||||
"type": "model",
|
||||
"display_name": "Claude 3.5 Sonnet",
|
||||
"created_at": "2024-10-22T00:00:00Z",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
class MockAnthropicModels:
|
||||
async def list(self):
|
||||
return mock_anthropic_response
|
||||
|
||||
class MockAsyncAnthropic:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.models = MockAnthropicModels()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"letta.llm_api.openai.openai_get_model_list_async",
|
||||
mock_openai_fail,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"anthropic.AsyncAnthropic",
|
||||
MockAsyncAnthropic,
|
||||
)
|
||||
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key")
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key")
|
||||
|
||||
from letta.settings import model_settings
|
||||
|
||||
monkeypatch.setattr(model_settings, "openai_api_key", "sk-test-key")
|
||||
monkeypatch.setattr(model_settings, "anthropic_api_key", "sk-ant-test-key")
|
||||
|
||||
# Create server
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.default_org = default_organization
|
||||
|
||||
# Sync base providers
|
||||
await server.provider_manager.sync_base_providers(
|
||||
base_providers=server._enabled_providers,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# This should NOT crash even though OpenAI fails
|
||||
await server._sync_provider_models_async()
|
||||
|
||||
# Verify Anthropic still synced successfully
|
||||
anthropic_providers = await server.provider_manager.list_providers_async(
|
||||
name="anthropic",
|
||||
actor=default_user,
|
||||
)
|
||||
assert len(anthropic_providers) == 1
|
||||
|
||||
anthropic_models = await server.provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
provider_id=anthropic_providers[0].id,
|
||||
model_type="llm",
|
||||
)
|
||||
assert len(anthropic_models) >= 1, "Anthropic models should have synced despite OpenAI failure"
|
||||
|
||||
# OpenAI should have no models (sync failed)
|
||||
openai_providers = await server.provider_manager.list_providers_async(
|
||||
name="openai",
|
||||
actor=default_user,
|
||||
)
|
||||
if len(openai_providers) > 0:
|
||||
openai_models = await server.provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
provider_id=openai_providers[0].id,
|
||||
)
|
||||
# Models might exist from previous runs, but the sync attempt should have been logged as failed
|
||||
# The key is that the server didn't crash
|
||||
|
||||
@@ -32,7 +32,6 @@ from letta.config import LettaConfig
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.server.server import SyncServer
|
||||
from tests.helpers.utils import upload_file_and_wait
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Constants
|
||||
SERVER_PORT = 8283
|
||||
@@ -107,7 +106,7 @@ def client() -> LettaSDKClient:
|
||||
print("Starting server thread")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(server_url, timeout=60)
|
||||
time.sleep(5)
|
||||
|
||||
print("Running client tests with server:", server_url)
|
||||
client = LettaSDKClient(base_url=server_url)
|
||||
|
||||
@@ -1740,110 +1740,3 @@ async def test_handle_uniqueness_per_org(default_user, provider_manager):
|
||||
assert model is not None
|
||||
assert model.provider_id == provider_1.id # Still original provider
|
||||
assert model.max_context_window == 8192 # Still original
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_provider_cascades_to_models(default_user, provider_manager, monkeypatch):
|
||||
"""Test that deleting a provider also soft-deletes its associated models."""
|
||||
test_id = generate_test_id()
|
||||
|
||||
# Mock _sync_default_models_for_provider to avoid external API calls
|
||||
async def mock_sync(provider, actor):
|
||||
pass # Don't actually sync - we'll manually create models below
|
||||
|
||||
monkeypatch.setattr(provider_manager, "_sync_default_models_for_provider", mock_sync)
|
||||
|
||||
# 1. Create a BYOK provider (org-scoped, so the actor can delete it)
|
||||
provider_create = ProviderCreate(
|
||||
name=f"test-cascade-{test_id}",
|
||||
provider_type=ProviderType.openai,
|
||||
api_key="sk-test-key",
|
||||
)
|
||||
provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# 2. Manually sync models to the provider
|
||||
llm_models = [
|
||||
LLMConfig(
|
||||
model=f"gpt-4o-{test_id}",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle=f"test-{test_id}/gpt-4o",
|
||||
provider_name=provider.name,
|
||||
provider_category=ProviderCategory.byok,
|
||||
),
|
||||
LLMConfig(
|
||||
model=f"gpt-4o-mini-{test_id}",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=16384,
|
||||
handle=f"test-{test_id}/gpt-4o-mini",
|
||||
provider_name=provider.name,
|
||||
provider_category=ProviderCategory.byok,
|
||||
),
|
||||
]
|
||||
|
||||
embedding_models = [
|
||||
EmbeddingConfig(
|
||||
embedding_model=f"text-embedding-3-small-{test_id}",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
handle=f"test-{test_id}/text-embedding-3-small",
|
||||
),
|
||||
]
|
||||
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=provider,
|
||||
llm_models=llm_models,
|
||||
embedding_models=embedding_models,
|
||||
organization_id=default_user.organization_id, # Org-scoped for BYOK provider
|
||||
)
|
||||
|
||||
# 3. Verify models exist before deletion
|
||||
llm_models_before = await provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
model_type="llm",
|
||||
provider_id=provider.id,
|
||||
)
|
||||
embedding_models_before = await provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
model_type="embedding",
|
||||
provider_id=provider.id,
|
||||
)
|
||||
|
||||
llm_handles_before = {m.handle for m in llm_models_before}
|
||||
embedding_handles_before = {m.handle for m in embedding_models_before}
|
||||
|
||||
assert f"test-{test_id}/gpt-4o" in llm_handles_before
|
||||
assert f"test-{test_id}/gpt-4o-mini" in llm_handles_before
|
||||
assert f"test-{test_id}/text-embedding-3-small" in embedding_handles_before
|
||||
|
||||
# 4. Delete the provider
|
||||
await provider_manager.delete_provider_by_id_async(provider.id, actor=default_user)
|
||||
|
||||
# 5. Verify models are soft-deleted (no longer returned in list)
|
||||
all_llm_models_after = await provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
model_type="llm",
|
||||
)
|
||||
all_embedding_models_after = await provider_manager.list_models_async(
|
||||
actor=default_user,
|
||||
model_type="embedding",
|
||||
)
|
||||
|
||||
all_llm_handles_after = {m.handle for m in all_llm_models_after}
|
||||
all_embedding_handles_after = {m.handle for m in all_embedding_models_after}
|
||||
|
||||
# All models from the deleted provider should be gone
|
||||
assert f"test-{test_id}/gpt-4o" not in all_llm_handles_after
|
||||
assert f"test-{test_id}/gpt-4o-mini" not in all_llm_handles_after
|
||||
assert f"test-{test_id}/text-embedding-3-small" not in all_embedding_handles_after
|
||||
|
||||
# 6. Verify provider is also deleted
|
||||
providers_after = await provider_manager.list_providers_async(
|
||||
actor=default_user,
|
||||
name=f"test-cascade-{test_id}",
|
||||
)
|
||||
assert len(providers_after) == 0
|
||||
|
||||
Reference in New Issue
Block a user