chore: Add agent mass update test (#1798)
This commit is contained in:
220
performance_tests/test_agent_mass_update.py
Normal file
220
performance_tests/test_agent_mass_update.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
from tqdm import tqdm
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
# --- Server Management --- #
|
||||
|
||||
|
||||
def _run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server_url():
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(2) # Allow server startup time
|
||||
|
||||
return url
|
||||
|
||||
|
||||
# --- Client Setup --- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client for testing."""
|
||||
client = Letta(base_url=server_url)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def roll_dice_tool(client):
|
||||
def roll_dice():
|
||||
"""
|
||||
Rolls a 6 sided die.
|
||||
|
||||
Returns:
|
||||
str: The roll result.
|
||||
"""
|
||||
return "Rolled a 10!"
|
||||
|
||||
tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def rethink_tool(client):
|
||||
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
|
||||
"""
|
||||
Re-evaluate the memory in block_name, integrating new and updated facts.
|
||||
Replace outdated information with the most likely truths, avoiding redundancy with original memories.
|
||||
Ensure consistency with other memory blocks.
|
||||
|
||||
Args:
|
||||
new_memory (str): The new memory with information integrated from the memory block. If there is no new information, then this should be the same as the content in the source block.
|
||||
target_block_label (str): The name of the block to write to.
|
||||
Returns:
|
||||
str: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
|
||||
return None
|
||||
|
||||
tool = client.tools.upsert_from_function(func=rethink_memory)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def weather_tool(client):
|
||||
def get_weather(location: str) -> str:
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
|
||||
Parameters:
|
||||
location (str): The location to get the weather for.
|
||||
|
||||
Returns:
|
||||
str: A formatted string describing the weather in the given location.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the request to fetch weather data fails.
|
||||
"""
|
||||
import requests
|
||||
|
||||
url = f"https://wttr.in/{location}?format=%C+%t"
|
||||
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
weather_data = response.text
|
||||
return f"The weather in {location} is {weather_data}."
|
||||
else:
|
||||
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
||||
|
||||
tool = client.tools.upsert_from_function(func=get_weather)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
# --- Load Test --- #
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_parallel_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool):
|
||||
# 1) Create 30 agents WITHOUT the rethink_tool initially
|
||||
agent_ids = []
|
||||
for i in range(5):
|
||||
agent = client.agents.create(
|
||||
name=f"complex_agent_{i}_{uuid.uuid4().hex[:6]}",
|
||||
tool_ids=[roll_dice_tool.id, weather_tool.id],
|
||||
include_base_tools=False,
|
||||
memory_blocks=[
|
||||
{"label": "human", "value": "Name: Matt"},
|
||||
{"label": "persona", "value": "Friendly agent"},
|
||||
],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
)
|
||||
agent_ids.append(agent.id)
|
||||
|
||||
# 2) Pre-create 10 new blocks *per* agent
|
||||
per_agent_blocks = {}
|
||||
for aid in agent_ids:
|
||||
block_ids = []
|
||||
for j in range(10):
|
||||
blk = client.blocks.create(
|
||||
label=f"{aid[:6]}_blk{j}",
|
||||
value="Precreated block content",
|
||||
description="Load-test block",
|
||||
limit=500,
|
||||
metadata={"idx": str(j)},
|
||||
)
|
||||
block_ids.append(blk.id)
|
||||
per_agent_blocks[aid] = block_ids
|
||||
|
||||
# 3) Dispatch 100 updates per agent in parallel
|
||||
total_updates = len(agent_ids) * 100
|
||||
latencies = []
|
||||
|
||||
def do_update(agent_id: str):
|
||||
start = time.time()
|
||||
if random.random() < 0.5:
|
||||
client.agents.modify(agent_id=agent_id, tool_ids=[rethink_tool.id])
|
||||
else:
|
||||
bid = random.choice(per_agent_blocks[agent_id])
|
||||
client.agents.modify(agent_id=agent_id, block_ids=[bid])
|
||||
return time.time() - start
|
||||
|
||||
with ThreadPoolExecutor(max_workers=50) as executor:
|
||||
futures = [executor.submit(do_update, aid) for aid in agent_ids for _ in range(10)]
|
||||
for future in tqdm(as_completed(futures), total=total_updates, desc="Complex updates"):
|
||||
latencies.append(future.result())
|
||||
|
||||
# 4) Cleanup
|
||||
for aid in agent_ids:
|
||||
client.agents.delete(aid)
|
||||
|
||||
# 5) Plot latency distribution
|
||||
df = pd.DataFrame({"latency": latencies})
|
||||
plt.figure(figsize=(12, 6))
|
||||
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.hist(df["latency"], bins=30, edgecolor="black")
|
||||
plt.title("Update Latency Distribution")
|
||||
plt.xlabel("Latency (seconds)")
|
||||
plt.ylabel("Frequency")
|
||||
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.boxplot(df["latency"], vert=False)
|
||||
plt.title("Update Latency Boxplot")
|
||||
plt.xlabel("Latency (seconds)")
|
||||
|
||||
plt.tight_layout()
|
||||
plot_file = f"complex_update_latency_{int(time.time())}.png"
|
||||
plt.savefig(plot_file)
|
||||
plt.close()
|
||||
|
||||
# 6) Report summary
|
||||
mean = df["latency"].mean()
|
||||
median = df["latency"].median()
|
||||
minimum = df["latency"].min()
|
||||
maximum = df["latency"].max()
|
||||
stdev = df["latency"].std()
|
||||
|
||||
print("\n===== Complex Update Latency Statistics =====")
|
||||
print(f"Total updates: {len(latencies)}")
|
||||
print(f"Mean: {mean:.3f}s")
|
||||
print(f"Median: {median:.3f}s")
|
||||
print(f"Min: {minimum:.3f}s")
|
||||
print(f"Max: {maximum:.3f}s")
|
||||
print(f"Std: {stdev:.3f}s")
|
||||
print(f"Plot saved to: {plot_file}")
|
||||
|
||||
# Sanity assertion
|
||||
assert median < 2.0, f"Median update latency too high: {median:.3f}s"
|
||||
Reference in New Issue
Block a user