* auto fixes * auto fix pt2 and transitive deps and undefined var checking locals() * manual fixes (ignored or letta-code fixed) * fix circular import * remove all ignores, add FastAPI rules and Ruff rules * add ty and precommit * ruff stuff * ty check fixes * ty check fixes pt 2 * error on invalid
1086 lines
50 KiB
Python
1086 lines
50 KiB
Python
from typing import List, Optional, Tuple, Union
|
|
|
|
from sqlalchemy import and_, select
|
|
|
|
from letta.log import get_logger
|
|
from letta.orm.provider import Provider as ProviderModel
|
|
from letta.orm.provider_model import ProviderModel as ProviderModelORM
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.provider_model import ProviderModel as PydanticProviderModel
|
|
from letta.schemas.providers import Provider as PydanticProvider, ProviderCheck, ProviderCreate, ProviderUpdate
|
|
from letta.schemas.secret import Secret
|
|
from letta.schemas.user import User as PydanticUser
|
|
from letta.server.db import db_registry
|
|
from letta.utils import enforce_types
|
|
from letta.validators import raise_on_invalid_id
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class ProviderManager:
|
|
@enforce_types
|
|
@trace_method
|
|
async def create_provider_async(self, request: ProviderCreate, actor: PydanticUser, is_byok: bool = True) -> PydanticProvider:
|
|
"""Create a new provider if it doesn't already exist.
|
|
|
|
Args:
|
|
request: ProviderCreate object with provider details
|
|
actor: User creating the provider
|
|
is_byok: If True, creates a BYOK provider (default). If False, creates a base provider.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
# Check for name conflicts
|
|
if is_byok:
|
|
# BYOK providers cannot use the same name as base providers
|
|
existing_base_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
name=request.name,
|
|
organization_id=None, # Base providers have NULL organization_id
|
|
limit=1,
|
|
)
|
|
if existing_base_providers:
|
|
raise ValueError(
|
|
f"Provider name '{request.name}' conflicts with an existing base provider. Please choose a different name."
|
|
)
|
|
else:
|
|
# Base providers must have unique names among themselves
|
|
# (the DB constraint won't catch this because NULL != NULL)
|
|
existing_base_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
name=request.name,
|
|
organization_id=None, # Base providers have NULL organization_id
|
|
limit=1,
|
|
)
|
|
if existing_base_providers:
|
|
raise ValueError(f"Base provider name '{request.name}' already exists. Please choose a different name.")
|
|
|
|
# Check if there's a soft-deleted provider with the same name that we can restore
|
|
org_id = actor.organization_id if is_byok else None
|
|
if org_id is not None:
|
|
stmt = select(ProviderModel).where(
|
|
and_(
|
|
ProviderModel.name == request.name,
|
|
ProviderModel.organization_id == org_id,
|
|
ProviderModel.is_deleted == True,
|
|
)
|
|
)
|
|
else:
|
|
stmt = select(ProviderModel).where(
|
|
and_(
|
|
ProviderModel.name == request.name,
|
|
ProviderModel.organization_id.is_(None),
|
|
ProviderModel.is_deleted == True,
|
|
)
|
|
)
|
|
result = await session.execute(stmt)
|
|
deleted_provider = result.scalar_one_or_none()
|
|
|
|
if deleted_provider:
|
|
# Restore the soft-deleted provider and update its fields
|
|
logger.info(f"Restoring soft-deleted provider '{request.name}' with id: {deleted_provider.id}")
|
|
deleted_provider.is_deleted = False
|
|
deleted_provider.provider_type = request.provider_type
|
|
deleted_provider.provider_category = ProviderCategory.byok if is_byok else ProviderCategory.base
|
|
deleted_provider.base_url = request.base_url
|
|
deleted_provider.region = request.region
|
|
deleted_provider.api_version = request.api_version
|
|
|
|
# Update encrypted fields (async to avoid blocking event loop)
|
|
if request.api_key is not None:
|
|
api_key_secret = await Secret.from_plaintext_async(request.api_key)
|
|
deleted_provider.api_key_enc = api_key_secret.get_encrypted()
|
|
if request.access_key is not None:
|
|
access_key_secret = await Secret.from_plaintext_async(request.access_key)
|
|
deleted_provider.access_key_enc = access_key_secret.get_encrypted()
|
|
|
|
await deleted_provider.update_async(session, actor=actor)
|
|
|
|
# Also restore any soft-deleted models associated with this provider
|
|
# This is needed because the unique constraint on provider_models doesn't include is_deleted,
|
|
# so soft-deleted models would block creation of new models with the same handle
|
|
from sqlalchemy import update
|
|
|
|
restore_models_stmt = (
|
|
update(ProviderModelORM)
|
|
.where(
|
|
and_(
|
|
ProviderModelORM.provider_id == deleted_provider.id,
|
|
ProviderModelORM.is_deleted == True,
|
|
)
|
|
)
|
|
.values(is_deleted=False)
|
|
)
|
|
result = await session.execute(restore_models_stmt)
|
|
if result.rowcount > 0:
|
|
logger.info(f"Restored {result.rowcount} soft-deleted model(s) for provider '{request.name}'")
|
|
|
|
# Commit the provider and model restoration before syncing
|
|
# This is needed because _sync_default_models_for_provider opens a new session
|
|
# that can't see uncommitted changes from this session
|
|
await session.commit()
|
|
|
|
provider_pydantic = deleted_provider.to_pydantic()
|
|
|
|
# For BYOK providers, automatically sync available models
|
|
# This will add any new models and remove any that are no longer available
|
|
if is_byok:
|
|
await self._sync_default_models_for_provider(provider_pydantic, actor)
|
|
|
|
return provider_pydantic
|
|
|
|
# Create provider with the appropriate category
|
|
provider_data = request.model_dump()
|
|
|
|
# Unset deprecated api_key and access_key as to not write plaintext values, api_key_enc and access_key_enc will be set below
|
|
provider_data.pop("api_key", None)
|
|
provider_data.pop("access_key", None)
|
|
|
|
provider_data["provider_category"] = ProviderCategory.byok if is_byok else ProviderCategory.base
|
|
provider = PydanticProvider(**provider_data)
|
|
|
|
# if provider.name == provider.provider_type.value:
|
|
# raise ValueError("Provider name must be unique and different from provider type")
|
|
|
|
# Fill in schema-default base_url if not provided
|
|
# This ensures providers like ZAI get their default endpoint persisted to DB
|
|
# rather than relying on cast_to_subtype() at read time
|
|
if provider.base_url is None:
|
|
typed_provider = provider.cast_to_subtype()
|
|
if typed_provider.base_url is not None:
|
|
provider.base_url = typed_provider.base_url
|
|
|
|
# Only assign organization id for non-base providers
|
|
# Base providers should be globally accessible (org_id = None)
|
|
if is_byok:
|
|
provider.organization_id = actor.organization_id
|
|
|
|
# Lazily create the provider id prior to persistence
|
|
provider.resolve_identifier()
|
|
|
|
# Explicitly populate encrypted fields from plaintext (async to avoid blocking event loop)
|
|
if request.api_key is not None:
|
|
provider.api_key_enc = await Secret.from_plaintext_async(request.api_key)
|
|
if request.access_key is not None:
|
|
provider.access_key_enc = await Secret.from_plaintext_async(request.access_key)
|
|
|
|
new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
|
|
await new_provider.create_async(session, actor=actor)
|
|
provider_pydantic = new_provider.to_pydantic()
|
|
|
|
# For BYOK providers, automatically sync available models
|
|
if is_byok:
|
|
await self._sync_default_models_for_provider(provider_pydantic, actor)
|
|
|
|
return provider_pydantic
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
|
@trace_method
|
|
async def update_provider_async(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
|
|
"""Update provider details."""
|
|
async with db_registry.async_session() as session:
|
|
# Retrieve the existing provider by ID
|
|
existing_provider = await ProviderModel.read_async(
|
|
db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
|
|
)
|
|
|
|
# Update only the fields that are provided in ProviderUpdate
|
|
update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
|
|
|
# Handle encryption for api_key if provided
|
|
# Only re-encrypt if the value has actually changed
|
|
if "api_key" in update_data and update_data["api_key"] is not None:
|
|
# Check if value changed
|
|
existing_api_key = None
|
|
if existing_provider.api_key_enc:
|
|
existing_secret = Secret.from_encrypted(existing_provider.api_key_enc)
|
|
existing_api_key = await existing_secret.get_plaintext_async()
|
|
|
|
# Only re-encrypt if different (async to avoid blocking event loop)
|
|
if existing_api_key != update_data["api_key"]:
|
|
api_key_secret = await Secret.from_plaintext_async(update_data["api_key"])
|
|
existing_provider.api_key_enc = api_key_secret.get_encrypted()
|
|
|
|
# Remove from update_data since we set directly on existing_provider
|
|
update_data.pop("api_key", None)
|
|
update_data.pop("api_key_enc", None)
|
|
|
|
# Handle encryption for access_key if provided
|
|
# Only re-encrypt if the value has actually changed
|
|
if "access_key" in update_data and update_data["access_key"] is not None:
|
|
# Check if value changed
|
|
existing_access_key = None
|
|
if existing_provider.access_key_enc:
|
|
existing_secret = Secret.from_encrypted(existing_provider.access_key_enc)
|
|
existing_access_key = await existing_secret.get_plaintext_async()
|
|
|
|
# Only re-encrypt if different (async to avoid blocking event loop)
|
|
if existing_access_key != update_data["access_key"]:
|
|
access_key_secret = await Secret.from_plaintext_async(update_data["access_key"])
|
|
existing_provider.access_key_enc = access_key_secret.get_encrypted()
|
|
|
|
# Remove from update_data since we set directly on existing_provider
|
|
update_data.pop("access_key", None)
|
|
update_data.pop("access_key_enc", None)
|
|
|
|
# Apply remaining updates
|
|
for key, value in update_data.items():
|
|
setattr(existing_provider, key, value)
|
|
|
|
# Commit the updated provider
|
|
await existing_provider.update_async(session, actor=actor)
|
|
return existing_provider.to_pydantic()
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
|
async def update_provider_last_synced_async(self, provider_id: str, actor: Optional[PydanticUser] = None) -> None:
|
|
"""Update the last_synced timestamp for a provider.
|
|
|
|
Note: actor is optional to support system-level operations (e.g., during server initialization
|
|
for global providers). When actor is provided, org-scoping is enforced.
|
|
"""
|
|
from datetime import datetime, timezone
|
|
|
|
async with db_registry.async_session() as session:
|
|
provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
|
|
provider.last_synced = datetime.now(timezone.utc)
|
|
await session.commit()
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
|
@trace_method
|
|
async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
|
|
"""Delete a provider and its associated models."""
|
|
async with db_registry.async_session() as session:
|
|
# Clear api key field
|
|
existing_provider = await ProviderModel.read_async(
|
|
db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
|
|
)
|
|
existing_provider.api_key_enc = None
|
|
existing_provider.access_key_enc = None
|
|
|
|
# Only accessing these deprecated fields to clear, which may trigger a warning
|
|
existing_provider.api_key = None
|
|
existing_provider.access_key = None
|
|
|
|
logger.info("Soft deleting provider with id: %s", provider_id)
|
|
|
|
await existing_provider.update_async(session, actor=actor)
|
|
|
|
# Soft delete all models associated with this provider
|
|
provider_models = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
provider_id=provider_id,
|
|
check_is_deleted=True,
|
|
)
|
|
for model in provider_models:
|
|
await model.delete_async(session, actor=actor)
|
|
|
|
# Soft delete in provider table
|
|
await existing_provider.delete_async(session, actor=actor)
|
|
|
|
# context manager now handles commits
|
|
# await session.commit()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def list_providers_async(
|
|
self,
|
|
actor: PydanticUser,
|
|
name: Optional[str] = None,
|
|
provider_type: Optional[ProviderType] = None,
|
|
provider_category: Optional[List[ProviderCategory]] = None,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = 50,
|
|
ascending: bool = False,
|
|
) -> List[PydanticProvider]:
|
|
"""
|
|
List all providers with pagination support.
|
|
Returns both global providers (organization_id=NULL) and organization-specific providers.
|
|
"""
|
|
filter_kwargs = {}
|
|
if name:
|
|
filter_kwargs["name"] = name
|
|
if provider_type:
|
|
filter_kwargs["provider_type"] = provider_type
|
|
async with db_registry.async_session() as session:
|
|
# Get organization-specific providers
|
|
org_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
actor=actor,
|
|
ascending=ascending,
|
|
check_is_deleted=True,
|
|
**filter_kwargs,
|
|
)
|
|
|
|
# Get global providers (base providers with organization_id=NULL)
|
|
global_filter_kwargs = {**filter_kwargs, "organization_id": None}
|
|
global_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
ascending=ascending,
|
|
check_is_deleted=True,
|
|
**global_filter_kwargs,
|
|
)
|
|
|
|
# Combine both lists
|
|
all_providers = []
|
|
if not provider_category:
|
|
all_providers = org_providers + global_providers
|
|
else:
|
|
if ProviderCategory.byok in provider_category:
|
|
all_providers += org_providers
|
|
if ProviderCategory.base in provider_category:
|
|
all_providers += global_providers
|
|
|
|
# Remove deprecated api_key and access_key fields from the response
|
|
for provider in all_providers:
|
|
provider.api_key = None
|
|
provider.access_key = None
|
|
|
|
return [provider.to_pydantic() for provider in all_providers]
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
def list_providers(
|
|
self,
|
|
actor: PydanticUser,
|
|
name: Optional[str] = None,
|
|
provider_type: Optional[ProviderType] = None,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = 50,
|
|
ascending: bool = False,
|
|
) -> List[PydanticProvider]:
|
|
"""
|
|
List all providers with pagination support (synchronous version).
|
|
Returns both global providers (organization_id=NULL) and organization-specific providers.
|
|
"""
|
|
filter_kwargs = {}
|
|
if name:
|
|
filter_kwargs["name"] = name
|
|
if provider_type:
|
|
filter_kwargs["provider_type"] = provider_type
|
|
with db_registry.get_session() as session:
|
|
# Get organization-specific providers
|
|
org_providers = ProviderModel.list(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
actor=actor,
|
|
ascending=ascending,
|
|
check_is_deleted=True,
|
|
**filter_kwargs,
|
|
)
|
|
|
|
# Get global providers (base providers with organization_id=NULL)
|
|
global_filter_kwargs = {**filter_kwargs, "organization_id": None}
|
|
global_providers = ProviderModel.list(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
ascending=ascending,
|
|
check_is_deleted=True,
|
|
**global_filter_kwargs,
|
|
)
|
|
|
|
# Combine both lists
|
|
all_providers = org_providers + global_providers
|
|
|
|
return [provider.to_pydantic() for provider in all_providers]
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
|
@trace_method
|
|
async def get_provider_async(self, provider_id: str, actor: PydanticUser) -> PydanticProvider:
|
|
async with db_registry.async_session() as session:
|
|
# First try to get as organization-specific provider
|
|
try:
|
|
provider_model = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
|
|
return provider_model.to_pydantic()
|
|
except Exception:
|
|
# If not found, try to get as global provider (organization_id=NULL)
|
|
from sqlalchemy import select
|
|
|
|
stmt = select(ProviderModel).where(
|
|
ProviderModel.id == provider_id,
|
|
ProviderModel.organization_id.is_(None),
|
|
ProviderModel.is_deleted == False,
|
|
)
|
|
result = await session.execute(stmt)
|
|
provider_model = result.scalar_one_or_none()
|
|
if provider_model:
|
|
# Remove deprecated api_key and access_key fields from the response
|
|
provider_model.api_key = None
|
|
provider_model.access_key = None
|
|
return provider_model.to_pydantic()
|
|
else:
|
|
from letta.orm.errors import NoResultFound
|
|
|
|
raise NoResultFound(f"Provider not found with id='{provider_id}'")
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
|
providers = self.list_providers(name=provider_name, actor=actor)
|
|
return providers[0].id if providers else None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
|
providers = self.list_providers(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the API key before returning
|
|
api_key_secret = providers[0].api_key_enc
|
|
return api_key_secret.get_plaintext() if api_key_secret else None
|
|
return None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_override_key_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
|
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the API key before returning
|
|
api_key_secret = providers[0].api_key_enc
|
|
return await api_key_secret.get_plaintext_async() if api_key_secret else None
|
|
return None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_bedrock_credentials_async(
|
|
self, provider_name: Union[str, None], actor: PydanticUser
|
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the credentials before returning
|
|
access_key_secret = providers[0].access_key_enc
|
|
api_key_secret = providers[0].api_key_enc
|
|
access_key = await access_key_secret.get_plaintext_async() if access_key_secret else None
|
|
secret_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
|
|
region = providers[0].region
|
|
return access_key, secret_key, region
|
|
return None, None, None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
def get_azure_credentials(
|
|
self, provider_name: Union[str, None], actor: PydanticUser
|
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
providers = self.list_providers(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the API key before returning
|
|
api_key_secret = providers[0].api_key_enc
|
|
api_key = api_key_secret.get_plaintext() if api_key_secret else None
|
|
base_url = providers[0].base_url
|
|
api_version = providers[0].api_version
|
|
return api_key, base_url, api_version
|
|
return None, None, None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_azure_credentials_async(
|
|
self, provider_name: Union[str, None], actor: PydanticUser
|
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the API key before returning
|
|
api_key_secret = providers[0].api_key_enc
|
|
api_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
|
|
base_url = providers[0].base_url
|
|
api_version = providers[0].api_version
|
|
return api_key, base_url, api_version
|
|
return None, None, None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
|
|
provider = PydanticProvider(
|
|
name=provider_check.provider_type.value,
|
|
provider_type=provider_check.provider_type,
|
|
api_key_enc=Secret.from_plaintext(provider_check.api_key),
|
|
provider_category=ProviderCategory.byok,
|
|
access_key_enc=Secret.from_plaintext(provider_check.access_key) if provider_check.access_key else None,
|
|
region=provider_check.region,
|
|
base_url=provider_check.base_url,
|
|
api_version=provider_check.api_version,
|
|
).cast_to_subtype()
|
|
|
|
# TODO: add more string sanity checks here before we hit actual endpoints
|
|
if not provider.api_key_enc or not await provider.api_key_enc.get_plaintext_async():
|
|
raise ValueError("API key is required!")
|
|
|
|
await provider.check_api_key()
|
|
|
|
async def _sync_default_models_for_provider(self, provider: PydanticProvider, actor: PydanticUser) -> None:
|
|
"""Sync models for a newly created BYOK provider by querying the provider's API."""
|
|
try:
|
|
# Use cast_to_subtype() which properly handles all provider types and preserves api_key_enc
|
|
typed_provider = provider.cast_to_subtype()
|
|
llm_models = await typed_provider.list_llm_models_async()
|
|
embedding_models = await typed_provider.list_embedding_models_async()
|
|
|
|
await self.sync_provider_models_async(
|
|
provider=provider,
|
|
llm_models=llm_models,
|
|
embedding_models=embedding_models,
|
|
organization_id=actor.organization_id,
|
|
)
|
|
await self.update_provider_last_synced_async(provider.id, actor=actor)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to sync models for provider '{provider.name}': {e}")
|
|
# Don't fail provider creation if model sync fails
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def sync_base_providers(self, base_providers: list[PydanticProvider], actor: PydanticUser) -> None:
|
|
"""
|
|
Sync base providers (from environment) to database (idempotent).
|
|
|
|
This method is safe to call from multiple pods simultaneously as it:
|
|
1. Checks if provider exists before creating
|
|
2. Handles race conditions with UniqueConstraintViolationError
|
|
3. Only creates providers that don't exist (no updates to avoid conflicts)
|
|
|
|
Args:
|
|
base_providers: List of base provider instances from environment variables
|
|
actor: User actor for database operations
|
|
"""
|
|
from letta.log import get_logger
|
|
from letta.orm.errors import UniqueConstraintViolationError
|
|
|
|
logger = get_logger(__name__)
|
|
logger.info(f"Syncing {len(base_providers)} base providers to database")
|
|
|
|
async with db_registry.async_session() as session:
|
|
for provider in base_providers:
|
|
try:
|
|
# Check if base provider already exists (base providers have organization_id=None)
|
|
existing_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
name=provider.name,
|
|
organization_id=None, # Base providers are global
|
|
limit=1,
|
|
)
|
|
|
|
if existing_providers:
|
|
logger.debug(f"Base provider '{provider.name}' already exists in database, skipping")
|
|
continue
|
|
|
|
# Convert Provider to ProviderCreate
|
|
# NOTE: Do NOT store API keys for base providers in the database.
|
|
# Base providers should always use environment variables for API keys.
|
|
# This ensures keys stay in sync with env vars and aren't duplicated in DB.
|
|
provider_create = ProviderCreate(
|
|
name=provider.name,
|
|
provider_type=provider.provider_type,
|
|
api_key="", # Base providers use env vars, not DB-stored keys
|
|
access_key=None,
|
|
region=provider.region,
|
|
base_url=provider.base_url,
|
|
api_version=provider.api_version,
|
|
)
|
|
|
|
# Create the provider in the database as a base provider
|
|
await self.create_provider_async(request=provider_create, actor=actor, is_byok=False)
|
|
logger.info(f"Successfully initialized base provider '{provider.name}' to database")
|
|
|
|
except UniqueConstraintViolationError:
|
|
# Race condition: another pod created this provider between our check and create
|
|
# This is expected and safe - just log and continue
|
|
logger.debug(f"Provider '{provider.name}' was created by another pod, skipping")
|
|
except Exception as e:
|
|
# Log error but don't fail startup - provider initialization is not critical
|
|
logger.error(f"Failed to sync provider '{provider.name}' to database: {e}", exc_info=True)
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def sync_provider_models_async(
|
|
self,
|
|
provider: PydanticProvider,
|
|
llm_models: List[LLMConfig],
|
|
embedding_models: List[EmbeddingConfig],
|
|
organization_id: Optional[str] = None,
|
|
) -> None:
|
|
"""Sync models from a provider to the database - adds new models and removes old ones."""
|
|
from letta.log import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
logger.info(f"=== Starting sync for provider '{provider.name}' (ID: {provider.id}) ===")
|
|
logger.info(f" Organization ID: {organization_id}")
|
|
logger.info(f" LLM models to sync: {[m.handle for m in llm_models]}")
|
|
logger.info(f" Embedding models to sync: {[m.handle for m in embedding_models]}")
|
|
|
|
async with db_registry.async_session() as session:
|
|
# Get all existing models for this provider and organization
|
|
# We need to handle None organization_id specially for SQL NULL comparisons
|
|
from sqlalchemy import and_, select
|
|
|
|
# Build the query conditions
|
|
if organization_id is None:
|
|
# For global models (organization_id IS NULL), excluding soft-deleted
|
|
stmt = select(ProviderModelORM).where(
|
|
and_(
|
|
ProviderModelORM.provider_id == provider.id,
|
|
ProviderModelORM.organization_id.is_(None),
|
|
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
|
|
)
|
|
)
|
|
result = await session.execute(stmt)
|
|
existing_models = list(result.scalars().all())
|
|
else:
|
|
# For org-specific models
|
|
existing_models = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
check_is_deleted=True, # Filter out soft-deleted models
|
|
**{
|
|
"provider_id": provider.id,
|
|
"organization_id": organization_id,
|
|
},
|
|
)
|
|
|
|
# Build sets of handles for incoming models
|
|
incoming_llm_handles = {llm.handle for llm in llm_models}
|
|
incoming_embedding_handles = {emb.handle for emb in embedding_models}
|
|
all_incoming_handles = incoming_llm_handles | incoming_embedding_handles
|
|
|
|
# Determine which models to remove (existing models not in the incoming list)
|
|
models_to_remove = []
|
|
for existing_model in existing_models:
|
|
if existing_model.handle not in all_incoming_handles:
|
|
models_to_remove.append(existing_model)
|
|
|
|
# Remove models that are no longer in the sync list
|
|
for model_to_remove in models_to_remove:
|
|
await model_to_remove.delete_async(session)
|
|
logger.debug(f"Removed model {model_to_remove.handle} from provider {provider.name}")
|
|
|
|
# Commit the deletions
|
|
await session.commit()
|
|
|
|
# Process LLM models - add new ones
|
|
logger.info(f"Processing {len(llm_models)} LLM models for provider {provider.name}")
|
|
for llm_config in llm_models:
|
|
logger.info(f" Checking LLM model: {llm_config.handle} (name: {llm_config.model})")
|
|
|
|
# Check if model already exists by handle (excluding soft-deleted ones)
|
|
existing = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=1,
|
|
check_is_deleted=True, # Filter out soft-deleted models
|
|
**{
|
|
"handle": llm_config.handle,
|
|
"organization_id": organization_id,
|
|
"model_type": "llm", # Must check model_type since handle can be same for LLM and embedding
|
|
},
|
|
)
|
|
|
|
# Also check by name+provider_id (covers unique_model_per_provider_and_type constraint)
|
|
if not existing:
|
|
existing = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=1,
|
|
check_is_deleted=True,
|
|
**{
|
|
"name": llm_config.model,
|
|
"provider_id": provider.id,
|
|
"model_type": "llm",
|
|
},
|
|
)
|
|
|
|
if not existing:
|
|
logger.info(f" Creating new LLM model {llm_config.handle}")
|
|
# Create new model entry
|
|
pydantic_model = PydanticProviderModel(
|
|
handle=llm_config.handle,
|
|
display_name=llm_config.model,
|
|
name=llm_config.model,
|
|
provider_id=provider.id,
|
|
organization_id=organization_id,
|
|
model_type="llm",
|
|
enabled=True,
|
|
model_endpoint_type=llm_config.model_endpoint_type,
|
|
max_context_window=llm_config.context_window,
|
|
supports_token_streaming=llm_config.model_endpoint_type in ["openai", "anthropic", "deepseek", "openrouter"],
|
|
supports_tool_calling=True, # Assume true for LLMs for now
|
|
)
|
|
|
|
logger.info(
|
|
f" Model data: handle={pydantic_model.handle}, name={pydantic_model.name}, "
|
|
f"model_type={pydantic_model.model_type}, provider_id={pydantic_model.provider_id}, "
|
|
f"org_id={pydantic_model.organization_id}"
|
|
)
|
|
|
|
model = ProviderModelORM(**pydantic_model.model_dump(to_orm=True))
|
|
result = await model.create_async(session, ignore_conflicts=True)
|
|
if result:
|
|
logger.info(f" ✓ Successfully created LLM model {llm_config.handle}")
|
|
else:
|
|
logger.info(f" LLM model {llm_config.handle} already exists (concurrent insert), skipping")
|
|
else:
|
|
# Check if max_context_window or model_endpoint_type needs to be updated
|
|
existing_model = existing[0]
|
|
needs_update = False
|
|
|
|
if existing_model.max_context_window != llm_config.context_window:
|
|
logger.info(
|
|
f" Updating LLM model {llm_config.handle} max_context_window: "
|
|
f"{existing_model.max_context_window} -> {llm_config.context_window}"
|
|
)
|
|
existing_model.max_context_window = llm_config.context_window
|
|
needs_update = True
|
|
|
|
if existing_model.model_endpoint_type != llm_config.model_endpoint_type:
|
|
logger.info(
|
|
f" Updating LLM model {llm_config.handle} model_endpoint_type: "
|
|
f"{existing_model.model_endpoint_type} -> {llm_config.model_endpoint_type}"
|
|
)
|
|
existing_model.model_endpoint_type = llm_config.model_endpoint_type
|
|
needs_update = True
|
|
|
|
if needs_update:
|
|
await existing_model.update_async(session)
|
|
else:
|
|
logger.info(f" LLM model {llm_config.handle} already exists (ID: {existing[0].id}), skipping")
|
|
|
|
# Process embedding models - add new ones
|
|
logger.info(f"Processing {len(embedding_models)} embedding models for provider {provider.name}")
|
|
for embedding_config in embedding_models:
|
|
logger.info(f" Checking embedding model: {embedding_config.handle} (name: {embedding_config.embedding_model})")
|
|
|
|
# Check if model already exists by handle (excluding soft-deleted ones)
|
|
existing = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=1,
|
|
check_is_deleted=True, # Filter out soft-deleted models
|
|
**{
|
|
"handle": embedding_config.handle,
|
|
"organization_id": organization_id,
|
|
"model_type": "embedding", # Must check model_type since handle can be same for LLM and embedding
|
|
},
|
|
)
|
|
|
|
# Also check by name+provider_id (covers unique_model_per_provider_and_type constraint)
|
|
if not existing:
|
|
existing = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=1,
|
|
check_is_deleted=True,
|
|
**{
|
|
"name": embedding_config.embedding_model,
|
|
"provider_id": provider.id,
|
|
"model_type": "embedding",
|
|
},
|
|
)
|
|
|
|
if not existing:
|
|
logger.info(f" Creating new embedding model {embedding_config.handle}")
|
|
# Create new model entry
|
|
pydantic_model = PydanticProviderModel(
|
|
handle=embedding_config.handle,
|
|
display_name=embedding_config.embedding_model,
|
|
name=embedding_config.embedding_model,
|
|
provider_id=provider.id,
|
|
organization_id=organization_id,
|
|
model_type="embedding",
|
|
enabled=True,
|
|
model_endpoint_type=embedding_config.embedding_endpoint_type,
|
|
embedding_dim=embedding_config.embedding_dim if hasattr(embedding_config, "embedding_dim") else None,
|
|
)
|
|
|
|
logger.info(
|
|
f" Model data: handle={pydantic_model.handle}, name={pydantic_model.name}, "
|
|
f"model_type={pydantic_model.model_type}, provider_id={pydantic_model.provider_id}, "
|
|
f"org_id={pydantic_model.organization_id}"
|
|
)
|
|
|
|
model = ProviderModelORM(**pydantic_model.model_dump(to_orm=True))
|
|
result = await model.create_async(session, ignore_conflicts=True)
|
|
if result:
|
|
logger.info(f" ✓ Successfully created embedding model {embedding_config.handle}")
|
|
else:
|
|
logger.info(f" Embedding model {embedding_config.handle} already exists (concurrent insert), skipping")
|
|
else:
|
|
# Check if model_endpoint_type needs to be updated
|
|
existing_model = existing[0]
|
|
if existing_model.model_endpoint_type != embedding_config.embedding_endpoint_type:
|
|
logger.info(
|
|
f" Updating embedding model {embedding_config.handle} model_endpoint_type: "
|
|
f"{existing_model.model_endpoint_type} -> {embedding_config.embedding_endpoint_type}"
|
|
)
|
|
existing_model.model_endpoint_type = embedding_config.embedding_endpoint_type
|
|
await existing_model.update_async(session)
|
|
else:
|
|
logger.info(f" Embedding model {embedding_config.handle} already exists (ID: {existing[0].id}), skipping")
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_model_by_handle_async(
|
|
self,
|
|
handle: str,
|
|
actor: PydanticUser,
|
|
model_type: Optional[str] = None,
|
|
) -> Optional[PydanticProviderModel]:
|
|
"""Get a model by its handle. Handles are unique per organization."""
|
|
async with db_registry.async_session() as session:
|
|
from sqlalchemy import and_, or_, select
|
|
|
|
# Build conditions for the query
|
|
conditions = [
|
|
ProviderModelORM.handle == handle,
|
|
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
|
|
]
|
|
|
|
if model_type:
|
|
conditions.append(ProviderModelORM.model_type == model_type)
|
|
|
|
# Search for models that are either:
|
|
# 1. Organization-specific (matching actor's org)
|
|
# 2. Global (organization_id is NULL)
|
|
conditions.append(or_(ProviderModelORM.organization_id == actor.organization_id, ProviderModelORM.organization_id.is_(None)))
|
|
|
|
stmt = select(ProviderModelORM).where(and_(*conditions))
|
|
result = await session.execute(stmt)
|
|
models = list(result.scalars().all())
|
|
|
|
# Find the model the user has access to
|
|
# Prioritize org-specific models over global models
|
|
org_model = None
|
|
global_model = None
|
|
|
|
for model in models:
|
|
if model.organization_id == actor.organization_id:
|
|
org_model = model
|
|
elif model.organization_id is None:
|
|
global_model = model
|
|
|
|
# Return org-specific model if it exists, otherwise return global model
|
|
if org_model:
|
|
return org_model.to_pydantic()
|
|
elif global_model:
|
|
return global_model.to_pydantic()
|
|
|
|
return None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def list_models_async(
|
|
self,
|
|
actor: PydanticUser,
|
|
model_type: Optional[str] = None,
|
|
provider_id: Optional[str] = None,
|
|
enabled: Optional[bool] = True,
|
|
limit: Optional[int] = None,
|
|
) -> List[PydanticProviderModel]:
|
|
"""List models available to an actor (both global and org-scoped)."""
|
|
async with db_registry.async_session() as session:
|
|
# Build filters
|
|
filters = {}
|
|
if model_type:
|
|
filters["model_type"] = model_type
|
|
if provider_id:
|
|
filters["provider_id"] = provider_id
|
|
if enabled is not None:
|
|
filters["enabled"] = enabled
|
|
|
|
# Get org-scoped models (excluding soft-deleted ones)
|
|
org_filters = {**filters, "organization_id": actor.organization_id}
|
|
org_models = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=limit,
|
|
check_is_deleted=True, # Filter out soft-deleted models
|
|
**org_filters,
|
|
)
|
|
|
|
# Get global models - need to handle NULL organization_id specially
|
|
from sqlalchemy import and_, select
|
|
|
|
# Build conditions for global models query
|
|
conditions = [
|
|
ProviderModelORM.organization_id.is_(None),
|
|
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
|
|
]
|
|
if model_type:
|
|
conditions.append(ProviderModelORM.model_type == model_type)
|
|
if provider_id:
|
|
conditions.append(ProviderModelORM.provider_id == provider_id)
|
|
if enabled is not None:
|
|
conditions.append(ProviderModelORM.enabled == enabled)
|
|
|
|
stmt = select(ProviderModelORM).where(and_(*conditions))
|
|
if limit:
|
|
stmt = stmt.limit(limit)
|
|
result = await session.execute(stmt)
|
|
global_models = list(result.scalars().all())
|
|
|
|
# Combine and deduplicate by handle AND model_type (org-scoped takes precedence)
|
|
# Use (handle, model_type) tuple as key since same handle can exist for LLM and embedding
|
|
all_models = {(m.handle, m.model_type): m for m in global_models}
|
|
all_models.update({(m.handle, m.model_type): m for m in org_models})
|
|
|
|
return [m.to_pydantic() for m in all_models.values()]
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_llm_config_from_handle(
|
|
self,
|
|
handle: str,
|
|
actor: PydanticUser,
|
|
) -> LLMConfig:
|
|
"""Get an LLMConfig from a model handle.
|
|
|
|
Args:
|
|
handle: The model handle to look up
|
|
actor: The user actor for permission checking
|
|
|
|
Returns:
|
|
LLMConfig constructed from the provider and model data
|
|
|
|
Raises:
|
|
NoResultFound: If the handle doesn't exist in the database or BYOK provider
|
|
"""
|
|
from letta.orm.errors import NoResultFound
|
|
|
|
# Look up the model by handle in the database (for base providers)
|
|
model = await self.get_model_by_handle_async(handle=handle, actor=actor, model_type="llm")
|
|
|
|
if not model:
|
|
# Model not in DB - check if it's from a BYOK provider
|
|
# Handle format is "provider_name/model_name"
|
|
if "/" in handle:
|
|
provider_name, model_name = handle.split("/", 1)
|
|
byok_providers = await self.list_providers_async(
|
|
actor=actor,
|
|
name=provider_name,
|
|
provider_category=[ProviderCategory.byok],
|
|
)
|
|
if byok_providers:
|
|
# Fetch models dynamically from BYOK provider
|
|
provider = byok_providers[0]
|
|
typed_provider = provider.cast_to_subtype()
|
|
try:
|
|
all_llm_configs = await typed_provider.list_llm_models_async()
|
|
# Match by handle first (original logic)
|
|
llm_configs = [config for config in all_llm_configs if config.handle == handle]
|
|
# Fallback to match by model name (original logic)
|
|
if not llm_configs:
|
|
llm_configs = [config for config in all_llm_configs if config.model == model_name]
|
|
if llm_configs:
|
|
return llm_configs[0]
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch models from BYOK provider {provider_name}: {e}")
|
|
|
|
raise NoResultFound(f"LLM model not found with handle='{handle}'")
|
|
|
|
# Get the provider for this model and cast to subtype to access provider-specific methods
|
|
provider = await self.get_provider_async(provider_id=model.provider_id, actor=actor)
|
|
typed_provider = provider.cast_to_subtype()
|
|
|
|
# Get the default max_output_tokens from the provider (provider-specific logic)
|
|
max_tokens = typed_provider.get_default_max_output_tokens(model.name)
|
|
|
|
# Determine the model endpoint - use provider's OpenAI-compatible base_url if available,
|
|
# otherwise fall back to raw base_url or provider-specific defaults
|
|
|
|
if hasattr(typed_provider, "openai_compat_base_url"):
|
|
# For providers like ollama/vllm/lmstudio that need /v1 appended for OpenAI compatibility
|
|
model_endpoint = typed_provider.openai_compat_base_url
|
|
elif typed_provider.base_url:
|
|
model_endpoint = typed_provider.base_url
|
|
elif provider.provider_type == ProviderType.chatgpt_oauth:
|
|
# ChatGPT OAuth uses the ChatGPT backend API, not a generic endpoint pattern
|
|
from letta.schemas.providers.chatgpt_oauth import CHATGPT_CODEX_ENDPOINT
|
|
|
|
model_endpoint = CHATGPT_CODEX_ENDPOINT
|
|
else:
|
|
model_endpoint = f"https://api.{provider.provider_type.value}.com/v1"
|
|
|
|
# Construct the LLMConfig from the model and provider data
|
|
llm_config = LLMConfig(
|
|
model=model.name,
|
|
model_endpoint_type=model.model_endpoint_type,
|
|
model_endpoint=model_endpoint,
|
|
context_window=model.max_context_window or 16384, # Default if not set
|
|
handle=model.handle,
|
|
provider_name=provider.name,
|
|
provider_category=provider.provider_category,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
return llm_config
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_embedding_config_from_handle(
|
|
self,
|
|
handle: str,
|
|
actor: PydanticUser,
|
|
) -> EmbeddingConfig:
|
|
"""Get an EmbeddingConfig from a model handle.
|
|
|
|
Args:
|
|
handle: The model handle to look up
|
|
actor: The user actor for permission checking
|
|
|
|
Returns:
|
|
EmbeddingConfig constructed from the provider and model data
|
|
|
|
Raises:
|
|
NoResultFound: If the handle doesn't exist in the database or BYOK provider
|
|
"""
|
|
from letta.orm.errors import NoResultFound
|
|
|
|
# Look up the model by handle in the database (for base providers)
|
|
model = await self.get_model_by_handle_async(handle=handle, actor=actor, model_type="embedding")
|
|
|
|
if not model:
|
|
# Model not in DB - check if it's from a BYOK provider
|
|
# Handle format is "provider_name/model_name"
|
|
if "/" in handle:
|
|
provider_name, _model_name = handle.split("/", 1)
|
|
byok_providers = await self.list_providers_async(
|
|
actor=actor,
|
|
name=provider_name,
|
|
provider_category=[ProviderCategory.byok],
|
|
)
|
|
if byok_providers:
|
|
# Fetch models dynamically from BYOK provider
|
|
provider = byok_providers[0]
|
|
typed_provider = provider.cast_to_subtype()
|
|
try:
|
|
all_embedding_configs = await typed_provider.list_embedding_models_async()
|
|
# Match by handle (original logic - no model_name fallback for embeddings)
|
|
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
|
|
if embedding_configs:
|
|
return embedding_configs[0]
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider_name}: {e}")
|
|
|
|
raise NoResultFound(f"Embedding model not found with handle='{handle}'")
|
|
|
|
# Get the provider for this model
|
|
provider = await self.get_provider_async(provider_id=model.provider_id, actor=actor)
|
|
|
|
# Construct the EmbeddingConfig from the model and provider data
|
|
embedding_config = EmbeddingConfig(
|
|
embedding_model=model.name,
|
|
embedding_endpoint_type=model.model_endpoint_type,
|
|
embedding_endpoint=provider.base_url or f"https://api.{provider.provider_type.value}.com/v1",
|
|
embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default
|
|
embedding_chunk_size=300, # Default chunk size
|
|
handle=model.handle,
|
|
)
|
|
|
|
return embedding_config
|