diff --git a/.github/workflows/notify-letta-cloud.yml b/.github/workflows/notify-letta-cloud.yml deleted file mode 100644 index 0874be59..00000000 --- a/.github/workflows/notify-letta-cloud.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Notify Letta Cloud - -on: - push: - branches: - - main - -jobs: - notify: - runs-on: ubuntu-latest - if: ${{ !contains(github.event.head_commit.message, '[sync-skip]') }} - steps: - - name: Trigger repository_dispatch - run: | - curl -X POST \ - -H "Authorization: token ${{ secrets.SYNC_PAT }}" \ - -H "Accept: application/vnd.github.v3+json" \ - https://api.github.com/repos/letta-ai/letta-cloud/dispatches \ - -d '{"event_type":"oss-update"}' diff --git a/examples/files/README.md b/examples/files/README.md index e6b4a421..736c3a8a 100644 --- a/examples/files/README.md +++ b/examples/files/README.md @@ -31,4 +31,4 @@ The demo will: 3. Create an agent named "Clippy" 4. Start an interactive chat session -Type 'quit' or 'exit' to end the conversation. +Type 'quit' or 'exit' to end the conversation. \ No newline at end of file diff --git a/examples/files/main.py b/examples/files/main.py index 6ef04978..98c29c55 100644 --- a/examples/files/main.py +++ b/examples/files/main.py @@ -63,7 +63,6 @@ except Exception as e: # 1. From an existing file # 2. From a string by encoding it into a base64 string # -# # 1. From an existing file # "rb" means "read binary" diff --git a/letta/__init__.py b/letta/__init__.py index 39426a1b..468c69e9 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -5,7 +5,7 @@ try: __version__ = version("letta") except PackageNotFoundError: # Fallback for development installations - __version__ = "0.10.0" + __version__ = "0.11.4" if os.environ.get("LETTA_VERSION"): __version__ = os.environ["LETTA_VERSION"] diff --git a/letta/constants.py b/letta/constants.py index ff77ff69..5513fab8 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -10,6 +10,7 @@ DEFAULT_TIMEZONE = "UTC" ADMIN_PREFIX = "/v1/admin" API_PREFIX = "/v1" +OLLAMA_API_PREFIX = "/v1" OPENAI_API_PREFIX = "/openai" COMPOSIO_ENTITY_ENV_VAR_KEY = "COMPOSIO_ENTITY" @@ -51,8 +52,9 @@ TOOL_CALL_ID_MAX_LEN = 29 # Max steps for agent loop DEFAULT_MAX_STEPS = 50 -# minimum context window size +# context window size MIN_CONTEXT_WINDOW = 4096 +DEFAULT_CONTEXT_WINDOW = 32000 # number of concurrent embedding requests to sent EMBEDDING_BATCH_SIZE = 200 @@ -64,6 +66,7 @@ DEFAULT_MIN_MESSAGE_BUFFER_LENGTH = 15 # embeddings MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset DEFAULT_EMBEDDING_CHUNK_SIZE = 300 +DEFAULT_EMBEDDING_DIM = 1024 # tokenizers EMBEDDING_TO_TOKENIZER_MAP = { diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py new file mode 100644 index 00000000..97d68281 --- /dev/null +++ b/letta/schemas/providers.py @@ -0,0 +1,1618 @@ +import warnings +from datetime import datetime +from typing import List, Literal, Optional + +import aiohttp +import requests +from pydantic import BaseModel, Field, model_validator + +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW +from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint +from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.letta_base import LettaBase +from letta.schemas.llm_config import LLMConfig +from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES +from letta.settings import model_settings + + +class ProviderBase(LettaBase): + __id_prefix__ = "provider" + + +class Provider(ProviderBase): + id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.") + name: str = Field(..., description="The name of the provider") + provider_type: ProviderType = Field(..., description="The type of the provider") + provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)") + api_key: Optional[str] = Field(None, description="API key or secret key used for requests to the provider.") + base_url: Optional[str] = Field(None, description="Base URL for the provider.") + access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.") + region: Optional[str] = Field(None, description="Region used for requests to the provider.") + organization_id: Optional[str] = Field(None, description="The organization id of the user") + updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.") + + @model_validator(mode="after") + def default_base_url(self): + if self.provider_type == ProviderType.openai and self.base_url is None: + self.base_url = model_settings.openai_api_base + return self + + def resolve_identifier(self): + if not self.id: + self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__) + + def check_api_key(self): + """Check if the API key is valid for the provider""" + raise NotImplementedError + + def list_llm_models(self) -> List[LLMConfig]: + return [] + + async def list_llm_models_async(self) -> List[LLMConfig]: + return [] + + def list_embedding_models(self) -> List[EmbeddingConfig]: + return [] + + async def list_embedding_models_async(self) -> List[EmbeddingConfig]: + return self.list_embedding_models() + + def get_model_context_window(self, model_name: str) -> Optional[int]: + raise NotImplementedError + + async def get_model_context_window_async(self, model_name: str) -> Optional[int]: + raise NotImplementedError + + def provider_tag(self) -> str: + """String representation of the provider for display purposes""" + raise NotImplementedError + + def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str: + """ + Get the handle for a model, with support for custom overrides. + + Args: + model_name (str): The name of the model. + is_embedding (bool, optional): Whether the handle is for an embedding model. Defaults to False. + + Returns: + str: The handle for the model. + """ + base_name = base_name if base_name else self.name + + overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES + if base_name in overrides and model_name in overrides[base_name]: + model_name = overrides[base_name][model_name] + + return f"{base_name}/{model_name}" + + def cast_to_subtype(self): + match self.provider_type: + case ProviderType.letta: + return LettaProvider(**self.model_dump(exclude_none=True)) + case ProviderType.openai: + return OpenAIProvider(**self.model_dump(exclude_none=True)) + case ProviderType.anthropic: + return AnthropicProvider(**self.model_dump(exclude_none=True)) + case ProviderType.bedrock: + return BedrockProvider(**self.model_dump(exclude_none=True)) + case ProviderType.ollama: + return OllamaProvider(**self.model_dump(exclude_none=True)) + case ProviderType.google_ai: + return GoogleAIProvider(**self.model_dump(exclude_none=True)) + case ProviderType.google_vertex: + return GoogleVertexProvider(**self.model_dump(exclude_none=True)) + case ProviderType.azure: + return AzureProvider(**self.model_dump(exclude_none=True)) + case ProviderType.groq: + return GroqProvider(**self.model_dump(exclude_none=True)) + case ProviderType.together: + return TogetherProvider(**self.model_dump(exclude_none=True)) + case ProviderType.vllm_chat_completions: + return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True)) + case ProviderType.vllm_completions: + return VLLMCompletionsProvider(**self.model_dump(exclude_none=True)) + case ProviderType.xai: + return XAIProvider(**self.model_dump(exclude_none=True)) + case _: + raise ValueError(f"Unknown provider type: {self.provider_type}") + + +class ProviderCreate(ProviderBase): + name: str = Field(..., description="The name of the provider.") + provider_type: ProviderType = Field(..., description="The type of the provider.") + api_key: str = Field(..., description="API key or secret key used for requests to the provider.") + access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.") + region: Optional[str] = Field(None, description="Region used for requests to the provider.") + + +class ProviderUpdate(ProviderBase): + api_key: str = Field(..., description="API key or secret key used for requests to the provider.") + access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.") + region: Optional[str] = Field(None, description="Region used for requests to the provider.") + + +class ProviderCheck(BaseModel): + provider_type: ProviderType = Field(..., description="The type of the provider.") + api_key: str = Field(..., description="API key or secret key used for requests to the provider.") + access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.") + region: Optional[str] = Field(None, description="Region used for requests to the provider.") + + +class LettaProvider(Provider): + provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + + def list_llm_models(self) -> List[LLMConfig]: + return [ + LLMConfig( + model="letta-free", # NOTE: renamed + model_endpoint_type="openai", + model_endpoint=LETTA_MODEL_ENDPOINT, + context_window=30000, + handle=self.get_handle("letta-free"), + provider_name=self.name, + provider_category=self.provider_category, + ) + ] + + async def list_llm_models_async(self) -> List[LLMConfig]: + return [ + LLMConfig( + model="letta-free", # NOTE: renamed + model_endpoint_type="openai", + model_endpoint=LETTA_MODEL_ENDPOINT, + context_window=30000, + handle=self.get_handle("letta-free"), + provider_name=self.name, + provider_category=self.provider_category, + ) + ] + + def list_embedding_models(self): + return [ + EmbeddingConfig( + embedding_model="letta-free", # NOTE: renamed + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_dim=1024, + embedding_chunk_size=300, + handle=self.get_handle("letta-free", is_embedding=True), + batch_size=32, + ) + ] + + +class OpenAIProvider(Provider): + provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + api_key: str = Field(..., description="API key for the OpenAI API.") + base_url: str = Field(..., description="Base URL for the OpenAI API.") + + def check_api_key(self): + from letta.llm_api.openai import openai_check_valid_api_key + + openai_check_valid_api_key(self.base_url, self.api_key) + + def _get_models(self) -> List[dict]: + from letta.llm_api.openai import openai_get_model_list + + # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... + # See: https://openrouter.ai/docs/requests + extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None + + # Similar to Nebius + extra_params = {"verbose": True} if "nebius.com" in self.base_url else None + + response = openai_get_model_list( + self.base_url, + api_key=self.api_key, + extra_params=extra_params, + # fix_url=True, # NOTE: make sure together ends with /v1 + ) + + if "data" in response: + data = response["data"] + else: + # TogetherAI's response is missing the 'data' field + data = response + + return data + + async def _get_models_async(self) -> List[dict]: + from letta.llm_api.openai import openai_get_model_list_async + + # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... + # See: https://openrouter.ai/docs/requests + extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None + + # Similar to Nebius + extra_params = {"verbose": True} if "nebius.com" in self.base_url else None + + response = await openai_get_model_list_async( + self.base_url, + api_key=self.api_key, + extra_params=extra_params, + # fix_url=True, # NOTE: make sure together ends with /v1 + ) + + if "data" in response: + data = response["data"] + else: + # TogetherAI's response is missing the 'data' field + data = response + + return data + + def list_llm_models(self) -> List[LLMConfig]: + data = self._get_models() + return self._list_llm_models(data) + + async def list_llm_models_async(self) -> List[LLMConfig]: + data = await self._get_models_async() + return self._list_llm_models(data) + + def _list_llm_models(self, data) -> List[LLMConfig]: + configs = [] + for model in data: + assert "id" in model, f"OpenAI model missing 'id' field: {model}" + model_name = model["id"] + + if "context_length" in model: + # Context length is returned in OpenRouter as "context_length" + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) + + if not context_window_size: + continue + + # TogetherAI includes the type, which we can use to filter out embedding models + if "api.together.ai" in self.base_url or "api.together.xyz" in self.base_url: + if "type" in model and model["type"] not in ["chat", "language"]: + continue + + # for TogetherAI, we need to skip the models that don't support JSON mode / function calling + # requests.exceptions.HTTPError: HTTP error occurred: 400 Client Error: Bad Request for url: https://api.together.ai/v1/chat/completions | Status code: 400, Message: { + # "error": { + # "message": "mistralai/Mixtral-8x7B-v0.1 is not supported for JSON mode/function calling", + # "type": "invalid_request_error", + # "param": null, + # "code": "constraints_model" + # } + # } + if "config" not in model: + continue + + if "nebius.com" in self.base_url: + # Nebius includes the type, which we can use to filter for text models + try: + model_type = model["architecture"]["modality"] + if model_type not in ["text->text", "text+image->text"]: + # print(f"Skipping model w/ modality {model_type}:\n{model}") + continue + except KeyError: + print(f"Couldn't access architecture type field, skipping model:\n{model}") + continue + + # for openai, filter models + if self.base_url == "https://api.openai.com/v1": + allowed_types = ["gpt-4", "o1", "o3", "o4"] + # NOTE: o1-mini and o1-preview do not support tool calling + # NOTE: o1-mini does not support system messages + # NOTE: o1-pro is only available in Responses API + disallowed_types = ["transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro"] + skip = True + for model_type in allowed_types: + if model_name.startswith(model_type): + skip = False + break + for keyword in disallowed_types: + if keyword in model_name: + skip = True + break + # ignore this model + if skip: + continue + + # set the handle to openai-proxy if the base URL isn't OpenAI + if self.base_url != "https://api.openai.com/v1": + handle = self.get_handle(model_name, base_name="openai-proxy") + else: + handle = self.get_handle(model_name) + + llm_config = LLMConfig( + model=model_name, + model_endpoint_type="openai", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=handle, + provider_name=self.name, + provider_category=self.provider_category, + ) + + # gpt-4o-mini has started to regress with pretty bad emoji spam loops + # this is to counteract that + if "gpt-4o-mini" in model_name: + llm_config.frequency_penalty = 1.0 + if "gpt-4.1-mini" in model_name: + llm_config.frequency_penalty = 1.0 + + configs.append(llm_config) + + # for OpenAI, sort in reverse order + if self.base_url == "https://api.openai.com/v1": + # alphnumeric sort + configs.sort(key=lambda x: x.model, reverse=True) + + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + if self.base_url == "https://api.openai.com/v1": + # TODO: actually automatically list models for OpenAI + return [ + EmbeddingConfig( + embedding_model="text-embedding-ada-002", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=1536, + embedding_chunk_size=300, + handle=self.get_handle("text-embedding-ada-002", is_embedding=True), + batch_size=1024, + ), + EmbeddingConfig( + embedding_model="text-embedding-3-small", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=2000, + embedding_chunk_size=300, + handle=self.get_handle("text-embedding-3-small", is_embedding=True), + batch_size=1024, + ), + EmbeddingConfig( + embedding_model="text-embedding-3-large", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=2000, + embedding_chunk_size=300, + handle=self.get_handle("text-embedding-3-large", is_embedding=True), + batch_size=1024, + ), + ] + + else: + # Actually attempt to list + data = self._get_models() + return self._list_embedding_models(data) + + async def list_embedding_models_async(self) -> List[EmbeddingConfig]: + if self.base_url == "https://api.openai.com/v1": + # TODO: actually automatically list models for OpenAI + return [ + EmbeddingConfig( + embedding_model="text-embedding-ada-002", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=1536, + embedding_chunk_size=300, + handle=self.get_handle("text-embedding-ada-002", is_embedding=True), + batch_size=1024, + ), + EmbeddingConfig( + embedding_model="text-embedding-3-small", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=2000, + embedding_chunk_size=300, + handle=self.get_handle("text-embedding-3-small", is_embedding=True), + batch_size=1024, + ), + EmbeddingConfig( + embedding_model="text-embedding-3-large", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=2000, + embedding_chunk_size=300, + handle=self.get_handle("text-embedding-3-large", is_embedding=True), + batch_size=1024, + ), + ] + + else: + # Actually attempt to list + data = await self._get_models_async() + return self._list_embedding_models(data) + + def _list_embedding_models(self, data) -> List[EmbeddingConfig]: + configs = [] + for model in data: + assert "id" in model, f"Model missing 'id' field: {model}" + model_name = model["id"] + + if "context_length" in model: + # Context length is returned in Nebius as "context_length" + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) + + # We need the context length for embeddings too + if not context_window_size: + continue + + if "nebius.com" in self.base_url: + # Nebius includes the type, which we can use to filter for embedidng models + try: + model_type = model["architecture"]["modality"] + if model_type not in ["text->embedding"]: + # print(f"Skipping model w/ modality {model_type}:\n{model}") + continue + except KeyError: + print(f"Couldn't access architecture type field, skipping model:\n{model}") + continue + + elif "together.ai" in self.base_url or "together.xyz" in self.base_url: + # TogetherAI includes the type, which we can use to filter for embedding models + if "type" in model and model["type"] not in ["embedding"]: + # print(f"Skipping model w/ modality {model_type}:\n{model}") + continue + + else: + # For other providers we should skip by default, since we don't want to assume embeddings are supported + continue + + configs.append( + EmbeddingConfig( + embedding_model=model_name, + embedding_endpoint_type=self.provider_type, + embedding_endpoint=self.base_url, + embedding_dim=context_window_size, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=self.get_handle(model, is_embedding=True), + ) + ) + + return configs + + def get_model_context_window_size(self, model_name: str): + if model_name in LLM_MAX_TOKENS: + return LLM_MAX_TOKENS[model_name] + else: + return LLM_MAX_TOKENS["DEFAULT"] + + +class DeepSeekProvider(OpenAIProvider): + """ + DeepSeek ChatCompletions API is similar to OpenAI's reasoning API, + but with slight differences: + * For example, DeepSeek's API requires perfect interleaving of user/assistant + * It also does not support native function calling + """ + + provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.") + api_key: str = Field(..., description="API key for the DeepSeek API.") + + def get_model_context_window_size(self, model_name: str) -> Optional[int]: + # DeepSeek doesn't return context window in the model listing, + # so these are hardcoded from their website + if model_name == "deepseek-reasoner": + return 64000 + elif model_name == "deepseek-chat": + return 64000 + else: + return None + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list + + response = openai_get_model_list(self.base_url, api_key=self.api_key) + + if "data" in response: + data = response["data"] + else: + data = response + + configs = [] + for model in data: + assert "id" in model, f"DeepSeek model missing 'id' field: {model}" + model_name = model["id"] + + # In case DeepSeek starts supporting it in the future: + if "context_length" in model: + # Context length is returned in OpenRouter as "context_length" + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) + + if not context_window_size: + warnings.warn(f"Couldn't find context window size for model {model_name}") + continue + + # Not used for deepseek-reasoner, but otherwise is true + put_inner_thoughts_in_kwargs = False if model_name == "deepseek-reasoner" else True + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="deepseek", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + # No embeddings supported + return [] + + +class LMStudioOpenAIProvider(OpenAIProvider): + provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.") + api_key: Optional[str] = Field(None, description="API key for the LMStudio API.") + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list + + # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models' + MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0" + response = openai_get_model_list(MODEL_ENDPOINT_URL) + + """ + Example response: + + { + "object": "list", + "data": [ + { + "id": "qwen2-vl-7b-instruct", + "object": "model", + "type": "vlm", + "publisher": "mlx-community", + "arch": "qwen2_vl", + "compatibility_type": "mlx", + "quantization": "4bit", + "state": "not-loaded", + "max_context_length": 32768 + }, + ... + """ + if "data" not in response: + warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}") + return [] + + configs = [] + for model in response["data"]: + assert "id" in model, f"Model missing 'id' field: {model}" + model_name = model["id"] + + if "type" not in model: + warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}") + continue + elif model["type"] not in ["vlm", "llm"]: + continue + + if "max_context_length" in model: + context_window_size = model["max_context_length"] + else: + warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}") + continue + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="openai", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + ) + ) + + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + from letta.llm_api.openai import openai_get_model_list + + # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models' + MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0" + response = openai_get_model_list(MODEL_ENDPOINT_URL) + + """ + Example response: + { + "object": "list", + "data": [ + { + "id": "text-embedding-nomic-embed-text-v1.5", + "object": "model", + "type": "embeddings", + "publisher": "nomic-ai", + "arch": "nomic-bert", + "compatibility_type": "gguf", + "quantization": "Q4_0", + "state": "not-loaded", + "max_context_length": 2048 + } + ... + """ + if "data" not in response: + warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}") + return [] + + configs = [] + for model in response["data"]: + assert "id" in model, f"Model missing 'id' field: {model}" + model_name = model["id"] + + if "type" not in model: + warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}") + continue + elif model["type"] not in ["embeddings"]: + continue + + if "max_context_length" in model: + context_window_size = model["max_context_length"] + else: + warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}") + continue + + configs.append( + EmbeddingConfig( + embedding_model=model_name, + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=context_window_size, + embedding_chunk_size=300, # NOTE: max is 2048 + handle=self.get_handle(model_name), + ), + ) + + return configs + + +class XAIProvider(OpenAIProvider): + """https://docs.x.ai/docs/api-reference""" + + provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + api_key: str = Field(..., description="API key for the xAI/Grok API.") + base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.") + + def get_model_context_window_size(self, model_name: str) -> Optional[int]: + # xAI doesn't return context window in the model listing, + # so these are hardcoded from their website + if model_name == "grok-2-1212": + return 131072 + # NOTE: disabling the minis for now since they return weird MM parts + # elif model_name == "grok-3-mini-fast-beta": + # return 131072 + # elif model_name == "grok-3-mini-beta": + # return 131072 + elif model_name == "grok-3-fast-beta": + return 131072 + elif model_name == "grok-3-beta": + return 131072 + else: + return None + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list + + response = openai_get_model_list(self.base_url, api_key=self.api_key) + + if "data" in response: + data = response["data"] + else: + data = response + + configs = [] + for model in data: + assert "id" in model, f"xAI/Grok model missing 'id' field: {model}" + model_name = model["id"] + + # In case xAI starts supporting it in the future: + if "context_length" in model: + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) + + if not context_window_size: + warnings.warn(f"Couldn't find context window size for model {model_name}") + continue + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="xai", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + # No embeddings supported + return [] + + +class AnthropicProvider(Provider): + provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + api_key: str = Field(..., description="API key for the Anthropic API.") + base_url: str = "https://api.anthropic.com/v1" + + def check_api_key(self): + from letta.llm_api.anthropic import anthropic_check_valid_api_key + + anthropic_check_valid_api_key(self.api_key) + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.anthropic import anthropic_get_model_list + + models = anthropic_get_model_list(api_key=self.api_key) + return self._list_llm_models(models) + + async def list_llm_models_async(self) -> List[LLMConfig]: + from letta.llm_api.anthropic import anthropic_get_model_list_async + + models = await anthropic_get_model_list_async(api_key=self.api_key) + return self._list_llm_models(models) + + def _list_llm_models(self, models) -> List[LLMConfig]: + from letta.llm_api.anthropic import MODEL_LIST + + configs = [] + for model in models: + + if model["type"] != "model": + continue + + if "id" not in model: + continue + + # Don't support 2.0 and 2.1 + if model["id"].startswith("claude-2"): + continue + + # Anthropic doesn't return the context window in their API + if "context_window" not in model: + # Remap list to name: context_window + model_library = {m["name"]: m["context_window"] for m in MODEL_LIST} + # Attempt to look it up in a hardcoded list + if model["id"] in model_library: + model["context_window"] = model_library[model["id"]] + else: + # On fallback, we can set 200k (generally safe), but we should warn the user + warnings.warn(f"Couldn't find context window size for model {model['id']}, defaulting to 200,000") + model["context_window"] = 200000 + + max_tokens = 8192 + if "claude-3-opus" in model["id"]: + max_tokens = 4096 + if "claude-3-haiku" in model["id"]: + max_tokens = 4096 + # TODO: set for 3-7 extended thinking mode + + # We set this to false by default, because Anthropic can + # natively support tags inside of content fields + # However, putting COT inside of tool calls can make it more + # reliable for tool calling (no chance of a non-tool call step) + # Since tool_choice_type 'any' doesn't work with in-content COT + # NOTE For Haiku, it can be flaky if we don't enable this by default + # inner_thoughts_in_kwargs = True if "haiku" in model["id"] else False + inner_thoughts_in_kwargs = True # we no longer support thinking tags + + configs.append( + LLMConfig( + model=model["id"], + model_endpoint_type="anthropic", + model_endpoint=self.base_url, + context_window=model["context_window"], + handle=self.get_handle(model["id"]), + put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, + max_tokens=max_tokens, + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs + + +class MistralProvider(Provider): + provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + api_key: str = Field(..., description="API key for the Mistral API.") + base_url: str = "https://api.mistral.ai/v1" + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.mistral import mistral_get_model_list + + # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... + # See: https://openrouter.ai/docs/requests + response = mistral_get_model_list(self.base_url, api_key=self.api_key) + + assert "data" in response, f"Mistral model query response missing 'data' field: {response}" + + configs = [] + for model in response["data"]: + # If model has chat completions and function calling enabled + if model["capabilities"]["completion_chat"] and model["capabilities"]["function_calling"]: + configs.append( + LLMConfig( + model=model["id"], + model_endpoint_type="openai", + model_endpoint=self.base_url, + context_window=model["max_context_length"], + handle=self.get_handle(model["id"]), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + # Not supported for mistral + return [] + + def get_model_context_window(self, model_name: str) -> Optional[int]: + # Redoing this is fine because it's a pretty lightweight call + models = self.list_llm_models() + + for m in models: + if model_name in m["id"]: + return int(m["max_context_length"]) + + return None + + +class OllamaProvider(OpenAIProvider): + """Ollama provider that uses the native /api/generate endpoint + + See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion + """ + + provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = Field(..., description="Base URL for the Ollama API.") + api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).") + default_prompt_formatter: str = Field( + ..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API." + ) + + async def list_llm_models_async(self) -> List[LLMConfig]: + """Async version of list_llm_models below""" + endpoint = f"{self.base_url}/api/tags" + async with aiohttp.ClientSession() as session: + async with session.get(endpoint) as response: + if response.status != 200: + raise Exception(f"Failed to list Ollama models: {response.text}") + response_json = await response.json() + + configs = [] + for model in response_json["models"]: + context_window = self.get_model_context_window(model["name"]) + if context_window is None: + print(f"Ollama model {model['name']} has no context window") + continue + configs.append( + LLMConfig( + model=model["name"], + model_endpoint_type="ollama", + model_endpoint=self.base_url, + model_wrapper=self.default_prompt_formatter, + context_window=context_window, + handle=self.get_handle(model["name"]), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs + + def list_llm_models(self) -> List[LLMConfig]: + # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models + response = requests.get(f"{self.base_url}/api/tags") + if response.status_code != 200: + raise Exception(f"Failed to list Ollama models: {response.text}") + response_json = response.json() + + configs = [] + for model in response_json["models"]: + context_window = self.get_model_context_window(model["name"]) + if context_window is None: + print(f"Ollama model {model['name']} has no context window") + continue + configs.append( + LLMConfig( + model=model["name"], + model_endpoint_type="ollama", + model_endpoint=self.base_url, + model_wrapper=self.default_prompt_formatter, + context_window=context_window, + handle=self.get_handle(model["name"]), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs + + def get_model_context_window(self, model_name: str) -> Optional[int]: + response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) + response_json = response.json() + + ## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675 + # possible_keys = [ + # # OPT + # "max_position_embeddings", + # # GPT-2 + # "n_positions", + # # MPT + # "max_seq_len", + # # ChatGLM2 + # "seq_length", + # # Command-R + # "model_max_length", + # # Others + # "max_sequence_length", + # "max_seq_length", + # "seq_len", + # ] + # max_position_embeddings + # parse model cards: nous, dolphon, llama + if "model_info" not in response_json: + if "error" in response_json: + print(f"Ollama fetch model info error for {model_name}: {response_json['error']}") + return None + for key, value in response_json["model_info"].items(): + if "context_length" in key: + return value + return None + + def _get_model_embedding_dim(self, model_name: str): + response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) + response_json = response.json() + return self._get_model_embedding_dim_impl(response_json, model_name) + + async def _get_model_embedding_dim_async(self, model_name: str): + async with aiohttp.ClientSession() as session: + async with session.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) as response: + response_json = await response.json() + return self._get_model_embedding_dim_impl(response_json, model_name) + + @staticmethod + def _get_model_embedding_dim_impl(response_json: dict, model_name: str): + if "model_info" not in response_json: + if "error" in response_json: + print(f"Ollama fetch model info error for {model_name}: {response_json['error']}") + return None + for key, value in response_json["model_info"].items(): + if "embedding_length" in key: + return value + return None + + async def list_embedding_models_async(self) -> List[EmbeddingConfig]: + """Async version of list_embedding_models below""" + endpoint = f"{self.base_url}/api/tags" + async with aiohttp.ClientSession() as session: + async with session.get(endpoint) as response: + if response.status != 200: + raise Exception(f"Failed to list Ollama models: {response.text}") + response_json = await response.json() + + configs = [] + for model in response_json["models"]: + embedding_dim = await self._get_model_embedding_dim_async(model["name"]) + if not embedding_dim: + print(f"Ollama model {model['name']} has no embedding dimension") + continue + configs.append( + EmbeddingConfig( + embedding_model=model["name"], + embedding_endpoint_type="ollama", + embedding_endpoint=self.base_url, + embedding_dim=embedding_dim, + embedding_chunk_size=300, + handle=self.get_handle(model["name"], is_embedding=True), + ) + ) + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models + response = requests.get(f"{self.base_url}/api/tags") + if response.status_code != 200: + raise Exception(f"Failed to list Ollama models: {response.text}") + response_json = response.json() + + configs = [] + for model in response_json["models"]: + embedding_dim = self._get_model_embedding_dim(model["name"]) + if not embedding_dim: + print(f"Ollama model {model['name']} has no embedding dimension") + continue + configs.append( + EmbeddingConfig( + embedding_model=model["name"], + embedding_endpoint_type="ollama", + embedding_endpoint=self.base_url, + embedding_dim=embedding_dim, + embedding_chunk_size=300, + handle=self.get_handle(model["name"], is_embedding=True), + ) + ) + return configs + + +class GroqProvider(OpenAIProvider): + provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = "https://api.groq.com/openai/v1" + api_key: str = Field(..., description="API key for the Groq API.") + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list + + response = openai_get_model_list(self.base_url, api_key=self.api_key) + configs = [] + for model in response["data"]: + if not "context_window" in model: + continue + configs.append( + LLMConfig( + model=model["id"], + model_endpoint_type="groq", + model_endpoint=self.base_url, + context_window=model["context_window"], + handle=self.get_handle(model["id"]), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + return [] + + +class TogetherProvider(OpenAIProvider): + """TogetherAI provider that uses the /completions API + + TogetherAI can also be used via the /chat/completions API + by settings OPENAI_API_KEY and OPENAI_API_BASE to the TogetherAI API key + and API URL, however /completions is preferred because their /chat/completions + function calling support is limited. + """ + + provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = "https://api.together.ai/v1" + api_key: str = Field(..., description="API key for the TogetherAI API.") + default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list + + models = openai_get_model_list(self.base_url, api_key=self.api_key) + return self._list_llm_models(models) + + async def list_llm_models_async(self) -> List[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list_async + + models = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + return self._list_llm_models(models) + + def _list_llm_models(self, models) -> List[LLMConfig]: + pass + + # TogetherAI's response is missing the 'data' field + # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}" + if "data" in models: + data = models["data"] + else: + data = models + + configs = [] + for model in data: + assert "id" in model, f"TogetherAI model missing 'id' field: {model}" + model_name = model["id"] + + if "context_length" in model: + # Context length is returned in OpenRouter as "context_length" + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) + + # We need the context length for embeddings too + if not context_window_size: + continue + + # Skip models that are too small for Letta + if context_window_size <= MIN_CONTEXT_WINDOW: + continue + + # TogetherAI includes the type, which we can use to filter for embedding models + if "type" in model and model["type"] not in ["chat", "language"]: + continue + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="together", + model_endpoint=self.base_url, + model_wrapper=self.default_prompt_formatter, + context_window=context_window_size, + handle=self.get_handle(model_name), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + # TODO renable once we figure out how to pass API keys through properly + return [] + + # from letta.llm_api.openai import openai_get_model_list + + # response = openai_get_model_list(self.base_url, api_key=self.api_key) + + # # TogetherAI's response is missing the 'data' field + # # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}" + # if "data" in response: + # data = response["data"] + # else: + # data = response + + # configs = [] + # for model in data: + # assert "id" in model, f"TogetherAI model missing 'id' field: {model}" + # model_name = model["id"] + + # if "context_length" in model: + # # Context length is returned in OpenRouter as "context_length" + # context_window_size = model["context_length"] + # else: + # context_window_size = self.get_model_context_window_size(model_name) + + # if not context_window_size: + # continue + + # # TogetherAI includes the type, which we can use to filter out embedding models + # if "type" in model and model["type"] not in ["embedding"]: + # continue + + # configs.append( + # EmbeddingConfig( + # embedding_model=model_name, + # embedding_endpoint_type="openai", + # embedding_endpoint=self.base_url, + # embedding_dim=context_window_size, + # embedding_chunk_size=300, # TODO: change? + # ) + # ) + + # return configs + + +class GoogleAIProvider(Provider): + # gemini + provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + api_key: str = Field(..., description="API key for the Google AI API.") + base_url: str = "https://generativelanguage.googleapis.com" + + def check_api_key(self): + from letta.llm_api.google_ai_client import google_ai_check_valid_api_key + + google_ai_check_valid_api_key(self.api_key) + + def list_llm_models(self): + from letta.llm_api.google_ai_client import google_ai_get_model_list + + model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) + model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] + model_options = [str(m["name"]) for m in model_options] + + # filter by model names + model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] + + # Add support for all gemini models + model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] + + configs = [] + for model in model_options: + configs.append( + LLMConfig( + model=model, + model_endpoint_type="google_ai", + model_endpoint=self.base_url, + context_window=self.get_model_context_window(model), + handle=self.get_handle(model), + max_tokens=8192, + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs + + async def list_llm_models_async(self): + import asyncio + + from letta.llm_api.google_ai_client import google_ai_get_model_list_async + + # Get and filter the model list + model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) + model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] + model_options = [str(m["name"]) for m in model_options] + + # filter by model names + model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] + + # Add support for all gemini models + model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] + + # Prepare tasks for context window lookups in parallel + async def create_config(model): + context_window = await self.get_model_context_window_async(model) + return LLMConfig( + model=model, + model_endpoint_type="google_ai", + model_endpoint=self.base_url, + context_window=context_window, + handle=self.get_handle(model), + max_tokens=8192, + provider_name=self.name, + provider_category=self.provider_category, + ) + + # Execute all config creation tasks concurrently + configs = await asyncio.gather(*[create_config(model) for model in model_options]) + + return configs + + def list_embedding_models(self): + from letta.llm_api.google_ai_client import google_ai_get_model_list + + # TODO: use base_url instead + model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) + return self._list_embedding_models(model_options) + + async def list_embedding_models_async(self): + from letta.llm_api.google_ai_client import google_ai_get_model_list_async + + # TODO: use base_url instead + model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) + return self._list_embedding_models(model_options) + + def _list_embedding_models(self, model_options): + # filter by 'generateContent' models + model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]] + model_options = [str(m["name"]) for m in model_options] + model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] + + configs = [] + for model in model_options: + configs.append( + EmbeddingConfig( + embedding_model=model, + embedding_endpoint_type="google_ai", + embedding_endpoint=self.base_url, + embedding_dim=768, + embedding_chunk_size=300, # NOTE: max is 2048 + handle=self.get_handle(model, is_embedding=True), + batch_size=1024, + ) + ) + return configs + + def get_model_context_window(self, model_name: str) -> Optional[int]: + from letta.llm_api.google_ai_client import google_ai_get_model_context_window + + if model_name in LLM_MAX_TOKENS: + return LLM_MAX_TOKENS[model_name] + else: + return google_ai_get_model_context_window(self.base_url, self.api_key, model_name) + + async def get_model_context_window_async(self, model_name: str) -> Optional[int]: + from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async + + if model_name in LLM_MAX_TOKENS: + return LLM_MAX_TOKENS[model_name] + else: + return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name) + + +class GoogleVertexProvider(Provider): + provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.") + google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.") + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.google_constants import GOOGLE_MODEL_TO_CONTEXT_LENGTH + + configs = [] + for model, context_length in GOOGLE_MODEL_TO_CONTEXT_LENGTH.items(): + configs.append( + LLMConfig( + model=model, + model_endpoint_type="google_vertex", + model_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}", + context_window=context_length, + handle=self.get_handle(model), + max_tokens=8192, + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + from letta.llm_api.google_constants import GOOGLE_EMBEDING_MODEL_TO_DIM + + configs = [] + for model, dim in GOOGLE_EMBEDING_MODEL_TO_DIM.items(): + configs.append( + EmbeddingConfig( + embedding_model=model, + embedding_endpoint_type="google_vertex", + embedding_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}", + embedding_dim=dim, + embedding_chunk_size=300, # NOTE: max is 2048 + handle=self.get_handle(model, is_embedding=True), + batch_size=1024, + ) + ) + return configs + + +class AzureProvider(Provider): + provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation + base_url: str = Field( + ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`." + ) + api_key: str = Field(..., description="API key for the Azure API.") + api_version: str = Field(latest_api_version, description="API version for the Azure API") + + @model_validator(mode="before") + def set_default_api_version(cls, values): + """ + This ensures that api_version is always set to the default if None is passed in. + """ + if values.get("api_version") is None: + values["api_version"] = cls.model_fields["latest_api_version"].default + return values + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.azure_openai import azure_openai_get_chat_completion_model_list + + model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version) + configs = [] + for model_option in model_options: + model_name = model_option["id"] + context_window_size = self.get_model_context_window(model_name) + model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version) + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="azure", + model_endpoint=model_endpoint, + context_window=context_window_size, + handle=self.get_handle(model_name), + provider_name=self.name, + provider_category=self.provider_category, + ), + ) + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list + + model_options = azure_openai_get_embeddings_model_list( + self.base_url, api_key=self.api_key, api_version=self.api_version, require_embedding_in_name=True + ) + configs = [] + for model_option in model_options: + model_name = model_option["id"] + model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version) + configs.append( + EmbeddingConfig( + embedding_model=model_name, + embedding_endpoint_type="azure", + embedding_endpoint=model_endpoint, + embedding_dim=768, + embedding_chunk_size=300, # NOTE: max is 2048 + handle=self.get_handle(model_name), + batch_size=1024, + ), + ) + return configs + + def get_model_context_window(self, model_name: str) -> Optional[int]: + """ + This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model. + """ + context_window = AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, None) + if context_window is None: + context_window = LLM_MAX_TOKENS.get(model_name, 4096) + return context_window + + +class VLLMChatCompletionsProvider(Provider): + """vLLM provider that treats vLLM as an OpenAI /chat/completions proxy""" + + # NOTE: vLLM only serves one model at a time (so could configure that through env variables) + provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = Field(..., description="Base URL for the vLLM API.") + + def list_llm_models(self) -> List[LLMConfig]: + # not supported with vLLM + from letta.llm_api.openai import openai_get_model_list + + assert self.base_url, "base_url is required for vLLM provider" + response = openai_get_model_list(self.base_url, api_key=None) + + configs = [] + for model in response["data"]: + configs.append( + LLMConfig( + model=model["id"], + model_endpoint_type="openai", + model_endpoint=self.base_url, + context_window=model["max_model_len"], + handle=self.get_handle(model["id"]), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + # not supported with vLLM + return [] + + +class VLLMCompletionsProvider(Provider): + """This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper""" + + # NOTE: vLLM only serves one model at a time (so could configure that through env variables) + provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + base_url: str = Field(..., description="Base URL for the vLLM API.") + default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") + + def list_llm_models(self) -> List[LLMConfig]: + # not supported with vLLM + from letta.llm_api.openai import openai_get_model_list + + response = openai_get_model_list(self.base_url, api_key=None) + + configs = [] + for model in response["data"]: + configs.append( + LLMConfig( + model=model["id"], + model_endpoint_type="vllm", + model_endpoint=self.base_url, + model_wrapper=self.default_prompt_formatter, + context_window=model["max_model_len"], + handle=self.get_handle(model["id"]), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + # not supported with vLLM + return [] + + +class CohereProvider(OpenAIProvider): + pass + + +class BedrockProvider(Provider): + provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + region: str = Field(..., description="AWS region for Bedrock") + + def check_api_key(self): + """Check if the Bedrock credentials are valid""" + from letta.errors import LLMAuthenticationError + from letta.llm_api.aws_bedrock import bedrock_get_model_list + + try: + # For BYOK providers, use the custom credentials + if self.provider_category == ProviderCategory.byok: + # If we can list models, the credentials are valid + bedrock_get_model_list( + region_name=self.region, + access_key_id=self.access_key, + secret_access_key=self.api_key, # api_key stores the secret access key + ) + else: + # For base providers, use default credentials + bedrock_get_model_list(region_name=self.region) + except Exception as e: + raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}") + + def list_llm_models(self): + from letta.llm_api.aws_bedrock import bedrock_get_model_list + + models = bedrock_get_model_list(self.region) + + configs = [] + for model_summary in models: + model_arn = model_summary["inferenceProfileArn"] + configs.append( + LLMConfig( + model=model_arn, + model_endpoint_type=self.provider_type.value, + model_endpoint=None, + context_window=self.get_model_context_window(model_arn), + handle=self.get_handle(model_arn), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs + + async def list_llm_models_async(self) -> List[LLMConfig]: + from letta.llm_api.aws_bedrock import bedrock_get_model_list_async + + models = await bedrock_get_model_list_async( + self.access_key, + self.api_key, + self.region, + ) + + configs = [] + for model_summary in models: + model_arn = model_summary["inferenceProfileArn"] + configs.append( + LLMConfig( + model=model_arn, + model_endpoint_type=self.provider_type.value, + model_endpoint=None, + context_window=self.get_model_context_window(model_arn), + handle=self.get_handle(model_arn), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs + + def list_embedding_models(self): + return [] + + def get_model_context_window(self, model_name: str) -> Optional[int]: + # Context windows for Claude models + from letta.llm_api.aws_bedrock import bedrock_get_model_context_window + + return bedrock_get_model_context_window(model_name) + + def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str: + print(model_name) + model = model_name.split(".")[-1] + return f"{self.name}/{model}" diff --git a/letta/schemas/providers/ollama.py b/letta/schemas/providers/ollama.py index d34d86d7..4cd70612 100644 --- a/letta/schemas/providers/ollama.py +++ b/letta/schemas/providers/ollama.py @@ -3,7 +3,7 @@ from typing import Literal import aiohttp from pydantic import Field -from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, DEFAULT_CONTEXT_WINDOW, DEFAULT_EMBEDDING_DIM, OLLAMA_API_PREFIX from letta.log import get_logger from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import ProviderCategory, ProviderType @@ -12,8 +12,6 @@ from letta.schemas.providers.openai import OpenAIProvider logger = get_logger(__name__) -ollama_prefix = "/v1" - class OllamaProvider(OpenAIProvider): """Ollama provider that uses the native /api/generate endpoint @@ -41,19 +39,30 @@ class OllamaProvider(OpenAIProvider): response_json = await response.json() configs = [] - for model in response_json["models"]: - context_window = await self._get_model_context_window(model["name"]) + for model in response_json.get("models", []): + model_name = model["name"] + model_details = await self._get_model_details_async(model_name) + if not model_details or "completion" not in model_details.get("capabilities", []): + continue + + context_window = None + model_info = model_details.get("model_info", {}) + if architecture := model_info.get("general.architecture"): + if context_length := model_info.get(f"{architecture}.context_length"): + context_window = int(context_length) + if context_window is None: - print(f"Ollama model {model['name']} has no context window, using default 32000") - context_window = 32000 + logger.warning(f"Ollama model {model_name} has no context window, using default {DEFAULT_CONTEXT_WINDOW}") + context_window = DEFAULT_CONTEXT_WINDOW + configs.append( LLMConfig( - model=model["name"], + model=model_name, model_endpoint_type=ProviderType.ollama, - model_endpoint=f"{self.base_url}{ollama_prefix}", + model_endpoint=f"{self.base_url}{OLLAMA_API_PREFIX}", model_wrapper=self.default_prompt_formatter, context_window=context_window, - handle=self.get_handle(model["name"]), + handle=self.get_handle(model_name), provider_name=self.name, provider_category=self.provider_category, ) @@ -76,22 +85,23 @@ class OllamaProvider(OpenAIProvider): for model in response_json["models"]: embedding_dim = await self._get_model_embedding_dim(model["name"]) if not embedding_dim: - print(f"Ollama model {model['name']} has no embedding dimension, using default 1024") - # continue - embedding_dim = 1024 + logger.warning(f"Ollama model {model_name} has no embedding dimension, using default {DEFAULT_EMBEDDING_DIM}") + embedding_dim = DEFAULT_EMBEDDING_DIM + configs.append( EmbeddingConfig( - embedding_model=model["name"], + embedding_model=model_name, embedding_endpoint_type=ProviderType.ollama, - embedding_endpoint=f"{self.base_url}{ollama_prefix}", + embedding_endpoint=f"{self.base_url}{OLLAMA_API_PREFIX}", embedding_dim=embedding_dim, embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, - handle=self.get_handle(model["name"], is_embedding=True), + handle=self.get_handle(model_name, is_embedding=True), ) ) return configs - async def _get_model_context_window(self, model_name: str) -> int | None: + async def _get_model_details_async(self, model_name: str) -> dict | None: + """Get detailed information for a specific model from /api/show.""" endpoint = f"{self.base_url}/api/show" payload = {"name": model_name} @@ -102,39 +112,7 @@ class OllamaProvider(OpenAIProvider): error_text = await response.text() logger.warning(f"Failed to get model info for {model_name}: {response.status} - {error_text}") return None - - response_json = await response.json() - model_info = response_json.get("model_info", {}) - - if architecture := model_info.get("general.architecture"): - if context_length := model_info.get(f"{architecture}.context_length"): - return int(context_length) - + return await response.json() except Exception as e: - logger.warning(f"Failed to get model context window for {model_name} with error: {e}") - - return None - - async def _get_model_embedding_dim(self, model_name: str) -> int | None: - endpoint = f"{self.base_url}/api/show" - payload = {"name": model_name} - - try: - async with aiohttp.ClientSession() as session: - async with session.post(endpoint, json=payload) as response: - if response.status != 200: - error_text = await response.text() - logger.warning(f"Failed to get model info for {model_name}: {response.status} - {error_text}") - return None - - response_json = await response.json() - model_info = response_json.get("model_info", {}) - - if architecture := model_info.get("general.architecture"): - if embedding_length := model_info.get(f"{architecture}.embedding_length"): - return int(embedding_length) - - except Exception as e: - logger.warning(f"Failed to get model embedding dimension for {model_name} with error: {e}") - - return None + logger.warning(f"Failed to get model details for {model_name} with error: {e}") + return None diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py index 86a0b54f..75579145 100644 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -30,9 +30,7 @@ logger = get_logger(__name__) responses={ 200: { "description": "Successful response", - "content": { - "text/event-stream": {"description": "Server-Sent Events stream"}, - }, + "content": {"text/event-stream": {}}, } }, ) diff --git a/letta/server/rest_api/routers/v1/voice.py b/letta/server/rest_api/routers/v1/voice.py index e33bf9bf..a8b172de 100644 --- a/letta/server/rest_api/routers/v1/voice.py +++ b/letta/server/rest_api/routers/v1/voice.py @@ -25,9 +25,7 @@ logger = get_logger(__name__) responses={ 200: { "description": "Successful response", - "content": { - "text/event-stream": {"description": "Server-Sent Events stream"}, - }, + "content": {"text/event-stream": {}}, } }, ) diff --git a/letta/services/agent_file_manager.py b/letta/services/agent_file_manager.py new file mode 100644 index 00000000..78f34b89 --- /dev/null +++ b/letta/services/agent_file_manager.py @@ -0,0 +1,685 @@ +from datetime import datetime, timezone +from typing import Dict, List + +from letta.errors import AgentFileExportError, AgentFileImportError +from letta.log import get_logger +from letta.schemas.agent import AgentState, CreateAgent +from letta.schemas.agent_file import ( + AgentFileSchema, + AgentSchema, + BlockSchema, + FileAgentSchema, + FileSchema, + GroupSchema, + ImportResult, + MessageSchema, + SourceSchema, + ToolSchema, +) +from letta.schemas.block import Block +from letta.schemas.file import FileMetadata +from letta.schemas.message import Message +from letta.schemas.source import Source +from letta.schemas.tool import Tool +from letta.schemas.user import User +from letta.services.agent_manager import AgentManager +from letta.services.block_manager import BlockManager +from letta.services.file_manager import FileManager +from letta.services.file_processor.embedder.base_embedder import BaseEmbedder +from letta.services.file_processor.file_processor import FileProcessor +from letta.services.file_processor.parser.mistral_parser import MistralFileParser +from letta.services.files_agents_manager import FileAgentManager +from letta.services.group_manager import GroupManager +from letta.services.mcp_manager import MCPManager +from letta.services.message_manager import MessageManager +from letta.services.source_manager import SourceManager +from letta.services.tool_manager import ToolManager +from letta.utils import get_latest_alembic_revision + +logger = get_logger(__name__) + + +class AgentFileManager: + """ + Manages export and import of agent files between database and AgentFileSchema format. + + Handles: + - ID mapping between database IDs and human-readable file IDs + - Coordination across multiple entity managers + - Transaction safety during imports + - Referential integrity validation + """ + + def __init__( + self, + agent_manager: AgentManager, + tool_manager: ToolManager, + source_manager: SourceManager, + block_manager: BlockManager, + group_manager: GroupManager, + mcp_manager: MCPManager, + file_manager: FileManager, + file_agent_manager: FileAgentManager, + message_manager: MessageManager, + embedder: BaseEmbedder, + file_parser: MistralFileParser, + using_pinecone: bool = False, + ): + self.agent_manager = agent_manager + self.tool_manager = tool_manager + self.source_manager = source_manager + self.block_manager = block_manager + self.group_manager = group_manager + self.mcp_manager = mcp_manager + self.file_manager = file_manager + self.file_agent_manager = file_agent_manager + self.message_manager = message_manager + self.embedder = embedder + self.file_parser = file_parser + self.using_pinecone = using_pinecone + + # ID mapping state for export + self._db_to_file_ids: Dict[str, str] = {} + + # Counters for generating Stripe-style IDs + self._id_counters: Dict[str, int] = { + AgentSchema.__id_prefix__: 0, + GroupSchema.__id_prefix__: 0, + BlockSchema.__id_prefix__: 0, + FileSchema.__id_prefix__: 0, + SourceSchema.__id_prefix__: 0, + ToolSchema.__id_prefix__: 0, + MessageSchema.__id_prefix__: 0, + FileAgentSchema.__id_prefix__: 0, + # MCPServerSchema.__id_prefix__: 0, + } + + def _reset_state(self): + """Reset internal state for a new operation""" + self._db_to_file_ids.clear() + for key in self._id_counters: + self._id_counters[key] = 0 + + def _generate_file_id(self, entity_type: str) -> str: + """Generate a Stripe-style ID for the given entity type""" + counter = self._id_counters[entity_type] + file_id = f"{entity_type}-{counter}" + self._id_counters[entity_type] += 1 + return file_id + + def _map_db_to_file_id(self, db_id: str, entity_type: str, allow_new: bool = True) -> str: + """Map a database UUID to a file ID, creating if needed (export only)""" + if db_id in self._db_to_file_ids: + return self._db_to_file_ids[db_id] + + if not allow_new: + raise AgentFileExportError( + f"Unexpected new {entity_type} ID '{db_id}' encountered during conversion. " + f"All IDs should have been mapped during agent processing." + ) + + file_id = self._generate_file_id(entity_type) + self._db_to_file_ids[db_id] = file_id + return file_id + + def _extract_unique_tools(self, agent_states: List[AgentState]) -> List: + """Extract unique tools across all agent states by ID""" + all_tools = [] + for agent_state in agent_states: + if agent_state.tools: + all_tools.extend(agent_state.tools) + + unique_tools = {} + for tool in all_tools: + unique_tools[tool.id] = tool + + return sorted(unique_tools.values(), key=lambda x: x.name) + + def _extract_unique_blocks(self, agent_states: List[AgentState]) -> List: + """Extract unique blocks across all agent states by ID""" + all_blocks = [] + for agent_state in agent_states: + if agent_state.memory and agent_state.memory.blocks: + all_blocks.extend(agent_state.memory.blocks) + + unique_blocks = {} + for block in all_blocks: + unique_blocks[block.id] = block + + return sorted(unique_blocks.values(), key=lambda x: x.label) + + async def _extract_unique_sources_and_files_from_agents( + self, agent_states: List[AgentState], actor: User, files_agents_cache: dict = None + ) -> tuple[List[Source], List[FileMetadata]]: + """Extract unique sources and files from agent states using bulk operations""" + + all_source_ids = set() + all_file_ids = set() + + for agent_state in agent_states: + files_agents = await self.file_agent_manager.list_files_for_agent( + agent_id=agent_state.id, actor=actor, is_open_only=False, return_as_blocks=False + ) + # cache the results for reuse during conversion + if files_agents_cache is not None: + files_agents_cache[agent_state.id] = files_agents + + for file_agent in files_agents: + all_source_ids.add(file_agent.source_id) + all_file_ids.add(file_agent.file_id) + sources = await self.source_manager.get_sources_by_ids_async(list(all_source_ids), actor) + files = await self.file_manager.get_files_by_ids_async(list(all_file_ids), actor, include_content=True) + + return sources, files + + async def _convert_agent_state_to_schema(self, agent_state: AgentState, actor: User, files_agents_cache: dict = None) -> AgentSchema: + """Convert AgentState to AgentSchema with ID remapping""" + + agent_file_id = self._map_db_to_file_id(agent_state.id, AgentSchema.__id_prefix__) + + # use cached file-agent data if available, otherwise fetch + if files_agents_cache is not None and agent_state.id in files_agents_cache: + files_agents = files_agents_cache[agent_state.id] + else: + files_agents = await self.file_agent_manager.list_files_for_agent( + agent_id=agent_state.id, actor=actor, is_open_only=False, return_as_blocks=False + ) + agent_schema = await AgentSchema.from_agent_state( + agent_state, message_manager=self.message_manager, files_agents=files_agents, actor=actor + ) + agent_schema.id = agent_file_id + + if agent_schema.messages: + for message in agent_schema.messages: + message_file_id = self._map_db_to_file_id(message.id, MessageSchema.__id_prefix__) + message.id = message_file_id + message.agent_id = agent_file_id + + if agent_schema.in_context_message_ids: + agent_schema.in_context_message_ids = [ + self._map_db_to_file_id(message_id, MessageSchema.__id_prefix__, allow_new=False) + for message_id in agent_schema.in_context_message_ids + ] + + if agent_schema.tool_ids: + agent_schema.tool_ids = [self._map_db_to_file_id(tool_id, ToolSchema.__id_prefix__) for tool_id in agent_schema.tool_ids] + + if agent_schema.source_ids: + agent_schema.source_ids = [ + self._map_db_to_file_id(source_id, SourceSchema.__id_prefix__) for source_id in agent_schema.source_ids + ] + + if agent_schema.block_ids: + agent_schema.block_ids = [self._map_db_to_file_id(block_id, BlockSchema.__id_prefix__) for block_id in agent_schema.block_ids] + + if agent_schema.files_agents: + for file_agent in agent_schema.files_agents: + file_agent.file_id = self._map_db_to_file_id(file_agent.file_id, FileSchema.__id_prefix__) + file_agent.source_id = self._map_db_to_file_id(file_agent.source_id, SourceSchema.__id_prefix__) + file_agent.agent_id = agent_file_id + + return agent_schema + + def _convert_tool_to_schema(self, tool) -> ToolSchema: + """Convert Tool to ToolSchema with ID remapping""" + tool_file_id = self._map_db_to_file_id(tool.id, ToolSchema.__id_prefix__, allow_new=False) + tool_schema = ToolSchema.from_tool(tool) + tool_schema.id = tool_file_id + return tool_schema + + def _convert_block_to_schema(self, block) -> BlockSchema: + """Convert Block to BlockSchema with ID remapping""" + block_file_id = self._map_db_to_file_id(block.id, BlockSchema.__id_prefix__, allow_new=False) + block_schema = BlockSchema.from_block(block) + block_schema.id = block_file_id + return block_schema + + def _convert_source_to_schema(self, source) -> SourceSchema: + """Convert Source to SourceSchema with ID remapping""" + source_file_id = self._map_db_to_file_id(source.id, SourceSchema.__id_prefix__, allow_new=False) + source_schema = SourceSchema.from_source(source) + source_schema.id = source_file_id + return source_schema + + def _convert_file_to_schema(self, file_metadata) -> FileSchema: + """Convert FileMetadata to FileSchema with ID remapping""" + file_file_id = self._map_db_to_file_id(file_metadata.id, FileSchema.__id_prefix__, allow_new=False) + file_schema = FileSchema.from_file_metadata(file_metadata) + file_schema.id = file_file_id + file_schema.source_id = self._map_db_to_file_id(file_metadata.source_id, SourceSchema.__id_prefix__, allow_new=False) + return file_schema + + async def export(self, agent_ids: List[str], actor: User) -> AgentFileSchema: + """ + Export agents and their related entities to AgentFileSchema format. + + Args: + agent_ids: List of agent UUIDs to export + + Returns: + AgentFileSchema with all related entities + + Raises: + AgentFileExportError: If export fails + """ + try: + self._reset_state() + + agent_states = await self.agent_manager.get_agents_by_ids_async(agent_ids=agent_ids, actor=actor) + + # Validate that all requested agents were found + if len(agent_states) != len(agent_ids): + found_ids = {agent.id for agent in agent_states} + missing_ids = [agent_id for agent_id in agent_ids if agent_id not in found_ids] + raise AgentFileExportError(f"The following agent IDs were not found: {missing_ids}") + + # cache for file-agent relationships to avoid duplicate queries + files_agents_cache = {} # Maps agent_id to list of file_agent relationships + + # Extract unique entities across all agents + tool_set = self._extract_unique_tools(agent_states) + block_set = self._extract_unique_blocks(agent_states) + + # Extract sources and files from agent states BEFORE conversion (with caching) + source_set, file_set = await self._extract_unique_sources_and_files_from_agents(agent_states, actor, files_agents_cache) + + # Convert to schemas with ID remapping (reusing cached file-agent data) + agent_schemas = [ + await self._convert_agent_state_to_schema(agent_state, actor=actor, files_agents_cache=files_agents_cache) + for agent_state in agent_states + ] + tool_schemas = [self._convert_tool_to_schema(tool) for tool in tool_set] + block_schemas = [self._convert_block_to_schema(block) for block in block_set] + source_schemas = [self._convert_source_to_schema(source) for source in source_set] + file_schemas = [self._convert_file_to_schema(file_metadata) for file_metadata in file_set] + + logger.info(f"Exporting {len(agent_ids)} agents to agent file format") + + # Return AgentFileSchema with converted entities + return AgentFileSchema( + agents=agent_schemas, + groups=[], # TODO: Extract and convert groups + blocks=block_schemas, + files=file_schemas, + sources=source_schemas, + tools=tool_schemas, + # mcp_servers=[], # TODO: Extract and convert MCP servers + metadata={"revision_id": await get_latest_alembic_revision()}, + created_at=datetime.now(timezone.utc), + ) + + except Exception as e: + logger.error(f"Failed to export agent file: {e}") + raise AgentFileExportError(f"Export failed: {e}") from e + + async def import_file(self, schema: AgentFileSchema, actor: User, dry_run: bool = False) -> ImportResult: + """ + Import AgentFileSchema into the database. + + Args: + schema: The agent file schema to import + dry_run: If True, validate but don't commit changes + + Returns: + ImportResult with success status and details + + Raises: + AgentFileImportError: If import fails + """ + try: + self._reset_state() + + if dry_run: + logger.info("Starting dry run import validation") + else: + logger.info("Starting agent file import") + + # Validate schema first + self._validate_schema(schema) + + if dry_run: + return ImportResult( + success=True, + message="Dry run validation passed", + imported_count=0, + ) + + # Import in dependency order + imported_count = 0 + file_to_db_ids = {} # Maps file IDs to new database IDs + # in-memory cache for file metadata to avoid repeated db calls + file_metadata_cache = {} # Maps database file ID to FileMetadata + + # 1. Create tools first (no dependencies) - using bulk upsert for efficiency + if schema.tools: + # convert tool schemas to pydantic tools + pydantic_tools = [] + for tool_schema in schema.tools: + pydantic_tools.append(Tool(**tool_schema.model_dump(exclude={"id"}))) + + # bulk upsert all tools at once + created_tools = await self.tool_manager.bulk_upsert_tools_async(pydantic_tools, actor) + + # map file ids to database ids + # note: tools are matched by name during upsert, so we need to match by name here too + created_tools_by_name = {tool.name: tool for tool in created_tools} + for tool_schema in schema.tools: + created_tool = created_tools_by_name.get(tool_schema.name) + if created_tool: + file_to_db_ids[tool_schema.id] = created_tool.id + imported_count += 1 + else: + logger.warning(f"Tool {tool_schema.name} was not created during bulk upsert") + + # 2. Create blocks (no dependencies) - using batch create for efficiency + if schema.blocks: + # convert block schemas to pydantic blocks (excluding IDs to create new blocks) + pydantic_blocks = [] + for block_schema in schema.blocks: + pydantic_blocks.append(Block(**block_schema.model_dump(exclude={"id"}))) + + # batch create all blocks at once + created_blocks = await self.block_manager.batch_create_blocks_async(pydantic_blocks, actor) + + # map file ids to database ids + for block_schema, created_block in zip(schema.blocks, created_blocks): + file_to_db_ids[block_schema.id] = created_block.id + imported_count += 1 + + # 3. Create sources (no dependencies) - using bulk upsert for efficiency + if schema.sources: + # convert source schemas to pydantic sources + pydantic_sources = [] + for source_schema in schema.sources: + source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"}) + pydantic_sources.append(Source(**source_data)) + + # bulk upsert all sources at once + created_sources = await self.source_manager.bulk_upsert_sources_async(pydantic_sources, actor) + + # map file ids to database ids + # note: sources are matched by name during upsert, so we need to match by name here too + created_sources_by_name = {source.name: source for source in created_sources} + for source_schema in schema.sources: + created_source = created_sources_by_name.get(source_schema.name) + if created_source: + file_to_db_ids[source_schema.id] = created_source.id + imported_count += 1 + else: + logger.warning(f"Source {source_schema.name} was not created during bulk upsert") + + # 4. Create files (depends on sources) + for file_schema in schema.files: + # Convert FileSchema back to FileMetadata + file_data = file_schema.model_dump(exclude={"id", "content"}) + # Remap source_id from file ID to database ID + file_data["source_id"] = file_to_db_ids[file_schema.source_id] + file_metadata = FileMetadata(**file_data) + created_file = await self.file_manager.create_file(file_metadata, actor, text=file_schema.content) + file_to_db_ids[file_schema.id] = created_file.id + imported_count += 1 + + # 5. Process files for chunking/embedding (depends on files and sources) + file_processor = FileProcessor( + file_parser=self.file_parser, + embedder=self.embedder, + actor=actor, + using_pinecone=self.using_pinecone, + ) + + for file_schema in schema.files: + if file_schema.content: # Only process files with content + file_db_id = file_to_db_ids[file_schema.id] + source_db_id = file_to_db_ids[file_schema.source_id] + + # Get the created file metadata (with caching) + if file_db_id not in file_metadata_cache: + file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor) + file_metadata = file_metadata_cache[file_db_id] + + # Save the db call of fetching content again + file_metadata.content = file_schema.content + + # Process the file for chunking/embedding + passages = await file_processor.process_imported_file(file_metadata=file_metadata, source_id=source_db_id) + imported_count += len(passages) + + # 6. Create agents with empty message history + for agent_schema in schema.agents: + # Convert AgentSchema back to CreateAgent, remapping tool/block IDs + agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"}) + + # Remap tool_ids from file IDs to database IDs + if agent_data.get("tool_ids"): + agent_data["tool_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["tool_ids"]] + + # Remap block_ids from file IDs to database IDs + if agent_data.get("block_ids"): + agent_data["block_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["block_ids"]] + + agent_create = CreateAgent(**agent_data) + created_agent = await self.agent_manager.create_agent_async(agent_create, actor, _init_with_no_messages=True) + file_to_db_ids[agent_schema.id] = created_agent.id + imported_count += 1 + + # 7. Create messages and update agent message_ids + for agent_schema in schema.agents: + agent_db_id = file_to_db_ids[agent_schema.id] + message_file_to_db_ids = {} + + # Create messages for this agent + messages = [] + for message_schema in agent_schema.messages: + # Convert MessageSchema back to Message, setting agent_id to new DB ID + message_data = message_schema.model_dump(exclude={"id"}) + message_data["agent_id"] = agent_db_id # Remap agent_id to new database ID + message_obj = Message(**message_data) + messages.append(message_obj) + # Map file ID to the generated database ID immediately + message_file_to_db_ids[message_schema.id] = message_obj.id + + created_messages = await self.message_manager.create_many_messages_async(pydantic_msgs=messages, actor=actor) + imported_count += len(created_messages) + + # Remap in_context_message_ids from file IDs to database IDs + in_context_db_ids = [message_file_to_db_ids[message_schema_id] for message_schema_id in agent_schema.in_context_message_ids] + + # Update agent with the correct message_ids + await self.agent_manager.update_message_ids_async(agent_id=agent_db_id, message_ids=in_context_db_ids, actor=actor) + + # 8. Create file-agent relationships (depends on agents and files) + for agent_schema in schema.agents: + if agent_schema.files_agents: + agent_db_id = file_to_db_ids[agent_schema.id] + + # Prepare files for bulk attachment + files_for_agent = [] + visible_content_map = {} + + for file_agent_schema in agent_schema.files_agents: + file_db_id = file_to_db_ids[file_agent_schema.file_id] + + # Use cached file metadata if available + if file_db_id not in file_metadata_cache: + file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor) + file_metadata = file_metadata_cache[file_db_id] + files_for_agent.append(file_metadata) + + if file_agent_schema.visible_content: + visible_content_map[file_db_id] = file_agent_schema.visible_content + + # Bulk attach files to agent + await self.file_agent_manager.attach_files_bulk( + agent_id=agent_db_id, files_metadata=files_for_agent, visible_content_map=visible_content_map, actor=actor + ) + imported_count += len(files_for_agent) + + return ImportResult( + success=True, + message=f"Import completed successfully. Imported {imported_count} entities.", + imported_count=imported_count, + id_mappings=file_to_db_ids, + ) + + except Exception as e: + logger.exception(f"Failed to import agent file: {e}") + raise AgentFileImportError(f"Import failed: {e}") from e + + def _validate_id_format(self, schema: AgentFileSchema) -> List[str]: + """Validate that all IDs follow the expected format""" + errors = [] + + # Define entity types and their expected prefixes + entity_checks = [ + (schema.agents, AgentSchema.__id_prefix__), + (schema.groups, GroupSchema.__id_prefix__), + (schema.blocks, BlockSchema.__id_prefix__), + (schema.files, FileSchema.__id_prefix__), + (schema.sources, SourceSchema.__id_prefix__), + (schema.tools, ToolSchema.__id_prefix__), + ] + + for entities, expected_prefix in entity_checks: + for entity in entities: + if not entity.id.startswith(f"{expected_prefix}-"): + errors.append(f"Invalid ID format: {entity.id} should start with '{expected_prefix}-'") + else: + # Check that the suffix is a valid integer + try: + suffix = entity.id[len(expected_prefix) + 1 :] + int(suffix) + except ValueError: + errors.append(f"Invalid ID format: {entity.id} should have integer suffix") + + # Also check message IDs within agents + for agent in schema.agents: + for message in agent.messages: + if not message.id.startswith(f"{MessageSchema.__id_prefix__}-"): + errors.append(f"Invalid message ID format: {message.id} should start with '{MessageSchema.__id_prefix__}-'") + else: + # Check that the suffix is a valid integer + try: + suffix = message.id[len(MessageSchema.__id_prefix__) + 1 :] + int(suffix) + except ValueError: + errors.append(f"Invalid message ID format: {message.id} should have integer suffix") + + return errors + + def _validate_duplicate_ids(self, schema: AgentFileSchema) -> List[str]: + """Validate that there are no duplicate IDs within or across entity types""" + errors = [] + all_ids = set() + + # Check each entity type for internal duplicates and collect all IDs + entity_collections = [ + ("agents", schema.agents), + ("groups", schema.groups), + ("blocks", schema.blocks), + ("files", schema.files), + ("sources", schema.sources), + ("tools", schema.tools), + ] + + for entity_type, entities in entity_collections: + entity_ids = [entity.id for entity in entities] + + # Check for duplicates within this entity type + seen = set() + duplicates = set() + for entity_id in entity_ids: + if entity_id in seen: + duplicates.add(entity_id) + else: + seen.add(entity_id) + + if duplicates: + errors.append(f"Duplicate {entity_type} IDs found: {duplicates}") + + # Check for duplicates across all entity types + for entity_id in entity_ids: + if entity_id in all_ids: + errors.append(f"Duplicate ID across entity types: {entity_id}") + all_ids.add(entity_id) + + # Also check message IDs within agents + for agent in schema.agents: + message_ids = [msg.id for msg in agent.messages] + + # Check for duplicates within agent messages + seen = set() + duplicates = set() + for message_id in message_ids: + if message_id in seen: + duplicates.add(message_id) + else: + seen.add(message_id) + + if duplicates: + errors.append(f"Duplicate message IDs in agent {agent.id}: {duplicates}") + + # Check for duplicates across all entity types + for message_id in message_ids: + if message_id in all_ids: + errors.append(f"Duplicate ID across entity types: {message_id}") + all_ids.add(message_id) + + return errors + + def _validate_file_source_references(self, schema: AgentFileSchema) -> List[str]: + """Validate that all file source_id references exist""" + errors = [] + source_ids = {source.id for source in schema.sources} + + for file in schema.files: + if file.source_id not in source_ids: + errors.append(f"File {file.id} references non-existent source {file.source_id}") + + return errors + + def _validate_file_agent_references(self, schema: AgentFileSchema) -> List[str]: + """Validate that all file-agent relationships reference existing entities""" + errors = [] + file_ids = {file.id for file in schema.files} + source_ids = {source.id for source in schema.sources} + {agent.id for agent in schema.agents} + + for agent in schema.agents: + for file_agent in agent.files_agents: + if file_agent.file_id not in file_ids: + errors.append(f"File-agent relationship references non-existent file {file_agent.file_id}") + if file_agent.source_id not in source_ids: + errors.append(f"File-agent relationship references non-existent source {file_agent.source_id}") + if file_agent.agent_id != agent.id: + errors.append(f"File-agent relationship has mismatched agent_id {file_agent.agent_id} vs {agent.id}") + + return errors + + def _validate_schema(self, schema: AgentFileSchema): + """ + Validate the agent file schema for consistency and referential integrity. + + Args: + schema: The schema to validate + + Raises: + AgentFileImportError: If validation fails + """ + errors = [] + + # 1. ID Format Validation + errors.extend(self._validate_id_format(schema)) + + # 2. Duplicate ID Detection + errors.extend(self._validate_duplicate_ids(schema)) + + # 3. File Source Reference Validation + errors.extend(self._validate_file_source_references(schema)) + + # 4. File-Agent Reference Validation + errors.extend(self._validate_file_agent_references(schema)) + + if errors: + raise AgentFileImportError(f"Schema validation failed: {'; '.join(errors)}") + + logger.info("Schema validation passed") diff --git a/letta/services/tool_executor/builtin_tool_executor.py b/letta/services/tool_executor/builtin_tool_executor.py index a4146320..8e7cfcb8 100644 --- a/letta/services/tool_executor/builtin_tool_executor.py +++ b/letta/services/tool_executor/builtin_tool_executor.py @@ -1,5 +1,6 @@ import asyncio import json +import os import time from typing import Any, Dict, List, Literal, Optional diff --git a/pyproject.toml b/pyproject.toml index 8d554bb3..e13c7ecf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ letta = "letta.main:app" [tool.poetry] name = "letta" -version = "0.10.0" +version = "0.11.4" packages = [ {include = "letta"}, ] diff --git a/scripts/docker-compose.yml b/scripts/docker-compose.yml new file mode 100644 index 00000000..3347d213 --- /dev/null +++ b/scripts/docker-compose.yml @@ -0,0 +1,32 @@ +version: '3.7' +services: + redis: + image: redis:alpine + container_name: redis + healthcheck: + test: ['CMD-SHELL', 'redis-cli ping | grep PONG'] + interval: 1s + timeout: 3s + retries: 5 + ports: + - '6379:6379' + volumes: + - ./data/redis:/data + command: redis-server --appendonly yes + postgres: + image: ankane/pgvector + container_name: postgres + healthcheck: + test: ['CMD-SHELL', 'pg_isready -U postgres'] + interval: 1s + timeout: 3s + retries: 5 + ports: + - '5432:5432' + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: letta + volumes: + - ./data/postgres:/var/lib/postgresql/data + - ./scripts/postgres-db-init/init.sql:/docker-entrypoint-initdb.d/init.sql diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index 5f08452b..3d5203bc 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -156,6 +156,7 @@ async def test_sleeptime_group_chat(server, actor): # 6. Verify run status after sleep time.sleep(2) + for run_id in run_ids: job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor) assert job.status == JobStatus.running or job.status == JobStatus.completed diff --git a/tests/test_tool_sandbox/restaurant_management_system/adjust_menu_prices.py b/tests/test_tool_sandbox/restaurant_management_system/adjust_menu_prices.py index 1d3d9d3e..ffe734b3 100644 --- a/tests/test_tool_sandbox/restaurant_management_system/adjust_menu_prices.py +++ b/tests/test_tool_sandbox/restaurant_management_system/adjust_menu_prices.py @@ -8,10 +8,9 @@ def adjust_menu_prices(percentage: float) -> str: str: A formatted string summarizing the price adjustments. """ import cowsay - from tqdm import tqdm - from core.menu import Menu, MenuItem # Import a class from the codebase from core.utils import format_currency # Use a utility function to test imports + from tqdm import tqdm if not isinstance(percentage, (int, float)): raise TypeError("percentage must be a number")