Merge branch 'main' into bump-10-0

This commit is contained in:
Caren Thomas
2025-07-31 17:53:29 -07:00
30 changed files with 8862 additions and 67 deletions

View File

@@ -11,20 +11,25 @@ assignees: ''
A clear and concise description of what the bug is.
**Please describe your setup**
- [ ] How did you install letta?
- `pip install letta`? `pip install letta-nightly`? `git clone`?
- [ ] How are you running Letta?
- Docker
- pip (legacy)
- From source
- Desktop
- [ ] Describe your setup
- What's your OS (Windows/MacOS/Linux)?
- How are you running `letta`? (`cmd.exe`/Powershell/Anaconda Shell/Terminal)
- What is your `docker run ...` command (if applicable)
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.
- What model you are using
**Agent File (optional)**
Please attach your `.af` file, as this helps with reproducing issues.
**Letta Config**
Please attach your `~/.letta/config` file or copy paste it below.
---

286
.github/scripts/model-sweep/conftest.py vendored Normal file
View File

@@ -0,0 +1,286 @@
import logging
import os
import socket
import threading
import time
from datetime import datetime, timezone
from typing import Generator
import pytest
import requests
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts
from dotenv import load_dotenv
from letta_client import AsyncLetta, Letta
from letta.schemas.agent import AgentState
from letta.schemas.llm_config import LLMConfig
from letta.services.organization_manager import OrganizationManager
from letta.services.user_manager import UserManager
from letta.settings import tool_settings
def pytest_configure(config):
logging.basicConfig(level=logging.DEBUG)
@pytest.fixture
def disable_e2b_api_key() -> Generator[None, None, None]:
"""
Temporarily disables the E2B API key by setting `tool_settings.e2b_api_key` to None
for the duration of the test. Restores the original value afterward.
"""
from letta.settings import tool_settings
original_api_key = tool_settings.e2b_api_key
tool_settings.e2b_api_key = None
yield
tool_settings.e2b_api_key = original_api_key
@pytest.fixture
def check_e2b_key_is_set():
from letta.settings import tool_settings
original_api_key = tool_settings.e2b_api_key
assert original_api_key is not None, "Missing e2b key! Cannot execute these tests."
yield
@pytest.fixture
def default_organization():
"""Fixture to create and return the default organization."""
manager = OrganizationManager()
org = manager.create_default_organization()
yield org
@pytest.fixture
def default_user(default_organization):
"""Fixture to create and return the default user within the default organization."""
manager = UserManager()
user = manager.create_default_user(org_id=default_organization.id)
yield user
@pytest.fixture
def check_composio_key_set():
original_api_key = tool_settings.composio_api_key
assert original_api_key is not None, "Missing composio key! Cannot execute this test."
yield
# --- Tool Fixtures ---
@pytest.fixture
def weather_tool_func():
def get_weather(location: str) -> str:
"""
Fetches the current weather for a given location.
Parameters:
location (str): The location to get the weather for.
Returns:
str: A formatted string describing the weather in the given location.
Raises:
RuntimeError: If the request to fetch weather data fails.
"""
import requests
url = f"https://wttr.in/{location}?format=%C+%t"
response = requests.get(url)
if response.status_code == 200:
weather_data = response.text
return f"The weather in {location} is {weather_data}."
else:
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
yield get_weather
@pytest.fixture
def print_tool_func():
"""Fixture to create a tool with default settings and clean up after the test."""
def print_tool(message: str):
"""
Args:
message (str): The message to print.
Returns:
str: The message that was printed.
"""
print(message)
return message
yield print_tool
@pytest.fixture
def roll_dice_tool_func():
def roll_dice():
"""
Rolls a 6 sided die.
Returns:
str: The roll result.
"""
import time
time.sleep(1)
return "Rolled a 10!"
yield roll_dice
@pytest.fixture
def dummy_beta_message_batch() -> BetaMessageBatch:
return BetaMessageBatch(
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
processing_status="in_progress",
request_counts=BetaMessageBatchRequestCounts(
canceled=10,
errored=30,
expired=10,
processing=100,
succeeded=50,
),
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
type="message_batch",
)
# --- Model Sweep ---
# Global flag to track server state
_server_started = False
_server_url = None
def _start_server_once() -> str:
"""Start server exactly once, return URL"""
global _server_started, _server_url
if _server_started and _server_url:
return _server_url
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
# Check if already running
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(("localhost", 8283)) == 0:
_server_started = True
_server_url = url
return url
# Start server (your existing logic)
if not os.getenv("LETTA_SERVER_URL"):
def _run_server():
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until up
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
_server_started = True
_server_url = url
return url
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""Return URL of already-started server"""
return _start_server_once()
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="function")
def async_client(server_url: str) -> AsyncLetta:
"""
Creates and returns an asynchronous Letta REST client for testing.
"""
async_client_instance = AsyncLetta(base_url=server_url)
yield async_client_instance
@pytest.fixture(scope="module")
def agent_state(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
"""
client.tools.upsert_base_tools()
send_message_tool = client.tools.list(name="send_message")[0]
agent_state_instance = client.agents.create(
name="supervisor",
include_base_tools=False,
tool_ids=[send_message_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["supervisor"],
)
yield agent_state_instance
client.agents.delete(agent_state_instance.id)
@pytest.fixture(scope="module")
def all_available_llm_configs(client: Letta) -> [LLMConfig]:
"""
Returns a list of all available LLM configs.
"""
llm_configs = client.models.list()
return llm_configs
# create a client to the started server started at
def get_available_llm_configs() -> [LLMConfig]:
"""Get configs, starting server if needed"""
server_url = _start_server_once()
temp_client = Letta(base_url=server_url)
return temp_client.models.list()
# dynamically insert llm_config paramter at collection time
def pytest_generate_tests(metafunc):
"""Dynamically parametrize tests that need llm_config."""
if "llm_config" in metafunc.fixturenames:
configs = get_available_llm_configs()
if configs:
metafunc.parametrize("llm_config", configs, ids=[c.model for c in configs])

View File

@@ -0,0 +1,24 @@
{
"Basic": [
"test_greeting_with_assistant_message",
"test_greeting_without_assistant_message",
"test_async_greeting_with_assistant_message",
"test_agent_loop_error",
"test_step_stream_agent_loop_error",
"test_step_streaming_greeting_with_assistant_message",
"test_step_streaming_greeting_without_assistant_message",
"test_step_streaming_tool_call",
"test_tool_call",
"test_auto_summarize"
],
"Token Streaming": [
"test_token_streaming_greeting_with_assistant_message",
"test_token_streaming_greeting_without_assistant_message",
"test_token_streaming_agent_loop_error",
"test_token_streaming_tool_call"
],
"Multimodal": [
"test_base64_image_input",
"test_url_image_input"
]
}

View File

@@ -0,0 +1,495 @@
#!/usr/bin/env python3
import json
import os
import sys
from collections import defaultdict
from datetime import datetime
def load_feature_mappings(config_file=None):
"""Load feature mappings from config file."""
if config_file is None:
# Default to feature_mappings.json in the same directory as this script
script_dir = os.path.dirname(os.path.abspath(__file__))
config_file = os.path.join(script_dir, "feature_mappings.json")
try:
with open(config_file, "r") as f:
return json.load(f)
except FileNotFoundError:
print(f"Error: Could not find feature mappings config file '{config_file}'")
sys.exit(1)
except json.JSONDecodeError:
print(f"Error: Invalid JSON in feature mappings config file '{config_file}'")
sys.exit(1)
def get_support_status(passed_tests, feature_tests):
"""Determine support status for a feature category."""
if not feature_tests:
return "" # Unknown - no tests for this feature
# Filter out error tests when checking for support
non_error_tests = [test for test in feature_tests if not test.endswith("_error")]
error_tests = [test for test in feature_tests if test.endswith("_error")]
# Check which non-error tests passed
passed_non_error_tests = [test for test in non_error_tests if test in passed_tests]
# If there are no non-error tests, only error tests, treat as unknown
if not non_error_tests:
return "" # Only error tests available
# Support is based only on non-error tests
if len(passed_non_error_tests) == len(non_error_tests):
return "" # Full support
elif len(passed_non_error_tests) == 0:
return "" # No support
else:
return "⚠️" # Partial support
def categorize_tests(all_test_names, feature_mapping):
"""Categorize test names into feature buckets."""
categorized = {feature: [] for feature in feature_mapping.keys()}
for test_name in all_test_names:
for feature, test_patterns in feature_mapping.items():
if test_name in test_patterns:
categorized[feature].append(test_name)
break
return categorized
def calculate_support_score(feature_support, feature_order):
"""Calculate a numeric support score for ranking models.
For partial support, the score is weighted by the position of the feature
in the feature_order list (earlier features get higher weight).
"""
score = 0
max_features = len(feature_order)
for feature, status in feature_support.items():
# Get position weight (earlier features get higher weight)
if feature in feature_order:
position_weight = (max_features - feature_order.index(feature)) / max_features
else:
position_weight = 0.5 # Default weight for unmapped features
if status == "": # Full support
score += 10 * position_weight
elif status == "⚠️": # Partial support - weighted by column position
score += 5 * position_weight
elif status == "": # No support
score += 1 * position_weight
# Unknown (❓) gets 0 points
return score
def calculate_provider_support_score(models_data, feature_order):
"""Calculate a provider-level support score based on all models' support scores."""
if not models_data:
return 0
# Calculate the average support score across all models in the provider
total_score = sum(model["support_score"] for model in models_data)
return total_score / len(models_data)
def get_test_function_line_numbers(test_file_path):
"""Extract line numbers for test functions from the test file."""
test_line_numbers = {}
try:
with open(test_file_path, "r") as f:
lines = f.readlines()
for i, line in enumerate(lines, 1):
if "def test_" in line and line.strip().startswith("def test_"):
# Extract function name
func_name = line.strip().split("def ")[1].split("(")[0]
test_line_numbers[func_name] = i
except FileNotFoundError:
print(f"Warning: Could not find test file at {test_file_path}")
return test_line_numbers
def get_github_repo_info():
"""Get GitHub repository information from git remote."""
try:
# Try to get the GitHub repo URL from git remote
import subprocess
result = subprocess.run(["git", "remote", "get-url", "origin"], capture_output=True, text=True, cwd=os.path.dirname(__file__))
if result.returncode == 0:
remote_url = result.stdout.strip()
# Parse GitHub URL
if "github.com" in remote_url:
if remote_url.startswith("https://"):
# https://github.com/user/repo.git -> user/repo
repo_path = remote_url.replace("https://github.com/", "").replace(".git", "")
elif remote_url.startswith("git@"):
# git@github.com:user/repo.git -> user/repo
repo_path = remote_url.split(":")[1].replace(".git", "")
else:
return None
return repo_path
except:
pass
# Default fallback
return "letta-ai/letta"
def generate_test_details(model_info, feature_mapping):
"""Generate detailed test results for a model."""
details = []
# Get test function line numbers
script_dir = os.path.dirname(os.path.abspath(__file__))
test_file_path = os.path.join(script_dir, "model_sweep.py")
test_line_numbers = get_test_function_line_numbers(test_file_path)
# Use the main branch GitHub URL
base_github_url = "https://github.com/letta-ai/letta/blob/main/.github/scripts/model-sweep/model_sweep.py"
for feature, tests in model_info["categorized_tests"].items():
if not tests:
continue
details.append(f"### {feature}")
details.append("")
for test in sorted(tests):
if test in model_info["passed_tests"]:
status = ""
elif test in model_info["failed_tests"]:
status = ""
else:
status = ""
# Create GitHub link if we have line number info
if test in test_line_numbers:
line_num = test_line_numbers[test]
github_link = f"{base_github_url}#L{line_num}"
details.append(f"- {status} [`{test}`]({github_link})")
else:
details.append(f"- {status} `{test}`")
details.append("")
return details
def calculate_column_widths(all_provider_data, feature_mapping):
"""Calculate the maximum width needed for each column across all providers."""
widths = {"model": len("Model"), "context_window": len("Context Window"), "last_scanned": len("Last Scanned")}
# Feature column widths
for feature in feature_mapping.keys():
widths[feature] = len(feature)
# Check all model data for maximum widths
for provider_data in all_provider_data.values():
for model_info in provider_data:
# Model name width (including backticks)
model_width = len(f"`{model_info['name']}`")
widths["model"] = max(widths["model"], model_width)
# Context window width (with commas)
context_width = len(f"{model_info['context_window']:,}")
widths["context_window"] = max(widths["context_window"], context_width)
# Last scanned width
widths["last_scanned"] = max(widths["last_scanned"], len(str(model_info["last_scanned"])))
# Feature support symbols are always 2 chars, so no need to check
return widths
def process_model_sweep_report(input_file, output_file, config_file=None, debug=False):
"""Convert model sweep JSON data to MDX report."""
# Load feature mappings from config file
feature_mapping = load_feature_mappings(config_file)
# if debug:
# print("DEBUG: Feature mappings loaded:")
# for feature, tests in feature_mapping.items():
# print(f" {feature}: {tests}")
# print()
# Read the JSON data
with open(input_file, "r") as f:
data = json.load(f)
tests = data.get("tests", [])
# if debug:
# print("DEBUG: Tests loaded:")
# print([test['outcome'] for test in tests if 'haiku' in test['nodeid']])
# Calculate summary statistics
providers = set(test["metadata"]["llm_config"]["provider_name"] for test in tests)
models = set(test["metadata"]["llm_config"]["model"] for test in tests)
total_tests = len(tests)
# Start building the MDX
mdx_lines = [
"---",
"title: Support Models",
f"generated: {datetime.now().isoformat()}",
"---",
"",
"# Supported Models",
"",
"## Overview",
"",
"Letta routinely runs automated scans against available providers and models. These are the results of the latest scan.",
"",
f"Ran {total_tests} tests against {len(models)} models across {len(providers)} providers on {datetime.now().strftime('%B %dth, %Y')}",
"",
"",
]
# Group tests by provider
provider_groups = defaultdict(list)
for test in tests:
provider_name = test["metadata"]["llm_config"]["provider_name"]
provider_groups[provider_name].append(test)
# Process all providers first to collect model data
all_provider_data = {}
provider_support_scores = {}
for provider_name in provider_groups.keys():
provider_tests = provider_groups[provider_name]
# Group tests by model within this provider
model_groups = defaultdict(list)
for test in provider_tests:
model_name = test["metadata"]["llm_config"]["model"]
model_groups[model_name].append(test)
# Process all models to calculate support scores for ranking
model_data = []
for model_name in model_groups.keys():
model_tests = model_groups[model_name]
# if debug:
# print(f"DEBUG: Processing model '{model_name}' in provider '{provider_name}'")
# Extract unique test names for passed and failed tests
passed_tests = set()
failed_tests = set()
all_test_names = set()
for test in model_tests:
# Extract test name from nodeid (split on :: and [)
test_name = test["nodeid"].split("::")[1].split("[")[0]
all_test_names.add(test_name)
# if debug:
# print(f" Test name: {test_name}")
# print(f" Outcome: {test}")
if test["outcome"] == "passed":
passed_tests.add(test_name)
elif test["outcome"] == "failed":
failed_tests.add(test_name)
# if debug:
# print(f" All test names found: {sorted(all_test_names)}")
# print(f" Passed tests: {sorted(passed_tests)}")
# print(f" Failed tests: {sorted(failed_tests)}")
# Categorize tests into features
categorized_tests = categorize_tests(all_test_names, feature_mapping)
# if debug:
# print(f" Categorized tests:")
# for feature, tests in categorized_tests.items():
# print(f" {feature}: {tests}")
# Determine support status for each feature
feature_support = {}
for feature_name in feature_mapping.keys():
feature_support[feature_name] = get_support_status(passed_tests, categorized_tests[feature_name])
# if debug:
# print(f" Feature support:")
# for feature, status in feature_support.items():
# print(f" {feature}: {status}")
# print()
# Get context window and last scanned time
context_window = model_tests[0]["metadata"]["llm_config"]["context_window"]
# Try to get time_last_scanned from metadata, fallback to current time
try:
last_scanned = model_tests[0]["metadata"].get(
"time_last_scanned", model_tests[0]["metadata"].get("timestamp", datetime.now().isoformat())
)
# Format timestamp if it's a full ISO string
if "T" in str(last_scanned):
last_scanned = str(last_scanned).split("T")[0] # Just the date part
except:
last_scanned = "Unknown"
# Calculate support score for ranking
feature_order = list(feature_mapping.keys())
support_score = calculate_support_score(feature_support, feature_order)
# Store model data for sorting
model_data.append(
{
"name": model_name,
"feature_support": feature_support,
"context_window": context_window,
"last_scanned": last_scanned,
"support_score": support_score,
"failed_tests": failed_tests,
"passed_tests": passed_tests,
"categorized_tests": categorized_tests,
}
)
# Sort models by support score (descending) then by name (ascending)
model_data.sort(key=lambda x: (-x["support_score"], x["name"]))
# Store provider data
all_provider_data[provider_name] = model_data
provider_support_scores[provider_name] = calculate_provider_support_score(model_data, list(feature_mapping.keys()))
# Calculate column widths for consistent formatting (add details column)
column_widths = calculate_column_widths(all_provider_data, feature_mapping)
column_widths["details"] = len("Details")
# Sort providers by support score (descending) then by name (ascending)
sorted_providers = sorted(provider_support_scores.keys(), key=lambda x: (-provider_support_scores[x], x))
# Generate tables for all providers first
for provider_name in sorted_providers:
model_data = all_provider_data[provider_name]
support_score = provider_support_scores[provider_name]
# Create dynamic headers with proper padding and centering
feature_names = list(feature_mapping.keys())
# Build header row with left-aligned first column, centered others
header_parts = [f"{'Model':<{column_widths['model']}}"]
for feature in feature_names:
header_parts.append(f"{feature:^{column_widths[feature]}}")
header_parts.extend(
[
f"{'Context Window':^{column_widths['context_window']}}",
f"{'Last Scanned':^{column_widths['last_scanned']}}",
f"{'Details':^{column_widths['details']}}",
]
)
header_row = "| " + " | ".join(header_parts) + " |"
# Build separator row with left-aligned first column, centered others
separator_parts = [f"{'-' * column_widths['model']}"]
for feature in feature_names:
separator_parts.append(f":{'-' * (column_widths[feature] - 2)}:")
separator_parts.extend(
[
f":{'-' * (column_widths['context_window'] - 2)}:",
f":{'-' * (column_widths['last_scanned'] - 2)}:",
f":{'-' * (column_widths['details'] - 2)}:",
]
)
separator_row = "|" + "|".join(separator_parts) + "|"
# Add provider section without percentage
mdx_lines.extend([f"## {provider_name}", "", header_row, separator_row])
# Generate table rows for sorted models with proper padding
for model_info in model_data:
# Create anchor for model details
model_anchor = model_info["name"].replace("/", "_").replace(":", "_").replace("-", "_").lower()
details_anchor = f"{provider_name.lower().replace(' ', '_')}_{model_anchor}_details"
# Build row with left-aligned first column, centered others
row_parts = [f"`{model_info['name']}`".ljust(column_widths["model"])]
for feature in feature_names:
row_parts.append(f"{model_info['feature_support'][feature]:^{column_widths[feature]}}")
row_parts.extend(
[
f"{model_info['context_window']:,}".center(column_widths["context_window"]),
f"{model_info['last_scanned']}".center(column_widths["last_scanned"]),
f"[View](#{details_anchor})".center(column_widths["details"]),
]
)
row = "| " + " | ".join(row_parts) + " |"
mdx_lines.append(row)
# Add spacing between provider tables
mdx_lines.extend(["", ""])
# Add detailed test results section after all tables
mdx_lines.extend(["---", "", "# Detailed Test Results", ""])
for provider_name in sorted_providers:
model_data = all_provider_data[provider_name]
mdx_lines.extend([f"## {provider_name}", ""])
for model_info in model_data:
model_anchor = model_info["name"].replace("/", "_").replace(":", "_").replace("-", "_").lower()
details_anchor = f"{provider_name.lower().replace(' ', '_')}_{model_anchor}_details"
mdx_lines.append(f'<a id="{details_anchor}"></a>')
mdx_lines.append(f"### {model_info['name']}")
mdx_lines.append("")
# Add test details
test_details = generate_test_details(model_info, feature_mapping)
mdx_lines.extend(test_details)
# Add spacing between providers in details section
mdx_lines.extend(["", ""])
# Write the MDX file
with open(output_file, "w") as f:
f.write("\n".join(mdx_lines))
print(f"Model sweep report saved to {output_file}")
def main():
input_file = "model_sweep_report.json"
output_file = "model_sweep_report.mdx"
config_file = None
debug = False
# Allow command line arguments
if len(sys.argv) > 1:
# Use the file located in the same directory as this script
script_dir = os.path.dirname(os.path.abspath(__file__))
input_file = os.path.join(script_dir, sys.argv[1])
if len(sys.argv) > 2:
# Use the file located in the same directory as this script
script_dir = os.path.dirname(os.path.abspath(__file__))
output_file = os.path.join(script_dir, sys.argv[2])
if len(sys.argv) > 3:
config_file = sys.argv[3]
if len(sys.argv) > 4 and sys.argv[4] == "--debug":
debug = True
try:
process_model_sweep_report(input_file, output_file, config_file, debug)
except FileNotFoundError:
print(f"Error: Could not find input file '{input_file}'")
sys.exit(1)
except json.JSONDecodeError:
print(f"Error: Invalid JSON in file '{input_file}'")
sys.exit(1)
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,786 @@
import base64
import json
import os
import socket
import threading
import time
import uuid
from typing import Any, Dict, List
import httpx
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta, MessageCreate, Run
from letta_client.core.api_error import ApiError
from letta_client.types import (
AssistantMessage,
Base64Image,
ImageContent,
LettaUsageStatistics,
ReasoningMessage,
TextContent,
ToolCallMessage,
ToolReturnMessage,
UrlImage,
UserMessage,
)
from letta.schemas.agent import AgentState
from letta.schemas.llm_config import LLMConfig
# ------------------------------
# Helper Functions and Constants
# ------------------------------
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
filename = os.path.join(llm_config_dir, filename)
config_data = json.load(open(filename, "r"))
llm_config = LLMConfig(**config_data)
return llm_config
def roll_dice(num_sides: int) -> int:
"""
Returns a random number between 1 and num_sides.
Args:
num_sides (int): The number of sides on the die.
Returns:
int: A random integer between 1 and num_sides, representing the die roll.
"""
import random
return random.randint(1, num_sides)
USER_MESSAGE_OTID = str(uuid.uuid4())
USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work"
USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [
MessageCreate(
role="user",
content=f"This is an automated test message. Call the send_message tool with the message '{USER_MESSAGE_RESPONSE}'.",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
MessageCreate(
role="user",
content="This is an automated test message. Call the roll_dice tool with 16 sides and tell me the outcome.",
otid=USER_MESSAGE_OTID,
)
]
URL_IMAGE = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg"
USER_MESSAGE_URL_IMAGE: List[MessageCreate] = [
MessageCreate(
role="user",
content=[
ImageContent(source=UrlImage(url=URL_IMAGE)),
TextContent(text="What is in this image?"),
],
otid=USER_MESSAGE_OTID,
)
]
BASE64_IMAGE = base64.standard_b64encode(httpx.get(URL_IMAGE).content).decode("utf-8")
USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [
MessageCreate(
role="user",
content=[
ImageContent(source=Base64Image(data=BASE64_IMAGE, media_type="image/jpeg")),
TextContent(text="What is in this image?"),
],
otid=USER_MESSAGE_OTID,
)
]
all_configs = [
"openai-gpt-4o-mini.json",
# "azure-gpt-4o-mini.json", # TODO: Re-enable on new agent loop
"claude-3-5-sonnet.json",
"claude-3-7-sonnet.json",
"claude-3-7-sonnet-extended.json",
"gemini-1.5-pro.json",
"gemini-2.5-flash-vertex.json",
"gemini-2.5-pro-vertex.json",
"together-qwen-2.5-72b-instruct.json",
"ollama.json",
]
requested = os.getenv("LLM_CONFIG_FILE")
filenames = [requested] if requested else all_configs
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
def assert_greeting_with_assistant_message_response(
messages: List[Any],
streaming: bool = False,
token_streaming: bool = False,
from_db: bool = False,
) -> None:
"""
Asserts that the messages list follows the expected sequence:
ReasoningMessage -> AssistantMessage.
"""
expected_message_count = 3 if streaming or from_db else 2
assert len(messages) == expected_message_count
index = 0
if from_db:
assert isinstance(messages[index], UserMessage)
assert messages[index].otid == USER_MESSAGE_OTID
index += 1
# Agent Step 1
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == "0"
index += 1
assert isinstance(messages[index], AssistantMessage)
if not token_streaming:
assert USER_MESSAGE_RESPONSE in messages[index].content
assert messages[index].otid and messages[index].otid[-1] == "1"
index += 1
if streaming:
assert isinstance(messages[index], LettaUsageStatistics)
assert messages[index].prompt_tokens > 0
assert messages[index].completion_tokens > 0
assert messages[index].total_tokens > 0
assert messages[index].step_count > 0
def assert_greeting_without_assistant_message_response(
messages: List[Any],
streaming: bool = False,
token_streaming: bool = False,
from_db: bool = False,
) -> None:
"""
Asserts that the messages list follows the expected sequence:
ReasoningMessage -> ToolCallMessage -> ToolReturnMessage.
"""
expected_message_count = 4 if streaming or from_db else 3
assert len(messages) == expected_message_count
index = 0
if from_db:
assert isinstance(messages[index], UserMessage)
assert messages[index].otid == USER_MESSAGE_OTID
index += 1
# Agent Step 1
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == "0"
index += 1
assert isinstance(messages[index], ToolCallMessage)
assert messages[index].tool_call.name == "send_message"
if not token_streaming:
assert USER_MESSAGE_RESPONSE in messages[index].tool_call.arguments
assert messages[index].otid and messages[index].otid[-1] == "1"
index += 1
# Agent Step 2
assert isinstance(messages[index], ToolReturnMessage)
assert messages[index].otid and messages[index].otid[-1] == "0"
index += 1
if streaming:
assert isinstance(messages[index], LettaUsageStatistics)
def assert_tool_call_response(
messages: List[Any],
streaming: bool = False,
from_db: bool = False,
) -> None:
"""
Asserts that the messages list follows the expected sequence:
ReasoningMessage -> ToolCallMessage -> ToolReturnMessage ->
ReasoningMessage -> AssistantMessage.
"""
expected_message_count = 6 if streaming else 7 if from_db else 5
assert len(messages) == expected_message_count
index = 0
if from_db:
assert isinstance(messages[index], UserMessage)
assert messages[index].otid == USER_MESSAGE_OTID
index += 1
# Agent Step 1
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == "0"
index += 1
assert isinstance(messages[index], ToolCallMessage)
assert messages[index].otid and messages[index].otid[-1] == "1"
index += 1
# Agent Step 2
assert isinstance(messages[index], ToolReturnMessage)
assert messages[index].otid and messages[index].otid[-1] == "0"
index += 1
# Hidden User Message
if from_db:
assert isinstance(messages[index], UserMessage)
assert "request_heartbeat=true" in messages[index].content
index += 1
# Agent Step 3
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == "0"
index += 1
assert isinstance(messages[index], AssistantMessage)
assert messages[index].otid and messages[index].otid[-1] == "1"
index += 1
if streaming:
assert isinstance(messages[index], LettaUsageStatistics)
def assert_image_input_response(
messages: List[Any],
streaming: bool = False,
token_streaming: bool = False,
from_db: bool = False,
) -> None:
"""
Asserts that the messages list follows the expected sequence:
ReasoningMessage -> AssistantMessage.
"""
expected_message_count = 3 if streaming or from_db else 2
assert len(messages) == expected_message_count
index = 0
if from_db:
assert isinstance(messages[index], UserMessage)
assert messages[index].otid == USER_MESSAGE_OTID
index += 1
# Agent Step 1
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == "0"
index += 1
assert isinstance(messages[index], AssistantMessage)
assert messages[index].otid and messages[index].otid[-1] == "1"
index += 1
if streaming:
assert isinstance(messages[index], LettaUsageStatistics)
assert messages[index].prompt_tokens > 0
assert messages[index].completion_tokens > 0
assert messages[index].total_tokens > 0
assert messages[index].step_count > 0
def accumulate_chunks(chunks: List[Any]) -> List[Any]:
"""
Accumulates chunks into a list of messages.
"""
messages = []
current_message = None
prev_message_type = None
for chunk in chunks:
current_message_type = chunk.message_type
if prev_message_type != current_message_type:
messages.append(current_message)
current_message = None
if current_message is None:
current_message = chunk
else:
pass # TODO: actually accumulate the chunks. For now we only care about the count
prev_message_type = current_message_type
messages.append(current_message)
return [m for m in messages if m is not None]
def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
start = time.time()
while True:
run = client.runs.retrieve(run_id)
if run.status == "completed":
return run
if run.status == "failed":
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
if time.time() - start > timeout:
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})")
time.sleep(interval)
def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
"""
Asserts that a list of message dictionaries contains the expected types and statuses.
Expected order:
1. reasoning_message
2. tool_call_message
3. tool_return_message (with status 'success')
4. reasoning_message
5. assistant_message
"""
assert isinstance(messages, list)
assert messages[0]["message_type"] == "reasoning_message"
assert messages[1]["message_type"] == "assistant_message"
# ------------------------------
# Test Cases
# ------------------------------
# def test_that_ci_workflow_works(
# disable_e2b_api_key: Any,
# client: Letta,
# agent_state: AgentState,
# llm_config: LLMConfig,
# json_metadata: pytest.FixtureRequest,
# ) -> None:
# """
# Tests that the CI workflow works.
# """
# json_metadata["test_type"] = "debug"
def test_greeting_with_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that the response messages follow the expected order.
"""
json_metadata["llm_config"] = dict(llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
assert_greeting_with_assistant_message_response(response.messages)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True)
def test_greeting_without_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that the response messages follow the expected order.
"""
json_metadata["llm_config"] = dict(llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
use_assistant_message=False,
)
assert_greeting_without_assistant_message_response(response.messages)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True)
def test_tool_call(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that the response messages follow the expected order.
"""
json_metadata["llm_config"] = dict(llm_config)
dice_tool = client.tools.upsert_from_function(func=roll_dice)
client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
)
assert_tool_call_response(response.messages)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True)
def test_url_image_input(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that the response messages follow the expected order.
"""
json_metadata["llm_config"] = dict(llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_URL_IMAGE,
)
assert_image_input_response(response.messages)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_image_input_response(messages_from_db, from_db=True)
def test_base64_image_input(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that the response messages follow the expected order.
"""
json_metadata["llm_config"] = dict(llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_BASE64_IMAGE,
)
assert_image_input_response(response.messages)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_image_input_response(messages_from_db, from_db=True)
def test_agent_loop_error(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that no new messages are persisted on error.
"""
json_metadata["llm_config"] = dict(llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
tools = agent_state.tools
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[])
with pytest.raises(ApiError):
client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert len(messages_from_db) == 0
client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools])
def test_step_streaming_greeting_with_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
json_metadata["llm_config"] = dict(llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
assert_greeting_with_assistant_message_response(messages, streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True)
def test_step_streaming_greeting_without_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
json_metadata["llm_config"] = dict(llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
use_assistant_message=False,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
assert_greeting_without_assistant_message_response(messages, streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True)
def test_step_streaming_tool_call(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
json_metadata["llm_config"] = dict(llm_config)
dice_tool = client.tools.upsert_from_function(func=roll_dice)
agent_state = client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
assert_tool_call_response(messages, streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True)
def test_step_stream_agent_loop_error(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that no new messages are persisted on error.
"""
json_metadata["llm_config"] = dict(llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
tools = agent_state.tools
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[])
with pytest.raises(ApiError):
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
list(response)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert len(messages_from_db) == 0
client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools])
def test_token_streaming_greeting_with_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
json_metadata["llm_config"] = dict(llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
stream_tokens=True,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True)
def test_token_streaming_greeting_without_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
json_metadata["llm_config"] = dict(llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
use_assistant_message=False,
stream_tokens=True,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True)
def test_token_streaming_tool_call(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
json_metadata["llm_config"] = dict(llm_config)
dice_tool = client.tools.upsert_from_function(func=roll_dice)
agent_state = client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
stream_tokens=True,
)
chunks = list(response)
messages = accumulate_chunks(chunks)
assert_tool_call_response(messages, streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True)
def test_token_streaming_agent_loop_error(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that no new messages are persisted on error.
"""
json_metadata["llm_config"] = dict(llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
tools = agent_state.tools
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[])
try:
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
stream_tokens=True,
)
list(response)
except:
pass # only some models throw an error TODO: make this consistent
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert len(messages_from_db) == 0
client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools])
def test_async_greeting_with_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
agent_state: AgentState,
json_metadata: pytest.FixtureRequest,
) -> None:
"""
Tests sending a message as an asynchronous job using the synchronous client.
Waits for job completion and asserts that the result messages are as expected.
"""
json_metadata["llm_config"] = dict(llm_config)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
run = client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
run = wait_for_run_completion(client, run.id)
result = run.metadata.get("result")
assert result is not None, "Run metadata missing 'result' key"
messages = result["messages"]
assert_tool_response_dict_messages(messages)
def test_auto_summarize(
disable_e2b_api_key: Any,
client: Letta,
llm_config: LLMConfig,
json_metadata: pytest.FixtureRequest,
) -> None:
"""Test that summarization is automatically triggered."""
json_metadata["llm_config"] = dict(llm_config)
# pydantic prevents us for overriding the context window paramter in the passed LLMConfig
new_llm_config = llm_config.model_dump()
new_llm_config["context_window"] = 3000
pinned_context_window_llm_config = LLMConfig(**new_llm_config)
send_message_tool = client.tools.list(name="send_message")[0]
temp_agent_state = client.agents.create(
include_base_tools=False,
tool_ids=[send_message_tool.id],
llm_config=pinned_context_window_llm_config,
embedding="letta/letta-free",
tags=["supervisor"],
)
philosophical_question = """
You know, sometimes I wonder if the entire structure of our lives is built on a series of unexamined assumptions we just silently agreed to somewhere along the way—like how we all just decided that five days a week of work and two days of “rest” constitutes balance, or how 9-to-5 became the default rhythm of a meaningful life, or even how the idea of “success” got boiled down to job titles and property ownership and productivity metrics on a LinkedIn profile, when maybe none of that is actually what makes a life feel full, or grounded, or real. And then theres the weird paradox of ambition, how we're taught to chase it like a finish line that keeps moving, constantly redefining itself right as youre about to grasp it—because even when you get the job, or the degree, or the validation, there's always something next, something more, like a treadmill with invisible settings you didnt realize were turned up all the way.
And have you noticed how we rarely stop to ask who set those definitions for us? Like was there ever a council that decided, yes, owning a home by thirty-five and retiring by sixty-five is the universal template for fulfillment? Or did it just accumulate like cultural sediment over generations, layered into us so deeply that questioning it feels uncomfortable, even dangerous? And isnt it strange that we spend so much of our lives trying to optimize things—our workflows, our diets, our sleep, our morning routines—as though the point of life is to operate more efficiently rather than to experience it more richly? We build these intricate systems, these rulebooks for being a “high-functioning” human, but where in all of that is the space for feeling lost, for being soft, for wandering without a purpose just because its a sunny day and your heart is tugging you toward nowhere in particular?
Sometimes I lie awake at night and wonder if all the noise we wrap around ourselves—notifications, updates, performance reviews, even our internal monologues—might be crowding out the questions we were meant to live into slowly, like how to love better, or how to forgive ourselves, or what the hell were even doing here in the first place. And when you strip it all down—no goals, no KPIs, no curated identity—whats actually left of us? Are we just a sum of the roles we perform, or is there something quieter underneath that we've forgotten how to hear?
And if there is something underneath all of it—something real, something worth listening to—then how do we begin to uncover it, gently, without rushing or reducing it to another task on our to-do list?
"""
MAX_ATTEMPTS = 10
prev_length = None
for attempt in range(MAX_ATTEMPTS):
client.agents.messages.create(
agent_id=temp_agent_state.id,
messages=[MessageCreate(role="user", content=philosophical_question)],
)
temp_agent_state = client.agents.retrieve(agent_id=temp_agent_state.id)
message_ids = temp_agent_state.message_ids
current_length = len(message_ids)
print("LENGTH OF IN_CONTEXT_MESSAGES:", current_length)
if prev_length is not None and current_length <= prev_length:
# TODO: Add more stringent checks here
print(f"Summarization was triggered, detected current_length {current_length} is at least prev_length {prev_length}.")
break
prev_length = current_length
else:
raise AssertionError("Summarization was not triggered after 10 messages")

