Files
letta-server/tests/integration_test_multi_agent.py
2025-05-29 11:10:13 -07:00

426 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import os
import threading
import time
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta
from letta.config import LettaConfig
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.schemas.letta_message import SystemMessage, ToolReturnMessage
from letta.schemas.tool import Tool
from letta.server.server import SyncServer
from letta.services.agent_manager import AgentManager
from letta.settings import settings
from tests.helpers.utils import retry_until_success
from tests.utils import wait_for_incoming_message
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until its accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
temp = settings.use_experimental
settings.use_experimental = True
yield url
settings.use_experimental = temp
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
print("CONFIG PATH", config.config_path)
config.save()
server = SyncServer()
return server
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
client_instance.tools.upsert_base_tools()
yield client_instance
@pytest.fixture(autouse=True)
def remove_stale_agents(client):
stale_agents = client.agents.list(limit=300)
for agent in stale_agents:
client.agents.delete(agent_id=agent.id)
@pytest.fixture(scope="function")
def agent_obj(client):
"""Create a test agent that we can call functions on"""
send_message_to_agent_tool = client.tools.list(name="send_message_to_agent_and_wait_for_reply")[0]
agent_state_instance = client.agents.create(
include_base_tools=True,
tool_ids=[send_message_to_agent_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
context_window_limit=32000,
)
yield agent_state_instance
client.agents.delete(agent_state_instance.id)
@pytest.fixture(scope="function")
def other_agent_obj(client):
"""Create another test agent that we can call functions on"""
agent_state_instance = client.agents.create(
include_base_tools=True,
include_multi_agent_tools=False,
model="openai/gpt-4o",
embedding="letta/letta-free",
context_window_limit=32000,
)
yield agent_state_instance
client.agents.delete(agent_state_instance.id)
@pytest.fixture
def roll_dice_tool(client):
def roll_dice():
"""
Rolls a 6 sided die.
Returns:
str: The roll result.
"""
return "Rolled a 5!"
# Set up tool details
source_code = parse_source_code(roll_dice)
source_type = "python"
description = "test_description"
tags = ["test"]
tool = Tool(description=description, tags=tags, source_code=source_code, source_type=source_type)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool.json_schema = derived_json_schema
tool.name = derived_name
tool = client.tools.upsert_from_function(func=roll_dice)
# Yield the created tool
yield tool
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agent(client, server, agent_obj, other_agent_obj):
secret_word = "banana"
actor = server.user_manager.get_user_or_default()
# Encourage the agent to send a message to the other agent_obj with the secret string
client.agents.messages.create(
agent_id=agent_obj.id,
messages=[
{
"role": "user",
"content": f"Use your tool to send a message to another agent with id {other_agent_obj.id} to share the secret word: {secret_word}!",
}
],
)
# Conversation search the other agent
messages = server.get_agent_recall(
user_id=actor.id,
agent_id=other_agent_obj.id,
reverse=True,
return_message_object=False,
)
# Check for the presence of system message
for m in reversed(messages):
print(f"\n\n {other_agent_obj.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word in m.content
break
# Search the sender agent for the response from another agent
in_context_messages = AgentManager().get_in_context_messages(agent_id=agent_obj.id, actor=actor)
found = False
target_snippet = f"'agent_id': '{other_agent_obj.id}', 'response': ["
for m in in_context_messages:
if target_snippet in m.content[0].text:
found = True
break
joined = "\n".join([m.content[0].text for m in in_context_messages[1:]])
print(f"In context messages of the sender agent (without system):\n\n{joined}")
if not found:
raise Exception(f"Was not able to find an instance of the target snippet: {target_snippet}")
# Test that the agent can still receive messages fine
response = client.agents.messages.create(
agent_id=agent_obj.id,
messages=[
{
"role": "user",
"content": "So what did the other agent say?",
}
],
)
print(response.messages)
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_simple(client):
worker_tags_123 = ["worker", "user-123"]
worker_tags_456 = ["worker", "user-456"]
secret_word = "banana"
# Create "manager" agent
send_message_to_agents_matching_tags_tool_id = client.tools.list(name="send_message_to_agents_matching_tags")[0].id
manager_agent_state = client.agents.create(
name="manager_agent",
tool_ids=[send_message_to_agents_matching_tags_tool_id],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
# Create 3 non-matching worker agents (These should NOT get the message)
worker_agents_123 = []
for idx in range(2):
worker_agent_state = client.agents.create(
name=f"not_worker_{idx}",
include_multi_agent_tools=False,
tags=worker_tags_123,
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
worker_agents_123.append(worker_agent_state)
# Create 3 worker agents that should get the message
worker_agents_456 = []
for idx in range(2):
worker_agent_state = client.agents.create(
name=f"worker_{idx}",
include_multi_agent_tools=False,
tags=worker_tags_456,
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
worker_agents_456.append(worker_agent_state)
# Encourage the manager to send a message to the other agent_obj with the secret string
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
{
"role": "user",
"content": f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!",
}
],
)
for m in response.messages:
if isinstance(m, ToolReturnMessage):
tool_response = eval(json.loads(m.tool_return)["message"])
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
assert len(tool_response) == len(worker_agents_456)
# We can break after this, the ToolReturnMessage after is not related
break
# Conversation search the worker agents
for agent_state in worker_agents_456:
messages = client.agents.messages.list(agent_state.id)
# Check for the presence of system message
for m in reversed(messages):
print(f"\n\n {agent_state.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word in m.content
break
# Ensure it's NOT in the non matching worker agents
for agent_state in worker_agents_123:
messages = client.agents.messages.list(agent_state.id)
# Check for the presence of system message
for m in reversed(messages):
print(f"\n\n {agent_state.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word not in m.content
# Test that the agent can still receive messages fine
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
{
"role": "user",
"content": "So what did the other agent say?",
}
],
)
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_tool):
# Create "manager" agent
send_message_to_agents_matching_tags_tool_id = client.tools.list(name="send_message_to_agents_matching_tags")[0].id
manager_agent_state = client.agents.create(
tool_ids=[send_message_to_agents_matching_tags_tool_id],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
# Create 3 worker agents
worker_agents = []
worker_tags = ["dice-rollers"]
for _ in range(2):
worker_agent_state = client.agents.create(
include_multi_agent_tools=False,
tags=worker_tags,
tool_ids=[roll_dice_tool.id],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
worker_agents.append(worker_agent_state)
# Encourage the manager to send a message to the other agent_obj with the secret string
broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!"
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
{
"role": "user",
"content": broadcast_message,
}
],
)
for m in response.messages:
if isinstance(m, ToolReturnMessage):
tool_response = eval(json.loads(m.tool_return)["message"])
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
assert len(tool_response) == len(worker_agents)
# We can break after this, the ToolReturnMessage after is not related
break
# Test that the agent can still receive messages fine
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
{
"role": "user",
"content": "So what did the other agent say?",
}
],
)
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
# @retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_agents_async_simple(client):
"""
Test two agents with multi-agent tools sending messages back and forth to count to 5.
The chain is started by prompting one of the agents.
"""
# Create two agents with multi-agent tools
send_message_to_agent_async_tool_id = client.tools.list(name="send_message_to_agent_async")[0].id
charles_state = client.agents.create(
name="charles",
tool_ids=[send_message_to_agent_async_tool_id],
memory_blocks=[
{
"label": "human",
"value": "Chad - I'm interested in hearing poem.",
},
{
"label": "persona",
"value": "You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who has some great poem ideas (so I've heard).",
},
],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
sarah_state = client.agents.create(
name="sarah",
tool_ids=[send_message_to_agent_async_tool_id],
memory_blocks=[
{
"label": "human",
"value": "No human - you are to only communicate with the other AI agent.",
},
{
"label": "persona",
"value": "You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who is interested in great poem ideas.",
},
],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
# Start the count chain with Agent1
initial_prompt = f"I want you to talk to the other agent with ID {sarah_state.id} using `send_message_to_agent_async`. Specifically, I want you to ask him for a poem idea, and then craft a poem for me."
client.agents.messages.create(
agent_id=charles_state.id,
messages=[{"role": "user", "content": initial_prompt}],
)
found_in_charles = wait_for_incoming_message(
client=client,
agent_id=charles_state.id,
substring="[Incoming message from agent with ID",
max_wait_seconds=10,
sleep_interval=0.5,
)
assert found_in_charles, "Charles never received the system message from Sarah (timed out)."
found_in_sarah = wait_for_incoming_message(
client=client,
agent_id=sarah_state.id,
substring="[Incoming message from agent with ID",
max_wait_seconds=10,
sleep_interval=0.5,
)
assert found_in_sarah, "Sarah never received the system message from Charles (timed out)."