chore: merge oss (#3712)
This commit is contained in:
15
.github/ISSUE_TEMPLATE/bug_report.md
vendored
15
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -11,20 +11,25 @@ assignees: ''
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Please describe your setup**
|
||||
- [ ] How did you install letta?
|
||||
- `pip install letta`? `pip install letta-nightly`? `git clone`?
|
||||
- [ ] How are you running Letta?
|
||||
- Docker
|
||||
- pip (legacy)
|
||||
- From source
|
||||
- Desktop
|
||||
- [ ] Describe your setup
|
||||
- What's your OS (Windows/MacOS/Linux)?
|
||||
- How are you running `letta`? (`cmd.exe`/Powershell/Anaconda Shell/Terminal)
|
||||
- What is your `docker run ...` command (if applicable)
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
- What model you are using
|
||||
|
||||
**Agent File (optional)**
|
||||
Please attach your `.af` file, as this helps with reproducing issues.
|
||||
|
||||
**Letta Config**
|
||||
Please attach your `~/.letta/config` file or copy paste it below.
|
||||
|
||||
---
|
||||
|
||||
|
||||
286
.github/scripts/model-sweep/conftest.py
vendored
Normal file
286
.github/scripts/model-sweep/conftest.py
vendored
Normal file
@@ -0,0 +1,286 @@
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta, Letta
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import tool_settings
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_e2b_api_key() -> Generator[None, None, None]:
|
||||
"""
|
||||
Temporarily disables the E2B API key by setting `tool_settings.e2b_api_key` to None
|
||||
for the duration of the test. Restores the original value afterward.
|
||||
"""
|
||||
from letta.settings import tool_settings
|
||||
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
tool_settings.e2b_api_key = None
|
||||
yield
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_e2b_key_is_set():
|
||||
from letta.settings import tool_settings
|
||||
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
assert original_api_key is not None, "Missing e2b key! Cannot execute these tests."
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_organization():
|
||||
"""Fixture to create and return the default organization."""
|
||||
manager = OrganizationManager()
|
||||
org = manager.create_default_organization()
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
manager = UserManager()
|
||||
user = manager.create_default_user(org_id=default_organization.id)
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_composio_key_set():
|
||||
original_api_key = tool_settings.composio_api_key
|
||||
assert original_api_key is not None, "Missing composio key! Cannot execute this test."
|
||||
yield
|
||||
|
||||
|
||||
# --- Tool Fixtures ---
|
||||
@pytest.fixture
|
||||
def weather_tool_func():
|
||||
def get_weather(location: str) -> str:
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
|
||||
Parameters:
|
||||
location (str): The location to get the weather for.
|
||||
|
||||
Returns:
|
||||
str: A formatted string describing the weather in the given location.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the request to fetch weather data fails.
|
||||
"""
|
||||
import requests
|
||||
|
||||
url = f"https://wttr.in/{location}?format=%C+%t"
|
||||
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
weather_data = response.text
|
||||
return f"The weather in {location} is {weather_data}."
|
||||
else:
|
||||
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
||||
|
||||
yield get_weather
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def print_tool_func():
|
||||
"""Fixture to create a tool with default settings and clean up after the test."""
|
||||
|
||||
def print_tool(message: str):
|
||||
"""
|
||||
Args:
|
||||
message (str): The message to print.
|
||||
|
||||
Returns:
|
||||
str: The message that was printed.
|
||||
"""
|
||||
print(message)
|
||||
return message
|
||||
|
||||
yield print_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def roll_dice_tool_func():
|
||||
def roll_dice():
|
||||
"""
|
||||
Rolls a 6 sided die.
|
||||
|
||||
Returns:
|
||||
str: The roll result.
|
||||
"""
|
||||
import time
|
||||
|
||||
time.sleep(1)
|
||||
return "Rolled a 10!"
|
||||
|
||||
yield roll_dice
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_beta_message_batch() -> BetaMessageBatch:
|
||||
return BetaMessageBatch(
|
||||
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
processing_status="in_progress",
|
||||
request_counts=BetaMessageBatchRequestCounts(
|
||||
canceled=10,
|
||||
errored=30,
|
||||
expired=10,
|
||||
processing=100,
|
||||
succeeded=50,
|
||||
),
|
||||
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
|
||||
type="message_batch",
|
||||
)
|
||||
|
||||
|
||||
# --- Model Sweep ---
|
||||
# Global flag to track server state
|
||||
_server_started = False
|
||||
_server_url = None
|
||||
|
||||
|
||||
def _start_server_once() -> str:
|
||||
"""Start server exactly once, return URL"""
|
||||
global _server_started, _server_url
|
||||
|
||||
if _server_started and _server_url:
|
||||
return _server_url
|
||||
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
# Check if already running
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
if s.connect_ex(("localhost", 8283)) == 0:
|
||||
_server_started = True
|
||||
_server_url = url
|
||||
return url
|
||||
|
||||
# Start server (your existing logic)
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
|
||||
def _run_server():
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Poll until up
|
||||
timeout_seconds = 30
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url + "/v1/health")
|
||||
if resp.status_code < 500:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||
|
||||
_server_started = True
|
||||
_server_url = url
|
||||
return url
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Fixtures
|
||||
# ------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url() -> str:
|
||||
"""Return URL of already-started server"""
|
||||
return _start_server_once()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server_url: str) -> Letta:
|
||||
"""
|
||||
Creates and returns a synchronous Letta REST client for testing.
|
||||
"""
|
||||
client_instance = Letta(base_url=server_url)
|
||||
yield client_instance
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def async_client(server_url: str) -> AsyncLetta:
|
||||
"""
|
||||
Creates and returns an asynchronous Letta REST client for testing.
|
||||
"""
|
||||
async_client_instance = AsyncLetta(base_url=server_url)
|
||||
yield async_client_instance
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent_state(client: Letta) -> AgentState:
|
||||
"""
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
|
||||
"""
|
||||
client.tools.upsert_base_tools()
|
||||
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
agent_state_instance = client.agents.create(
|
||||
name="supervisor",
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id],
|
||||
model="openai/gpt-4o",
|
||||
embedding="letta/letta-free",
|
||||
tags=["supervisor"],
|
||||
)
|
||||
yield agent_state_instance
|
||||
|
||||
client.agents.delete(agent_state_instance.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def all_available_llm_configs(client: Letta) -> [LLMConfig]:
|
||||
"""
|
||||
Returns a list of all available LLM configs.
|
||||
"""
|
||||
llm_configs = client.models.list()
|
||||
return llm_configs
|
||||
|
||||
|
||||
# create a client to the started server started at
|
||||
def get_available_llm_configs() -> [LLMConfig]:
|
||||
"""Get configs, starting server if needed"""
|
||||
server_url = _start_server_once()
|
||||
temp_client = Letta(base_url=server_url)
|
||||
return temp_client.models.list()
|
||||
|
||||
|
||||
# dynamically insert llm_config paramter at collection time
|
||||
def pytest_generate_tests(metafunc):
|
||||
"""Dynamically parametrize tests that need llm_config."""
|
||||
if "llm_config" in metafunc.fixturenames:
|
||||
configs = get_available_llm_configs()
|
||||
if configs:
|
||||
metafunc.parametrize("llm_config", configs, ids=[c.model for c in configs])
|
||||
21
.github/scripts/model-sweep/feature_mappings.json
vendored
Normal file
21
.github/scripts/model-sweep/feature_mappings.json
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"Basic": [
|
||||
"test_greeting_with_assistant_message",
|
||||
"test_greeting_without_assistant_message",
|
||||
"test_async_greeting_with_assistant_message",
|
||||
"test_agent_loop_error",
|
||||
"test_step_stream_agent_loop_error",
|
||||
"test_step_streaming_greeting_with_assistant_message",
|
||||
"test_step_streaming_greeting_without_assistant_message",
|
||||
"test_step_streaming_tool_call",
|
||||
"test_tool_call",
|
||||
"test_auto_summarize"
|
||||
],
|
||||
"Token Streaming": [
|
||||
"test_token_streaming_greeting_with_assistant_message",
|
||||
"test_token_streaming_greeting_without_assistant_message",
|
||||
"test_token_streaming_agent_loop_error",
|
||||
"test_token_streaming_tool_call"
|
||||
],
|
||||
"Multimodal": ["test_base64_image_input", "test_url_image_input"]
|
||||
}
|
||||
495
.github/scripts/model-sweep/generate_model_sweep_markdown.py
vendored
Normal file
495
.github/scripts/model-sweep/generate_model_sweep_markdown.py
vendored
Normal file
@@ -0,0 +1,495 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def load_feature_mappings(config_file=None):
|
||||
"""Load feature mappings from config file."""
|
||||
if config_file is None:
|
||||
# Default to feature_mappings.json in the same directory as this script
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_file = os.path.join(script_dir, "feature_mappings.json")
|
||||
|
||||
try:
|
||||
with open(config_file, "r") as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Could not find feature mappings config file '{config_file}'")
|
||||
sys.exit(1)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: Invalid JSON in feature mappings config file '{config_file}'")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_support_status(passed_tests, feature_tests):
|
||||
"""Determine support status for a feature category."""
|
||||
if not feature_tests:
|
||||
return "❓" # Unknown - no tests for this feature
|
||||
|
||||
# Filter out error tests when checking for support
|
||||
non_error_tests = [test for test in feature_tests if not test.endswith("_error")]
|
||||
error_tests = [test for test in feature_tests if test.endswith("_error")]
|
||||
|
||||
# Check which non-error tests passed
|
||||
passed_non_error_tests = [test for test in non_error_tests if test in passed_tests]
|
||||
|
||||
# If there are no non-error tests, only error tests, treat as unknown
|
||||
if not non_error_tests:
|
||||
return "❓" # Only error tests available
|
||||
|
||||
# Support is based only on non-error tests
|
||||
if len(passed_non_error_tests) == len(non_error_tests):
|
||||
return "✅" # Full support
|
||||
elif len(passed_non_error_tests) == 0:
|
||||
return "❌" # No support
|
||||
else:
|
||||
return "⚠️" # Partial support
|
||||
|
||||
|
||||
def categorize_tests(all_test_names, feature_mapping):
|
||||
"""Categorize test names into feature buckets."""
|
||||
categorized = {feature: [] for feature in feature_mapping.keys()}
|
||||
|
||||
for test_name in all_test_names:
|
||||
for feature, test_patterns in feature_mapping.items():
|
||||
if test_name in test_patterns:
|
||||
categorized[feature].append(test_name)
|
||||
break
|
||||
|
||||
return categorized
|
||||
|
||||
|
||||
def calculate_support_score(feature_support, feature_order):
|
||||
"""Calculate a numeric support score for ranking models.
|
||||
|
||||
For partial support, the score is weighted by the position of the feature
|
||||
in the feature_order list (earlier features get higher weight).
|
||||
"""
|
||||
score = 0
|
||||
max_features = len(feature_order)
|
||||
|
||||
for feature, status in feature_support.items():
|
||||
# Get position weight (earlier features get higher weight)
|
||||
if feature in feature_order:
|
||||
position_weight = (max_features - feature_order.index(feature)) / max_features
|
||||
else:
|
||||
position_weight = 0.5 # Default weight for unmapped features
|
||||
|
||||
if status == "✅": # Full support
|
||||
score += 10 * position_weight
|
||||
elif status == "⚠️": # Partial support - weighted by column position
|
||||
score += 5 * position_weight
|
||||
elif status == "❌": # No support
|
||||
score += 1 * position_weight
|
||||
# Unknown (❓) gets 0 points
|
||||
return score
|
||||
|
||||
|
||||
def calculate_provider_support_score(models_data, feature_order):
|
||||
"""Calculate a provider-level support score based on all models' support scores."""
|
||||
if not models_data:
|
||||
return 0
|
||||
|
||||
# Calculate the average support score across all models in the provider
|
||||
total_score = sum(model["support_score"] for model in models_data)
|
||||
return total_score / len(models_data)
|
||||
|
||||
|
||||
def get_test_function_line_numbers(test_file_path):
|
||||
"""Extract line numbers for test functions from the test file."""
|
||||
test_line_numbers = {}
|
||||
|
||||
try:
|
||||
with open(test_file_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
if "def test_" in line and line.strip().startswith("def test_"):
|
||||
# Extract function name
|
||||
func_name = line.strip().split("def ")[1].split("(")[0]
|
||||
test_line_numbers[func_name] = i
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Could not find test file at {test_file_path}")
|
||||
|
||||
return test_line_numbers
|
||||
|
||||
|
||||
def get_github_repo_info():
|
||||
"""Get GitHub repository information from git remote."""
|
||||
try:
|
||||
# Try to get the GitHub repo URL from git remote
|
||||
import subprocess
|
||||
|
||||
result = subprocess.run(["git", "remote", "get-url", "origin"], capture_output=True, text=True, cwd=os.path.dirname(__file__))
|
||||
if result.returncode == 0:
|
||||
remote_url = result.stdout.strip()
|
||||
# Parse GitHub URL
|
||||
if "github.com" in remote_url:
|
||||
if remote_url.startswith("https://"):
|
||||
# https://github.com/user/repo.git -> user/repo
|
||||
repo_path = remote_url.replace("https://github.com/", "").replace(".git", "")
|
||||
elif remote_url.startswith("git@"):
|
||||
# git@github.com:user/repo.git -> user/repo
|
||||
repo_path = remote_url.split(":")[1].replace(".git", "")
|
||||
else:
|
||||
return None
|
||||
return repo_path
|
||||
except:
|
||||
pass
|
||||
|
||||
# Default fallback
|
||||
return "letta-ai/letta"
|
||||
|
||||
|
||||
def generate_test_details(model_info, feature_mapping):
|
||||
"""Generate detailed test results for a model."""
|
||||
details = []
|
||||
|
||||
# Get test function line numbers
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
test_file_path = os.path.join(script_dir, "model_sweep.py")
|
||||
test_line_numbers = get_test_function_line_numbers(test_file_path)
|
||||
|
||||
# Use the main branch GitHub URL
|
||||
base_github_url = "https://github.com/letta-ai/letta/blob/main/.github/scripts/model-sweep/model_sweep.py"
|
||||
|
||||
for feature, tests in model_info["categorized_tests"].items():
|
||||
if not tests:
|
||||
continue
|
||||
|
||||
details.append(f"### {feature}")
|
||||
details.append("")
|
||||
|
||||
for test in sorted(tests):
|
||||
if test in model_info["passed_tests"]:
|
||||
status = "✅"
|
||||
elif test in model_info["failed_tests"]:
|
||||
status = "❌"
|
||||
else:
|
||||
status = "❓"
|
||||
|
||||
# Create GitHub link if we have line number info
|
||||
if test in test_line_numbers:
|
||||
line_num = test_line_numbers[test]
|
||||
github_link = f"{base_github_url}#L{line_num}"
|
||||
details.append(f"- {status} [`{test}`]({github_link})")
|
||||
else:
|
||||
details.append(f"- {status} `{test}`")
|
||||
details.append("")
|
||||
|
||||
return details
|
||||
|
||||
|
||||
def calculate_column_widths(all_provider_data, feature_mapping):
|
||||
"""Calculate the maximum width needed for each column across all providers."""
|
||||
widths = {"model": len("Model"), "context_window": len("Context Window"), "last_scanned": len("Last Scanned")}
|
||||
|
||||
# Feature column widths
|
||||
for feature in feature_mapping.keys():
|
||||
widths[feature] = len(feature)
|
||||
|
||||
# Check all model data for maximum widths
|
||||
for provider_data in all_provider_data.values():
|
||||
for model_info in provider_data:
|
||||
# Model name width (including backticks)
|
||||
model_width = len(f"`{model_info['name']}`")
|
||||
widths["model"] = max(widths["model"], model_width)
|
||||
|
||||
# Context window width (with commas)
|
||||
context_width = len(f"{model_info['context_window']:,}")
|
||||
widths["context_window"] = max(widths["context_window"], context_width)
|
||||
|
||||
# Last scanned width
|
||||
widths["last_scanned"] = max(widths["last_scanned"], len(str(model_info["last_scanned"])))
|
||||
|
||||
# Feature support symbols are always 2 chars, so no need to check
|
||||
|
||||
return widths
|
||||
|
||||
|
||||
def process_model_sweep_report(input_file, output_file, config_file=None, debug=False):
|
||||
"""Convert model sweep JSON data to MDX report."""
|
||||
|
||||
# Load feature mappings from config file
|
||||
feature_mapping = load_feature_mappings(config_file)
|
||||
|
||||
# if debug:
|
||||
# print("DEBUG: Feature mappings loaded:")
|
||||
# for feature, tests in feature_mapping.items():
|
||||
# print(f" {feature}: {tests}")
|
||||
# print()
|
||||
|
||||
# Read the JSON data
|
||||
with open(input_file, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
tests = data.get("tests", [])
|
||||
|
||||
# if debug:
|
||||
# print("DEBUG: Tests loaded:")
|
||||
# print([test['outcome'] for test in tests if 'haiku' in test['nodeid']])
|
||||
|
||||
# Calculate summary statistics
|
||||
providers = set(test["metadata"]["llm_config"]["provider_name"] for test in tests)
|
||||
models = set(test["metadata"]["llm_config"]["model"] for test in tests)
|
||||
total_tests = len(tests)
|
||||
|
||||
# Start building the MDX
|
||||
mdx_lines = [
|
||||
"---",
|
||||
"title: Support Models",
|
||||
f"generated: {datetime.now().isoformat()}",
|
||||
"---",
|
||||
"",
|
||||
"# Supported Models",
|
||||
"",
|
||||
"## Overview",
|
||||
"",
|
||||
"Letta routinely runs automated scans against available providers and models. These are the results of the latest scan.",
|
||||
"",
|
||||
f"Ran {total_tests} tests against {len(models)} models across {len(providers)} providers on {datetime.now().strftime('%B %dth, %Y')}",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
|
||||
# Group tests by provider
|
||||
provider_groups = defaultdict(list)
|
||||
for test in tests:
|
||||
provider_name = test["metadata"]["llm_config"]["provider_name"]
|
||||
provider_groups[provider_name].append(test)
|
||||
|
||||
# Process all providers first to collect model data
|
||||
all_provider_data = {}
|
||||
provider_support_scores = {}
|
||||
|
||||
for provider_name in provider_groups.keys():
|
||||
provider_tests = provider_groups[provider_name]
|
||||
|
||||
# Group tests by model within this provider
|
||||
model_groups = defaultdict(list)
|
||||
for test in provider_tests:
|
||||
model_name = test["metadata"]["llm_config"]["model"]
|
||||
model_groups[model_name].append(test)
|
||||
|
||||
# Process all models to calculate support scores for ranking
|
||||
model_data = []
|
||||
for model_name in model_groups.keys():
|
||||
model_tests = model_groups[model_name]
|
||||
|
||||
# if debug:
|
||||
# print(f"DEBUG: Processing model '{model_name}' in provider '{provider_name}'")
|
||||
|
||||
# Extract unique test names for passed and failed tests
|
||||
passed_tests = set()
|
||||
failed_tests = set()
|
||||
all_test_names = set()
|
||||
|
||||
for test in model_tests:
|
||||
# Extract test name from nodeid (split on :: and [)
|
||||
test_name = test["nodeid"].split("::")[1].split("[")[0]
|
||||
all_test_names.add(test_name)
|
||||
|
||||
# if debug:
|
||||
# print(f" Test name: {test_name}")
|
||||
# print(f" Outcome: {test}")
|
||||
if test["outcome"] == "passed":
|
||||
passed_tests.add(test_name)
|
||||
elif test["outcome"] == "failed":
|
||||
failed_tests.add(test_name)
|
||||
|
||||
# if debug:
|
||||
# print(f" All test names found: {sorted(all_test_names)}")
|
||||
# print(f" Passed tests: {sorted(passed_tests)}")
|
||||
# print(f" Failed tests: {sorted(failed_tests)}")
|
||||
|
||||
# Categorize tests into features
|
||||
categorized_tests = categorize_tests(all_test_names, feature_mapping)
|
||||
|
||||
# if debug:
|
||||
# print(f" Categorized tests:")
|
||||
# for feature, tests in categorized_tests.items():
|
||||
# print(f" {feature}: {tests}")
|
||||
|
||||
# Determine support status for each feature
|
||||
feature_support = {}
|
||||
for feature_name in feature_mapping.keys():
|
||||
feature_support[feature_name] = get_support_status(passed_tests, categorized_tests[feature_name])
|
||||
|
||||
# if debug:
|
||||
# print(f" Feature support:")
|
||||
# for feature, status in feature_support.items():
|
||||
# print(f" {feature}: {status}")
|
||||
# print()
|
||||
|
||||
# Get context window and last scanned time
|
||||
context_window = model_tests[0]["metadata"]["llm_config"]["context_window"]
|
||||
|
||||
# Try to get time_last_scanned from metadata, fallback to current time
|
||||
try:
|
||||
last_scanned = model_tests[0]["metadata"].get(
|
||||
"time_last_scanned", model_tests[0]["metadata"].get("timestamp", datetime.now().isoformat())
|
||||
)
|
||||
# Format timestamp if it's a full ISO string
|
||||
if "T" in str(last_scanned):
|
||||
last_scanned = str(last_scanned).split("T")[0] # Just the date part
|
||||
except:
|
||||
last_scanned = "Unknown"
|
||||
|
||||
# Calculate support score for ranking
|
||||
feature_order = list(feature_mapping.keys())
|
||||
support_score = calculate_support_score(feature_support, feature_order)
|
||||
|
||||
# Store model data for sorting
|
||||
model_data.append(
|
||||
{
|
||||
"name": model_name,
|
||||
"feature_support": feature_support,
|
||||
"context_window": context_window,
|
||||
"last_scanned": last_scanned,
|
||||
"support_score": support_score,
|
||||
"failed_tests": failed_tests,
|
||||
"passed_tests": passed_tests,
|
||||
"categorized_tests": categorized_tests,
|
||||
}
|
||||
)
|
||||
|
||||
# Sort models by support score (descending) then by name (ascending)
|
||||
model_data.sort(key=lambda x: (-x["support_score"], x["name"]))
|
||||
|
||||
# Store provider data
|
||||
all_provider_data[provider_name] = model_data
|
||||
provider_support_scores[provider_name] = calculate_provider_support_score(model_data, list(feature_mapping.keys()))
|
||||
|
||||
# Calculate column widths for consistent formatting (add details column)
|
||||
column_widths = calculate_column_widths(all_provider_data, feature_mapping)
|
||||
column_widths["details"] = len("Details")
|
||||
|
||||
# Sort providers by support score (descending) then by name (ascending)
|
||||
sorted_providers = sorted(provider_support_scores.keys(), key=lambda x: (-provider_support_scores[x], x))
|
||||
|
||||
# Generate tables for all providers first
|
||||
for provider_name in sorted_providers:
|
||||
model_data = all_provider_data[provider_name]
|
||||
support_score = provider_support_scores[provider_name]
|
||||
|
||||
# Create dynamic headers with proper padding and centering
|
||||
feature_names = list(feature_mapping.keys())
|
||||
|
||||
# Build header row with left-aligned first column, centered others
|
||||
header_parts = [f"{'Model':<{column_widths['model']}}"]
|
||||
for feature in feature_names:
|
||||
header_parts.append(f"{feature:^{column_widths[feature]}}")
|
||||
header_parts.extend(
|
||||
[
|
||||
f"{'Context Window':^{column_widths['context_window']}}",
|
||||
f"{'Last Scanned':^{column_widths['last_scanned']}}",
|
||||
f"{'Details':^{column_widths['details']}}",
|
||||
]
|
||||
)
|
||||
header_row = "| " + " | ".join(header_parts) + " |"
|
||||
|
||||
# Build separator row with left-aligned first column, centered others
|
||||
separator_parts = [f"{'-' * column_widths['model']}"]
|
||||
for feature in feature_names:
|
||||
separator_parts.append(f":{'-' * (column_widths[feature] - 2)}:")
|
||||
separator_parts.extend(
|
||||
[
|
||||
f":{'-' * (column_widths['context_window'] - 2)}:",
|
||||
f":{'-' * (column_widths['last_scanned'] - 2)}:",
|
||||
f":{'-' * (column_widths['details'] - 2)}:",
|
||||
]
|
||||
)
|
||||
separator_row = "|" + "|".join(separator_parts) + "|"
|
||||
|
||||
# Add provider section without percentage
|
||||
mdx_lines.extend([f"## {provider_name}", "", header_row, separator_row])
|
||||
|
||||
# Generate table rows for sorted models with proper padding
|
||||
for model_info in model_data:
|
||||
# Create anchor for model details
|
||||
model_anchor = model_info["name"].replace("/", "_").replace(":", "_").replace("-", "_").lower()
|
||||
details_anchor = f"{provider_name.lower().replace(' ', '_')}_{model_anchor}_details"
|
||||
|
||||
# Build row with left-aligned first column, centered others
|
||||
row_parts = [f"`{model_info['name']}`".ljust(column_widths["model"])]
|
||||
for feature in feature_names:
|
||||
row_parts.append(f"{model_info['feature_support'][feature]:^{column_widths[feature]}}")
|
||||
row_parts.extend(
|
||||
[
|
||||
f"{model_info['context_window']:,}".center(column_widths["context_window"]),
|
||||
f"{model_info['last_scanned']}".center(column_widths["last_scanned"]),
|
||||
f"[View](#{details_anchor})".center(column_widths["details"]),
|
||||
]
|
||||
)
|
||||
row = "| " + " | ".join(row_parts) + " |"
|
||||
mdx_lines.append(row)
|
||||
|
||||
# Add spacing between provider tables
|
||||
mdx_lines.extend(["", ""])
|
||||
|
||||
# Add detailed test results section after all tables
|
||||
mdx_lines.extend(["---", "", "# Detailed Test Results", ""])
|
||||
|
||||
for provider_name in sorted_providers:
|
||||
model_data = all_provider_data[provider_name]
|
||||
mdx_lines.extend([f"## {provider_name}", ""])
|
||||
|
||||
for model_info in model_data:
|
||||
model_anchor = model_info["name"].replace("/", "_").replace(":", "_").replace("-", "_").lower()
|
||||
details_anchor = f"{provider_name.lower().replace(' ', '_')}_{model_anchor}_details"
|
||||
mdx_lines.append(f'<a id="{details_anchor}"></a>')
|
||||
mdx_lines.append(f"### {model_info['name']}")
|
||||
mdx_lines.append("")
|
||||
|
||||
# Add test details
|
||||
test_details = generate_test_details(model_info, feature_mapping)
|
||||
mdx_lines.extend(test_details)
|
||||
|
||||
# Add spacing between providers in details section
|
||||
mdx_lines.extend(["", ""])
|
||||
|
||||
# Write the MDX file
|
||||
with open(output_file, "w") as f:
|
||||
f.write("\n".join(mdx_lines))
|
||||
|
||||
print(f"Model sweep report saved to {output_file}")
|
||||
|
||||
|
||||
def main():
|
||||
input_file = "model_sweep_report.json"
|
||||
output_file = "model_sweep_report.mdx"
|
||||
config_file = None
|
||||
debug = False
|
||||
|
||||
# Allow command line arguments
|
||||
if len(sys.argv) > 1:
|
||||
# Use the file located in the same directory as this script
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
input_file = os.path.join(script_dir, sys.argv[1])
|
||||
if len(sys.argv) > 2:
|
||||
# Use the file located in the same directory as this script
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_file = os.path.join(script_dir, sys.argv[2])
|
||||
if len(sys.argv) > 3:
|
||||
config_file = sys.argv[3]
|
||||
if len(sys.argv) > 4 and sys.argv[4] == "--debug":
|
||||
debug = True
|
||||
|
||||
try:
|
||||
process_model_sweep_report(input_file, output_file, config_file, debug)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Could not find input file '{input_file}'")
|
||||
sys.exit(1)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: Invalid JSON in file '{input_file}'")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
786
.github/scripts/model-sweep/model_sweep.py
vendored
Normal file
786
.github/scripts/model-sweep/model_sweep.py
vendored
Normal file
@@ -0,0 +1,786 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta, MessageCreate, Run
|
||||
from letta_client.core.api_error import ApiError
|
||||
from letta_client.types import (
|
||||
AssistantMessage,
|
||||
Base64Image,
|
||||
ImageContent,
|
||||
LettaUsageStatistics,
|
||||
ReasoningMessage,
|
||||
TextContent,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
UrlImage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
# ------------------------------
|
||||
# Helper Functions and Constants
|
||||
# ------------------------------
|
||||
|
||||
|
||||
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
|
||||
filename = os.path.join(llm_config_dir, filename)
|
||||
config_data = json.load(open(filename, "r"))
|
||||
llm_config = LLMConfig(**config_data)
|
||||
return llm_config
|
||||
|
||||
|
||||
def roll_dice(num_sides: int) -> int:
|
||||
"""
|
||||
Returns a random number between 1 and num_sides.
|
||||
Args:
|
||||
num_sides (int): The number of sides on the die.
|
||||
Returns:
|
||||
int: A random integer between 1 and num_sides, representing the die roll.
|
||||
"""
|
||||
import random
|
||||
|
||||
return random.randint(1, num_sides)
|
||||
|
||||
|
||||
USER_MESSAGE_OTID = str(uuid.uuid4())
|
||||
USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work"
|
||||
USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=f"This is an automated test message. Call the send_message tool with the message '{USER_MESSAGE_RESPONSE}'.",
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="This is an automated test message. Call the roll_dice tool with 16 sides and tell me the outcome.",
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
URL_IMAGE = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg"
|
||||
USER_MESSAGE_URL_IMAGE: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
ImageContent(source=UrlImage(url=URL_IMAGE)),
|
||||
TextContent(text="What is in this image?"),
|
||||
],
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
BASE64_IMAGE = base64.standard_b64encode(httpx.get(URL_IMAGE).content).decode("utf-8")
|
||||
USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
ImageContent(source=Base64Image(data=BASE64_IMAGE, media_type="image/jpeg")),
|
||||
TextContent(text="What is in this image?"),
|
||||
],
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
# "azure-gpt-4o-mini.json", # TODO: Re-enable on new agent loop
|
||||
"claude-3-5-sonnet.json",
|
||||
"claude-3-7-sonnet.json",
|
||||
"claude-3-7-sonnet-extended.json",
|
||||
"gemini-1.5-pro.json",
|
||||
"gemini-2.5-flash-vertex.json",
|
||||
"gemini-2.5-pro-vertex.json",
|
||||
"together-qwen-2.5-72b-instruct.json",
|
||||
"ollama.json",
|
||||
]
|
||||
requested = os.getenv("LLM_CONFIG_FILE")
|
||||
filenames = [requested] if requested else all_configs
|
||||
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
|
||||
|
||||
|
||||
def assert_greeting_with_assistant_message_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
token_streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
ReasoningMessage -> AssistantMessage.
|
||||
"""
|
||||
expected_message_count = 3 if streaming or from_db else 2
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
if not token_streaming:
|
||||
assert USER_MESSAGE_RESPONSE in messages[index].content
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
assert messages[index].prompt_tokens > 0
|
||||
assert messages[index].completion_tokens > 0
|
||||
assert messages[index].total_tokens > 0
|
||||
assert messages[index].step_count > 0
|
||||
|
||||
|
||||
def assert_greeting_without_assistant_message_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
token_streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
ReasoningMessage -> ToolCallMessage -> ToolReturnMessage.
|
||||
"""
|
||||
expected_message_count = 4 if streaming or from_db else 3
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], ToolCallMessage)
|
||||
assert messages[index].tool_call.name == "send_message"
|
||||
if not token_streaming:
|
||||
assert USER_MESSAGE_RESPONSE in messages[index].tool_call.arguments
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
# Agent Step 2
|
||||
assert isinstance(messages[index], ToolReturnMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
|
||||
|
||||
def assert_tool_call_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
ReasoningMessage -> ToolCallMessage -> ToolReturnMessage ->
|
||||
ReasoningMessage -> AssistantMessage.
|
||||
"""
|
||||
expected_message_count = 6 if streaming else 7 if from_db else 5
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], ToolCallMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
# Agent Step 2
|
||||
assert isinstance(messages[index], ToolReturnMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
# Hidden User Message
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert "request_heartbeat=true" in messages[index].content
|
||||
index += 1
|
||||
|
||||
# Agent Step 3
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
|
||||
|
||||
def assert_image_input_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
token_streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
ReasoningMessage -> AssistantMessage.
|
||||
"""
|
||||
expected_message_count = 3 if streaming or from_db else 2
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
assert messages[index].prompt_tokens > 0
|
||||
assert messages[index].completion_tokens > 0
|
||||
assert messages[index].total_tokens > 0
|
||||
assert messages[index].step_count > 0
|
||||
|
||||
|
||||
def accumulate_chunks(chunks: List[Any]) -> List[Any]:
|
||||
"""
|
||||
Accumulates chunks into a list of messages.
|
||||
"""
|
||||
messages = []
|
||||
current_message = None
|
||||
prev_message_type = None
|
||||
for chunk in chunks:
|
||||
current_message_type = chunk.message_type
|
||||
if prev_message_type != current_message_type:
|
||||
messages.append(current_message)
|
||||
current_message = None
|
||||
if current_message is None:
|
||||
current_message = chunk
|
||||
else:
|
||||
pass # TODO: actually accumulate the chunks. For now we only care about the count
|
||||
prev_message_type = current_message_type
|
||||
messages.append(current_message)
|
||||
return [m for m in messages if m is not None]
|
||||
|
||||
|
||||
def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
|
||||
start = time.time()
|
||||
while True:
|
||||
run = client.runs.retrieve(run_id)
|
||||
if run.status == "completed":
|
||||
return run
|
||||
if run.status == "failed":
|
||||
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})")
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Asserts that a list of message dictionaries contains the expected types and statuses.
|
||||
|
||||
Expected order:
|
||||
1. reasoning_message
|
||||
2. tool_call_message
|
||||
3. tool_return_message (with status 'success')
|
||||
4. reasoning_message
|
||||
5. assistant_message
|
||||
"""
|
||||
assert isinstance(messages, list)
|
||||
assert messages[0]["message_type"] == "reasoning_message"
|
||||
assert messages[1]["message_type"] == "assistant_message"
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Test Cases
|
||||
# ------------------------------
|
||||
|
||||
# def test_that_ci_workflow_works(
|
||||
# disable_e2b_api_key: Any,
|
||||
# client: Letta,
|
||||
# agent_state: AgentState,
|
||||
# llm_config: LLMConfig,
|
||||
# json_metadata: pytest.FixtureRequest,
|
||||
# ) -> None:
|
||||
# """
|
||||
# Tests that the CI workflow works.
|
||||
# """
|
||||
# json_metadata["test_type"] = "debug"
|
||||
|
||||
|
||||
def test_greeting_with_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
assert_greeting_with_assistant_message_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_greeting_without_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
)
|
||||
assert_greeting_without_assistant_message_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
|
||||
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_tool_call(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
dice_tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
)
|
||||
assert_tool_call_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_url_image_input(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_URL_IMAGE,
|
||||
)
|
||||
assert_image_input_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_image_input_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_base64_image_input(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_BASE64_IMAGE,
|
||||
)
|
||||
assert_image_input_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_image_input_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_agent_loop_error(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that no new messages are persisted on error.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
tools = agent_state.tools
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[])
|
||||
with pytest.raises(ApiError):
|
||||
client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert len(messages_from_db) == 0
|
||||
client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools])
|
||||
|
||||
|
||||
def test_step_streaming_greeting_with_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_greeting_with_assistant_message_response(messages, streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_step_streaming_greeting_without_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_greeting_without_assistant_message_response(messages, streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
|
||||
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_step_streaming_tool_call(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
dice_tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
agent_state = client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_tool_call_response(messages, streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_step_stream_agent_loop_error(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that no new messages are persisted on error.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
tools = agent_state.tools
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[])
|
||||
with pytest.raises(ApiError):
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
list(response)
|
||||
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert len(messages_from_db) == 0
|
||||
client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools])
|
||||
|
||||
|
||||
def test_token_streaming_greeting_with_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
stream_tokens=True,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_token_streaming_greeting_without_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
stream_tokens=True,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
|
||||
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_token_streaming_tool_call(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
dice_tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
agent_state = client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
stream_tokens=True,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_tool_call_response(messages, streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_token_streaming_agent_loop_error(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that no new messages are persisted on error.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
tools = agent_state.tools
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[])
|
||||
try:
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
stream_tokens=True,
|
||||
)
|
||||
list(response)
|
||||
except:
|
||||
pass # only some models throw an error TODO: make this consistent
|
||||
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert len(messages_from_db) == 0
|
||||
client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools])
|
||||
|
||||
|
||||
def test_async_greeting_with_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message as an asynchronous job using the synchronous client.
|
||||
Waits for job completion and asserts that the result messages are as expected.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
run = client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
run = wait_for_run_completion(client, run.id)
|
||||
|
||||
result = run.metadata.get("result")
|
||||
assert result is not None, "Run metadata missing 'result' key"
|
||||
|
||||
messages = result["messages"]
|
||||
assert_tool_response_dict_messages(messages)
|
||||
|
||||
|
||||
def test_auto_summarize(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""Test that summarization is automatically triggered."""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
|
||||
# pydantic prevents us for overriding the context window paramter in the passed LLMConfig
|
||||
new_llm_config = llm_config.model_dump()
|
||||
new_llm_config["context_window"] = 3000
|
||||
pinned_context_window_llm_config = LLMConfig(**new_llm_config)
|
||||
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
temp_agent_state = client.agents.create(
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id],
|
||||
llm_config=pinned_context_window_llm_config,
|
||||
embedding="letta/letta-free",
|
||||
tags=["supervisor"],
|
||||
)
|
||||
|
||||
philosophical_question = """
|
||||
You know, sometimes I wonder if the entire structure of our lives is built on a series of unexamined assumptions we just silently agreed to somewhere along the way—like how we all just decided that five days a week of work and two days of “rest” constitutes balance, or how 9-to-5 became the default rhythm of a meaningful life, or even how the idea of “success” got boiled down to job titles and property ownership and productivity metrics on a LinkedIn profile, when maybe none of that is actually what makes a life feel full, or grounded, or real. And then there’s the weird paradox of ambition, how we're taught to chase it like a finish line that keeps moving, constantly redefining itself right as you’re about to grasp it—because even when you get the job, or the degree, or the validation, there's always something next, something more, like a treadmill with invisible settings you didn’t realize were turned up all the way.
|
||||
|
||||
And have you noticed how we rarely stop to ask who set those definitions for us? Like was there ever a council that decided, yes, owning a home by thirty-five and retiring by sixty-five is the universal template for fulfillment? Or did it just accumulate like cultural sediment over generations, layered into us so deeply that questioning it feels uncomfortable, even dangerous? And isn’t it strange that we spend so much of our lives trying to optimize things—our workflows, our diets, our sleep, our morning routines—as though the point of life is to operate more efficiently rather than to experience it more richly? We build these intricate systems, these rulebooks for being a “high-functioning” human, but where in all of that is the space for feeling lost, for being soft, for wandering without a purpose just because it’s a sunny day and your heart is tugging you toward nowhere in particular?
|
||||
|
||||
Sometimes I lie awake at night and wonder if all the noise we wrap around ourselves—notifications, updates, performance reviews, even our internal monologues—might be crowding out the questions we were meant to live into slowly, like how to love better, or how to forgive ourselves, or what the hell we’re even doing here in the first place. And when you strip it all down—no goals, no KPIs, no curated identity—what’s actually left of us? Are we just a sum of the roles we perform, or is there something quieter underneath that we've forgotten how to hear?
|
||||
|
||||
And if there is something underneath all of it—something real, something worth listening to—then how do we begin to uncover it, gently, without rushing or reducing it to another task on our to-do list?
|
||||
"""
|
||||
|
||||
MAX_ATTEMPTS = 10
|
||||
prev_length = None
|
||||
|
||||
for attempt in range(MAX_ATTEMPTS):
|
||||
client.agents.messages.create(
|
||||
agent_id=temp_agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=philosophical_question)],
|
||||
)
|
||||
|
||||
temp_agent_state = client.agents.retrieve(agent_id=temp_agent_state.id)
|
||||
message_ids = temp_agent_state.message_ids
|
||||
current_length = len(message_ids)
|
||||
|
||||
print("LENGTH OF IN_CONTEXT_MESSAGES:", current_length)
|
||||
|
||||
if prev_length is not None and current_length <= prev_length:
|
||||
# TODO: Add more stringent checks here
|
||||
print(f"Summarization was triggered, detected current_length {current_length} is at least prev_length {prev_length}.")
|
||||
break
|
||||
|
||||
prev_length = current_length
|
||||
else:
|
||||
raise AssertionError("Summarization was not triggered after 10 messages")
|
||||
4551
.github/scripts/model-sweep/supported-models.mdx
vendored
Normal file
4551
.github/scripts/model-sweep/supported-models.mdx
vendored
Normal file
File diff suppressed because it is too large
Load Diff
144
.github/workflows/model-sweep.yaml
vendored
Normal file
144
.github/workflows/model-sweep.yaml
vendored
Normal file
@@ -0,0 +1,144 @@
|
||||
name: Model Sweep
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch-name:
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
model-sweep:
|
||||
runs-on: [self-hosted, medium]
|
||||
services:
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
ports:
|
||||
- 6333:6333
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg17
|
||||
ports:
|
||||
- 5432:5432
|
||||
env:
|
||||
POSTGRES_HOST_AUTH_METHOD: trust
|
||||
POSTGRES_DB: postgres
|
||||
POSTGRES_USER: postgres
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
steps:
|
||||
- name: Check if gh is installed
|
||||
run: |
|
||||
if ! command -v gh >/dev/null 2>&1
|
||||
then
|
||||
echo "gh could not be found, installing now..."
|
||||
# install gh cli
|
||||
(type -p wget >/dev/null || (sudo apt update && sudo apt-get install wget -y)) \
|
||||
&& sudo mkdir -p -m 755 /etc/apt/keyrings \
|
||||
&& out=$(mktemp) && wget -nv -O$out https://cli.github.com/packages/githubcli-archive-keyring.gpg \
|
||||
&& cat $out | sudo tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
|
||||
&& sudo chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | sudo tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
|
||||
&& sudo apt update \
|
||||
&& sudo apt install gh -y
|
||||
fi
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Inject env vars into environment
|
||||
run: |
|
||||
# Get secrets and mask them before adding to environment
|
||||
while IFS= read -r line || [[ -n "$line" ]]; do
|
||||
if [[ -n "$line" ]]; then
|
||||
value=$(echo "$line" | cut -d= -f2-)
|
||||
echo "::add-mask::$value"
|
||||
echo "$line" >> $GITHUB_ENV
|
||||
fi
|
||||
done < <(letta_secrets_helper --env dev --service ci)
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry install --no-interaction --no-root ${{ inputs.install-args || '-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox -E google' }}
|
||||
- name: Migrate database
|
||||
env:
|
||||
LETTA_PG_PORT: 5432
|
||||
LETTA_PG_USER: postgres
|
||||
LETTA_PG_PASSWORD: postgres
|
||||
LETTA_PG_DB: postgres
|
||||
LETTA_PG_HOST: localhost
|
||||
run: |
|
||||
psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION vector'
|
||||
poetry run alembic upgrade head
|
||||
|
||||
- name: Run integration tests
|
||||
# if any of the 1000+ test cases fail, pytest reports exit code 1 and won't procces/upload the results
|
||||
continue-on-error: true
|
||||
env:
|
||||
LETTA_PG_PORT: 5432
|
||||
LETTA_PG_USER: postgres
|
||||
LETTA_PG_PASSWORD: postgres
|
||||
LETTA_PG_DB: postgres
|
||||
LETTA_PG_HOST: localhost
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
OPENAI_API_KEY: ${{ env.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ env.ANTHROPIC_API_KEY }}
|
||||
AZURE_API_KEY: ${{ env.AZURE_API_KEY }}
|
||||
AZURE_BASE_URL: ${{ secrets.AZURE_BASE_URL }}
|
||||
GEMINI_API_KEY: ${{ env.GEMINI_API_KEY }}
|
||||
COMPOSIO_API_KEY: ${{ env.COMPOSIO_API_KEY }}
|
||||
GOOGLE_CLOUD_PROJECT: ${{ secrets.GOOGLE_CLOUD_PROJECT}}
|
||||
GOOGLE_CLOUD_LOCATION: ${{ secrets.GOOGLE_CLOUD_LOCATION}}
|
||||
DEEPSEEK_API_KEY: ${{ env.DEEPSEEK_API_KEY}}
|
||||
LETTA_USE_EXPERIMENTAL: 1
|
||||
run: |
|
||||
poetry run pytest \
|
||||
-s -vv \
|
||||
.github/scripts/model-sweep/model_sweep.py \
|
||||
--json-report --json-report-file=.github/scripts/model-sweep/model_sweep_report.json --json-report-indent=4
|
||||
|
||||
- name: Convert report to markdown
|
||||
continue-on-error: true
|
||||
# file path args to generate_model_sweep_markdown.py are relative to the script
|
||||
run: |
|
||||
poetry run python \
|
||||
.github/scripts/model-sweep/generate_model_sweep_markdown.py \
|
||||
.github/scripts/model-sweep/model_sweep_report.json \
|
||||
.github/scripts/model-sweep/supported-models.mdx
|
||||
echo "Model sweep report saved to .github/scripts/model-sweep/supported-models.mdx"
|
||||
|
||||
- id: date
|
||||
run: echo "date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: commit and open pull request
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
BRANCH_NAME=model-sweep/${{ inputs.branch-name || format('{0}', steps.date.outputs.date) }}
|
||||
gh auth setup-git
|
||||
git config --global user.name "github-actions[bot]"
|
||||
git config --global user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git checkout -b $BRANCH_NAME
|
||||
git add .github/scripts/model-sweep/supported-models.mdx
|
||||
git commit -m "Update model sweep report"
|
||||
# only push if changes were made
|
||||
if git diff main --quiet; then
|
||||
echo "No changes detected, skipping push"
|
||||
exit 0
|
||||
else
|
||||
git push origin $BRANCH_NAME
|
||||
gh pr create \
|
||||
--base main \
|
||||
--head $BRANCH_NAME \
|
||||
--title "chore: update model sweep report" \
|
||||
--body "Automated PR to update model sweep report"
|
||||
fi
|
||||
|
||||
- name: Upload model sweep report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: model-sweep-report
|
||||
path: .github/scripts/model-sweep/model_sweep_report.json
|
||||
155
.github/workflows/send-message-integration-tests.yaml
vendored
Normal file
155
.github/workflows/send-message-integration-tests.yaml
vendored
Normal file
@@ -0,0 +1,155 @@
|
||||
name: Send Message SDK Tests
|
||||
on:
|
||||
pull_request_target:
|
||||
# branches: [main] # TODO: uncomment before merge
|
||||
types: [labeled]
|
||||
paths:
|
||||
- 'letta/**'
|
||||
|
||||
jobs:
|
||||
send-messages:
|
||||
# Only run when the "safe to test" label is applied
|
||||
if: contains(github.event.pull_request.labels.*.name, 'safe to test')
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config_file:
|
||||
- "openai-gpt-4o-mini.json"
|
||||
- "azure-gpt-4o-mini.json"
|
||||
- "claude-3-5-sonnet.json"
|
||||
- "claude-3-7-sonnet.json"
|
||||
- "claude-3-7-sonnet-extended.json"
|
||||
- "gemini-pro.json"
|
||||
- "gemini-vertex.json"
|
||||
services:
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
ports:
|
||||
- 6333:6333
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg17
|
||||
ports:
|
||||
- 5432:5432
|
||||
env:
|
||||
POSTGRES_HOST_AUTH_METHOD: trust
|
||||
POSTGRES_DB: postgres
|
||||
POSTGRES_USER: postgres
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
steps:
|
||||
# Ensure secrets don't leak
|
||||
- name: Configure git to hide secrets
|
||||
run: |
|
||||
git config --global core.logAllRefUpdates false
|
||||
git config --global log.hideCredentials true
|
||||
- name: Set up secret masking
|
||||
run: |
|
||||
# Automatically mask any environment variable ending with _KEY
|
||||
for var in $(env | grep '_KEY=' | cut -d= -f1); do
|
||||
value="${!var}"
|
||||
if [[ -n "$value" ]]; then
|
||||
# Mask the full value
|
||||
echo "::add-mask::$value"
|
||||
|
||||
# Also mask partial values (first and last several characters)
|
||||
# This helps when only parts of keys appear in logs
|
||||
if [[ ${#value} -gt 8 ]]; then
|
||||
echo "::add-mask::${value:0:8}"
|
||||
echo "::add-mask::${value:(-8)}"
|
||||
fi
|
||||
|
||||
# Also mask with common formatting changes
|
||||
# Some logs might add quotes or other characters
|
||||
echo "::add-mask::\"$value\""
|
||||
echo "::add-mask::$value\""
|
||||
echo "::add-mask::\"$value"
|
||||
|
||||
echo "Masked secret: $var (length: ${#value})"
|
||||
fi
|
||||
done
|
||||
|
||||
# Check out base repository code, not the PR's code (for security)
|
||||
- name: Checkout base repository
|
||||
uses: actions/checkout@v4 # No ref specified means it uses base branch
|
||||
|
||||
# Only extract relevant files from the PR (for security, specifically prevent modification of workflow files)
|
||||
- name: Extract PR schema files
|
||||
run: |
|
||||
# Fetch PR without checking it out
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr-${{ github.event.pull_request.number }}
|
||||
|
||||
# Extract ONLY the schema files
|
||||
git checkout pr-${{ github.event.pull_request.number }} -- letta/
|
||||
- name: Set up python 3.12
|
||||
id: setup-python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.12
|
||||
- name: Load cached Poetry Binary
|
||||
id: cached-poetry-binary
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.local
|
||||
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-1.8.3
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1
|
||||
with:
|
||||
version: 1.8.3
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
- name: Load cached venv
|
||||
id: cached-poetry-dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}${{ inputs.install-args || '-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox' }}
|
||||
# Restore cache with this prefix if not exact match with key
|
||||
# Note cache-hit returns false in this case, so the below step will run
|
||||
restore-keys: |
|
||||
venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-
|
||||
- name: Install dependencies
|
||||
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: poetry install --no-interaction --no-root ${{ inputs.install-args || '-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox -E google' }}
|
||||
- name: Install letta packages via Poetry
|
||||
run: |
|
||||
poetry run pip install --upgrade letta-client letta
|
||||
- name: Migrate database
|
||||
env:
|
||||
LETTA_PG_PORT: 5432
|
||||
LETTA_PG_USER: postgres
|
||||
LETTA_PG_PASSWORD: postgres
|
||||
LETTA_PG_DB: postgres
|
||||
LETTA_PG_HOST: localhost
|
||||
run: |
|
||||
psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION vector'
|
||||
poetry run alembic upgrade head
|
||||
- name: Run integration tests for ${{ matrix.config_file }}
|
||||
env:
|
||||
LLM_CONFIG_FILE: ${{ matrix.config_file }}
|
||||
LETTA_PG_PORT: 5432
|
||||
LETTA_PG_USER: postgres
|
||||
LETTA_PG_PASSWORD: postgres
|
||||
LETTA_PG_DB: postgres
|
||||
LETTA_PG_HOST: localhost
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
AZURE_BASE_URL: ${{ secrets.AZURE_BASE_URL }}
|
||||
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||
COMPOSIO_API_KEY: ${{ secrets.COMPOSIO_API_KEY }}
|
||||
DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }}
|
||||
GOOGLE_CLOUD_PROJECT: ${{ secrets.GOOGLE_CLOUD_PROJECT }}
|
||||
GOOGLE_CLOUD_LOCATION: ${{ secrets.GOOGLE_CLOUD_LOCATION }}
|
||||
run: |
|
||||
poetry run pytest \
|
||||
-s -vv \
|
||||
tests/integration_test_send_message.py \
|
||||
--maxfail=1 --durations=10
|
||||
@@ -28,7 +28,7 @@ First, install Poetry using [the official instructions here](https://python-poet
|
||||
Once Poetry is installed, navigate to the letta directory and install the Letta project with Poetry:
|
||||
```shell
|
||||
cd letta
|
||||
poetry shell
|
||||
eval $(poetry env activate)
|
||||
poetry install --all-extras
|
||||
```
|
||||
#### Setup PostgreSQL environment (optional)
|
||||
|
||||
17
README.md
17
README.md
@@ -8,26 +8,13 @@
|
||||
|
||||
<div align="center">
|
||||
<h1>Letta (previously MemGPT)</h1>
|
||||
|
||||
**☄️ New release: Letta Agent Development Environment (_read more [here](#-access-the-ade-agent-development-environment)_) ☄️**
|
||||
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/letta-ai/letta/refs/heads/main/assets/example_ade_screenshot.png">
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/letta-ai/letta/refs/heads/main/assets/example_ade_screenshot_light.png">
|
||||
<img alt="Letta logo" src="https://raw.githubusercontent.com/letta-ai/letta/refs/heads/main/assets/example_ade_screenshot.png" width="800">
|
||||
</picture>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
<h3>
|
||||
|
||||
[Homepage](https://letta.com) // [Documentation](https://docs.letta.com) // [ADE](https://docs.letta.com/agent-development-environment) // [Letta Cloud](https://forms.letta.com/early-access)
|
||||
|
||||
</h3>
|
||||
|
||||
**👾 Letta** is an open source framework for building stateful LLM applications. You can use Letta to build **stateful agents** with advanced reasoning capabilities and transparent long-term memory. The Letta framework is white box and model-agnostic.
|
||||
**👾 Letta** is an open source framework for building **stateful agents** with advanced reasoning capabilities and transparent long-term memory. The Letta framework is white box and model-agnostic.
|
||||
|
||||
[](https://discord.gg/letta)
|
||||
[](https://twitter.com/Letta_AI)
|
||||
@@ -157,7 +144,7 @@ No, the data in your Letta server database stays on your machine. The Letta ADE
|
||||
|
||||
> _"Do I have to use your ADE? Can I build my own?"_
|
||||
|
||||
The ADE is built on top of the (fully open source) Letta server and Letta Agents API. You can build your own application like the ADE on top of the REST API (view the documention [here](https://docs.letta.com/api-reference)).
|
||||
The ADE is built on top of the (fully open source) Letta server and Letta Agents API. You can build your own application like the ADE on top of the REST API (view the documentation [here](https://docs.letta.com/api-reference)).
|
||||
|
||||
> _"Can I interact with Letta agents via the CLI?"_
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ services:
|
||||
- "8083:8083"
|
||||
- "8283:8283"
|
||||
environment:
|
||||
- SERPAPI_API_KEY=${SERPAPI_API_KEY}
|
||||
- LETTA_PG_DB=${LETTA_PG_DB:-letta}
|
||||
- LETTA_PG_USER=${LETTA_PG_USER:-letta}
|
||||
- LETTA_PG_PASSWORD=${LETTA_PG_PASSWORD:-letta}
|
||||
|
||||
@@ -8,6 +8,8 @@ If you're using Letta Cloud, replace 'baseURL' with 'token'
|
||||
See: https://docs.letta.com/api-reference/overview
|
||||
|
||||
Execute this script using `poetry run python3 example.py`
|
||||
|
||||
This will install `letta_client` and other dependencies.
|
||||
"""
|
||||
client = Letta(
|
||||
base_url="http://localhost:8283",
|
||||
|
||||
34
examples/files/README.md
Normal file
34
examples/files/README.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# Letta Files and Streaming Demo
|
||||
|
||||
This demo shows how to work with Letta's file upload and streaming capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- Upload files from disk to a Letta data source
|
||||
- Create files from strings and upload them
|
||||
- Download and upload PDF files
|
||||
- Create an agent and attach data sources
|
||||
- Stream agent responses in real-time
|
||||
- Interactive chat with file-aware agent
|
||||
|
||||
## Files
|
||||
|
||||
- `main.py` - Main demo script showing file upload and streaming
|
||||
- `example-on-disk.txt` - Sample text file for upload demonstration
|
||||
- `memgpt.pdf` - MemGPT paper (downloaded automatically)
|
||||
|
||||
## Setup
|
||||
|
||||
1. Set your Letta API key: `export LETTA_API_KEY=your_key_here`
|
||||
2. Install dependencies: `pip install letta-client requests rich`
|
||||
3. Run the demo: `python main.py`
|
||||
|
||||
## Usage
|
||||
|
||||
The demo will:
|
||||
1. Create a data source called "Example Source"
|
||||
2. Upload the example text file and PDF
|
||||
3. Create an agent named "Clippy"
|
||||
4. Start an interactive chat session
|
||||
|
||||
Type 'quit' or 'exit' to end the conversation.
|
||||
2
examples/files/example-on-disk.txt
Normal file
2
examples/files/example-on-disk.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
Hey, you're looking at a different example.
|
||||
This password is "stateful agents".
|
||||
190
examples/files/main.py
Normal file
190
examples/files/main.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Letta Filesystem
|
||||
|
||||
This demo shows how to:
|
||||
1. Create a folder and upload files (both from disk and from strings)
|
||||
2. Create an agent and attach the data folder
|
||||
3. Stream the agent's responses
|
||||
4. Query the agent about the uploaded files
|
||||
|
||||
The demo uploads:
|
||||
- A text file from disk (example-on-disk.txt)
|
||||
- A text file created from a string (containing a password)
|
||||
- The MemGPT paper PDF from arXiv
|
||||
|
||||
Then asks the agent to summarize the paper and find passwords in the files.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import requests
|
||||
from letta_client import Letta
|
||||
from letta_client.core.api_error import ApiError
|
||||
from rich import print
|
||||
|
||||
LETTA_API_KEY = os.getenv("LETTA_API_KEY")
|
||||
if LETTA_API_KEY is None:
|
||||
raise ValueError("LETTA_API_KEY is not set")
|
||||
|
||||
FOLDER_NAME = "Example Folder"
|
||||
|
||||
# Connect to our Letta server
|
||||
client = Letta(token=LETTA_API_KEY)
|
||||
|
||||
# get an available embedding_config
|
||||
embedding_configs = client.embedding_models.list()
|
||||
embedding_config = embedding_configs[0]
|
||||
|
||||
# Check if the folder already exists
|
||||
try:
|
||||
folder_id = client.folders.retrieve_by_name(FOLDER_NAME)
|
||||
|
||||
# We got an API error. Check if it's a 404, meaning the folder doesn't exist.
|
||||
except ApiError as e:
|
||||
if e.status_code == 404:
|
||||
# Create a new folder
|
||||
folder = client.folders.create(
|
||||
name=FOLDER_NAME,
|
||||
description="This is an example folder",
|
||||
instructions="Use this data folder to see how Letta works.",
|
||||
)
|
||||
folder_id = folder.id
|
||||
else:
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
# Something else went wrong
|
||||
raise e
|
||||
|
||||
|
||||
#
|
||||
# There's two ways to upload a file to a folder.
|
||||
#
|
||||
# 1. From an existing file
|
||||
# 2. From a string by encoding it into a base64 string
|
||||
#
|
||||
#
|
||||
|
||||
# 1. From an existing file
|
||||
# "rb" means "read binary"
|
||||
file = open("example-on-disk.txt", "rb")
|
||||
|
||||
# Upload the file to the folder
|
||||
file = client.folders.files.upload(
|
||||
folder_id=folder_id,
|
||||
file=file,
|
||||
duplicate_handling="skip"
|
||||
)
|
||||
|
||||
# 2. From a string by encoding it into a base64 string
|
||||
import io
|
||||
|
||||
content = """
|
||||
This is an example file. If you can read this,
|
||||
the password is 'letta'.
|
||||
"""
|
||||
|
||||
# Encode the string into bytes, and then create a file-like object
|
||||
# that exists only in memory.
|
||||
file_object = io.BytesIO(content.encode("utf-8"))
|
||||
|
||||
# Set the name of the file
|
||||
file_object.name = "example.txt"
|
||||
|
||||
# Upload the file to the folder
|
||||
file = client.folders.files.upload(
|
||||
folder_id=folder_id,
|
||||
file=file_object,
|
||||
duplicate_handling="skip"
|
||||
)
|
||||
|
||||
#
|
||||
# You can also upload PDFs!
|
||||
# Letta extracts text from PDFs using OCR.
|
||||
#
|
||||
|
||||
# Download the PDF to the local directory if it doesn't exist
|
||||
if not os.path.exists("memgpt.pdf"):
|
||||
# Download the PDF
|
||||
print("Downloading memgpt.pdf")
|
||||
response = requests.get("https://arxiv.org/pdf/2310.08560")
|
||||
with open("memgpt.pdf", "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
# Upload the PDF to the folder
|
||||
file = client.folders.files.upload(
|
||||
folder_id=folder_id,
|
||||
file=open("memgpt.pdf", "rb"),
|
||||
duplicate_handling="skip"
|
||||
)
|
||||
|
||||
#
|
||||
# Now we need to create an agent that can use this folder
|
||||
#
|
||||
|
||||
# Create an agent
|
||||
agent = client.agents.create(
|
||||
model="openai/gpt-4o-mini",
|
||||
name="Example Agent",
|
||||
description="This agent looks at files and answers questions about them.",
|
||||
memory_blocks = [
|
||||
{
|
||||
"label": "human",
|
||||
"value": "The human wants to know about the files."
|
||||
},
|
||||
{
|
||||
"label": "persona",
|
||||
"value": "My name is Clippy, I answer questions about files."
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Attach the data folder to the agent.
|
||||
# Once the folder is attached, the agent will be able to see all
|
||||
# files in the folder.
|
||||
client.agents.folders.attach(
|
||||
agent_id=agent.id,
|
||||
folder_id=folder_id
|
||||
)
|
||||
|
||||
########################################################
|
||||
# This code makes a simple chatbot interface to the agent
|
||||
########################################################
|
||||
|
||||
# Wrap this in a try/catch block to remove the agent in the event of an error
|
||||
try:
|
||||
print(f"🤖 Connected to agent: {agent.name}")
|
||||
print("💡 Type 'quit' or 'exit' to end the conversation")
|
||||
print("=" * 50)
|
||||
|
||||
while True:
|
||||
# Get user input
|
||||
try:
|
||||
user_input = input("\n👤 You: ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\n👋 Goodbye!")
|
||||
break
|
||||
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
print("👋 Goodbye!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Stream the agent's response
|
||||
stream = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_input
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
print(chunk)
|
||||
|
||||
finally:
|
||||
client.agents.delete(agent.id)
|
||||
@@ -2,22 +2,33 @@ from pprint import pprint
|
||||
|
||||
from letta_client import Letta
|
||||
|
||||
# Connect to Letta server
|
||||
client = Letta(base_url="http://localhost:8283")
|
||||
|
||||
# Use the "everything" mcp server:
|
||||
# https://github.com/modelcontextprotocol/servers/tree/main/src/everything
|
||||
mcp_server_name = "everything"
|
||||
mcp_tool_name = "echo"
|
||||
|
||||
# List all McpTool belonging to the "everything" mcp server.
|
||||
mcp_tools = client.tools.list_mcp_tools_by_server(
|
||||
mcp_server_name=mcp_server_name,
|
||||
)
|
||||
|
||||
# We can see that "echo" is one of the tools, but it's not
|
||||
# a letta tool that can be added to a client (it has no tool id).
|
||||
for tool in mcp_tools:
|
||||
pprint(tool)
|
||||
|
||||
# Create a Tool (with a tool id) using the server and tool names.
|
||||
mcp_tool = client.tools.add_mcp_tool(
|
||||
mcp_server_name=mcp_server_name,
|
||||
mcp_tool_name=mcp_tool_name
|
||||
)
|
||||
|
||||
# Create an agent with the tool, using tool.id -- note that
|
||||
# this is the ONLY tool in the agent, you typically want to
|
||||
# also include the default tools.
|
||||
agent = client.agents.create(
|
||||
memory_blocks=[
|
||||
{
|
||||
@@ -31,6 +42,7 @@ agent = client.agents.create(
|
||||
)
|
||||
print(f"Created agent id {agent.id}")
|
||||
|
||||
# Ask the agent to call the tool.
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
|
||||
@@ -253,15 +253,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"id": "7808912f-831b-4cdc-8606-40052eb809b4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Optional, List\n",
|
||||
"from typing import Optional, List, TYPE_CHECKING\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"def task_queue_push(self: \"Agent\", task_description: str):\n",
|
||||
"if TYPE_CHECKING:\n",
|
||||
" from letta import AgentState\n",
|
||||
"\n",
|
||||
"def task_queue_push(agent_state: \"AgentState\", task_description: str):\n",
|
||||
" \"\"\"\n",
|
||||
" Push to a task queue stored in core memory. \n",
|
||||
"\n",
|
||||
@@ -273,12 +276,12 @@
|
||||
" does not produce a response.\n",
|
||||
" \"\"\"\n",
|
||||
" import json\n",
|
||||
" tasks = json.loads(self.memory.get_block(\"tasks\").value)\n",
|
||||
" tasks = json.loads(agent_state.memory.get_block(\"tasks\").value)\n",
|
||||
" tasks.append(task_description)\n",
|
||||
" self.memory.update_block_value(\"tasks\", json.dumps(tasks))\n",
|
||||
" agent_state.memory.update_block_value(\"tasks\", json.dumps(tasks))\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
"def task_queue_pop(self: \"Agent\"):\n",
|
||||
"def task_queue_pop(agent_state: \"AgentState\"):\n",
|
||||
" \"\"\"\n",
|
||||
" Get the next task from the task queue \n",
|
||||
"\n",
|
||||
@@ -288,12 +291,12 @@
|
||||
" None (the task queue is empty)\n",
|
||||
" \"\"\"\n",
|
||||
" import json\n",
|
||||
" tasks = json.loads(self.memory.get_block(\"tasks\").value)\n",
|
||||
" tasks = json.loads(agent_state.memory.get_block(\"tasks\").value)\n",
|
||||
" if len(tasks) == 0: \n",
|
||||
" return None\n",
|
||||
" task = tasks[0]\n",
|
||||
" print(\"CURRENT TASKS: \", tasks)\n",
|
||||
" self.memory.update_block_value(\"tasks\", json.dumps(tasks[1:]))\n",
|
||||
" agent_state.memory.update_block_value(\"tasks\", json.dumps(tasks[1:]))\n",
|
||||
" return task\n",
|
||||
"\n",
|
||||
"push_task_tool = client.tools.upsert_from_function(func=task_queue_push)\n",
|
||||
@@ -310,7 +313,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": null,
|
||||
"id": "135fcf3e-59c4-4da3-b86b-dbffb21aa343",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -336,10 +339,12 @@
|
||||
" ),\n",
|
||||
" CreateBlock(\n",
|
||||
" label=\"tasks\",\n",
|
||||
" value=\"\",\n",
|
||||
" value=\"[]\",\n",
|
||||
" ),\n",
|
||||
" ],\n",
|
||||
" tool_ids=[push_task_tool.id, pop_task_tool.id],\n",
|
||||
" model=\"letta/letta-free\",\n",
|
||||
" embedding=\"letta/letta-free\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .main import app
|
||||
|
||||
app()
|
||||
@@ -46,7 +46,7 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
# TODO: add paging by page number. currently cursor only works with strings.
|
||||
# original: start=page * count
|
||||
messages = self.message_manager.list_user_messages_for_agent(
|
||||
messages = self.message_manager.list_messages_for_agent(
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
query_text=query,
|
||||
|
||||
@@ -242,7 +242,8 @@ class AnthropicClient(LLMClientBase):
|
||||
# Move 'system' to the top level
|
||||
if messages[0].role != "system":
|
||||
raise RuntimeError(f"First message is not a system message, instead has role {messages[0].role}")
|
||||
data["system"] = messages[0].content if isinstance(messages[0].content, str) else messages[0].content[0].text
|
||||
system_content = messages[0].content if isinstance(messages[0].content, str) else messages[0].content[0].text
|
||||
data["system"] = self._add_cache_control_to_system_message(system_content)
|
||||
data["messages"] = [
|
||||
m.to_anthropic_dict(
|
||||
inner_thoughts_xml_tag=inner_thoughts_xml_tag,
|
||||
@@ -499,6 +500,22 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
def _add_cache_control_to_system_message(self, system_content):
|
||||
"""Add cache control to system message content"""
|
||||
if isinstance(system_content, str):
|
||||
# For string content, convert to list format with cache control
|
||||
return [{"type": "text", "text": system_content, "cache_control": {"type": "ephemeral"}}]
|
||||
elif isinstance(system_content, list):
|
||||
# For list content, add cache control to the last text block
|
||||
cached_content = system_content.copy()
|
||||
for i in range(len(cached_content) - 1, -1, -1):
|
||||
if cached_content[i].get("type") == "text":
|
||||
cached_content[i]["cache_control"] = {"type": "ephemeral"}
|
||||
break
|
||||
return cached_content
|
||||
|
||||
return system_content
|
||||
|
||||
|
||||
def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]:
|
||||
"""See: https://docs.anthropic.com/claude/docs/tool-use
|
||||
|
||||
@@ -63,7 +63,8 @@ class LLMConfig(BaseModel):
|
||||
description="The reasoning effort to use when generating text reasoning models",
|
||||
)
|
||||
max_reasoning_tokens: int = Field(
|
||||
0, description="Configurable thinking budget for extended thinking, only used if enable_reasoner is True. Minimum value is 1024."
|
||||
0,
|
||||
description="Configurable thinking budget for extended thinking. Used for enable_reasoner and also for Google Vertex models like Gemini 2.5 Flash. Minimum value is 1024 when used with enable_reasoner.",
|
||||
)
|
||||
frequency_penalty: Optional[float] = Field(
|
||||
None, # Can also deafult to 0.0?
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Literal
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
@@ -43,7 +42,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
|
||||
configs = []
|
||||
for model in response_json["models"]:
|
||||
context_window = self.get_model_context_window(model["name"])
|
||||
context_window = await self._get_model_context_window(model["name"])
|
||||
if context_window is None:
|
||||
print(f"Ollama model {model['name']} has no context window, using default 32000")
|
||||
context_window = 32000
|
||||
@@ -92,63 +91,50 @@ class OllamaProvider(OpenAIProvider):
|
||||
)
|
||||
return configs
|
||||
|
||||
def get_model_context_window(self, model_name: str) -> int | None:
|
||||
"""Gets model context window for Ollama. As this can look different based on models,
|
||||
we use the following for guidance:
|
||||
|
||||
"llama.context_length": 8192,
|
||||
"llama.embedding_length": 4096,
|
||||
source: https://github.com/ollama/ollama/blob/main/docs/api.md#show-model-information
|
||||
|
||||
FROM 2024-10-08
|
||||
Notes from vLLM around keys
|
||||
source: https://github.com/vllm-project/vllm/blob/72ad2735823e23b4e1cc79b7c73c3a5f3c093ab0/vllm/config.py#L3488
|
||||
|
||||
possible_keys = [
|
||||
# OPT
|
||||
"max_position_embeddings",
|
||||
# GPT-2
|
||||
"n_positions",
|
||||
# MPT
|
||||
"max_seq_len",
|
||||
# ChatGLM2
|
||||
"seq_length",
|
||||
# Command-R
|
||||
"model_max_length",
|
||||
# Whisper
|
||||
"max_target_positions",
|
||||
# Others
|
||||
"max_sequence_length",
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
max_position_embeddings
|
||||
parse model cards: nous, dolphon, llama
|
||||
"""
|
||||
async def _get_model_context_window(self, model_name: str) -> int | None:
|
||||
endpoint = f"{self.base_url}/api/show"
|
||||
payload = {"name": model_name, "verbose": True}
|
||||
response = requests.post(endpoint, json=payload)
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
payload = {"name": model_name}
|
||||
|
||||
try:
|
||||
model_info = response.json()
|
||||
# Try to extract context window from model parameters
|
||||
if "model_info" in model_info and "llama.context_length" in model_info["model_info"]:
|
||||
return int(model_info["model_info"]["llama.context_length"])
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning(f"Failed to get model context window for {model_name}")
|
||||
return None
|
||||
|
||||
async def _get_model_embedding_dim_async(self, model_name: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) as response:
|
||||
response_json = await response.json()
|
||||
|
||||
if "model_info" not in response_json:
|
||||
if "error" in response_json:
|
||||
logger.warning("Ollama fetch model info error for %s: %s", model_name, response_json["error"])
|
||||
async with session.post(endpoint, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.warning(f"Failed to get model info for {model_name}: {response.status} - {error_text}")
|
||||
return None
|
||||
|
||||
return response_json["model_info"].get("embedding_length")
|
||||
response_json = await response.json()
|
||||
model_info = response_json.get("model_info", {})
|
||||
|
||||
if architecture := model_info.get("general.architecture"):
|
||||
if context_length := model_info.get(f"{architecture}.context_length"):
|
||||
return int(context_length)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model context window for {model_name} with error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_model_embedding_dim(self, model_name: str) -> int | None:
|
||||
endpoint = f"{self.base_url}/api/show"
|
||||
payload = {"name": model_name}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(endpoint, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.warning(f"Failed to get model info for {model_name}: {response.status} - {error_text}")
|
||||
return None
|
||||
|
||||
response_json = await response.json()
|
||||
model_info = response_json.get("model_info", {})
|
||||
|
||||
if architecture := model_info.get("general.architecture"):
|
||||
if embedding_length := model_info.get(f"{architecture}.embedding_length"):
|
||||
return int(embedding_length)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model embedding dimension for {model_name} with error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
@@ -151,6 +151,7 @@ def test_archival(agent_obj):
|
||||
def test_recall(server, agent_obj, default_user):
|
||||
"""Test that an agent can recall messages using a keyword via conversation search."""
|
||||
keyword = "banana"
|
||||
"".join(reversed(keyword))
|
||||
|
||||
# Send messages
|
||||
for msg in ["hello", keyword, "tell me a fun fact"]:
|
||||
|
||||
Reference in New Issue
Block a user