From 9ca1664ed89aec0597ff670b2320c719c16cbfbf Mon Sep 17 00:00:00 2001
From: Andy Li <55300002+cliandy@users.noreply.github.com>
Date: Tue, 22 Jul 2025 16:09:50 -0700
Subject: [PATCH] feat: support for providers
---
letta/groups/dynamic_multi_agent.py | 1 +
letta/llm_api/anthropic.py | 4 +-
letta/llm_api/aws_bedrock.py | 118 +-
letta/llm_api/deepseek.py | 2 +-
letta/llm_api/google_ai_client.py | 38 -
letta/llm_api/google_constants.py | 9 +-
letta/llm_api/helpers.py | 2 +-
letta/llm_api/llm_api_tools.py | 11 +-
letta/llm_api/mistral.py | 49 +-
letta/llm_api/openai.py | 34 +-
.../sample_response_jsons/aws_bedrock.json | 38 +
.../lmstudio_embedding_list.json | 15 +
.../lmstudio_model_list.json | 15 +
letta/local_llm/constants.py | 22 -
.../llm_chat_completion_wrappers/airoboros.py | 18 +-
.../llm_chat_completion_wrappers/chatml.py | 15 +-
.../configurable_wrapper.py | 12 +-
.../llm_chat_completion_wrappers/dolphin.py | 6 +-
.../simple_summary_wrapper.py | 2 +-
letta/local_llm/ollama/api.py | 4 +-
letta/schemas/embedding_config.py | 8 +-
letta/schemas/enums.py | 2 +
letta/schemas/providers.py | 1618 -----------------
letta/schemas/providers/__init__.py | 47 +
letta/schemas/providers/anthropic.py | 78 +
letta/schemas/providers/azure.py | 80 +
letta/schemas/providers/base.py | 201 ++
letta/schemas/providers/bedrock.py | 78 +
letta/schemas/providers/cerebras.py | 79 +
letta/schemas/providers/cohere.py | 18 +
letta/schemas/providers/deepseek.py | 63 +
letta/schemas/providers/google_gemini.py | 102 ++
letta/schemas/providers/google_vertex.py | 54 +
letta/schemas/providers/groq.py | 35 +
letta/schemas/providers/letta.py | 39 +
letta/schemas/providers/lmstudio.py | 97 +
letta/schemas/providers/mistral.py | 41 +
letta/schemas/providers/ollama.py | 151 ++
letta/schemas/providers/openai.py | 241 +++
letta/schemas/providers/together.py | 85 +
letta/schemas/providers/vllm.py | 57 +
letta/schemas/providers/xai.py | 66 +
letta/server/rest_api/app.py | 7 +-
letta/server/rest_api/routers/v1/providers.py | 4 +-
letta/server/server.py | 24 +-
.../services/file_processor/file_processor.py | 81 +-
letta/services/group_manager.py | 7 +
letta/services/provider_manager.py | 4 +-
tests/integration_test_async_tool_sandbox.py | 138 +-
tests/test_multi_agent.py | 243 +--
tests/test_providers.py | 295 +--
tests/test_server.py | 7 +-
tests/test_tool_rule_solver.py | 2 +-
53 files changed, 2277 insertions(+), 2190 deletions(-)
create mode 100644 letta/llm_api/sample_response_jsons/aws_bedrock.json
create mode 100644 letta/llm_api/sample_response_jsons/lmstudio_embedding_list.json
create mode 100644 letta/llm_api/sample_response_jsons/lmstudio_model_list.json
delete mode 100644 letta/schemas/providers.py
create mode 100644 letta/schemas/providers/__init__.py
create mode 100644 letta/schemas/providers/anthropic.py
create mode 100644 letta/schemas/providers/azure.py
create mode 100644 letta/schemas/providers/base.py
create mode 100644 letta/schemas/providers/bedrock.py
create mode 100644 letta/schemas/providers/cerebras.py
create mode 100644 letta/schemas/providers/cohere.py
create mode 100644 letta/schemas/providers/deepseek.py
create mode 100644 letta/schemas/providers/google_gemini.py
create mode 100644 letta/schemas/providers/google_vertex.py
create mode 100644 letta/schemas/providers/groq.py
create mode 100644 letta/schemas/providers/letta.py
create mode 100644 letta/schemas/providers/lmstudio.py
create mode 100644 letta/schemas/providers/mistral.py
create mode 100644 letta/schemas/providers/ollama.py
create mode 100644 letta/schemas/providers/openai.py
create mode 100644 letta/schemas/providers/together.py
create mode 100644 letta/schemas/providers/vllm.py
create mode 100644 letta/schemas/providers/xai.py
diff --git a/letta/groups/dynamic_multi_agent.py b/letta/groups/dynamic_multi_agent.py
index f89a6f3a..500d923d 100644
--- a/letta/groups/dynamic_multi_agent.py
+++ b/letta/groups/dynamic_multi_agent.py
@@ -94,6 +94,7 @@ class DynamicMultiAgent(Agent):
for name, agent_id in [(agents[agent_id].agent_state.name, agent_id) for agent_id in agent_id_options]:
if name.lower() in assistant_message.content.lower():
speaker_id = agent_id
+ assert speaker_id is not None, f"No names found in {assistant_message.content}"
# Sum usage
total_usage.prompt_tokens += usage_stats.prompt_tokens
diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py
index 2fb9b865..8b8445a2 100644
--- a/letta/llm_api/anthropic.py
+++ b/letta/llm_api/anthropic.py
@@ -717,7 +717,7 @@ def _prepare_anthropic_request(
data["temperature"] = 1.0
if "functions" in data:
- raise ValueError(f"'functions' unexpected in Anthropic API payload")
+ raise ValueError("'functions' unexpected in Anthropic API payload")
# Handle tools
if "tools" in data and data["tools"] is None:
@@ -1150,7 +1150,7 @@ def anthropic_chat_completions_process_stream(
accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments
if message_delta.function_call is not None:
- raise NotImplementedError(f"Old function_call style not support with stream=True")
+ raise NotImplementedError("Old function_call style not support with stream=True")
# overwrite response fields based on latest chunk
if not create_message_id:
diff --git a/letta/llm_api/aws_bedrock.py b/letta/llm_api/aws_bedrock.py
index 539ce0fd..fd994795 100644
--- a/letta/llm_api/aws_bedrock.py
+++ b/letta/llm_api/aws_bedrock.py
@@ -1,17 +1,30 @@
+"""
+Note that this formally only supports Anthropic Bedrock.
+TODO (cliandy): determine what other providers are supported and what is needed to add support.
+"""
+
import os
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from anthropic import AnthropicBedrock
+from letta.log import get_logger
from letta.settings import model_settings
+logger = get_logger(__name__)
+
def has_valid_aws_credentials() -> bool:
"""
Check if AWS credentials are properly configured.
"""
- valid_aws_credentials = os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY") and os.getenv("AWS_DEFAULT_REGION")
- return valid_aws_credentials
+ return all(
+ (
+ os.getenv("AWS_ACCESS_KEY_ID"),
+ os.getenv("AWS_SECRET_ACCESS_KEY"),
+ os.getenv("AWS_DEFAULT_REGION"),
+ )
+ )
def get_bedrock_client(
@@ -41,48 +54,11 @@ def get_bedrock_client(
return bedrock
-def bedrock_get_model_list(
- region_name: str,
- access_key_id: Optional[str] = None,
- secret_access_key: Optional[str] = None,
-) -> List[dict]:
- """
- Get list of available models from Bedrock.
-
- Args:
- region_name: AWS region name
- access_key_id: Optional AWS access key ID
- secret_access_key: Optional AWS secret access key
-
- TODO: Implement model_provider and output_modality filtering
- model_provider: Optional provider name to filter models. If None, returns all models.
- output_modality: Output modality to filter models. Defaults to "text".
-
- Returns:
- List of model summaries
-
- """
- import boto3
-
- try:
- bedrock = boto3.client(
- "bedrock",
- region_name=region_name,
- aws_access_key_id=access_key_id,
- aws_secret_access_key=secret_access_key,
- )
- response = bedrock.list_inference_profiles()
- return response["inferenceProfileSummaries"]
- except Exception as e:
- print(f"Error getting model list: {str(e)}")
- raise e
-
-
async def bedrock_get_model_list_async(
access_key_id: Optional[str] = None,
secret_access_key: Optional[str] = None,
default_region: Optional[str] = None,
-) -> List[dict]:
+) -> list[dict]:
from aioboto3.session import Session
try:
@@ -96,11 +72,11 @@ async def bedrock_get_model_list_async(
response = await bedrock.list_inference_profiles()
return response["inferenceProfileSummaries"]
except Exception as e:
- print(f"Error getting model list: {str(e)}")
+ logger.error(f"Error getting model list for bedrock: %s", e)
raise e
-def bedrock_get_model_details(region_name: str, model_id: str) -> Dict[str, Any]:
+def bedrock_get_model_details(region_name: str, model_id: str) -> dict[str, Any]:
"""
Get details for a specific model from Bedrock.
"""
@@ -121,54 +97,8 @@ def bedrock_get_model_context_window(model_id: str) -> int:
Get context window size for a specific model.
"""
# Bedrock doesn't provide this via API, so we maintain a mapping
- context_windows = {
- "anthropic.claude-3-5-sonnet-20241022-v2:0": 200000,
- "anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
- "anthropic.claude-3-5-haiku-20241022-v1:0": 200000,
- "anthropic.claude-3-haiku-20240307-v1:0": 200000,
- "anthropic.claude-3-opus-20240229-v1:0": 200000,
- "anthropic.claude-3-sonnet-20240229-v1:0": 200000,
- }
- return context_windows.get(model_id, 200000) # default to 100k if unknown
-
-
-"""
-{
- "id": "msg_123",
- "type": "message",
- "role": "assistant",
- "model": "anthropic.claude-3-5-sonnet-20241022-v2:0",
- "content": [
- {
- "type": "text",
- "text": "I see the Firefox icon. Let me click on it and then navigate to a weather website."
- },
- {
- "type": "tool_use",
- "id": "toolu_123",
- "name": "computer",
- "input": {
- "action": "mouse_move",
- "coordinate": [
- 708,
- 736
- ]
- }
- },
- {
- "type": "tool_use",
- "id": "toolu_234",
- "name": "computer",
- "input": {
- "action": "left_click"
- }
- }
- ],
- "stop_reason": "tool_use",
- "stop_sequence": null,
- "usage": {
- "input_tokens": 3391,
- "output_tokens": 132
- }
-}
-"""
+ # 200k for anthropic: https://aws.amazon.com/bedrock/anthropic/
+ if model_id.startswith("anthropic"):
+ return 200_000
+ else:
+ return 100_000 # default to 100k if unknown
diff --git a/letta/llm_api/deepseek.py b/letta/llm_api/deepseek.py
index f0b2a45a..5d4eb9e1 100644
--- a/letta/llm_api/deepseek.py
+++ b/letta/llm_api/deepseek.py
@@ -120,7 +120,7 @@ def build_deepseek_chat_completions_request(
def add_functions_to_system_message(system_message: ChatMessage):
system_message.content += f" {''.join(json.dumps(f) for f in functions)} "
- system_message.content += f'Select best function to call simply respond with a single json block with the fields "name" and "arguments". Use double quotes around the arguments.'
+ system_message.content += 'Select best function to call simply respond with a single json block with the fields "name" and "arguments". Use double quotes around the arguments.'
if llm_config.model == "deepseek-reasoner": # R1 currently doesn't support function calling natively
add_functions_to_system_message(
diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py
index a8fa03f4..e7987aa2 100644
--- a/letta/llm_api/google_ai_client.py
+++ b/letta/llm_api/google_ai_client.py
@@ -66,44 +66,6 @@ def google_ai_check_valid_api_key(api_key: str):
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
-def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]:
- """Synchronous version to get model list from Google AI API using httpx."""
- import httpx
-
- from letta.utils import printd
-
- url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
-
- try:
- with httpx.Client() as client:
- response = client.get(url, headers=headers)
- response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
- response_data = response.json() # convert to dict from string
-
- # Grab the models out
- model_list = response_data["models"]
- return model_list
-
- except httpx.HTTPStatusError as http_err:
- # Handle HTTP errors (e.g., response 4XX, 5XX)
- printd(f"Got HTTPError, exception={http_err}")
- # Print the HTTP status code
- print(f"HTTP Error: {http_err.response.status_code}")
- # Print the response content (error message from server)
- print(f"Message: {http_err.response.text}")
- raise http_err
-
- except httpx.RequestError as req_err:
- # Handle other httpx-related errors (e.g., connection error)
- printd(f"Got RequestException, exception={req_err}")
- raise req_err
-
- except Exception as e:
- # Handle other potential errors
- printd(f"Got unknown Exception, exception={e}")
- raise e
-
-
async def google_ai_get_model_list_async(
base_url: str, api_key: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None
) -> List[dict]:
diff --git a/letta/llm_api/google_constants.py b/letta/llm_api/google_constants.py
index 1c30d615..3a50fb1b 100644
--- a/letta/llm_api/google_constants.py
+++ b/letta/llm_api/google_constants.py
@@ -1,7 +1,12 @@
GOOGLE_MODEL_TO_CONTEXT_LENGTH = {
+ "gemini-2.5-pro": 1048576,
+ "gemini-2.5-flash": 1048576,
+ "gemini-live-2.5-flash": 1048576,
+ "gemini-2.0-flash-001": 1048576,
+ "gemini-2.0-flash-lite-001": 1048576,
+ # The following are either deprecated or discontinued.
"gemini-2.5-pro-exp-03-25": 1048576,
"gemini-2.5-flash-preview-04-17": 1048576,
- "gemini-2.0-flash-001": 1048576,
"gemini-2.0-pro-exp-02-05": 2097152,
"gemini-2.0-flash-lite-preview-02-05": 1048576,
"gemini-2.0-flash-thinking-exp-01-21": 1048576,
@@ -11,8 +16,6 @@ GOOGLE_MODEL_TO_CONTEXT_LENGTH = {
"gemini-1.0-pro-vision": 16384,
}
-GOOGLE_MODEL_TO_OUTPUT_LENGTH = {"gemini-2.0-flash-001": 8192, "gemini-2.5-pro-exp-03-25": 65536}
-
GOOGLE_EMBEDING_MODEL_TO_DIM = {"text-embedding-005": 768, "text-multilingual-embedding-002": 768}
GOOGLE_MODEL_FOR_API_KEY_CHECK = "gemini-2.0-flash-lite"
diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py
index 749ae974..cff7b96a 100644
--- a/letta/llm_api/helpers.py
+++ b/letta/llm_api/helpers.py
@@ -252,7 +252,7 @@ def unpack_all_inner_thoughts_from_kwargs(
) -> ChatCompletionResponse:
"""Strip the inner thoughts out of the tool call and put it in the message content"""
if len(response.choices) == 0:
- raise ValueError(f"Unpacking inner thoughts from empty response not supported")
+ raise ValueError("Unpacking inner thoughts from empty response not supported")
new_choices = []
for choice in response.choices:
diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py
index 0f46a17a..bf943ed4 100644
--- a/letta/llm_api/llm_api_tools.py
+++ b/letta/llm_api/llm_api_tools.py
@@ -67,7 +67,6 @@ def retry_with_exponential_backoff(
# Stop retrying if user hits Ctrl-C
raise KeyboardInterrupt("User intentionally stopped thread. Stopping...")
except requests.exceptions.HTTPError as http_err:
-
if not hasattr(http_err, "response") or not http_err.response:
raise
@@ -175,7 +174,6 @@ def create(
# openai
if llm_config.model_endpoint_type == "openai":
-
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])
@@ -256,7 +254,6 @@ def create(
return response
elif llm_config.model_endpoint_type == "xai":
-
api_key = model_settings.xai_api_key
if function_call is None and functions is not None and len(functions) > 0:
@@ -464,7 +461,7 @@ def create(
# )
elif llm_config.model_endpoint_type == "groq":
if stream:
- raise NotImplementedError(f"Streaming not yet implemented for Groq.")
+ raise NotImplementedError("Streaming not yet implemented for Groq.")
if model_settings.groq_api_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions":
raise LettaConfigurationError(message="Groq key is missing from letta config file", missing_fields=["groq_api_key"])
@@ -517,7 +514,7 @@ def create(
"""TogetherAI endpoint that goes via /completions instead of /chat/completions"""
if stream:
- raise NotImplementedError(f"Streaming not yet implemented for TogetherAI (via the /completions endpoint).")
+ raise NotImplementedError("Streaming not yet implemented for TogetherAI (via the /completions endpoint).")
if model_settings.together_api_key is None and (
llm_config.model_endpoint == "https://api.together.ai/v1/completions"
@@ -547,7 +544,7 @@ def create(
"""Anthropic endpoint that goes via /embeddings instead of /chat/completions"""
if stream:
- raise NotImplementedError(f"Streaming not yet implemented for Anthropic (via the /embeddings endpoint).")
+ raise NotImplementedError("Streaming not yet implemented for Anthropic (via the /embeddings endpoint).")
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Anthropic API requests")
@@ -631,7 +628,7 @@ def create(
messages[0].content[0].text += f" {''.join(json.dumps(f) for f in functions)} "
messages[0].content[
0
- ].text += f'Select best function to call simply by responding with a single json block with the keys "function" and "params". Use double quotes around the arguments.'
+ ].text += 'Select best function to call simply by responding with a single json block with the keys "function" and "params". Use double quotes around the arguments.'
return get_chat_completion(
model=llm_config.model,
messages=messages,
diff --git a/letta/llm_api/mistral.py b/letta/llm_api/mistral.py
index 932cf874..68bb87e9 100644
--- a/letta/llm_api/mistral.py
+++ b/letta/llm_api/mistral.py
@@ -1,47 +1,22 @@
-import requests
+import aiohttp
-from letta.utils import printd, smart_urljoin
+from letta.log import get_logger
+from letta.utils import smart_urljoin
+
+logger = get_logger(__name__)
-def mistral_get_model_list(url: str, api_key: str) -> dict:
+async def mistral_get_model_list_async(url: str, api_key: str) -> dict:
url = smart_urljoin(url, "models")
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
- printd(f"Sending request to {url}")
- response = None
- try:
+ logger.debug(f"Sending request to %s", url)
+
+ async with aiohttp.ClientSession() as session:
# TODO add query param "tool" to be true
- response = requests.get(url, headers=headers)
- response.raise_for_status() # Raises HTTPError for 4XX/5XX status
- response_json = response.json() # convert to dict from string
- return response_json
- except requests.exceptions.HTTPError as http_err:
- # Handle HTTP errors (e.g., response 4XX, 5XX)
- try:
- if response:
- response = response.json()
- except:
- pass
- printd(f"Got HTTPError, exception={http_err}, response={response}")
- raise http_err
- except requests.exceptions.RequestException as req_err:
- # Handle other requests-related errors (e.g., connection error)
- try:
- if response:
- response = response.json()
- except:
- pass
- printd(f"Got RequestException, exception={req_err}, response={response}")
- raise req_err
- except Exception as e:
- # Handle other potential errors
- try:
- if response:
- response = response.json()
- except:
- pass
- printd(f"Got unknown Exception, exception={e}, response={response}")
- raise e
+ async with session.get(url, headers=headers) as response:
+ response.raise_for_status()
+ return await response.json()
diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py
index 7372b1d0..b83c4de4 100644
--- a/letta/llm_api/openai.py
+++ b/letta/llm_api/openai.py
@@ -59,11 +59,15 @@ def openai_check_valid_api_key(base_url: str, api_key: Union[str, None]) -> None
def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool = False, extra_params: Optional[dict] = None) -> dict:
"""https://platform.openai.com/docs/api-reference/models/list"""
- from letta.utils import printd
# In some cases we may want to double-check the URL and do basic correction, eg:
# In Letta config the address for vLLM is w/o a /v1 suffix for simplicity
# However if we're treating the server as an OpenAI proxy we want the /v1 suffix on our model hit
+
+ import warnings
+
+ warnings.warn("The synchronous version of openai_get_model_list function is deprecated. Use the async one instead.", DeprecationWarning)
+
if fix_url:
if not url.endswith("/v1"):
url = smart_urljoin(url, "v1")
@@ -74,14 +78,14 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
- printd(f"Sending request to {url}")
+ logger.debug(f"Sending request to {url}")
response = None
try:
# TODO add query param "tool" to be true
response = requests.get(url, headers=headers, params=extra_params)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
- printd(f"response = {response}")
+ logger.debug(f"response = {response}")
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
@@ -90,7 +94,7 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool
response = response.json()
except:
pass
- printd(f"Got HTTPError, exception={http_err}, response={response}")
+ logger.debug(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
@@ -99,7 +103,7 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool
response = response.json()
except:
pass
- printd(f"Got RequestException, exception={req_err}, response={response}")
+ logger.debug(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
@@ -108,7 +112,7 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool
response = response.json()
except:
pass
- printd(f"Got unknown Exception, exception={e}, response={response}")
+ logger.debug(f"Got unknown Exception, exception={e}, response={response}")
raise e
@@ -120,7 +124,6 @@ async def openai_get_model_list_async(
client: Optional["httpx.AsyncClient"] = None,
) -> dict:
"""https://platform.openai.com/docs/api-reference/models/list"""
- from letta.utils import printd
# In some cases we may want to double-check the URL and do basic correction
if fix_url and not url.endswith("/v1"):
@@ -132,7 +135,7 @@ async def openai_get_model_list_async(
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
- printd(f"Sending request to {url}")
+ logger.debug(f"Sending request to {url}")
# Use provided client or create a new one
close_client = False
@@ -144,24 +147,23 @@ async def openai_get_model_list_async(
response = await client.get(url, headers=headers, params=extra_params)
response.raise_for_status()
result = response.json()
- printd(f"response = {result}")
+ logger.debug(f"response = {result}")
return result
except httpx.HTTPStatusError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
- error_response = None
try:
error_response = http_err.response.json()
except:
error_response = {"status_code": http_err.response.status_code, "text": http_err.response.text}
- printd(f"Got HTTPError, exception={http_err}, response={error_response}")
+ logger.debug(f"Got HTTPError, exception={http_err}, response={error_response}")
raise http_err
except httpx.RequestError as req_err:
# Handle other httpx-related errors (e.g., connection error)
- printd(f"Got RequestException, exception={req_err}")
+ logger.debug(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
- printd(f"Got unknown Exception, exception={e}")
+ logger.debug(f"Got unknown Exception, exception={e}")
raise e
finally:
if close_client:
@@ -480,7 +482,7 @@ def openai_chat_completions_process_stream(
)
if message_delta.function_call is not None:
- raise NotImplementedError(f"Old function_call style not support with stream=True")
+ raise NotImplementedError("Old function_call style not support with stream=True")
# overwrite response fields based on latest chunk
if not create_message_id:
@@ -503,7 +505,7 @@ def openai_chat_completions_process_stream(
logger.error(f"Parsing ChatCompletion stream failed with error:\n{str(e)}")
raise e
finally:
- logger.info(f"Finally ending streaming interface.")
+ logger.info("Finally ending streaming interface.")
if stream_interface:
stream_interface.stream_end()
@@ -525,7 +527,6 @@ def openai_chat_completions_process_stream(
assert len(chat_completion_response.choices) > 0, f"No response from provider {chat_completion_response}"
- # printd(chat_completion_response)
log_event(name="llm_response_received", attributes=chat_completion_response.model_dump())
return chat_completion_response
@@ -536,7 +537,6 @@ def openai_chat_completions_request_stream(
chat_completion_request: ChatCompletionRequest,
fix_url: bool = False,
) -> Generator[ChatCompletionChunkResponse, None, None]:
-
# In some cases we may want to double-check the URL and do basic correction, eg:
# In Letta config the address for vLLM is w/o a /v1 suffix for simplicity
# However if we're treating the server as an OpenAI proxy we want the /v1 suffix on our model hit
diff --git a/letta/llm_api/sample_response_jsons/aws_bedrock.json b/letta/llm_api/sample_response_jsons/aws_bedrock.json
new file mode 100644
index 00000000..c8ff79c8
--- /dev/null
+++ b/letta/llm_api/sample_response_jsons/aws_bedrock.json
@@ -0,0 +1,38 @@
+{
+ "id": "msg_123",
+ "type": "message",
+ "role": "assistant",
+ "model": "anthropic.claude-3-5-sonnet-20241022-v2:0",
+ "content": [
+ {
+ "type": "text",
+ "text": "I see the Firefox icon. Let me click on it and then navigate to a weather website."
+ },
+ {
+ "type": "tool_use",
+ "id": "toolu_123",
+ "name": "computer",
+ "input": {
+ "action": "mouse_move",
+ "coordinate": [
+ 708,
+ 736
+ ]
+ }
+ },
+ {
+ "type": "tool_use",
+ "id": "toolu_234",
+ "name": "computer",
+ "input": {
+ "action": "left_click"
+ }
+ }
+ ],
+ "stop_reason": "tool_use",
+ "stop_sequence": null,
+ "usage": {
+ "input_tokens": 3391,
+ "output_tokens": 132
+ }
+}
diff --git a/letta/llm_api/sample_response_jsons/lmstudio_embedding_list.json b/letta/llm_api/sample_response_jsons/lmstudio_embedding_list.json
new file mode 100644
index 00000000..25489ff3
--- /dev/null
+++ b/letta/llm_api/sample_response_jsons/lmstudio_embedding_list.json
@@ -0,0 +1,15 @@
+{
+ "object": "list",
+ "data": [
+ {
+ "id": "text-embedding-nomic-embed-text-v1.5",
+ "object": "model",
+ "type": "embeddings",
+ "publisher": "nomic-ai",
+ "arch": "nomic-bert",
+ "compatibility_type": "gguf",
+ "quantization": "Q4_0",
+ "state": "not-loaded",
+ "max_context_length": 2048
+ },
+ ...
diff --git a/letta/llm_api/sample_response_jsons/lmstudio_model_list.json b/letta/llm_api/sample_response_jsons/lmstudio_model_list.json
new file mode 100644
index 00000000..8b7e7b70
--- /dev/null
+++ b/letta/llm_api/sample_response_jsons/lmstudio_model_list.json
@@ -0,0 +1,15 @@
+ {
+ "object": "list",
+ "data": [
+ {
+ "id": "qwen2-vl-7b-instruct",
+ "object": "model",
+ "type": "vlm",
+ "publisher": "mlx-community",
+ "arch": "qwen2_vl",
+ "compatibility_type": "mlx",
+ "quantization": "4bit",
+ "state": "not-loaded",
+ "max_context_length": 32768
+ },
+ ...,
diff --git a/letta/local_llm/constants.py b/letta/local_llm/constants.py
index 83681fde..2b51101d 100644
--- a/letta/local_llm/constants.py
+++ b/letta/local_llm/constants.py
@@ -1,27 +1,5 @@
-# import letta.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
from letta.local_llm.llm_chat_completion_wrappers.chatml import ChatMLInnerMonologueWrapper
-DEFAULT_ENDPOINTS = {
- # Local
- "koboldcpp": "http://localhost:5001",
- "llamacpp": "http://localhost:8080",
- "lmstudio": "http://localhost:1234",
- "lmstudio-legacy": "http://localhost:1234",
- "ollama": "http://localhost:11434",
- "webui-legacy": "http://localhost:5000",
- "webui": "http://localhost:5000",
- "vllm": "http://localhost:8000",
- # APIs
- "openai": "https://api.openai.com",
- "anthropic": "https://api.anthropic.com",
- "groq": "https://api.groq.com/openai",
-}
-
-DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K"
-
-# DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper
-# DEFAULT_WRAPPER_NAME = "airoboros-l2-70b-2.1"
-
DEFAULT_WRAPPER = ChatMLInnerMonologueWrapper
DEFAULT_WRAPPER_NAME = "chatml"
diff --git a/letta/local_llm/llm_chat_completion_wrappers/airoboros.py b/letta/local_llm/llm_chat_completion_wrappers/airoboros.py
index 5f076ec8..544d11d4 100644
--- a/letta/local_llm/llm_chat_completion_wrappers/airoboros.py
+++ b/letta/local_llm/llm_chat_completion_wrappers/airoboros.py
@@ -75,7 +75,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
func_str = ""
func_str += f"{schema['name']}:"
func_str += f"\n description: {schema['description']}"
- func_str += f"\n params:"
+ func_str += "\n params:"
for param_k, param_v in schema["parameters"]["properties"].items():
# TODO we're ignoring type
func_str += f"\n {param_k}: {param_v['description']}"
@@ -83,8 +83,8 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
return func_str
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
- prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
- prompt += f"\nAvailable functions:"
+ prompt += "\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
+ prompt += "\nAvailable functions:"
if function_documentation is not None:
prompt += f"\n{function_documentation}"
else:
@@ -150,7 +150,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
prompt += "\n### RESPONSE"
if self.include_assistant_prefix:
- prompt += f"\nASSISTANT:"
+ prompt += "\nASSISTANT:"
if self.include_opening_brance_in_prefix:
prompt += "\n{"
@@ -282,9 +282,9 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
func_str = ""
func_str += f"{schema['name']}:"
func_str += f"\n description: {schema['description']}"
- func_str += f"\n params:"
+ func_str += "\n params:"
if add_inner_thoughts:
- func_str += f"\n inner_thoughts: Deep inner monologue private to you only."
+ func_str += "\n inner_thoughts: Deep inner monologue private to you only."
for param_k, param_v in schema["parameters"]["properties"].items():
# TODO we're ignoring type
func_str += f"\n {param_k}: {param_v['description']}"
@@ -292,8 +292,8 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
return func_str
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
- prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
- prompt += f"\nAvailable functions:"
+ prompt += "\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
+ prompt += "\nAvailable functions:"
if function_documentation is not None:
prompt += f"\n{function_documentation}"
else:
@@ -375,7 +375,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
prompt += "\n### RESPONSE"
if self.include_assistant_prefix:
- prompt += f"\nASSISTANT:"
+ prompt += "\nASSISTANT:"
if self.assistant_prefix_extra:
prompt += self.assistant_prefix_extra
diff --git a/letta/local_llm/llm_chat_completion_wrappers/chatml.py b/letta/local_llm/llm_chat_completion_wrappers/chatml.py
index 62e8a8cf..71589959 100644
--- a/letta/local_llm/llm_chat_completion_wrappers/chatml.py
+++ b/letta/local_llm/llm_chat_completion_wrappers/chatml.py
@@ -71,7 +71,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
func_str = ""
func_str += f"{schema['name']}:"
func_str += f"\n description: {schema['description']}"
- func_str += f"\n params:"
+ func_str += "\n params:"
if add_inner_thoughts:
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
@@ -87,8 +87,8 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
prompt = ""
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
- prompt += f"Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
- prompt += f"\nAvailable functions:"
+ prompt += "Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
+ prompt += "\nAvailable functions:"
for function_dict in functions:
prompt += f"\n{self._compile_function_description(function_dict)}"
@@ -101,8 +101,8 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
prompt += system_message
prompt += "\n"
if function_documentation is not None:
- prompt += f"Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
- prompt += f"\nAvailable functions:\n"
+ prompt += "Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
+ prompt += "\nAvailable functions:\n"
prompt += function_documentation
else:
prompt += self._compile_function_block(functions)
@@ -230,7 +230,6 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
prompt += f"\n<|im_start|>{role_str}\n{msg_str.strip()}<|im_end|>"
elif message["role"] == "system":
-
role_str = "system"
msg_str = self._compile_system_message(
system_message=message["content"], functions=functions, function_documentation=function_documentation
@@ -255,7 +254,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
raise ValueError(message)
if self.include_assistant_prefix:
- prompt += f"\n<|im_start|>assistant"
+ prompt += "\n<|im_start|>assistant"
if self.assistant_prefix_hint:
prompt += f"\n{FIRST_PREFIX_HINT if first_message else PREFIX_HINT}"
if self.supports_first_message and first_message:
@@ -386,7 +385,7 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper):
"You must always include inner thoughts, but you do not always have to call a function.",
]
)
- prompt += f"\nAvailable functions:"
+ prompt += "\nAvailable functions:"
for function_dict in functions:
prompt += f"\n{self._compile_function_description(function_dict, add_inner_thoughts=False)}"
diff --git a/letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py b/letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py
index fa0fa839..9f53fa83 100644
--- a/letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py
+++ b/letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py
@@ -91,9 +91,9 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
func_str = ""
func_str += f"{schema['name']}:"
func_str += f"\n description: {schema['description']}"
- func_str += f"\n params:"
+ func_str += "\n params:"
if add_inner_thoughts:
- func_str += f"\n inner_thoughts: Deep inner monologue private to you only."
+ func_str += "\n inner_thoughts: Deep inner monologue private to you only."
for param_k, param_v in schema["parameters"]["properties"].items():
# TODO we're ignoring type
func_str += f"\n {param_k}: {param_v['description']}"
@@ -105,8 +105,8 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
prompt = ""
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
- prompt += f"Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
- prompt += f"\nAvailable functions:"
+ prompt += "Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
+ prompt += "\nAvailable functions:"
for function_dict in functions:
prompt += f"\n{self._compile_function_description(function_dict)}"
@@ -117,8 +117,8 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
prompt = system_message
prompt += "\n"
if function_documentation is not None:
- prompt += f"Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
- prompt += f"\nAvailable functions:"
+ prompt += "Please select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
+ prompt += "\nAvailable functions:"
prompt += function_documentation
else:
prompt += self._compile_function_block(functions)
diff --git a/letta/local_llm/llm_chat_completion_wrappers/dolphin.py b/letta/local_llm/llm_chat_completion_wrappers/dolphin.py
index 6a7f0852..e393d9b1 100644
--- a/letta/local_llm/llm_chat_completion_wrappers/dolphin.py
+++ b/letta/local_llm/llm_chat_completion_wrappers/dolphin.py
@@ -85,7 +85,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
func_str = ""
func_str += f"{schema['name']}:"
func_str += f"\n description: {schema['description']}"
- func_str += f"\n params:"
+ func_str += "\n params:"
for param_k, param_v in schema["parameters"]["properties"].items():
# TODO we're ignoring type
func_str += f"\n {param_k}: {param_v['description']}"
@@ -93,8 +93,8 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
return func_str
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
- prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
- prompt += f"\nAvailable functions:"
+ prompt += "\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
+ prompt += "\nAvailable functions:"
if function_documentation is not None:
prompt += f"\n{function_documentation}"
else:
diff --git a/letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py b/letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py
index c69f0960..d20bd2d3 100644
--- a/letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py
+++ b/letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py
@@ -124,7 +124,7 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper):
if self.include_assistant_prefix:
# prompt += f"\nASSISTANT:"
- prompt += f"\nSUMMARY:"
+ prompt += "\nSUMMARY:"
# print(prompt)
return prompt
diff --git a/letta/local_llm/ollama/api.py b/letta/local_llm/ollama/api.py
index 00bdf509..69926a43 100644
--- a/letta/local_llm/ollama/api.py
+++ b/letta/local_llm/ollama/api.py
@@ -18,7 +18,7 @@ def get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_
if model is None:
raise LocalLLMError(
- f"Error: model name not specified. Set model in your config to the model you want to run (e.g. 'dolphin2.2-mistral')"
+ "Error: model name not specified. Set model in your config to the model you want to run (e.g. 'dolphin2.2-mistral')"
)
# Settings for the generation, includes the prompt + stop tokens, max length, etc
@@ -51,7 +51,7 @@ def get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_
# Set grammar
if grammar is not None:
# request["grammar_string"] = load_grammar_file(grammar)
- raise NotImplementedError(f"Ollama does not support grammars")
+ raise NotImplementedError("Ollama does not support grammars")
if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
diff --git a/letta/schemas/embedding_config.py b/letta/schemas/embedding_config.py
index 38e6542d..1020382b 100644
--- a/letta/schemas/embedding_config.py
+++ b/letta/schemas/embedding_config.py
@@ -2,6 +2,8 @@ from typing import Literal, Optional
from pydantic import BaseModel, Field
+from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
+
class EmbeddingConfig(BaseModel):
"""
@@ -62,7 +64,7 @@ class EmbeddingConfig(BaseModel):
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
- embedding_chunk_size=300,
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
)
if (model_name == "text-embedding-3-small" and provider == "openai") or (not model_name and provider == "openai"):
return cls(
@@ -70,14 +72,14 @@ class EmbeddingConfig(BaseModel):
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=2000,
- embedding_chunk_size=300,
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
)
elif model_name == "letta":
return cls(
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_model="BAAI/bge-large-en-v1.5",
embedding_dim=1024,
- embedding_chunk_size=300,
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
embedding_endpoint_type="hugging-face",
)
else:
diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py
index cbe922ca..a76178f9 100644
--- a/letta/schemas/enums.py
+++ b/letta/schemas/enums.py
@@ -8,6 +8,7 @@ class ProviderType(str, Enum):
openai = "openai"
letta = "letta"
deepseek = "deepseek"
+ cerebras = "cerebras"
lmstudio_openai = "lmstudio_openai"
xai = "xai"
mistral = "mistral"
@@ -17,6 +18,7 @@ class ProviderType(str, Enum):
azure = "azure"
vllm = "vllm"
bedrock = "bedrock"
+ cohere = "cohere"
class ProviderCategory(str, Enum):
diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py
deleted file mode 100644
index 97d68281..00000000
--- a/letta/schemas/providers.py
+++ /dev/null
@@ -1,1618 +0,0 @@
-import warnings
-from datetime import datetime
-from typing import List, Literal, Optional
-
-import aiohttp
-import requests
-from pydantic import BaseModel, Field, model_validator
-
-from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
-from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
-from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
-from letta.schemas.embedding_config import EmbeddingConfig
-from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
-from letta.schemas.enums import ProviderCategory, ProviderType
-from letta.schemas.letta_base import LettaBase
-from letta.schemas.llm_config import LLMConfig
-from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
-from letta.settings import model_settings
-
-
-class ProviderBase(LettaBase):
- __id_prefix__ = "provider"
-
-
-class Provider(ProviderBase):
- id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
- name: str = Field(..., description="The name of the provider")
- provider_type: ProviderType = Field(..., description="The type of the provider")
- provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)")
- api_key: Optional[str] = Field(None, description="API key or secret key used for requests to the provider.")
- base_url: Optional[str] = Field(None, description="Base URL for the provider.")
- access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.")
- region: Optional[str] = Field(None, description="Region used for requests to the provider.")
- organization_id: Optional[str] = Field(None, description="The organization id of the user")
- updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.")
-
- @model_validator(mode="after")
- def default_base_url(self):
- if self.provider_type == ProviderType.openai and self.base_url is None:
- self.base_url = model_settings.openai_api_base
- return self
-
- def resolve_identifier(self):
- if not self.id:
- self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
-
- def check_api_key(self):
- """Check if the API key is valid for the provider"""
- raise NotImplementedError
-
- def list_llm_models(self) -> List[LLMConfig]:
- return []
-
- async def list_llm_models_async(self) -> List[LLMConfig]:
- return []
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- return []
-
- async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
- return self.list_embedding_models()
-
- def get_model_context_window(self, model_name: str) -> Optional[int]:
- raise NotImplementedError
-
- async def get_model_context_window_async(self, model_name: str) -> Optional[int]:
- raise NotImplementedError
-
- def provider_tag(self) -> str:
- """String representation of the provider for display purposes"""
- raise NotImplementedError
-
- def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str:
- """
- Get the handle for a model, with support for custom overrides.
-
- Args:
- model_name (str): The name of the model.
- is_embedding (bool, optional): Whether the handle is for an embedding model. Defaults to False.
-
- Returns:
- str: The handle for the model.
- """
- base_name = base_name if base_name else self.name
-
- overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES
- if base_name in overrides and model_name in overrides[base_name]:
- model_name = overrides[base_name][model_name]
-
- return f"{base_name}/{model_name}"
-
- def cast_to_subtype(self):
- match self.provider_type:
- case ProviderType.letta:
- return LettaProvider(**self.model_dump(exclude_none=True))
- case ProviderType.openai:
- return OpenAIProvider(**self.model_dump(exclude_none=True))
- case ProviderType.anthropic:
- return AnthropicProvider(**self.model_dump(exclude_none=True))
- case ProviderType.bedrock:
- return BedrockProvider(**self.model_dump(exclude_none=True))
- case ProviderType.ollama:
- return OllamaProvider(**self.model_dump(exclude_none=True))
- case ProviderType.google_ai:
- return GoogleAIProvider(**self.model_dump(exclude_none=True))
- case ProviderType.google_vertex:
- return GoogleVertexProvider(**self.model_dump(exclude_none=True))
- case ProviderType.azure:
- return AzureProvider(**self.model_dump(exclude_none=True))
- case ProviderType.groq:
- return GroqProvider(**self.model_dump(exclude_none=True))
- case ProviderType.together:
- return TogetherProvider(**self.model_dump(exclude_none=True))
- case ProviderType.vllm_chat_completions:
- return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True))
- case ProviderType.vllm_completions:
- return VLLMCompletionsProvider(**self.model_dump(exclude_none=True))
- case ProviderType.xai:
- return XAIProvider(**self.model_dump(exclude_none=True))
- case _:
- raise ValueError(f"Unknown provider type: {self.provider_type}")
-
-
-class ProviderCreate(ProviderBase):
- name: str = Field(..., description="The name of the provider.")
- provider_type: ProviderType = Field(..., description="The type of the provider.")
- api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
- access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.")
- region: Optional[str] = Field(None, description="Region used for requests to the provider.")
-
-
-class ProviderUpdate(ProviderBase):
- api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
- access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.")
- region: Optional[str] = Field(None, description="Region used for requests to the provider.")
-
-
-class ProviderCheck(BaseModel):
- provider_type: ProviderType = Field(..., description="The type of the provider.")
- api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
- access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.")
- region: Optional[str] = Field(None, description="Region used for requests to the provider.")
-
-
-class LettaProvider(Provider):
- provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
-
- def list_llm_models(self) -> List[LLMConfig]:
- return [
- LLMConfig(
- model="letta-free", # NOTE: renamed
- model_endpoint_type="openai",
- model_endpoint=LETTA_MODEL_ENDPOINT,
- context_window=30000,
- handle=self.get_handle("letta-free"),
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- ]
-
- async def list_llm_models_async(self) -> List[LLMConfig]:
- return [
- LLMConfig(
- model="letta-free", # NOTE: renamed
- model_endpoint_type="openai",
- model_endpoint=LETTA_MODEL_ENDPOINT,
- context_window=30000,
- handle=self.get_handle("letta-free"),
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- ]
-
- def list_embedding_models(self):
- return [
- EmbeddingConfig(
- embedding_model="letta-free", # NOTE: renamed
- embedding_endpoint_type="hugging-face",
- embedding_endpoint="https://embeddings.memgpt.ai",
- embedding_dim=1024,
- embedding_chunk_size=300,
- handle=self.get_handle("letta-free", is_embedding=True),
- batch_size=32,
- )
- ]
-
-
-class OpenAIProvider(Provider):
- provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- api_key: str = Field(..., description="API key for the OpenAI API.")
- base_url: str = Field(..., description="Base URL for the OpenAI API.")
-
- def check_api_key(self):
- from letta.llm_api.openai import openai_check_valid_api_key
-
- openai_check_valid_api_key(self.base_url, self.api_key)
-
- def _get_models(self) -> List[dict]:
- from letta.llm_api.openai import openai_get_model_list
-
- # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
- # See: https://openrouter.ai/docs/requests
- extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
-
- # Similar to Nebius
- extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
-
- response = openai_get_model_list(
- self.base_url,
- api_key=self.api_key,
- extra_params=extra_params,
- # fix_url=True, # NOTE: make sure together ends with /v1
- )
-
- if "data" in response:
- data = response["data"]
- else:
- # TogetherAI's response is missing the 'data' field
- data = response
-
- return data
-
- async def _get_models_async(self) -> List[dict]:
- from letta.llm_api.openai import openai_get_model_list_async
-
- # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
- # See: https://openrouter.ai/docs/requests
- extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
-
- # Similar to Nebius
- extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
-
- response = await openai_get_model_list_async(
- self.base_url,
- api_key=self.api_key,
- extra_params=extra_params,
- # fix_url=True, # NOTE: make sure together ends with /v1
- )
-
- if "data" in response:
- data = response["data"]
- else:
- # TogetherAI's response is missing the 'data' field
- data = response
-
- return data
-
- def list_llm_models(self) -> List[LLMConfig]:
- data = self._get_models()
- return self._list_llm_models(data)
-
- async def list_llm_models_async(self) -> List[LLMConfig]:
- data = await self._get_models_async()
- return self._list_llm_models(data)
-
- def _list_llm_models(self, data) -> List[LLMConfig]:
- configs = []
- for model in data:
- assert "id" in model, f"OpenAI model missing 'id' field: {model}"
- model_name = model["id"]
-
- if "context_length" in model:
- # Context length is returned in OpenRouter as "context_length"
- context_window_size = model["context_length"]
- else:
- context_window_size = self.get_model_context_window_size(model_name)
-
- if not context_window_size:
- continue
-
- # TogetherAI includes the type, which we can use to filter out embedding models
- if "api.together.ai" in self.base_url or "api.together.xyz" in self.base_url:
- if "type" in model and model["type"] not in ["chat", "language"]:
- continue
-
- # for TogetherAI, we need to skip the models that don't support JSON mode / function calling
- # requests.exceptions.HTTPError: HTTP error occurred: 400 Client Error: Bad Request for url: https://api.together.ai/v1/chat/completions | Status code: 400, Message: {
- # "error": {
- # "message": "mistralai/Mixtral-8x7B-v0.1 is not supported for JSON mode/function calling",
- # "type": "invalid_request_error",
- # "param": null,
- # "code": "constraints_model"
- # }
- # }
- if "config" not in model:
- continue
-
- if "nebius.com" in self.base_url:
- # Nebius includes the type, which we can use to filter for text models
- try:
- model_type = model["architecture"]["modality"]
- if model_type not in ["text->text", "text+image->text"]:
- # print(f"Skipping model w/ modality {model_type}:\n{model}")
- continue
- except KeyError:
- print(f"Couldn't access architecture type field, skipping model:\n{model}")
- continue
-
- # for openai, filter models
- if self.base_url == "https://api.openai.com/v1":
- allowed_types = ["gpt-4", "o1", "o3", "o4"]
- # NOTE: o1-mini and o1-preview do not support tool calling
- # NOTE: o1-mini does not support system messages
- # NOTE: o1-pro is only available in Responses API
- disallowed_types = ["transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro"]
- skip = True
- for model_type in allowed_types:
- if model_name.startswith(model_type):
- skip = False
- break
- for keyword in disallowed_types:
- if keyword in model_name:
- skip = True
- break
- # ignore this model
- if skip:
- continue
-
- # set the handle to openai-proxy if the base URL isn't OpenAI
- if self.base_url != "https://api.openai.com/v1":
- handle = self.get_handle(model_name, base_name="openai-proxy")
- else:
- handle = self.get_handle(model_name)
-
- llm_config = LLMConfig(
- model=model_name,
- model_endpoint_type="openai",
- model_endpoint=self.base_url,
- context_window=context_window_size,
- handle=handle,
- provider_name=self.name,
- provider_category=self.provider_category,
- )
-
- # gpt-4o-mini has started to regress with pretty bad emoji spam loops
- # this is to counteract that
- if "gpt-4o-mini" in model_name:
- llm_config.frequency_penalty = 1.0
- if "gpt-4.1-mini" in model_name:
- llm_config.frequency_penalty = 1.0
-
- configs.append(llm_config)
-
- # for OpenAI, sort in reverse order
- if self.base_url == "https://api.openai.com/v1":
- # alphnumeric sort
- configs.sort(key=lambda x: x.model, reverse=True)
-
- return configs
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- if self.base_url == "https://api.openai.com/v1":
- # TODO: actually automatically list models for OpenAI
- return [
- EmbeddingConfig(
- embedding_model="text-embedding-ada-002",
- embedding_endpoint_type="openai",
- embedding_endpoint=self.base_url,
- embedding_dim=1536,
- embedding_chunk_size=300,
- handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
- batch_size=1024,
- ),
- EmbeddingConfig(
- embedding_model="text-embedding-3-small",
- embedding_endpoint_type="openai",
- embedding_endpoint=self.base_url,
- embedding_dim=2000,
- embedding_chunk_size=300,
- handle=self.get_handle("text-embedding-3-small", is_embedding=True),
- batch_size=1024,
- ),
- EmbeddingConfig(
- embedding_model="text-embedding-3-large",
- embedding_endpoint_type="openai",
- embedding_endpoint=self.base_url,
- embedding_dim=2000,
- embedding_chunk_size=300,
- handle=self.get_handle("text-embedding-3-large", is_embedding=True),
- batch_size=1024,
- ),
- ]
-
- else:
- # Actually attempt to list
- data = self._get_models()
- return self._list_embedding_models(data)
-
- async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
- if self.base_url == "https://api.openai.com/v1":
- # TODO: actually automatically list models for OpenAI
- return [
- EmbeddingConfig(
- embedding_model="text-embedding-ada-002",
- embedding_endpoint_type="openai",
- embedding_endpoint=self.base_url,
- embedding_dim=1536,
- embedding_chunk_size=300,
- handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
- batch_size=1024,
- ),
- EmbeddingConfig(
- embedding_model="text-embedding-3-small",
- embedding_endpoint_type="openai",
- embedding_endpoint=self.base_url,
- embedding_dim=2000,
- embedding_chunk_size=300,
- handle=self.get_handle("text-embedding-3-small", is_embedding=True),
- batch_size=1024,
- ),
- EmbeddingConfig(
- embedding_model="text-embedding-3-large",
- embedding_endpoint_type="openai",
- embedding_endpoint=self.base_url,
- embedding_dim=2000,
- embedding_chunk_size=300,
- handle=self.get_handle("text-embedding-3-large", is_embedding=True),
- batch_size=1024,
- ),
- ]
-
- else:
- # Actually attempt to list
- data = await self._get_models_async()
- return self._list_embedding_models(data)
-
- def _list_embedding_models(self, data) -> List[EmbeddingConfig]:
- configs = []
- for model in data:
- assert "id" in model, f"Model missing 'id' field: {model}"
- model_name = model["id"]
-
- if "context_length" in model:
- # Context length is returned in Nebius as "context_length"
- context_window_size = model["context_length"]
- else:
- context_window_size = self.get_model_context_window_size(model_name)
-
- # We need the context length for embeddings too
- if not context_window_size:
- continue
-
- if "nebius.com" in self.base_url:
- # Nebius includes the type, which we can use to filter for embedidng models
- try:
- model_type = model["architecture"]["modality"]
- if model_type not in ["text->embedding"]:
- # print(f"Skipping model w/ modality {model_type}:\n{model}")
- continue
- except KeyError:
- print(f"Couldn't access architecture type field, skipping model:\n{model}")
- continue
-
- elif "together.ai" in self.base_url or "together.xyz" in self.base_url:
- # TogetherAI includes the type, which we can use to filter for embedding models
- if "type" in model and model["type"] not in ["embedding"]:
- # print(f"Skipping model w/ modality {model_type}:\n{model}")
- continue
-
- else:
- # For other providers we should skip by default, since we don't want to assume embeddings are supported
- continue
-
- configs.append(
- EmbeddingConfig(
- embedding_model=model_name,
- embedding_endpoint_type=self.provider_type,
- embedding_endpoint=self.base_url,
- embedding_dim=context_window_size,
- embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
- handle=self.get_handle(model, is_embedding=True),
- )
- )
-
- return configs
-
- def get_model_context_window_size(self, model_name: str):
- if model_name in LLM_MAX_TOKENS:
- return LLM_MAX_TOKENS[model_name]
- else:
- return LLM_MAX_TOKENS["DEFAULT"]
-
-
-class DeepSeekProvider(OpenAIProvider):
- """
- DeepSeek ChatCompletions API is similar to OpenAI's reasoning API,
- but with slight differences:
- * For example, DeepSeek's API requires perfect interleaving of user/assistant
- * It also does not support native function calling
- """
-
- provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
- api_key: str = Field(..., description="API key for the DeepSeek API.")
-
- def get_model_context_window_size(self, model_name: str) -> Optional[int]:
- # DeepSeek doesn't return context window in the model listing,
- # so these are hardcoded from their website
- if model_name == "deepseek-reasoner":
- return 64000
- elif model_name == "deepseek-chat":
- return 64000
- else:
- return None
-
- def list_llm_models(self) -> List[LLMConfig]:
- from letta.llm_api.openai import openai_get_model_list
-
- response = openai_get_model_list(self.base_url, api_key=self.api_key)
-
- if "data" in response:
- data = response["data"]
- else:
- data = response
-
- configs = []
- for model in data:
- assert "id" in model, f"DeepSeek model missing 'id' field: {model}"
- model_name = model["id"]
-
- # In case DeepSeek starts supporting it in the future:
- if "context_length" in model:
- # Context length is returned in OpenRouter as "context_length"
- context_window_size = model["context_length"]
- else:
- context_window_size = self.get_model_context_window_size(model_name)
-
- if not context_window_size:
- warnings.warn(f"Couldn't find context window size for model {model_name}")
- continue
-
- # Not used for deepseek-reasoner, but otherwise is true
- put_inner_thoughts_in_kwargs = False if model_name == "deepseek-reasoner" else True
-
- configs.append(
- LLMConfig(
- model=model_name,
- model_endpoint_type="deepseek",
- model_endpoint=self.base_url,
- context_window=context_window_size,
- handle=self.get_handle(model_name),
- put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
-
- return configs
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- # No embeddings supported
- return []
-
-
-class LMStudioOpenAIProvider(OpenAIProvider):
- provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
- api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
-
- def list_llm_models(self) -> List[LLMConfig]:
- from letta.llm_api.openai import openai_get_model_list
-
- # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models'
- MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0"
- response = openai_get_model_list(MODEL_ENDPOINT_URL)
-
- """
- Example response:
-
- {
- "object": "list",
- "data": [
- {
- "id": "qwen2-vl-7b-instruct",
- "object": "model",
- "type": "vlm",
- "publisher": "mlx-community",
- "arch": "qwen2_vl",
- "compatibility_type": "mlx",
- "quantization": "4bit",
- "state": "not-loaded",
- "max_context_length": 32768
- },
- ...
- """
- if "data" not in response:
- warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}")
- return []
-
- configs = []
- for model in response["data"]:
- assert "id" in model, f"Model missing 'id' field: {model}"
- model_name = model["id"]
-
- if "type" not in model:
- warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}")
- continue
- elif model["type"] not in ["vlm", "llm"]:
- continue
-
- if "max_context_length" in model:
- context_window_size = model["max_context_length"]
- else:
- warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}")
- continue
-
- configs.append(
- LLMConfig(
- model=model_name,
- model_endpoint_type="openai",
- model_endpoint=self.base_url,
- context_window=context_window_size,
- handle=self.get_handle(model_name),
- )
- )
-
- return configs
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- from letta.llm_api.openai import openai_get_model_list
-
- # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models'
- MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0"
- response = openai_get_model_list(MODEL_ENDPOINT_URL)
-
- """
- Example response:
- {
- "object": "list",
- "data": [
- {
- "id": "text-embedding-nomic-embed-text-v1.5",
- "object": "model",
- "type": "embeddings",
- "publisher": "nomic-ai",
- "arch": "nomic-bert",
- "compatibility_type": "gguf",
- "quantization": "Q4_0",
- "state": "not-loaded",
- "max_context_length": 2048
- }
- ...
- """
- if "data" not in response:
- warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}")
- return []
-
- configs = []
- for model in response["data"]:
- assert "id" in model, f"Model missing 'id' field: {model}"
- model_name = model["id"]
-
- if "type" not in model:
- warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}")
- continue
- elif model["type"] not in ["embeddings"]:
- continue
-
- if "max_context_length" in model:
- context_window_size = model["max_context_length"]
- else:
- warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}")
- continue
-
- configs.append(
- EmbeddingConfig(
- embedding_model=model_name,
- embedding_endpoint_type="openai",
- embedding_endpoint=self.base_url,
- embedding_dim=context_window_size,
- embedding_chunk_size=300, # NOTE: max is 2048
- handle=self.get_handle(model_name),
- ),
- )
-
- return configs
-
-
-class XAIProvider(OpenAIProvider):
- """https://docs.x.ai/docs/api-reference"""
-
- provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- api_key: str = Field(..., description="API key for the xAI/Grok API.")
- base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
-
- def get_model_context_window_size(self, model_name: str) -> Optional[int]:
- # xAI doesn't return context window in the model listing,
- # so these are hardcoded from their website
- if model_name == "grok-2-1212":
- return 131072
- # NOTE: disabling the minis for now since they return weird MM parts
- # elif model_name == "grok-3-mini-fast-beta":
- # return 131072
- # elif model_name == "grok-3-mini-beta":
- # return 131072
- elif model_name == "grok-3-fast-beta":
- return 131072
- elif model_name == "grok-3-beta":
- return 131072
- else:
- return None
-
- def list_llm_models(self) -> List[LLMConfig]:
- from letta.llm_api.openai import openai_get_model_list
-
- response = openai_get_model_list(self.base_url, api_key=self.api_key)
-
- if "data" in response:
- data = response["data"]
- else:
- data = response
-
- configs = []
- for model in data:
- assert "id" in model, f"xAI/Grok model missing 'id' field: {model}"
- model_name = model["id"]
-
- # In case xAI starts supporting it in the future:
- if "context_length" in model:
- context_window_size = model["context_length"]
- else:
- context_window_size = self.get_model_context_window_size(model_name)
-
- if not context_window_size:
- warnings.warn(f"Couldn't find context window size for model {model_name}")
- continue
-
- configs.append(
- LLMConfig(
- model=model_name,
- model_endpoint_type="xai",
- model_endpoint=self.base_url,
- context_window=context_window_size,
- handle=self.get_handle(model_name),
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
-
- return configs
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- # No embeddings supported
- return []
-
-
-class AnthropicProvider(Provider):
- provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- api_key: str = Field(..., description="API key for the Anthropic API.")
- base_url: str = "https://api.anthropic.com/v1"
-
- def check_api_key(self):
- from letta.llm_api.anthropic import anthropic_check_valid_api_key
-
- anthropic_check_valid_api_key(self.api_key)
-
- def list_llm_models(self) -> List[LLMConfig]:
- from letta.llm_api.anthropic import anthropic_get_model_list
-
- models = anthropic_get_model_list(api_key=self.api_key)
- return self._list_llm_models(models)
-
- async def list_llm_models_async(self) -> List[LLMConfig]:
- from letta.llm_api.anthropic import anthropic_get_model_list_async
-
- models = await anthropic_get_model_list_async(api_key=self.api_key)
- return self._list_llm_models(models)
-
- def _list_llm_models(self, models) -> List[LLMConfig]:
- from letta.llm_api.anthropic import MODEL_LIST
-
- configs = []
- for model in models:
-
- if model["type"] != "model":
- continue
-
- if "id" not in model:
- continue
-
- # Don't support 2.0 and 2.1
- if model["id"].startswith("claude-2"):
- continue
-
- # Anthropic doesn't return the context window in their API
- if "context_window" not in model:
- # Remap list to name: context_window
- model_library = {m["name"]: m["context_window"] for m in MODEL_LIST}
- # Attempt to look it up in a hardcoded list
- if model["id"] in model_library:
- model["context_window"] = model_library[model["id"]]
- else:
- # On fallback, we can set 200k (generally safe), but we should warn the user
- warnings.warn(f"Couldn't find context window size for model {model['id']}, defaulting to 200,000")
- model["context_window"] = 200000
-
- max_tokens = 8192
- if "claude-3-opus" in model["id"]:
- max_tokens = 4096
- if "claude-3-haiku" in model["id"]:
- max_tokens = 4096
- # TODO: set for 3-7 extended thinking mode
-
- # We set this to false by default, because Anthropic can
- # natively support tags inside of content fields
- # However, putting COT inside of tool calls can make it more
- # reliable for tool calling (no chance of a non-tool call step)
- # Since tool_choice_type 'any' doesn't work with in-content COT
- # NOTE For Haiku, it can be flaky if we don't enable this by default
- # inner_thoughts_in_kwargs = True if "haiku" in model["id"] else False
- inner_thoughts_in_kwargs = True # we no longer support thinking tags
-
- configs.append(
- LLMConfig(
- model=model["id"],
- model_endpoint_type="anthropic",
- model_endpoint=self.base_url,
- context_window=model["context_window"],
- handle=self.get_handle(model["id"]),
- put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
- max_tokens=max_tokens,
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
- return configs
-
-
-class MistralProvider(Provider):
- provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- api_key: str = Field(..., description="API key for the Mistral API.")
- base_url: str = "https://api.mistral.ai/v1"
-
- def list_llm_models(self) -> List[LLMConfig]:
- from letta.llm_api.mistral import mistral_get_model_list
-
- # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
- # See: https://openrouter.ai/docs/requests
- response = mistral_get_model_list(self.base_url, api_key=self.api_key)
-
- assert "data" in response, f"Mistral model query response missing 'data' field: {response}"
-
- configs = []
- for model in response["data"]:
- # If model has chat completions and function calling enabled
- if model["capabilities"]["completion_chat"] and model["capabilities"]["function_calling"]:
- configs.append(
- LLMConfig(
- model=model["id"],
- model_endpoint_type="openai",
- model_endpoint=self.base_url,
- context_window=model["max_context_length"],
- handle=self.get_handle(model["id"]),
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
-
- return configs
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- # Not supported for mistral
- return []
-
- def get_model_context_window(self, model_name: str) -> Optional[int]:
- # Redoing this is fine because it's a pretty lightweight call
- models = self.list_llm_models()
-
- for m in models:
- if model_name in m["id"]:
- return int(m["max_context_length"])
-
- return None
-
-
-class OllamaProvider(OpenAIProvider):
- """Ollama provider that uses the native /api/generate endpoint
-
- See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
- """
-
- provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- base_url: str = Field(..., description="Base URL for the Ollama API.")
- api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
- default_prompt_formatter: str = Field(
- ..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API."
- )
-
- async def list_llm_models_async(self) -> List[LLMConfig]:
- """Async version of list_llm_models below"""
- endpoint = f"{self.base_url}/api/tags"
- async with aiohttp.ClientSession() as session:
- async with session.get(endpoint) as response:
- if response.status != 200:
- raise Exception(f"Failed to list Ollama models: {response.text}")
- response_json = await response.json()
-
- configs = []
- for model in response_json["models"]:
- context_window = self.get_model_context_window(model["name"])
- if context_window is None:
- print(f"Ollama model {model['name']} has no context window")
- continue
- configs.append(
- LLMConfig(
- model=model["name"],
- model_endpoint_type="ollama",
- model_endpoint=self.base_url,
- model_wrapper=self.default_prompt_formatter,
- context_window=context_window,
- handle=self.get_handle(model["name"]),
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
- return configs
-
- def list_llm_models(self) -> List[LLMConfig]:
- # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
- response = requests.get(f"{self.base_url}/api/tags")
- if response.status_code != 200:
- raise Exception(f"Failed to list Ollama models: {response.text}")
- response_json = response.json()
-
- configs = []
- for model in response_json["models"]:
- context_window = self.get_model_context_window(model["name"])
- if context_window is None:
- print(f"Ollama model {model['name']} has no context window")
- continue
- configs.append(
- LLMConfig(
- model=model["name"],
- model_endpoint_type="ollama",
- model_endpoint=self.base_url,
- model_wrapper=self.default_prompt_formatter,
- context_window=context_window,
- handle=self.get_handle(model["name"]),
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
- return configs
-
- def get_model_context_window(self, model_name: str) -> Optional[int]:
- response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
- response_json = response.json()
-
- ## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
- # possible_keys = [
- # # OPT
- # "max_position_embeddings",
- # # GPT-2
- # "n_positions",
- # # MPT
- # "max_seq_len",
- # # ChatGLM2
- # "seq_length",
- # # Command-R
- # "model_max_length",
- # # Others
- # "max_sequence_length",
- # "max_seq_length",
- # "seq_len",
- # ]
- # max_position_embeddings
- # parse model cards: nous, dolphon, llama
- if "model_info" not in response_json:
- if "error" in response_json:
- print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
- return None
- for key, value in response_json["model_info"].items():
- if "context_length" in key:
- return value
- return None
-
- def _get_model_embedding_dim(self, model_name: str):
- response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
- response_json = response.json()
- return self._get_model_embedding_dim_impl(response_json, model_name)
-
- async def _get_model_embedding_dim_async(self, model_name: str):
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) as response:
- response_json = await response.json()
- return self._get_model_embedding_dim_impl(response_json, model_name)
-
- @staticmethod
- def _get_model_embedding_dim_impl(response_json: dict, model_name: str):
- if "model_info" not in response_json:
- if "error" in response_json:
- print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
- return None
- for key, value in response_json["model_info"].items():
- if "embedding_length" in key:
- return value
- return None
-
- async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
- """Async version of list_embedding_models below"""
- endpoint = f"{self.base_url}/api/tags"
- async with aiohttp.ClientSession() as session:
- async with session.get(endpoint) as response:
- if response.status != 200:
- raise Exception(f"Failed to list Ollama models: {response.text}")
- response_json = await response.json()
-
- configs = []
- for model in response_json["models"]:
- embedding_dim = await self._get_model_embedding_dim_async(model["name"])
- if not embedding_dim:
- print(f"Ollama model {model['name']} has no embedding dimension")
- continue
- configs.append(
- EmbeddingConfig(
- embedding_model=model["name"],
- embedding_endpoint_type="ollama",
- embedding_endpoint=self.base_url,
- embedding_dim=embedding_dim,
- embedding_chunk_size=300,
- handle=self.get_handle(model["name"], is_embedding=True),
- )
- )
- return configs
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
- response = requests.get(f"{self.base_url}/api/tags")
- if response.status_code != 200:
- raise Exception(f"Failed to list Ollama models: {response.text}")
- response_json = response.json()
-
- configs = []
- for model in response_json["models"]:
- embedding_dim = self._get_model_embedding_dim(model["name"])
- if not embedding_dim:
- print(f"Ollama model {model['name']} has no embedding dimension")
- continue
- configs.append(
- EmbeddingConfig(
- embedding_model=model["name"],
- embedding_endpoint_type="ollama",
- embedding_endpoint=self.base_url,
- embedding_dim=embedding_dim,
- embedding_chunk_size=300,
- handle=self.get_handle(model["name"], is_embedding=True),
- )
- )
- return configs
-
-
-class GroqProvider(OpenAIProvider):
- provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- base_url: str = "https://api.groq.com/openai/v1"
- api_key: str = Field(..., description="API key for the Groq API.")
-
- def list_llm_models(self) -> List[LLMConfig]:
- from letta.llm_api.openai import openai_get_model_list
-
- response = openai_get_model_list(self.base_url, api_key=self.api_key)
- configs = []
- for model in response["data"]:
- if not "context_window" in model:
- continue
- configs.append(
- LLMConfig(
- model=model["id"],
- model_endpoint_type="groq",
- model_endpoint=self.base_url,
- context_window=model["context_window"],
- handle=self.get_handle(model["id"]),
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
- return configs
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- return []
-
-
-class TogetherProvider(OpenAIProvider):
- """TogetherAI provider that uses the /completions API
-
- TogetherAI can also be used via the /chat/completions API
- by settings OPENAI_API_KEY and OPENAI_API_BASE to the TogetherAI API key
- and API URL, however /completions is preferred because their /chat/completions
- function calling support is limited.
- """
-
- provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- base_url: str = "https://api.together.ai/v1"
- api_key: str = Field(..., description="API key for the TogetherAI API.")
- default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
-
- def list_llm_models(self) -> List[LLMConfig]:
- from letta.llm_api.openai import openai_get_model_list
-
- models = openai_get_model_list(self.base_url, api_key=self.api_key)
- return self._list_llm_models(models)
-
- async def list_llm_models_async(self) -> List[LLMConfig]:
- from letta.llm_api.openai import openai_get_model_list_async
-
- models = await openai_get_model_list_async(self.base_url, api_key=self.api_key)
- return self._list_llm_models(models)
-
- def _list_llm_models(self, models) -> List[LLMConfig]:
- pass
-
- # TogetherAI's response is missing the 'data' field
- # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
- if "data" in models:
- data = models["data"]
- else:
- data = models
-
- configs = []
- for model in data:
- assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
- model_name = model["id"]
-
- if "context_length" in model:
- # Context length is returned in OpenRouter as "context_length"
- context_window_size = model["context_length"]
- else:
- context_window_size = self.get_model_context_window_size(model_name)
-
- # We need the context length for embeddings too
- if not context_window_size:
- continue
-
- # Skip models that are too small for Letta
- if context_window_size <= MIN_CONTEXT_WINDOW:
- continue
-
- # TogetherAI includes the type, which we can use to filter for embedding models
- if "type" in model and model["type"] not in ["chat", "language"]:
- continue
-
- configs.append(
- LLMConfig(
- model=model_name,
- model_endpoint_type="together",
- model_endpoint=self.base_url,
- model_wrapper=self.default_prompt_formatter,
- context_window=context_window_size,
- handle=self.get_handle(model_name),
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
-
- return configs
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- # TODO renable once we figure out how to pass API keys through properly
- return []
-
- # from letta.llm_api.openai import openai_get_model_list
-
- # response = openai_get_model_list(self.base_url, api_key=self.api_key)
-
- # # TogetherAI's response is missing the 'data' field
- # # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
- # if "data" in response:
- # data = response["data"]
- # else:
- # data = response
-
- # configs = []
- # for model in data:
- # assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
- # model_name = model["id"]
-
- # if "context_length" in model:
- # # Context length is returned in OpenRouter as "context_length"
- # context_window_size = model["context_length"]
- # else:
- # context_window_size = self.get_model_context_window_size(model_name)
-
- # if not context_window_size:
- # continue
-
- # # TogetherAI includes the type, which we can use to filter out embedding models
- # if "type" in model and model["type"] not in ["embedding"]:
- # continue
-
- # configs.append(
- # EmbeddingConfig(
- # embedding_model=model_name,
- # embedding_endpoint_type="openai",
- # embedding_endpoint=self.base_url,
- # embedding_dim=context_window_size,
- # embedding_chunk_size=300, # TODO: change?
- # )
- # )
-
- # return configs
-
-
-class GoogleAIProvider(Provider):
- # gemini
- provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- api_key: str = Field(..., description="API key for the Google AI API.")
- base_url: str = "https://generativelanguage.googleapis.com"
-
- def check_api_key(self):
- from letta.llm_api.google_ai_client import google_ai_check_valid_api_key
-
- google_ai_check_valid_api_key(self.api_key)
-
- def list_llm_models(self):
- from letta.llm_api.google_ai_client import google_ai_get_model_list
-
- model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
- model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
- model_options = [str(m["name"]) for m in model_options]
-
- # filter by model names
- model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
-
- # Add support for all gemini models
- model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
-
- configs = []
- for model in model_options:
- configs.append(
- LLMConfig(
- model=model,
- model_endpoint_type="google_ai",
- model_endpoint=self.base_url,
- context_window=self.get_model_context_window(model),
- handle=self.get_handle(model),
- max_tokens=8192,
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
-
- return configs
-
- async def list_llm_models_async(self):
- import asyncio
-
- from letta.llm_api.google_ai_client import google_ai_get_model_list_async
-
- # Get and filter the model list
- model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
- model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
- model_options = [str(m["name"]) for m in model_options]
-
- # filter by model names
- model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
-
- # Add support for all gemini models
- model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
-
- # Prepare tasks for context window lookups in parallel
- async def create_config(model):
- context_window = await self.get_model_context_window_async(model)
- return LLMConfig(
- model=model,
- model_endpoint_type="google_ai",
- model_endpoint=self.base_url,
- context_window=context_window,
- handle=self.get_handle(model),
- max_tokens=8192,
- provider_name=self.name,
- provider_category=self.provider_category,
- )
-
- # Execute all config creation tasks concurrently
- configs = await asyncio.gather(*[create_config(model) for model in model_options])
-
- return configs
-
- def list_embedding_models(self):
- from letta.llm_api.google_ai_client import google_ai_get_model_list
-
- # TODO: use base_url instead
- model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
- return self._list_embedding_models(model_options)
-
- async def list_embedding_models_async(self):
- from letta.llm_api.google_ai_client import google_ai_get_model_list_async
-
- # TODO: use base_url instead
- model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
- return self._list_embedding_models(model_options)
-
- def _list_embedding_models(self, model_options):
- # filter by 'generateContent' models
- model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
- model_options = [str(m["name"]) for m in model_options]
- model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
-
- configs = []
- for model in model_options:
- configs.append(
- EmbeddingConfig(
- embedding_model=model,
- embedding_endpoint_type="google_ai",
- embedding_endpoint=self.base_url,
- embedding_dim=768,
- embedding_chunk_size=300, # NOTE: max is 2048
- handle=self.get_handle(model, is_embedding=True),
- batch_size=1024,
- )
- )
- return configs
-
- def get_model_context_window(self, model_name: str) -> Optional[int]:
- from letta.llm_api.google_ai_client import google_ai_get_model_context_window
-
- if model_name in LLM_MAX_TOKENS:
- return LLM_MAX_TOKENS[model_name]
- else:
- return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
-
- async def get_model_context_window_async(self, model_name: str) -> Optional[int]:
- from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async
-
- if model_name in LLM_MAX_TOKENS:
- return LLM_MAX_TOKENS[model_name]
- else:
- return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name)
-
-
-class GoogleVertexProvider(Provider):
- provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
- google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
-
- def list_llm_models(self) -> List[LLMConfig]:
- from letta.llm_api.google_constants import GOOGLE_MODEL_TO_CONTEXT_LENGTH
-
- configs = []
- for model, context_length in GOOGLE_MODEL_TO_CONTEXT_LENGTH.items():
- configs.append(
- LLMConfig(
- model=model,
- model_endpoint_type="google_vertex",
- model_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}",
- context_window=context_length,
- handle=self.get_handle(model),
- max_tokens=8192,
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
- return configs
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- from letta.llm_api.google_constants import GOOGLE_EMBEDING_MODEL_TO_DIM
-
- configs = []
- for model, dim in GOOGLE_EMBEDING_MODEL_TO_DIM.items():
- configs.append(
- EmbeddingConfig(
- embedding_model=model,
- embedding_endpoint_type="google_vertex",
- embedding_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}",
- embedding_dim=dim,
- embedding_chunk_size=300, # NOTE: max is 2048
- handle=self.get_handle(model, is_embedding=True),
- batch_size=1024,
- )
- )
- return configs
-
-
-class AzureProvider(Provider):
- provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
- base_url: str = Field(
- ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
- )
- api_key: str = Field(..., description="API key for the Azure API.")
- api_version: str = Field(latest_api_version, description="API version for the Azure API")
-
- @model_validator(mode="before")
- def set_default_api_version(cls, values):
- """
- This ensures that api_version is always set to the default if None is passed in.
- """
- if values.get("api_version") is None:
- values["api_version"] = cls.model_fields["latest_api_version"].default
- return values
-
- def list_llm_models(self) -> List[LLMConfig]:
- from letta.llm_api.azure_openai import azure_openai_get_chat_completion_model_list
-
- model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
- configs = []
- for model_option in model_options:
- model_name = model_option["id"]
- context_window_size = self.get_model_context_window(model_name)
- model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
- configs.append(
- LLMConfig(
- model=model_name,
- model_endpoint_type="azure",
- model_endpoint=model_endpoint,
- context_window=context_window_size,
- handle=self.get_handle(model_name),
- provider_name=self.name,
- provider_category=self.provider_category,
- ),
- )
- return configs
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list
-
- model_options = azure_openai_get_embeddings_model_list(
- self.base_url, api_key=self.api_key, api_version=self.api_version, require_embedding_in_name=True
- )
- configs = []
- for model_option in model_options:
- model_name = model_option["id"]
- model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version)
- configs.append(
- EmbeddingConfig(
- embedding_model=model_name,
- embedding_endpoint_type="azure",
- embedding_endpoint=model_endpoint,
- embedding_dim=768,
- embedding_chunk_size=300, # NOTE: max is 2048
- handle=self.get_handle(model_name),
- batch_size=1024,
- ),
- )
- return configs
-
- def get_model_context_window(self, model_name: str) -> Optional[int]:
- """
- This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
- """
- context_window = AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, None)
- if context_window is None:
- context_window = LLM_MAX_TOKENS.get(model_name, 4096)
- return context_window
-
-
-class VLLMChatCompletionsProvider(Provider):
- """vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
-
- # NOTE: vLLM only serves one model at a time (so could configure that through env variables)
- provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- base_url: str = Field(..., description="Base URL for the vLLM API.")
-
- def list_llm_models(self) -> List[LLMConfig]:
- # not supported with vLLM
- from letta.llm_api.openai import openai_get_model_list
-
- assert self.base_url, "base_url is required for vLLM provider"
- response = openai_get_model_list(self.base_url, api_key=None)
-
- configs = []
- for model in response["data"]:
- configs.append(
- LLMConfig(
- model=model["id"],
- model_endpoint_type="openai",
- model_endpoint=self.base_url,
- context_window=model["max_model_len"],
- handle=self.get_handle(model["id"]),
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
- return configs
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- # not supported with vLLM
- return []
-
-
-class VLLMCompletionsProvider(Provider):
- """This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
-
- # NOTE: vLLM only serves one model at a time (so could configure that through env variables)
- provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- base_url: str = Field(..., description="Base URL for the vLLM API.")
- default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
-
- def list_llm_models(self) -> List[LLMConfig]:
- # not supported with vLLM
- from letta.llm_api.openai import openai_get_model_list
-
- response = openai_get_model_list(self.base_url, api_key=None)
-
- configs = []
- for model in response["data"]:
- configs.append(
- LLMConfig(
- model=model["id"],
- model_endpoint_type="vllm",
- model_endpoint=self.base_url,
- model_wrapper=self.default_prompt_formatter,
- context_window=model["max_model_len"],
- handle=self.get_handle(model["id"]),
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
- return configs
-
- def list_embedding_models(self) -> List[EmbeddingConfig]:
- # not supported with vLLM
- return []
-
-
-class CohereProvider(OpenAIProvider):
- pass
-
-
-class BedrockProvider(Provider):
- provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
- provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
- region: str = Field(..., description="AWS region for Bedrock")
-
- def check_api_key(self):
- """Check if the Bedrock credentials are valid"""
- from letta.errors import LLMAuthenticationError
- from letta.llm_api.aws_bedrock import bedrock_get_model_list
-
- try:
- # For BYOK providers, use the custom credentials
- if self.provider_category == ProviderCategory.byok:
- # If we can list models, the credentials are valid
- bedrock_get_model_list(
- region_name=self.region,
- access_key_id=self.access_key,
- secret_access_key=self.api_key, # api_key stores the secret access key
- )
- else:
- # For base providers, use default credentials
- bedrock_get_model_list(region_name=self.region)
- except Exception as e:
- raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}")
-
- def list_llm_models(self):
- from letta.llm_api.aws_bedrock import bedrock_get_model_list
-
- models = bedrock_get_model_list(self.region)
-
- configs = []
- for model_summary in models:
- model_arn = model_summary["inferenceProfileArn"]
- configs.append(
- LLMConfig(
- model=model_arn,
- model_endpoint_type=self.provider_type.value,
- model_endpoint=None,
- context_window=self.get_model_context_window(model_arn),
- handle=self.get_handle(model_arn),
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
- return configs
-
- async def list_llm_models_async(self) -> List[LLMConfig]:
- from letta.llm_api.aws_bedrock import bedrock_get_model_list_async
-
- models = await bedrock_get_model_list_async(
- self.access_key,
- self.api_key,
- self.region,
- )
-
- configs = []
- for model_summary in models:
- model_arn = model_summary["inferenceProfileArn"]
- configs.append(
- LLMConfig(
- model=model_arn,
- model_endpoint_type=self.provider_type.value,
- model_endpoint=None,
- context_window=self.get_model_context_window(model_arn),
- handle=self.get_handle(model_arn),
- provider_name=self.name,
- provider_category=self.provider_category,
- )
- )
-
- return configs
-
- def list_embedding_models(self):
- return []
-
- def get_model_context_window(self, model_name: str) -> Optional[int]:
- # Context windows for Claude models
- from letta.llm_api.aws_bedrock import bedrock_get_model_context_window
-
- return bedrock_get_model_context_window(model_name)
-
- def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str:
- print(model_name)
- model = model_name.split(".")[-1]
- return f"{self.name}/{model}"
diff --git a/letta/schemas/providers/__init__.py b/letta/schemas/providers/__init__.py
new file mode 100644
index 00000000..cc3cbd69
--- /dev/null
+++ b/letta/schemas/providers/__init__.py
@@ -0,0 +1,47 @@
+# Provider base classes and utilities
+# Provider implementations
+from .anthropic import AnthropicProvider
+from .azure import AzureProvider
+from .base import Provider, ProviderBase, ProviderCheck, ProviderCreate, ProviderUpdate
+from .bedrock import BedrockProvider
+from .cerebras import CerebrasProvider
+from .cohere import CohereProvider
+from .deepseek import DeepSeekProvider
+from .google_gemini import GoogleAIProvider
+from .google_vertex import GoogleVertexProvider
+from .groq import GroqProvider
+from .letta import LettaProvider
+from .lmstudio import LMStudioOpenAIProvider
+from .mistral import MistralProvider
+from .ollama import OllamaProvider
+from .openai import OpenAIProvider
+from .together import TogetherProvider
+from .vllm import VLLMProvider
+from .xai import XAIProvider
+
+__all__ = [
+ # Base classes
+ "Provider",
+ "ProviderBase",
+ "ProviderCreate",
+ "ProviderUpdate",
+ "ProviderCheck",
+ # Provider implementations
+ "AnthropicProvider",
+ "AzureProvider",
+ "BedrockProvider",
+ "CerebrasProvider", # NEW
+ "CohereProvider",
+ "DeepSeekProvider",
+ "GoogleAIProvider",
+ "GoogleVertexProvider",
+ "GroqProvider",
+ "LettaProvider",
+ "LMStudioOpenAIProvider",
+ "MistralProvider",
+ "OllamaProvider",
+ "OpenAIProvider",
+ "TogetherProvider",
+ "VLLMProvider", # Replaces ChatCompletions and Completions
+ "XAIProvider",
+]
diff --git a/letta/schemas/providers/anthropic.py b/letta/schemas/providers/anthropic.py
new file mode 100644
index 00000000..eac4c90d
--- /dev/null
+++ b/letta/schemas/providers/anthropic.py
@@ -0,0 +1,78 @@
+import warnings
+from typing import Literal
+
+from pydantic import Field
+
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.base import Provider
+
+
+class AnthropicProvider(Provider):
+ provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ api_key: str = Field(..., description="API key for the Anthropic API.")
+ base_url: str = "https://api.anthropic.com/v1"
+
+ async def check_api_key(self):
+ from letta.llm_api.anthropic import anthropic_check_valid_api_key
+
+ anthropic_check_valid_api_key(self.api_key)
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ from letta.llm_api.anthropic import anthropic_get_model_list_async
+
+ models = await anthropic_get_model_list_async(api_key=self.api_key)
+ return self._list_llm_models(models)
+
+ def _list_llm_models(self, models) -> list[LLMConfig]:
+ from letta.llm_api.anthropic import MODEL_LIST
+
+ configs = []
+ for model in models:
+ if any((model.get("type") != "model", "id" not in model, model.get("id").startswith("claude-2"))):
+ continue
+
+ # Anthropic doesn't return the context window in their API
+ if "context_window" not in model:
+ # Remap list to name: context_window
+ model_library = {m["name"]: m["context_window"] for m in MODEL_LIST}
+ # Attempt to look it up in a hardcoded list
+ if model["id"] in model_library:
+ model["context_window"] = model_library[model["id"]]
+ else:
+ # On fallback, we can set 200k (generally safe), but we should warn the user
+ warnings.warn(f"Couldn't find context window size for model {model['id']}, defaulting to 200,000")
+ model["context_window"] = 200000
+
+ max_tokens = 8192
+ if "claude-3-opus" in model["id"]:
+ max_tokens = 4096
+ if "claude-3-haiku" in model["id"]:
+ max_tokens = 4096
+ # TODO: set for 3-7 extended thinking mode
+
+ # NOTE: from 2025-02
+ # We set this to false by default, because Anthropic can
+ # natively support tags inside of content fields
+ # However, putting COT inside of tool calls can make it more
+ # reliable for tool calling (no chance of a non-tool call step)
+ # Since tool_choice_type 'any' doesn't work with in-content COT
+ # NOTE For Haiku, it can be flaky if we don't enable this by default
+ # inner_thoughts_in_kwargs = True if "haiku" in model["id"] else False
+ inner_thoughts_in_kwargs = True # we no longer support thinking tags
+
+ configs.append(
+ LLMConfig(
+ model=model["id"],
+ model_endpoint_type="anthropic",
+ model_endpoint=self.base_url,
+ context_window=model["context_window"],
+ handle=self.get_handle(model["id"]),
+ put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
+ max_tokens=max_tokens,
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+ return configs
diff --git a/letta/schemas/providers/azure.py b/letta/schemas/providers/azure.py
new file mode 100644
index 00000000..e51c1775
--- /dev/null
+++ b/letta/schemas/providers/azure.py
@@ -0,0 +1,80 @@
+from typing import ClassVar, Literal
+
+from pydantic import Field, field_validator
+
+from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_TOKENS
+from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
+from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
+from letta.schemas.embedding_config import EmbeddingConfig
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.base import Provider
+
+
+class AzureProvider(Provider):
+ LATEST_API_VERSION: ClassVar[str] = "2024-09-01-preview"
+
+ provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ # Note: 2024-09-01-preview was set here until 2025-07-16.
+ # set manually, see: https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
+ latest_api_version: str = "2025-04-01-preview"
+ base_url: str = Field(
+ ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
+ )
+ api_key: str = Field(..., description="API key for the Azure API.")
+ api_version: str = Field(default=LATEST_API_VERSION, description="API version for the Azure API")
+
+ @field_validator("api_version", mode="before")
+ def replace_none_with_default(cls, v):
+ return v if v is not None else cls.LATEST_API_VERSION
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ # TODO (cliandy): asyncify
+ from letta.llm_api.azure_openai import azure_openai_get_chat_completion_model_list
+
+ model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
+ configs = []
+ for model_option in model_options:
+ model_name = model_option["id"]
+ context_window_size = self.get_model_context_window(model_name)
+ model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
+ configs.append(
+ LLMConfig(
+ model=model_name,
+ model_endpoint_type="azure",
+ model_endpoint=model_endpoint,
+ context_window=context_window_size,
+ handle=self.get_handle(model_name),
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+ return configs
+
+ async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
+ # TODO (cliandy): asyncify dependent function calls
+ from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list
+
+ model_options = azure_openai_get_embeddings_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
+ configs = []
+ for model_option in model_options:
+ model_name = model_option["id"]
+ model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version)
+ configs.append(
+ EmbeddingConfig(
+ embedding_model=model_name,
+ embedding_endpoint_type="azure",
+ embedding_endpoint=model_endpoint,
+ embedding_dim=768, # TODO generated 1536?
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, # old note: max is 2048
+ handle=self.get_handle(model_name, is_embedding=True),
+ batch_size=1024,
+ )
+ )
+ return configs
+
+ def get_model_context_window(self, model_name: str) -> int | None:
+ # Hard coded as there are no API endpoints for this
+ llm_default = LLM_MAX_TOKENS.get(model_name, 4096)
+ return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default)
diff --git a/letta/schemas/providers/base.py b/letta/schemas/providers/base.py
new file mode 100644
index 00000000..eef2cb39
--- /dev/null
+++ b/letta/schemas/providers/base.py
@@ -0,0 +1,201 @@
+from datetime import datetime
+
+from pydantic import BaseModel, Field, model_validator
+
+from letta.schemas.embedding_config import EmbeddingConfig
+from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.letta_base import LettaBase
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
+from letta.settings import model_settings
+
+
+class ProviderBase(LettaBase):
+ __id_prefix__ = "provider"
+
+
+class Provider(ProviderBase):
+ id: str | None = Field(None, description="The id of the provider, lazily created by the database manager.")
+ name: str = Field(..., description="The name of the provider")
+ provider_type: ProviderType = Field(..., description="The type of the provider")
+ provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)")
+ api_key: str | None = Field(None, description="API key or secret key used for requests to the provider.")
+ base_url: str | None = Field(None, description="Base URL for the provider.")
+ access_key: str | None = Field(None, description="Access key used for requests to the provider.")
+ region: str | None = Field(None, description="Region used for requests to the provider.")
+ organization_id: str | None = Field(None, description="The organization id of the user")
+ updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.")
+
+ @model_validator(mode="after")
+ def default_base_url(self):
+ if self.provider_type == ProviderType.openai and self.base_url is None:
+ self.base_url = model_settings.openai_api_base
+ return self
+
+ def resolve_identifier(self):
+ if not self.id:
+ self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
+
+ async def check_api_key(self):
+ """Check if the API key is valid for the provider"""
+ raise NotImplementedError
+
+ def list_llm_models(self) -> list[LLMConfig]:
+ """List available LLM models (deprecated: use list_llm_models_async)"""
+ import asyncio
+ import warnings
+
+ warnings.warn("list_llm_models is deprecated, use list_llm_models_async instead", DeprecationWarning, stacklevel=2)
+
+ # Simplified asyncio handling - just use asyncio.run()
+ # This works in most contexts and avoids complex event loop detection
+ try:
+ return asyncio.run(self.list_llm_models_async())
+ except RuntimeError as e:
+ # If we're in an active event loop context, use a thread pool
+ if "cannot be called from a running event loop" in str(e):
+ import concurrent.futures
+
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future = executor.submit(asyncio.run, self.list_llm_models_async())
+ return future.result()
+ else:
+ raise
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ return []
+
+ def list_embedding_models(self) -> list[EmbeddingConfig]:
+ """List available embedding models (deprecated: use list_embedding_models_async)"""
+ import asyncio
+ import warnings
+
+ warnings.warn("list_embedding_models is deprecated, use list_embedding_models_async instead", DeprecationWarning, stacklevel=2)
+
+ # Simplified asyncio handling - just use asyncio.run()
+ # This works in most contexts and avoids complex event loop detection
+ try:
+ return asyncio.run(self.list_embedding_models_async())
+ except RuntimeError as e:
+ # If we're in an active event loop context, use a thread pool
+ if "cannot be called from a running event loop" in str(e):
+ import concurrent.futures
+
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future = executor.submit(asyncio.run, self.list_embedding_models_async())
+ return future.result()
+ else:
+ raise
+
+ async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
+ """List available embedding models. The following do not have support for embedding models:
+ Anthropic, Bedrock, Cerebras, Deepseek, Groq, Mistral, xAI
+ """
+ return []
+
+ def get_model_context_window(self, model_name: str) -> int | None:
+ raise NotImplementedError
+
+ async def get_model_context_window_async(self, model_name: str) -> int | None:
+ raise NotImplementedError
+
+ def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str:
+ """
+ Get the handle for a model, with support for custom overrides.
+
+ Args:
+ model_name (str): The name of the model.
+ is_embedding (bool, optional): Whether the handle is for an embedding model. Defaults to False.
+
+ Returns:
+ str: The handle for the model.
+ """
+ base_name = base_name if base_name else self.name
+
+ overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES
+ if base_name in overrides and model_name in overrides[base_name]:
+ model_name = overrides[base_name][model_name]
+
+ return f"{base_name}/{model_name}"
+
+ def cast_to_subtype(self):
+ # Import here to avoid circular imports
+ from letta.schemas.providers import (
+ AnthropicProvider,
+ AzureProvider,
+ BedrockProvider,
+ CerebrasProvider,
+ CohereProvider,
+ DeepSeekProvider,
+ GoogleAIProvider,
+ GoogleVertexProvider,
+ GroqProvider,
+ LettaProvider,
+ LMStudioOpenAIProvider,
+ MistralProvider,
+ OllamaProvider,
+ OpenAIProvider,
+ TogetherProvider,
+ VLLMProvider,
+ XAIProvider,
+ )
+
+ match self.provider_type:
+ case ProviderType.letta:
+ return LettaProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.openai:
+ return OpenAIProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.anthropic:
+ return AnthropicProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.google_ai:
+ return GoogleAIProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.google_vertex:
+ return GoogleVertexProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.azure:
+ return AzureProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.groq:
+ return GroqProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.together:
+ return TogetherProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.ollama:
+ return OllamaProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.vllm:
+ return VLLMProvider(**self.model_dump(exclude_none=True)) # Removed support for CompletionsProvider
+ case ProviderType.mistral:
+ return MistralProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.deepseek:
+ return DeepSeekProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.cerebras:
+ return CerebrasProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.xai:
+ return XAIProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.lmstudio_openai:
+ return LMStudioOpenAIProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.bedrock:
+ return BedrockProvider(**self.model_dump(exclude_none=True))
+ case ProviderType.cohere:
+ return CohereProvider(**self.model_dump(exclude_none=True))
+ case _:
+ raise ValueError(f"Unknown provider type: {self.provider_type}")
+
+
+class ProviderCreate(ProviderBase):
+ name: str = Field(..., description="The name of the provider.")
+ provider_type: ProviderType = Field(..., description="The type of the provider.")
+ api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
+ access_key: str | None = Field(None, description="Access key used for requests to the provider.")
+ region: str | None = Field(None, description="Region used for requests to the provider.")
+
+
+class ProviderUpdate(ProviderBase):
+ api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
+ access_key: str | None = Field(None, description="Access key used for requests to the provider.")
+ region: str | None = Field(None, description="Region used for requests to the provider.")
+
+
+class ProviderCheck(BaseModel):
+ provider_type: ProviderType = Field(..., description="The type of the provider.")
+ api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
+ access_key: str | None = Field(None, description="Access key used for requests to the provider.")
+ region: str | None = Field(None, description="Region used for requests to the provider.")
diff --git a/letta/schemas/providers/bedrock.py b/letta/schemas/providers/bedrock.py
new file mode 100644
index 00000000..d7d8437f
--- /dev/null
+++ b/letta/schemas/providers/bedrock.py
@@ -0,0 +1,78 @@
+"""
+Note that this formally only supports Anthropic Bedrock.
+TODO (cliandy): determine what other providers are supported and what is needed to add support.
+"""
+
+from typing import Literal
+
+from pydantic import Field
+
+from letta.log import get_logger
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.base import Provider
+
+logger = get_logger(__name__)
+
+
+class BedrockProvider(Provider):
+ provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ region: str = Field(..., description="AWS region for Bedrock")
+
+ async def check_api_key(self):
+ """Check if the Bedrock credentials are valid"""
+ from letta.errors import LLMAuthenticationError
+ from letta.llm_api.aws_bedrock import bedrock_get_model_list_async
+
+ try:
+ # For BYOK providers, use the custom credentials
+ if self.provider_category == ProviderCategory.byok:
+ # If we can list models, the credentials are valid
+ await bedrock_get_model_list_async(
+ access_key_id=self.access_key,
+ secret_access_key=self.api_key, # api_key stores the secret access key
+ region_name=self.region,
+ )
+ else:
+ # For base providers, use default credentials
+ bedrock_get_model_list(region_name=self.region)
+ except Exception as e:
+ raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}")
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ from letta.llm_api.aws_bedrock import bedrock_get_model_list_async
+
+ models = await bedrock_get_model_list_async(
+ self.access_key,
+ self.api_key,
+ self.region,
+ )
+
+ configs = []
+ for model_summary in models:
+ model_arn = model_summary["inferenceProfileArn"]
+ configs.append(
+ LLMConfig(
+ model=model_arn,
+ model_endpoint_type=self.provider_type.value,
+ model_endpoint=None,
+ context_window=self.get_model_context_window(model_arn),
+ handle=self.get_handle(model_arn),
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+
+ return configs
+
+ def get_model_context_window(self, model_name: str) -> int | None:
+ # Context windows for Claude models
+ from letta.llm_api.aws_bedrock import bedrock_get_model_context_window
+
+ return bedrock_get_model_context_window(model_name)
+
+ def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str:
+ logger.debug("Getting handle for model_name: %s", model_name)
+ model = model_name.split(".")[-1]
+ return f"{self.name}/{model}"
diff --git a/letta/schemas/providers/cerebras.py b/letta/schemas/providers/cerebras.py
new file mode 100644
index 00000000..173dc4ba
--- /dev/null
+++ b/letta/schemas/providers/cerebras.py
@@ -0,0 +1,79 @@
+import warnings
+from typing import Literal
+
+from pydantic import Field
+
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.openai import OpenAIProvider
+
+
+class CerebrasProvider(OpenAIProvider):
+ """
+ Cerebras Inference API is OpenAI-compatible and focuses on ultra-fast inference.
+
+ Available Models (as of 2025):
+ - llama-4-scout-17b-16e-instruct: Llama 4 Scout (109B params, 10M context, ~2600 tokens/s)
+ - llama3.1-8b: Llama 3.1 8B (8B params, 128K context, ~2200 tokens/s)
+ - llama-3.3-70b: Llama 3.3 70B (70B params, 128K context, ~2100 tokens/s)
+ - qwen-3-32b: Qwen 3 32B (32B params, 131K context, ~2100 tokens/s)
+ - deepseek-r1-distill-llama-70b: DeepSeek R1 Distill (70B params, 128K context, ~1700 tokens/s)
+ """
+
+ provider_type: Literal[ProviderType.cerebras] = Field(ProviderType.cerebras, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ base_url: str = Field("https://api.cerebras.ai/v1", description="Base URL for the Cerebras API.")
+ api_key: str = Field(..., description="API key for the Cerebras API.")
+
+ def get_model_context_window_size(self, model_name: str) -> int | None:
+ """Cerebras has limited context window sizes.
+
+ see https://inference-docs.cerebras.ai/support/pricing for details by plan
+ """
+ is_free_tier = True
+ if is_free_tier:
+ return 8192
+ return 128000
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ from letta.llm_api.openai import openai_get_model_list_async
+
+ response = await openai_get_model_list_async(self.base_url, api_key=self.api_key)
+
+ if "data" in response:
+ data = response["data"]
+ else:
+ data = response
+
+ configs = []
+ for model in data:
+ assert "id" in model, f"Cerebras model missing 'id' field: {model}"
+ model_name = model["id"]
+
+ # Check if model has context_length in response
+ if "context_length" in model:
+ context_window_size = model["context_length"]
+ else:
+ context_window_size = self.get_model_context_window_size(model_name)
+
+ if not context_window_size:
+ warnings.warn(f"Couldn't find context window size for model {model_name}")
+ continue
+
+ # Cerebras supports function calling
+ put_inner_thoughts_in_kwargs = True
+
+ configs.append(
+ LLMConfig(
+ model=model_name,
+ model_endpoint_type="openai", # Cerebras uses OpenAI-compatible endpoint
+ model_endpoint=self.base_url,
+ context_window=context_window_size,
+ handle=self.get_handle(model_name),
+ put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+
+ return configs
diff --git a/letta/schemas/providers/cohere.py b/letta/schemas/providers/cohere.py
new file mode 100644
index 00000000..ce3d6150
--- /dev/null
+++ b/letta/schemas/providers/cohere.py
@@ -0,0 +1,18 @@
+from typing import Literal
+
+from pydantic import Field
+
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.openai import OpenAIProvider
+
+
+# TODO (cliandy): this needs to be implemented
+class CohereProvider(OpenAIProvider):
+ provider_type: Literal[ProviderType.cohere] = Field(ProviderType.cohere, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ base_url: str = ""
+ api_key: str = Field(..., description="API key for the Cohere API.")
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ raise NotImplementedError
diff --git a/letta/schemas/providers/deepseek.py b/letta/schemas/providers/deepseek.py
new file mode 100644
index 00000000..0c1ae0c2
--- /dev/null
+++ b/letta/schemas/providers/deepseek.py
@@ -0,0 +1,63 @@
+from typing import Literal
+
+from pydantic import Field
+
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.openai import OpenAIProvider
+
+
+class DeepSeekProvider(OpenAIProvider):
+ """
+ DeepSeek ChatCompletions API is similar to OpenAI's reasoning API,
+ but with slight differences:
+ * For example, DeepSeek's API requires perfect interleaving of user/assistant
+ * It also does not support native function calling
+ """
+
+ provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
+ api_key: str = Field(..., description="API key for the DeepSeek API.")
+
+ # TODO (cliandy): this may need to be updated to reflect current models
+ def get_model_context_window_size(self, model_name: str) -> int | None:
+ # DeepSeek doesn't return context window in the model listing,
+ # so these are hardcoded from their website
+ if model_name == "deepseek-reasoner":
+ return 64000
+ elif model_name == "deepseek-chat":
+ return 64000
+ else:
+ return None
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ from letta.llm_api.openai import openai_get_model_list_async
+
+ response = await openai_get_model_list_async(self.base_url, api_key=self.api_key)
+ data = response.get("data", response)
+
+ configs = []
+ for model in data:
+ check = self._do_model_checks_for_name_and_context_size(model)
+ if check is None:
+ continue
+ model_name, context_window_size = check
+
+ # Not used for deepseek-reasoner, but otherwise is true
+ put_inner_thoughts_in_kwargs = False if model_name == "deepseek-reasoner" else True
+
+ configs.append(
+ LLMConfig(
+ model=model_name,
+ model_endpoint_type="deepseek",
+ model_endpoint=self.base_url,
+ context_window=context_window_size,
+ handle=self.get_handle(model_name),
+ put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+
+ return configs
diff --git a/letta/schemas/providers/google_gemini.py b/letta/schemas/providers/google_gemini.py
new file mode 100644
index 00000000..6404e0fc
--- /dev/null
+++ b/letta/schemas/providers/google_gemini.py
@@ -0,0 +1,102 @@
+import asyncio
+from typing import Literal
+
+from pydantic import Field
+
+from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_TOKENS
+from letta.schemas.embedding_config import EmbeddingConfig
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.base import Provider
+
+
+class GoogleAIProvider(Provider):
+ provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ api_key: str = Field(..., description="API key for the Google AI API.")
+ base_url: str = "https://generativelanguage.googleapis.com"
+
+ async def check_api_key(self):
+ from letta.llm_api.google_ai_client import google_ai_check_valid_api_key
+
+ google_ai_check_valid_api_key(self.api_key)
+
+ async def list_llm_models_async(self):
+ from letta.llm_api.google_ai_client import google_ai_get_model_list_async
+
+ # Get and filter the model list
+ model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
+ model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
+ model_options = [str(m["name"]) for m in model_options]
+
+ # filter by model names
+ model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
+
+ # Add support for all gemini models
+ model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
+
+ # Prepare tasks for context window lookups in parallel
+ async def create_config(model):
+ context_window = await self.get_model_context_window_async(model)
+ return LLMConfig(
+ model=model,
+ model_endpoint_type="google_ai",
+ model_endpoint=self.base_url,
+ context_window=context_window,
+ handle=self.get_handle(model),
+ max_tokens=8192,
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+
+ # Execute all config creation tasks concurrently
+ configs = await asyncio.gather(*[create_config(model) for model in model_options])
+
+ return configs
+
+ async def list_embedding_models_async(self):
+ from letta.llm_api.google_ai_client import google_ai_get_model_list_async
+
+ # TODO: use base_url instead
+ model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
+ return self._list_embedding_models(model_options)
+
+ def _list_embedding_models(self, model_options):
+ # filter by 'generateContent' models
+ model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
+ model_options = [str(m["name"]) for m in model_options]
+ model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
+
+ configs = []
+ for model in model_options:
+ configs.append(
+ EmbeddingConfig(
+ embedding_model=model,
+ embedding_endpoint_type="google_ai",
+ embedding_endpoint=self.base_url,
+ embedding_dim=768,
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, # NOTE: max is 2048
+ handle=self.get_handle(model, is_embedding=True),
+ batch_size=1024,
+ )
+ )
+ return configs
+
+ def get_model_context_window(self, model_name: str) -> int | None:
+ import warnings
+
+ warnings.warn("This is deprecated, use get_model_context_window_async when possible.", DeprecationWarning)
+ from letta.llm_api.google_ai_client import google_ai_get_model_context_window
+
+ if model_name in LLM_MAX_TOKENS:
+ return LLM_MAX_TOKENS[model_name]
+ else:
+ return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
+
+ async def get_model_context_window_async(self, model_name: str) -> int | None:
+ from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async
+
+ if model_name in LLM_MAX_TOKENS:
+ return LLM_MAX_TOKENS[model_name]
+ else:
+ return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name)
diff --git a/letta/schemas/providers/google_vertex.py b/letta/schemas/providers/google_vertex.py
new file mode 100644
index 00000000..0ed68541
--- /dev/null
+++ b/letta/schemas/providers/google_vertex.py
@@ -0,0 +1,54 @@
+from typing import Literal
+
+from pydantic import Field
+
+from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
+from letta.schemas.embedding_config import EmbeddingConfig
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.base import Provider
+
+
+# TODO (cliandy): GoogleVertexProvider uses hardcoded models vs Gemini fetches from API
+class GoogleVertexProvider(Provider):
+ provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
+ google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ from letta.llm_api.google_constants import GOOGLE_MODEL_TO_CONTEXT_LENGTH
+
+ configs = []
+ for model, context_length in GOOGLE_MODEL_TO_CONTEXT_LENGTH.items():
+ configs.append(
+ LLMConfig(
+ model=model,
+ model_endpoint_type="google_vertex",
+ model_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}",
+ context_window=context_length,
+ handle=self.get_handle(model),
+ max_tokens=8192,
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+ return configs
+
+ async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
+ from letta.llm_api.google_constants import GOOGLE_EMBEDING_MODEL_TO_DIM
+
+ configs = []
+ for model, dim in GOOGLE_EMBEDING_MODEL_TO_DIM.items():
+ configs.append(
+ EmbeddingConfig(
+ embedding_model=model,
+ embedding_endpoint_type="google_vertex",
+ embedding_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}",
+ embedding_dim=dim,
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, # NOTE: max is 2048
+ handle=self.get_handle(model, is_embedding=True),
+ batch_size=1024,
+ )
+ )
+ return configs
diff --git a/letta/schemas/providers/groq.py b/letta/schemas/providers/groq.py
new file mode 100644
index 00000000..18b4cb31
--- /dev/null
+++ b/letta/schemas/providers/groq.py
@@ -0,0 +1,35 @@
+from typing import Literal
+
+from pydantic import Field
+
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.openai import OpenAIProvider
+
+
+class GroqProvider(OpenAIProvider):
+ provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ base_url: str = "https://api.groq.com/openai/v1"
+ api_key: str = Field(..., description="API key for the Groq API.")
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ from letta.llm_api.openai import openai_get_model_list_async
+
+ response = await openai_get_model_list_async(self.base_url, api_key=self.api_key)
+ configs = []
+ for model in response["data"]:
+ if "context_window" not in model:
+ continue
+ configs.append(
+ LLMConfig(
+ model=model["id"],
+ model_endpoint_type="groq",
+ model_endpoint=self.base_url,
+ context_window=model["context_window"],
+ handle=self.get_handle(model["id"]),
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+ return configs
diff --git a/letta/schemas/providers/letta.py b/letta/schemas/providers/letta.py
new file mode 100644
index 00000000..37763884
--- /dev/null
+++ b/letta/schemas/providers/letta.py
@@ -0,0 +1,39 @@
+from typing import Literal
+
+from pydantic import Field
+
+from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LETTA_MODEL_ENDPOINT
+from letta.schemas.embedding_config import EmbeddingConfig
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.base import Provider
+
+
+class LettaProvider(Provider):
+ provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ return [
+ LLMConfig(
+ model="letta-free", # NOTE: renamed
+ model_endpoint_type="openai",
+ model_endpoint=LETTA_MODEL_ENDPOINT,
+ context_window=30000,
+ handle=self.get_handle("letta-free"),
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ ]
+
+ async def list_embedding_models_async(self):
+ return [
+ EmbeddingConfig(
+ embedding_model="letta-free", # NOTE: renamed
+ embedding_endpoint_type="hugging-face",
+ embedding_endpoint="https://embeddings.memgpt.ai",
+ embedding_dim=1024,
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
+ handle=self.get_handle("letta-free", is_embedding=True),
+ )
+ ]
diff --git a/letta/schemas/providers/lmstudio.py b/letta/schemas/providers/lmstudio.py
new file mode 100644
index 00000000..6e4a639a
--- /dev/null
+++ b/letta/schemas/providers/lmstudio.py
@@ -0,0 +1,97 @@
+import warnings
+from typing import Literal
+
+from pydantic import Field
+
+from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
+from letta.schemas.embedding_config import EmbeddingConfig
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.openai import OpenAIProvider
+
+
+class LMStudioOpenAIProvider(OpenAIProvider):
+ provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
+ api_key: str | None = Field(None, description="API key for the LMStudio API.")
+
+ @property
+ def model_endpoint_url(self):
+ # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models'
+ return f"{self.base_url.strip('/v1')}/api/v0"
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ from letta.llm_api.openai import openai_get_model_list_async
+
+ response = await openai_get_model_list_async(self.model_endpoint_url)
+
+ if "data" not in response:
+ warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}")
+ return []
+
+ configs = []
+ for model in response["data"]:
+ model_type = model.get("type")
+ if not model_type:
+ warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}")
+ continue
+ if model_type not in ("vlm", "llm"):
+ continue
+
+ # TODO (cliandy): previously we didn't get the backup context size, is this valid?
+ check = self._do_model_checks_for_name_and_context_size(model)
+ if check is None:
+ continue
+ model_name, context_window_size = check
+
+ configs.append(
+ LLMConfig(
+ model=model_name,
+ model_endpoint_type="openai",
+ model_endpoint=self.base_url,
+ context_window=context_window_size,
+ handle=self.get_handle(model_name),
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+
+ return configs
+
+ async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
+ from letta.llm_api.openai import openai_get_model_list_async
+
+ response = await openai_get_model_list_async(self.model_endpoint_url)
+
+ if "data" not in response:
+ warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}")
+ return []
+
+ configs = []
+ for model in response["data"]:
+ model_type = model.get("type")
+ if not model_type:
+ warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}")
+ continue
+ if model_type not in ("embeddings"):
+ continue
+
+ # TODO (cliandy): previously we didn't get the backup context size, is this valid?
+ check = self._do_model_checks_for_name_and_context_size(model, length_key="max_context_length")
+ if check is None:
+ continue
+ model_name, context_window_size = check
+
+ configs.append(
+ EmbeddingConfig(
+ embedding_model=model_name,
+ embedding_endpoint_type="openai",
+ embedding_endpoint=self.base_url,
+ embedding_dim=768, # Default embedding dimension, not context window
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, # NOTE: max is 2048
+ handle=self.get_handle(model_name),
+ ),
+ )
+
+ return configs
diff --git a/letta/schemas/providers/mistral.py b/letta/schemas/providers/mistral.py
new file mode 100644
index 00000000..2eeb3a23
--- /dev/null
+++ b/letta/schemas/providers/mistral.py
@@ -0,0 +1,41 @@
+from typing import Literal
+
+from pydantic import Field
+
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.base import Provider
+
+
+class MistralProvider(Provider):
+ provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ api_key: str = Field(..., description="API key for the Mistral API.")
+ base_url: str = "https://api.mistral.ai/v1"
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ from letta.llm_api.mistral import mistral_get_model_list_async
+
+ # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
+ # See: https://openrouter.ai/docs/requests
+ response = await mistral_get_model_list_async(self.base_url, api_key=self.api_key)
+
+ assert "data" in response, f"Mistral model query response missing 'data' field: {response}"
+
+ configs = []
+ for model in response["data"]:
+ # If model has chat completions and function calling enabled
+ if model["capabilities"]["completion_chat"] and model["capabilities"]["function_calling"]:
+ configs.append(
+ LLMConfig(
+ model=model["id"],
+ model_endpoint_type="openai",
+ model_endpoint=self.base_url,
+ context_window=model["max_context_length"],
+ handle=self.get_handle(model["id"]),
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+
+ return configs
diff --git a/letta/schemas/providers/ollama.py b/letta/schemas/providers/ollama.py
new file mode 100644
index 00000000..b9ddaa2c
--- /dev/null
+++ b/letta/schemas/providers/ollama.py
@@ -0,0 +1,151 @@
+from typing import Literal
+
+import aiohttp
+import requests
+from pydantic import Field
+
+from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
+from letta.log import get_logger
+from letta.schemas.embedding_config import EmbeddingConfig
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.openai import OpenAIProvider
+
+logger = get_logger(__name__)
+
+
+class OllamaProvider(OpenAIProvider):
+ """Ollama provider that uses the native /api/generate endpoint
+
+ See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
+ """
+
+ provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ base_url: str = Field(..., description="Base URL for the Ollama API.")
+ api_key: str | None = Field(None, description="API key for the Ollama API (default: `None`).")
+ default_prompt_formatter: str = Field(
+ ..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API."
+ )
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ """List available LLM Models from Ollama
+
+ https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models"""
+ endpoint = f"{self.base_url}/api/tags"
+ async with aiohttp.ClientSession() as session:
+ async with session.get(endpoint) as response:
+ if response.status != 200:
+ raise Exception(f"Failed to list Ollama models: {response.text}")
+ response_json = await response.json()
+
+ configs = []
+ for model in response_json["models"]:
+ context_window = self.get_model_context_window(model["name"])
+ if context_window is None:
+ print(f"Ollama model {model['name']} has no context window")
+ continue
+ configs.append(
+ LLMConfig(
+ model=model["name"],
+ model_endpoint_type="ollama",
+ model_endpoint=self.base_url,
+ model_wrapper=self.default_prompt_formatter,
+ context_window=context_window,
+ handle=self.get_handle(model["name"]),
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+ return configs
+
+ async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
+ """List available embedding models from Ollama
+
+ https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
+ """
+ endpoint = f"{self.base_url}/api/tags"
+ async with aiohttp.ClientSession() as session:
+ async with session.get(endpoint) as response:
+ if response.status != 200:
+ raise Exception(f"Failed to list Ollama models: {response.text}")
+ response_json = await response.json()
+
+ configs = []
+ for model in response_json["models"]:
+ embedding_dim = await self._get_model_embedding_dim_async(model["name"])
+ if not embedding_dim:
+ print(f"Ollama model {model['name']} has no embedding dimension")
+ continue
+ configs.append(
+ EmbeddingConfig(
+ embedding_model=model["name"],
+ embedding_endpoint_type="ollama",
+ embedding_endpoint=self.base_url,
+ embedding_dim=embedding_dim,
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
+ handle=self.get_handle(model["name"], is_embedding=True),
+ )
+ )
+ return configs
+
+ def get_model_context_window(self, model_name: str) -> int | None:
+ """Gets model context window for Ollama. As this can look different based on models,
+ we use the following for guidance:
+
+ "llama.context_length": 8192,
+ "llama.embedding_length": 4096,
+ source: https://github.com/ollama/ollama/blob/main/docs/api.md#show-model-information
+
+ FROM 2024-10-08
+ Notes from vLLM around keys
+ source: https://github.com/vllm-project/vllm/blob/72ad2735823e23b4e1cc79b7c73c3a5f3c093ab0/vllm/config.py#L3488
+
+ possible_keys = [
+ # OPT
+ "max_position_embeddings",
+ # GPT-2
+ "n_positions",
+ # MPT
+ "max_seq_len",
+ # ChatGLM2
+ "seq_length",
+ # Command-R
+ "model_max_length",
+ # Whisper
+ "max_target_positions",
+ # Others
+ "max_sequence_length",
+ "max_seq_length",
+ "seq_len",
+ ]
+ max_position_embeddings
+ parse model cards: nous, dolphon, llama
+ """
+ endpoint = f"{self.base_url}/api/show"
+ payload = {"name": model_name, "verbose": True}
+ response = requests.post(endpoint, json=payload)
+ if response.status_code != 200:
+ return None
+
+ try:
+ model_info = response.json()
+ # Try to extract context window from model parameters
+ if "model_info" in model_info and "llama.context_length" in model_info["model_info"]:
+ return int(model_info["model_info"]["llama.context_length"])
+ except Exception:
+ pass
+ logger.warning(f"Failed to get model context window for {model_name}")
+ return None
+
+ async def _get_model_embedding_dim_async(self, model_name: str):
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) as response:
+ response_json = await response.json()
+
+ if "model_info" not in response_json:
+ if "error" in response_json:
+ logger.warning("Ollama fetch model info error for %s: %s", model_name, response_json["error"])
+ return None
+
+ return response_json["model_info"].get("embedding_length")
diff --git a/letta/schemas/providers/openai.py b/letta/schemas/providers/openai.py
new file mode 100644
index 00000000..2a3bc0b8
--- /dev/null
+++ b/letta/schemas/providers/openai.py
@@ -0,0 +1,241 @@
+from typing import Literal
+
+from pydantic import Field
+
+from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_TOKENS
+from letta.log import get_logger
+from letta.schemas.embedding_config import EmbeddingConfig
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.base import Provider
+
+logger = get_logger(__name__)
+
+ALLOWED_PREFIXES = {"gpt-4", "o1", "o3", "o4"}
+DISALLOWED_KEYWORDS = {"transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro"}
+DEFAULT_EMBEDDING_BATCH_SIZE = 1024
+
+
+class OpenAIProvider(Provider):
+ provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ api_key: str = Field(..., description="API key for the OpenAI API.")
+ base_url: str = Field(..., description="Base URL for the OpenAI API.")
+
+ async def check_api_key(self):
+ from letta.llm_api.openai import openai_check_valid_api_key
+
+ openai_check_valid_api_key(self.base_url, self.api_key)
+
+ async def _get_models_async(self) -> list[dict]:
+ from letta.llm_api.openai import openai_get_model_list_async
+
+ # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
+ # See: https://openrouter.ai/docs/requests
+ extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
+
+ # Similar to Nebius
+ extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
+
+ response = await openai_get_model_list_async(
+ self.base_url,
+ api_key=self.api_key,
+ extra_params=extra_params,
+ # fix_url=True, # NOTE: make sure together ends with /v1
+ )
+
+ # TODO (cliandy): this is brittle as TogetherAI seems to result in a list instead of having a 'data' field
+ data = response.get("data", response)
+ assert isinstance(data, list)
+ return data
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ data = await self._get_models_async()
+ return self._list_llm_models(data)
+
+ def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]:
+ """
+ This handles filtering out LLM Models by provider that meet Letta's requirements.
+ """
+ configs = []
+ for model in data:
+ check = self._do_model_checks_for_name_and_context_size(model)
+ if check is None:
+ continue
+ model_name, context_window_size = check
+
+ # ===== Provider filtering =====
+ # TogetherAI: includes the type, which we can use to filter out embedding models
+ if "api.together.ai" in self.base_url or "api.together.xyz" in self.base_url:
+ if "type" in model and model["type"] not in ["chat", "language"]:
+ continue
+
+ # for TogetherAI, we need to skip the models that don't support JSON mode / function calling
+ # requests.exceptions.HTTPError: HTTP error occurred: 400 Client Error: Bad Request for url: https://api.together.ai/v1/chat/completions | Status code: 400, Message: {
+ # "error": {
+ # "message": "mistralai/Mixtral-8x7B-v0.1 is not supported for JSON mode/function calling",
+ # "type": "invalid_request_error",
+ # "param": null,
+ # "code": "constraints_model"
+ # }
+ # }
+ if "config" not in model:
+ continue
+
+ # Nebius: includes the type, which we can use to filter for text models
+ if "nebius.com" in self.base_url:
+ model_type = model.get("architecture", {}).get("modality")
+ if model_type not in ["text->text", "text+image->text"]:
+ continue
+
+ # OpenAI
+ # NOTE: o1-mini and o1-preview do not support tool calling
+ # NOTE: o1-mini does not support system messages
+ # NOTE: o1-pro is only available in Responses API
+ if self.base_url == "https://api.openai.com/v1":
+ if any(keyword in model_name for keyword in DISALLOWED_KEYWORDS) or not any(
+ model_name.startswith(prefix) for prefix in ALLOWED_PREFIXES
+ ):
+ continue
+
+ # We'll set the model endpoint based on the base URL
+ # Note: openai-proxy just means that the model is using the OpenAIProvider
+ if self.base_url != "https://api.openai.com/v1":
+ handle = self.get_handle(model_name, base_name="openai-proxy")
+ else:
+ handle = self.get_handle(model_name)
+
+ config = LLMConfig(
+ model=model_name,
+ model_endpoint_type="openai",
+ model_endpoint=self.base_url,
+ context_window=context_window_size,
+ handle=handle,
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+
+ config = self._set_model_parameter_tuned_defaults(model_name, config)
+ configs.append(config)
+
+ # for OpenAI, sort in reverse order
+ if self.base_url == "https://api.openai.com/v1":
+ configs.sort(key=lambda x: x.model, reverse=True)
+ return configs
+
+ def _do_model_checks_for_name_and_context_size(self, model: dict, length_key: str = "context_length") -> tuple[str, int] | None:
+ if "id" not in model:
+ logger.warning("Model missing 'id' field for provider: %s and model: %s", self.provider_type, model)
+ return None
+
+ model_name = model["id"]
+ context_window_size = model.get(length_key) or self.get_model_context_window_size(model_name)
+
+ if not context_window_size:
+ logger.info("No context window size found for model: %s", model_name)
+ return None
+
+ return model_name, context_window_size
+
+ @staticmethod
+ def _set_model_parameter_tuned_defaults(model_name: str, llm_config: LLMConfig):
+ """This function is used to tune LLMConfig parameters to improve model performance."""
+
+ # gpt-4o-mini has started to regress with pretty bad emoji spam loops (2025-07)
+ if "gpt-4o-mini" in model_name or "gpt-4.1-mini" in model_name:
+ llm_config.frequency_penalty = 1.0
+ return llm_config
+
+ async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
+ if self.base_url == "https://api.openai.com/v1":
+ # TODO: actually automatically list models for OpenAI
+ return [
+ EmbeddingConfig(
+ embedding_model="text-embedding-ada-002",
+ embedding_endpoint_type="openai",
+ embedding_endpoint=self.base_url,
+ embedding_dim=1536,
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
+ handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
+ batch_size=DEFAULT_EMBEDDING_BATCH_SIZE,
+ ),
+ EmbeddingConfig(
+ embedding_model="text-embedding-3-small",
+ embedding_endpoint_type="openai",
+ embedding_endpoint=self.base_url,
+ embedding_dim=2000,
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
+ handle=self.get_handle("text-embedding-3-small", is_embedding=True),
+ batch_size=DEFAULT_EMBEDDING_BATCH_SIZE,
+ ),
+ EmbeddingConfig(
+ embedding_model="text-embedding-3-large",
+ embedding_endpoint_type="openai",
+ embedding_endpoint=self.base_url,
+ embedding_dim=2000,
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
+ handle=self.get_handle("text-embedding-3-large", is_embedding=True),
+ batch_size=DEFAULT_EMBEDDING_BATCH_SIZE,
+ ),
+ ]
+ else:
+ # TODO: this has filtering that doesn't apply for embedding models, fix this.
+ data = await self._get_models_async()
+ return self._list_embedding_models(data)
+
+ def _list_embedding_models(self, data) -> list[EmbeddingConfig]:
+ configs = []
+ for model in data:
+ check = self._do_model_checks_for_name_and_context_size(model)
+ if check is None:
+ continue
+ model_name, context_window_size = check
+
+ # ===== Provider filtering =====
+ # TogetherAI: includes the type, which we can use to filter for embedding models
+ if "api.together.ai" in self.base_url or "api.together.xyz" in self.base_url:
+ if "type" in model and model["type"] not in ["embedding"]:
+ continue
+ # Nebius: includes the type, which we can use to filter for text models
+ elif "nebius.com" in self.base_url:
+ model_type = model.get("architecture", {}).get("modality")
+ if model_type not in ["text->embedding"]:
+ continue
+ else:
+ logger.info(
+ f"Skipping embedding models for %s by default, as we don't assume embeddings are supported."
+ "Please open an issue on GitHub if support is required.",
+ self.base_url,
+ )
+ continue
+
+ configs.append(
+ EmbeddingConfig(
+ embedding_model=model_name,
+ embedding_endpoint_type=self.provider_type,
+ embedding_endpoint=self.base_url,
+ embedding_dim=context_window_size,
+ embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
+ handle=self.get_handle(model, is_embedding=True),
+ )
+ )
+
+ return configs
+
+ def get_model_context_window_size(self, model_name: str) -> int | None:
+ if model_name in LLM_MAX_TOKENS:
+ return LLM_MAX_TOKENS[model_name]
+ else:
+ logger.debug(
+ f"Model %s on %s for provider %s not found in LLM_MAX_TOKENS. Using default of {{LLM_MAX_TOKENS['DEFAULT']}}",
+ model_name,
+ self.base_url,
+ self.__class__.__name__,
+ )
+ return LLM_MAX_TOKENS["DEFAULT"]
+
+ def get_model_context_window(self, model_name: str) -> int | None:
+ return self.get_model_context_window_size(model_name)
+
+ async def get_model_context_window_async(self, model_name: str) -> int | None:
+ return self.get_model_context_window_size(model_name)
diff --git a/letta/schemas/providers/together.py b/letta/schemas/providers/together.py
new file mode 100644
index 00000000..8a7a000e
--- /dev/null
+++ b/letta/schemas/providers/together.py
@@ -0,0 +1,85 @@
+"""
+Note: this supports completions (deprecated by openai) and chat completions via the OpenAI API.
+"""
+
+from typing import Literal
+
+from pydantic import Field
+
+from letta.constants import MIN_CONTEXT_WINDOW
+from letta.schemas.embedding_config import EmbeddingConfig
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.openai import OpenAIProvider
+
+
+class TogetherProvider(OpenAIProvider):
+ provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ base_url: str = "https://api.together.xyz/v1"
+ api_key: str = Field(..., description="API key for the Together API.")
+ default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ from letta.llm_api.openai import openai_get_model_list_async
+
+ models = await openai_get_model_list_async(self.base_url, api_key=self.api_key)
+ return self._list_llm_models(models)
+
+ async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
+ import warnings
+
+ warnings.warn(
+ "Letta does not currently support listing embedding models for Together. Please "
+ "contact support or reach out via GitHub or Discord to get support."
+ )
+ return []
+
+ # TODO (cliandy): verify this with openai
+ def _list_llm_models(self, models) -> list[LLMConfig]:
+ pass
+
+ # TogetherAI's response is missing the 'data' field
+ # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
+ if "data" in models:
+ data = models["data"]
+ else:
+ data = models
+
+ configs = []
+ for model in data:
+ assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
+ model_name = model["id"]
+
+ if "context_length" in model:
+ # Context length is returned in OpenRouter as "context_length"
+ context_window_size = model["context_length"]
+ else:
+ context_window_size = self.get_model_context_window_size(model_name)
+
+ # We need the context length for embeddings too
+ if not context_window_size:
+ continue
+
+ # Skip models that are too small for Letta
+ if context_window_size <= MIN_CONTEXT_WINDOW:
+ continue
+
+ # TogetherAI includes the type, which we can use to filter for embedding models
+ if "type" in model and model["type"] not in ["chat", "language"]:
+ continue
+
+ configs.append(
+ LLMConfig(
+ model=model_name,
+ model_endpoint_type="together",
+ model_endpoint=self.base_url,
+ model_wrapper=self.default_prompt_formatter,
+ context_window=context_window_size,
+ handle=self.get_handle(model_name),
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+
+ return configs
diff --git a/letta/schemas/providers/vllm.py b/letta/schemas/providers/vllm.py
new file mode 100644
index 00000000..2f261c3e
--- /dev/null
+++ b/letta/schemas/providers/vllm.py
@@ -0,0 +1,57 @@
+"""
+Note: this consolidates the vLLM provider for completions (deprecated by openai)
+and chat completions. Support is provided primarily for the chat completions endpoint,
+but to utilize the completions endpoint, set the proper `base_url` and
+`default_prompt_formatter`.
+"""
+
+from typing import Literal
+
+from pydantic import Field
+
+from letta.schemas.embedding_config import EmbeddingConfig
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.base import Provider
+
+
+class VLLMProvider(Provider):
+ provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ base_url: str = Field(..., description="Base URL for the vLLM API.")
+ api_key: str | None = Field(None, description="API key for the vLLM API.")
+ default_prompt_formatter: str | None = Field(
+ default=None, description="Default prompt formatter (aka model wrapper) to use on a /completions style API."
+ )
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ from letta.llm_api.openai import openai_get_model_list_async
+
+ # TODO (cliandy): previously unsupported with vLLM; confirm if this is still the case or not
+ response = await openai_get_model_list_async(self.base_url, api_key=self.api_key)
+
+ data = response.get("data", response)
+
+ configs = []
+ for model in data:
+ model_name = model["id"]
+
+ configs.append(
+ LLMConfig(
+ model=model_name,
+ model_endpoint_type="openai", # TODO (cliandy): this was previous vllm for the completions provider, why?
+ model_endpoint=self.base_url,
+ model_wrapper=self.default_prompt_formatter,
+ context_window=model["max_model_len"],
+ handle=self.get_handle(model_name),
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+
+ return configs
+
+ async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
+ # Note: vLLM technically can support embedding models though may require multiple instances
+ # for now, we will not support embedding models for vLLM.
+ return []
diff --git a/letta/schemas/providers/xai.py b/letta/schemas/providers/xai.py
new file mode 100644
index 00000000..d042aad0
--- /dev/null
+++ b/letta/schemas/providers/xai.py
@@ -0,0 +1,66 @@
+import warnings
+from typing import Literal
+
+from pydantic import Field
+
+from letta.schemas.enums import ProviderCategory, ProviderType
+from letta.schemas.llm_config import LLMConfig
+from letta.schemas.providers.openai import OpenAIProvider
+
+MODEL_CONTEXT_WINDOWS = {
+ "grok-3-fast": 131_072,
+ "grok-3": 131_072,
+ "grok-3-mini": 131_072,
+ "grok-3-mini-fast": 131_072,
+ "grok-4-0709": 256_000,
+}
+
+
+class XAIProvider(OpenAIProvider):
+ """https://docs.x.ai/docs/api-reference"""
+
+ provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
+ api_key: str = Field(..., description="API key for the xAI/Grok API.")
+ base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
+
+ def get_model_context_window_size(self, model_name: str) -> int | None:
+ # xAI doesn't return context window in the model listing,
+ # this is hardcoded from https://docs.x.ai/docs/models
+ return MODEL_CONTEXT_WINDOWS.get(model_name)
+
+ async def list_llm_models_async(self) -> list[LLMConfig]:
+ from letta.llm_api.openai import openai_get_model_list_async
+
+ response = await openai_get_model_list_async(self.base_url, api_key=self.api_key)
+
+ data = response.get("data", response)
+
+ configs = []
+ for model in data:
+ assert "id" in model, f"xAI/Grok model missing 'id' field: {model}"
+ model_name = model["id"]
+
+ # In case xAI starts supporting it in the future:
+ if "context_length" in model:
+ context_window_size = model["context_length"]
+ else:
+ context_window_size = self.get_model_context_window_size(model_name)
+
+ if not context_window_size:
+ warnings.warn(f"Couldn't find context window size for model {model_name}")
+ continue
+
+ configs.append(
+ LLMConfig(
+ model=model_name,
+ model_endpoint_type="xai",
+ model_endpoint=self.base_url,
+ context_window=context_window_size,
+ handle=self.get_handle(model_name),
+ provider_name=self.name,
+ provider_category=self.provider_category,
+ )
+ )
+
+ return configs
diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py
index 3c1cbdfb..208613ec 100644
--- a/letta/server/rest_api/app.py
+++ b/letta/server/rest_api/app.py
@@ -407,9 +407,10 @@ def start_server(
address=host or "127.0.0.1", # Note granian address must be an ip address
port=port or REST_DEFAULT_PORT,
workers=settings.uvicorn_workers,
- # threads=
+ # runtime_blocking_threads=
+ # runtime_threads=
reload=reload or settings.uvicorn_reload,
- reload_ignore_patterns=["openapi_letta.json"],
+ reload_paths=["letta/"],
reload_ignore_worker_failure=True,
reload_tick=4000, # set to 4s to prevent crashing on weird state
# log_level="info"
@@ -451,7 +452,7 @@ def start_server(
# runtime_blocking_threads=
# runtime_threads=
reload=reload or settings.uvicorn_reload,
- reload_paths=["../letta/"],
+ reload_paths=["letta/"],
reload_ignore_worker_failure=True,
reload_tick=4000, # set to 4s to prevent crashing on weird state
# log_level="info"
diff --git a/letta/server/rest_api/routers/v1/providers.py b/letta/server/rest_api/routers/v1/providers.py
index 66cb36ec..6f3f30d2 100644
--- a/letta/server/rest_api/routers/v1/providers.py
+++ b/letta/server/rest_api/routers/v1/providers.py
@@ -71,12 +71,12 @@ async def modify_provider(
@router.get("/check", response_model=None, operation_id="check_provider")
-def check_provider(
+async def check_provider(
request: ProviderCheck = Body(...),
server: "SyncServer" = Depends(get_letta_server),
):
try:
- server.provider_manager.check_provider_api_key(provider_check=request)
+ await server.provider_manager.check_provider_api_key(provider_check=request)
return JSONResponse(
status_code=status.HTTP_200_OK, content={"message": f"Valid api key for provider_type={request.provider_type.value}"}
)
diff --git a/letta/server/server.py b/letta/server/server.py
index fbc6db1e..cda9f8e0 100644
--- a/letta/server/server.py
+++ b/letta/server/server.py
@@ -68,8 +68,6 @@ from letta.schemas.providers import (
OpenAIProvider,
Provider,
TogetherProvider,
- VLLMChatCompletionsProvider,
- VLLMCompletionsProvider,
XAIProvider,
)
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxConfigCreate, SandboxType
@@ -360,7 +358,7 @@ class SyncServer(Server):
)
)
# NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup
- # see: https://docs.vllm.ai/en/latest/getting_started/examples/openai_chat_completion_client_with_tools.html
+ # see: https://docs.vllm.ai/en/stable/features/tool_calling.html
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
self._enabled_providers.append(
VLLMChatCompletionsProvider(
@@ -460,7 +458,7 @@ class SyncServer(Server):
# Determine whether or not to token stream based on the capability of the interface
token_streaming = letta_agent.interface.streaming_mode if hasattr(letta_agent.interface, "streaming_mode") else False
- logger.debug(f"Starting agent step")
+ logger.debug("Starting agent step")
if interface:
metadata = interface.metadata if hasattr(interface, "metadata") else None
else:
@@ -534,7 +532,7 @@ class SyncServer(Server):
letta_agent.interface.print_messages_raw(letta_agent.messages)
elif command.lower() == "memory":
- ret_str = f"\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.passage_manager)}"
+ ret_str = "\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.passage_manager)}"
return ret_str
elif command.lower() == "pop" or command.lower().startswith("pop "):
@@ -554,7 +552,7 @@ class SyncServer(Server):
elif command.lower() == "retry":
# TODO this needs to also modify the persistence manager
- logger.debug(f"Retrying for another answer")
+ logger.debug("Retrying for another answer")
while len(letta_agent.messages) > 0:
if letta_agent.messages[-1].get("role") == "user":
# we want to pop up to the last user message and send it again
@@ -770,6 +768,7 @@ class SyncServer(Server):
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
return self._embedding_config_cache[key]
+ # @async_redis_cache(key_func=lambda (actor, **kwargs): actor.id + hash(kwargs))
@trace_method
async def get_cached_embedding_config_async(self, actor: User, **kwargs):
key = make_key(**kwargs)
@@ -782,9 +781,9 @@ class SyncServer(Server):
self,
request: CreateAgent,
actor: User,
- # interface
- interface: Union[AgentInterface, None] = None,
+ interface: AgentInterface | None = None,
) -> AgentState:
+ warnings.warn("This method is deprecated, use create_agent_async where possible.", DeprecationWarning, stacklevel=2)
if request.llm_config is None:
if request.model is None:
raise ValueError("Must specify either model or llm_config in request")
@@ -1320,7 +1319,6 @@ class SyncServer(Server):
# TODO: delete data from agent passage stores (?)
async def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job:
-
# update job
job = await self.job_manager.get_job_by_id_async(job_id, actor=actor)
job.status = JobStatus.running
@@ -1564,7 +1562,6 @@ class SyncServer(Server):
# Add extra metadata to the sources
sources_with_metadata = []
for source in sources:
-
# count number of passages
num_passages = self.agent_manager.passage_size(actor=actor, source_id=source.id)
@@ -2118,7 +2115,6 @@ class SyncServer(Server):
mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME)
if os.path.exists(mcp_config_path):
with open(mcp_config_path, "r") as f:
-
try:
mcp_config = json.load(f)
except Exception as e:
@@ -2130,7 +2126,6 @@ class SyncServer(Server):
# with the value being the schema from StdioServerParameters
if MCP_CONFIG_TOPLEVEL_KEY in mcp_config:
for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items():
-
# No support for duplicate server names
if server_name in mcp_server_list:
logger.error(f"Duplicate MCP server name found (skipping): {server_name}")
@@ -2301,7 +2296,6 @@ class SyncServer(Server):
# For streaming response
try:
-
# TODO: move this logic into server.py
# Get the generator object off of the agent's streaming interface
@@ -2441,9 +2435,9 @@ class SyncServer(Server):
if not stream_steps and stream_tokens:
raise ValueError("stream_steps must be 'true' if stream_tokens is 'true'")
- group = self.group_manager.retrieve_group(group_id=group_id, actor=actor)
+ group = await self.group_manager.retrieve_group_async(group_id=group_id, actor=actor)
agent_state_id = group.manager_agent_id or (group.agent_ids[0] if len(group.agent_ids) > 0 else None)
- agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_state_id, actor=actor) if agent_state_id else None
+ agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_state_id, actor=actor) if agent_state_id else None
letta_multi_agent = load_multi_agent(group=group, agent_state=agent_state, actor=actor)
llm_config = letta_multi_agent.agent_state.llm_config
diff --git a/letta/services/file_processor/file_processor.py b/letta/services/file_processor/file_processor.py
index f972eb6a..60de7c71 100644
--- a/letta/services/file_processor/file_processor.py
+++ b/letta/services/file_processor/file_processor.py
@@ -58,18 +58,30 @@ class FileProcessor:
for page in ocr_response.pages:
chunks = text_chunker.chunk_text(page)
if not chunks:
- log_event("file_processor.chunking_failed", {"filename": filename, "page_index": ocr_response.pages.index(page)})
+ log_event(
+ "file_processor.chunking_failed",
+ {
+ "filename": filename,
+ "page_index": ocr_response.pages.index(page),
+ },
+ )
raise ValueError("No chunks created from text")
all_chunks.extend(chunks)
all_passages = await self.embedder.generate_embedded_passages(
- file_id=file_metadata.id, source_id=source_id, chunks=all_chunks, actor=self.actor
+ file_id=file_metadata.id,
+ source_id=source_id,
+ chunks=all_chunks,
+ actor=self.actor,
)
return all_passages
except Exception as e:
logger.warning(f"Failed to chunk/embed with file-specific chunker for {filename}: {str(e)}. Retrying with default chunker.")
- log_event("file_processor.embedding_failed_retrying", {"filename": filename, "error": str(e), "error_type": type(e).__name__})
+ log_event(
+ "file_processor.embedding_failed_retrying",
+ {"filename": filename, "error": str(e), "error_type": type(e).__name__},
+ )
# Retry with default chunker
try:
@@ -80,31 +92,50 @@ class FileProcessor:
chunks = text_chunker.default_chunk_text(page)
if not chunks:
log_event(
- "file_processor.default_chunking_failed", {"filename": filename, "page_index": ocr_response.pages.index(page)}
+ "file_processor.default_chunking_failed",
+ {
+ "filename": filename,
+ "page_index": ocr_response.pages.index(page),
+ },
)
raise ValueError("No chunks created from text with default chunker")
all_chunks.extend(chunks)
all_passages = await self.embedder.generate_embedded_passages(
- file_id=file_metadata.id, source_id=source_id, chunks=all_chunks, actor=self.actor
+ file_id=file_metadata.id,
+ source_id=source_id,
+ chunks=all_chunks,
+ actor=self.actor,
)
logger.info(f"Successfully generated passages with default chunker for {filename}")
- log_event("file_processor.default_chunking_success", {"filename": filename, "total_chunks": len(all_chunks)})
+ log_event(
+ "file_processor.default_chunking_success",
+ {"filename": filename, "total_chunks": len(all_chunks)},
+ )
return all_passages
except Exception as fallback_error:
logger.error("Default chunking also failed for %s: %s", filename, fallback_error)
log_event(
"file_processor.default_chunking_also_failed",
- {"filename": filename, "fallback_error": str(fallback_error), "fallback_error_type": type(fallback_error).__name__},
+ {
+ "filename": filename,
+ "fallback_error": str(fallback_error),
+ "fallback_error_type": type(fallback_error).__name__,
+ },
)
raise fallback_error
# TODO: Factor this function out of SyncServer
@trace_method
async def process(
- self, server: SyncServer, agent_states: List[AgentState], source_id: str, content: bytes, file_metadata: FileMetadata
- ) -> List[Passage]:
+ self,
+ server: SyncServer,
+ agent_states: list[AgentState],
+ source_id: str,
+ content: bytes,
+ file_metadata: FileMetadata,
+ ) -> list[Passage]:
filename = file_metadata.file_name
# Create file as early as possible with no content
@@ -170,18 +201,28 @@ class FileProcessor:
raise ValueError("No text extracted from PDF")
logger.info("Chunking extracted text")
- log_event("file_processor.chunking_started", {"filename": filename, "pages_to_process": len(ocr_response.pages)})
+ log_event(
+ "file_processor.chunking_started",
+ {"filename": filename, "pages_to_process": len(ocr_response.pages)},
+ )
# Chunk and embed with fallback logic
all_passages = await self._chunk_and_embed_with_fallback(
- file_metadata=file_metadata, ocr_response=ocr_response, source_id=source_id
+ file_metadata=file_metadata,
+ ocr_response=ocr_response,
+ source_id=source_id,
)
if not self.using_pinecone:
all_passages = await self.passage_manager.create_many_source_passages_async(
- passages=all_passages, file_metadata=file_metadata, actor=self.actor
+ passages=all_passages,
+ file_metadata=file_metadata,
+ actor=self.actor,
+ )
+ log_event(
+ "file_processor.passages_created",
+ {"filename": filename, "total_passages": len(all_passages)},
)
- log_event("file_processor.passages_created", {"filename": filename, "total_passages": len(all_passages)})
logger.info(f"Successfully processed {filename}: {len(all_passages)} passages")
log_event(
@@ -197,11 +238,16 @@ class FileProcessor:
# update job status
if not self.using_pinecone:
await self.file_manager.update_file_status(
- file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED
+ file_id=file_metadata.id,
+ actor=self.actor,
+ processing_status=FileProcessingStatus.COMPLETED,
)
else:
await self.file_manager.update_file_status(
- file_id=file_metadata.id, actor=self.actor, total_chunks=len(all_passages), chunks_embedded=0
+ file_id=file_metadata.id,
+ actor=self.actor,
+ total_chunks=len(all_passages),
+ chunks_embedded=0,
)
return all_passages
@@ -310,7 +356,10 @@ class FileProcessor:
},
)
await self.file_manager.update_file_status(
- file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.ERROR, error_message=str(e)
+ file_id=file_metadata.id,
+ actor=self.actor,
+ processing_status=FileProcessingStatus.ERROR,
+ error_message=str(e),
)
return []
diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py
index 96684271..91e15d50 100644
--- a/letta/services/group_manager.py
+++ b/letta/services/group_manager.py
@@ -220,6 +220,13 @@ class GroupManager:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
group.hard_delete(session)
+ @enforce_types
+ @trace_method
+ async def delete_group_async(self, group_id: str, actor: PydanticUser) -> None:
+ async with db_registry.async_session() as session:
+ group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
+ await group.hard_delete_async(session)
+
@enforce_types
@trace_method
def list_group_messages(
diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py
index 610ffb2e..675ff763 100644
--- a/letta/services/provider_manager.py
+++ b/letta/services/provider_manager.py
@@ -207,7 +207,7 @@ class ProviderManager:
@enforce_types
@trace_method
- def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
+ async def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
provider = PydanticProvider(
name=provider_check.provider_type.value,
provider_type=provider_check.provider_type,
@@ -221,4 +221,4 @@ class ProviderManager:
if not provider.api_key:
raise ValueError("API key is required")
- provider.check_api_key()
+ await provider.check_api_key()
diff --git a/tests/integration_test_async_tool_sandbox.py b/tests/integration_test_async_tool_sandbox.py
index d3528597..d1f599aa 100644
--- a/tests/integration_test_async_tool_sandbox.py
+++ b/tests/integration_test_async_tool_sandbox.py
@@ -1,4 +1,4 @@
-import asyncio
+import os
import secrets
import string
import uuid
@@ -18,6 +18,7 @@ from letta.schemas.organization import Organization
from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate
from letta.schemas.user import User
+from letta.server.db import db_registry
from letta.server.server import SyncServer
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager
@@ -32,6 +33,48 @@ namespace = uuid.NAMESPACE_DNS
org_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-org"))
user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user"))
+# Set environment variable immediately to prevent pooling issues
+os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] = "true"
+
+# Recreate settings instance to pick up the environment variable
+import letta.settings
+
+# Force settings reload after setting environment variable
+from letta.settings import Settings
+
+letta.settings.settings = Settings()
+
+
+# Disable SQLAlchemy connection pooling for tests to prevent event loop issues
+@pytest.fixture(scope="session", autouse=True)
+def disable_db_pooling_for_tests():
+ """Disable database connection pooling for the entire test session."""
+ # Environment variable is already set above and settings reloaded
+ yield
+ # Clean up environment variable after tests
+ if "LETTA_DISABLE_SQLALCHEMY_POOLING" in os.environ:
+ del os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"]
+
+
+@pytest.fixture(autouse=True)
+async def cleanup_db_connections():
+ """Cleanup database connections after each test."""
+ yield
+
+ # Dispose async engines in the current event loop
+ try:
+ if hasattr(db_registry, "_async_engines"):
+ for engine in db_registry._async_engines.values():
+ if engine:
+ await engine.dispose()
+ # Reset async initialization to force fresh connections
+ db_registry._initialized["async"] = False
+ db_registry._async_engines.clear()
+ db_registry._async_session_factories.clear()
+ except Exception as e:
+ # Log the error but don't fail the test
+ print(f"Warning: Failed to cleanup database connections: {e}")
+
# Fixtures
@pytest.fixture(scope="module")
@@ -50,14 +93,14 @@ def server():
@pytest.fixture(autouse=True)
-def clear_tables():
+async def clear_tables():
"""Fixture to clear the organization table before each test."""
- from letta.server.db import db_context
+ from letta.server.db import db_registry
- with db_context() as session:
- session.execute(delete(SandboxEnvironmentVariable))
- session.execute(delete(SandboxConfig))
- session.commit() # Commit the deletion
+ async with db_registry.async_session() as session:
+ await session.execute(delete(SandboxEnvironmentVariable))
+ await session.execute(delete(SandboxConfig))
+ await session.commit() # Commit the deletion
@pytest.fixture
@@ -208,9 +251,9 @@ def external_codebase_tool(test_user):
@pytest.fixture
-def agent_state(server):
+async def agent_state(server):
actor = server.user_manager.get_user_or_default()
- agent_state = server.create_agent(
+ agent_state = await server.create_agent_async(
CreateAgent(
memory_blocks=[
CreateBlock(
@@ -471,12 +514,7 @@ def async_complex_tool(test_user):
yield tool
-@pytest.fixture(scope="session")
-def event_loop(request):
- """Create an instance of the default event loop for each test case."""
- loop = asyncio.get_event_loop_policy().new_event_loop()
- yield loop
- loop.close()
+# Removed custom event_loop fixture to avoid conflicts with pytest-asyncio
# Local sandbox tests
@@ -484,7 +522,7 @@ def event_loop(request):
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_default(disable_e2b_api_key, add_integers_tool, test_user, event_loop):
+async def test_local_sandbox_default(disable_e2b_api_key, add_integers_tool, test_user):
args = {"x": 10, "y": 5}
# Mock and assert correct pathway was invoked
@@ -501,7 +539,7 @@ async def test_local_sandbox_default(disable_e2b_api_key, add_integers_tool, tes
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_stateful_tool(disable_e2b_api_key, clear_core_memory_tool, test_user, agent_state, event_loop):
+async def test_local_sandbox_stateful_tool(disable_e2b_api_key, clear_core_memory_tool, test_user, agent_state):
args = {}
sandbox = AsyncToolSandboxLocal(clear_core_memory_tool.name, args, user=test_user)
result = await sandbox.run(agent_state=agent_state)
@@ -513,7 +551,7 @@ async def test_local_sandbox_stateful_tool(disable_e2b_api_key, clear_core_memor
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_user, event_loop):
+async def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_user):
sandbox = AsyncToolSandboxLocal(list_tool.name, {}, user=test_user)
result = await sandbox.run()
assert len(result.func_return) == 5
@@ -521,7 +559,7 @@ async def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_u
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user, event_loop):
+async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user):
manager = SandboxConfigManager()
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
config_create = SandboxConfigCreate(config=LocalSandboxConfig(sandbox_dir=sandbox_dir).model_dump())
@@ -540,7 +578,7 @@ async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user, e
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_per_agent_env(disable_e2b_api_key, get_env_tool, agent_state, test_user, event_loop):
+async def test_local_sandbox_per_agent_env(disable_e2b_api_key, get_env_tool, agent_state, test_user):
manager = SandboxConfigManager()
key = "secret_word"
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
@@ -562,7 +600,7 @@ async def test_local_sandbox_per_agent_env(disable_e2b_api_key, get_env_tool, ag
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_external_codebase_with_venv(
- disable_e2b_api_key, custom_test_sandbox_config, external_codebase_tool, test_user, event_loop
+ disable_e2b_api_key, custom_test_sandbox_config, external_codebase_tool, test_user
):
args = {"percentage": 10}
sandbox = AsyncToolSandboxLocal(external_codebase_tool.name, args, user=test_user)
@@ -574,7 +612,7 @@ async def test_local_sandbox_external_codebase_with_venv(
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_with_venv_and_warnings_does_not_error(
- disable_e2b_api_key, custom_test_sandbox_config, get_warning_tool, test_user, event_loop
+ disable_e2b_api_key, custom_test_sandbox_config, get_warning_tool, test_user
):
sandbox = AsyncToolSandboxLocal(get_warning_tool.name, {}, user=test_user)
result = await sandbox.run()
@@ -583,7 +621,7 @@ async def test_local_sandbox_with_venv_and_warnings_does_not_error(
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_sandbox_config, always_err_tool, test_user, event_loop):
+async def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_sandbox_config, always_err_tool, test_user):
sandbox = AsyncToolSandboxLocal(always_err_tool.name, {}, user=test_user)
result = await sandbox.run()
assert len(result.stdout) != 0
@@ -594,7 +632,7 @@ async def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_s
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, cowsay_tool, test_user, event_loop):
+async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, cowsay_tool, test_user):
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(
config=LocalSandboxConfig(use_venv=True, pip_requirements=[PipRequirement(name="cowsay")]).model_dump()
@@ -614,7 +652,7 @@ async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, c
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_with_tool_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user, event_loop):
+async def test_local_sandbox_with_tool_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user):
"""Test that local sandbox installs tool-specific pip requirements."""
manager = SandboxConfigManager()
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
@@ -634,7 +672,7 @@ async def test_local_sandbox_with_tool_pip_requirements(disable_e2b_api_key, too
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_with_mixed_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user, event_loop):
+async def test_local_sandbox_with_mixed_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user):
"""Test that local sandbox installs both sandbox and tool pip requirements."""
manager = SandboxConfigManager()
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
@@ -658,7 +696,7 @@ async def test_local_sandbox_with_mixed_pip_requirements(disable_e2b_api_key, to
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_key, cowsay_tool, test_user, event_loop):
+async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_key, cowsay_tool, test_user):
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=LocalSandboxConfig(use_venv=True).model_dump())
config = manager.create_or_update_sandbox_config(config_create, test_user)
@@ -689,7 +727,7 @@ async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test_user, event_loop):
+async def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test_user):
args = {"x": 10, "y": 5}
# Mock and assert correct pathway was invoked
@@ -706,7 +744,7 @@ async def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_user, event_loop):
+async def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_user):
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=E2BSandboxConfig(pip_requirements=["cowsay"]).model_dump())
config = manager.create_or_update_sandbox_config(config_create, test_user)
@@ -726,7 +764,7 @@ async def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool, test_user, agent_state, event_loop):
+async def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool, test_user, agent_state):
sandbox = AsyncToolSandboxE2B(clear_core_memory_tool.name, {}, user=test_user)
result = await sandbox.run(agent_state=agent_state)
assert result.agent_state.memory.get_block("human").value == ""
@@ -736,7 +774,7 @@ async def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_env_tool, test_user, event_loop):
+async def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_env_tool, test_user):
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump())
config = manager.create_or_update_sandbox_config(config_create, test_user)
@@ -760,7 +798,7 @@ async def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set,
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, agent_state, test_user, event_loop):
+async def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, agent_state, test_user):
manager = SandboxConfigManager()
key = "secret_word"
wrong_val = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20))
@@ -784,7 +822,7 @@ async def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, age
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_user, event_loop):
+async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_user):
sandbox = AsyncToolSandboxE2B(list_tool.name, {}, user=test_user)
result = await sandbox.run()
assert len(result.func_return) == 5
@@ -792,7 +830,7 @@ async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_us
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_with_tool_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user, event_loop):
+async def test_e2b_sandbox_with_tool_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user):
"""Test that E2B sandbox installs tool-specific pip requirements."""
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump())
@@ -809,7 +847,7 @@ async def test_e2b_sandbox_with_tool_pip_requirements(check_e2b_key_is_set, tool
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user, event_loop):
+async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user):
"""Test that E2B sandbox installs both sandbox and tool pip requirements."""
manager = SandboxConfigManager()
@@ -829,7 +867,7 @@ async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, too
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_with_broken_tool_pip_requirements_error_handling(
- check_e2b_key_is_set, tool_with_broken_pip_requirements, test_user, event_loop
+ check_e2b_key_is_set, tool_with_broken_pip_requirements, test_user
):
"""Test that E2B sandbox provides informative error messages for broken tool pip requirements."""
manager = SandboxConfigManager()
@@ -894,7 +932,7 @@ def test_async_template_selection(add_integers_tool, async_add_integers_tool, te
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_async_function_execution(disable_e2b_api_key, async_add_integers_tool, test_user, event_loop):
+async def test_local_sandbox_async_function_execution(disable_e2b_api_key, async_add_integers_tool, test_user):
"""Test that async functions execute correctly in local sandbox"""
args = {"x": 15, "y": 25}
@@ -905,7 +943,7 @@ async def test_local_sandbox_async_function_execution(disable_e2b_api_key, async
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_async_function_execution(check_e2b_key_is_set, async_add_integers_tool, test_user, event_loop):
+async def test_e2b_sandbox_async_function_execution(check_e2b_key_is_set, async_add_integers_tool, test_user):
"""Test that async functions execute correctly in E2B sandbox"""
args = {"x": 20, "y": 30}
@@ -916,7 +954,7 @@ async def test_e2b_sandbox_async_function_execution(check_e2b_key_is_set, async_
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_async_complex_computation(disable_e2b_api_key, async_complex_tool, test_user, event_loop):
+async def test_local_sandbox_async_complex_computation(disable_e2b_api_key, async_complex_tool, test_user):
"""Test complex async computation with multiple awaits in local sandbox"""
args = {"iterations": 2}
@@ -932,7 +970,7 @@ async def test_local_sandbox_async_complex_computation(disable_e2b_api_key, asyn
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_async_complex_computation(check_e2b_key_is_set, async_complex_tool, test_user, event_loop):
+async def test_e2b_sandbox_async_complex_computation(check_e2b_key_is_set, async_complex_tool, test_user):
"""Test complex async computation with multiple awaits in E2B sandbox"""
args = {"iterations": 2}
@@ -949,7 +987,7 @@ async def test_e2b_sandbox_async_complex_computation(check_e2b_key_is_set, async
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_async_list_return(disable_e2b_api_key, async_list_tool, test_user, event_loop):
+async def test_local_sandbox_async_list_return(disable_e2b_api_key, async_list_tool, test_user):
"""Test async function returning list in local sandbox"""
sandbox = AsyncToolSandboxLocal(async_list_tool.name, {}, user=test_user)
result = await sandbox.run()
@@ -958,7 +996,7 @@ async def test_local_sandbox_async_list_return(disable_e2b_api_key, async_list_t
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_async_list_return(check_e2b_key_is_set, async_list_tool, test_user, event_loop):
+async def test_e2b_sandbox_async_list_return(check_e2b_key_is_set, async_list_tool, test_user):
"""Test async function returning list in E2B sandbox"""
sandbox = AsyncToolSandboxE2B(async_list_tool.name, {}, user=test_user)
result = await sandbox.run()
@@ -967,7 +1005,7 @@ async def test_e2b_sandbox_async_list_return(check_e2b_key_is_set, async_list_to
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_async_with_env_vars(disable_e2b_api_key, async_get_env_tool, test_user, event_loop):
+async def test_local_sandbox_async_with_env_vars(disable_e2b_api_key, async_get_env_tool, test_user):
"""Test async function with environment variables in local sandbox"""
manager = SandboxConfigManager()
@@ -991,7 +1029,7 @@ async def test_local_sandbox_async_with_env_vars(disable_e2b_api_key, async_get_
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_async_with_env_vars(check_e2b_key_is_set, async_get_env_tool, test_user, event_loop):
+async def test_e2b_sandbox_async_with_env_vars(check_e2b_key_is_set, async_get_env_tool, test_user):
"""Test async function with environment variables in E2B sandbox"""
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump())
@@ -1012,7 +1050,7 @@ async def test_e2b_sandbox_async_with_env_vars(check_e2b_key_is_set, async_get_e
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_async_with_agent_state(disable_e2b_api_key, async_stateful_tool, test_user, agent_state, event_loop):
+async def test_local_sandbox_async_with_agent_state(disable_e2b_api_key, async_stateful_tool, test_user, agent_state):
"""Test async function with agent state in local sandbox"""
sandbox = AsyncToolSandboxLocal(async_stateful_tool.name, {}, user=test_user)
result = await sandbox.run(agent_state=agent_state)
@@ -1025,7 +1063,7 @@ async def test_local_sandbox_async_with_agent_state(disable_e2b_api_key, async_s
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_async_with_agent_state(check_e2b_key_is_set, async_stateful_tool, test_user, agent_state, event_loop):
+async def test_e2b_sandbox_async_with_agent_state(check_e2b_key_is_set, async_stateful_tool, test_user, agent_state):
"""Test async function with agent state in E2B sandbox"""
sandbox = AsyncToolSandboxE2B(async_stateful_tool.name, {}, user=test_user)
result = await sandbox.run(agent_state=agent_state)
@@ -1037,7 +1075,7 @@ async def test_e2b_sandbox_async_with_agent_state(check_e2b_key_is_set, async_st
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_async_error_handling(disable_e2b_api_key, async_error_tool, test_user, event_loop):
+async def test_local_sandbox_async_error_handling(disable_e2b_api_key, async_error_tool, test_user):
"""Test async function error handling in local sandbox"""
sandbox = AsyncToolSandboxLocal(async_error_tool.name, {}, user=test_user)
result = await sandbox.run()
@@ -1051,7 +1089,7 @@ async def test_local_sandbox_async_error_handling(disable_e2b_api_key, async_err
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_async_error_handling(check_e2b_key_is_set, async_error_tool, test_user, event_loop):
+async def test_e2b_sandbox_async_error_handling(check_e2b_key_is_set, async_error_tool, test_user):
"""Test async function error handling in E2B sandbox"""
sandbox = AsyncToolSandboxE2B(async_error_tool.name, {}, user=test_user)
result = await sandbox.run()
@@ -1065,7 +1103,7 @@ async def test_e2b_sandbox_async_error_handling(check_e2b_key_is_set, async_erro
@pytest.mark.asyncio
@pytest.mark.local_sandbox
-async def test_local_sandbox_async_per_agent_env(disable_e2b_api_key, async_get_env_tool, agent_state, test_user, event_loop):
+async def test_local_sandbox_async_per_agent_env(disable_e2b_api_key, async_get_env_tool, agent_state, test_user):
"""Test async function with per-agent environment variables in local sandbox"""
manager = SandboxConfigManager()
key = "secret_word"
@@ -1087,7 +1125,7 @@ async def test_local_sandbox_async_per_agent_env(disable_e2b_api_key, async_get_
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
-async def test_e2b_sandbox_async_per_agent_env(check_e2b_key_is_set, async_get_env_tool, agent_state, test_user, event_loop):
+async def test_e2b_sandbox_async_per_agent_env(check_e2b_key_is_set, async_get_env_tool, agent_state, test_user):
"""Test async function with per-agent environment variables in E2B sandbox"""
manager = SandboxConfigManager()
key = "secret_word"
diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py
index 2dda36a9..b4fa4eb9 100644
--- a/tests/test_multi_agent.py
+++ b/tests/test_multi_agent.py
@@ -1,8 +1,8 @@
+import os
+
import pytest
-from sqlalchemy import delete
from letta.config import LettaConfig
-from letta.orm import Provider, ProviderTrace, Step
from letta.schemas.agent import CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.group import (
@@ -19,6 +19,37 @@ from letta.server.db import db_registry
from letta.server.server import SyncServer
+# Disable SQLAlchemy connection pooling for tests to prevent event loop issues
+@pytest.fixture(scope="session", autouse=True)
+def disable_db_pooling_for_tests():
+ """Disable database connection pooling for the entire test session."""
+ os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] = "true"
+ yield
+ # Clean up environment variable after tests
+ if "LETTA_DISABLE_SQLALCHEMY_POOLING" in os.environ:
+ del os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"]
+
+
+@pytest.fixture(autouse=True)
+async def cleanup_db_connections():
+ """Cleanup database connections after each test."""
+ yield
+
+ # Dispose async engines in the current event loop
+ try:
+ if hasattr(db_registry, "_async_engines"):
+ for engine in db_registry._async_engines.values():
+ if engine:
+ await engine.dispose()
+ # Reset async initialization to force fresh connections
+ db_registry._initialized["async"] = False
+ db_registry._async_engines.clear()
+ db_registry._async_session_factories.clear()
+ except Exception as e:
+ # Log the error but don't fail the test
+ print(f"Warning: Failed to cleanup database connections: {e}")
+
+
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
@@ -30,33 +61,21 @@ def server():
return server
-@pytest.fixture(scope="module")
-def org_id(server):
- org = server.organization_manager.create_default_organization()
-
- yield org.id
-
- # cleanup
- with db_registry.session() as session:
- session.execute(delete(ProviderTrace))
- session.execute(delete(Step))
- session.execute(delete(Provider))
- session.commit()
- server.organization_manager.delete_organization_by_id(org.id)
+@pytest.fixture
+async def default_organization(server: SyncServer):
+ """Fixture to create and return the default organization."""
+ yield await server.organization_manager.create_default_organization_async()
-@pytest.fixture(scope="module")
-def actor(server, org_id):
- user = server.user_manager.create_default_user()
- yield user
-
- # cleanup
- server.user_manager.delete_user_by_id(user.id)
+@pytest.fixture
+async def default_user(server: SyncServer, default_organization):
+ """Fixture to create and return the default user within the default organization."""
+ yield await server.user_manager.create_default_actor_async(org_id=default_organization.id)
-@pytest.fixture(scope="module")
-def participant_agents(server, actor):
- agent_fred = server.create_agent(
+@pytest.fixture
+async def four_participant_agents(server, default_user):
+ agent_fred = await server.create_agent_async(
request=CreateAgent(
name="fred",
memory_blocks=[
@@ -68,9 +87,9 @@ def participant_agents(server, actor):
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
),
- actor=actor,
+ actor=default_user,
)
- agent_velma = server.create_agent(
+ agent_velma = await server.create_agent_async(
request=CreateAgent(
name="velma",
memory_blocks=[
@@ -82,9 +101,9 @@ def participant_agents(server, actor):
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
),
- actor=actor,
+ actor=default_user,
)
- agent_daphne = server.create_agent(
+ agent_daphne = await server.create_agent_async(
request=CreateAgent(
name="daphne",
memory_blocks=[
@@ -96,9 +115,9 @@ def participant_agents(server, actor):
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
),
- actor=actor,
+ actor=default_user,
)
- agent_shaggy = server.create_agent(
+ agent_shaggy = await server.create_agent_async(
request=CreateAgent(
name="shaggy",
memory_blocks=[
@@ -110,20 +129,14 @@ def participant_agents(server, actor):
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
),
- actor=actor,
+ actor=default_user,
)
yield [agent_fred, agent_velma, agent_daphne, agent_shaggy]
- # cleanup
- server.agent_manager.delete_agent(agent_fred.id, actor=actor)
- server.agent_manager.delete_agent(agent_velma.id, actor=actor)
- server.agent_manager.delete_agent(agent_daphne.id, actor=actor)
- server.agent_manager.delete_agent(agent_shaggy.id, actor=actor)
-
-@pytest.fixture(scope="module")
-def manager_agent(server, actor):
- agent_scooby = server.create_agent(
+@pytest.fixture
+async def manager_agent(server, default_user):
+ agent_scooby = await server.create_agent_async(
request=CreateAgent(
name="scooby",
memory_blocks=[
@@ -139,27 +152,24 @@ def manager_agent(server, actor):
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
),
- actor=actor,
+ actor=default_user,
)
yield agent_scooby
- # cleanup
- server.agent_manager.delete_agent(agent_scooby.id, actor=actor)
-
-@pytest.mark.asyncio(loop_scope="module")
-async def test_empty_group(server, actor):
- group = server.group_manager.create_group(
+@pytest.mark.asyncio
+async def test_empty_group(server, default_user):
+ group = await server.group_manager.create_group_async(
group=GroupCreate(
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
agent_ids=[],
),
- actor=actor,
+ actor=default_user,
)
with pytest.raises(ValueError, match="Empty group"):
await server.send_group_message_to_agent(
group_id=group.id,
- actor=actor,
+ actor=default_user,
input_messages=[
MessageCreate(
role="user",
@@ -169,17 +179,17 @@ async def test_empty_group(server, actor):
stream_steps=False,
stream_tokens=False,
)
- server.group_manager.delete_group(group_id=group.id, actor=actor)
+ await server.group_manager.delete_group_async(group_id=group.id, actor=default_user)
-@pytest.mark.asyncio(loop_scope="module")
-async def test_modify_group_pattern(server, actor, participant_agents, manager_agent):
- group = server.group_manager.create_group(
+@pytest.mark.asyncio
+async def test_modify_group_pattern(server, default_user, four_participant_agents, manager_agent):
+ group = await server.group_manager.create_group_async(
group=GroupCreate(
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
- agent_ids=[agent.id for agent in participant_agents],
+ agent_ids=[agent.id for agent in four_participant_agents],
),
- actor=actor,
+ actor=default_user,
)
with pytest.raises(ValueError, match="Cannot change group pattern"):
await server.group_manager.modify_group_async(
@@ -190,64 +200,64 @@ async def test_modify_group_pattern(server, actor, participant_agents, manager_a
manager_agent_id=manager_agent.id,
),
),
- actor=actor,
+ actor=default_user,
)
- server.group_manager.delete_group(group_id=group.id, actor=actor)
+ await server.group_manager.delete_group_async(group_id=group.id, actor=default_user)
-@pytest.mark.asyncio(loop_scope="module")
-async def test_list_agent_groups(server, actor, participant_agents):
- group_a = server.group_manager.create_group(
+@pytest.mark.asyncio
+async def test_list_agent_groups(server, default_user, four_participant_agents):
+ group_a = await server.group_manager.create_group_async(
group=GroupCreate(
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
- agent_ids=[agent.id for agent in participant_agents],
+ agent_ids=[agent.id for agent in four_participant_agents],
),
- actor=actor,
+ actor=default_user,
)
- group_b = server.group_manager.create_group(
+ group_b = await server.group_manager.create_group_async(
group=GroupCreate(
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
- agent_ids=[participant_agents[0].id],
+ agent_ids=[four_participant_agents[0].id],
),
- actor=actor,
+ actor=default_user,
)
- agent_a_groups = server.agent_manager.list_groups(agent_id=participant_agents[0].id, actor=actor)
+ agent_a_groups = server.agent_manager.list_groups(agent_id=four_participant_agents[0].id, actor=default_user)
assert sorted([group.id for group in agent_a_groups]) == sorted([group_a.id, group_b.id])
- agent_b_groups = server.agent_manager.list_groups(agent_id=participant_agents[1].id, actor=actor)
+ agent_b_groups = server.agent_manager.list_groups(agent_id=four_participant_agents[1].id, actor=default_user)
assert [group.id for group in agent_b_groups] == [group_a.id]
- server.group_manager.delete_group(group_id=group_a.id, actor=actor)
- server.group_manager.delete_group(group_id=group_b.id, actor=actor)
+ await server.group_manager.delete_group_async(group_id=group_a.id, actor=default_user)
+ await server.group_manager.delete_group_async(group_id=group_b.id, actor=default_user)
-@pytest.mark.asyncio(loop_scope="module")
-async def test_round_robin(server, actor, participant_agents):
+@pytest.mark.asyncio
+async def test_round_robin(server, default_user, four_participant_agents):
description = (
"This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries."
)
- group = server.group_manager.create_group(
+ group = await server.group_manager.create_group_async(
group=GroupCreate(
description=description,
- agent_ids=[agent.id for agent in participant_agents],
+ agent_ids=[agent.id for agent in four_participant_agents],
),
- actor=actor,
+ actor=default_user,
)
# verify group creation
assert group.manager_type == ManagerType.round_robin
assert group.description == description
- assert group.agent_ids == [agent.id for agent in participant_agents]
+ assert group.agent_ids == [agent.id for agent in four_participant_agents]
assert group.max_turns is None
assert group.manager_agent_id is None
assert group.termination_token is None
try:
- server.group_manager.reset_messages(group_id=group.id, actor=actor)
+ server.group_manager.reset_messages(group_id=group.id, actor=default_user)
response = await server.send_group_message_to_agent(
group_id=group.id,
- actor=actor,
+ actor=default_user,
input_messages=[
MessageCreate(
role="user",
@@ -261,11 +271,11 @@ async def test_round_robin(server, actor, participant_agents):
assert len(response.messages) == response.usage.step_count * 2
for i, message in enumerate(response.messages):
assert message.message_type == "reasoning_message" if i % 2 == 0 else "assistant_message"
- assert message.name == participant_agents[i // 2].name
+ assert message.name == four_participant_agents[i // 2].name
for agent_id in group.agent_ids:
agent_messages = server.get_agent_recall(
- user_id=actor.id,
+ user_id=default_user.id,
agent_id=agent_id,
group_id=group.id,
reverse=True,
@@ -276,7 +286,7 @@ async def test_round_robin(server, actor, participant_agents):
# TODO: filter this to return a clean conversation history
messages = server.group_manager.list_group_messages(
group_id=group.id,
- actor=actor,
+ actor=default_user,
)
assert len(messages) == (len(group.agent_ids) + 2) * len(group.agent_ids)
@@ -284,25 +294,25 @@ async def test_round_robin(server, actor, participant_agents):
group = await server.group_manager.modify_group_async(
group_id=group.id,
group_update=GroupUpdate(
- agent_ids=[agent.id for agent in participant_agents][::-1],
+ agent_ids=[agent.id for agent in four_participant_agents][::-1],
manager_config=RoundRobinManagerUpdate(
max_turns=max_turns,
),
),
- actor=actor,
+ actor=default_user,
)
assert group.manager_type == ManagerType.round_robin
assert group.description == description
- assert group.agent_ids == [agent.id for agent in participant_agents][::-1]
+ assert group.agent_ids == [agent.id for agent in four_participant_agents][::-1]
assert group.max_turns == max_turns
assert group.manager_agent_id is None
assert group.termination_token is None
- server.group_manager.reset_messages(group_id=group.id, actor=actor)
+ server.group_manager.reset_messages(group_id=group.id, actor=default_user)
response = await server.send_group_message_to_agent(
group_id=group.id,
- actor=actor,
+ actor=default_user,
input_messages=[
MessageCreate(
role="user",
@@ -317,11 +327,11 @@ async def test_round_robin(server, actor, participant_agents):
for i, message in enumerate(response.messages):
assert message.message_type == "reasoning_message" if i % 2 == 0 else "assistant_message"
- assert message.name == participant_agents[::-1][i // 2].name
+ assert message.name == four_participant_agents[::-1][i // 2].name
for i in range(len(group.agent_ids)):
agent_messages = server.get_agent_recall(
- user_id=actor.id,
+ user_id=default_user.id,
agent_id=group.agent_ids[i],
group_id=group.id,
reverse=True,
@@ -331,12 +341,12 @@ async def test_round_robin(server, actor, participant_agents):
assert len(agent_messages) == expected_message_count
finally:
- server.group_manager.delete_group(group_id=group.id, actor=actor)
+ await server.group_manager.delete_group_async(group_id=group.id, actor=default_user)
-@pytest.mark.asyncio(loop_scope="module")
-async def test_supervisor(server, actor, participant_agents):
- agent_scrappy = server.create_agent(
+@pytest.mark.asyncio
+async def test_supervisor(server, default_user, four_participant_agents):
+ agent_scrappy = await server.create_agent_async(
request=CreateAgent(
name="shaggy",
memory_blocks=[
@@ -352,23 +362,23 @@ async def test_supervisor(server, actor, participant_agents):
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
),
- actor=actor,
+ actor=default_user,
)
- group = server.group_manager.create_group(
+ group = await server.group_manager.create_group_async(
group=GroupCreate(
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
- agent_ids=[agent.id for agent in participant_agents],
+ agent_ids=[agent.id for agent in four_participant_agents],
manager_config=SupervisorManager(
manager_agent_id=agent_scrappy.id,
),
),
- actor=actor,
+ actor=default_user,
)
try:
response = await server.send_group_message_to_agent(
group_id=group.id,
- actor=actor,
+ actor=default_user,
input_messages=[
MessageCreate(
role="user",
@@ -388,32 +398,33 @@ async def test_supervisor(server, actor, participant_agents):
and response.messages[1].tool_call.name == "send_message_to_all_agents_in_group"
)
assert response.messages[2].message_type == "tool_return_message" and len(eval(response.messages[2].tool_return)) == len(
- participant_agents
+ four_participant_agents
)
assert response.messages[3].message_type == "reasoning_message"
assert response.messages[4].message_type == "assistant_message"
finally:
- server.group_manager.delete_group(group_id=group.id, actor=actor)
- server.agent_manager.delete_agent(agent_id=agent_scrappy.id, actor=actor)
+ await server.group_manager.delete_group_async(group_id=group.id, actor=default_user)
+ server.agent_manager.delete_agent(agent_id=agent_scrappy.id, actor=default_user)
-@pytest.mark.asyncio(loop_scope="module")
-async def test_dynamic_group_chat(server, actor, manager_agent, participant_agents):
+@pytest.mark.asyncio
+@pytest.mark.flaky(max_runs=2)
+async def test_dynamic_group_chat(server, default_user, manager_agent, four_participant_agents):
description = (
"This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries."
)
# error on duplicate agent in participant list
with pytest.raises(ValueError, match="Duplicate agent ids"):
- server.group_manager.create_group(
+ await server.group_manager.create_group_async(
group=GroupCreate(
description=description,
- agent_ids=[agent.id for agent in participant_agents] + [participant_agents[0].id],
+ agent_ids=[agent.id for agent in four_participant_agents] + [four_participant_agents[0].id],
manager_config=DynamicManager(
manager_agent_id=manager_agent.id,
),
),
- actor=actor,
+ actor=default_user,
)
# error on duplicate agent names
duplicate_agent_shaggy = server.create_agent(
@@ -422,43 +433,43 @@ async def test_dynamic_group_chat(server, actor, manager_agent, participant_agen
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
),
- actor=actor,
+ actor=default_user,
)
with pytest.raises(ValueError, match="Duplicate agent names"):
- server.group_manager.create_group(
+ await server.group_manager.create_group_async(
group=GroupCreate(
description=description,
- agent_ids=[agent.id for agent in participant_agents] + [duplicate_agent_shaggy.id],
+ agent_ids=[agent.id for agent in four_participant_agents] + [duplicate_agent_shaggy.id],
manager_config=DynamicManager(
manager_agent_id=manager_agent.id,
),
),
- actor=actor,
+ actor=default_user,
)
- server.agent_manager.delete_agent(duplicate_agent_shaggy.id, actor=actor)
+ server.agent_manager.delete_agent(duplicate_agent_shaggy.id, actor=default_user)
- group = server.group_manager.create_group(
+ group = await server.group_manager.create_group_async(
group=GroupCreate(
description=description,
- agent_ids=[agent.id for agent in participant_agents],
+ agent_ids=[agent.id for agent in four_participant_agents],
manager_config=DynamicManager(
manager_agent_id=manager_agent.id,
),
),
- actor=actor,
+ actor=default_user,
)
try:
response = await server.send_group_message_to_agent(
group_id=group.id,
- actor=actor,
+ actor=default_user,
input_messages=[
MessageCreate(role="user", content="what is everyone up to for the holidays?"),
],
stream_steps=False,
stream_tokens=False,
)
- assert response.usage.step_count == len(participant_agents) * 2
+ assert response.usage.step_count == len(four_participant_agents) * 2
assert len(response.messages) == response.usage.step_count * 2
finally:
- server.group_manager.delete_group(group_id=group.id, actor=actor)
+ await server.group_manager.delete_group_async(group_id=group.id, actor=default_user)
diff --git a/tests/test_providers.py b/tests/test_providers.py
index 439d3417..71230218 100644
--- a/tests/test_providers.py
+++ b/tests/test_providers.py
@@ -10,7 +10,7 @@ from letta.schemas.providers import (
OllamaProvider,
OpenAIProvider,
TogetherProvider,
- VLLMChatCompletionsProvider,
+ VLLMProvider,
)
from letta.settings import model_settings
@@ -46,25 +46,8 @@ async def test_openai_async():
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
-def test_deepseek():
- provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key)
- models = provider.list_llm_models()
- assert len(models) > 0
- assert models[0].handle == f"{provider.name}/{models[0].model}"
-
-
-def test_anthropic():
- provider = AnthropicProvider(
- name="anthropic",
- api_key=model_settings.anthropic_api_key,
- )
- models = provider.list_llm_models()
- assert len(models) > 0
- assert models[0].handle == f"{provider.name}/{models[0].model}"
-
-
@pytest.mark.asyncio
-async def test_anthropic_async():
+async def test_anthropic():
provider = AnthropicProvider(
name="anthropic",
api_key=model_settings.anthropic_api_key,
@@ -74,67 +57,8 @@ async def test_anthropic_async():
assert models[0].handle == f"{provider.name}/{models[0].model}"
-def test_groq():
- provider = GroqProvider(
- name="groq",
- api_key=model_settings.groq_api_key,
- )
- models = provider.list_llm_models()
- assert len(models) > 0
- assert models[0].handle == f"{provider.name}/{models[0].model}"
-
-
-def test_azure():
- provider = AzureProvider(
- name="azure",
- api_key=model_settings.azure_api_key,
- base_url=model_settings.azure_base_url,
- api_version=model_settings.azure_api_version,
- )
- models = provider.list_llm_models()
- assert len(models) > 0
- assert models[0].handle == f"{provider.name}/{models[0].model}"
-
- embedding_models = provider.list_embedding_models()
- assert len(embedding_models) > 0
- assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
-
-
-@pytest.mark.skipif(model_settings.ollama_base_url is None, reason="Only run if OLLAMA_BASE_URL is set.")
-def test_ollama():
- provider = OllamaProvider(
- name="ollama",
- base_url=model_settings.ollama_base_url,
- api_key=None,
- default_prompt_formatter=model_settings.default_prompt_formatter,
- )
- models = provider.list_llm_models()
- assert len(models) > 0
- assert models[0].handle == f"{provider.name}/{models[0].model}"
-
- embedding_models = provider.list_embedding_models()
- assert len(embedding_models) > 0
- assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
-
-
-def test_googleai():
- api_key = model_settings.gemini_api_key
- assert api_key is not None
- provider = GoogleAIProvider(
- name="google_ai",
- api_key=api_key,
- )
- models = provider.list_llm_models()
- assert len(models) > 0
- assert models[0].handle == f"{provider.name}/{models[0].model}"
-
- embedding_models = provider.list_embedding_models()
- assert len(embedding_models) > 0
- assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
-
-
@pytest.mark.asyncio
-async def test_googleai_async():
+async def test_googleai():
api_key = model_settings.gemini_api_key
assert api_key is not None
provider = GoogleAIProvider(
@@ -150,42 +74,64 @@ async def test_googleai_async():
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
-def test_google_vertex():
+@pytest.mark.asyncio
+async def test_google_vertex():
provider = GoogleVertexProvider(
name="google_vertex",
google_cloud_project=model_settings.google_cloud_project,
google_cloud_location=model_settings.google_cloud_location,
)
- models = provider.list_llm_models()
+ models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
- embedding_models = provider.list_embedding_models()
+ embedding_models = await provider.list_embedding_models_async()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
-def test_together():
- provider = TogetherProvider(
- name="together",
- api_key=model_settings.together_api_key,
- default_prompt_formatter=model_settings.default_prompt_formatter,
- )
- models = provider.list_llm_models()
- assert len(models) > 0
- # Handle may be different from raw model name due to LLM_HANDLE_OVERRIDES
- assert models[0].handle.startswith(f"{provider.name}/")
- # Verify the handle is properly constructed via get_handle method
- assert models[0].handle == provider.get_handle(models[0].model)
-
- # TODO: We don't have embedding models on together for CI
- # embedding_models = provider.list_embedding_models()
- # assert len(embedding_models) > 0
- # assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
-
-
+@pytest.mark.skipif(model_settings.deepseek_api_key is None, reason="Only run if DEEPSEEK_API_KEY is set.")
@pytest.mark.asyncio
-async def test_together_async():
+async def test_deepseek():
+ provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key)
+ models = await provider.list_llm_models_async()
+ assert len(models) > 0
+ assert models[0].handle == f"{provider.name}/{models[0].model}"
+
+
+@pytest.mark.skipif(model_settings.groq_api_key is None, reason="Only run if GROQ_API_KEY is set.")
+@pytest.mark.asyncio
+async def test_groq():
+ provider = GroqProvider(
+ name="groq",
+ api_key=model_settings.groq_api_key,
+ )
+ models = await provider.list_llm_models_async()
+ assert len(models) > 0
+ assert models[0].handle == f"{provider.name}/{models[0].model}"
+
+
+@pytest.mark.skipif(model_settings.azure_api_key is None, reason="Only run if AZURE_API_KEY is set.")
+@pytest.mark.asyncio
+async def test_azure():
+ provider = AzureProvider(
+ name="azure",
+ api_key=model_settings.azure_api_key,
+ base_url=model_settings.azure_base_url,
+ api_version=model_settings.azure_api_version,
+ )
+ models = await provider.list_llm_models_async()
+ assert len(models) > 0
+ assert models[0].handle == f"{provider.name}/{models[0].model}"
+
+ embedding_models = await provider.list_embedding_models_async()
+ assert len(embedding_models) > 0
+ assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
+
+
+@pytest.mark.skipif(model_settings.together_api_key is None, reason="Only run if TOGETHER_API_KEY is set.")
+@pytest.mark.asyncio
+async def test_together():
provider = TogetherProvider(
name="together",
api_key=model_settings.together_api_key,
@@ -204,6 +150,33 @@ async def test_together_async():
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
+# ===== Local Models =====
+@pytest.mark.skipif(model_settings.ollama_base_url is None, reason="Only run if OLLAMA_BASE_URL is set.")
+@pytest.mark.asyncio
+async def test_ollama():
+ provider = OllamaProvider(
+ name="ollama",
+ base_url=model_settings.ollama_base_url,
+ api_key=None,
+ default_prompt_formatter=model_settings.default_prompt_formatter,
+ )
+ models = await provider.list_llm_models_async()
+ assert len(models) > 0
+ assert models[0].handle == f"{provider.name}/{models[0].model}"
+
+ embedding_models = await provider.list_embedding_models_async()
+ assert len(embedding_models) > 0
+ assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
+
+
+@pytest.mark.skipif(model_settings.vllm_api_base is None, reason="Only run if VLLM_API_BASE is set.")
+@pytest.mark.asyncio
+async def test_vllm():
+ provider = VLLMProvider(base_url=model_settings.vllm_api_base)
+ models = await provider.list_llm_models_async()
+ print(models)
+
+
# TODO: Add back in, difficulty adding this to CI properly, need boto credentials
# def test_anthropic_bedrock():
# from letta.settings import model_settings
@@ -218,20 +191,122 @@ async def test_together_async():
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
-def test_custom_anthropic():
+async def test_custom_anthropic():
provider = AnthropicProvider(
name="custom_anthropic",
api_key=model_settings.anthropic_api_key,
)
- models = provider.list_llm_models()
+ models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
-@pytest.mark.skipif(model_settings.vllm_api_base is None, reason="Only run if VLLM_API_BASE is set.")
-def test_vllm():
- provider = VLLMChatCompletionsProvider(base_url=model_settings.vllm_api_base)
- models = provider.list_llm_models()
- print(models)
+def test_provider_context_window():
+ """Test that providers implement context window methods correctly."""
+ provider = OpenAIProvider(
+ name="openai",
+ api_key=model_settings.openai_api_key,
+ base_url=model_settings.openai_api_base,
+ )
- provider.list_embedding_models()
+ # Test both sync and async context window methods
+ context_window = provider.get_model_context_window("gpt-4")
+ assert context_window is not None
+ assert isinstance(context_window, int)
+ assert context_window > 0
+
+
+@pytest.mark.asyncio
+async def test_provider_context_window_async():
+ """Test that providers implement async context window methods correctly."""
+ provider = OpenAIProvider(
+ name="openai",
+ api_key=model_settings.openai_api_key,
+ base_url=model_settings.openai_api_base,
+ )
+
+ context_window = await provider.get_model_context_window_async("gpt-4")
+ assert context_window is not None
+ assert isinstance(context_window, int)
+ assert context_window > 0
+
+
+def test_provider_handle_generation():
+ """Test that providers generate handles correctly."""
+ provider = OpenAIProvider(
+ name="test_openai",
+ api_key="test_key",
+ base_url="https://api.openai.com/v1",
+ )
+
+ # Test LLM handle
+ llm_handle = provider.get_handle("gpt-4")
+ assert llm_handle == "test_openai/gpt-4"
+
+ # Test embedding handle
+ embedding_handle = provider.get_handle("text-embedding-ada-002", is_embedding=True)
+ assert embedding_handle == "test_openai/text-embedding-ada-002"
+
+
+def test_provider_casting():
+ """Test that providers can be cast to their specific subtypes."""
+ from letta.schemas.enums import ProviderCategory, ProviderType
+ from letta.schemas.providers.base import Provider
+
+ base_provider = Provider(
+ name="test_provider",
+ provider_type=ProviderType.openai,
+ provider_category=ProviderCategory.base,
+ api_key="test_key",
+ base_url="https://api.openai.com/v1",
+ )
+
+ cast_provider = base_provider.cast_to_subtype()
+ assert isinstance(cast_provider, OpenAIProvider)
+ assert cast_provider.name == "test_provider"
+ assert cast_provider.api_key == "test_key"
+
+
+@pytest.mark.asyncio
+async def test_provider_embedding_models_consistency():
+ """Test that providers return consistent embedding model formats."""
+ provider = OpenAIProvider(
+ name="openai",
+ api_key=model_settings.openai_api_key,
+ base_url=model_settings.openai_api_base,
+ )
+
+ embedding_models = await provider.list_embedding_models_async()
+ if embedding_models: # Only test if provider supports embedding models
+ for model in embedding_models:
+ assert hasattr(model, "embedding_model")
+ assert hasattr(model, "embedding_endpoint_type")
+ assert hasattr(model, "embedding_endpoint")
+ assert hasattr(model, "embedding_dim")
+ assert hasattr(model, "handle")
+ assert model.handle.startswith(f"{provider.name}/")
+
+
+@pytest.mark.asyncio
+async def test_provider_llm_models_consistency():
+ """Test that providers return consistent LLM model formats."""
+ provider = OpenAIProvider(
+ name="openai",
+ api_key=model_settings.openai_api_key,
+ base_url=model_settings.openai_api_base,
+ )
+
+ models = await provider.list_llm_models_async()
+ assert len(models) > 0
+
+ for model in models:
+ assert hasattr(model, "model")
+ assert hasattr(model, "model_endpoint_type")
+ assert hasattr(model, "model_endpoint")
+ assert hasattr(model, "context_window")
+ assert hasattr(model, "handle")
+ assert hasattr(model, "provider_name")
+ assert hasattr(model, "provider_category")
+ assert model.handle.startswith(f"{provider.name}/")
+ assert model.provider_name == provider.name
+ assert model.context_window > 0
diff --git a/tests/test_server.py b/tests/test_server.py
index 23756386..ec9007cc 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -991,7 +991,8 @@ def test_default_tool_rules(server: SyncServer, user_id: str, base_tools, base_m
assert len(agent_state.tool_rules) == len(base_tools + base_memory_tools)
-def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools, base_memory_tools):
+@pytest.mark.asyncio
+async def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools, base_memory_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = server.user_manager.get_user_or_default(user_id)
@@ -1055,12 +1056,12 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
# Add all the base tools
request.tool_ids = [b.id for b in base_tools]
- agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor)
+ agent_state = await server.agent_manager.update_agent_async(agent_state.id, agent_update=request, actor=actor)
assert len(agent_state.tools) == len(base_tools)
# Remove one base tool
request.tool_ids = [b.id for b in base_tools[:-2]]
- agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor)
+ agent_state = await server.agent_manager.update_agent_async(agent_state.id, agent_update=request, actor=actor)
assert len(agent_state.tools) == len(base_tools) - 2
diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py
index a228e250..0e5f4ed5 100644
--- a/tests/test_tool_rule_solver.py
+++ b/tests/test_tool_rule_solver.py
@@ -59,7 +59,7 @@ def test_get_allowed_tool_names_no_matching_rule_error():
solver = ToolRulesSolver(tool_rules=[init_rule])
solver.register_tool_call(UNRECOGNIZED_TOOL)
- with pytest.raises(ValueError, match=f"No valid tools found based on tool rules."):
+ with pytest.raises(ValueError, match="No valid tools found based on tool rules."):
solver.get_allowed_tool_names(set(), error_on_empty=True)