File diff suppressed because it is too large Load Diff

144
.github/workflows/model-sweep.yaml vendored Normal file
View File

@@ -0,0 +1,144 @@
name: Model Sweep
on:
workflow_dispatch:
inputs:
branch-name:
required: true
type: string
jobs:
model-sweep:
runs-on: [self-hosted, medium]
services:
qdrant:
image: qdrant/qdrant
ports:
- 6333:6333
postgres:
image: pgvector/pgvector:pg17
ports:
- 5432:5432
env:
POSTGRES_HOST_AUTH_METHOD: trust
POSTGRES_DB: postgres
POSTGRES_USER: postgres
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Check if gh is installed
run: |
if ! command -v gh >/dev/null 2>&1
then
echo "gh could not be found, installing now..."
# install gh cli
(type -p wget >/dev/null || (sudo apt update && sudo apt-get install wget -y)) \
&& sudo mkdir -p -m 755 /etc/apt/keyrings \
&& out=$(mktemp) && wget -nv -O$out https://cli.github.com/packages/githubcli-archive-keyring.gpg \
&& cat $out | sudo tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
&& sudo chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | sudo tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
&& sudo apt update \
&& sudo apt install gh -y
fi
- name: Checkout
uses: actions/checkout@v4
- name: Inject env vars into environment
run: |
# Get secrets and mask them before adding to environment
while IFS= read -r line || [[ -n "$line" ]]; do
if [[ -n "$line" ]]; then
value=$(echo "$line" | cut -d= -f2-)
echo "::add-mask::$value"
echo "$line" >> $GITHUB_ENV
fi
done < <(letta_secrets_helper --env dev --service ci)
- name: Install dependencies
shell: bash
run: poetry install --no-interaction --no-root ${{ inputs.install-args || '-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox -E google' }}
- name: Migrate database
env:
LETTA_PG_PORT: 5432
LETTA_PG_USER: postgres
LETTA_PG_PASSWORD: postgres
LETTA_PG_DB: postgres
LETTA_PG_HOST: localhost
run: |
psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION vector'
poetry run alembic upgrade head
- name: Run integration tests
# if any of the 1000+ test cases fail, pytest reports exit code 1 and won't procces/upload the results
continue-on-error: true
env:
LETTA_PG_PORT: 5432
LETTA_PG_USER: postgres
LETTA_PG_PASSWORD: postgres
LETTA_PG_DB: postgres
LETTA_PG_HOST: localhost
LETTA_SERVER_PASS: test_server_token
OPENAI_API_KEY: ${{ env.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ env.ANTHROPIC_API_KEY }}
AZURE_API_KEY: ${{ env.AZURE_API_KEY }}
AZURE_BASE_URL: ${{ secrets.AZURE_BASE_URL }}
GEMINI_API_KEY: ${{ env.GEMINI_API_KEY }}
COMPOSIO_API_KEY: ${{ env.COMPOSIO_API_KEY }}
GOOGLE_CLOUD_PROJECT: ${{ secrets.GOOGLE_CLOUD_PROJECT}}
GOOGLE_CLOUD_LOCATION: ${{ secrets.GOOGLE_CLOUD_LOCATION}}
DEEPSEEK_API_KEY: ${{ env.DEEPSEEK_API_KEY}}
LETTA_USE_EXPERIMENTAL: 1
run: |
poetry run pytest \
-s -vv \
.github/scripts/model-sweep/model_sweep.py \
--json-report --json-report-file=.github/scripts/model-sweep/model_sweep_report.json --json-report-indent=4
- name: Convert report to markdown
continue-on-error: true
# file path args to generate_model_sweep_markdown.py are relative to the script
run: |
poetry run python \
.github/scripts/model-sweep/generate_model_sweep_markdown.py \
.github/scripts/model-sweep/model_sweep_report.json \
.github/scripts/model-sweep/supported-models.mdx
echo "Model sweep report saved to .github/scripts/model-sweep/supported-models.mdx"
- id: date
run: echo "date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT
- name: commit and open pull request
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
BRANCH_NAME=model-sweep/${{ inputs.branch-name || format('{0}', steps.date.outputs.date) }}
gh auth setup-git
git config --global user.name "github-actions[bot]"
git config --global user.email "github-actions[bot]@users.noreply.github.com"
git checkout -b $BRANCH_NAME
git add .github/scripts/model-sweep/supported-models.mdx
git commit -m "Update model sweep report"
# only push if changes were made
if git diff main --quiet; then
echo "No changes detected, skipping push"
exit 0
else
git push origin $BRANCH_NAME
gh pr create \
--base main \
--head $BRANCH_NAME \
--title "chore: update model sweep report" \
--body "Automated PR to update model sweep report"
fi
- name: Upload model sweep report
if: always()
uses: actions/upload-artifact@v4
with:
name: model-sweep-report
path: .github/scripts/model-sweep/model_sweep_report.json

