import warnings from datetime import datetime from typing import List, Literal, Optional import aiohttp import requests from pydantic import BaseModel, Field, model_validator from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.letta_base import LettaBase from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES from letta.settings import model_settings class ProviderBase(LettaBase): __id_prefix__ = "provider" class Provider(ProviderBase): id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.") name: str = Field(..., description="The name of the provider") provider_type: ProviderType = Field(..., description="The type of the provider") provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)") api_key: Optional[str] = Field(None, description="API key or secret key used for requests to the provider.") base_url: Optional[str] = Field(None, description="Base URL for the provider.") access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.") region: Optional[str] = Field(None, description="Region used for requests to the provider.") organization_id: Optional[str] = Field(None, description="The organization id of the user") updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.") @model_validator(mode="after") def default_base_url(self): if self.provider_type == ProviderType.openai and self.base_url is None: self.base_url = model_settings.openai_api_base return self def resolve_identifier(self): if not self.id: self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__) def check_api_key(self): """Check if the API key is valid for the provider""" raise NotImplementedError def list_llm_models(self) -> List[LLMConfig]: return [] async def list_llm_models_async(self) -> List[LLMConfig]: return [] def list_embedding_models(self) -> List[EmbeddingConfig]: return [] async def list_embedding_models_async(self) -> List[EmbeddingConfig]: return self.list_embedding_models() def get_model_context_window(self, model_name: str) -> Optional[int]: raise NotImplementedError async def get_model_context_window_async(self, model_name: str) -> Optional[int]: raise NotImplementedError def provider_tag(self) -> str: """String representation of the provider for display purposes""" raise NotImplementedError def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str: """ Get the handle for a model, with support for custom overrides. Args: model_name (str): The name of the model. is_embedding (bool, optional): Whether the handle is for an embedding model. Defaults to False. Returns: str: The handle for the model. """ base_name = base_name if base_name else self.name overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES if base_name in overrides and model_name in overrides[base_name]: model_name = overrides[base_name][model_name] return f"{base_name}/{model_name}" def cast_to_subtype(self): match self.provider_type: case ProviderType.letta: return LettaProvider(**self.model_dump(exclude_none=True)) case ProviderType.openai: return OpenAIProvider(**self.model_dump(exclude_none=True)) case ProviderType.anthropic: return AnthropicProvider(**self.model_dump(exclude_none=True)) case ProviderType.bedrock: return BedrockProvider(**self.model_dump(exclude_none=True)) case ProviderType.ollama: return OllamaProvider(**self.model_dump(exclude_none=True)) case ProviderType.google_ai: return GoogleAIProvider(**self.model_dump(exclude_none=True)) case ProviderType.google_vertex: return GoogleVertexProvider(**self.model_dump(exclude_none=True)) case ProviderType.azure: return AzureProvider(**self.model_dump(exclude_none=True)) case ProviderType.groq: return GroqProvider(**self.model_dump(exclude_none=True)) case ProviderType.together: return TogetherProvider(**self.model_dump(exclude_none=True)) case ProviderType.vllm_chat_completions: return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True)) case ProviderType.vllm_completions: return VLLMCompletionsProvider(**self.model_dump(exclude_none=True)) case ProviderType.xai: return XAIProvider(**self.model_dump(exclude_none=True)) case _: raise ValueError(f"Unknown provider type: {self.provider_type}") class ProviderCreate(ProviderBase): name: str = Field(..., description="The name of the provider.") provider_type: ProviderType = Field(..., description="The type of the provider.") api_key: str = Field(..., description="API key or secret key used for requests to the provider.") access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.") region: Optional[str] = Field(None, description="Region used for requests to the provider.") class ProviderUpdate(ProviderBase): api_key: str = Field(..., description="API key or secret key used for requests to the provider.") access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.") region: Optional[str] = Field(None, description="Region used for requests to the provider.") class ProviderCheck(BaseModel): provider_type: ProviderType = Field(..., description="The type of the provider.") api_key: str = Field(..., description="API key or secret key used for requests to the provider.") access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.") region: Optional[str] = Field(None, description="Region used for requests to the provider.") class LettaProvider(Provider): provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") def list_llm_models(self) -> List[LLMConfig]: return [ LLMConfig( model="letta-free", # NOTE: renamed model_endpoint_type="openai", model_endpoint=LETTA_MODEL_ENDPOINT, context_window=30000, handle=self.get_handle("letta-free"), provider_name=self.name, provider_category=self.provider_category, ) ] async def list_llm_models_async(self) -> List[LLMConfig]: return [ LLMConfig( model="letta-free", # NOTE: renamed model_endpoint_type="openai", model_endpoint=LETTA_MODEL_ENDPOINT, context_window=30000, handle=self.get_handle("letta-free"), provider_name=self.name, provider_category=self.provider_category, ) ] def list_embedding_models(self): return [ EmbeddingConfig( embedding_model="letta-free", # NOTE: renamed embedding_endpoint_type="hugging-face", embedding_endpoint="https://embeddings.memgpt.ai", embedding_dim=1024, embedding_chunk_size=300, handle=self.get_handle("letta-free", is_embedding=True), batch_size=32, ) ] class OpenAIProvider(Provider): provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") api_key: str = Field(..., description="API key for the OpenAI API.") base_url: str = Field(..., description="Base URL for the OpenAI API.") def check_api_key(self): from letta.llm_api.openai import openai_check_valid_api_key openai_check_valid_api_key(self.base_url, self.api_key) def _get_models(self) -> List[dict]: from letta.llm_api.openai import openai_get_model_list # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... # See: https://openrouter.ai/docs/requests extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None # Similar to Nebius extra_params = {"verbose": True} if "nebius.com" in self.base_url else None response = openai_get_model_list( self.base_url, api_key=self.api_key, extra_params=extra_params, # fix_url=True, # NOTE: make sure together ends with /v1 ) if "data" in response: data = response["data"] else: # TogetherAI's response is missing the 'data' field data = response return data async def _get_models_async(self) -> List[dict]: from letta.llm_api.openai import openai_get_model_list_async # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... # See: https://openrouter.ai/docs/requests extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None # Similar to Nebius extra_params = {"verbose": True} if "nebius.com" in self.base_url else None response = await openai_get_model_list_async( self.base_url, api_key=self.api_key, extra_params=extra_params, # fix_url=True, # NOTE: make sure together ends with /v1 ) if "data" in response: data = response["data"] else: # TogetherAI's response is missing the 'data' field data = response return data def list_llm_models(self) -> List[LLMConfig]: data = self._get_models() return self._list_llm_models(data) async def list_llm_models_async(self) -> List[LLMConfig]: data = await self._get_models_async() return self._list_llm_models(data) def _list_llm_models(self, data) -> List[LLMConfig]: configs = [] for model in data: assert "id" in model, f"OpenAI model missing 'id' field: {model}" model_name = model["id"] if "context_length" in model: # Context length is returned in OpenRouter as "context_length" context_window_size = model["context_length"] else: context_window_size = self.get_model_context_window_size(model_name) if not context_window_size: continue # TogetherAI includes the type, which we can use to filter out embedding models if "api.together.ai" in self.base_url or "api.together.xyz" in self.base_url: if "type" in model and model["type"] not in ["chat", "language"]: continue # for TogetherAI, we need to skip the models that don't support JSON mode / function calling # requests.exceptions.HTTPError: HTTP error occurred: 400 Client Error: Bad Request for url: https://api.together.ai/v1/chat/completions | Status code: 400, Message: { # "error": { # "message": "mistralai/Mixtral-8x7B-v0.1 is not supported for JSON mode/function calling", # "type": "invalid_request_error", # "param": null, # "code": "constraints_model" # } # } if "config" not in model: continue if "nebius.com" in self.base_url: # Nebius includes the type, which we can use to filter for text models try: model_type = model["architecture"]["modality"] if model_type not in ["text->text", "text+image->text"]: # print(f"Skipping model w/ modality {model_type}:\n{model}") continue except KeyError: print(f"Couldn't access architecture type field, skipping model:\n{model}") continue # for openai, filter models if self.base_url == "https://api.openai.com/v1": allowed_types = ["gpt-4", "o1", "o3", "o4"] # NOTE: o1-mini and o1-preview do not support tool calling # NOTE: o1-mini does not support system messages # NOTE: o1-pro is only available in Responses API disallowed_types = ["transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro"] skip = True for model_type in allowed_types: if model_name.startswith(model_type): skip = False break for keyword in disallowed_types: if keyword in model_name: skip = True break # ignore this model if skip: continue # set the handle to openai-proxy if the base URL isn't OpenAI if self.base_url != "https://api.openai.com/v1": handle = self.get_handle(model_name, base_name="openai-proxy") else: handle = self.get_handle(model_name) llm_config = LLMConfig( model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size, handle=handle, provider_name=self.name, provider_category=self.provider_category, ) # gpt-4o-mini has started to regress with pretty bad emoji spam loops # this is to counteract that if "gpt-4o-mini" in model_name: llm_config.frequency_penalty = 1.0 if "gpt-4.1-mini" in model_name: llm_config.frequency_penalty = 1.0 configs.append(llm_config) # for OpenAI, sort in reverse order if self.base_url == "https://api.openai.com/v1": # alphnumeric sort configs.sort(key=lambda x: x.model, reverse=True) return configs def list_embedding_models(self) -> List[EmbeddingConfig]: if self.base_url == "https://api.openai.com/v1": # TODO: actually automatically list models for OpenAI return [ EmbeddingConfig( embedding_model="text-embedding-ada-002", embedding_endpoint_type="openai", embedding_endpoint=self.base_url, embedding_dim=1536, embedding_chunk_size=300, handle=self.get_handle("text-embedding-ada-002", is_embedding=True), batch_size=1024, ), EmbeddingConfig( embedding_model="text-embedding-3-small", embedding_endpoint_type="openai", embedding_endpoint=self.base_url, embedding_dim=2000, embedding_chunk_size=300, handle=self.get_handle("text-embedding-3-small", is_embedding=True), batch_size=1024, ), EmbeddingConfig( embedding_model="text-embedding-3-large", embedding_endpoint_type="openai", embedding_endpoint=self.base_url, embedding_dim=2000, embedding_chunk_size=300, handle=self.get_handle("text-embedding-3-large", is_embedding=True), batch_size=1024, ), ] else: # Actually attempt to list data = self._get_models() return self._list_embedding_models(data) async def list_embedding_models_async(self) -> List[EmbeddingConfig]: if self.base_url == "https://api.openai.com/v1": # TODO: actually automatically list models for OpenAI return [ EmbeddingConfig( embedding_model="text-embedding-ada-002", embedding_endpoint_type="openai", embedding_endpoint=self.base_url, embedding_dim=1536, embedding_chunk_size=300, handle=self.get_handle("text-embedding-ada-002", is_embedding=True), batch_size=1024, ), EmbeddingConfig( embedding_model="text-embedding-3-small", embedding_endpoint_type="openai", embedding_endpoint=self.base_url, embedding_dim=2000, embedding_chunk_size=300, handle=self.get_handle("text-embedding-3-small", is_embedding=True), batch_size=1024, ), EmbeddingConfig( embedding_model="text-embedding-3-large", embedding_endpoint_type="openai", embedding_endpoint=self.base_url, embedding_dim=2000, embedding_chunk_size=300, handle=self.get_handle("text-embedding-3-large", is_embedding=True), batch_size=1024, ), ] else: # Actually attempt to list data = await self._get_models_async() return self._list_embedding_models(data) def _list_embedding_models(self, data) -> List[EmbeddingConfig]: configs = [] for model in data: assert "id" in model, f"Model missing 'id' field: {model}" model_name = model["id"] if "context_length" in model: # Context length is returned in Nebius as "context_length" context_window_size = model["context_length"] else: context_window_size = self.get_model_context_window_size(model_name) # We need the context length for embeddings too if not context_window_size: continue if "nebius.com" in self.base_url: # Nebius includes the type, which we can use to filter for embedidng models try: model_type = model["architecture"]["modality"] if model_type not in ["text->embedding"]: # print(f"Skipping model w/ modality {model_type}:\n{model}") continue except KeyError: print(f"Couldn't access architecture type field, skipping model:\n{model}") continue elif "together.ai" in self.base_url or "together.xyz" in self.base_url: # TogetherAI includes the type, which we can use to filter for embedding models if "type" in model and model["type"] not in ["embedding"]: # print(f"Skipping model w/ modality {model_type}:\n{model}") continue else: # For other providers we should skip by default, since we don't want to assume embeddings are supported continue configs.append( EmbeddingConfig( embedding_model=model_name, embedding_endpoint_type=self.provider_type, embedding_endpoint=self.base_url, embedding_dim=context_window_size, embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, handle=self.get_handle(model, is_embedding=True), ) ) return configs def get_model_context_window_size(self, model_name: str): if model_name in LLM_MAX_TOKENS: return LLM_MAX_TOKENS[model_name] else: return LLM_MAX_TOKENS["DEFAULT"] class DeepSeekProvider(OpenAIProvider): """ DeepSeek ChatCompletions API is similar to OpenAI's reasoning API, but with slight differences: * For example, DeepSeek's API requires perfect interleaving of user/assistant * It also does not support native function calling """ provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.") api_key: str = Field(..., description="API key for the DeepSeek API.") def get_model_context_window_size(self, model_name: str) -> Optional[int]: # DeepSeek doesn't return context window in the model listing, # so these are hardcoded from their website if model_name == "deepseek-reasoner": return 64000 elif model_name == "deepseek-chat": return 64000 else: return None def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.openai import openai_get_model_list response = openai_get_model_list(self.base_url, api_key=self.api_key) if "data" in response: data = response["data"] else: data = response configs = [] for model in data: assert "id" in model, f"DeepSeek model missing 'id' field: {model}" model_name = model["id"] # In case DeepSeek starts supporting it in the future: if "context_length" in model: # Context length is returned in OpenRouter as "context_length" context_window_size = model["context_length"] else: context_window_size = self.get_model_context_window_size(model_name) if not context_window_size: warnings.warn(f"Couldn't find context window size for model {model_name}") continue # Not used for deepseek-reasoner, but otherwise is true put_inner_thoughts_in_kwargs = False if model_name == "deepseek-reasoner" else True configs.append( LLMConfig( model=model_name, model_endpoint_type="deepseek", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name), put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, provider_name=self.name, provider_category=self.provider_category, ) ) return configs def list_embedding_models(self) -> List[EmbeddingConfig]: # No embeddings supported return [] class LMStudioOpenAIProvider(OpenAIProvider): provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.") api_key: Optional[str] = Field(None, description="API key for the LMStudio API.") def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.openai import openai_get_model_list # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models' MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0" response = openai_get_model_list(MODEL_ENDPOINT_URL) """ Example response: { "object": "list", "data": [ { "id": "qwen2-vl-7b-instruct", "object": "model", "type": "vlm", "publisher": "mlx-community", "arch": "qwen2_vl", "compatibility_type": "mlx", "quantization": "4bit", "state": "not-loaded", "max_context_length": 32768 }, ... """ if "data" not in response: warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}") return [] configs = [] for model in response["data"]: assert "id" in model, f"Model missing 'id' field: {model}" model_name = model["id"] if "type" not in model: warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}") continue elif model["type"] not in ["vlm", "llm"]: continue if "max_context_length" in model: context_window_size = model["max_context_length"] else: warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}") continue configs.append( LLMConfig( model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name), ) ) return configs def list_embedding_models(self) -> List[EmbeddingConfig]: from letta.llm_api.openai import openai_get_model_list # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models' MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0" response = openai_get_model_list(MODEL_ENDPOINT_URL) """ Example response: { "object": "list", "data": [ { "id": "text-embedding-nomic-embed-text-v1.5", "object": "model", "type": "embeddings", "publisher": "nomic-ai", "arch": "nomic-bert", "compatibility_type": "gguf", "quantization": "Q4_0", "state": "not-loaded", "max_context_length": 2048 } ... """ if "data" not in response: warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}") return [] configs = [] for model in response["data"]: assert "id" in model, f"Model missing 'id' field: {model}" model_name = model["id"] if "type" not in model: warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}") continue elif model["type"] not in ["embeddings"]: continue if "max_context_length" in model: context_window_size = model["max_context_length"] else: warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}") continue configs.append( EmbeddingConfig( embedding_model=model_name, embedding_endpoint_type="openai", embedding_endpoint=self.base_url, embedding_dim=context_window_size, embedding_chunk_size=300, # NOTE: max is 2048 handle=self.get_handle(model_name), ), ) return configs class XAIProvider(OpenAIProvider): """https://docs.x.ai/docs/api-reference""" provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") api_key: str = Field(..., description="API key for the xAI/Grok API.") base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.") def get_model_context_window_size(self, model_name: str) -> Optional[int]: # xAI doesn't return context window in the model listing, # so these are hardcoded from their website if model_name == "grok-2-1212": return 131072 # NOTE: disabling the minis for now since they return weird MM parts # elif model_name == "grok-3-mini-fast-beta": # return 131072 # elif model_name == "grok-3-mini-beta": # return 131072 elif model_name == "grok-3-fast-beta": return 131072 elif model_name == "grok-3-beta": return 131072 else: return None def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.openai import openai_get_model_list response = openai_get_model_list(self.base_url, api_key=self.api_key) if "data" in response: data = response["data"] else: data = response configs = [] for model in data: assert "id" in model, f"xAI/Grok model missing 'id' field: {model}" model_name = model["id"] # In case xAI starts supporting it in the future: if "context_length" in model: context_window_size = model["context_length"] else: context_window_size = self.get_model_context_window_size(model_name) if not context_window_size: warnings.warn(f"Couldn't find context window size for model {model_name}") continue configs.append( LLMConfig( model=model_name, model_endpoint_type="xai", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name), provider_name=self.name, provider_category=self.provider_category, ) ) return configs def list_embedding_models(self) -> List[EmbeddingConfig]: # No embeddings supported return [] class AnthropicProvider(Provider): provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") api_key: str = Field(..., description="API key for the Anthropic API.") base_url: str = "https://api.anthropic.com/v1" def check_api_key(self): from letta.llm_api.anthropic import anthropic_check_valid_api_key anthropic_check_valid_api_key(self.api_key) def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.anthropic import anthropic_get_model_list models = anthropic_get_model_list(api_key=self.api_key) return self._list_llm_models(models) async def list_llm_models_async(self) -> List[LLMConfig]: from letta.llm_api.anthropic import anthropic_get_model_list_async models = await anthropic_get_model_list_async(api_key=self.api_key) return self._list_llm_models(models) def _list_llm_models(self, models) -> List[LLMConfig]: from letta.llm_api.anthropic import MODEL_LIST configs = [] for model in models: if model["type"] != "model": continue if "id" not in model: continue # Don't support 2.0 and 2.1 if model["id"].startswith("claude-2"): continue # Anthropic doesn't return the context window in their API if "context_window" not in model: # Remap list to name: context_window model_library = {m["name"]: m["context_window"] for m in MODEL_LIST} # Attempt to look it up in a hardcoded list if model["id"] in model_library: model["context_window"] = model_library[model["id"]] else: # On fallback, we can set 200k (generally safe), but we should warn the user warnings.warn(f"Couldn't find context window size for model {model['id']}, defaulting to 200,000") model["context_window"] = 200000 max_tokens = 8192 if "claude-3-opus" in model["id"]: max_tokens = 4096 if "claude-3-haiku" in model["id"]: max_tokens = 4096 # TODO: set for 3-7 extended thinking mode # We set this to false by default, because Anthropic can # natively support tags inside of content fields # However, putting COT inside of tool calls can make it more # reliable for tool calling (no chance of a non-tool call step) # Since tool_choice_type 'any' doesn't work with in-content COT # NOTE For Haiku, it can be flaky if we don't enable this by default # inner_thoughts_in_kwargs = True if "haiku" in model["id"] else False inner_thoughts_in_kwargs = True # we no longer support thinking tags configs.append( LLMConfig( model=model["id"], model_endpoint_type="anthropic", model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"]), put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, max_tokens=max_tokens, provider_name=self.name, provider_category=self.provider_category, ) ) return configs class MistralProvider(Provider): provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") api_key: str = Field(..., description="API key for the Mistral API.") base_url: str = "https://api.mistral.ai/v1" def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.mistral import mistral_get_model_list # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... # See: https://openrouter.ai/docs/requests response = mistral_get_model_list(self.base_url, api_key=self.api_key) assert "data" in response, f"Mistral model query response missing 'data' field: {response}" configs = [] for model in response["data"]: # If model has chat completions and function calling enabled if model["capabilities"]["completion_chat"] and model["capabilities"]["function_calling"]: configs.append( LLMConfig( model=model["id"], model_endpoint_type="openai", model_endpoint=self.base_url, context_window=model["max_context_length"], handle=self.get_handle(model["id"]), provider_name=self.name, provider_category=self.provider_category, ) ) return configs def list_embedding_models(self) -> List[EmbeddingConfig]: # Not supported for mistral return [] def get_model_context_window(self, model_name: str) -> Optional[int]: # Redoing this is fine because it's a pretty lightweight call models = self.list_llm_models() for m in models: if model_name in m["id"]: return int(m["max_context_length"]) return None class OllamaProvider(OpenAIProvider): """Ollama provider that uses the native /api/generate endpoint See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion """ provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = Field(..., description="Base URL for the Ollama API.") api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).") default_prompt_formatter: str = Field( ..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API." ) async def list_llm_models_async(self) -> List[LLMConfig]: """Async version of list_llm_models below""" endpoint = f"{self.base_url}/api/tags" async with aiohttp.ClientSession() as session: async with session.get(endpoint) as response: if response.status != 200: raise Exception(f"Failed to list Ollama models: {response.text}") response_json = await response.json() configs = [] for model in response_json["models"]: context_window = self.get_model_context_window(model["name"]) if context_window is None: print(f"Ollama model {model['name']} has no context window") continue configs.append( LLMConfig( model=model["name"], model_endpoint_type="ollama", model_endpoint=self.base_url, model_wrapper=self.default_prompt_formatter, context_window=context_window, handle=self.get_handle(model["name"]), provider_name=self.name, provider_category=self.provider_category, ) ) return configs def list_llm_models(self) -> List[LLMConfig]: # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models response = requests.get(f"{self.base_url}/api/tags") if response.status_code != 200: raise Exception(f"Failed to list Ollama models: {response.text}") response_json = response.json() configs = [] for model in response_json["models"]: context_window = self.get_model_context_window(model["name"]) if context_window is None: print(f"Ollama model {model['name']} has no context window") continue configs.append( LLMConfig( model=model["name"], model_endpoint_type="ollama", model_endpoint=self.base_url, model_wrapper=self.default_prompt_formatter, context_window=context_window, handle=self.get_handle(model["name"]), provider_name=self.name, provider_category=self.provider_category, ) ) return configs def get_model_context_window(self, model_name: str) -> Optional[int]: response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) response_json = response.json() ## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675 # possible_keys = [ # # OPT # "max_position_embeddings", # # GPT-2 # "n_positions", # # MPT # "max_seq_len", # # ChatGLM2 # "seq_length", # # Command-R # "model_max_length", # # Others # "max_sequence_length", # "max_seq_length", # "seq_len", # ] # max_position_embeddings # parse model cards: nous, dolphon, llama if "model_info" not in response_json: if "error" in response_json: print(f"Ollama fetch model info error for {model_name}: {response_json['error']}") return None for key, value in response_json["model_info"].items(): if "context_length" in key: return value return None def _get_model_embedding_dim(self, model_name: str): response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) response_json = response.json() return self._get_model_embedding_dim_impl(response_json, model_name) async def _get_model_embedding_dim_async(self, model_name: str): async with aiohttp.ClientSession() as session: async with session.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) as response: response_json = await response.json() return self._get_model_embedding_dim_impl(response_json, model_name) @staticmethod def _get_model_embedding_dim_impl(response_json: dict, model_name: str): if "model_info" not in response_json: if "error" in response_json: print(f"Ollama fetch model info error for {model_name}: {response_json['error']}") return None for key, value in response_json["model_info"].items(): if "embedding_length" in key: return value return None async def list_embedding_models_async(self) -> List[EmbeddingConfig]: """Async version of list_embedding_models below""" endpoint = f"{self.base_url}/api/tags" async with aiohttp.ClientSession() as session: async with session.get(endpoint) as response: if response.status != 200: raise Exception(f"Failed to list Ollama models: {response.text}") response_json = await response.json() configs = [] for model in response_json["models"]: embedding_dim = await self._get_model_embedding_dim_async(model["name"]) if not embedding_dim: print(f"Ollama model {model['name']} has no embedding dimension") continue configs.append( EmbeddingConfig( embedding_model=model["name"], embedding_endpoint_type="ollama", embedding_endpoint=self.base_url, embedding_dim=embedding_dim, embedding_chunk_size=300, handle=self.get_handle(model["name"], is_embedding=True), ) ) return configs def list_embedding_models(self) -> List[EmbeddingConfig]: # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models response = requests.get(f"{self.base_url}/api/tags") if response.status_code != 200: raise Exception(f"Failed to list Ollama models: {response.text}") response_json = response.json() configs = [] for model in response_json["models"]: embedding_dim = self._get_model_embedding_dim(model["name"]) if not embedding_dim: print(f"Ollama model {model['name']} has no embedding dimension") continue configs.append( EmbeddingConfig( embedding_model=model["name"], embedding_endpoint_type="ollama", embedding_endpoint=self.base_url, embedding_dim=embedding_dim, embedding_chunk_size=300, handle=self.get_handle(model["name"], is_embedding=True), ) ) return configs class GroqProvider(OpenAIProvider): provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = "https://api.groq.com/openai/v1" api_key: str = Field(..., description="API key for the Groq API.") def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.openai import openai_get_model_list response = openai_get_model_list(self.base_url, api_key=self.api_key) configs = [] for model in response["data"]: if "context_window" not in model: continue configs.append( LLMConfig( model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"]), provider_name=self.name, provider_category=self.provider_category, ) ) return configs def list_embedding_models(self) -> List[EmbeddingConfig]: return [] class TogetherProvider(OpenAIProvider): """TogetherAI provider that uses the /completions API TogetherAI can also be used via the /chat/completions API by settings OPENAI_API_KEY and OPENAI_API_BASE to the TogetherAI API key and API URL, however /completions is preferred because their /chat/completions function calling support is limited. """ provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = "https://api.together.ai/v1" api_key: str = Field(..., description="API key for the TogetherAI API.") default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.openai import openai_get_model_list models = openai_get_model_list(self.base_url, api_key=self.api_key) return self._list_llm_models(models) async def list_llm_models_async(self) -> List[LLMConfig]: from letta.llm_api.openai import openai_get_model_list_async models = await openai_get_model_list_async(self.base_url, api_key=self.api_key) return self._list_llm_models(models) def _list_llm_models(self, models) -> List[LLMConfig]: pass # TogetherAI's response is missing the 'data' field # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}" if "data" in models: data = models["data"] else: data = models configs = [] for model in data: assert "id" in model, f"TogetherAI model missing 'id' field: {model}" model_name = model["id"] if "context_length" in model: # Context length is returned in OpenRouter as "context_length" context_window_size = model["context_length"] else: context_window_size = self.get_model_context_window_size(model_name) # We need the context length for embeddings too if not context_window_size: continue # Skip models that are too small for Letta if context_window_size <= MIN_CONTEXT_WINDOW: continue # TogetherAI includes the type, which we can use to filter for embedding models if "type" in model and model["type"] not in ["chat", "language"]: continue configs.append( LLMConfig( model=model_name, model_endpoint_type="together", model_endpoint=self.base_url, model_wrapper=self.default_prompt_formatter, context_window=context_window_size, handle=self.get_handle(model_name), provider_name=self.name, provider_category=self.provider_category, ) ) return configs def list_embedding_models(self) -> List[EmbeddingConfig]: # TODO renable once we figure out how to pass API keys through properly return [] # from letta.llm_api.openai import openai_get_model_list # response = openai_get_model_list(self.base_url, api_key=self.api_key) # # TogetherAI's response is missing the 'data' field # # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}" # if "data" in response: # data = response["data"] # else: # data = response # configs = [] # for model in data: # assert "id" in model, f"TogetherAI model missing 'id' field: {model}" # model_name = model["id"] # if "context_length" in model: # # Context length is returned in OpenRouter as "context_length" # context_window_size = model["context_length"] # else: # context_window_size = self.get_model_context_window_size(model_name) # if not context_window_size: # continue # # TogetherAI includes the type, which we can use to filter out embedding models # if "type" in model and model["type"] not in ["embedding"]: # continue # configs.append( # EmbeddingConfig( # embedding_model=model_name, # embedding_endpoint_type="openai", # embedding_endpoint=self.base_url, # embedding_dim=context_window_size, # embedding_chunk_size=300, # TODO: change? # ) # ) # return configs class GoogleAIProvider(Provider): # gemini provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") api_key: str = Field(..., description="API key for the Google AI API.") base_url: str = "https://generativelanguage.googleapis.com" def check_api_key(self): from letta.llm_api.google_ai_client import google_ai_check_valid_api_key google_ai_check_valid_api_key(self.api_key) def list_llm_models(self): from letta.llm_api.google_ai_client import google_ai_get_model_list model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] model_options = [str(m["name"]) for m in model_options] # filter by model names model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] # Add support for all gemini models model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] configs = [] for model in model_options: configs.append( LLMConfig( model=model, model_endpoint_type="google_ai", model_endpoint=self.base_url, context_window=self.get_model_context_window(model), handle=self.get_handle(model), max_tokens=8192, provider_name=self.name, provider_category=self.provider_category, ) ) return configs async def list_llm_models_async(self): import asyncio from letta.llm_api.google_ai_client import google_ai_get_model_list_async # Get and filter the model list model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] model_options = [str(m["name"]) for m in model_options] # filter by model names model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] # Add support for all gemini models model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] # Prepare tasks for context window lookups in parallel async def create_config(model): context_window = await self.get_model_context_window_async(model) return LLMConfig( model=model, model_endpoint_type="google_ai", model_endpoint=self.base_url, context_window=context_window, handle=self.get_handle(model), max_tokens=8192, provider_name=self.name, provider_category=self.provider_category, ) # Execute all config creation tasks concurrently configs = await asyncio.gather(*[create_config(model) for model in model_options]) return configs def list_embedding_models(self): from letta.llm_api.google_ai_client import google_ai_get_model_list # TODO: use base_url instead model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) return self._list_embedding_models(model_options) async def list_embedding_models_async(self): from letta.llm_api.google_ai_client import google_ai_get_model_list_async # TODO: use base_url instead model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) return self._list_embedding_models(model_options) def _list_embedding_models(self, model_options): # filter by 'generateContent' models model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]] model_options = [str(m["name"]) for m in model_options] model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] configs = [] for model in model_options: configs.append( EmbeddingConfig( embedding_model=model, embedding_endpoint_type="google_ai", embedding_endpoint=self.base_url, embedding_dim=768, embedding_chunk_size=300, # NOTE: max is 2048 handle=self.get_handle(model, is_embedding=True), batch_size=1024, ) ) return configs def get_model_context_window(self, model_name: str) -> Optional[int]: from letta.llm_api.google_ai_client import google_ai_get_model_context_window if model_name in LLM_MAX_TOKENS: return LLM_MAX_TOKENS[model_name] else: return google_ai_get_model_context_window(self.base_url, self.api_key, model_name) async def get_model_context_window_async(self, model_name: str) -> Optional[int]: from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async if model_name in LLM_MAX_TOKENS: return LLM_MAX_TOKENS[model_name] else: return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name) class GoogleVertexProvider(Provider): provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.") google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.") def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.google_constants import GOOGLE_MODEL_TO_CONTEXT_LENGTH configs = [] for model, context_length in GOOGLE_MODEL_TO_CONTEXT_LENGTH.items(): configs.append( LLMConfig( model=model, model_endpoint_type="google_vertex", model_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}", context_window=context_length, handle=self.get_handle(model), max_tokens=8192, provider_name=self.name, provider_category=self.provider_category, ) ) return configs def list_embedding_models(self) -> List[EmbeddingConfig]: from letta.llm_api.google_constants import GOOGLE_EMBEDING_MODEL_TO_DIM configs = [] for model, dim in GOOGLE_EMBEDING_MODEL_TO_DIM.items(): configs.append( EmbeddingConfig( embedding_model=model, embedding_endpoint_type="google_vertex", embedding_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}", embedding_dim=dim, embedding_chunk_size=300, # NOTE: max is 2048 handle=self.get_handle(model, is_embedding=True), batch_size=1024, ) ) return configs class AzureProvider(Provider): provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation base_url: str = Field( ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`." ) api_key: str = Field(..., description="API key for the Azure API.") api_version: str = Field(latest_api_version, description="API version for the Azure API") @model_validator(mode="before") def set_default_api_version(cls, values): """ This ensures that api_version is always set to the default if None is passed in. """ if values.get("api_version") is None: values["api_version"] = cls.model_fields["latest_api_version"].default return values def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.azure_openai import azure_openai_get_chat_completion_model_list model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version) configs = [] for model_option in model_options: model_name = model_option["id"] context_window_size = self.get_model_context_window(model_name) model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version) configs.append( LLMConfig( model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size, handle=self.get_handle(model_name), provider_name=self.name, provider_category=self.provider_category, ), ) return configs def list_embedding_models(self) -> List[EmbeddingConfig]: from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list model_options = azure_openai_get_embeddings_model_list( self.base_url, api_key=self.api_key, api_version=self.api_version, require_embedding_in_name=True ) configs = [] for model_option in model_options: model_name = model_option["id"] model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version) configs.append( EmbeddingConfig( embedding_model=model_name, embedding_endpoint_type="azure", embedding_endpoint=model_endpoint, embedding_dim=768, embedding_chunk_size=300, # NOTE: max is 2048 handle=self.get_handle(model_name), batch_size=1024, ), ) return configs def get_model_context_window(self, model_name: str) -> Optional[int]: """ This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model. """ context_window = AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, None) if context_window is None: context_window = LLM_MAX_TOKENS.get(model_name, 4096) return context_window class VLLMChatCompletionsProvider(Provider): """vLLM provider that treats vLLM as an OpenAI /chat/completions proxy""" # NOTE: vLLM only serves one model at a time (so could configure that through env variables) provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = Field(..., description="Base URL for the vLLM API.") def list_llm_models(self) -> List[LLMConfig]: # not supported with vLLM from letta.llm_api.openai import openai_get_model_list assert self.base_url, "base_url is required for vLLM provider" response = openai_get_model_list(self.base_url, api_key=None) configs = [] for model in response["data"]: configs.append( LLMConfig( model=model["id"], model_endpoint_type="openai", model_endpoint=self.base_url, context_window=model["max_model_len"], handle=self.get_handle(model["id"]), provider_name=self.name, provider_category=self.provider_category, ) ) return configs def list_embedding_models(self) -> List[EmbeddingConfig]: # not supported with vLLM return [] class VLLMCompletionsProvider(Provider): """This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper""" # NOTE: vLLM only serves one model at a time (so could configure that through env variables) provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = Field(..., description="Base URL for the vLLM API.") default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") def list_llm_models(self) -> List[LLMConfig]: # not supported with vLLM from letta.llm_api.openai import openai_get_model_list response = openai_get_model_list(self.base_url, api_key=None) configs = [] for model in response["data"]: configs.append( LLMConfig( model=model["id"], model_endpoint_type="vllm", model_endpoint=self.base_url, model_wrapper=self.default_prompt_formatter, context_window=model["max_model_len"], handle=self.get_handle(model["id"]), provider_name=self.name, provider_category=self.provider_category, ) ) return configs def list_embedding_models(self) -> List[EmbeddingConfig]: # not supported with vLLM return [] class CohereProvider(OpenAIProvider): pass class BedrockProvider(Provider): provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") region: str = Field(..., description="AWS region for Bedrock") def check_api_key(self): """Check if the Bedrock credentials are valid""" from letta.errors import LLMAuthenticationError from letta.llm_api.aws_bedrock import bedrock_get_model_list try: # For BYOK providers, use the custom credentials if self.provider_category == ProviderCategory.byok: # If we can list models, the credentials are valid bedrock_get_model_list( region_name=self.region, access_key_id=self.access_key, secret_access_key=self.api_key, # api_key stores the secret access key ) else: # For base providers, use default credentials bedrock_get_model_list(region_name=self.region) except Exception as e: raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}") def list_llm_models(self): from letta.llm_api.aws_bedrock import bedrock_get_model_list models = bedrock_get_model_list(self.region) configs = [] for model_summary in models: model_arn = model_summary["inferenceProfileArn"] configs.append( LLMConfig( model=model_arn, model_endpoint_type=self.provider_type.value, model_endpoint=None, context_window=self.get_model_context_window(model_arn), handle=self.get_handle(model_arn), provider_name=self.name, provider_category=self.provider_category, ) ) return configs async def list_llm_models_async(self) -> List[LLMConfig]: from letta.llm_api.aws_bedrock import bedrock_get_model_list_async models = await bedrock_get_model_list_async( self.access_key, self.api_key, self.region, ) configs = [] for model_summary in models: model_arn = model_summary["inferenceProfileArn"] configs.append( LLMConfig( model=model_arn, model_endpoint_type=self.provider_type.value, model_endpoint=None, context_window=self.get_model_context_window(model_arn), handle=self.get_handle(model_arn), provider_name=self.name, provider_category=self.provider_category, ) ) return configs def list_embedding_models(self): return [] def get_model_context_window(self, model_name: str) -> Optional[int]: # Context windows for Claude models from letta.llm_api.aws_bedrock import bedrock_get_model_context_window return bedrock_get_model_context_window(model_name) def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str: print(model_name) model = model_name.split(".")[-1] return f"{self.name}/{model}"