Files
letta-server/letta/schemas/providers/base.py
2025-08-18 14:22:36 -07:00

209 lines
9.8 KiB
Python

from datetime import datetime
from pydantic import BaseModel, Field, model_validator
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.letta_base import LettaBase
from letta.schemas.llm_config import LLMConfig
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
from letta.settings import model_settings
class ProviderBase(LettaBase):
__id_prefix__ = "provider"
class Provider(ProviderBase):
id: str | None = Field(None, description="The id of the provider, lazily created by the database manager.")
name: str = Field(..., description="The name of the provider")
provider_type: ProviderType = Field(..., description="The type of the provider")
provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)")
api_key: str | None = Field(None, description="API key or secret key used for requests to the provider.")
base_url: str | None = Field(None, description="Base URL for the provider.")
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
region: str | None = Field(None, description="Region used for requests to the provider.")
api_version: str | None = Field(None, description="API version used for requests to the provider.")
organization_id: str | None = Field(None, description="The organization id of the user")
updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.")
@model_validator(mode="after")
def default_base_url(self):
if self.provider_type == ProviderType.openai and self.base_url is None:
self.base_url = model_settings.openai_api_base
return self
def resolve_identifier(self):
if not self.id:
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
async def check_api_key(self):
"""Check if the API key is valid for the provider"""
raise NotImplementedError
def list_llm_models(self) -> list[LLMConfig]:
"""List available LLM models (deprecated: use list_llm_models_async)"""
import asyncio
import warnings
warnings.warn("list_llm_models is deprecated, use list_llm_models_async instead", DeprecationWarning, stacklevel=2)
# Simplified asyncio handling - just use asyncio.run()
# This works in most contexts and avoids complex event loop detection
try:
return asyncio.run(self.list_llm_models_async())
except RuntimeError as e:
# If we're in an active event loop context, use a thread pool
if "cannot be called from a running event loop" in str(e):
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self.list_llm_models_async())
return future.result()
else:
raise
async def list_llm_models_async(self) -> list[LLMConfig]:
return []
def list_embedding_models(self) -> list[EmbeddingConfig]:
"""List available embedding models (deprecated: use list_embedding_models_async)"""
import asyncio
import warnings
warnings.warn("list_embedding_models is deprecated, use list_embedding_models_async instead", DeprecationWarning, stacklevel=2)
# Simplified asyncio handling - just use asyncio.run()
# This works in most contexts and avoids complex event loop detection
try:
return asyncio.run(self.list_embedding_models_async())
except RuntimeError as e:
# If we're in an active event loop context, use a thread pool
if "cannot be called from a running event loop" in str(e):
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self.list_embedding_models_async())
return future.result()
else:
raise
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
"""List available embedding models. The following do not have support for embedding models:
Anthropic, Bedrock, Cerebras, Deepseek, Groq, Mistral, xAI
"""
return []
def get_model_context_window(self, model_name: str) -> int | None:
raise NotImplementedError
async def get_model_context_window_async(self, model_name: str) -> int | None:
raise NotImplementedError
def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str:
"""
Get the handle for a model, with support for custom overrides.
Args:
model_name (str): The name of the model.
is_embedding (bool, optional): Whether the handle is for an embedding model. Defaults to False.
Returns:
str: The handle for the model.
"""
base_name = base_name if base_name else self.name
overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES
if base_name in overrides and model_name in overrides[base_name]:
model_name = overrides[base_name][model_name]
return f"{base_name}/{model_name}"
def cast_to_subtype(self):
# Import here to avoid circular imports
from letta.schemas.providers import (
AnthropicProvider,
AzureProvider,
BedrockProvider,
CerebrasProvider,
DeepSeekProvider,
GoogleAIProvider,
GoogleVertexProvider,
GroqProvider,
LettaProvider,
LMStudioOpenAIProvider,
MistralProvider,
OllamaProvider,
OpenAIProvider,
TogetherProvider,
VLLMProvider,
XAIProvider,
)
if self.base_url == "":
self.base_url = None
match self.provider_type:
case ProviderType.letta:
return LettaProvider(**self.model_dump(exclude_none=True))
case ProviderType.openai:
return OpenAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.anthropic:
return AnthropicProvider(**self.model_dump(exclude_none=True))
case ProviderType.google_ai:
return GoogleAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.google_vertex:
return GoogleVertexProvider(**self.model_dump(exclude_none=True))
case ProviderType.azure:
return AzureProvider(**self.model_dump(exclude_none=True))
case ProviderType.groq:
return GroqProvider(**self.model_dump(exclude_none=True))
case ProviderType.together:
return TogetherProvider(**self.model_dump(exclude_none=True))
case ProviderType.ollama:
return OllamaProvider(**self.model_dump(exclude_none=True))
case ProviderType.vllm:
return VLLMProvider(**self.model_dump(exclude_none=True)) # Removed support for CompletionsProvider
case ProviderType.mistral:
return MistralProvider(**self.model_dump(exclude_none=True))
case ProviderType.deepseek:
return DeepSeekProvider(**self.model_dump(exclude_none=True))
case ProviderType.cerebras:
return CerebrasProvider(**self.model_dump(exclude_none=True))
case ProviderType.xai:
return XAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.lmstudio_openai:
return LMStudioOpenAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.bedrock:
return BedrockProvider(**self.model_dump(exclude_none=True))
case _:
raise ValueError(f"Unknown provider type: {self.provider_type}")
class ProviderCreate(ProviderBase):
name: str = Field(..., description="The name of the provider.")
provider_type: ProviderType = Field(..., description="The type of the provider.")
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
region: str | None = Field(None, description="Region used for requests to the provider.")
base_url: str | None = Field(None, description="Base URL used for requests to the provider.")
api_version: str | None = Field(None, description="API version used for requests to the provider.")
class ProviderUpdate(ProviderBase):
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
region: str | None = Field(None, description="Region used for requests to the provider.")
base_url: str | None = Field(None, description="Base URL used for requests to the provider.")
api_version: str | None = Field(None, description="API version used for requests to the provider.")
class ProviderCheck(BaseModel):
provider_type: ProviderType = Field(..., description="The type of the provider.")
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
region: str | None = Field(None, description="Region used for requests to the provider.")
base_url: str | None = Field(None, description="Base URL used for requests to the provider.")
api_version: str | None = Field(None, description="API version used for requests to the provider.")