Files
letta-server/letta/services/provider_manager.py

1072 lines
50 KiB
Python

from typing import List, Optional, Tuple, Union
from sqlalchemy import and_, select
from letta.log import get_logger
from letta.orm.errors import UniqueConstraintViolationError
from letta.orm.provider import Provider as ProviderModel
from letta.orm.provider_model import ProviderModel as ProviderModelORM
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.provider_model import ProviderModel as PydanticProviderModel
from letta.schemas.providers import Provider as PydanticProvider, ProviderCheck, ProviderCreate, ProviderUpdate
from letta.schemas.secret import Secret
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
class ProviderManager:
@enforce_types
@trace_method
async def create_provider_async(self, request: ProviderCreate, actor: PydanticUser, is_byok: bool = True) -> PydanticProvider:
"""Create a new provider if it doesn't already exist.
Args:
request: ProviderCreate object with provider details
actor: User creating the provider
is_byok: If True, creates a BYOK provider (default). If False, creates a base provider.
"""
async with db_registry.async_session() as session:
# Check for name conflicts
if is_byok:
# BYOK providers cannot use the same name as base providers
existing_base_providers = await ProviderModel.list_async(
db_session=session,
name=request.name,
organization_id=None, # Base providers have NULL organization_id
limit=1,
)
if existing_base_providers:
raise ValueError(
f"Provider name '{request.name}' conflicts with an existing base provider. Please choose a different name."
)
else:
# Base providers must have unique names among themselves
# (the DB constraint won't catch this because NULL != NULL)
existing_base_providers = await ProviderModel.list_async(
db_session=session,
name=request.name,
organization_id=None, # Base providers have NULL organization_id
limit=1,
)
if existing_base_providers:
raise ValueError(f"Base provider name '{request.name}' already exists. Please choose a different name.")
# Check if there's a soft-deleted provider with the same name that we can restore
org_id = actor.organization_id if is_byok else None
if org_id is not None:
stmt = select(ProviderModel).where(
and_(
ProviderModel.name == request.name,
ProviderModel.organization_id == org_id,
ProviderModel.is_deleted == True,
)
)
else:
stmt = select(ProviderModel).where(
and_(
ProviderModel.name == request.name,
ProviderModel.organization_id.is_(None),
ProviderModel.is_deleted == True,
)
)
result = await session.execute(stmt)
deleted_provider = result.scalar_one_or_none()
if deleted_provider:
# Restore the soft-deleted provider and update its fields
logger.info(f"Restoring soft-deleted provider '{request.name}' with id: {deleted_provider.id}")
deleted_provider.is_deleted = False
deleted_provider.provider_type = request.provider_type
deleted_provider.provider_category = ProviderCategory.byok if is_byok else ProviderCategory.base
deleted_provider.base_url = request.base_url
deleted_provider.region = request.region
deleted_provider.api_version = request.api_version
# Update encrypted fields (async to avoid blocking event loop)
if request.api_key is not None:
api_key_secret = await Secret.from_plaintext_async(request.api_key)
deleted_provider.api_key_enc = api_key_secret.get_encrypted()
if request.access_key is not None:
access_key_secret = await Secret.from_plaintext_async(request.access_key)
deleted_provider.access_key_enc = access_key_secret.get_encrypted()
await deleted_provider.update_async(session, actor=actor)
# Also restore any soft-deleted models associated with this provider
# This is needed because the unique constraint on provider_models doesn't include is_deleted,
# so soft-deleted models would block creation of new models with the same handle
from sqlalchemy import update
restore_models_stmt = (
update(ProviderModelORM)
.where(
and_(
ProviderModelORM.provider_id == deleted_provider.id,
ProviderModelORM.is_deleted == True,
)
)
.values(is_deleted=False)
)
result = await session.execute(restore_models_stmt)
if result.rowcount > 0:
logger.info(f"Restored {result.rowcount} soft-deleted model(s) for provider '{request.name}'")
# Commit the provider and model restoration before syncing
# This is needed because _sync_default_models_for_provider opens a new session
# that can't see uncommitted changes from this session
await session.commit()
provider_pydantic = deleted_provider.to_pydantic()
# For BYOK providers, automatically sync available models
# This will add any new models and remove any that are no longer available
if is_byok:
await self._sync_default_models_for_provider(provider_pydantic, actor)
return provider_pydantic
# Create provider with the appropriate category
provider_data = request.model_dump()
# Unset deprecated api_key and access_key as to not write plaintext values, api_key_enc and access_key_enc will be set below
provider_data.pop("api_key", None)
provider_data.pop("access_key", None)
provider_data["provider_category"] = ProviderCategory.byok if is_byok else ProviderCategory.base
provider = PydanticProvider(**provider_data)
# if provider.name == provider.provider_type.value:
# raise ValueError("Provider name must be unique and different from provider type")
# Fill in schema-default base_url if not provided
# This ensures providers like ZAI get their default endpoint persisted to DB
# rather than relying on cast_to_subtype() at read time
if provider.base_url is None:
typed_provider = provider.cast_to_subtype()
if typed_provider.base_url is not None:
provider.base_url = typed_provider.base_url
# Only assign organization id for non-base providers
# Base providers should be globally accessible (org_id = None)
if is_byok:
provider.organization_id = actor.organization_id
# Lazily create the provider id prior to persistence
provider.resolve_identifier()
# Explicitly populate encrypted fields from plaintext (async to avoid blocking event loop)
if request.api_key is not None:
provider.api_key_enc = await Secret.from_plaintext_async(request.api_key)
if request.access_key is not None:
provider.access_key_enc = await Secret.from_plaintext_async(request.access_key)
new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
await new_provider.create_async(session, actor=actor)
provider_pydantic = new_provider.to_pydantic()
# For BYOK providers, automatically sync available models
if is_byok:
await self._sync_default_models_for_provider(provider_pydantic, actor)
return provider_pydantic
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
@trace_method
async def update_provider_async(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
"""Update provider details."""
async with db_registry.async_session() as session:
# Retrieve the existing provider by ID
existing_provider = await ProviderModel.read_async(
db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
)
# Update only the fields that are provided in ProviderUpdate
update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Handle encryption for api_key if provided
# Only re-encrypt if the value has actually changed
if "api_key" in update_data and update_data["api_key"] is not None:
# Check if value changed
existing_api_key = None
if existing_provider.api_key_enc:
existing_secret = Secret.from_encrypted(existing_provider.api_key_enc)
existing_api_key = await existing_secret.get_plaintext_async()
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_api_key != update_data["api_key"]:
api_key_secret = await Secret.from_plaintext_async(update_data["api_key"])
existing_provider.api_key_enc = api_key_secret.get_encrypted()
# Remove from update_data since we set directly on existing_provider
update_data.pop("api_key", None)
update_data.pop("api_key_enc", None)
# Handle encryption for access_key if provided
# Only re-encrypt if the value has actually changed
if "access_key" in update_data and update_data["access_key"] is not None:
# Check if value changed
existing_access_key = None
if existing_provider.access_key_enc:
existing_secret = Secret.from_encrypted(existing_provider.access_key_enc)
existing_access_key = await existing_secret.get_plaintext_async()
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_access_key != update_data["access_key"]:
access_key_secret = await Secret.from_plaintext_async(update_data["access_key"])
existing_provider.access_key_enc = access_key_secret.get_encrypted()
# Remove from update_data since we set directly on existing_provider
update_data.pop("access_key", None)
update_data.pop("access_key_enc", None)
# Apply remaining updates
for key, value in update_data.items():
setattr(existing_provider, key, value)
# Commit the updated provider
await existing_provider.update_async(session, actor=actor)
return existing_provider.to_pydantic()
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
async def update_provider_last_synced_async(self, provider_id: str, actor: Optional[PydanticUser] = None) -> None:
"""Update the last_synced timestamp for a provider.
Note: actor is optional to support system-level operations (e.g., during server initialization
for global providers). When actor is provided, org-scoping is enforced.
"""
from datetime import datetime, timezone
async with db_registry.async_session() as session:
provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
provider.last_synced = datetime.now(timezone.utc)
await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
@trace_method
async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
"""Delete a provider and its associated models."""
async with db_registry.async_session() as session:
# Clear api key field
existing_provider = await ProviderModel.read_async(
db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
)
existing_provider.api_key_enc = None
existing_provider.access_key_enc = None
# Only accessing these deprecated fields to clear, which may trigger a warning
existing_provider.api_key = None
existing_provider.access_key = None
logger.info("Soft deleting provider with id: %s", provider_id)
await existing_provider.update_async(session, actor=actor)
# Soft delete all models associated with this provider
provider_models = await ProviderModelORM.list_async(
db_session=session,
provider_id=provider_id,
check_is_deleted=True,
)
for model in provider_models:
await model.delete_async(session, actor=actor)
# Soft delete in provider table
await existing_provider.delete_async(session, actor=actor)
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method
async def list_providers_async(
self,
actor: PydanticUser,
name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
provider_category: Optional[List[ProviderCategory]] = None,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
ascending: bool = False,
) -> List[PydanticProvider]:
"""
List all providers with pagination support.
Returns both global providers (organization_id=NULL) and organization-specific providers.
"""
filter_kwargs = {}
if name:
filter_kwargs["name"] = name
if provider_type:
filter_kwargs["provider_type"] = provider_type
async with db_registry.async_session() as session:
# Get organization-specific providers
org_providers = await ProviderModel.list_async(
db_session=session,
before=before,
after=after,
limit=limit,
actor=actor,
ascending=ascending,
check_is_deleted=True,
**filter_kwargs,
)
# Get global providers (base providers with organization_id=NULL)
global_filter_kwargs = {**filter_kwargs, "organization_id": None}
global_providers = await ProviderModel.list_async(
db_session=session,
before=before,
after=after,
limit=limit,
ascending=ascending,
check_is_deleted=True,
**global_filter_kwargs,
)
# Combine both lists
all_providers = []
if not provider_category:
all_providers = org_providers + global_providers
else:
if ProviderCategory.byok in provider_category:
all_providers += org_providers
if ProviderCategory.base in provider_category:
all_providers += global_providers
# Remove deprecated api_key and access_key fields from the response
for provider in all_providers:
provider.api_key = None
provider.access_key = None
return [provider.to_pydantic() for provider in all_providers]
@enforce_types
@trace_method
def list_providers(
self,
actor: PydanticUser,
name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
ascending: bool = False,
) -> List[PydanticProvider]:
"""
List all providers with pagination support (synchronous version).
Returns both global providers (organization_id=NULL) and organization-specific providers.
"""
filter_kwargs = {}
if name:
filter_kwargs["name"] = name
if provider_type:
filter_kwargs["provider_type"] = provider_type
with db_registry.get_session() as session:
# Get organization-specific providers
org_providers = ProviderModel.list(
db_session=session,
before=before,
after=after,
limit=limit,
actor=actor,
ascending=ascending,
check_is_deleted=True,
**filter_kwargs,
)
# Get global providers (base providers with organization_id=NULL)
global_filter_kwargs = {**filter_kwargs, "organization_id": None}
global_providers = ProviderModel.list(
db_session=session,
before=before,
after=after,
limit=limit,
ascending=ascending,
check_is_deleted=True,
**global_filter_kwargs,
)
# Combine both lists
all_providers = org_providers + global_providers
return [provider.to_pydantic() for provider in all_providers]
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
@trace_method
async def get_provider_async(self, provider_id: str, actor: PydanticUser) -> PydanticProvider:
async with db_registry.async_session() as session:
# First try to get as organization-specific provider
try:
provider_model = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
return provider_model.to_pydantic()
except:
# If not found, try to get as global provider (organization_id=NULL)
from sqlalchemy import select
stmt = select(ProviderModel).where(
ProviderModel.id == provider_id,
ProviderModel.organization_id.is_(None),
ProviderModel.is_deleted == False,
)
result = await session.execute(stmt)
provider_model = result.scalar_one_or_none()
if provider_model:
# Remove deprecated api_key and access_key fields from the response
provider_model.api_key = None
provider_model.access_key = None
return provider_model.to_pydantic()
else:
from letta.orm.errors import NoResultFound
raise NoResultFound(f"Provider not found with id='{provider_id}'")
@enforce_types
@trace_method
def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
providers = self.list_providers(name=provider_name, actor=actor)
return providers[0].id if providers else None
@enforce_types
@trace_method
def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
providers = self.list_providers(name=provider_name, actor=actor)
if providers:
# Decrypt the API key before returning
api_key_secret = providers[0].api_key_enc
return api_key_secret.get_plaintext() if api_key_secret else None
return None
@enforce_types
@trace_method
async def get_override_key_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
providers = await self.list_providers_async(name=provider_name, actor=actor)
if providers:
# Decrypt the API key before returning
api_key_secret = providers[0].api_key_enc
return await api_key_secret.get_plaintext_async() if api_key_secret else None
return None
@enforce_types
@trace_method
async def get_bedrock_credentials_async(
self, provider_name: Union[str, None], actor: PydanticUser
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
providers = await self.list_providers_async(name=provider_name, actor=actor)
if providers:
# Decrypt the credentials before returning
access_key_secret = providers[0].access_key_enc
api_key_secret = providers[0].api_key_enc
access_key = await access_key_secret.get_plaintext_async() if access_key_secret else None
secret_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
region = providers[0].region
return access_key, secret_key, region
return None, None, None
@enforce_types
@trace_method
def get_azure_credentials(
self, provider_name: Union[str, None], actor: PydanticUser
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
providers = self.list_providers(name=provider_name, actor=actor)
if providers:
# Decrypt the API key before returning
api_key_secret = providers[0].api_key_enc
api_key = api_key_secret.get_plaintext() if api_key_secret else None
base_url = providers[0].base_url
api_version = providers[0].api_version
return api_key, base_url, api_version
return None, None, None
@enforce_types
@trace_method
async def get_azure_credentials_async(
self, provider_name: Union[str, None], actor: PydanticUser
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
providers = await self.list_providers_async(name=provider_name, actor=actor)
if providers:
# Decrypt the API key before returning
api_key_secret = providers[0].api_key_enc
api_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
base_url = providers[0].base_url
api_version = providers[0].api_version
return api_key, base_url, api_version
return None, None, None
@enforce_types
@trace_method
async def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
provider = PydanticProvider(
name=provider_check.provider_type.value,
provider_type=provider_check.provider_type,
api_key_enc=Secret.from_plaintext(provider_check.api_key),
provider_category=ProviderCategory.byok,
access_key_enc=Secret.from_plaintext(provider_check.access_key) if provider_check.access_key else None,
region=provider_check.region,
base_url=provider_check.base_url,
api_version=provider_check.api_version,
).cast_to_subtype()
# TODO: add more string sanity checks here before we hit actual endpoints
if not provider.api_key_enc or not await provider.api_key_enc.get_plaintext_async():
raise ValueError("API key is required!")
await provider.check_api_key()
async def _sync_default_models_for_provider(self, provider: PydanticProvider, actor: PydanticUser) -> None:
"""Sync models for a newly created BYOK provider by querying the provider's API."""
try:
# Use cast_to_subtype() which properly handles all provider types and preserves api_key_enc
typed_provider = provider.cast_to_subtype()
llm_models = await typed_provider.list_llm_models_async()
embedding_models = await typed_provider.list_embedding_models_async()
await self.sync_provider_models_async(
provider=provider,
llm_models=llm_models,
embedding_models=embedding_models,
organization_id=actor.organization_id,
)
await self.update_provider_last_synced_async(provider.id, actor=actor)
except Exception as e:
logger.error(f"Failed to sync models for provider '{provider.name}': {e}")
# Don't fail provider creation if model sync fails
@enforce_types
@trace_method
async def sync_base_providers(self, base_providers: list[PydanticProvider], actor: PydanticUser) -> None:
"""
Sync base providers (from environment) to database (idempotent).
This method is safe to call from multiple pods simultaneously as it:
1. Checks if provider exists before creating
2. Handles race conditions with UniqueConstraintViolationError
3. Only creates providers that don't exist (no updates to avoid conflicts)
Args:
base_providers: List of base provider instances from environment variables
actor: User actor for database operations
"""
from letta.log import get_logger
from letta.orm.errors import UniqueConstraintViolationError
logger = get_logger(__name__)
logger.info(f"Syncing {len(base_providers)} base providers to database")
async with db_registry.async_session() as session:
for provider in base_providers:
try:
# Check if base provider already exists (base providers have organization_id=None)
existing_providers = await ProviderModel.list_async(
db_session=session,
name=provider.name,
organization_id=None, # Base providers are global
limit=1,
)
if existing_providers:
logger.debug(f"Base provider '{provider.name}' already exists in database, skipping")
continue
# Convert Provider to ProviderCreate
# NOTE: Do NOT store API keys for base providers in the database.
# Base providers should always use environment variables for API keys.
# This ensures keys stay in sync with env vars and aren't duplicated in DB.
provider_create = ProviderCreate(
name=provider.name,
provider_type=provider.provider_type,
api_key="", # Base providers use env vars, not DB-stored keys
access_key=None,
region=provider.region,
base_url=provider.base_url,
api_version=provider.api_version,
)
# Create the provider in the database as a base provider
await self.create_provider_async(request=provider_create, actor=actor, is_byok=False)
logger.info(f"Successfully initialized base provider '{provider.name}' to database")
except UniqueConstraintViolationError:
# Race condition: another pod created this provider between our check and create
# This is expected and safe - just log and continue
logger.debug(f"Provider '{provider.name}' was created by another pod, skipping")
except Exception as e:
# Log error but don't fail startup - provider initialization is not critical
logger.error(f"Failed to sync provider '{provider.name}' to database: {e}", exc_info=True)
@enforce_types
@trace_method
async def sync_provider_models_async(
self,
provider: PydanticProvider,
llm_models: List[LLMConfig],
embedding_models: List[EmbeddingConfig],
organization_id: Optional[str] = None,
) -> None:
"""Sync models from a provider to the database - adds new models and removes old ones."""
from letta.log import get_logger
logger = get_logger(__name__)
logger.info(f"=== Starting sync for provider '{provider.name}' (ID: {provider.id}) ===")
logger.info(f" Organization ID: {organization_id}")
logger.info(f" LLM models to sync: {[m.handle for m in llm_models]}")
logger.info(f" Embedding models to sync: {[m.handle for m in embedding_models]}")
async with db_registry.async_session() as session:
# Get all existing models for this provider and organization
# We need to handle None organization_id specially for SQL NULL comparisons
from sqlalchemy import and_, select
# Build the query conditions
if organization_id is None:
# For global models (organization_id IS NULL), excluding soft-deleted
stmt = select(ProviderModelORM).where(
and_(
ProviderModelORM.provider_id == provider.id,
ProviderModelORM.organization_id.is_(None),
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
)
)
result = await session.execute(stmt)
existing_models = list(result.scalars().all())
else:
# For org-specific models
existing_models = await ProviderModelORM.list_async(
db_session=session,
check_is_deleted=True, # Filter out soft-deleted models
**{
"provider_id": provider.id,
"organization_id": organization_id,
},
)
# Build sets of handles for incoming models
incoming_llm_handles = {llm.handle for llm in llm_models}
incoming_embedding_handles = {emb.handle for emb in embedding_models}
all_incoming_handles = incoming_llm_handles | incoming_embedding_handles
# Determine which models to remove (existing models not in the incoming list)
models_to_remove = []
for existing_model in existing_models:
if existing_model.handle not in all_incoming_handles:
models_to_remove.append(existing_model)
# Remove models that are no longer in the sync list
for model_to_remove in models_to_remove:
await model_to_remove.delete_async(session)
logger.debug(f"Removed model {model_to_remove.handle} from provider {provider.name}")
# Commit the deletions
await session.commit()
# Process LLM models - add new ones
logger.info(f"Processing {len(llm_models)} LLM models for provider {provider.name}")
for llm_config in llm_models:
logger.info(f" Checking LLM model: {llm_config.handle} (name: {llm_config.model})")
# Check if model already exists (excluding soft-deleted ones)
existing = await ProviderModelORM.list_async(
db_session=session,
limit=1,
check_is_deleted=True, # Filter out soft-deleted models
**{
"handle": llm_config.handle,
"organization_id": organization_id,
"model_type": "llm", # Must check model_type since handle can be same for LLM and embedding
},
)
if not existing:
logger.info(f" Creating new LLM model {llm_config.handle}")
# Create new model entry
pydantic_model = PydanticProviderModel(
handle=llm_config.handle,
display_name=llm_config.model,
name=llm_config.model,
provider_id=provider.id,
organization_id=organization_id,
model_type="llm",
enabled=True,
model_endpoint_type=llm_config.model_endpoint_type,
max_context_window=llm_config.context_window,
supports_token_streaming=llm_config.model_endpoint_type in ["openai", "anthropic", "deepseek"],
supports_tool_calling=True, # Assume true for LLMs for now
)
logger.info(
f" Model data: handle={pydantic_model.handle}, name={pydantic_model.name}, "
f"model_type={pydantic_model.model_type}, provider_id={pydantic_model.provider_id}, "
f"org_id={pydantic_model.organization_id}"
)
# Convert to ORM
model = ProviderModelORM(**pydantic_model.model_dump(to_orm=True))
try:
await model.create_async(session)
logger.info(f" ✓ Successfully created LLM model {llm_config.handle} with ID {model.id}")
except Exception as e:
logger.info(f" ✗ Failed to create LLM model {llm_config.handle}: {e}")
# Log the full error details
import traceback
logger.info(f" Full traceback: {traceback.format_exc()}")
# Roll back the session to clear the failed transaction
await session.rollback()
else:
# Check if max_context_window or model_endpoint_type needs to be updated
existing_model = existing[0]
needs_update = False
if existing_model.max_context_window != llm_config.context_window:
logger.info(
f" Updating LLM model {llm_config.handle} max_context_window: "
f"{existing_model.max_context_window} -> {llm_config.context_window}"
)
existing_model.max_context_window = llm_config.context_window
needs_update = True
if existing_model.model_endpoint_type != llm_config.model_endpoint_type:
logger.info(
f" Updating LLM model {llm_config.handle} model_endpoint_type: "
f"{existing_model.model_endpoint_type} -> {llm_config.model_endpoint_type}"
)
existing_model.model_endpoint_type = llm_config.model_endpoint_type
needs_update = True
if needs_update:
await existing_model.update_async(session)
else:
logger.info(f" LLM model {llm_config.handle} already exists (ID: {existing[0].id}), skipping")
# Process embedding models - add new ones
logger.info(f"Processing {len(embedding_models)} embedding models for provider {provider.name}")
for embedding_config in embedding_models:
logger.info(f" Checking embedding model: {embedding_config.handle} (name: {embedding_config.embedding_model})")
# Check if model already exists (excluding soft-deleted ones)
existing = await ProviderModelORM.list_async(
db_session=session,
limit=1,
check_is_deleted=True, # Filter out soft-deleted models
**{
"handle": embedding_config.handle,
"organization_id": organization_id,
"model_type": "embedding", # Must check model_type since handle can be same for LLM and embedding
},
)
if not existing:
logger.info(f" Creating new embedding model {embedding_config.handle}")
# Create new model entry
pydantic_model = PydanticProviderModel(
handle=embedding_config.handle,
display_name=embedding_config.embedding_model,
name=embedding_config.embedding_model,
provider_id=provider.id,
organization_id=organization_id,
model_type="embedding",
enabled=True,
model_endpoint_type=embedding_config.embedding_endpoint_type,
embedding_dim=embedding_config.embedding_dim if hasattr(embedding_config, "embedding_dim") else None,
)
logger.info(
f" Model data: handle={pydantic_model.handle}, name={pydantic_model.name}, "
f"model_type={pydantic_model.model_type}, provider_id={pydantic_model.provider_id}, "
f"org_id={pydantic_model.organization_id}"
)
# Convert to ORM
model = ProviderModelORM(**pydantic_model.model_dump(to_orm=True))
try:
await model.create_async(session)
logger.info(f" ✓ Successfully created embedding model {embedding_config.handle} with ID {model.id}")
except Exception as e:
logger.error(f" ✗ Failed to create embedding model {embedding_config.handle}: {e}")
# Log the full error details
import traceback
logger.error(f" Full traceback: {traceback.format_exc()}")
# Roll back the session to clear the failed transaction
await session.rollback()
else:
# Check if model_endpoint_type needs to be updated
existing_model = existing[0]
if existing_model.model_endpoint_type != embedding_config.embedding_endpoint_type:
logger.info(
f" Updating embedding model {embedding_config.handle} model_endpoint_type: "
f"{existing_model.model_endpoint_type} -> {embedding_config.embedding_endpoint_type}"
)
existing_model.model_endpoint_type = embedding_config.embedding_endpoint_type
await existing_model.update_async(session)
else:
logger.info(f" Embedding model {embedding_config.handle} already exists (ID: {existing[0].id}), skipping")
@enforce_types
@trace_method
async def get_model_by_handle_async(
self,
handle: str,
actor: PydanticUser,
model_type: Optional[str] = None,
) -> Optional[PydanticProviderModel]:
"""Get a model by its handle. Handles are unique per organization."""
async with db_registry.async_session() as session:
from sqlalchemy import and_, or_, select
# Build conditions for the query
conditions = [
ProviderModelORM.handle == handle,
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
]
if model_type:
conditions.append(ProviderModelORM.model_type == model_type)
# Search for models that are either:
# 1. Organization-specific (matching actor's org)
# 2. Global (organization_id is NULL)
conditions.append(or_(ProviderModelORM.organization_id == actor.organization_id, ProviderModelORM.organization_id.is_(None)))
stmt = select(ProviderModelORM).where(and_(*conditions))
result = await session.execute(stmt)
models = list(result.scalars().all())
# Find the model the user has access to
# Prioritize org-specific models over global models
org_model = None
global_model = None
for model in models:
if model.organization_id == actor.organization_id:
org_model = model
elif model.organization_id is None:
global_model = model
# Return org-specific model if it exists, otherwise return global model
if org_model:
return org_model.to_pydantic()
elif global_model:
return global_model.to_pydantic()
return None
@enforce_types
@trace_method
async def list_models_async(
self,
actor: PydanticUser,
model_type: Optional[str] = None,
provider_id: Optional[str] = None,
enabled: Optional[bool] = True,
limit: Optional[int] = None,
) -> List[PydanticProviderModel]:
"""List models available to an actor (both global and org-scoped)."""
async with db_registry.async_session() as session:
# Build filters
filters = {}
if model_type:
filters["model_type"] = model_type
if provider_id:
filters["provider_id"] = provider_id
if enabled is not None:
filters["enabled"] = enabled
# Get org-scoped models (excluding soft-deleted ones)
org_filters = {**filters, "organization_id": actor.organization_id}
org_models = await ProviderModelORM.list_async(
db_session=session,
limit=limit,
check_is_deleted=True, # Filter out soft-deleted models
**org_filters,
)
# Get global models - need to handle NULL organization_id specially
from sqlalchemy import and_, select
# Build conditions for global models query
conditions = [
ProviderModelORM.organization_id.is_(None),
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
]
if model_type:
conditions.append(ProviderModelORM.model_type == model_type)
if provider_id:
conditions.append(ProviderModelORM.provider_id == provider_id)
if enabled is not None:
conditions.append(ProviderModelORM.enabled == enabled)
stmt = select(ProviderModelORM).where(and_(*conditions))
if limit:
stmt = stmt.limit(limit)
result = await session.execute(stmt)
global_models = list(result.scalars().all())
# Combine and deduplicate by handle AND model_type (org-scoped takes precedence)
# Use (handle, model_type) tuple as key since same handle can exist for LLM and embedding
all_models = {(m.handle, m.model_type): m for m in global_models}
all_models.update({(m.handle, m.model_type): m for m in org_models})
return [m.to_pydantic() for m in all_models.values()]
@enforce_types
@trace_method
async def get_llm_config_from_handle(
self,
handle: str,
actor: PydanticUser,
) -> LLMConfig:
"""Get an LLMConfig from a model handle.
Args:
handle: The model handle to look up
actor: The user actor for permission checking
Returns:
LLMConfig constructed from the provider and model data
Raises:
NoResultFound: If the handle doesn't exist in the database or BYOK provider
"""
from letta.orm.errors import NoResultFound
# Look up the model by handle in the database (for base providers)
model = await self.get_model_by_handle_async(handle=handle, actor=actor, model_type="llm")
if not model:
# Model not in DB - check if it's from a BYOK provider
# Handle format is "provider_name/model_name"
if "/" in handle:
provider_name, model_name = handle.split("/", 1)
byok_providers = await self.list_providers_async(
actor=actor,
name=provider_name,
provider_category=[ProviderCategory.byok],
)
if byok_providers:
# Fetch models dynamically from BYOK provider
provider = byok_providers[0]
typed_provider = provider.cast_to_subtype()
try:
all_llm_configs = await typed_provider.list_llm_models_async()
# Match by handle first (original logic)
llm_configs = [config for config in all_llm_configs if config.handle == handle]
# Fallback to match by model name (original logic)
if not llm_configs:
llm_configs = [config for config in all_llm_configs if config.model == model_name]
if llm_configs:
return llm_configs[0]
except Exception as e:
logger.warning(f"Failed to fetch models from BYOK provider {provider_name}: {e}")
raise NoResultFound(f"LLM model not found with handle='{handle}'")
# Get the provider for this model and cast to subtype to access provider-specific methods
provider = await self.get_provider_async(provider_id=model.provider_id, actor=actor)
typed_provider = provider.cast_to_subtype()
# Get the default max_output_tokens from the provider (provider-specific logic)
max_tokens = typed_provider.get_default_max_output_tokens(model.name)
# Determine the model endpoint - use provider's base_url if set,
# otherwise use provider-specific defaults
if typed_provider.base_url:
model_endpoint = typed_provider.base_url
elif provider.provider_type == ProviderType.chatgpt_oauth:
# ChatGPT OAuth uses the ChatGPT backend API, not a generic endpoint pattern
from letta.schemas.providers.chatgpt_oauth import CHATGPT_CODEX_ENDPOINT
model_endpoint = CHATGPT_CODEX_ENDPOINT
else:
model_endpoint = f"https://api.{provider.provider_type.value}.com/v1"
# Construct the LLMConfig from the model and provider data
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=model_endpoint,
context_window=model.max_context_window or 16384, # Default if not set
handle=model.handle,
provider_name=provider.name,
provider_category=provider.provider_category,
max_tokens=max_tokens,
)
return llm_config
@enforce_types
@trace_method
async def get_embedding_config_from_handle(
self,
handle: str,
actor: PydanticUser,
) -> EmbeddingConfig:
"""Get an EmbeddingConfig from a model handle.
Args:
handle: The model handle to look up
actor: The user actor for permission checking
Returns:
EmbeddingConfig constructed from the provider and model data
Raises:
NoResultFound: If the handle doesn't exist in the database or BYOK provider
"""
from letta.orm.errors import NoResultFound
# Look up the model by handle in the database (for base providers)
model = await self.get_model_by_handle_async(handle=handle, actor=actor, model_type="embedding")
if not model:
# Model not in DB - check if it's from a BYOK provider
# Handle format is "provider_name/model_name"
if "/" in handle:
provider_name, model_name = handle.split("/", 1)
byok_providers = await self.list_providers_async(
actor=actor,
name=provider_name,
provider_category=[ProviderCategory.byok],
)
if byok_providers:
# Fetch models dynamically from BYOK provider
provider = byok_providers[0]
typed_provider = provider.cast_to_subtype()
try:
all_embedding_configs = await typed_provider.list_embedding_models_async()
# Match by handle (original logic - no model_name fallback for embeddings)
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
if embedding_configs:
return embedding_configs[0]
except Exception as e:
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider_name}: {e}")
raise NoResultFound(f"Embedding model not found with handle='{handle}'")
# Get the provider for this model
provider = await self.get_provider_async(provider_id=model.provider_id, actor=actor)
# Construct the EmbeddingConfig from the model and provider data
embedding_config = EmbeddingConfig(
embedding_model=model.name,
embedding_endpoint_type=model.model_endpoint_type,
embedding_endpoint=provider.base_url or f"https://api.{provider.provider_type.value}.com/v1",
embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default
embedding_chunk_size=300, # Default chunk size
handle=model.handle,
)
return embedding_config