View File

@@ -1,19 +0,0 @@
name: Notify Letta Cloud
on:
push:
branches:
- main
jobs:
notify:
runs-on: ubuntu-latest
if: ${{ !contains(github.event.head_commit.message, '[sync-skip]') }}
steps:
- name: Trigger repository_dispatch
run: |
curl -X POST \
-H "Authorization: token ${{ secrets.SYNC_PAT }}" \
-H "Accept: application/vnd.github.v3+json" \
https://api.github.com/repos/letta-ai/letta-cloud/dispatches \
-d '{"event_type":"oss-update"}'

View File

@@ -0,0 +1,155 @@
name: Send Message SDK Tests
on:
pull_request_target:
# branches: [main] # TODO: uncomment before merge
types: [labeled]
paths:
- 'letta/**'
jobs:
send-messages:
# Only run when the "safe to test" label is applied
if: contains(github.event.pull_request.labels.*.name, 'safe to test')
runs-on: ubuntu-latest
timeout-minutes: 15
strategy:
fail-fast: false
matrix:
config_file:
- "openai-gpt-4o-mini.json"
- "azure-gpt-4o-mini.json"
- "claude-3-5-sonnet.json"
- "claude-3-7-sonnet.json"
- "claude-3-7-sonnet-extended.json"
- "gemini-pro.json"
- "gemini-vertex.json"
services:
qdrant:
image: qdrant/qdrant
ports:
- 6333:6333
postgres:
image: pgvector/pgvector:pg17
ports:
- 5432:5432
env:
POSTGRES_HOST_AUTH_METHOD: trust
POSTGRES_DB: postgres
POSTGRES_USER: postgres
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
# Ensure secrets don't leak
- name: Configure git to hide secrets
run: |
git config --global core.logAllRefUpdates false
git config --global log.hideCredentials true
- name: Set up secret masking
run: |
# Automatically mask any environment variable ending with _KEY
for var in $(env | grep '_KEY=' | cut -d= -f1); do
value="${!var}"
if [[ -n "$value" ]]; then
# Mask the full value
echo "::add-mask::$value"
# Also mask partial values (first and last several characters)
# This helps when only parts of keys appear in logs
if [[ ${#value} -gt 8 ]]; then
echo "::add-mask::${value:0:8}"
echo "::add-mask::${value:(-8)}"
fi
# Also mask with common formatting changes
# Some logs might add quotes or other characters
echo "::add-mask::\"$value\""
echo "::add-mask::$value\""
echo "::add-mask::\"$value"
echo "Masked secret: $var (length: ${#value})"
fi
done
# Check out base repository code, not the PR's code (for security)
- name: Checkout base repository
uses: actions/checkout@v4 # No ref specified means it uses base branch
# Only extract relevant files from the PR (for security, specifically prevent modification of workflow files)
- name: Extract PR schema files
run: |
# Fetch PR without checking it out
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr-${{ github.event.pull_request.number }}
# Extract ONLY the schema files
git checkout pr-${{ github.event.pull_request.number }} -- letta/
- name: Set up python 3.12
id: setup-python
uses: actions/setup-python@v5
with:
python-version: 3.12
- name: Load cached Poetry Binary
id: cached-poetry-binary
uses: actions/cache@v4
with:
path: ~/.local
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-1.8.3
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: 1.8.3
virtualenvs-create: true
virtualenvs-in-project: true
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}${{ inputs.install-args || '-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox' }}
# Restore cache with this prefix if not exact match with key
# Note cache-hit returns false in this case, so the below step will run
restore-keys: |
venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
shell: bash
run: poetry install --no-interaction --no-root ${{ inputs.install-args || '-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox -E google' }}
- name: Install letta packages via Poetry
run: |
poetry run pip install --upgrade letta-client letta
- name: Migrate database
env:
LETTA_PG_PORT: 5432
LETTA_PG_USER: postgres
LETTA_PG_PASSWORD: postgres
LETTA_PG_DB: postgres
LETTA_PG_HOST: localhost
run: |
psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION vector'
poetry run alembic upgrade head
- name: Run integration tests for ${{ matrix.config_file }}
env:
LLM_CONFIG_FILE: ${{ matrix.config_file }}
LETTA_PG_PORT: 5432
LETTA_PG_USER: postgres
LETTA_PG_PASSWORD: postgres
LETTA_PG_DB: postgres
LETTA_PG_HOST: localhost
LETTA_SERVER_PASS: test_server_token
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
AZURE_BASE_URL: ${{ secrets.AZURE_BASE_URL }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
COMPOSIO_API_KEY: ${{ secrets.COMPOSIO_API_KEY }}
DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }}
GOOGLE_CLOUD_PROJECT: ${{ secrets.GOOGLE_CLOUD_PROJECT }}
GOOGLE_CLOUD_LOCATION: ${{ secrets.GOOGLE_CLOUD_LOCATION }}
run: |
poetry run pytest \
-s -vv \
tests/integration_test_send_message.py \
--maxfail=1 --durations=10

View File

@@ -28,7 +28,7 @@ First, install Poetry using [the official instructions here](https://python-poet
Once Poetry is installed, navigate to the letta directory and install the Letta project with Poetry:
```shell
cd letta
poetry shell
eval $(poetry env activate)
poetry install --all-extras
```
#### Setup PostgreSQL environment (optional)

View File

@@ -8,26 +8,13 @@
<div align="center">
<h1>Letta (previously MemGPT)</h1>
**☄️ New release: Letta Agent Development Environment (_read more [here](#-access-the-ade-agent-development-environment)_) ☄️**
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/letta-ai/letta/refs/heads/main/assets/example_ade_screenshot.png">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/letta-ai/letta/refs/heads/main/assets/example_ade_screenshot_light.png">
<img alt="Letta logo" src="https://raw.githubusercontent.com/letta-ai/letta/refs/heads/main/assets/example_ade_screenshot.png" width="800">
</picture>
</p>
---
<h3>
[Homepage](https://letta.com) // [Documentation](https://docs.letta.com) // [ADE](https://docs.letta.com/agent-development-environment) // [Letta Cloud](https://forms.letta.com/early-access)
</h3>
**👾 Letta** is an open source framework for building stateful LLM applications. You can use Letta to build **stateful agents** with advanced reasoning capabilities and transparent long-term memory. The Letta framework is white box and model-agnostic.
**👾 Letta** is an open source framework for building **stateful agents** with advanced reasoning capabilities and transparent long-term memory. The Letta framework is white box and model-agnostic.
[![Discord](https://img.shields.io/discord/1161736243340640419?label=Discord&logo=discord&logoColor=5865F2&style=flat-square&color=5865F2)](https://discord.gg/letta)
[![Twitter Follow](https://img.shields.io/badge/Follow-%40Letta__AI-1DA1F2?style=flat-square&logo=x&logoColor=white)](https://twitter.com/Letta_AI)
@@ -157,7 +144,7 @@ No, the data in your Letta server database stays on your machine. The Letta ADE
> _"Do I have to use your ADE? Can I build my own?"_
The ADE is built on top of the (fully open source) Letta server and Letta Agents API. You can build your own application like the ADE on top of the REST API (view the documention [here](https://docs.letta.com/api-reference)).
The ADE is built on top of the (fully open source) Letta server and Letta Agents API. You can build your own application like the ADE on top of the REST API (view the documentation [here](https://docs.letta.com/api-reference)).
> _"Can I interact with Letta agents via the CLI?"_

View File

@@ -28,7 +28,6 @@ services:
- "8083:8083"
- "8283:8283"
environment:
- SERPAPI_API_KEY=${SERPAPI_API_KEY}
- LETTA_PG_DB=${LETTA_PG_DB:-letta}
- LETTA_PG_USER=${LETTA_PG_USER:-letta}
- LETTA_PG_PASSWORD=${LETTA_PG_PASSWORD:-letta}

View File

@@ -8,6 +8,7 @@ If you're using Letta Cloud, replace 'baseURL' with 'token'
See: https://docs.letta.com/api-reference/overview
Execute this script using `poetry run python3 example.py`
This will install `letta_client` and other dependencies.
"""
client = Letta(
base_url="http://localhost:8283",

View File

@@ -2,22 +2,33 @@ from pprint import pprint
from letta_client import Letta
# Connect to Letta server
client = Letta(base_url="http://localhost:8283")
# Use the "everything" mcp server:
# https://github.com/modelcontextprotocol/servers/tree/main/src/everything
mcp_server_name = "everything"
mcp_tool_name = "echo"
# List all McpTool belonging to the "everything" mcp server.
mcp_tools = client.tools.list_mcp_tools_by_server(
mcp_server_name=mcp_server_name,
)
# We can see that "echo" is one of the tools, but it's not
# a letta tool that can be added to a client (it has no tool id).
for tool in mcp_tools:
pprint(tool)
# Create a Tool (with a tool id) using the server and tool names.
mcp_tool = client.tools.add_mcp_tool(
mcp_server_name=mcp_server_name,
mcp_tool_name=mcp_tool_name
)
# Create an agent with the tool, using tool.id -- note that
# this is the ONLY tool in the agent, you typically want to
# also include the default tools.
agent = client.agents.create(
memory_blocks=[
{
@@ -31,6 +42,7 @@ agent = client.agents.create(
)
print(f"Created agent id {agent.id}")
# Ask the agent to call the tool.
response = client.agents.messages.create(
agent_id=agent.id,
messages=[

View File

@@ -253,15 +253,18 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"id": "7808912f-831b-4cdc-8606-40052eb809b4",
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional, List\n",
"from typing import Optional, List, TYPE_CHECKING\n",
"import json\n",
"\n",
"def task_queue_push(self: \"Agent\", task_description: str):\n",
"if TYPE_CHECKING:\n",
" from letta import AgentState\n",
"\n",
"def task_queue_push(agent_state: \"AgentState\", task_description: str):\n",
" \"\"\"\n",
" Push to a task queue stored in core memory. \n",
"\n",
@@ -273,12 +276,12 @@
" does not produce a response.\n",
" \"\"\"\n",
" import json\n",
" tasks = json.loads(self.memory.get_block(\"tasks\").value)\n",
" tasks = json.loads(agent_state.memory.get_block(\"tasks\").value)\n",
" tasks.append(task_description)\n",
" self.memory.update_block_value(\"tasks\", json.dumps(tasks))\n",
" agent_state.memory.update_block_value(\"tasks\", json.dumps(tasks))\n",
" return None\n",
"\n",
"def task_queue_pop(self: \"Agent\"):\n",
"def task_queue_pop(agent_state: \"AgentState\"):\n",
" \"\"\"\n",
" Get the next task from the task queue \n",
"\n",
@@ -288,12 +291,12 @@
" None (the task queue is empty)\n",
" \"\"\"\n",
" import json\n",
" tasks = json.loads(self.memory.get_block(\"tasks\").value)\n",
" tasks = json.loads(agent_state.memory.get_block(\"tasks\").value)\n",
" if len(tasks) == 0: \n",
" return None\n",
" task = tasks[0]\n",
" print(\"CURRENT TASKS: \", tasks)\n",
" self.memory.update_block_value(\"tasks\", json.dumps(tasks[1:]))\n",
" agent_state.memory.update_block_value(\"tasks\", json.dumps(tasks[1:]))\n",
" return task\n",
"\n",
"push_task_tool = client.tools.upsert_from_function(func=task_queue_push)\n",
@@ -310,7 +313,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"id": "135fcf3e-59c4-4da3-b86b-dbffb21aa343",
"metadata": {},
"outputs": [],
@@ -336,10 +339,12 @@
" ),\n",
" CreateBlock(\n",
" label=\"tasks\",\n",
" value=\"\",\n",
" value=\"[]\",\n",
" ),\n",
" ],\n",
" tool_ids=[push_task_tool.id, pop_task_tool.id],\n",
" model=\"letta/letta-free\",\n",
" embedding=\"letta/letta-free\",\n",
")"
]
},

View File

@@ -1,3 +0,0 @@
from .main import app
app()

View File

@@ -233,7 +233,9 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
if endpoint_type == "openai":
return OpenAIEmbeddings(
api_key=model_settings.openai_api_key, model=config.embedding_model, base_url=model_settings.openai_api_base
api_key=model_settings.openai_api_key,
model=config.embedding_model,
base_url=model_settings.openai_api_base,
)
elif endpoint_type == "azure":

View File

@@ -46,7 +46,7 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
# TODO: add paging by page number. currently cursor only works with strings.
# original: start=page * count
messages = self.message_manager.list_user_messages_for_agent(
messages = self.message_manager.list_messages_for_agent(
agent_id=self.agent_state.id,
actor=self.user,
query_text=query,

View File

@@ -55,7 +55,19 @@ BASE_URL = "https://api.anthropic.com/v1"
# https://docs.anthropic.com/claude/docs/models-overview
# Sadly hardcoded
MODEL_LIST = [
## Opus 3
{
"name": "claude-opus-4-20250514",
"context_window": 200000,
},
{
"name": "claude-sonnet-4-20250514",
"context_window": 200000,
},
{
"name": "claude-3-5-haiku-20241022",
"context_window": 200000,
},
## Opus
{
"name": "claude-3-opus-20240229",
"context_window": 200000,

View File

@@ -242,7 +242,8 @@ class AnthropicClient(LLMClientBase):
# Move 'system' to the top level
if messages[0].role != "system":
raise RuntimeError(f"First message is not a system message, instead has role {messages[0].role}")
data["system"] = messages[0].content if isinstance(messages[0].content, str) else messages[0].content[0].text
system_content = messages[0].content if isinstance(messages[0].content, str) else messages[0].content[0].text
data["system"] = self._add_cache_control_to_system_message(system_content)
data["messages"] = [
m.to_anthropic_dict(
inner_thoughts_xml_tag=inner_thoughts_xml_tag,
@@ -499,6 +500,22 @@ class AnthropicClient(LLMClientBase):
return chat_completion_response
def _add_cache_control_to_system_message(self, system_content):
"""Add cache control to system message content"""
if isinstance(system_content, str):
# For string content, convert to list format with cache control
return [{"type": "text", "text": system_content, "cache_control": {"type": "ephemeral"}}]
elif isinstance(system_content, list):
# For list content, add cache control to the last text block
cached_content = system_content.copy()
for i in range(len(cached_content) - 1, -1, -1):
if cached_content[i].get("type") == "text":
cached_content[i]["cache_control"] = {"type": "ephemeral"}
break
return cached_content
return system_content
def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]:
"""See: https://docs.anthropic.com/claude/docs/tool-use

View File

@@ -75,7 +75,8 @@ class LLMConfig(BaseModel):
description="The reasoning effort to use when generating text reasoning models",
)
max_reasoning_tokens: int = Field(
0, description="Configurable thinking budget for extended thinking, only used if enable_reasoner is True. Minimum value is 1024."
0,
description="Configurable thinking budget for extended thinking. Used for enable_reasoner and also for Google Vertex models like Gemini 2.5 Flash. Minimum value is 1024 when used with enable_reasoner.",
)
frequency_penalty: Optional[float] = Field(
None, # Can also deafult to 0.0?

1618
letta/schemas/providers.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -30,9 +30,7 @@ logger = get_logger(__name__)
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {"description": "Server-Sent Events stream"},
},
"content": {"text/event-stream": {}},
}
},
)

View File

@@ -25,9 +25,7 @@ logger = get_logger(__name__)
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {"description": "Server-Sent Events stream"},
},
"content": {"text/event-stream": {}},
}
},
)

View File

@@ -0,0 +1,685 @@
from datetime import datetime, timezone
from typing import Dict, List
from letta.errors import AgentFileExportError, AgentFileImportError
from letta.log import get_logger
from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.agent_file import (
AgentFileSchema,
AgentSchema,
BlockSchema,
FileAgentSchema,
FileSchema,
GroupSchema,
ImportResult,
MessageSchema,
SourceSchema,
ToolSchema,
)
from letta.schemas.block import Block
from letta.schemas.file import FileMetadata
from letta.schemas.message import Message
from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.file_manager import FileManager
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
from letta.services.file_processor.file_processor import FileProcessor
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
from letta.services.files_agents_manager import FileAgentManager
from letta.services.group_manager import GroupManager
from letta.services.mcp_manager import MCPManager
from letta.services.message_manager import MessageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
from letta.utils import get_latest_alembic_revision
logger = get_logger(__name__)
class AgentFileManager:
"""
Manages export and import of agent files between database and AgentFileSchema format.
Handles:
- ID mapping between database IDs and human-readable file IDs
- Coordination across multiple entity managers
- Transaction safety during imports
- Referential integrity validation
"""
def __init__(
self,
agent_manager: AgentManager,
tool_manager: ToolManager,
source_manager: SourceManager,
block_manager: BlockManager,
group_manager: GroupManager,
mcp_manager: MCPManager,
file_manager: FileManager,
file_agent_manager: FileAgentManager,
message_manager: MessageManager,
embedder: BaseEmbedder,
file_parser: MistralFileParser,
using_pinecone: bool = False,
):
self.agent_manager = agent_manager
self.tool_manager = tool_manager
self.source_manager = source_manager
self.block_manager = block_manager
self.group_manager = group_manager
self.mcp_manager = mcp_manager
self.file_manager = file_manager
self.file_agent_manager = file_agent_manager
self.message_manager = message_manager
self.embedder = embedder
self.file_parser = file_parser
self.using_pinecone = using_pinecone
# ID mapping state for export
self._db_to_file_ids: Dict[str, str] = {}
# Counters for generating Stripe-style IDs
self._id_counters: Dict[str, int] = {
AgentSchema.__id_prefix__: 0,
GroupSchema.__id_prefix__: 0,
BlockSchema.__id_prefix__: 0,
FileSchema.__id_prefix__: 0,
SourceSchema.__id_prefix__: 0,
ToolSchema.__id_prefix__: 0,
MessageSchema.__id_prefix__: 0,
FileAgentSchema.__id_prefix__: 0,
# MCPServerSchema.__id_prefix__: 0,
}
def _reset_state(self):
"""Reset internal state for a new operation"""
self._db_to_file_ids.clear()
for key in self._id_counters:
self._id_counters[key] = 0
def _generate_file_id(self, entity_type: str) -> str:
"""Generate a Stripe-style ID for the given entity type"""
counter = self._id_counters[entity_type]
file_id = f"{entity_type}-{counter}"
self._id_counters[entity_type] += 1
return file_id
def _map_db_to_file_id(self, db_id: str, entity_type: str, allow_new: bool = True) -> str:
"""Map a database UUID to a file ID, creating if needed (export only)"""
if db_id in self._db_to_file_ids:
return self._db_to_file_ids[db_id]
if not allow_new:
raise AgentFileExportError(
f"Unexpected new {entity_type} ID '{db_id}' encountered during conversion. "
f"All IDs should have been mapped during agent processing."
)
file_id = self._generate_file_id(entity_type)
self._db_to_file_ids[db_id] = file_id
return file_id
def _extract_unique_tools(self, agent_states: List[AgentState]) -> List:
"""Extract unique tools across all agent states by ID"""
all_tools = []
for agent_state in agent_states:
if agent_state.tools:
all_tools.extend(agent_state.tools)
unique_tools = {}
for tool in all_tools:
unique_tools[tool.id] = tool
return sorted(unique_tools.values(), key=lambda x: x.name)
def _extract_unique_blocks(self, agent_states: List[AgentState]) -> List:
"""Extract unique blocks across all agent states by ID"""
all_blocks = []
for agent_state in agent_states:
if agent_state.memory and agent_state.memory.blocks:
all_blocks.extend(agent_state.memory.blocks)
unique_blocks = {}
for block in all_blocks:
unique_blocks[block.id] = block
return sorted(unique_blocks.values(), key=lambda x: x.label)
async def _extract_unique_sources_and_files_from_agents(
self, agent_states: List[AgentState], actor: User, files_agents_cache: dict = None
) -> tuple[List[Source], List[FileMetadata]]:
"""Extract unique sources and files from agent states using bulk operations"""
all_source_ids = set()
all_file_ids = set()
for agent_state in agent_states:
files_agents = await self.file_agent_manager.list_files_for_agent(
agent_id=agent_state.id, actor=actor, is_open_only=False, return_as_blocks=False
)
# cache the results for reuse during conversion
if files_agents_cache is not None:
files_agents_cache[agent_state.id] = files_agents
for file_agent in files_agents:
all_source_ids.add(file_agent.source_id)
all_file_ids.add(file_agent.file_id)
sources = await self.source_manager.get_sources_by_ids_async(list(all_source_ids), actor)
files = await self.file_manager.get_files_by_ids_async(list(all_file_ids), actor, include_content=True)
return sources, files
async def _convert_agent_state_to_schema(self, agent_state: AgentState, actor: User, files_agents_cache: dict = None) -> AgentSchema:
"""Convert AgentState to AgentSchema with ID remapping"""
agent_file_id = self._map_db_to_file_id(agent_state.id, AgentSchema.__id_prefix__)
# use cached file-agent data if available, otherwise fetch
if files_agents_cache is not None and agent_state.id in files_agents_cache:
files_agents = files_agents_cache[agent_state.id]
else:
files_agents = await self.file_agent_manager.list_files_for_agent(
agent_id=agent_state.id, actor=actor, is_open_only=False, return_as_blocks=False
)
agent_schema = await AgentSchema.from_agent_state(
agent_state, message_manager=self.message_manager, files_agents=files_agents, actor=actor
)
agent_schema.id = agent_file_id
if agent_schema.messages:
for message in agent_schema.messages:
message_file_id = self._map_db_to_file_id(message.id, MessageSchema.__id_prefix__)
message.id = message_file_id
message.agent_id = agent_file_id
if agent_schema.in_context_message_ids:
agent_schema.in_context_message_ids = [
self._map_db_to_file_id(message_id, MessageSchema.__id_prefix__, allow_new=False)
for message_id in agent_schema.in_context_message_ids
]
if agent_schema.tool_ids:
agent_schema.tool_ids = [self._map_db_to_file_id(tool_id, ToolSchema.__id_prefix__) for tool_id in agent_schema.tool_ids]
if agent_schema.source_ids:
agent_schema.source_ids = [
self._map_db_to_file_id(source_id, SourceSchema.__id_prefix__) for source_id in agent_schema.source_ids
]
if agent_schema.block_ids:
agent_schema.block_ids = [self._map_db_to_file_id(block_id, BlockSchema.__id_prefix__) for block_id in agent_schema.block_ids]
if agent_schema.files_agents:
for file_agent in agent_schema.files_agents:
file_agent.file_id = self._map_db_to_file_id(file_agent.file_id, FileSchema.__id_prefix__)
file_agent.source_id = self._map_db_to_file_id(file_agent.source_id, SourceSchema.__id_prefix__)
file_agent.agent_id = agent_file_id
return agent_schema
def _convert_tool_to_schema(self, tool) -> ToolSchema:
"""Convert Tool to ToolSchema with ID remapping"""
tool_file_id = self._map_db_to_file_id(tool.id, ToolSchema.__id_prefix__, allow_new=False)
tool_schema = ToolSchema.from_tool(tool)
tool_schema.id = tool_file_id
return tool_schema
def _convert_block_to_schema(self, block) -> BlockSchema:
"""Convert Block to BlockSchema with ID remapping"""
block_file_id = self._map_db_to_file_id(block.id, BlockSchema.__id_prefix__, allow_new=False)
block_schema = BlockSchema.from_block(block)
block_schema.id = block_file_id
return block_schema
def _convert_source_to_schema(self, source) -> SourceSchema:
"""Convert Source to SourceSchema with ID remapping"""
source_file_id = self._map_db_to_file_id(source.id, SourceSchema.__id_prefix__, allow_new=False)
source_schema = SourceSchema.from_source(source)
source_schema.id = source_file_id
return source_schema
def _convert_file_to_schema(self, file_metadata) -> FileSchema:
"""Convert FileMetadata to FileSchema with ID remapping"""
file_file_id = self._map_db_to_file_id(file_metadata.id, FileSchema.__id_prefix__, allow_new=False)
file_schema = FileSchema.from_file_metadata(file_metadata)
file_schema.id = file_file_id
file_schema.source_id = self._map_db_to_file_id(file_metadata.source_id, SourceSchema.__id_prefix__, allow_new=False)
return file_schema
async def export(self, agent_ids: List[str], actor: User) -> AgentFileSchema:
"""
Export agents and their related entities to AgentFileSchema format.
Args:
agent_ids: List of agent UUIDs to export
Returns:
AgentFileSchema with all related entities
Raises:
AgentFileExportError: If export fails
"""
try:
self._reset_state()
agent_states = await self.agent_manager.get_agents_by_ids_async(agent_ids=agent_ids, actor=actor)
# Validate that all requested agents were found
if len(agent_states) != len(agent_ids):
found_ids = {agent.id for agent in agent_states}
missing_ids = [agent_id for agent_id in agent_ids if agent_id not in found_ids]
raise AgentFileExportError(f"The following agent IDs were not found: {missing_ids}")
# cache for file-agent relationships to avoid duplicate queries
files_agents_cache = {} # Maps agent_id to list of file_agent relationships
# Extract unique entities across all agents
tool_set = self._extract_unique_tools(agent_states)
block_set = self._extract_unique_blocks(agent_states)
# Extract sources and files from agent states BEFORE conversion (with caching)
source_set, file_set = await self._extract_unique_sources_and_files_from_agents(agent_states, actor, files_agents_cache)
# Convert to schemas with ID remapping (reusing cached file-agent data)
agent_schemas = [
await self._convert_agent_state_to_schema(agent_state, actor=actor, files_agents_cache=files_agents_cache)
for agent_state in agent_states
]
tool_schemas = [self._convert_tool_to_schema(tool) for tool in tool_set]
block_schemas = [self._convert_block_to_schema(block) for block in block_set]
source_schemas = [self._convert_source_to_schema(source) for source in source_set]
file_schemas = [self._convert_file_to_schema(file_metadata) for file_metadata in file_set]
logger.info(f"Exporting {len(agent_ids)} agents to agent file format")
# Return AgentFileSchema with converted entities
return AgentFileSchema(
agents=agent_schemas,
groups=[], # TODO: Extract and convert groups
blocks=block_schemas,
files=file_schemas,
sources=source_schemas,
tools=tool_schemas,
# mcp_servers=[], # TODO: Extract and convert MCP servers
metadata={"revision_id": await get_latest_alembic_revision()},
created_at=datetime.now(timezone.utc),
)
except Exception as e:
logger.error(f"Failed to export agent file: {e}")
raise AgentFileExportError(f"Export failed: {e}") from e
async def import_file(self, schema: AgentFileSchema, actor: User, dry_run: bool = False) -> ImportResult:
"""
Import AgentFileSchema into the database.
Args:
schema: The agent file schema to import
dry_run: If True, validate but don't commit changes
Returns:
ImportResult with success status and details
Raises:
AgentFileImportError: If import fails
"""
try:
self._reset_state()
if dry_run:
logger.info("Starting dry run import validation")
else:
logger.info("Starting agent file import")
# Validate schema first
self._validate_schema(schema)
if dry_run:
return ImportResult(
success=True,
message="Dry run validation passed",
imported_count=0,
)
# Import in dependency order
imported_count = 0
file_to_db_ids = {} # Maps file IDs to new database IDs
# in-memory cache for file metadata to avoid repeated db calls
file_metadata_cache = {} # Maps database file ID to FileMetadata
# 1. Create tools first (no dependencies) - using bulk upsert for efficiency
if schema.tools:
# convert tool schemas to pydantic tools
pydantic_tools = []
for tool_schema in schema.tools:
pydantic_tools.append(Tool(**tool_schema.model_dump(exclude={"id"})))
# bulk upsert all tools at once
created_tools = await self.tool_manager.bulk_upsert_tools_async(pydantic_tools, actor)
# map file ids to database ids
# note: tools are matched by name during upsert, so we need to match by name here too
created_tools_by_name = {tool.name: tool for tool in created_tools}
for tool_schema in schema.tools:
created_tool = created_tools_by_name.get(tool_schema.name)
if created_tool:
file_to_db_ids[tool_schema.id] = created_tool.id
imported_count += 1
else:
logger.warning(f"Tool {tool_schema.name} was not created during bulk upsert")
# 2. Create blocks (no dependencies) - using batch create for efficiency
if schema.blocks:
# convert block schemas to pydantic blocks (excluding IDs to create new blocks)
pydantic_blocks = []
for block_schema in schema.blocks:
pydantic_blocks.append(Block(**block_schema.model_dump(exclude={"id"})))
# batch create all blocks at once
created_blocks = await self.block_manager.batch_create_blocks_async(pydantic_blocks, actor)
# map file ids to database ids
for block_schema, created_block in zip(schema.blocks, created_blocks):
file_to_db_ids[block_schema.id] = created_block.id
imported_count += 1
# 3. Create sources (no dependencies) - using bulk upsert for efficiency
if schema.sources:
# convert source schemas to pydantic sources
pydantic_sources = []
for source_schema in schema.sources:
source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"})
pydantic_sources.append(Source(**source_data))
# bulk upsert all sources at once
created_sources = await self.source_manager.bulk_upsert_sources_async(pydantic_sources, actor)
# map file ids to database ids
# note: sources are matched by name during upsert, so we need to match by name here too
created_sources_by_name = {source.name: source for source in created_sources}
for source_schema in schema.sources:
created_source = created_sources_by_name.get(source_schema.name)
if created_source:
file_to_db_ids[source_schema.id] = created_source.id
imported_count += 1
else:
logger.warning(f"Source {source_schema.name} was not created during bulk upsert")
# 4. Create files (depends on sources)
for file_schema in schema.files:
# Convert FileSchema back to FileMetadata
file_data = file_schema.model_dump(exclude={"id", "content"})
# Remap source_id from file ID to database ID
file_data["source_id"] = file_to_db_ids[file_schema.source_id]
file_metadata = FileMetadata(**file_data)
created_file = await self.file_manager.create_file(file_metadata, actor, text=file_schema.content)
file_to_db_ids[file_schema.id] = created_file.id
imported_count += 1
# 5. Process files for chunking/embedding (depends on files and sources)
file_processor = FileProcessor(
file_parser=self.file_parser,
embedder=self.embedder,
actor=actor,
using_pinecone=self.using_pinecone,
)
for file_schema in schema.files:
if file_schema.content: # Only process files with content
file_db_id = file_to_db_ids[file_schema.id]
source_db_id = file_to_db_ids[file_schema.source_id]
# Get the created file metadata (with caching)
if file_db_id not in file_metadata_cache:
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
file_metadata = file_metadata_cache[file_db_id]
# Save the db call of fetching content again
file_metadata.content = file_schema.content
# Process the file for chunking/embedding
passages = await file_processor.process_imported_file(file_metadata=file_metadata, source_id=source_db_id)
imported_count += len(passages)
# 6. Create agents with empty message history
for agent_schema in schema.agents:
# Convert AgentSchema back to CreateAgent, remapping tool/block IDs
agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"})
# Remap tool_ids from file IDs to database IDs
if agent_data.get("tool_ids"):
agent_data["tool_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["tool_ids"]]
# Remap block_ids from file IDs to database IDs
if agent_data.get("block_ids"):
agent_data["block_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["block_ids"]]
agent_create = CreateAgent(**agent_data)
created_agent = await self.agent_manager.create_agent_async(agent_create, actor, _init_with_no_messages=True)
file_to_db_ids[agent_schema.id] = created_agent.id
imported_count += 1
# 7. Create messages and update agent message_ids
for agent_schema in schema.agents:
agent_db_id = file_to_db_ids[agent_schema.id]
message_file_to_db_ids = {}
# Create messages for this agent
messages = []
for message_schema in agent_schema.messages:
# Convert MessageSchema back to Message, setting agent_id to new DB ID
message_data = message_schema.model_dump(exclude={"id"})
message_data["agent_id"] = agent_db_id # Remap agent_id to new database ID
message_obj = Message(**message_data)
messages.append(message_obj)
# Map file ID to the generated database ID immediately
message_file_to_db_ids[message_schema.id] = message_obj.id
created_messages = await self.message_manager.create_many_messages_async(pydantic_msgs=messages, actor=actor)
imported_count += len(created_messages)
# Remap in_context_message_ids from file IDs to database IDs
in_context_db_ids = [message_file_to_db_ids[message_schema_id] for message_schema_id in agent_schema.in_context_message_ids]
# Update agent with the correct message_ids
await self.agent_manager.update_message_ids_async(agent_id=agent_db_id, message_ids=in_context_db_ids, actor=actor)
# 8. Create file-agent relationships (depends on agents and files)
for agent_schema in schema.agents:
if agent_schema.files_agents:
agent_db_id = file_to_db_ids[agent_schema.id]
# Prepare files for bulk attachment
files_for_agent = []
visible_content_map = {}
for file_agent_schema in agent_schema.files_agents:
file_db_id = file_to_db_ids[file_agent_schema.file_id]
# Use cached file metadata if available
if file_db_id not in file_metadata_cache:
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
file_metadata = file_metadata_cache[file_db_id]
files_for_agent.append(file_metadata)
if file_agent_schema.visible_content:
visible_content_map[file_db_id] = file_agent_schema.visible_content
# Bulk attach files to agent
await self.file_agent_manager.attach_files_bulk(
agent_id=agent_db_id, files_metadata=files_for_agent, visible_content_map=visible_content_map, actor=actor
)
imported_count += len(files_for_agent)
return ImportResult(
success=True,
message=f"Import completed successfully. Imported {imported_count} entities.",
imported_count=imported_count,
id_mappings=file_to_db_ids,
)
except Exception as e:
logger.exception(f"Failed to import agent file: {e}")
raise AgentFileImportError(f"Import failed: {e}") from e
def _validate_id_format(self, schema: AgentFileSchema) -> List[str]:
"""Validate that all IDs follow the expected format"""
errors = []
# Define entity types and their expected prefixes
entity_checks = [
(schema.agents, AgentSchema.__id_prefix__),
(schema.groups, GroupSchema.__id_prefix__),
(schema.blocks, BlockSchema.__id_prefix__),
(schema.files, FileSchema.__id_prefix__),
(schema.sources, SourceSchema.__id_prefix__),
(schema.tools, ToolSchema.__id_prefix__),
]
for entities, expected_prefix in entity_checks:
for entity in entities:
if not entity.id.startswith(f"{expected_prefix}-"):
errors.append(f"Invalid ID format: {entity.id} should start with '{expected_prefix}-'")
else:
# Check that the suffix is a valid integer
try:
suffix = entity.id[len(expected_prefix) + 1 :]
int(suffix)
except ValueError:
errors.append(f"Invalid ID format: {entity.id} should have integer suffix")
# Also check message IDs within agents
for agent in schema.agents:
for message in agent.messages:
if not message.id.startswith(f"{MessageSchema.__id_prefix__}-"):
errors.append(f"Invalid message ID format: {message.id} should start with '{MessageSchema.__id_prefix__}-'")
else:
# Check that the suffix is a valid integer
try:
suffix = message.id[len(MessageSchema.__id_prefix__) + 1 :]
int(suffix)
except ValueError:
errors.append(f"Invalid message ID format: {message.id} should have integer suffix")
return errors
def _validate_duplicate_ids(self, schema: AgentFileSchema) -> List[str]:
"""Validate that there are no duplicate IDs within or across entity types"""
errors = []
all_ids = set()
# Check each entity type for internal duplicates and collect all IDs
entity_collections = [
("agents", schema.agents),
("groups", schema.groups),
("blocks", schema.blocks),
("files", schema.files),
("sources", schema.sources),
("tools", schema.tools),
]
for entity_type, entities in entity_collections:
entity_ids = [entity.id for entity in entities]
# Check for duplicates within this entity type
seen = set()
duplicates = set()
for entity_id in entity_ids:
if entity_id in seen:
duplicates.add(entity_id)
else:
seen.add(entity_id)
if duplicates:
errors.append(f"Duplicate {entity_type} IDs found: {duplicates}")
# Check for duplicates across all entity types
for entity_id in entity_ids:
if entity_id in all_ids:
errors.append(f"Duplicate ID across entity types: {entity_id}")
all_ids.add(entity_id)
# Also check message IDs within agents
for agent in schema.agents:
message_ids = [msg.id for msg in agent.messages]
# Check for duplicates within agent messages
seen = set()
duplicates = set()
for message_id in message_ids:
if message_id in seen:
duplicates.add(message_id)
else:
seen.add(message_id)
if duplicates:
errors.append(f"Duplicate message IDs in agent {agent.id}: {duplicates}")
# Check for duplicates across all entity types
for message_id in message_ids:
if message_id in all_ids:
errors.append(f"Duplicate ID across entity types: {message_id}")
all_ids.add(message_id)
return errors
def _validate_file_source_references(self, schema: AgentFileSchema) -> List[str]:
"""Validate that all file source_id references exist"""
errors = []
source_ids = {source.id for source in schema.sources}
for file in schema.files:
if file.source_id not in source_ids:
errors.append(f"File {file.id} references non-existent source {file.source_id}")
return errors
def _validate_file_agent_references(self, schema: AgentFileSchema) -> List[str]:
"""Validate that all file-agent relationships reference existing entities"""
errors = []
file_ids = {file.id for file in schema.files}
source_ids = {source.id for source in schema.sources}
{agent.id for agent in schema.agents}
for agent in schema.agents:
for file_agent in agent.files_agents:
if file_agent.file_id not in file_ids:
errors.append(f"File-agent relationship references non-existent file {file_agent.file_id}")
if file_agent.source_id not in source_ids:
errors.append(f"File-agent relationship references non-existent source {file_agent.source_id}")
if file_agent.agent_id != agent.id:
errors.append(f"File-agent relationship has mismatched agent_id {file_agent.agent_id} vs {agent.id}")
return errors
def _validate_schema(self, schema: AgentFileSchema):
"""
Validate the agent file schema for consistency and referential integrity.
Args:
schema: The schema to validate
Raises:
AgentFileImportError: If validation fails
"""
errors = []
# 1. ID Format Validation
errors.extend(self._validate_id_format(schema))
# 2. Duplicate ID Detection
errors.extend(self._validate_duplicate_ids(schema))
# 3. File Source Reference Validation
errors.extend(self._validate_file_source_references(schema))
# 4. File-Agent Reference Validation
errors.extend(self._validate_file_agent_references(schema))
if errors:
raise AgentFileImportError(f"Schema validation failed: {'; '.join(errors)}")
logger.info("Schema validation passed")

View File

@@ -1,5 +1,6 @@
import asyncio
import json
import os
import time
from typing import Any, Dict, List, Literal, Optional

View File

@@ -0,0 +1,32 @@
version: '3.7'
services:
redis:
image: redis:alpine
container_name: redis
healthcheck:
test: ['CMD-SHELL', 'redis-cli ping | grep PONG']
interval: 1s
timeout: 3s
retries: 5
ports:
- '6379:6379'
volumes:
- ./data/redis:/data
command: redis-server --appendonly yes
postgres:
image: ankane/pgvector
container_name: postgres
healthcheck:
test: ['CMD-SHELL', 'pg_isready -U postgres']
interval: 1s
timeout: 3s
retries: 5
ports:
- '5432:5432'
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: letta
volumes:
- ./data/postgres:/var/lib/postgresql/data
- ./scripts/postgres-db-init/init.sql:/docker-entrypoint-initdb.d/init.sql

View File

@@ -156,6 +156,7 @@ async def test_sleeptime_group_chat(server, actor):
# 6. Verify run status after sleep
time.sleep(2)
for run_id in run_ids:
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
assert job.status == JobStatus.running or job.status == JobStatus.completed

View File

@@ -151,6 +151,7 @@ def test_archival(agent_obj):
def test_recall(server, agent_obj, default_user):
"""Test that an agent can recall messages using a keyword via conversation search."""
keyword = "banana"
"".join(reversed(keyword))
# Send messages
for msg in ["hello", keyword, "tell me a fun fact"]:

View File

@@ -8,10 +8,9 @@ def adjust_menu_prices(percentage: float) -> str:
str: A formatted string summarizing the price adjustments.
"""
import cowsay
from tqdm import tqdm
from core.menu import Menu, MenuItem # Import a class from the codebase
from core.utils import format_currency # Use a utility function to test imports
from tqdm import tqdm
if not isinstance(percentage, (int, float)):
raise TypeError("percentage must be a number")