chore: merge oss (#3712)
This commit is contained in:
286
.github/scripts/model-sweep/conftest.py
vendored
Normal file
286
.github/scripts/model-sweep/conftest.py
vendored
Normal file
@@ -0,0 +1,286 @@
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta, Letta
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import tool_settings
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_e2b_api_key() -> Generator[None, None, None]:
|
||||
"""
|
||||
Temporarily disables the E2B API key by setting `tool_settings.e2b_api_key` to None
|
||||
for the duration of the test. Restores the original value afterward.
|
||||
"""
|
||||
from letta.settings import tool_settings
|
||||
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
tool_settings.e2b_api_key = None
|
||||
yield
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_e2b_key_is_set():
|
||||
from letta.settings import tool_settings
|
||||
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
assert original_api_key is not None, "Missing e2b key! Cannot execute these tests."
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_organization():
|
||||
"""Fixture to create and return the default organization."""
|
||||
manager = OrganizationManager()
|
||||
org = manager.create_default_organization()
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
manager = UserManager()
|
||||
user = manager.create_default_user(org_id=default_organization.id)
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_composio_key_set():
|
||||
original_api_key = tool_settings.composio_api_key
|
||||
assert original_api_key is not None, "Missing composio key! Cannot execute this test."
|
||||
yield
|
||||
|
||||
|
||||
# --- Tool Fixtures ---
|
||||
@pytest.fixture
|
||||
def weather_tool_func():
|
||||
def get_weather(location: str) -> str:
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
|
||||
Parameters:
|
||||
location (str): The location to get the weather for.
|
||||
|
||||
Returns:
|
||||
str: A formatted string describing the weather in the given location.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the request to fetch weather data fails.
|
||||
"""
|
||||
import requests
|
||||
|
||||
url = f"https://wttr.in/{location}?format=%C+%t"
|
||||
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
weather_data = response.text
|
||||
return f"The weather in {location} is {weather_data}."
|
||||
else:
|
||||
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
||||
|
||||
yield get_weather
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def print_tool_func():
|
||||
"""Fixture to create a tool with default settings and clean up after the test."""
|
||||
|
||||
def print_tool(message: str):
|
||||
"""
|
||||
Args:
|
||||
message (str): The message to print.
|
||||
|
||||
Returns:
|
||||
str: The message that was printed.
|
||||
"""
|
||||
print(message)
|
||||
return message
|
||||
|
||||
yield print_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def roll_dice_tool_func():
|
||||
def roll_dice():
|
||||
"""
|
||||
Rolls a 6 sided die.
|
||||
|
||||
Returns:
|
||||
str: The roll result.
|
||||
"""
|
||||
import time
|
||||
|
||||
time.sleep(1)
|
||||
return "Rolled a 10!"
|
||||
|
||||
yield roll_dice
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_beta_message_batch() -> BetaMessageBatch:
|
||||
return BetaMessageBatch(
|
||||
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
processing_status="in_progress",
|
||||
request_counts=BetaMessageBatchRequestCounts(
|
||||
canceled=10,
|
||||
errored=30,
|
||||
expired=10,
|
||||
processing=100,
|
||||
succeeded=50,
|
||||
),
|
||||
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
|
||||
type="message_batch",
|
||||
)
|
||||
|
||||
|
||||
# --- Model Sweep ---
|
||||
# Global flag to track server state
|
||||
_server_started = False
|
||||
_server_url = None
|
||||
|
||||
|
||||
def _start_server_once() -> str:
|
||||
"""Start server exactly once, return URL"""
|
||||
global _server_started, _server_url
|
||||
|
||||
if _server_started and _server_url:
|
||||
return _server_url
|
||||
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
# Check if already running
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
if s.connect_ex(("localhost", 8283)) == 0:
|
||||
_server_started = True
|
||||
_server_url = url
|
||||
return url
|
||||
|
||||
# Start server (your existing logic)
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
|
||||
def _run_server():
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Poll until up
|
||||
timeout_seconds = 30
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url + "/v1/health")
|
||||
if resp.status_code < 500:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||
|
||||
_server_started = True
|
||||
_server_url = url
|
||||
return url
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Fixtures
|
||||
# ------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url() -> str:
|
||||
"""Return URL of already-started server"""
|
||||
return _start_server_once()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server_url: str) -> Letta:
|
||||
"""
|
||||
Creates and returns a synchronous Letta REST client for testing.
|
||||
"""
|
||||
client_instance = Letta(base_url=server_url)
|
||||
yield client_instance
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def async_client(server_url: str) -> AsyncLetta:
|
||||
"""
|
||||
Creates and returns an asynchronous Letta REST client for testing.
|
||||
"""
|
||||
async_client_instance = AsyncLetta(base_url=server_url)
|
||||
yield async_client_instance
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent_state(client: Letta) -> AgentState:
|
||||
"""
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
|
||||
"""
|
||||
client.tools.upsert_base_tools()
|
||||
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
agent_state_instance = client.agents.create(
|
||||
name="supervisor",
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id],
|
||||
model="openai/gpt-4o",
|
||||
embedding="letta/letta-free",
|
||||
tags=["supervisor"],
|
||||
)
|
||||
yield agent_state_instance
|
||||
|
||||
client.agents.delete(agent_state_instance.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def all_available_llm_configs(client: Letta) -> [LLMConfig]:
|
||||
"""
|
||||
Returns a list of all available LLM configs.
|
||||
"""
|
||||
llm_configs = client.models.list()
|
||||
return llm_configs
|
||||
|
||||
|
||||
# create a client to the started server started at
|
||||
def get_available_llm_configs() -> [LLMConfig]:
|
||||
"""Get configs, starting server if needed"""
|
||||
server_url = _start_server_once()
|
||||
temp_client = Letta(base_url=server_url)
|
||||
return temp_client.models.list()
|
||||
|
||||
|
||||
# dynamically insert llm_config paramter at collection time
|
||||
def pytest_generate_tests(metafunc):
|
||||
"""Dynamically parametrize tests that need llm_config."""
|
||||
if "llm_config" in metafunc.fixturenames:
|
||||
configs = get_available_llm_configs()
|
||||
if configs:
|
||||
metafunc.parametrize("llm_config", configs, ids=[c.model for c in configs])
|
||||
21
.github/scripts/model-sweep/feature_mappings.json
vendored
Normal file
21
.github/scripts/model-sweep/feature_mappings.json
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"Basic": [
|
||||
"test_greeting_with_assistant_message",
|
||||
"test_greeting_without_assistant_message",
|
||||
"test_async_greeting_with_assistant_message",
|
||||
"test_agent_loop_error",
|
||||
"test_step_stream_agent_loop_error",
|
||||
"test_step_streaming_greeting_with_assistant_message",
|
||||
"test_step_streaming_greeting_without_assistant_message",
|
||||
"test_step_streaming_tool_call",
|
||||
"test_tool_call",
|
||||
"test_auto_summarize"
|
||||
],
|
||||
"Token Streaming": [
|
||||
"test_token_streaming_greeting_with_assistant_message",
|
||||
"test_token_streaming_greeting_without_assistant_message",
|
||||
"test_token_streaming_agent_loop_error",
|
||||
"test_token_streaming_tool_call"
|
||||
],
|
||||
"Multimodal": ["test_base64_image_input", "test_url_image_input"]
|
||||
}
|
||||
495
.github/scripts/model-sweep/generate_model_sweep_markdown.py
vendored
Normal file
495
.github/scripts/model-sweep/generate_model_sweep_markdown.py
vendored
Normal file
@@ -0,0 +1,495 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def load_feature_mappings(config_file=None):
|
||||
"""Load feature mappings from config file."""
|
||||
if config_file is None:
|
||||
# Default to feature_mappings.json in the same directory as this script
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_file = os.path.join(script_dir, "feature_mappings.json")
|
||||
|
||||
try:
|
||||
with open(config_file, "r") as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Could not find feature mappings config file '{config_file}'")
|
||||
sys.exit(1)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: Invalid JSON in feature mappings config file '{config_file}'")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_support_status(passed_tests, feature_tests):
|
||||
"""Determine support status for a feature category."""
|
||||
if not feature_tests:
|
||||
return "❓" # Unknown - no tests for this feature
|
||||
|
||||
# Filter out error tests when checking for support
|
||||
non_error_tests = [test for test in feature_tests if not test.endswith("_error")]
|
||||
error_tests = [test for test in feature_tests if test.endswith("_error")]
|
||||
|
||||
# Check which non-error tests passed
|
||||
passed_non_error_tests = [test for test in non_error_tests if test in passed_tests]
|
||||
|
||||
# If there are no non-error tests, only error tests, treat as unknown
|
||||
if not non_error_tests:
|
||||
return "❓" # Only error tests available
|
||||
|
||||
# Support is based only on non-error tests
|
||||
if len(passed_non_error_tests) == len(non_error_tests):
|
||||
return "✅" # Full support
|
||||
elif len(passed_non_error_tests) == 0:
|
||||
return "❌" # No support
|
||||
else:
|
||||
return "⚠️" # Partial support
|
||||
|
||||
|
||||
def categorize_tests(all_test_names, feature_mapping):
|
||||
"""Categorize test names into feature buckets."""
|
||||
categorized = {feature: [] for feature in feature_mapping.keys()}
|
||||
|
||||
for test_name in all_test_names:
|
||||
for feature, test_patterns in feature_mapping.items():
|
||||
if test_name in test_patterns:
|
||||
categorized[feature].append(test_name)
|
||||
break
|
||||
|
||||
return categorized
|
||||
|
||||
|
||||
def calculate_support_score(feature_support, feature_order):
|
||||
"""Calculate a numeric support score for ranking models.
|
||||
|
||||
For partial support, the score is weighted by the position of the feature
|
||||
in the feature_order list (earlier features get higher weight).
|
||||
"""
|
||||
score = 0
|
||||
max_features = len(feature_order)
|
||||
|
||||
for feature, status in feature_support.items():
|
||||
# Get position weight (earlier features get higher weight)
|
||||
if feature in feature_order:
|
||||
position_weight = (max_features - feature_order.index(feature)) / max_features
|
||||
else:
|
||||
position_weight = 0.5 # Default weight for unmapped features
|
||||
|
||||
if status == "✅": # Full support
|
||||
score += 10 * position_weight
|
||||
elif status == "⚠️": # Partial support - weighted by column position
|
||||
score += 5 * position_weight
|
||||
elif status == "❌": # No support
|
||||
score += 1 * position_weight
|
||||
# Unknown (❓) gets 0 points
|
||||
return score
|
||||
|
||||
|
||||
def calculate_provider_support_score(models_data, feature_order):
|
||||
"""Calculate a provider-level support score based on all models' support scores."""
|
||||
if not models_data:
|
||||
return 0
|
||||
|
||||
# Calculate the average support score across all models in the provider
|
||||
total_score = sum(model["support_score"] for model in models_data)
|
||||
return total_score / len(models_data)
|
||||
|
||||
|
||||
def get_test_function_line_numbers(test_file_path):
|
||||
"""Extract line numbers for test functions from the test file."""
|
||||
test_line_numbers = {}
|
||||
|
||||
try:
|
||||
with open(test_file_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
if "def test_" in line and line.strip().startswith("def test_"):
|
||||
# Extract function name
|
||||
func_name = line.strip().split("def ")[1].split("(")[0]
|
||||
test_line_numbers[func_name] = i
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Could not find test file at {test_file_path}")
|
||||
|
||||
return test_line_numbers
|
||||
|
||||
|
||||
def get_github_repo_info():
|
||||
"""Get GitHub repository information from git remote."""
|
||||
try:
|
||||
# Try to get the GitHub repo URL from git remote
|
||||
import subprocess
|
||||
|
||||
result = subprocess.run(["git", "remote", "get-url", "origin"], capture_output=True, text=True, cwd=os.path.dirname(__file__))
|
||||
if result.returncode == 0:
|
||||
remote_url = result.stdout.strip()
|
||||
# Parse GitHub URL
|
||||
if "github.com" in remote_url:
|
||||
if remote_url.startswith("https://"):
|
||||
# https://github.com/user/repo.git -> user/repo
|
||||
repo_path = remote_url.replace("https://github.com/", "").replace(".git", "")
|
||||
elif remote_url.startswith("git@"):
|
||||
# git@github.com:user/repo.git -> user/repo
|
||||
repo_path = remote_url.split(":")[1].replace(".git", "")
|
||||
else:
|
||||
return None
|
||||
return repo_path
|
||||
except:
|
||||
pass
|
||||
|
||||
# Default fallback
|
||||
return "letta-ai/letta"
|
||||
|
||||
|
||||
def generate_test_details(model_info, feature_mapping):
|
||||
"""Generate detailed test results for a model."""
|
||||
details = []
|
||||
|
||||
# Get test function line numbers
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
test_file_path = os.path.join(script_dir, "model_sweep.py")
|
||||
test_line_numbers = get_test_function_line_numbers(test_file_path)
|
||||
|
||||
# Use the main branch GitHub URL
|
||||
base_github_url = "https://github.com/letta-ai/letta/blob/main/.github/scripts/model-sweep/model_sweep.py"
|
||||
|
||||
for feature, tests in model_info["categorized_tests"].items():
|
||||
if not tests:
|
||||
continue
|
||||
|
||||
details.append(f"### {feature}")
|
||||
details.append("")
|
||||
|
||||
for test in sorted(tests):
|
||||
if test in model_info["passed_tests"]:
|
||||
status = "✅"
|
||||
elif test in model_info["failed_tests"]:
|
||||
status = "❌"
|
||||
else:
|
||||
status = "❓"
|
||||
|
||||
# Create GitHub link if we have line number info
|
||||
if test in test_line_numbers:
|
||||
line_num = test_line_numbers[test]
|
||||
github_link = f"{base_github_url}#L{line_num}"
|
||||
details.append(f"- {status} [`{test}`]({github_link})")
|
||||
else:
|
||||
details.append(f"- {status} `{test}`")
|
||||
details.append("")
|
||||
|
||||
return details
|
||||
|
||||
|
||||
def calculate_column_widths(all_provider_data, feature_mapping):
|
||||
"""Calculate the maximum width needed for each column across all providers."""
|
||||
widths = {"model": len("Model"), "context_window": len("Context Window"), "last_scanned": len("Last Scanned")}
|
||||
|
||||
# Feature column widths
|
||||
for feature in feature_mapping.keys():
|
||||
widths[feature] = len(feature)
|
||||
|
||||
# Check all model data for maximum widths
|
||||
for provider_data in all_provider_data.values():
|
||||
for model_info in provider_data:
|
||||
# Model name width (including backticks)
|
||||
model_width = len(f"`{model_info['name']}`")
|
||||
widths["model"] = max(widths["model"], model_width)
|
||||
|
||||
# Context window width (with commas)
|
||||
context_width = len(f"{model_info['context_window']:,}")
|
||||
widths["context_window"] = max(widths["context_window"], context_width)
|
||||
|
||||
# Last scanned width
|
||||
widths["last_scanned"] = max(widths["last_scanned"], len(str(model_info["last_scanned"])))
|
||||
|
||||
# Feature support symbols are always 2 chars, so no need to check
|
||||
|
||||
return widths
|
||||
|
||||
|
||||
def process_model_sweep_report(input_file, output_file, config_file=None, debug=False):
|
||||
"""Convert model sweep JSON data to MDX report."""
|
||||
|
||||
# Load feature mappings from config file
|
||||
feature_mapping = load_feature_mappings(config_file)
|
||||
|
||||
# if debug:
|
||||
# print("DEBUG: Feature mappings loaded:")
|
||||
# for feature, tests in feature_mapping.items():
|
||||
# print(f" {feature}: {tests}")
|
||||
# print()
|
||||
|
||||
# Read the JSON data
|
||||
with open(input_file, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
tests = data.get("tests", [])
|
||||
|
||||
# if debug:
|
||||
# print("DEBUG: Tests loaded:")
|
||||
# print([test['outcome'] for test in tests if 'haiku' in test['nodeid']])
|
||||
|
||||
# Calculate summary statistics
|
||||
providers = set(test["metadata"]["llm_config"]["provider_name"] for test in tests)
|
||||
models = set(test["metadata"]["llm_config"]["model"] for test in tests)
|
||||
total_tests = len(tests)
|
||||
|
||||
# Start building the MDX
|
||||
mdx_lines = [
|
||||
"---",
|
||||
"title: Support Models",
|
||||
f"generated: {datetime.now().isoformat()}",
|
||||
"---",
|
||||
"",
|
||||
"# Supported Models",
|
||||
"",
|
||||
"## Overview",
|
||||
"",
|
||||
"Letta routinely runs automated scans against available providers and models. These are the results of the latest scan.",
|
||||
"",
|
||||
f"Ran {total_tests} tests against {len(models)} models across {len(providers)} providers on {datetime.now().strftime('%B %dth, %Y')}",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
|
||||
# Group tests by provider
|
||||
provider_groups = defaultdict(list)
|
||||
for test in tests:
|
||||
provider_name = test["metadata"]["llm_config"]["provider_name"]
|
||||
provider_groups[provider_name].append(test)
|
||||
|
||||
# Process all providers first to collect model data
|
||||
all_provider_data = {}
|
||||
provider_support_scores = {}
|
||||
|
||||
for provider_name in provider_groups.keys():
|
||||
provider_tests = provider_groups[provider_name]
|
||||
|
||||
# Group tests by model within this provider
|
||||
model_groups = defaultdict(list)
|
||||
for test in provider_tests:
|
||||
model_name = test["metadata"]["llm_config"]["model"]
|
||||
model_groups[model_name].append(test)
|
||||
|
||||
# Process all models to calculate support scores for ranking
|
||||
model_data = []
|
||||
for model_name in model_groups.keys():
|
||||
model_tests = model_groups[model_name]
|
||||
|
||||
# if debug:
|
||||
# print(f"DEBUG: Processing model '{model_name}' in provider '{provider_name}'")
|
||||
|
||||
# Extract unique test names for passed and failed tests
|
||||
passed_tests = set()
|
||||
failed_tests = set()
|
||||
all_test_names = set()
|
||||
|
||||
for test in model_tests:
|
||||
# Extract test name from nodeid (split on :: and [)
|
||||
test_name = test["nodeid"].split("::")[1].split("[")[0]
|
||||
all_test_names.add(test_name)
|
||||
|
||||
# if debug:
|
||||
# print(f" Test name: {test_name}")
|
||||
# print(f" Outcome: {test}")
|
||||
if test["outcome"] == "passed":
|
||||
passed_tests.add(test_name)
|
||||
elif test["outcome"] == "failed":
|
||||
failed_tests.add(test_name)
|
||||
|
||||
# if debug:
|
||||
# print(f" All test names found: {sorted(all_test_names)}")
|
||||
# print(f" Passed tests: {sorted(passed_tests)}")
|
||||
# print(f" Failed tests: {sorted(failed_tests)}")
|
||||
|
||||
# Categorize tests into features
|
||||
categorized_tests = categorize_tests(all_test_names, feature_mapping)
|
||||
|
||||
# if debug:
|
||||
# print(f" Categorized tests:")
|
||||
# for feature, tests in categorized_tests.items():
|
||||
# print(f" {feature}: {tests}")
|
||||
|
||||
# Determine support status for each feature
|
||||
feature_support = {}
|
||||
for feature_name in feature_mapping.keys():
|
||||
feature_support[feature_name] = get_support_status(passed_tests, categorized_tests[feature_name])
|
||||
|
||||
# if debug:
|
||||
# print(f" Feature support:")
|
||||
# for feature, status in feature_support.items():
|
||||
# print(f" {feature}: {status}")
|
||||
# print()
|
||||
|
||||
# Get context window and last scanned time
|
||||
context_window = model_tests[0]["metadata"]["llm_config"]["context_window"]
|
||||
|
||||
# Try to get time_last_scanned from metadata, fallback to current time
|
||||
try:
|
||||
last_scanned = model_tests[0]["metadata"].get(
|
||||
"time_last_scanned", model_tests[0]["metadata"].get("timestamp", datetime.now().isoformat())
|
||||
)
|
||||
# Format timestamp if it's a full ISO string
|
||||
if "T" in str(last_scanned):
|
||||
last_scanned = str(last_scanned).split("T")[0] # Just the date part
|
||||
except:
|
||||
last_scanned = "Unknown"
|
||||
|
||||
# Calculate support score for ranking
|
||||
feature_order = list(feature_mapping.keys())
|
||||
support_score = calculate_support_score(feature_support, feature_order)
|
||||
|
||||
# Store model data for sorting
|
||||
model_data.append(
|
||||
{
|
||||
"name": model_name,
|
||||
"feature_support": feature_support,
|
||||
"context_window": context_window,
|
||||
"last_scanned": last_scanned,
|
||||
"support_score": support_score,
|
||||
"failed_tests": failed_tests,
|
||||
"passed_tests": passed_tests,
|
||||
"categorized_tests": categorized_tests,
|
||||
}
|
||||
)
|
||||
|
||||
# Sort models by support score (descending) then by name (ascending)
|
||||
model_data.sort(key=lambda x: (-x["support_score"], x["name"]))
|
||||
|
||||
# Store provider data
|
||||
all_provider_data[provider_name] = model_data
|
||||
provider_support_scores[provider_name] = calculate_provider_support_score(model_data, list(feature_mapping.keys()))
|
||||
|
||||
# Calculate column widths for consistent formatting (add details column)
|
||||
column_widths = calculate_column_widths(all_provider_data, feature_mapping)
|
||||
column_widths["details"] = len("Details")
|
||||
|
||||
# Sort providers by support score (descending) then by name (ascending)
|
||||
sorted_providers = sorted(provider_support_scores.keys(), key=lambda x: (-provider_support_scores[x], x))
|
||||
|
||||
# Generate tables for all providers first
|
||||
for provider_name in sorted_providers:
|
||||
model_data = all_provider_data[provider_name]
|
||||
support_score = provider_support_scores[provider_name]
|
||||
|
||||
# Create dynamic headers with proper padding and centering
|
||||
feature_names = list(feature_mapping.keys())
|
||||
|
||||
# Build header row with left-aligned first column, centered others
|
||||
header_parts = [f"{'Model':<{column_widths['model']}}"]
|
||||
for feature in feature_names:
|
||||
header_parts.append(f"{feature:^{column_widths[feature]}}")
|
||||
header_parts.extend(
|
||||
[
|
||||
f"{'Context Window':^{column_widths['context_window']}}",
|
||||
f"{'Last Scanned':^{column_widths['last_scanned']}}",
|
||||
f"{'Details':^{column_widths['details']}}",
|
||||
]
|
||||
)
|
||||
header_row = "| " + " | ".join(header_parts) + " |"
|
||||
|
||||
# Build separator row with left-aligned first column, centered others
|
||||
separator_parts = [f"{'-' * column_widths['model']}"]
|
||||
for feature in feature_names:
|
||||
separator_parts.append(f":{'-' * (column_widths[feature] - 2)}:")
|
||||
separator_parts.extend(
|
||||
[
|
||||
f":{'-' * (column_widths['context_window'] - 2)}:",
|
||||
f":{'-' * (column_widths['last_scanned'] - 2)}:",
|
||||
f":{'-' * (column_widths['details'] - 2)}:",
|
||||
]
|
||||
)
|
||||
separator_row = "|" + "|".join(separator_parts) + "|"
|
||||
|
||||
# Add provider section without percentage
|
||||
mdx_lines.extend([f"## {provider_name}", "", header_row, separator_row])
|
||||
|
||||
# Generate table rows for sorted models with proper padding
|
||||
for model_info in model_data:
|
||||
# Create anchor for model details
|
||||
model_anchor = model_info["name"].replace("/", "_").replace(":", "_").replace("-", "_").lower()
|
||||
details_anchor = f"{provider_name.lower().replace(' ', '_')}_{model_anchor}_details"
|
||||
|
||||
# Build row with left-aligned first column, centered others
|
||||
row_parts = [f"`{model_info['name']}`".ljust(column_widths["model"])]
|
||||
for feature in feature_names:
|
||||
row_parts.append(f"{model_info['feature_support'][feature]:^{column_widths[feature]}}")
|
||||
row_parts.extend(
|
||||
[
|
||||
f"{model_info['context_window']:,}".center(column_widths["context_window"]),
|
||||
f"{model_info['last_scanned']}".center(column_widths["last_scanned"]),
|
||||
f"[View](#{details_anchor})".center(column_widths["details"]),
|
||||
]
|
||||
)
|
||||
row = "| " + " | ".join(row_parts) + " |"
|
||||
mdx_lines.append(row)
|
||||
|
||||
# Add spacing between provider tables
|
||||
mdx_lines.extend(["", ""])
|
||||
|
||||
# Add detailed test results section after all tables
|
||||
mdx_lines.extend(["---", "", "# Detailed Test Results", ""])
|
||||
|
||||
for provider_name in sorted_providers:
|
||||
model_data = all_provider_data[provider_name]
|
||||
mdx_lines.extend([f"## {provider_name}", ""])
|
||||
|
||||
for model_info in model_data:
|
||||
model_anchor = model_info["name"].replace("/", "_").replace(":", "_").replace("-", "_").lower()
|
||||
details_anchor = f"{provider_name.lower().replace(' ', '_')}_{model_anchor}_details"
|
||||
mdx_lines.append(f'<a id="{details_anchor}"></a>')
|
||||
mdx_lines.append(f"### {model_info['name']}")
|
||||
mdx_lines.append("")
|
||||
|
||||
# Add test details
|
||||
test_details = generate_test_details(model_info, feature_mapping)
|
||||
mdx_lines.extend(test_details)
|
||||
|
||||
# Add spacing between providers in details section
|
||||
mdx_lines.extend(["", ""])
|
||||
|
||||
# Write the MDX file
|
||||
with open(output_file, "w") as f:
|
||||
f.write("\n".join(mdx_lines))
|
||||
|
||||
print(f"Model sweep report saved to {output_file}")
|
||||
|
||||
|
||||
def main():
|
||||
input_file = "model_sweep_report.json"
|
||||
output_file = "model_sweep_report.mdx"
|
||||
config_file = None
|
||||
debug = False
|
||||
|
||||
# Allow command line arguments
|
||||
if len(sys.argv) > 1:
|
||||
# Use the file located in the same directory as this script
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
input_file = os.path.join(script_dir, sys.argv[1])
|
||||
if len(sys.argv) > 2:
|
||||
# Use the file located in the same directory as this script
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_file = os.path.join(script_dir, sys.argv[2])
|
||||
if len(sys.argv) > 3:
|
||||
config_file = sys.argv[3]
|
||||
if len(sys.argv) > 4 and sys.argv[4] == "--debug":
|
||||
debug = True
|
||||
|
||||
try:
|
||||
process_model_sweep_report(input_file, output_file, config_file, debug)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Could not find input file '{input_file}'")
|
||||
sys.exit(1)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: Invalid JSON in file '{input_file}'")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
786
.github/scripts/model-sweep/model_sweep.py
vendored
Normal file
786
.github/scripts/model-sweep/model_sweep.py
vendored
Normal file
@@ -0,0 +1,786 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta, MessageCreate, Run
|
||||
from letta_client.core.api_error import ApiError
|
||||
from letta_client.types import (
|
||||
AssistantMessage,
|
||||
Base64Image,
|
||||
ImageContent,
|
||||
LettaUsageStatistics,
|
||||
ReasoningMessage,
|
||||
TextContent,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
UrlImage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
# ------------------------------
|
||||
# Helper Functions and Constants
|
||||
# ------------------------------
|
||||
|
||||
|
||||
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
|
||||
filename = os.path.join(llm_config_dir, filename)
|
||||
config_data = json.load(open(filename, "r"))
|
||||
llm_config = LLMConfig(**config_data)
|
||||
return llm_config
|
||||
|
||||
|
||||
def roll_dice(num_sides: int) -> int:
|
||||
"""
|
||||
Returns a random number between 1 and num_sides.
|
||||
Args:
|
||||
num_sides (int): The number of sides on the die.
|
||||
Returns:
|
||||
int: A random integer between 1 and num_sides, representing the die roll.
|
||||
"""
|
||||
import random
|
||||
|
||||
return random.randint(1, num_sides)
|
||||
|
||||
|
||||
USER_MESSAGE_OTID = str(uuid.uuid4())
|
||||
USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work"
|
||||
USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=f"This is an automated test message. Call the send_message tool with the message '{USER_MESSAGE_RESPONSE}'.",
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="This is an automated test message. Call the roll_dice tool with 16 sides and tell me the outcome.",
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
URL_IMAGE = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg"
|
||||
USER_MESSAGE_URL_IMAGE: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
ImageContent(source=UrlImage(url=URL_IMAGE)),
|
||||
TextContent(text="What is in this image?"),
|
||||
],
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
BASE64_IMAGE = base64.standard_b64encode(httpx.get(URL_IMAGE).content).decode("utf-8")
|
||||
USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
ImageContent(source=Base64Image(data=BASE64_IMAGE, media_type="image/jpeg")),
|
||||
TextContent(text="What is in this image?"),
|
||||
],
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
# "azure-gpt-4o-mini.json", # TODO: Re-enable on new agent loop
|
||||
"claude-3-5-sonnet.json",
|
||||
"claude-3-7-sonnet.json",
|
||||
"claude-3-7-sonnet-extended.json",
|
||||
"gemini-1.5-pro.json",
|
||||
"gemini-2.5-flash-vertex.json",
|
||||
"gemini-2.5-pro-vertex.json",
|
||||
"together-qwen-2.5-72b-instruct.json",
|
||||
"ollama.json",
|
||||
]
|
||||
requested = os.getenv("LLM_CONFIG_FILE")
|
||||
filenames = [requested] if requested else all_configs
|
||||
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
|
||||
|
||||
|
||||
def assert_greeting_with_assistant_message_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
token_streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
ReasoningMessage -> AssistantMessage.
|
||||
"""
|
||||
expected_message_count = 3 if streaming or from_db else 2
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
if not token_streaming:
|
||||
assert USER_MESSAGE_RESPONSE in messages[index].content
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
assert messages[index].prompt_tokens > 0
|
||||
assert messages[index].completion_tokens > 0
|
||||
assert messages[index].total_tokens > 0
|
||||
assert messages[index].step_count > 0
|
||||
|
||||
|
||||
def assert_greeting_without_assistant_message_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
token_streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
ReasoningMessage -> ToolCallMessage -> ToolReturnMessage.
|
||||
"""
|
||||
expected_message_count = 4 if streaming or from_db else 3
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], ToolCallMessage)
|
||||
assert messages[index].tool_call.name == "send_message"
|
||||
if not token_streaming:
|
||||
assert USER_MESSAGE_RESPONSE in messages[index].tool_call.arguments
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
# Agent Step 2
|
||||
assert isinstance(messages[index], ToolReturnMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
|
||||
|
||||
def assert_tool_call_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
ReasoningMessage -> ToolCallMessage -> ToolReturnMessage ->
|
||||
ReasoningMessage -> AssistantMessage.
|
||||
"""
|
||||
expected_message_count = 6 if streaming else 7 if from_db else 5
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], ToolCallMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
# Agent Step 2
|
||||
assert isinstance(messages[index], ToolReturnMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
# Hidden User Message
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert "request_heartbeat=true" in messages[index].content
|
||||
index += 1
|
||||
|
||||
# Agent Step 3
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
|
||||
|
||||
def assert_image_input_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
token_streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
ReasoningMessage -> AssistantMessage.
|
||||
"""
|
||||
expected_message_count = 3 if streaming or from_db else 2
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == "1"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
assert messages[index].prompt_tokens > 0
|
||||
assert messages[index].completion_tokens > 0
|
||||
assert messages[index].total_tokens > 0
|
||||
assert messages[index].step_count > 0
|
||||
|
||||
|
||||
def accumulate_chunks(chunks: List[Any]) -> List[Any]:
|
||||
"""
|
||||
Accumulates chunks into a list of messages.
|
||||
"""
|
||||
messages = []
|
||||
current_message = None
|
||||
prev_message_type = None
|
||||
for chunk in chunks:
|
||||
current_message_type = chunk.message_type
|
||||
if prev_message_type != current_message_type:
|
||||
messages.append(current_message)
|
||||
current_message = None
|
||||
if current_message is None:
|
||||
current_message = chunk
|
||||
else:
|
||||
pass # TODO: actually accumulate the chunks. For now we only care about the count
|
||||
prev_message_type = current_message_type
|
||||
messages.append(current_message)
|
||||
return [m for m in messages if m is not None]
|
||||
|
||||
|
||||
def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
|
||||
start = time.time()
|
||||
while True:
|
||||
run = client.runs.retrieve(run_id)
|
||||
if run.status == "completed":
|
||||
return run
|
||||
if run.status == "failed":
|
||||
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})")
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Asserts that a list of message dictionaries contains the expected types and statuses.
|
||||
|
||||
Expected order:
|
||||
1. reasoning_message
|
||||
2. tool_call_message
|
||||
3. tool_return_message (with status 'success')
|
||||
4. reasoning_message
|
||||
5. assistant_message
|
||||
"""
|
||||
assert isinstance(messages, list)
|
||||
assert messages[0]["message_type"] == "reasoning_message"
|
||||
assert messages[1]["message_type"] == "assistant_message"
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Test Cases
|
||||
# ------------------------------
|
||||
|
||||
# def test_that_ci_workflow_works(
|
||||
# disable_e2b_api_key: Any,
|
||||
# client: Letta,
|
||||
# agent_state: AgentState,
|
||||
# llm_config: LLMConfig,
|
||||
# json_metadata: pytest.FixtureRequest,
|
||||
# ) -> None:
|
||||
# """
|
||||
# Tests that the CI workflow works.
|
||||
# """
|
||||
# json_metadata["test_type"] = "debug"
|
||||
|
||||
|
||||
def test_greeting_with_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
assert_greeting_with_assistant_message_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_greeting_without_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
)
|
||||
assert_greeting_without_assistant_message_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
|
||||
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_tool_call(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
dice_tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
)
|
||||
assert_tool_call_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_url_image_input(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_URL_IMAGE,
|
||||
)
|
||||
assert_image_input_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_image_input_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_base64_image_input(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_BASE64_IMAGE,
|
||||
)
|
||||
assert_image_input_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_image_input_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_agent_loop_error(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that no new messages are persisted on error.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
tools = agent_state.tools
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[])
|
||||
with pytest.raises(ApiError):
|
||||
client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert len(messages_from_db) == 0
|
||||
client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools])
|
||||
|
||||
|
||||
def test_step_streaming_greeting_with_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_greeting_with_assistant_message_response(messages, streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_step_streaming_greeting_without_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_greeting_without_assistant_message_response(messages, streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
|
||||
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_step_streaming_tool_call(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
dice_tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
agent_state = client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_tool_call_response(messages, streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_step_stream_agent_loop_error(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that no new messages are persisted on error.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
tools = agent_state.tools
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[])
|
||||
with pytest.raises(ApiError):
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
list(response)
|
||||
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert len(messages_from_db) == 0
|
||||
client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools])
|
||||
|
||||
|
||||
def test_token_streaming_greeting_with_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
stream_tokens=True,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_token_streaming_greeting_without_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
use_assistant_message=False,
|
||||
stream_tokens=True,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
|
||||
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_token_streaming_tool_call(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
dice_tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
agent_state = client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
stream_tokens=True,
|
||||
)
|
||||
chunks = list(response)
|
||||
messages = accumulate_chunks(chunks)
|
||||
assert_tool_call_response(messages, streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
def test_token_streaming_agent_loop_error(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that no new messages are persisted on error.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
tools = agent_state.tools
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[])
|
||||
try:
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
stream_tokens=True,
|
||||
)
|
||||
list(response)
|
||||
except:
|
||||
pass # only some models throw an error TODO: make this consistent
|
||||
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert len(messages_from_db) == 0
|
||||
client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools])
|
||||
|
||||
|
||||
def test_async_greeting_with_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
agent_state: AgentState,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message as an asynchronous job using the synchronous client.
|
||||
Waits for job completion and asserts that the result messages are as expected.
|
||||
"""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
run = client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
run = wait_for_run_completion(client, run.id)
|
||||
|
||||
result = run.metadata.get("result")
|
||||
assert result is not None, "Run metadata missing 'result' key"
|
||||
|
||||
messages = result["messages"]
|
||||
assert_tool_response_dict_messages(messages)
|
||||
|
||||
|
||||
def test_auto_summarize(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
llm_config: LLMConfig,
|
||||
json_metadata: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""Test that summarization is automatically triggered."""
|
||||
json_metadata["llm_config"] = dict(llm_config)
|
||||
|
||||
# pydantic prevents us for overriding the context window paramter in the passed LLMConfig
|
||||
new_llm_config = llm_config.model_dump()
|
||||
new_llm_config["context_window"] = 3000
|
||||
pinned_context_window_llm_config = LLMConfig(**new_llm_config)
|
||||
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
temp_agent_state = client.agents.create(
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id],
|
||||
llm_config=pinned_context_window_llm_config,
|
||||
embedding="letta/letta-free",
|
||||
tags=["supervisor"],
|
||||
)
|
||||
|
||||
philosophical_question = """
|
||||
You know, sometimes I wonder if the entire structure of our lives is built on a series of unexamined assumptions we just silently agreed to somewhere along the way—like how we all just decided that five days a week of work and two days of “rest” constitutes balance, or how 9-to-5 became the default rhythm of a meaningful life, or even how the idea of “success” got boiled down to job titles and property ownership and productivity metrics on a LinkedIn profile, when maybe none of that is actually what makes a life feel full, or grounded, or real. And then there’s the weird paradox of ambition, how we're taught to chase it like a finish line that keeps moving, constantly redefining itself right as you’re about to grasp it—because even when you get the job, or the degree, or the validation, there's always something next, something more, like a treadmill with invisible settings you didn’t realize were turned up all the way.
|
||||
|
||||
And have you noticed how we rarely stop to ask who set those definitions for us? Like was there ever a council that decided, yes, owning a home by thirty-five and retiring by sixty-five is the universal template for fulfillment? Or did it just accumulate like cultural sediment over generations, layered into us so deeply that questioning it feels uncomfortable, even dangerous? And isn’t it strange that we spend so much of our lives trying to optimize things—our workflows, our diets, our sleep, our morning routines—as though the point of life is to operate more efficiently rather than to experience it more richly? We build these intricate systems, these rulebooks for being a “high-functioning” human, but where in all of that is the space for feeling lost, for being soft, for wandering without a purpose just because it’s a sunny day and your heart is tugging you toward nowhere in particular?
|
||||
|
||||
Sometimes I lie awake at night and wonder if all the noise we wrap around ourselves—notifications, updates, performance reviews, even our internal monologues—might be crowding out the questions we were meant to live into slowly, like how to love better, or how to forgive ourselves, or what the hell we’re even doing here in the first place. And when you strip it all down—no goals, no KPIs, no curated identity—what’s actually left of us? Are we just a sum of the roles we perform, or is there something quieter underneath that we've forgotten how to hear?
|
||||
|
||||
And if there is something underneath all of it—something real, something worth listening to—then how do we begin to uncover it, gently, without rushing or reducing it to another task on our to-do list?
|
||||
"""
|
||||
|
||||
MAX_ATTEMPTS = 10
|
||||
prev_length = None
|
||||
|
||||
for attempt in range(MAX_ATTEMPTS):
|
||||
client.agents.messages.create(
|
||||
agent_id=temp_agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=philosophical_question)],
|
||||
)
|
||||
|
||||
temp_agent_state = client.agents.retrieve(agent_id=temp_agent_state.id)
|
||||
message_ids = temp_agent_state.message_ids
|
||||
current_length = len(message_ids)
|
||||
|
||||
print("LENGTH OF IN_CONTEXT_MESSAGES:", current_length)
|
||||
|
||||
if prev_length is not None and current_length <= prev_length:
|
||||
# TODO: Add more stringent checks here
|
||||
print(f"Summarization was triggered, detected current_length {current_length} is at least prev_length {prev_length}.")
|
||||
break
|
||||
|
||||
prev_length = current_length
|
||||
else:
|
||||
raise AssertionError("Summarization was not triggered after 10 messages")
|
||||
4551
.github/scripts/model-sweep/supported-models.mdx
vendored
Normal file
4551
.github/scripts/model-sweep/supported-models.mdx
vendored
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user