Files
letta-server/letta/services/provider_manager.py
Kian Jones 71e0a8aab9 fix(core): use INSERT ON CONFLICT DO NOTHING for provider model sync (#9342)
* fix(core): use INSERT ON CONFLICT DO NOTHING for provider model sync

Replaces try/except around model.create_async() with pg_insert()
.on_conflict_do_nothing() to prevent UniqueViolationError from being
raised at the asyncpg driver level during concurrent model syncs.
The previous approach caught the exception in Python but ddtrace still
captured it at the driver level, causing Datadog error tracking noise.

Fixes Datadog issue d8dec148-d535-11f0-95eb-da7ad0900000

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* cleaner impl

* fix

---------

Co-authored-by: Letta <noreply@letta.com>
Co-authored-by: Ari Webb <ari@letta.com>
2026-02-24 10:52:07 -08:00

1084 lines
50 KiB
Python

from typing import List, Optional, Tuple, Union
from sqlalchemy import and_, select
from letta.log import get_logger
from letta.orm.errors import UniqueConstraintViolationError
from letta.orm.provider import Provider as ProviderModel
from letta.orm.provider_model import ProviderModel as ProviderModelORM
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.provider_model import ProviderModel as PydanticProviderModel
from letta.schemas.providers import Provider as PydanticProvider, ProviderCheck, ProviderCreate, ProviderUpdate
from letta.schemas.secret import Secret
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
class ProviderManager:
@enforce_types
@trace_method
async def create_provider_async(self, request: ProviderCreate, actor: PydanticUser, is_byok: bool = True) -> PydanticProvider:
"""Create a new provider if it doesn't already exist.
Args:
request: ProviderCreate object with provider details
actor: User creating the provider
is_byok: If True, creates a BYOK provider (default). If False, creates a base provider.
"""
async with db_registry.async_session() as session:
# Check for name conflicts
if is_byok:
# BYOK providers cannot use the same name as base providers
existing_base_providers = await ProviderModel.list_async(
db_session=session,
name=request.name,
organization_id=None, # Base providers have NULL organization_id
limit=1,
)
if existing_base_providers:
raise ValueError(
f"Provider name '{request.name}' conflicts with an existing base provider. Please choose a different name."
)
else:
# Base providers must have unique names among themselves
# (the DB constraint won't catch this because NULL != NULL)
existing_base_providers = await ProviderModel.list_async(
db_session=session,
name=request.name,
organization_id=None, # Base providers have NULL organization_id
limit=1,
)
if existing_base_providers:
raise ValueError(f"Base provider name '{request.name}' already exists. Please choose a different name.")
# Check if there's a soft-deleted provider with the same name that we can restore
org_id = actor.organization_id if is_byok else None
if org_id is not None:
stmt = select(ProviderModel).where(
and_(
ProviderModel.name == request.name,
ProviderModel.organization_id == org_id,
ProviderModel.is_deleted == True,
)
)
else:
stmt = select(ProviderModel).where(
and_(
ProviderModel.name == request.name,
ProviderModel.organization_id.is_(None),
ProviderModel.is_deleted == True,
)
)
result = await session.execute(stmt)
deleted_provider = result.scalar_one_or_none()
if deleted_provider:
# Restore the soft-deleted provider and update its fields
logger.info(f"Restoring soft-deleted provider '{request.name}' with id: {deleted_provider.id}")
deleted_provider.is_deleted = False
deleted_provider.provider_type = request.provider_type
deleted_provider.provider_category = ProviderCategory.byok if is_byok else ProviderCategory.base
deleted_provider.base_url = request.base_url
deleted_provider.region = request.region
deleted_provider.api_version = request.api_version
# Update encrypted fields (async to avoid blocking event loop)
if request.api_key is not None:
api_key_secret = await Secret.from_plaintext_async(request.api_key)
deleted_provider.api_key_enc = api_key_secret.get_encrypted()
if request.access_key is not None:
access_key_secret = await Secret.from_plaintext_async(request.access_key)
deleted_provider.access_key_enc = access_key_secret.get_encrypted()
await deleted_provider.update_async(session, actor=actor)
# Also restore any soft-deleted models associated with this provider
# This is needed because the unique constraint on provider_models doesn't include is_deleted,
# so soft-deleted models would block creation of new models with the same handle
from sqlalchemy import update
restore_models_stmt = (
update(ProviderModelORM)
.where(
and_(
ProviderModelORM.provider_id == deleted_provider.id,
ProviderModelORM.is_deleted == True,
)
)
.values(is_deleted=False)
)
result = await session.execute(restore_models_stmt)
if result.rowcount > 0:
logger.info(f"Restored {result.rowcount} soft-deleted model(s) for provider '{request.name}'")
# Commit the provider and model restoration before syncing
# This is needed because _sync_default_models_for_provider opens a new session
# that can't see uncommitted changes from this session
await session.commit()
provider_pydantic = deleted_provider.to_pydantic()
# For BYOK providers, automatically sync available models
# This will add any new models and remove any that are no longer available
if is_byok:
await self._sync_default_models_for_provider(provider_pydantic, actor)
return provider_pydantic
# Create provider with the appropriate category
provider_data = request.model_dump()
# Unset deprecated api_key and access_key as to not write plaintext values, api_key_enc and access_key_enc will be set below
provider_data.pop("api_key", None)
provider_data.pop("access_key", None)
provider_data["provider_category"] = ProviderCategory.byok if is_byok else ProviderCategory.base
provider = PydanticProvider(**provider_data)
# if provider.name == provider.provider_type.value:
# raise ValueError("Provider name must be unique and different from provider type")
# Fill in schema-default base_url if not provided
# This ensures providers like ZAI get their default endpoint persisted to DB
# rather than relying on cast_to_subtype() at read time
if provider.base_url is None:
typed_provider = provider.cast_to_subtype()
if typed_provider.base_url is not None:
provider.base_url = typed_provider.base_url
# Only assign organization id for non-base providers
# Base providers should be globally accessible (org_id = None)
if is_byok:
provider.organization_id = actor.organization_id
# Lazily create the provider id prior to persistence
provider.resolve_identifier()
# Explicitly populate encrypted fields from plaintext (async to avoid blocking event loop)
if request.api_key is not None:
provider.api_key_enc = await Secret.from_plaintext_async(request.api_key)
if request.access_key is not None:
provider.access_key_enc = await Secret.from_plaintext_async(request.access_key)
new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
await new_provider.create_async(session, actor=actor)
provider_pydantic = new_provider.to_pydantic()
# For BYOK providers, automatically sync available models
if is_byok:
await self._sync_default_models_for_provider(provider_pydantic, actor)
return provider_pydantic
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
@trace_method
async def update_provider_async(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
"""Update provider details."""
async with db_registry.async_session() as session:
# Retrieve the existing provider by ID
existing_provider = await ProviderModel.read_async(
db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
)
# Update only the fields that are provided in ProviderUpdate
update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Handle encryption for api_key if provided
# Only re-encrypt if the value has actually changed
if "api_key" in update_data and update_data["api_key"] is not None:
# Check if value changed
existing_api_key = None
if existing_provider.api_key_enc:
existing_secret = Secret.from_encrypted(existing_provider.api_key_enc)
existing_api_key = await existing_secret.get_plaintext_async()
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_api_key != update_data["api_key"]:
api_key_secret = await Secret.from_plaintext_async(update_data["api_key"])
existing_provider.api_key_enc = api_key_secret.get_encrypted()
# Remove from update_data since we set directly on existing_provider
update_data.pop("api_key", None)
update_data.pop("api_key_enc", None)
# Handle encryption for access_key if provided
# Only re-encrypt if the value has actually changed
if "access_key" in update_data and update_data["access_key"] is not None:
# Check if value changed
existing_access_key = None
if existing_provider.access_key_enc:
existing_secret = Secret.from_encrypted(existing_provider.access_key_enc)
existing_access_key = await existing_secret.get_plaintext_async()
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_access_key != update_data["access_key"]:
access_key_secret = await Secret.from_plaintext_async(update_data["access_key"])
existing_provider.access_key_enc = access_key_secret.get_encrypted()
# Remove from update_data since we set directly on existing_provider
update_data.pop("access_key", None)
update_data.pop("access_key_enc", None)
# Apply remaining updates
for key, value in update_data.items():
setattr(existing_provider, key, value)
# Commit the updated provider
await existing_provider.update_async(session, actor=actor)
return existing_provider.to_pydantic()
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
async def update_provider_last_synced_async(self, provider_id: str, actor: Optional[PydanticUser] = None) -> None:
"""Update the last_synced timestamp for a provider.
Note: actor is optional to support system-level operations (e.g., during server initialization
for global providers). When actor is provided, org-scoping is enforced.
"""
from datetime import datetime, timezone
async with db_registry.async_session() as session:
provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
provider.last_synced = datetime.now(timezone.utc)
await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
@trace_method
async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
"""Delete a provider and its associated models."""
async with db_registry.async_session() as session:
# Clear api key field
existing_provider = await ProviderModel.read_async(
db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
)
existing_provider.api_key_enc = None
existing_provider.access_key_enc = None
# Only accessing these deprecated fields to clear, which may trigger a warning
existing_provider.api_key = None
existing_provider.access_key = None
logger.info("Soft deleting provider with id: %s", provider_id)
await existing_provider.update_async(session, actor=actor)
# Soft delete all models associated with this provider
provider_models = await ProviderModelORM.list_async(
db_session=session,
provider_id=provider_id,
check_is_deleted=True,
)
for model in provider_models:
await model.delete_async(session, actor=actor)
# Soft delete in provider table
await existing_provider.delete_async(session, actor=actor)
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method
async def list_providers_async(
self,
actor: PydanticUser,
name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
provider_category: Optional[List[ProviderCategory]] = None,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
ascending: bool = False,
) -> List[PydanticProvider]:
"""
List all providers with pagination support.
Returns both global providers (organization_id=NULL) and organization-specific providers.
"""
filter_kwargs = {}
if name:
filter_kwargs["name"] = name
if provider_type:
filter_kwargs["provider_type"] = provider_type
async with db_registry.async_session() as session:
# Get organization-specific providers
org_providers = await ProviderModel.list_async(
db_session=session,
before=before,
after=after,
limit=limit,
actor=actor,
ascending=ascending,
check_is_deleted=True,
**filter_kwargs,
)
# Get global providers (base providers with organization_id=NULL)
global_filter_kwargs = {**filter_kwargs, "organization_id": None}
global_providers = await ProviderModel.list_async(
db_session=session,
before=before,
after=after,
limit=limit,
ascending=ascending,
check_is_deleted=True,
**global_filter_kwargs,
)
# Combine both lists
all_providers = []
if not provider_category:
all_providers = org_providers + global_providers
else:
if ProviderCategory.byok in provider_category:
all_providers += org_providers
if ProviderCategory.base in provider_category:
all_providers += global_providers
# Remove deprecated api_key and access_key fields from the response
for provider in all_providers:
provider.api_key = None
provider.access_key = None
return [provider.to_pydantic() for provider in all_providers]
@enforce_types
@trace_method
def list_providers(
self,
actor: PydanticUser,
name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
ascending: bool = False,
) -> List[PydanticProvider]:
"""
List all providers with pagination support (synchronous version).
Returns both global providers (organization_id=NULL) and organization-specific providers.
"""
filter_kwargs = {}
if name:
filter_kwargs["name"] = name
if provider_type:
filter_kwargs["provider_type"] = provider_type
with db_registry.get_session() as session:
# Get organization-specific providers
org_providers = ProviderModel.list(
db_session=session,
before=before,
after=after,
limit=limit,
actor=actor,
ascending=ascending,
check_is_deleted=True,
**filter_kwargs,
)
# Get global providers (base providers with organization_id=NULL)
global_filter_kwargs = {**filter_kwargs, "organization_id": None}
global_providers = ProviderModel.list(
db_session=session,
before=before,
after=after,
limit=limit,
ascending=ascending,
check_is_deleted=True,
**global_filter_kwargs,
)
# Combine both lists
all_providers = org_providers + global_providers
return [provider.to_pydantic() for provider in all_providers]
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
@trace_method
async def get_provider_async(self, provider_id: str, actor: PydanticUser) -> PydanticProvider:
async with db_registry.async_session() as session:
# First try to get as organization-specific provider
try:
provider_model = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
return provider_model.to_pydantic()
except:
# If not found, try to get as global provider (organization_id=NULL)
from sqlalchemy import select
stmt = select(ProviderModel).where(
ProviderModel.id == provider_id,
ProviderModel.organization_id.is_(None),
ProviderModel.is_deleted == False,
)
result = await session.execute(stmt)
provider_model = result.scalar_one_or_none()
if provider_model:
# Remove deprecated api_key and access_key fields from the response
provider_model.api_key = None
provider_model.access_key = None
return provider_model.to_pydantic()
else:
from letta.orm.errors import NoResultFound
raise NoResultFound(f"Provider not found with id='{provider_id}'")
@enforce_types
@trace_method
def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
providers = self.list_providers(name=provider_name, actor=actor)
return providers[0].id if providers else None
@enforce_types
@trace_method
def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
providers = self.list_providers(name=provider_name, actor=actor)
if providers:
# Decrypt the API key before returning
api_key_secret = providers[0].api_key_enc
return api_key_secret.get_plaintext() if api_key_secret else None
return None
@enforce_types
@trace_method
async def get_override_key_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
providers = await self.list_providers_async(name=provider_name, actor=actor)
if providers:
# Decrypt the API key before returning
api_key_secret = providers[0].api_key_enc
return await api_key_secret.get_plaintext_async() if api_key_secret else None
return None
@enforce_types
@trace_method
async def get_bedrock_credentials_async(
self, provider_name: Union[str, None], actor: PydanticUser
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
providers = await self.list_providers_async(name=provider_name, actor=actor)
if providers:
# Decrypt the credentials before returning
access_key_secret = providers[0].access_key_enc
api_key_secret = providers[0].api_key_enc
access_key = await access_key_secret.get_plaintext_async() if access_key_secret else None
secret_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
region = providers[0].region
return access_key, secret_key, region
return None, None, None
@enforce_types
@trace_method
def get_azure_credentials(
self, provider_name: Union[str, None], actor: PydanticUser
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
providers = self.list_providers(name=provider_name, actor=actor)
if providers:
# Decrypt the API key before returning
api_key_secret = providers[0].api_key_enc
api_key = api_key_secret.get_plaintext() if api_key_secret else None
base_url = providers[0].base_url
api_version = providers[0].api_version
return api_key, base_url, api_version
return None, None, None
@enforce_types
@trace_method
async def get_azure_credentials_async(
self, provider_name: Union[str, None], actor: PydanticUser
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
providers = await self.list_providers_async(name=provider_name, actor=actor)
if providers:
# Decrypt the API key before returning
api_key_secret = providers[0].api_key_enc
api_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
base_url = providers[0].base_url
api_version = providers[0].api_version
return api_key, base_url, api_version
return None, None, None
@enforce_types
@trace_method
async def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
provider = PydanticProvider(
name=provider_check.provider_type.value,
provider_type=provider_check.provider_type,
api_key_enc=Secret.from_plaintext(provider_check.api_key),
provider_category=ProviderCategory.byok,
access_key_enc=Secret.from_plaintext(provider_check.access_key) if provider_check.access_key else None,
region=provider_check.region,
base_url=provider_check.base_url,
api_version=provider_check.api_version,
).cast_to_subtype()
# TODO: add more string sanity checks here before we hit actual endpoints
if not provider.api_key_enc or not await provider.api_key_enc.get_plaintext_async():
raise ValueError("API key is required!")
await provider.check_api_key()
async def _sync_default_models_for_provider(self, provider: PydanticProvider, actor: PydanticUser) -> None:
"""Sync models for a newly created BYOK provider by querying the provider's API."""
try:
# Use cast_to_subtype() which properly handles all provider types and preserves api_key_enc
typed_provider = provider.cast_to_subtype()
llm_models = await typed_provider.list_llm_models_async()
embedding_models = await typed_provider.list_embedding_models_async()
await self.sync_provider_models_async(
provider=provider,
llm_models=llm_models,
embedding_models=embedding_models,
organization_id=actor.organization_id,
)
await self.update_provider_last_synced_async(provider.id, actor=actor)
except Exception as e:
logger.error(f"Failed to sync models for provider '{provider.name}': {e}")
# Don't fail provider creation if model sync fails
@enforce_types
@trace_method
async def sync_base_providers(self, base_providers: list[PydanticProvider], actor: PydanticUser) -> None:
"""
Sync base providers (from environment) to database (idempotent).
This method is safe to call from multiple pods simultaneously as it:
1. Checks if provider exists before creating
2. Handles race conditions with UniqueConstraintViolationError
3. Only creates providers that don't exist (no updates to avoid conflicts)
Args:
base_providers: List of base provider instances from environment variables
actor: User actor for database operations
"""
from letta.log import get_logger
from letta.orm.errors import UniqueConstraintViolationError
logger = get_logger(__name__)
logger.info(f"Syncing {len(base_providers)} base providers to database")
async with db_registry.async_session() as session:
for provider in base_providers:
try:
# Check if base provider already exists (base providers have organization_id=None)
existing_providers = await ProviderModel.list_async(
db_session=session,
name=provider.name,
organization_id=None, # Base providers are global
limit=1,
)
if existing_providers:
logger.debug(f"Base provider '{provider.name}' already exists in database, skipping")
continue
# Convert Provider to ProviderCreate
# NOTE: Do NOT store API keys for base providers in the database.
# Base providers should always use environment variables for API keys.
# This ensures keys stay in sync with env vars and aren't duplicated in DB.
provider_create = ProviderCreate(
name=provider.name,
provider_type=provider.provider_type,
api_key="", # Base providers use env vars, not DB-stored keys
access_key=None,
region=provider.region,
base_url=provider.base_url,
api_version=provider.api_version,
)
# Create the provider in the database as a base provider
await self.create_provider_async(request=provider_create, actor=actor, is_byok=False)
logger.info(f"Successfully initialized base provider '{provider.name}' to database")
except UniqueConstraintViolationError:
# Race condition: another pod created this provider between our check and create
# This is expected and safe - just log and continue
logger.debug(f"Provider '{provider.name}' was created by another pod, skipping")
except Exception as e:
# Log error but don't fail startup - provider initialization is not critical
logger.error(f"Failed to sync provider '{provider.name}' to database: {e}", exc_info=True)
@enforce_types
@trace_method
async def sync_provider_models_async(
self,
provider: PydanticProvider,
llm_models: List[LLMConfig],
embedding_models: List[EmbeddingConfig],
organization_id: Optional[str] = None,
) -> None:
"""Sync models from a provider to the database - adds new models and removes old ones."""
from letta.log import get_logger
logger = get_logger(__name__)
logger.info(f"=== Starting sync for provider '{provider.name}' (ID: {provider.id}) ===")
logger.info(f" Organization ID: {organization_id}")
logger.info(f" LLM models to sync: {[m.handle for m in llm_models]}")
logger.info(f" Embedding models to sync: {[m.handle for m in embedding_models]}")
async with db_registry.async_session() as session:
# Get all existing models for this provider and organization
# We need to handle None organization_id specially for SQL NULL comparisons
from sqlalchemy import and_, select
# Build the query conditions
if organization_id is None:
# For global models (organization_id IS NULL), excluding soft-deleted
stmt = select(ProviderModelORM).where(
and_(
ProviderModelORM.provider_id == provider.id,
ProviderModelORM.organization_id.is_(None),
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
)
)
result = await session.execute(stmt)
existing_models = list(result.scalars().all())
else:
# For org-specific models
existing_models = await ProviderModelORM.list_async(
db_session=session,
check_is_deleted=True, # Filter out soft-deleted models
**{
"provider_id": provider.id,
"organization_id": organization_id,
},
)
# Build sets of handles for incoming models
incoming_llm_handles = {llm.handle for llm in llm_models}
incoming_embedding_handles = {emb.handle for emb in embedding_models}
all_incoming_handles = incoming_llm_handles | incoming_embedding_handles
# Determine which models to remove (existing models not in the incoming list)
models_to_remove = []
for existing_model in existing_models:
if existing_model.handle not in all_incoming_handles:
models_to_remove.append(existing_model)
# Remove models that are no longer in the sync list
for model_to_remove in models_to_remove:
await model_to_remove.delete_async(session)
logger.debug(f"Removed model {model_to_remove.handle} from provider {provider.name}")
# Commit the deletions
await session.commit()
# Process LLM models - add new ones
logger.info(f"Processing {len(llm_models)} LLM models for provider {provider.name}")
for llm_config in llm_models:
logger.info(f" Checking LLM model: {llm_config.handle} (name: {llm_config.model})")
# Check if model already exists by handle (excluding soft-deleted ones)
existing = await ProviderModelORM.list_async(
db_session=session,
limit=1,
check_is_deleted=True, # Filter out soft-deleted models
**{
"handle": llm_config.handle,
"organization_id": organization_id,
"model_type": "llm", # Must check model_type since handle can be same for LLM and embedding
},
)
# Also check by name+provider_id (covers unique_model_per_provider_and_type constraint)
if not existing:
existing = await ProviderModelORM.list_async(
db_session=session,
limit=1,
check_is_deleted=True,
**{
"name": llm_config.model,
"provider_id": provider.id,
"model_type": "llm",
},
)
if not existing:
logger.info(f" Creating new LLM model {llm_config.handle}")
# Create new model entry
pydantic_model = PydanticProviderModel(
handle=llm_config.handle,
display_name=llm_config.model,
name=llm_config.model,
provider_id=provider.id,
organization_id=organization_id,
model_type="llm",
enabled=True,
model_endpoint_type=llm_config.model_endpoint_type,
max_context_window=llm_config.context_window,
supports_token_streaming=llm_config.model_endpoint_type in ["openai", "anthropic", "deepseek", "openrouter"],
supports_tool_calling=True, # Assume true for LLMs for now
)
logger.info(
f" Model data: handle={pydantic_model.handle}, name={pydantic_model.name}, "
f"model_type={pydantic_model.model_type}, provider_id={pydantic_model.provider_id}, "
f"org_id={pydantic_model.organization_id}"
)
model = ProviderModelORM(**pydantic_model.model_dump(to_orm=True))
result = await model.create_async(session, ignore_conflicts=True)
if result:
logger.info(f" ✓ Successfully created LLM model {llm_config.handle}")
else:
logger.info(f" LLM model {llm_config.handle} already exists (concurrent insert), skipping")
else:
# Check if max_context_window or model_endpoint_type needs to be updated
existing_model = existing[0]
needs_update = False
if existing_model.max_context_window != llm_config.context_window:
logger.info(
f" Updating LLM model {llm_config.handle} max_context_window: "
f"{existing_model.max_context_window} -> {llm_config.context_window}"
)
existing_model.max_context_window = llm_config.context_window
needs_update = True
if existing_model.model_endpoint_type != llm_config.model_endpoint_type:
logger.info(
f" Updating LLM model {llm_config.handle} model_endpoint_type: "
f"{existing_model.model_endpoint_type} -> {llm_config.model_endpoint_type}"
)
existing_model.model_endpoint_type = llm_config.model_endpoint_type
needs_update = True
if needs_update:
await existing_model.update_async(session)
else:
logger.info(f" LLM model {llm_config.handle} already exists (ID: {existing[0].id}), skipping")
# Process embedding models - add new ones
logger.info(f"Processing {len(embedding_models)} embedding models for provider {provider.name}")
for embedding_config in embedding_models:
logger.info(f" Checking embedding model: {embedding_config.handle} (name: {embedding_config.embedding_model})")
# Check if model already exists by handle (excluding soft-deleted ones)
existing = await ProviderModelORM.list_async(
db_session=session,
limit=1,
check_is_deleted=True, # Filter out soft-deleted models
**{
"handle": embedding_config.handle,
"organization_id": organization_id,
"model_type": "embedding", # Must check model_type since handle can be same for LLM and embedding
},
)
# Also check by name+provider_id (covers unique_model_per_provider_and_type constraint)
if not existing:
existing = await ProviderModelORM.list_async(
db_session=session,
limit=1,
check_is_deleted=True,
**{
"name": embedding_config.embedding_model,
"provider_id": provider.id,
"model_type": "embedding",
},
)
if not existing:
logger.info(f" Creating new embedding model {embedding_config.handle}")
# Create new model entry
pydantic_model = PydanticProviderModel(
handle=embedding_config.handle,
display_name=embedding_config.embedding_model,
name=embedding_config.embedding_model,
provider_id=provider.id,
organization_id=organization_id,
model_type="embedding",
enabled=True,
model_endpoint_type=embedding_config.embedding_endpoint_type,
embedding_dim=embedding_config.embedding_dim if hasattr(embedding_config, "embedding_dim") else None,
)
logger.info(
f" Model data: handle={pydantic_model.handle}, name={pydantic_model.name}, "
f"model_type={pydantic_model.model_type}, provider_id={pydantic_model.provider_id}, "
f"org_id={pydantic_model.organization_id}"
)
model = ProviderModelORM(**pydantic_model.model_dump(to_orm=True))
result = await model.create_async(session, ignore_conflicts=True)
if result:
logger.info(f" ✓ Successfully created embedding model {embedding_config.handle}")
else:
logger.info(f" Embedding model {embedding_config.handle} already exists (concurrent insert), skipping")
else:
# Check if model_endpoint_type needs to be updated
existing_model = existing[0]
if existing_model.model_endpoint_type != embedding_config.embedding_endpoint_type:
logger.info(
f" Updating embedding model {embedding_config.handle} model_endpoint_type: "
f"{existing_model.model_endpoint_type} -> {embedding_config.embedding_endpoint_type}"
)
existing_model.model_endpoint_type = embedding_config.embedding_endpoint_type
await existing_model.update_async(session)
else:
logger.info(f" Embedding model {embedding_config.handle} already exists (ID: {existing[0].id}), skipping")
@enforce_types
@trace_method
async def get_model_by_handle_async(
self,
handle: str,
actor: PydanticUser,
model_type: Optional[str] = None,
) -> Optional[PydanticProviderModel]:
"""Get a model by its handle. Handles are unique per organization."""
async with db_registry.async_session() as session:
from sqlalchemy import and_, or_, select
# Build conditions for the query
conditions = [
ProviderModelORM.handle == handle,
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
]
if model_type:
conditions.append(ProviderModelORM.model_type == model_type)
# Search for models that are either:
# 1. Organization-specific (matching actor's org)
# 2. Global (organization_id is NULL)
conditions.append(or_(ProviderModelORM.organization_id == actor.organization_id, ProviderModelORM.organization_id.is_(None)))
stmt = select(ProviderModelORM).where(and_(*conditions))
result = await session.execute(stmt)
models = list(result.scalars().all())
# Find the model the user has access to
# Prioritize org-specific models over global models
org_model = None
global_model = None
for model in models:
if model.organization_id == actor.organization_id:
org_model = model
elif model.organization_id is None:
global_model = model
# Return org-specific model if it exists, otherwise return global model
if org_model:
return org_model.to_pydantic()
elif global_model:
return global_model.to_pydantic()
return None
@enforce_types
@trace_method
async def list_models_async(
self,
actor: PydanticUser,
model_type: Optional[str] = None,
provider_id: Optional[str] = None,
enabled: Optional[bool] = True,
limit: Optional[int] = None,
) -> List[PydanticProviderModel]:
"""List models available to an actor (both global and org-scoped)."""
async with db_registry.async_session() as session:
# Build filters
filters = {}
if model_type:
filters["model_type"] = model_type
if provider_id:
filters["provider_id"] = provider_id
if enabled is not None:
filters["enabled"] = enabled
# Get org-scoped models (excluding soft-deleted ones)
org_filters = {**filters, "organization_id": actor.organization_id}
org_models = await ProviderModelORM.list_async(
db_session=session,
limit=limit,
check_is_deleted=True, # Filter out soft-deleted models
**org_filters,
)
# Get global models - need to handle NULL organization_id specially
from sqlalchemy import and_, select
# Build conditions for global models query
conditions = [
ProviderModelORM.organization_id.is_(None),
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
]
if model_type:
conditions.append(ProviderModelORM.model_type == model_type)
if provider_id:
conditions.append(ProviderModelORM.provider_id == provider_id)
if enabled is not None:
conditions.append(ProviderModelORM.enabled == enabled)
stmt = select(ProviderModelORM).where(and_(*conditions))
if limit:
stmt = stmt.limit(limit)
result = await session.execute(stmt)
global_models = list(result.scalars().all())
# Combine and deduplicate by handle AND model_type (org-scoped takes precedence)
# Use (handle, model_type) tuple as key since same handle can exist for LLM and embedding
all_models = {(m.handle, m.model_type): m for m in global_models}
all_models.update({(m.handle, m.model_type): m for m in org_models})
return [m.to_pydantic() for m in all_models.values()]
@enforce_types
@trace_method
async def get_llm_config_from_handle(
self,
handle: str,
actor: PydanticUser,
) -> LLMConfig:
"""Get an LLMConfig from a model handle.
Args:
handle: The model handle to look up
actor: The user actor for permission checking
Returns:
LLMConfig constructed from the provider and model data
Raises:
NoResultFound: If the handle doesn't exist in the database or BYOK provider
"""
from letta.orm.errors import NoResultFound
# Look up the model by handle in the database (for base providers)
model = await self.get_model_by_handle_async(handle=handle, actor=actor, model_type="llm")
if not model:
# Model not in DB - check if it's from a BYOK provider
# Handle format is "provider_name/model_name"
if "/" in handle:
provider_name, model_name = handle.split("/", 1)
byok_providers = await self.list_providers_async(
actor=actor,
name=provider_name,
provider_category=[ProviderCategory.byok],
)
if byok_providers:
# Fetch models dynamically from BYOK provider
provider = byok_providers[0]
typed_provider = provider.cast_to_subtype()
try:
all_llm_configs = await typed_provider.list_llm_models_async()
# Match by handle first (original logic)
llm_configs = [config for config in all_llm_configs if config.handle == handle]
# Fallback to match by model name (original logic)
if not llm_configs:
llm_configs = [config for config in all_llm_configs if config.model == model_name]
if llm_configs:
return llm_configs[0]
except Exception as e:
logger.warning(f"Failed to fetch models from BYOK provider {provider_name}: {e}")
raise NoResultFound(f"LLM model not found with handle='{handle}'")
# Get the provider for this model and cast to subtype to access provider-specific methods
provider = await self.get_provider_async(provider_id=model.provider_id, actor=actor)
typed_provider = provider.cast_to_subtype()
# Get the default max_output_tokens from the provider (provider-specific logic)
max_tokens = typed_provider.get_default_max_output_tokens(model.name)
# Determine the model endpoint - use provider's base_url if set,
# otherwise use provider-specific defaults
if typed_provider.base_url:
model_endpoint = typed_provider.base_url
elif provider.provider_type == ProviderType.chatgpt_oauth:
# ChatGPT OAuth uses the ChatGPT backend API, not a generic endpoint pattern
from letta.schemas.providers.chatgpt_oauth import CHATGPT_CODEX_ENDPOINT
model_endpoint = CHATGPT_CODEX_ENDPOINT
else:
model_endpoint = f"https://api.{provider.provider_type.value}.com/v1"
# Construct the LLMConfig from the model and provider data
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=model_endpoint,
context_window=model.max_context_window or 16384, # Default if not set
handle=model.handle,
provider_name=provider.name,
provider_category=provider.provider_category,
max_tokens=max_tokens,
)
return llm_config
@enforce_types
@trace_method
async def get_embedding_config_from_handle(
self,
handle: str,
actor: PydanticUser,
) -> EmbeddingConfig:
"""Get an EmbeddingConfig from a model handle.
Args:
handle: The model handle to look up
actor: The user actor for permission checking
Returns:
EmbeddingConfig constructed from the provider and model data
Raises:
NoResultFound: If the handle doesn't exist in the database or BYOK provider
"""
from letta.orm.errors import NoResultFound
# Look up the model by handle in the database (for base providers)
model = await self.get_model_by_handle_async(handle=handle, actor=actor, model_type="embedding")
if not model:
# Model not in DB - check if it's from a BYOK provider
# Handle format is "provider_name/model_name"
if "/" in handle:
provider_name, model_name = handle.split("/", 1)
byok_providers = await self.list_providers_async(
actor=actor,
name=provider_name,
provider_category=[ProviderCategory.byok],
)
if byok_providers:
# Fetch models dynamically from BYOK provider
provider = byok_providers[0]
typed_provider = provider.cast_to_subtype()
try:
all_embedding_configs = await typed_provider.list_embedding_models_async()
# Match by handle (original logic - no model_name fallback for embeddings)
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
if embedding_configs:
return embedding_configs[0]
except Exception as e:
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider_name}: {e}")
raise NoResultFound(f"Embedding model not found with handle='{handle}'")
# Get the provider for this model
provider = await self.get_provider_async(provider_id=model.provider_id, actor=actor)
# Construct the EmbeddingConfig from the model and provider data
embedding_config = EmbeddingConfig(
embedding_model=model.name,
embedding_endpoint_type=model.model_endpoint_type,
embedding_endpoint=provider.base_url or f"https://api.{provider.provider_type.value}.com/v1",
embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default
embedding_chunk_size=300, # Default chunk size
handle=model.handle,
)
return embedding_config