Swap the order of @trace_method and @raise_on_invalid_id decorators across all service managers so that @trace_method is always the first wrapper applied to the function (positioned directly above the method). This ensures the ID validation happens before tracing begins, which is the intended execution order. Files modified: - agent_manager.py (23 occurrences) - archive_manager.py (11 occurrences) - block_manager.py (7 occurrences) - file_manager.py (6 occurrences) - group_manager.py (9 occurrences) - identity_manager.py (10 occurrences) - job_manager.py (7 occurrences) - message_manager.py (2 occurrences) - provider_manager.py (3 occurrences) - sandbox_config_manager.py (7 occurrences) - source_manager.py (5 occurrences) - step_manager.py (13 occurrences)
909 lines
41 KiB
Python
909 lines
41 KiB
Python
from typing import List, Optional, Tuple, Union
|
|
|
|
from letta.log import get_logger
|
|
from letta.orm.provider import Provider as ProviderModel
|
|
from letta.orm.provider_model import ProviderModel as ProviderModelORM
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.provider_model import ProviderModel as PydanticProviderModel
|
|
from letta.schemas.providers import Provider as PydanticProvider, ProviderCheck, ProviderCreate, ProviderUpdate
|
|
from letta.schemas.secret import Secret
|
|
from letta.schemas.user import User as PydanticUser
|
|
from letta.server.db import db_registry
|
|
from letta.utils import enforce_types
|
|
from letta.validators import raise_on_invalid_id
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class ProviderManager:
|
|
@enforce_types
|
|
@trace_method
|
|
async def create_provider_async(self, request: ProviderCreate, actor: PydanticUser, is_byok: bool = True) -> PydanticProvider:
|
|
"""Create a new provider if it doesn't already exist.
|
|
|
|
Args:
|
|
request: ProviderCreate object with provider details
|
|
actor: User creating the provider
|
|
is_byok: If True, creates a BYOK provider (default). If False, creates a base provider.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
from letta.schemas.enums import ProviderCategory
|
|
|
|
# Check for name conflicts
|
|
if is_byok:
|
|
# BYOK providers cannot use the same name as base providers
|
|
existing_base_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
name=request.name,
|
|
organization_id=None, # Base providers have NULL organization_id
|
|
limit=1,
|
|
)
|
|
if existing_base_providers:
|
|
raise ValueError(
|
|
f"Provider name '{request.name}' conflicts with an existing base provider. Please choose a different name."
|
|
)
|
|
else:
|
|
# Base providers must have unique names among themselves
|
|
# (the DB constraint won't catch this because NULL != NULL)
|
|
existing_base_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
name=request.name,
|
|
organization_id=None, # Base providers have NULL organization_id
|
|
limit=1,
|
|
)
|
|
if existing_base_providers:
|
|
raise ValueError(f"Base provider name '{request.name}' already exists. Please choose a different name.")
|
|
|
|
# Create provider with the appropriate category
|
|
provider_data = request.model_dump()
|
|
|
|
# Unset deprecated api_key and access_key as to not write plaintext values, api_key_enc and access_key_enc will be set below
|
|
provider_data.pop("api_key", None)
|
|
provider_data.pop("access_key", None)
|
|
|
|
provider_data["provider_category"] = ProviderCategory.byok if is_byok else ProviderCategory.base
|
|
provider = PydanticProvider(**provider_data)
|
|
|
|
# if provider.name == provider.provider_type.value:
|
|
# raise ValueError("Provider name must be unique and different from provider type")
|
|
|
|
# Only assign organization id for non-base providers
|
|
# Base providers should be globally accessible (org_id = None)
|
|
if is_byok:
|
|
provider.organization_id = actor.organization_id
|
|
|
|
# Lazily create the provider id prior to persistence
|
|
provider.resolve_identifier()
|
|
|
|
# Explicitly populate encrypted fields from plaintext
|
|
if request.api_key is not None:
|
|
provider.api_key_enc = Secret.from_plaintext(request.api_key)
|
|
if request.access_key is not None:
|
|
provider.access_key_enc = Secret.from_plaintext(request.access_key)
|
|
|
|
new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
|
|
await new_provider.create_async(session, actor=actor)
|
|
provider_pydantic = new_provider.to_pydantic()
|
|
|
|
# For BYOK providers, automatically sync available models
|
|
if is_byok:
|
|
await self._sync_default_models_for_provider(provider_pydantic, actor)
|
|
|
|
return provider_pydantic
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
|
@trace_method
|
|
async def update_provider_async(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
|
|
"""Update provider details."""
|
|
async with db_registry.async_session() as session:
|
|
# Retrieve the existing provider by ID
|
|
existing_provider = await ProviderModel.read_async(
|
|
db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
|
|
)
|
|
|
|
# Update only the fields that are provided in ProviderUpdate
|
|
update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
|
|
|
# Handle encryption for api_key if provided
|
|
# Only re-encrypt if the value has actually changed
|
|
if "api_key" in update_data and update_data["api_key"] is not None:
|
|
# Check if value changed
|
|
existing_api_key = None
|
|
if existing_provider.api_key_enc:
|
|
existing_secret = Secret.from_encrypted(existing_provider.api_key_enc)
|
|
existing_api_key = await existing_secret.get_plaintext_async()
|
|
|
|
# Only re-encrypt if different
|
|
if existing_api_key != update_data["api_key"]:
|
|
existing_provider.api_key_enc = Secret.from_plaintext(update_data["api_key"]).get_encrypted()
|
|
|
|
# Remove from update_data since we set directly on existing_provider
|
|
update_data.pop("api_key", None)
|
|
update_data.pop("api_key_enc", None)
|
|
|
|
# Handle encryption for access_key if provided
|
|
# Only re-encrypt if the value has actually changed
|
|
if "access_key" in update_data and update_data["access_key"] is not None:
|
|
# Check if value changed
|
|
existing_access_key = None
|
|
if existing_provider.access_key_enc:
|
|
existing_secret = Secret.from_encrypted(existing_provider.access_key_enc)
|
|
existing_access_key = await existing_secret.get_plaintext_async()
|
|
|
|
# Only re-encrypt if different
|
|
if existing_access_key != update_data["access_key"]:
|
|
existing_provider.access_key_enc = Secret.from_plaintext(update_data["access_key"]).get_encrypted()
|
|
|
|
# Remove from update_data since we set directly on existing_provider
|
|
update_data.pop("access_key", None)
|
|
update_data.pop("access_key_enc", None)
|
|
|
|
# Apply remaining updates
|
|
for key, value in update_data.items():
|
|
setattr(existing_provider, key, value)
|
|
|
|
# Commit the updated provider
|
|
await existing_provider.update_async(session, actor=actor)
|
|
return existing_provider.to_pydantic()
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
|
@trace_method
|
|
async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
|
|
"""Delete a provider."""
|
|
async with db_registry.async_session() as session:
|
|
# Clear api key field
|
|
existing_provider = await ProviderModel.read_async(
|
|
db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
|
|
)
|
|
existing_provider.api_key_enc = None
|
|
existing_provider.access_key_enc = None
|
|
|
|
# Only accessing these deprecated fields to clear, which may trigger a warning
|
|
existing_provider.api_key = None
|
|
existing_provider.access_key = None
|
|
|
|
logger.info("Soft deleting provider with id: %s", provider_id)
|
|
|
|
await existing_provider.update_async(session, actor=actor)
|
|
|
|
# Soft delete in provider table
|
|
await existing_provider.delete_async(session, actor=actor)
|
|
|
|
await session.commit()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def list_providers_async(
|
|
self,
|
|
actor: PydanticUser,
|
|
name: Optional[str] = None,
|
|
provider_type: Optional[ProviderType] = None,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = 50,
|
|
ascending: bool = False,
|
|
) -> List[PydanticProvider]:
|
|
"""
|
|
List all providers with pagination support.
|
|
Returns both global providers (organization_id=NULL) and organization-specific providers.
|
|
"""
|
|
filter_kwargs = {}
|
|
if name:
|
|
filter_kwargs["name"] = name
|
|
if provider_type:
|
|
filter_kwargs["provider_type"] = provider_type
|
|
async with db_registry.async_session() as session:
|
|
# Get organization-specific providers
|
|
org_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
actor=actor,
|
|
ascending=ascending,
|
|
check_is_deleted=True,
|
|
**filter_kwargs,
|
|
)
|
|
|
|
# Get global providers (base providers with organization_id=NULL)
|
|
global_filter_kwargs = {**filter_kwargs, "organization_id": None}
|
|
global_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
ascending=ascending,
|
|
check_is_deleted=True,
|
|
**global_filter_kwargs,
|
|
)
|
|
|
|
# Combine both lists
|
|
all_providers = org_providers + global_providers
|
|
|
|
# Remove deprecated api_key and access_key fields from the response
|
|
for provider in all_providers:
|
|
provider.api_key = None
|
|
provider.access_key = None
|
|
|
|
return [provider.to_pydantic() for provider in all_providers]
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
def list_providers(
|
|
self,
|
|
actor: PydanticUser,
|
|
name: Optional[str] = None,
|
|
provider_type: Optional[ProviderType] = None,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = 50,
|
|
ascending: bool = False,
|
|
) -> List[PydanticProvider]:
|
|
"""
|
|
List all providers with pagination support (synchronous version).
|
|
Returns both global providers (organization_id=NULL) and organization-specific providers.
|
|
"""
|
|
filter_kwargs = {}
|
|
if name:
|
|
filter_kwargs["name"] = name
|
|
if provider_type:
|
|
filter_kwargs["provider_type"] = provider_type
|
|
with db_registry.get_session() as session:
|
|
# Get organization-specific providers
|
|
org_providers = ProviderModel.list(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
actor=actor,
|
|
ascending=ascending,
|
|
check_is_deleted=True,
|
|
**filter_kwargs,
|
|
)
|
|
|
|
# Get global providers (base providers with organization_id=NULL)
|
|
global_filter_kwargs = {**filter_kwargs, "organization_id": None}
|
|
global_providers = ProviderModel.list(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
ascending=ascending,
|
|
check_is_deleted=True,
|
|
**global_filter_kwargs,
|
|
)
|
|
|
|
# Combine both lists
|
|
all_providers = org_providers + global_providers
|
|
|
|
return [provider.to_pydantic() for provider in all_providers]
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
|
@trace_method
|
|
async def get_provider_async(self, provider_id: str, actor: PydanticUser) -> PydanticProvider:
|
|
async with db_registry.async_session() as session:
|
|
# First try to get as organization-specific provider
|
|
try:
|
|
provider_model = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
|
|
return provider_model.to_pydantic()
|
|
except:
|
|
# If not found, try to get as global provider (organization_id=NULL)
|
|
from sqlalchemy import select
|
|
|
|
stmt = select(ProviderModel).where(
|
|
ProviderModel.id == provider_id,
|
|
ProviderModel.organization_id.is_(None),
|
|
ProviderModel.is_deleted == False,
|
|
)
|
|
result = await session.execute(stmt)
|
|
provider_model = result.scalar_one_or_none()
|
|
if provider_model:
|
|
# Remove deprecated api_key and access_key fields from the response
|
|
provider_model.api_key = None
|
|
provider_model.access_key = None
|
|
return provider_model.to_pydantic()
|
|
else:
|
|
from letta.orm.errors import NoResultFound
|
|
|
|
raise NoResultFound(f"Provider not found with id='{provider_id}'")
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
|
providers = self.list_providers(name=provider_name, actor=actor)
|
|
return providers[0].id if providers else None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
|
providers = self.list_providers(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the API key before returning
|
|
api_key_secret = providers[0].api_key_enc
|
|
return api_key_secret.get_plaintext() if api_key_secret else None
|
|
return None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_override_key_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
|
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the API key before returning
|
|
api_key_secret = providers[0].api_key_enc
|
|
return await api_key_secret.get_plaintext_async() if api_key_secret else None
|
|
return None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_bedrock_credentials_async(
|
|
self, provider_name: Union[str, None], actor: PydanticUser
|
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the credentials before returning
|
|
access_key_secret = providers[0].access_key_enc
|
|
api_key_secret = providers[0].api_key_enc
|
|
access_key = await access_key_secret.get_plaintext_async() if access_key_secret else None
|
|
secret_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
|
|
region = providers[0].region
|
|
return access_key, secret_key, region
|
|
return None, None, None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
def get_azure_credentials(
|
|
self, provider_name: Union[str, None], actor: PydanticUser
|
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
providers = self.list_providers(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the API key before returning
|
|
api_key_secret = providers[0].api_key_enc
|
|
api_key = api_key_secret.get_plaintext() if api_key_secret else None
|
|
base_url = providers[0].base_url
|
|
api_version = providers[0].api_version
|
|
return api_key, base_url, api_version
|
|
return None, None, None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_azure_credentials_async(
|
|
self, provider_name: Union[str, None], actor: PydanticUser
|
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
|
if providers:
|
|
# Decrypt the API key before returning
|
|
api_key_secret = providers[0].api_key_enc
|
|
api_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
|
|
base_url = providers[0].base_url
|
|
api_version = providers[0].api_version
|
|
return api_key, base_url, api_version
|
|
return None, None, None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
|
|
provider = PydanticProvider(
|
|
name=provider_check.provider_type.value,
|
|
provider_type=provider_check.provider_type,
|
|
api_key_enc=Secret.from_plaintext(provider_check.api_key),
|
|
provider_category=ProviderCategory.byok,
|
|
access_key_enc=Secret.from_plaintext(provider_check.access_key) if provider_check.access_key else None,
|
|
region=provider_check.region,
|
|
base_url=provider_check.base_url,
|
|
api_version=provider_check.api_version,
|
|
).cast_to_subtype()
|
|
|
|
# TODO: add more string sanity checks here before we hit actual endpoints
|
|
if not provider.api_key_enc or not await provider.api_key_enc.get_plaintext_async():
|
|
raise ValueError("API key is required!")
|
|
|
|
await provider.check_api_key()
|
|
|
|
async def _sync_default_models_for_provider(self, provider: PydanticProvider, actor: PydanticUser) -> None:
|
|
"""Sync models for a newly created BYOK provider by querying the provider's API."""
|
|
from letta.log import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
try:
|
|
# Get the provider class and create an instance
|
|
from letta.schemas.providers.anthropic import AnthropicProvider
|
|
from letta.schemas.providers.azure import AzureProvider
|
|
from letta.schemas.providers.bedrock import BedrockProvider
|
|
from letta.schemas.providers.google_gemini import GoogleAIProvider
|
|
from letta.schemas.providers.groq import GroqProvider
|
|
from letta.schemas.providers.ollama import OllamaProvider
|
|
from letta.schemas.providers.openai import OpenAIProvider
|
|
|
|
provider_type_to_class = {
|
|
"openai": OpenAIProvider,
|
|
"anthropic": AnthropicProvider,
|
|
"groq": GroqProvider,
|
|
"google": GoogleAIProvider,
|
|
"ollama": OllamaProvider,
|
|
"bedrock": BedrockProvider,
|
|
"azure": AzureProvider,
|
|
}
|
|
|
|
provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type)
|
|
provider_class = provider_type_to_class.get(provider_type)
|
|
|
|
if not provider_class:
|
|
logger.warning(f"No provider class found for type '{provider_type}'")
|
|
return
|
|
|
|
# Create provider instance with necessary parameters
|
|
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
|
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
|
kwargs = {
|
|
"name": provider.name,
|
|
"api_key": api_key,
|
|
"provider_category": provider.provider_category,
|
|
}
|
|
if provider.base_url:
|
|
kwargs["base_url"] = provider.base_url
|
|
if access_key:
|
|
kwargs["access_key"] = access_key
|
|
if provider.region:
|
|
kwargs["region"] = provider.region
|
|
if provider.api_version:
|
|
kwargs["api_version"] = provider.api_version
|
|
|
|
provider_instance = provider_class(**kwargs)
|
|
|
|
# Query the provider's API for available models
|
|
llm_models = await provider_instance.list_llm_models_async()
|
|
embedding_models = await provider_instance.list_embedding_models_async()
|
|
|
|
# Update handles and provider_name for BYOK providers
|
|
for model in llm_models:
|
|
model.provider_name = provider.name
|
|
model.handle = f"{provider.name}/{model.model}"
|
|
model.provider_category = provider.provider_category
|
|
|
|
for model in embedding_models:
|
|
model.handle = f"{provider.name}/{model.embedding_model}"
|
|
|
|
# Use existing sync_provider_models_async to save to database
|
|
await self.sync_provider_models_async(
|
|
provider=provider, llm_models=llm_models, embedding_models=embedding_models, organization_id=actor.organization_id
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to sync models for provider '{provider.name}': {e}")
|
|
# Don't fail provider creation if model sync fails
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def sync_base_providers(self, base_providers: list[PydanticProvider], actor: PydanticUser) -> None:
|
|
"""
|
|
Sync base providers (from environment) to database (idempotent).
|
|
|
|
This method is safe to call from multiple pods simultaneously as it:
|
|
1. Checks if provider exists before creating
|
|
2. Handles race conditions with UniqueConstraintViolationError
|
|
3. Only creates providers that don't exist (no updates to avoid conflicts)
|
|
|
|
Args:
|
|
base_providers: List of base provider instances from environment variables
|
|
actor: User actor for database operations
|
|
"""
|
|
from letta.log import get_logger
|
|
from letta.orm.errors import UniqueConstraintViolationError
|
|
|
|
logger = get_logger(__name__)
|
|
logger.info(f"Syncing {len(base_providers)} base providers to database")
|
|
|
|
async with db_registry.async_session() as session:
|
|
for provider in base_providers:
|
|
try:
|
|
# Check if base provider already exists (base providers have organization_id=None)
|
|
existing_providers = await ProviderModel.list_async(
|
|
db_session=session,
|
|
name=provider.name,
|
|
organization_id=None, # Base providers are global
|
|
limit=1,
|
|
)
|
|
|
|
if existing_providers:
|
|
logger.debug(f"Base provider '{provider.name}' already exists in database, skipping")
|
|
continue
|
|
|
|
# Convert Provider to ProviderCreate
|
|
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
|
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
|
provider_create = ProviderCreate(
|
|
name=provider.name,
|
|
provider_type=provider.provider_type,
|
|
api_key=api_key or "", # ProviderCreate requires api_key, use empty string if None
|
|
access_key=access_key,
|
|
region=provider.region,
|
|
base_url=provider.base_url,
|
|
api_version=provider.api_version,
|
|
)
|
|
|
|
# Create the provider in the database as a base provider
|
|
await self.create_provider_async(request=provider_create, actor=actor, is_byok=False)
|
|
logger.info(f"Successfully initialized base provider '{provider.name}' to database")
|
|
|
|
except UniqueConstraintViolationError:
|
|
# Race condition: another pod created this provider between our check and create
|
|
# This is expected and safe - just log and continue
|
|
logger.debug(f"Provider '{provider.name}' was created by another pod, skipping")
|
|
except Exception as e:
|
|
# Log error but don't fail startup - provider initialization is not critical
|
|
logger.error(f"Failed to sync provider '{provider.name}' to database: {e}", exc_info=True)
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def sync_provider_models_async(
|
|
self,
|
|
provider: PydanticProvider,
|
|
llm_models: List[LLMConfig],
|
|
embedding_models: List[EmbeddingConfig],
|
|
organization_id: Optional[str] = None,
|
|
) -> None:
|
|
"""Sync models from a provider to the database - adds new models and removes old ones."""
|
|
from letta.log import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
logger.info(f"=== Starting sync for provider '{provider.name}' (ID: {provider.id}) ===")
|
|
logger.info(f" Organization ID: {organization_id}")
|
|
logger.info(f" LLM models to sync: {[m.handle for m in llm_models]}")
|
|
logger.info(f" Embedding models to sync: {[m.handle for m in embedding_models]}")
|
|
|
|
async with db_registry.async_session() as session:
|
|
# Get all existing models for this provider and organization
|
|
# We need to handle None organization_id specially for SQL NULL comparisons
|
|
from sqlalchemy import and_, select
|
|
|
|
# Build the query conditions
|
|
if organization_id is None:
|
|
# For global models (organization_id IS NULL), excluding soft-deleted
|
|
stmt = select(ProviderModelORM).where(
|
|
and_(
|
|
ProviderModelORM.provider_id == provider.id,
|
|
ProviderModelORM.organization_id.is_(None),
|
|
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
|
|
)
|
|
)
|
|
result = await session.execute(stmt)
|
|
existing_models = list(result.scalars().all())
|
|
else:
|
|
# For org-specific models
|
|
existing_models = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
check_is_deleted=True, # Filter out soft-deleted models
|
|
**{
|
|
"provider_id": provider.id,
|
|
"organization_id": organization_id,
|
|
},
|
|
)
|
|
|
|
# Build sets of handles for incoming models
|
|
incoming_llm_handles = {llm.handle for llm in llm_models}
|
|
incoming_embedding_handles = {emb.handle for emb in embedding_models}
|
|
all_incoming_handles = incoming_llm_handles | incoming_embedding_handles
|
|
|
|
# Determine which models to remove (existing models not in the incoming list)
|
|
models_to_remove = []
|
|
for existing_model in existing_models:
|
|
if existing_model.handle not in all_incoming_handles:
|
|
models_to_remove.append(existing_model)
|
|
|
|
# Remove models that are no longer in the sync list
|
|
for model_to_remove in models_to_remove:
|
|
await model_to_remove.delete_async(session)
|
|
logger.debug(f"Removed model {model_to_remove.handle} from provider {provider.name}")
|
|
|
|
# Commit the deletions
|
|
await session.commit()
|
|
|
|
# Process LLM models - add new ones
|
|
logger.info(f"Processing {len(llm_models)} LLM models for provider {provider.name}")
|
|
for llm_config in llm_models:
|
|
logger.info(f" Checking LLM model: {llm_config.handle} (name: {llm_config.model})")
|
|
|
|
# Check if model already exists (excluding soft-deleted ones)
|
|
existing = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=1,
|
|
check_is_deleted=True, # Filter out soft-deleted models
|
|
**{
|
|
"handle": llm_config.handle,
|
|
"organization_id": organization_id,
|
|
"model_type": "llm", # Must check model_type since handle can be same for LLM and embedding
|
|
},
|
|
)
|
|
|
|
if not existing:
|
|
logger.info(f" Creating new LLM model {llm_config.handle}")
|
|
# Create new model entry
|
|
pydantic_model = PydanticProviderModel(
|
|
handle=llm_config.handle,
|
|
display_name=llm_config.model,
|
|
name=llm_config.model,
|
|
provider_id=provider.id,
|
|
organization_id=organization_id,
|
|
model_type="llm",
|
|
enabled=True,
|
|
model_endpoint_type=llm_config.model_endpoint_type,
|
|
max_context_window=llm_config.context_window,
|
|
supports_token_streaming=llm_config.model_endpoint_type in ["openai", "anthropic", "deepseek"],
|
|
supports_tool_calling=True, # Assume true for LLMs for now
|
|
)
|
|
|
|
logger.info(
|
|
f" Model data: handle={pydantic_model.handle}, name={pydantic_model.name}, "
|
|
f"model_type={pydantic_model.model_type}, provider_id={pydantic_model.provider_id}, "
|
|
f"org_id={pydantic_model.organization_id}"
|
|
)
|
|
|
|
# Convert to ORM
|
|
model = ProviderModelORM(**pydantic_model.model_dump(to_orm=True))
|
|
try:
|
|
await model.create_async(session)
|
|
logger.info(f" ✓ Successfully created LLM model {llm_config.handle} with ID {model.id}")
|
|
except Exception as e:
|
|
logger.error(f" ✗ Failed to create LLM model {llm_config.handle}: {e}")
|
|
# Log the full error details
|
|
import traceback
|
|
|
|
logger.error(f" Full traceback: {traceback.format_exc()}")
|
|
# Roll back the session to clear the failed transaction
|
|
await session.rollback()
|
|
else:
|
|
logger.info(f" LLM model {llm_config.handle} already exists (ID: {existing[0].id}), skipping")
|
|
|
|
# Process embedding models - add new ones
|
|
logger.info(f"Processing {len(embedding_models)} embedding models for provider {provider.name}")
|
|
for embedding_config in embedding_models:
|
|
logger.info(f" Checking embedding model: {embedding_config.handle} (name: {embedding_config.embedding_model})")
|
|
|
|
# Check if model already exists (excluding soft-deleted ones)
|
|
existing = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=1,
|
|
check_is_deleted=True, # Filter out soft-deleted models
|
|
**{
|
|
"handle": embedding_config.handle,
|
|
"organization_id": organization_id,
|
|
"model_type": "embedding", # Must check model_type since handle can be same for LLM and embedding
|
|
},
|
|
)
|
|
|
|
if not existing:
|
|
logger.info(f" Creating new embedding model {embedding_config.handle}")
|
|
# Create new model entry
|
|
pydantic_model = PydanticProviderModel(
|
|
handle=embedding_config.handle,
|
|
display_name=embedding_config.embedding_model,
|
|
name=embedding_config.embedding_model,
|
|
provider_id=provider.id,
|
|
organization_id=organization_id,
|
|
model_type="embedding",
|
|
enabled=True,
|
|
model_endpoint_type=embedding_config.embedding_endpoint_type,
|
|
embedding_dim=embedding_config.embedding_dim if hasattr(embedding_config, "embedding_dim") else None,
|
|
)
|
|
|
|
logger.info(
|
|
f" Model data: handle={pydantic_model.handle}, name={pydantic_model.name}, "
|
|
f"model_type={pydantic_model.model_type}, provider_id={pydantic_model.provider_id}, "
|
|
f"org_id={pydantic_model.organization_id}"
|
|
)
|
|
|
|
# Convert to ORM
|
|
model = ProviderModelORM(**pydantic_model.model_dump(to_orm=True))
|
|
try:
|
|
await model.create_async(session)
|
|
logger.info(f" ✓ Successfully created embedding model {embedding_config.handle} with ID {model.id}")
|
|
except Exception as e:
|
|
logger.error(f" ✗ Failed to create embedding model {embedding_config.handle}: {e}")
|
|
# Log the full error details
|
|
import traceback
|
|
|
|
logger.error(f" Full traceback: {traceback.format_exc()}")
|
|
# Roll back the session to clear the failed transaction
|
|
await session.rollback()
|
|
else:
|
|
logger.info(f" Embedding model {embedding_config.handle} already exists (ID: {existing[0].id}), skipping")
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_model_by_handle_async(
|
|
self,
|
|
handle: str,
|
|
actor: PydanticUser,
|
|
model_type: Optional[str] = None,
|
|
) -> Optional[PydanticProviderModel]:
|
|
"""Get a model by its handle. Handles are unique per organization."""
|
|
async with db_registry.async_session() as session:
|
|
from sqlalchemy import and_, or_, select
|
|
|
|
# Build conditions for the query
|
|
conditions = [
|
|
ProviderModelORM.handle == handle,
|
|
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
|
|
]
|
|
|
|
if model_type:
|
|
conditions.append(ProviderModelORM.model_type == model_type)
|
|
|
|
# Search for models that are either:
|
|
# 1. Organization-specific (matching actor's org)
|
|
# 2. Global (organization_id is NULL)
|
|
conditions.append(or_(ProviderModelORM.organization_id == actor.organization_id, ProviderModelORM.organization_id.is_(None)))
|
|
|
|
stmt = select(ProviderModelORM).where(and_(*conditions))
|
|
result = await session.execute(stmt)
|
|
models = list(result.scalars().all())
|
|
|
|
# Find the model the user has access to
|
|
# Prioritize org-specific models over global models
|
|
org_model = None
|
|
global_model = None
|
|
|
|
for model in models:
|
|
if model.organization_id == actor.organization_id:
|
|
org_model = model
|
|
elif model.organization_id is None:
|
|
global_model = model
|
|
|
|
# Return org-specific model if it exists, otherwise return global model
|
|
if org_model:
|
|
return org_model.to_pydantic()
|
|
elif global_model:
|
|
return global_model.to_pydantic()
|
|
|
|
return None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def list_models_async(
|
|
self,
|
|
actor: PydanticUser,
|
|
model_type: Optional[str] = None,
|
|
provider_id: Optional[str] = None,
|
|
enabled: Optional[bool] = True,
|
|
limit: Optional[int] = None,
|
|
) -> List[PydanticProviderModel]:
|
|
"""List models available to an actor (both global and org-scoped)."""
|
|
async with db_registry.async_session() as session:
|
|
# Build filters
|
|
filters = {}
|
|
if model_type:
|
|
filters["model_type"] = model_type
|
|
if provider_id:
|
|
filters["provider_id"] = provider_id
|
|
if enabled is not None:
|
|
filters["enabled"] = enabled
|
|
|
|
# Get org-scoped models (excluding soft-deleted ones)
|
|
org_filters = {**filters, "organization_id": actor.organization_id}
|
|
org_models = await ProviderModelORM.list_async(
|
|
db_session=session,
|
|
limit=limit,
|
|
check_is_deleted=True, # Filter out soft-deleted models
|
|
**org_filters,
|
|
)
|
|
|
|
# Get global models - need to handle NULL organization_id specially
|
|
from sqlalchemy import and_, select
|
|
|
|
# Build conditions for global models query
|
|
conditions = [
|
|
ProviderModelORM.organization_id.is_(None),
|
|
ProviderModelORM.is_deleted == False, # Filter out soft-deleted models
|
|
]
|
|
if model_type:
|
|
conditions.append(ProviderModelORM.model_type == model_type)
|
|
if provider_id:
|
|
conditions.append(ProviderModelORM.provider_id == provider_id)
|
|
if enabled is not None:
|
|
conditions.append(ProviderModelORM.enabled == enabled)
|
|
|
|
stmt = select(ProviderModelORM).where(and_(*conditions))
|
|
if limit:
|
|
stmt = stmt.limit(limit)
|
|
result = await session.execute(stmt)
|
|
global_models = list(result.scalars().all())
|
|
|
|
# Combine and deduplicate by handle AND model_type (org-scoped takes precedence)
|
|
# Use (handle, model_type) tuple as key since same handle can exist for LLM and embedding
|
|
all_models = {(m.handle, m.model_type): m for m in global_models}
|
|
all_models.update({(m.handle, m.model_type): m for m in org_models})
|
|
|
|
return [m.to_pydantic() for m in all_models.values()]
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_llm_config_from_handle(
|
|
self,
|
|
handle: str,
|
|
actor: PydanticUser,
|
|
) -> LLMConfig:
|
|
"""Get an LLMConfig from a model handle.
|
|
|
|
Args:
|
|
handle: The model handle to look up
|
|
actor: The user actor for permission checking
|
|
|
|
Returns:
|
|
LLMConfig constructed from the provider and model data
|
|
|
|
Raises:
|
|
NoResultFound: If the handle doesn't exist in the database
|
|
"""
|
|
from letta.orm.errors import NoResultFound
|
|
|
|
# Look up the model by handle
|
|
model = await self.get_model_by_handle_async(handle=handle, actor=actor, model_type="llm")
|
|
|
|
if not model:
|
|
raise NoResultFound(f"LLM model not found with handle='{handle}'")
|
|
|
|
# Get the provider for this model
|
|
provider = await self.get_provider_async(provider_id=model.provider_id, actor=actor)
|
|
|
|
# Construct the LLMConfig from the model and provider data
|
|
llm_config = LLMConfig(
|
|
model=model.name,
|
|
model_endpoint_type=model.model_endpoint_type,
|
|
model_endpoint=provider.base_url or f"https://api.{provider.provider_type.value}.com/v1",
|
|
context_window=model.max_context_window or 16384, # Default if not set
|
|
handle=model.handle,
|
|
provider_name=provider.name,
|
|
provider_category=provider.provider_category,
|
|
)
|
|
|
|
return llm_config
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_embedding_config_from_handle(
|
|
self,
|
|
handle: str,
|
|
actor: PydanticUser,
|
|
) -> EmbeddingConfig:
|
|
"""Get an EmbeddingConfig from a model handle.
|
|
|
|
Args:
|
|
handle: The model handle to look up
|
|
actor: The user actor for permission checking
|
|
|
|
Returns:
|
|
EmbeddingConfig constructed from the provider and model data
|
|
|
|
Raises:
|
|
NoResultFound: If the handle doesn't exist in the database
|
|
"""
|
|
from letta.orm.errors import NoResultFound
|
|
|
|
# Look up the model by handle
|
|
model = await self.get_model_by_handle_async(handle=handle, actor=actor, model_type="embedding")
|
|
|
|
if not model:
|
|
raise NoResultFound(f"Embedding model not found with handle='{handle}'")
|
|
|
|
# Get the provider for this model
|
|
provider = await self.get_provider_async(provider_id=model.provider_id, actor=actor)
|
|
|
|
# Construct the EmbeddingConfig from the model and provider data
|
|
embedding_config = EmbeddingConfig(
|
|
embedding_model=model.name,
|
|
embedding_endpoint_type=model.model_endpoint_type,
|
|
embedding_endpoint=provider.base_url or f"https://api.{provider.provider_type.value}.com/v1",
|
|
embedding_dim=model.embedding_dim or 1536, # Use model's dimension or default
|
|
embedding_chunk_size=300, # Default chunk size
|
|
handle=model.handle,
|
|
)
|
|
|
|
return embedding_config
|