* fix(core): use INSERT ON CONFLICT DO NOTHING for provider model sync Replaces try/except around model.create_async() with pg_insert() .on_conflict_do_nothing() to prevent UniqueViolationError from being raised at the asyncpg driver level during concurrent model syncs. The previous approach caught the exception in Python but ddtrace still captured it at the driver level, causing Datadog error tracking noise. Fixes Datadog issue d8dec148-d535-11f0-95eb-da7ad0900000 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * cleaner impl * fix --------- Co-authored-by: Letta <noreply@letta.com> Co-authored-by: Ari Webb <ari@letta.com>
1084 lines
50 KiB
Python
1084 lines
50 KiB
Python
from typing import List, Optional, Tuple, Union
|
|
|
|
from sqlalchemy import and_, select
|
|
|
|
from letta.log import get_logger
|
|
from letta.orm.errors import UniqueConstraintViolationError
|
|
from letta.orm.provider import Provider as ProviderModel
|
|
from letta.orm.provider_model import ProviderModel as ProviderModelORM
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.provider_model import ProviderModel as PydanticProviderModel
|
|
from letta.schemas.providers import Provider as PydanticProvider, ProviderCheck, ProviderCreate, ProviderUpdate
|
|
from letta.schemas.secret import Secret
|
|
from letta.schemas.user import User as PydanticUser
|
|
from letta.server.db import db_registry
|
|
from letta.utils import enforce_types
|
|
from letta.validators import raise_on_invalid_id
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class ProviderManager:
|
|
@enforce_types
|
|
@trace_method
|
|
async def create_provider_async(self, request: ProviderCreate, actor: PydanticUser, is_byok: bool = True) -> PydanticProvider:
|
|
"""Create a new provider if it doesn't already exist.
|
|
|
|
Args:
|
|
request: ProviderCreate object with provider details
|
|
actor: User creating the provider
|
|
is_byok: If True, creates a BYOK provider (default). If False, creates a base provider.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
# Check for name conflicts
|
|
if is_byok:
|
|
# BYOK providers cannot use the same name as base providers
|
|
existing_base_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
name=request.name,
|
|
organization_id=None, # Base providers have NULL organization_id
|
|
limit=1,
|
|
)
|
|
if existing_base_providers:
|
|
raise ValueError(
|
|
f"Provider name '{request.name}' conflicts with an existing base provider. Please choose a different name."
|
|
)
|
|
else:
|
|
# Base providers must have unique names among themselves
|
|
# (the DB constraint won't catch this because NULL != NULL)
|
|
existing_base_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
name=request.name,
|
|
organization_id=None, # Base providers have NULL organization_id
|
|
limit=1,
|
|
)
|
|
if existing_base_providers:
|
|
raise ValueError(f"Base provider name '{request.name}' already exists. Please choose a different name.")
|
|
|
|
# Check if there's a soft-deleted provider with the same name that we can restore
|
|
org_id = actor.organization_id if is_byok else None
|
|
if org_id is not None:
|
|
stmt = select(ProviderModel).where(
|
|
and_(
|
|
ProviderModel.name == request.name,
|
|
ProviderModel.organization_id == org_id,
|
|
ProviderModel.is_deleted == True,
|
|
)
|
|
)
|
|
else:
|
|
stmt = select(ProviderModel).where(
|
|
and_(
|
|
ProviderModel.name == request.name,
|
|
ProviderModel.organization_id.is_(None),
|
|
ProviderModel.is_deleted == True,
|
|
)
|
|
)
|
|
result = await session.execute(stmt)
|
|
deleted_provider = result.scalar_one_or_none()
|
|
|
|
if deleted_provider:
|
|
# Restore the soft-deleted provider and update its fields
|
|
logger.info(f"Restoring soft-deleted provider '{request.name}' with id: {deleted_provider.id}")
|
|
deleted_provider.is_deleted = False
|
|
deleted_provider.provider_type = request.provider_type
|
|
deleted_provider.provider_category = ProviderCategory.byok if is_byok else ProviderCategory.base
|
|
deleted_provider.base_url = request.base_url
|
|
deleted_provider.region = request.region
|
|
deleted_provider.api_version = request.api_version
|
|
|
|
# Update encrypted fields (async to avoid blocking event loop)
|
|
if request.api_key is not None:
|
|
api_key_secret = await Secret.from_plaintext_async(request.api_key)
|
|
deleted_provider.api_key_enc = api_key_secret.get_encrypted()
|
|
if request.access_key is not None:
|
|
access_key_secret = await Secret.from_plaintext_async(request.access_key)
|
|
deleted_provider.access_key_enc = access_key_secret.get_encrypted()
|
|
|
|
await deleted_provider.update_async(session, actor=actor)
|
|
|
|
# Also restore any soft-deleted models associated with this provider
|
|
# This is needed because the unique constraint on provider_models doesn't include is_deleted,
|
|
# so soft-deleted models would block creation of new models with the same handle
|
|
from sqlalchemy import update
|
|
|
|
restore_models_stmt = (
|
|
update(ProviderModelORM)
|
|
.where(
|
|
and_(
|
|
ProviderModelORM.provider_id == deleted_provider.id,
|
|
ProviderModelORM.is_deleted == True,
|
|
)
|
|
)
|
|
.values(is_deleted=False)
|
|
)
|
|
result = await session.execute(restore_models_stmt)
|
|
if result.rowcount > 0:
|
|
logger.info(f"Restored {result.rowcount} soft-deleted model(s) for provider '{request.name}'")
|
|
|
|
# Commit the provider and model restoration before syncing
|
|
# This is needed because _sync_default_models_for_provider opens a new session
|
|
# that can't see uncommitted changes from this session
|
|
await session.commit()
|
|
|
|
provider_pydantic = deleted_provider.to_pydantic()
|
|
|
|
# For BYOK providers, automatically sync available models
|
|
# This will add any new models and remove any that are no longer available
|
|
if is_byok:
|
|
await self._sync_default_models_for_provider(provider_pydantic, actor)
|
|
|
|
return provider_pydantic
|
|
|
|
# Create provider with the appropriate category
|
|
provider_data = request.model_dump()
|
|
|
|
# Unset deprecated api_key and access_key as to not write plaintext values, api_key_enc and access_key_enc will be set below
|
|
provider_data.pop("api_key", None)
|
|
provider_data.pop("access_key", None)
|
|
|
|
provider_data["provider_category"] = ProviderCategory.byok if is_byok else ProviderCategory.base
|
|
provider = PydanticProvider(**provider_data)
|
|
|
|
# if provider.name == provider.provider_type.value:
|
|
# raise ValueError("Provider name must be unique and different from provider type")
|
|
|
|
# Fill in schema-default base_url if not provided
|
|
# This ensures providers like ZAI get their default endpoint persisted to DB
|
|
# rather than relying on cast_to_subtype() at read time
|
|
if provider.base_url is None:
|
|
typed_provider = provider.cast_to_subtype()
|
|
if typed_provider.base_url is not None:
|
|
provider.base_url = typed_provider.base_url
|
|
|
|
# Only assign organization id for non-base providers
|
|
# Base providers should be globally accessible (org_id = None)
|
|
if is_byok:
|
|
provider.organization_id = actor.organization_id
|
|
|
|
# Lazily create the provider id prior to persistence
|
|
provider.resolve_identifier()
|
|
|
|
# Explicitly populate encrypted fields from plaintext (async to avoid blocking event loop)
|
|
if request.api_key is not None:
|
|
provider.api_key_enc = await Secret.from_plaintext_async(request.api_key)
|
|
if request.access_key is not None:
|
|
provider.access_key_enc = await Secret.from_plaintext_async(request.access_key)
|
|
|
|
new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
|
|
await new_provider.create_async(session, actor=actor)
|
|
provider_pydantic = new_provider.to_pydantic()
|
|
|
|
# For BYOK providers, automatically sync available models
|
|
if is_byok:
|
|
await self._sync_default_models_for_provider(provider_pydantic, actor)
|
|
|
|
return provider_pydantic
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
|
@trace_method
|
|
async def update_provider_async(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
|
|
"""Update provider details."""
|
|
async with db_registry.async_session() as session:
|
|
# Retrieve the existing provider by ID
|
|
existing_provider = await ProviderModel.read_async(
|
|
db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
|
|
)
|
|
|
|
# Update only the fields that are provided in ProviderUpdate
|
|
update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
|
|
|
# Handle encryption for api_key if provided
|
|
# Only re-encrypt if the value has actually changed
|
|
if "api_key" in update_data and update_data["api_key"] is not None:
|
|
# Check if value changed
|
|
existing_api_key = None
|
|
if existing_provider.api_key_enc:
|
|
existing_secret = Secret.from_encrypted(existing_provider.api_key_enc)
|
|
existing_api_key = await existing_secret.get_plaintext_async()
|
|
|
|
# Only re-encrypt if different (async to avoid blocking event loop)
|
|
if existing_api_key != update_data["api_key"]:
|
|
api_key_secret = await Secret.from_plaintext_async(update_data["api_key"])
|
|
existing_provider.api_key_enc = api_key_secret.get_encrypted()
|
|
|
|
# Remove from update_data since we set directly on existing_provider
|
|
update_data.pop("api_key", None)
|
|
update_data.pop("api_key_enc", None)
|
|
|
|
# Handle encryption for access_key if provided
|
|
# Only re-encrypt if the value has actually changed
|
|
if "access_key" in update_data and update_data["access_key"] is not None:
|
|
# Check if value changed
|
|
existing_access_key = None
|
|
if existing_provider.access_key_enc:
|
|
existing_secret = Secret.from_encrypted(existing_provider.access_key_enc)
|
|
existing_access_key = await existing_secret.get_plaintext_async()
|
|
|
|
# Only re-encrypt if different (async to avoid blocking event loop)
|
|
if existing_access_key != update_data["access_key"]:
|
|
access_key_secret = await Secret.from_plaintext_async(update_data["access_key"])
|
|
existing_provider.access_key_enc = access_key_secret.get_encrypted()
|
|
|
|
# Remove from update_data since we set directly on existing_provider
|
|
update_data.pop("access_key", None)
|
|
update_data.pop("access_key_enc", None)
|
|
|
|
# Apply remaining updates
|
|
for key, value in update_data.items():
|
|
setattr(existing_provider, key, value)
|
|
|
|
# Commit the updated provider
|
|
await existing_provider.update_async(session, actor=actor)
|
|
return existing_provider.to_pydantic()
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
|
async def update_provider_last_synced_async(self, provider_id: str, actor: Optional[PydanticUser] = None) -> None:
|
|
"""Update the last_synced timestamp for a provider.
|
|
|
|
Note: actor is optional to support system-level operations (e.g., during server initialization
|
|
for global providers). When actor is provided, org-scoping is enforced.
|
|
"""
|
|
from datetime import datetime, timezone
|
|
|
|
async with db_registry.async_session() as session:
|
|
provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
|
|
provider.last_synced = datetime.now(timezone.utc)
|
|
await session.commit()
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
|
@trace_method
|
|
async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
|
|
"""Delete a provider and its associated models."""
|
|
async with db_registry.async_session() as session:
|
|
# Clear api key field
|
|
existing_provider = await ProviderModel.read_async(
|
|
db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
|
|
)
|
|
existing_provider.api_key_enc = None
|
|
existing_provider.access_key_enc = None
|
|
|
|
# Only accessing these deprecated fields to clear, which may trigger a warning
|
|
existing_provider.api_key = None
|
|
existing_provider.access_key = None
|
|
|
|
logger.info("Soft deleting provider with id: %s", provider_id)
|
|
|
|
await existing_provider.update_async(session, actor=actor)
|
|
|
|
# Soft delete all models associated with this provider
|
|
provider_models = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
provider_id=provider_id,
|
|
check_is_deleted=True,
|
|
)
|
|
for model in provider_models:
|
|
await model.delete_async(session, actor=actor)
|
|
|
|
# Soft delete in provider table
|
|
await existing_provider.delete_async(session, actor=actor)
|
|
|
|
# context manager now handles commits
|
|
# await session.commit()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def list_providers_async(
|
|
self,
|
|
actor: PydanticUser,
|
|
name: Optional[str] = None,
|
|
provider_type: Optional[ProviderType] = None,
|
|
provider_category: Optional[List[ProviderCategory]] = None,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = 50,
|
|
ascending: bool = False,
|
|
) -> List[PydanticProvider]:
|
|
"""
|
|
List all providers with pagination support.
|
|
Returns both global providers (organization_id=NULL) and organization-specific providers.
|
|
"""
|
|
filter_kwargs = {}
|
|
if name:
|
|
filter_kwargs["name"] = name
|
|
if provider_type:
|
|
filter_kwargs["provider_type"] = provider_type
|
|
async with db_registry.async_session() as session:
|
|
# Get organization-specific providers
|
|
org_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
actor=actor,
|
|
ascending=ascending,
|
|
check_is_deleted=True,
|
|
**filter_kwargs,
|
|
)
|
|
|
|
# Get global providers (base providers with organization_id=NULL)
|
|
global_filter_kwargs = {**filter_kwargs, "organization_id": None}
|
|
global_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
ascending=ascending,
|
|
check_is_deleted=True,
|
|
**global_filter_kwargs,
|
|
)
|
|
|
|
# Combine both lists
|
|
all_providers = []
|
|
if not provider_category:
|
|
all_providers = org_providers + global_providers
|
|
else:
|
|
if ProviderCategory.byok in provider_category:
|
|
all_providers += org_providers
|
|
if ProviderCategory.base in provider_category:
|
|
all_providers += global_providers
|
|
|
|
# Remove deprecated api_key and access_key fields from the response
|
|
for provider in all_providers:
|
|
provider.api_key = None
|
|
provider.access_key = None
|
|
|
|
return [provider.to_pydantic() for provider in all_providers]
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
def list_providers(
|
|
self,
|
|
actor: PydanticUser,
|
|
name: Optional[str] = None,
|
|
provider_type: Optional[ProviderType] = None,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = 50,
|
|
ascending: bool = False,
|
|
) -> List[PydanticProvider]:
|
|
"""
|
|
List all providers with pagination support (synchronous version).
|
|
Returns both global providers (organization_id=NULL) and organization-specific providers.
|
|
"""
|
|
filter_kwargs = {}
|
|
if name:
|
|
filter_kwargs["name"] = name
|
|
if provider_type:
|
|
filter_kwargs["provider_type"] = provider_type
|
|
with db_registry.get_session() as session:
|
|
# Get organization-specific providers
|
|
org_providers = ProviderModel.list(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
actor=actor,
|
|
ascending=ascending,
|
|
check_is_deleted=True,
|
|
**filter_kwargs,
|
|
)
|
|
|
|
# Get global providers (base providers with organization_id=NULL)
|
|
global_filter_kwargs = {**filter_kwargs, "organization_id": None}
|
|
global_providers = ProviderModel.list(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
ascending=ascending,
|
|
check_is_deleted=True,
|
|
**global_filter_kwargs,
|
|
)
|
|
|
|
# Combine both lists
|
|
all_providers = org_providers + global_providers
|
|
|
|
return [provider.to_pydantic() for provider in all_providers]
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
|
@trace_method
|
|
async def get_provider_async(self, provider_id: str, actor: PydanticUser) -> PydanticProvider:
|
|
async with db_registry.async_session() as session:
|
|
# First try to get as organization-specific provider
|
|
try:
|
|
provider_model = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
|
|
return provider_model.to_pydantic()
|
|
except:
|
|
# If not found, try to get as global provider (organization_id=NULL)
|
|
from sqlalchemy import select
|
|
|
|
stmt = select(ProviderModel).where(
|
|
ProviderModel.id == provider_id,
|
|
ProviderModel.organization_id.is_(None),
|
|
ProviderModel.is_deleted == False,
|
|
)
|
|
result = await session.execute(stmt)
|
|
provider_model = result.scalar_one_or_none()
|
|
if provider_model:
|
|
# Remove deprecated api_key and access_key fields from the response
|
|
provider_model.api_key = None
|
|
provider_model.access_key = None
|
|
return provider_model.to_pydantic()
|
|
else:
|
|
from letta.orm.errors import NoResultFound
|
|
|
|
raise NoResultFound(f"Provider not found with id='{provider_id}'")
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
|
providers = self.list_providers(name=provider_name, actor=actor)
|
|
return providers[0].id if providers else None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
|
providers = self.list_providers(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the API key before returning
|
|
api_key_secret = providers[0].api_key_enc
|
|
return api_key_secret.get_plaintext() if api_key_secret else None
|
|
return None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_override_key_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
|
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the API key before returning
|
|
api_key_secret = providers[0].api_key_enc
|
|
return await api_key_secret.get_plaintext_async() if api_key_secret else None
|
|
return None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_bedrock_credentials_async(
|
|
self, provider_name: Union[str, None], actor: PydanticUser
|
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the credentials before returning
|
|
access_key_secret = providers[0].access_key_enc
|
|
api_key_secret = providers[0].api_key_enc
|
|
access_key = await access_key_secret.get_plaintext_async() if access_key_secret else None
|
|
secret_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
|
|
region = providers[0].region
|
|
return access_key, secret_key, region
|
|
return None, None, None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
def get_azure_credentials(
|
|
self, provider_name: Union[str, None], actor: PydanticUser
|
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
providers = self.list_providers(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the API key before returning
|
|
api_key_secret = providers[0].api_key_enc
|
|
api_key = api_key_secret.get_plaintext() if api_key_secret else None
|
|
base_url = providers[0].base_url
|
|
api_version = providers[0].api_version
|
|
return api_key, base_url, api_version
|
|
return None, None, None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_azure_credentials_async(
|
|
self, provider_name: Union[str, None], actor: PydanticUser
|
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the API key before returning
|
|
api_key_secret = providers[0].api_key_enc
|
|
api_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
|
|
base_url = providers[0].base_url
|
|
api_version = providers[0].api_version
|
|
return api_key, base_url, api_version
|
|
return None, None, None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
|
|
provider = PydanticProvider(
|
|
name=provider_check.provider_type.value,
|
|
provider_type=provider_check.provider_type,
|
|
api_key_enc=Secret.from_plaintext(provider_check.api_key),
|
|
provider_category=ProviderCategory.byok,
|
|
access_key_enc=Secret.from_plaintext(provider_check.access_key) if provider_check.access_key else None,
|
|
region=provider_check.region,
|
|
base_url=provider_check.base_url,
|
|
api_version=provider_check.api_version,
|
|
).cast_to_subtype()
|
|
|
|
# TODO: add more string sanity checks here before we hit actual endpoints
|
|
if not provider.api_key_enc or not await provider.api_key_enc.get_plaintext_async():
|
|
raise ValueError("API key is required!")
|
|
|
|
await provider.check_api_key()
|
|
|
|
async def _sync_default_models_for_provider(self, provider: PydanticProvider, actor: PydanticUser) -> None:
|
|
"""Sync models for a newly created BYOK provider by querying the provider's API."""
|
|
try:
|
|
# Use cast_to_subtype() which properly handles all provider types and preserves api_key_enc
|
|
typed_provider = provider.cast_to_subtype()
|
|
llm_models = await typed_provider.list_llm_models_async()
|
|
embedding_models = await typed_provider.list_embedding_models_async()
|
|
|
|
await self.sync_provider_models_async(
|
|
provider=provider,
|
|
llm_models=llm_models,
|
|
embedding_models=embedding_models,
|
|
organization_id=actor.organization_id,
|
|
)
|
|
await self.update_provider_last_synced_async(provider.id, actor=actor)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to sync models for provider '{provider.name}': {e}")
|
|
# Don't fail provider creation if model sync fails
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def sync_base_providers(self, base_providers: list[PydanticProvider], actor: PydanticUser) -> None:
|
|
"""
|
|
Sync base providers (from environment) to database (idempotent).
|
|
|
|
This method is safe to call from multiple pods simultaneously as it:
|
|
1. Checks if provider exists before creating
|
|
2. Handles race conditions with UniqueConstraintViolationError
|
|
3. Only creates providers that don't exist (no updates to avoid conflicts)
|
|
|
|
Args:
|
|
base_providers: List of base provider instances from environment variables
|
|
actor: User actor for database operations
|
|
"""
|
|
from letta.log import get_logger
|
|
from letta.orm.errors import UniqueConstraintViolationError
|
|
|
|
logger = get_logger(__name__)
|
|
logger.info(f"Syncing {len(base_providers)} base providers to database")
|
|
|
|
async with db_registry.async_session() as session:
|
|
for provider in base_providers:
|
|
try:
|
|
# Check if base provider already exists (base providers have organization_id=None)
|
|
existing_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
name=provider.name,
|
|
organization_id=None, # Base providers are global
|
|
limit=1,
|
|
)
|
|
|
|
if existing_providers:
|
|
logger.debug(f"Base provider '{provider.name}' already exists in database, skipping")
|
|
continue
|
|
|
|
# Convert Provider to ProviderCreate
|
|
# NOTE: Do NOT store API keys for base providers in the database.
|
|
# Base providers should always use environment variables for API keys.
|
|
# This ensures keys stay in sync with env vars and aren't duplicated in DB.
|
|
provider_create = ProviderCreate(
|
|
name=provider.name,
|
|
provider_type=provider.provider_type,
|
|
api_key="", # Base providers use env vars, not DB-stored keys
|
|
access_key=None,
|
|
region=provider.region,
|
|
base_url=provider.base_url,
|
|
api_version=provider.api_version,
|
|
)
|
|
|
|
# Create the provider in the database as a base provider
|
|
await self.create_provider_async(request=provider_create, actor=actor, is_byok=False)
|
|
logger.info(f"Successfully initialized base provider '{provider.name}' to database")
|
|
|
|
except UniqueConstraintViolationError:
|
|
# Race condition: another pod created this provider between our check and create
|
|
# This is expected and safe - just log and continue
|
|
logger.debug(f"Provider '{provider.name}' was created by another pod, skipping")
|
|
except Exception as e:
|
|
# Log error but don't fail startup - provider initialization is not critical
|
|
logger.error(f"Failed to sync provider '{provider.name}' to database: {e}", exc_info=True)
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def sync_provider_models_async(
|
|
self,
|
|
provider: PydanticProvider,
|
|
llm_models: List[LLMConfig],
|
|
embedding_models: List[EmbeddingConfig],
|
|
organization_id: Optional[str] = None,
|
|
) -> None:
|
|
"""Sync models from a provider to the database - adds new models and removes old ones."""
|
|
from letta.log import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
logger.info(f"=== Starting sync for provider '{provider.name}' (ID: {provider.id}) ===")
|
|
logger.info(f" Organization ID: {organization_id}")
|
|
logger.info(f" LLM models to sync: {[m.handle for m in llm_models]}")
|
|
logger.info(f" Embedding models to sync: {[m.handle for m in embedding_models]}")
|
|
|
|
async with db_registry.async_session() as session:
|
|
# Get all existing models for this provider and organization
|
|
# We need to handle None organization_id specially for SQL NULL comparisons
|
|
from sqlalchemy import and_, select
|
|
|
|
# Build the query conditions
|
|
if organization_id is None:
|
|
# For global models (organization_id IS NULL), excluding soft-deleted
|
|
stmt = select(ProviderModelORM).where(
|
|
and_(
|
|
ProviderModelORM.provider_id == provider.id,
|
|
ProviderModelORM.organization_id.is_(None),
|
|
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
|
|
)
|
|
)
|
|
result = await session.execute(stmt)
|
|
existing_models = list(result.scalars().all())
|
|
else:
|
|
# For org-specific models
|
|
existing_models = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
check_is_deleted=True, # Filter out soft-deleted models
|
|
**{
|
|
"provider_id": provider.id,
|
|
"organization_id": organization_id,
|
|
},
|
|
)
|
|
|
|
# Build sets of handles for incoming models
|
|
incoming_llm_handles = {llm.handle for llm in llm_models}
|
|
incoming_embedding_handles = {emb.handle for emb in embedding_models}
|
|
all_incoming_handles = incoming_llm_handles | incoming_embedding_handles
|
|
|
|
# Determine which models to remove (existing models not in the incoming list)
|
|
models_to_remove = []
|
|
for existing_model in existing_models:
|
|
if existing_model.handle not in all_incoming_handles:
|
|
models_to_remove.append(existing_model)
|
|
|
|
# Remove models that are no longer in the sync list
|
|
for model_to_remove in models_to_remove:
|
|
await model_to_remove.delete_async(session)
|
|
logger.debug(f"Removed model {model_to_remove.handle} from provider {provider.name}")
|
|
|
|
# Commit the deletions
|
|
await session.commit()
|
|
|
|
# Process LLM models - add new ones
|
|
logger.info(f"Processing {len(llm_models)} LLM models for provider {provider.name}")
|
|
for llm_config in llm_models:
|
|
logger.info(f" Checking LLM model: {llm_config.handle} (name: {llm_config.model})")
|
|
|
|
# Check if model already exists by handle (excluding soft-deleted ones)
|
|
existing = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=1,
|
|
check_is_deleted=True, # Filter out soft-deleted models
|
|
**{
|
|
"handle": llm_config.handle,
|
|
"organization_id": organization_id,
|
|
"model_type": "llm", # Must check model_type since handle can be same for LLM and embedding
|
|
},
|
|
)
|
|
|
|
# Also check by name+provider_id (covers unique_model_per_provider_and_type constraint)
|
|
if not existing:
|
|
existing = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=1,
|
|
check_is_deleted=True,
|
|
**{
|
|
"name": llm_config.model,
|
|
"provider_id": provider.id,
|
|
"model_type": "llm",
|
|
},
|
|
)
|
|
|
|
if not existing:
|
|
logger.info(f" Creating new LLM model {llm_config.handle}")
|
|
# Create new model entry
|
|
pydantic_model = PydanticProviderModel(
|
|
handle=llm_config.handle,
|
|
display_name=llm_config.model,
|
|
name=llm_config.model,
|
|
provider_id=provider.id,
|
|
organization_id=organization_id,
|
|
model_type="llm",
|
|
enabled=True,
|
|
model_endpoint_type=llm_config.model_endpoint_type,
|
|
max_context_window=llm_config.context_window,
|
|
supports_token_streaming=llm_config.model_endpoint_type in ["openai", "anthropic", "deepseek", "openrouter"],
|
|
supports_tool_calling=True, # Assume true for LLMs for now
|
|
)
|
|
|
|
logger.info(
|
|
f" Model data: handle={pydantic_model.handle}, name={pydantic_model.name}, "
|
|
f"model_type={pydantic_model.model_type}, provider_id={pydantic_model.provider_id}, "
|
|
f"org_id={pydantic_model.organization_id}"
|
|
)
|
|
|
|
model = ProviderModelORM(**pydantic_model.model_dump(to_orm=True))
|
|
result = await model.create_async(session, ignore_conflicts=True)
|
|
if result:
|
|
logger.info(f" ✓ Successfully created LLM model {llm_config.handle}")
|
|
else:
|
|
logger.info(f" LLM model {llm_config.handle} already exists (concurrent insert), skipping")
|
|
else:
|
|
# Check if max_context_window or model_endpoint_type needs to be updated
|
|
existing_model = existing[0]
|
|
needs_update = False
|
|
|
|
if existing_model.max_context_window != llm_config.context_window:
|
|
logger.info(
|
|
f" Updating LLM model {llm_config.handle} max_context_window: "
|
|
f"{existing_model.max_context_window} -> {llm_config.context_window}"
|
|
)
|
|
existing_model.max_context_window = llm_config.context_window
|
|
needs_update = True
|
|
|
|
if existing_model.model_endpoint_type != llm_config.model_endpoint_type:
|
|
logger.info(
|
|
f" Updating LLM model {llm_config.handle} model_endpoint_type: "
|
|
f"{existing_model.model_endpoint_type} -> {llm_config.model_endpoint_type}"
|
|
)
|
|
existing_model.model_endpoint_type = llm_config.model_endpoint_type
|
|
needs_update = True
|
|
|
|
if needs_update:
|
|
await existing_model.update_async(session)
|
|
else:
|
|
logger.info(f" LLM model {llm_config.handle} already exists (ID: {existing[0].id}), skipping")
|
|
|
|
# Process embedding models - add new ones
|
|
logger.info(f"Processing {len(embedding_models)} embedding models for provider {provider.name}")
|
|
for embedding_config in embedding_models:
|
|
logger.info(f" Checking embedding model: {embedding_config.handle} (name: {embedding_config.embedding_model})")
|
|
|
|
# Check if model already exists by handle (excluding soft-deleted ones)
|
|
existing = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=1,
|
|
check_is_deleted=True, # Filter out soft-deleted models
|
|
**{
|
|
"handle": embedding_config.handle,
|
|
"organization_id": organization_id,
|
|
"model_type": "embedding", # Must check model_type since handle can be same for LLM and embedding
|
|
},
|
|
)
|
|
|
|
# Also check by name+provider_id (covers unique_model_per_provider_and_type constraint)
|
|
if not existing:
|
|
existing = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=1,
|
|
check_is_deleted=True,
|
|
**{
|
|
"name": embedding_config.embedding_model,
|
|
"provider_id": provider.id,
|
|
"model_type": "embedding",
|
|
},
|
|
)
|
|
|
|
if not existing:
|
|
logger.info(f" Creating new embedding model {embedding_config.handle}")
|
|
# Create new model entry
|
|
pydantic_model = PydanticProviderModel(
|
|
handle=embedding_config.handle,
|
|
display_name=embedding_config.embedding_model,
|
|
name=embedding_config.embedding_model,
|
|
provider_id=provider.id,
|
|
organization_id=organization_id,
|
|
model_type="embedding",
|
|
enabled=True,
|
|
model_endpoint_type=embedding_config.embedding_endpoint_type,
|
|
embedding_dim=embedding_config.embedding_dim if hasattr(embedding_config, "embedding_dim") else None,
|
|
)
|
|
|
|
logger.info(
|
|
f" Model data: handle={pydantic_model.handle}, name={pydantic_model.name}, "
|
|
f"model_type={pydantic_model.model_type}, provider_id={pydantic_model.provider_id}, "
|
|
f"org_id={pydantic_model.organization_id}"
|
|
)
|
|
|
|
model = ProviderModelORM(**pydantic_model.model_dump(to_orm=True))
|
|
result = await model.create_async(session, ignore_conflicts=True)
|
|
if result:
|
|
logger.info(f" ✓ Successfully created embedding model {embedding_config.handle}")
|
|
else:
|
|
logger.info(f" Embedding model {embedding_config.handle} already exists (concurrent insert), skipping")
|
|
else:
|
|
# Check if model_endpoint_type needs to be updated
|
|
existing_model = existing[0]
|
|
if existing_model.model_endpoint_type != embedding_config.embedding_endpoint_type:
|
|
logger.info(
|
|
f" Updating embedding model {embedding_config.handle} model_endpoint_type: "
|
|
f"{existing_model.model_endpoint_type} -> {embedding_config.embedding_endpoint_type}"
|
|
)
|
|
existing_model.model_endpoint_type = embedding_config.embedding_endpoint_type
|
|
await existing_model.update_async(session)
|
|
else:
|
|
logger.info(f" Embedding model {embedding_config.handle} already exists (ID: {existing[0].id}), skipping")
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_model_by_handle_async(
|
|
self,
|
|
handle: str,
|
|
actor: PydanticUser,
|
|
model_type: Optional[str] = None,
|
|
) -> Optional[PydanticProviderModel]:
|
|
"""Get a model by its handle. Handles are unique per organization."""
|
|
async with db_registry.async_session() as session:
|
|
from sqlalchemy import and_, or_, select
|
|
|
|
# Build conditions for the query
|
|
conditions = [
|
|
ProviderModelORM.handle == handle,
|
|
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
|
|
]
|
|
|
|
if model_type:
|
|
conditions.append(ProviderModelORM.model_type == model_type)
|
|
|
|
# Search for models that are either:
|
|
# 1. Organization-specific (matching actor's org)
|
|
# 2. Global (organization_id is NULL)
|
|
conditions.append(or_(ProviderModelORM.organization_id == actor.organization_id, ProviderModelORM.organization_id.is_(None)))
|
|
|
|
stmt = select(ProviderModelORM).where(and_(*conditions))
|
|
result = await session.execute(stmt)
|
|
models = list(result.scalars().all())
|
|
|
|
# Find the model the user has access to
|
|
# Prioritize org-specific models over global models
|
|
org_model = None
|
|
global_model = None
|
|
|
|
for model in models:
|
|
if model.organization_id == actor.organization_id:
|
|
org_model = model
|
|
elif model.organization_id is None:
|
|
global_model = model
|
|
|
|
# Return org-specific model if it exists, otherwise return global model
|
|
if org_model:
|
|
return org_model.to_pydantic()
|
|
elif global_model:
|
|
return global_model.to_pydantic()
|
|
|
|
return None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def list_models_async(
|
|
self,
|
|
actor: PydanticUser,
|
|
model_type: Optional[str] = None,
|
|
provider_id: Optional[str] = None,
|
|
enabled: Optional[bool] = True,
|
|
limit: Optional[int] = None,
|
|
) -> List[PydanticProviderModel]:
|
|
"""List models available to an actor (both global and org-scoped)."""
|
|
async with db_registry.async_session() as session:
|
|
# Build filters
|
|
filters = {}
|
|
if model_type:
|
|
filters["model_type"] = model_type
|
|
if provider_id:
|
|
filters["provider_id"] = provider_id
|
|
if enabled is not None:
|
|
filters["enabled"] = enabled
|
|
|
|
# Get org-scoped models (excluding soft-deleted ones)
|
|
org_filters = {**filters, "organization_id": actor.organization_id}
|
|
org_models = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=limit,
|
|
check_is_deleted=True, # Filter out soft-deleted models
|
|
**org_filters,
|
|
)
|
|
|
|
# Get global models - need to handle NULL organization_id specially
|
|
from sqlalchemy import and_, select
|
|
|
|
# Build conditions for global models query
|
|
conditions = [
|
|
ProviderModelORM.organization_id.is_(None),
|
|
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
|
|
]
|
|
if model_type:
|
|
conditions.append(ProviderModelORM.model_type == model_type)
|
|
if provider_id:
|
|
conditions.append(ProviderModelORM.provider_id == provider_id)
|
|
if enabled is not None:
|
|
conditions.append(ProviderModelORM.enabled == enabled)
|
|
|
|
stmt = select(ProviderModelORM).where(and_(*conditions))
|
|
if limit:
|
|
stmt = stmt.limit(limit)
|
|
result = await session.execute(stmt)
|
|
global_models = list(result.scalars().all())
|
|
|
|
# Combine and deduplicate by handle AND model_type (org-scoped takes precedence)
|
|
# Use (handle, model_type) tuple as key since same handle can exist for LLM and embedding
|
|
all_models = {(m.handle, m.model_type): m for m in global_models}
|
|
all_models.update({(m.handle, m.model_type): m for m in org_models})
|
|
|
|
return [m.to_pydantic() for m in all_models.values()]
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_llm_config_from_handle(
|
|
self,
|
|
handle: str,
|
|
actor: PydanticUser,
|
|
) -> LLMConfig:
|
|
"""Get an LLMConfig from a model handle.
|
|
|
|
Args:
|
|
handle: The model handle to look up
|
|
actor: The user actor for permission checking
|
|
|
|
Returns:
|
|
LLMConfig constructed from the provider and model data
|
|
|
|
Raises:
|
|
NoResultFound: If the handle doesn't exist in the database or BYOK provider
|
|
"""
|
|
from letta.orm.errors import NoResultFound
|
|
|
|
# Look up the model by handle in the database (for base providers)
|
|
model = await self.get_model_by_handle_async(handle=handle, actor=actor, model_type="llm")
|
|
|
|
if not model:
|
|
# Model not in DB - check if it's from a BYOK provider
|
|
# Handle format is "provider_name/model_name"
|
|
if "/" in handle:
|
|
provider_name, model_name = handle.split("/", 1)
|
|
byok_providers = await self.list_providers_async(
|
|
actor=actor,
|
|
name=provider_name,
|
|
provider_category=[ProviderCategory.byok],
|
|
)
|
|
if byok_providers:
|
|
# Fetch models dynamically from BYOK provider
|
|
provider = byok_providers[0]
|
|
typed_provider = provider.cast_to_subtype()
|
|
try:
|
|
all_llm_configs = await typed_provider.list_llm_models_async()
|
|
# Match by handle first (original logic)
|
|
llm_configs = [config for config in all_llm_configs if config.handle == handle]
|
|
# Fallback to match by model name (original logic)
|
|
if not llm_configs:
|
|
llm_configs = [config for config in all_llm_configs if config.model == model_name]
|
|
if llm_configs:
|
|
return llm_configs[0]
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch models from BYOK provider {provider_name}: {e}")
|
|
|
|
raise NoResultFound(f"LLM model not found with handle='{handle}'")
|
|
|
|
# Get the provider for this model and cast to subtype to access provider-specific methods
|
|
provider = await self.get_provider_async(provider_id=model.provider_id, actor=actor)
|
|
typed_provider = provider.cast_to_subtype()
|
|
|
|
# Get the default max_output_tokens from the provider (provider-specific logic)
|
|
max_tokens = typed_provider.get_default_max_output_tokens(model.name)
|
|
|
|
# Determine the model endpoint - use provider's base_url if set,
|
|
# otherwise use provider-specific defaults
|
|
|
|
if typed_provider.base_url:
|
|
model_endpoint = typed_provider.base_url
|
|
elif provider.provider_type == ProviderType.chatgpt_oauth:
|
|
# ChatGPT OAuth uses the ChatGPT backend API, not a generic endpoint pattern
|
|
from letta.schemas.providers.chatgpt_oauth import CHATGPT_CODEX_ENDPOINT
|
|
|
|
model_endpoint = CHATGPT_CODEX_ENDPOINT
|
|
else:
|
|
model_endpoint = f"https://api.{provider.provider_type.value}.com/v1"
|
|
|
|
# Construct the LLMConfig from the model and provider data
|
|
llm_config = LLMConfig(
|
|
model=model.name,
|
|
model_endpoint_type=model.model_endpoint_type,
|
|
model_endpoint=model_endpoint,
|
|
context_window=model.max_context_window or 16384, # Default if not set
|
|
handle=model.handle,
|
|
provider_name=provider.name,
|
|
provider_category=provider.provider_category,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
return llm_config
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_embedding_config_from_handle(
|
|
self,
|
|
handle: str,
|
|
actor: PydanticUser,
|
|
) -> EmbeddingConfig:
|
|
"""Get an EmbeddingConfig from a model handle.
|
|
|
|
Args:
|
|
handle: The model handle to look up
|
|
actor: The user actor for permission checking
|
|
|
|
Returns:
|
|
EmbeddingConfig constructed from the provider and model data
|
|
|
|
Raises:
|
|
NoResultFound: If the handle doesn't exist in the database or BYOK provider
|
|
"""
|
|
from letta.orm.errors import NoResultFound
|
|
|
|
# Look up the model by handle in the database (for base providers)
|
|
model = await self.get_model_by_handle_async(handle=handle, actor=actor, model_type="embedding")
|
|
|
|
if not model:
|
|
# Model not in DB - check if it's from a BYOK provider
|
|
# Handle format is "provider_name/model_name"
|
|
if "/" in handle:
|
|
provider_name, model_name = handle.split("/", 1)
|
|
byok_providers = await self.list_providers_async(
|
|
actor=actor,
|
|
name=provider_name,
|
|
provider_category=[ProviderCategory.byok],
|
|
)
|
|
if byok_providers:
|
|
# Fetch models dynamically from BYOK provider
|
|
provider = byok_providers[0]
|
|
typed_provider = provider.cast_to_subtype()
|
|
try:
|
|
all_embedding_configs = await typed_provider.list_embedding_models_async()
|
|
# Match by handle (original logic - no model_name fallback for embeddings)
|
|
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
|
|
if embedding_configs:
|
|
return embedding_configs[0]
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider_name}: {e}")
|
|
|
|
raise NoResultFound(f"Embedding model not found with handle='{handle}'")
|
|
|
|
# Get the provider for this model
|
|
provider = await self.get_provider_async(provider_id=model.provider_id, actor=actor)
|
|
|
|
# Construct the EmbeddingConfig from the model and provider data
|
|
embedding_config = EmbeddingConfig(
|
|
embedding_model=model.name,
|
|
embedding_endpoint_type=model.model_endpoint_type,
|
|
embedding_endpoint=provider.base_url or f"https://api.{provider.provider_type.value}.com/v1",
|
|
embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default
|
|
embedding_chunk_size=300, # Default chunk size
|
|
handle=model.handle,
|
|
)
|
|
|
|
return embedding_config
|