209 lines
9.8 KiB
Python
209 lines
9.8 KiB
Python
from datetime import datetime
|
|
|
|
from pydantic import BaseModel, Field, model_validator
|
|
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
|
|
from letta.schemas.enums import ProviderCategory, ProviderType
|
|
from letta.schemas.letta_base import LettaBase
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
|
|
from letta.settings import model_settings
|
|
|
|
|
|
class ProviderBase(LettaBase):
|
|
__id_prefix__ = "provider"
|
|
|
|
|
|
class Provider(ProviderBase):
|
|
id: str | None = Field(None, description="The id of the provider, lazily created by the database manager.")
|
|
name: str = Field(..., description="The name of the provider")
|
|
provider_type: ProviderType = Field(..., description="The type of the provider")
|
|
provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)")
|
|
api_key: str | None = Field(None, description="API key or secret key used for requests to the provider.")
|
|
base_url: str | None = Field(None, description="Base URL for the provider.")
|
|
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
|
|
region: str | None = Field(None, description="Region used for requests to the provider.")
|
|
api_version: str | None = Field(None, description="API version used for requests to the provider.")
|
|
organization_id: str | None = Field(None, description="The organization id of the user")
|
|
updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.")
|
|
|
|
@model_validator(mode="after")
|
|
def default_base_url(self):
|
|
if self.provider_type == ProviderType.openai and self.base_url is None:
|
|
self.base_url = model_settings.openai_api_base
|
|
return self
|
|
|
|
def resolve_identifier(self):
|
|
if not self.id:
|
|
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
|
|
|
|
async def check_api_key(self):
|
|
"""Check if the API key is valid for the provider"""
|
|
raise NotImplementedError
|
|
|
|
def list_llm_models(self) -> list[LLMConfig]:
|
|
"""List available LLM models (deprecated: use list_llm_models_async)"""
|
|
import asyncio
|
|
import warnings
|
|
|
|
warnings.warn("list_llm_models is deprecated, use list_llm_models_async instead", DeprecationWarning, stacklevel=2)
|
|
|
|
# Simplified asyncio handling - just use asyncio.run()
|
|
# This works in most contexts and avoids complex event loop detection
|
|
try:
|
|
return asyncio.run(self.list_llm_models_async())
|
|
except RuntimeError as e:
|
|
# If we're in an active event loop context, use a thread pool
|
|
if "cannot be called from a running event loop" in str(e):
|
|
import concurrent.futures
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
future = executor.submit(asyncio.run, self.list_llm_models_async())
|
|
return future.result()
|
|
else:
|
|
raise
|
|
|
|
async def list_llm_models_async(self) -> list[LLMConfig]:
|
|
return []
|
|
|
|
def list_embedding_models(self) -> list[EmbeddingConfig]:
|
|
"""List available embedding models (deprecated: use list_embedding_models_async)"""
|
|
import asyncio
|
|
import warnings
|
|
|
|
warnings.warn("list_embedding_models is deprecated, use list_embedding_models_async instead", DeprecationWarning, stacklevel=2)
|
|
|
|
# Simplified asyncio handling - just use asyncio.run()
|
|
# This works in most contexts and avoids complex event loop detection
|
|
try:
|
|
return asyncio.run(self.list_embedding_models_async())
|
|
except RuntimeError as e:
|
|
# If we're in an active event loop context, use a thread pool
|
|
if "cannot be called from a running event loop" in str(e):
|
|
import concurrent.futures
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
future = executor.submit(asyncio.run, self.list_embedding_models_async())
|
|
return future.result()
|
|
else:
|
|
raise
|
|
|
|
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
|
|
"""List available embedding models. The following do not have support for embedding models:
|
|
Anthropic, Bedrock, Cerebras, Deepseek, Groq, Mistral, xAI
|
|
"""
|
|
return []
|
|
|
|
def get_model_context_window(self, model_name: str) -> int | None:
|
|
raise NotImplementedError
|
|
|
|
async def get_model_context_window_async(self, model_name: str) -> int | None:
|
|
raise NotImplementedError
|
|
|
|
def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str:
|
|
"""
|
|
Get the handle for a model, with support for custom overrides.
|
|
|
|
Args:
|
|
model_name (str): The name of the model.
|
|
is_embedding (bool, optional): Whether the handle is for an embedding model. Defaults to False.
|
|
|
|
Returns:
|
|
str: The handle for the model.
|
|
"""
|
|
base_name = base_name if base_name else self.name
|
|
|
|
overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES
|
|
if base_name in overrides and model_name in overrides[base_name]:
|
|
model_name = overrides[base_name][model_name]
|
|
|
|
return f"{base_name}/{model_name}"
|
|
|
|
def cast_to_subtype(self):
|
|
# Import here to avoid circular imports
|
|
from letta.schemas.providers import (
|
|
AnthropicProvider,
|
|
AzureProvider,
|
|
BedrockProvider,
|
|
CerebrasProvider,
|
|
DeepSeekProvider,
|
|
GoogleAIProvider,
|
|
GoogleVertexProvider,
|
|
GroqProvider,
|
|
LettaProvider,
|
|
LMStudioOpenAIProvider,
|
|
MistralProvider,
|
|
OllamaProvider,
|
|
OpenAIProvider,
|
|
TogetherProvider,
|
|
VLLMProvider,
|
|
XAIProvider,
|
|
)
|
|
|
|
if self.base_url == "":
|
|
self.base_url = None
|
|
|
|
match self.provider_type:
|
|
case ProviderType.letta:
|
|
return LettaProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.openai:
|
|
return OpenAIProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.anthropic:
|
|
return AnthropicProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.google_ai:
|
|
return GoogleAIProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.google_vertex:
|
|
return GoogleVertexProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.azure:
|
|
return AzureProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.groq:
|
|
return GroqProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.together:
|
|
return TogetherProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.ollama:
|
|
return OllamaProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.vllm:
|
|
return VLLMProvider(**self.model_dump(exclude_none=True)) # Removed support for CompletionsProvider
|
|
case ProviderType.mistral:
|
|
return MistralProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.deepseek:
|
|
return DeepSeekProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.cerebras:
|
|
return CerebrasProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.xai:
|
|
return XAIProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.lmstudio_openai:
|
|
return LMStudioOpenAIProvider(**self.model_dump(exclude_none=True))
|
|
case ProviderType.bedrock:
|
|
return BedrockProvider(**self.model_dump(exclude_none=True))
|
|
case _:
|
|
raise ValueError(f"Unknown provider type: {self.provider_type}")
|
|
|
|
|
|
class ProviderCreate(ProviderBase):
|
|
name: str = Field(..., description="The name of the provider.")
|
|
provider_type: ProviderType = Field(..., description="The type of the provider.")
|
|
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
|
|
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
|
|
region: str | None = Field(None, description="Region used for requests to the provider.")
|
|
base_url: str | None = Field(None, description="Base URL used for requests to the provider.")
|
|
api_version: str | None = Field(None, description="API version used for requests to the provider.")
|
|
|
|
|
|
class ProviderUpdate(ProviderBase):
|
|
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
|
|
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
|
|
region: str | None = Field(None, description="Region used for requests to the provider.")
|
|
base_url: str | None = Field(None, description="Base URL used for requests to the provider.")
|
|
api_version: str | None = Field(None, description="API version used for requests to the provider.")
|
|
|
|
|
|
class ProviderCheck(BaseModel):
|
|
provider_type: ProviderType = Field(..., description="The type of the provider.")
|
|
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
|
|
access_key: str | None = Field(None, description="Access key used for requests to the provider.")
|
|
region: str | None = Field(None, description="Region used for requests to the provider.")
|
|
base_url: str | None = Field(None, description="Base URL used for requests to the provider.")
|
|
api_version: str | None = Field(None, description="API version used for requests to the provider.")
|