chore: Add testing around base tools (#2268)

This commit is contained in:
Matthew Zhou
2024-12-17 15:46:05 -08:00
committed by GitHub
parent 6fb2968006
commit 7fb8f16155
9 changed files with 135 additions and 174 deletions

View File

@@ -298,8 +298,6 @@ class Agent(BaseAgent):
self.agent_manager = AgentManager()
# State needed for heartbeat pausing
self.pause_heartbeats_start = None
self.pause_heartbeats_minutes = 0
self.first_message_verify_mono = first_message_verify_mono
@@ -1259,17 +1257,6 @@ class Agent(BaseAgent):
printd(f"Ran summarizer, messages length {prior_len} -> {len(self.messages)}")
def heartbeat_is_paused(self):
"""Check if there's a requested pause on timed heartbeats"""
# Check if the pause has been initiated
if self.pause_heartbeats_start is None:
return False
# Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start
elapsed_time = get_utc_time() - self.pause_heartbeats_start
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60
def _swap_system_message_in_buffer(self, new_system_message: str):
"""Update the system message (NOT prompt) of the Agent (requires updating the internal buffer)"""
assert isinstance(new_system_message, str)
@@ -1394,7 +1381,7 @@ class Agent(BaseAgent):
agent_manager: AgentManager,
):
"""Attach a source to the agent using the SourcesAgents ORM relationship.
Args:
user: User performing the action
source_id: ID of the source to attach

View File

@@ -38,7 +38,8 @@ DEFAULT_HUMAN = "basic"
DEFAULT_PRESET = "memgpt_chat"
# Base tools that cannot be edited, as they access agent state directly
BASE_TOOLS = ["send_message", "conversation_search", "conversation_search_date", "archival_memory_insert", "archival_memory_search"]
# Note that we don't include "conversation_search_date" for now
BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "archival_memory_search"]
O1_BASE_TOOLS = ["send_thinking_message", "send_final_message"]
# Base memory tools CAN be edited, and are added by default by the server
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]

View File

@@ -1,16 +1,6 @@
from datetime import datetime
from typing import Optional
from letta.agent import Agent
from letta.constants import MAX_PAUSE_HEARTBEATS
from letta.services.agent_manager import AgentManager
# import math
# from letta.utils import json_dumps
### Functions / tools the agent can use
# All functions should return a response string (or None)
# If the function fails, throw an exception
def send_message(self: "Agent", message: str) -> Optional[str]:
@@ -28,36 +18,6 @@ def send_message(self: "Agent", message: str) -> Optional[str]:
return None
# Construct the docstring dynamically (since it should use the external constants)
pause_heartbeats_docstring = f"""
Temporarily ignore timed heartbeats. You may still receive messages from manual heartbeats and other events.
Args:
minutes (int): Number of minutes to ignore heartbeats for. Max value of {MAX_PAUSE_HEARTBEATS} minutes ({MAX_PAUSE_HEARTBEATS // 60} hours).
Returns:
str: Function status response
"""
def pause_heartbeats(self: "Agent", minutes: int) -> Optional[str]:
import datetime
from letta.constants import MAX_PAUSE_HEARTBEATS
minutes = min(MAX_PAUSE_HEARTBEATS, minutes)
# Record the current time
self.pause_heartbeats_start = datetime.datetime.now(datetime.timezone.utc)
# And record how long the pause should go for
self.pause_heartbeats_minutes = int(minutes)
return f"Pausing timed heartbeats for {minutes} min"
pause_heartbeats.__doc__ = pause_heartbeats_docstring
def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search prior conversation history using case-insensitive string matching.
@@ -84,19 +44,19 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
# TODO: add paging by page number. currently cursor only works with strings.
# original: start=page * count
results = self.message_manager.list_user_messages_for_agent(
messages = self.message_manager.list_user_messages_for_agent(
agent_id=self.agent_state.id,
actor=self.user,
query_text=query,
limit=count,
)
total = len(results)
total = len(messages)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
if len(messages) == 0:
results_str = f"No results found."
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
results_pref = f"Showing {len(messages)} of {total} results (page {page}/{num_pages}):"
results_formatted = [message.text for message in messages]
results_str = f"{results_pref} {json_dumps(results_formatted)}"
return results_str
@@ -114,6 +74,7 @@ def conversation_search_date(self: "Agent", start_date: str, end_date: str, page
str: Query result string
"""
import math
from datetime import datetime
from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.utils import json_dumps
@@ -142,7 +103,6 @@ def conversation_search_date(self: "Agent", start_date: str, end_date: str, page
start_date=start_datetime,
end_date=end_datetime,
limit=count,
# start_date=start_date, end_date=end_date, limit=count, start=page * count
)
total = len(results)
num_pages = math.ceil(total / count) - 1 # 0 index
@@ -186,10 +146,8 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s
Returns:
str: Query result string
"""
import math
from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.utils import json_dumps
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
@@ -198,7 +156,7 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s
except:
raise ValueError(f"'page' argument must be an integer")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
try:
# Get results using passage manager
all_results = self.agent_manager.list_passages(
@@ -207,7 +165,7 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s
query_text=query,
limit=count + start, # Request enough results to handle offset
embedding_config=self.agent_state.embedding_config,
embed_query=True
embed_query=True,
)
# Apply pagination
@@ -215,13 +173,7 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s
paged_results = all_results[start:end]
# Format results to match previous implementation
formatted_results = [
{
"timestamp": str(result.created_at),
"content": result.text
}
for result in paged_results
]
formatted_results = [{"timestamp": str(result.created_at), "content": result.text} for result in paged_results]
return formatted_results, len(formatted_results)

View File

@@ -386,7 +386,7 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
# append the heartbeat
# TODO: don't hard-code
# TODO: if terminal, don't include this
if function.__name__ not in ["send_message", "pause_heartbeats"]:
if function.__name__ not in ["send_message"]:
schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",

View File

@@ -3,7 +3,7 @@ import json
from letta.utils import json_dumps, json_loads
NO_HEARTBEAT_FUNCS = ["send_message", "pause_heartbeats"]
NO_HEARTBEAT_FUNCS = ["send_message"]
def insert_heartbeat(message):

View File

@@ -40,13 +40,3 @@ for org in orgs:
),
actor=fake_user,
)
ToolManager().create_or_update_tool(
Tool(
name="pause_heartbeats",
source_code=source_code,
source_type=source_type,
description=description,
),
actor=fake_user,
)

View File

@@ -22,3 +22,12 @@ def mock_e2b_api_key_none():
# Restore the original value of e2b_api_key
tool_settings.e2b_api_key = original_api_key
@pytest.fixture
def check_e2b_key_is_set():
from letta.settings import tool_settings
original_api_key = tool_settings.e2b_api_key
assert original_api_key is not None, "Missing e2b key! Cannot execute these tests."
yield

View File

@@ -8,7 +8,7 @@ import pytest
from sqlalchemy import delete
from letta import create_client
from letta.functions.function_sets.base import core_memory_replace
from letta.functions.function_sets.base import core_memory_append, core_memory_replace
from letta.orm import SandboxConfig, SandboxEnvironmentVariable
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
@@ -57,13 +57,6 @@ def clear_tables():
Sandbox.connect(sandbox.sandbox_id).kill()
@pytest.fixture
def check_e2b_key_is_set():
original_api_key = tool_settings.e2b_api_key
assert original_api_key is not None, "Missing e2b key! Cannot execute these tests."
yield
@pytest.fixture
def check_composio_key_set():
original_api_key = tool_settings.composio_api_key
@@ -217,13 +210,6 @@ def clear_core_memory_tool(test_user):
yield tool
@pytest.fixture
def core_memory_replace_tool(test_user):
tool = create_tool_from_func(core_memory_replace)
tool = ToolManager().create_or_update_tool(tool, test_user)
yield tool
@pytest.fixture
def external_codebase_tool(test_user):
from tests.test_tool_sandbox.restaurant_management_system.adjust_menu_prices import (
@@ -273,6 +259,21 @@ def custom_test_sandbox_config(test_user):
return manager, local_sandbox_config
# Tool-specific fixtures
@pytest.fixture
def core_memory_tools(test_user):
"""Create all base tools for testing."""
tools = {}
for func in [
core_memory_replace,
core_memory_append,
]:
tool = create_tool_from_func(func)
tool = ToolManager().create_or_update_tool(tool, test_user)
tools[func.__name__] = tool
yield tools
# Local sandbox tests
@@ -303,30 +304,6 @@ def test_local_sandbox_stateful_tool(mock_e2b_api_key_none, clear_core_memory_to
assert result.func_return is None
@pytest.mark.local_sandbox
def test_local_sandbox_core_memory_replace(mock_e2b_api_key_none, core_memory_replace_tool, test_user, agent_state):
new_name = "Matt"
args = {"label": "human", "old_content": "Chad", "new_content": new_name}
sandbox = ToolExecutionSandbox(core_memory_replace_tool.name, args, user_id=test_user.id)
# run the sandbox
result = sandbox.run(agent_state=agent_state)
assert new_name in result.agent_state.memory.get_block("human").value
assert result.func_return is None
@pytest.mark.local_sandbox
def test_local_sandbox_core_memory_replace_errors(mock_e2b_api_key_none, core_memory_replace_tool, test_user, agent_state):
nonexistent_name = "Alexander Wang"
args = {"label": "human", "old_content": nonexistent_name, "new_content": "Matt"}
sandbox = ToolExecutionSandbox(core_memory_replace_tool.name, args, user_id=test_user.id)
# run the sandbox
result = sandbox.run(agent_state=agent_state)
assert len(result.stderr) != 0, "stderr not empty"
assert f"ValueError: Old content '{nonexistent_name}' not found in memory block 'human'" in result.stderr[0], "stderr contains expected error"
@pytest.mark.local_sandbox
def test_local_sandbox_with_list_rv(mock_e2b_api_key_none, list_tool, test_user):
sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id)
@@ -474,42 +451,6 @@ def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool,
assert result.func_return is None
@pytest.mark.e2b_sandbox
def test_e2b_sandbox_core_memory_replace(check_e2b_key_is_set, core_memory_replace_tool, test_user, agent_state):
new_name = "Matt"
args = {"label": "human", "old_content": "Chad", "new_content": new_name}
sandbox = ToolExecutionSandbox(core_memory_replace_tool.name, args, user_id=test_user.id)
# run the sandbox
result = sandbox.run(agent_state=agent_state)
assert new_name in result.agent_state.memory.get_block("human").value
assert result.func_return is None
@pytest.mark.e2b_sandbox
def test_e2b_sandbox_escape_strings_in_args(check_e2b_key_is_set, core_memory_replace_tool, test_user, agent_state):
new_name = "Matt"
args = {"label": "human", "old_content": "Chad", "new_content": new_name + "\n"}
sandbox = ToolExecutionSandbox(core_memory_replace_tool.name, args, user_id=test_user.id)
# run the sandbox
result = sandbox.run(agent_state=agent_state)
assert new_name in result.agent_state.memory.get_block("human").value
assert result.func_return is None
@pytest.mark.e2b_sandbox
def test_e2b_sandbox_core_memory_replace_errors(check_e2b_key_is_set, core_memory_replace_tool, test_user, agent_state):
nonexistent_name = "Alexander Wang"
args = {"label": "human", "old_content": nonexistent_name, "new_content": "Matt"}
sandbox = ToolExecutionSandbox(core_memory_replace_tool.name, args, user_id=test_user.id)
# run the sandbox
result = sandbox.run(agent_state=agent_state)
assert len(result.stderr) != 0, "stderr not empty"
assert f"ValueError: Old content '{nonexistent_name}' not found in memory block 'human'" in result.stderr[0], "stderr contains expected error"
@pytest.mark.e2b_sandbox
def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_env_tool, test_user):
manager = SandboxConfigManager(tool_settings)
@@ -585,3 +526,79 @@ def test_e2b_e2e_composio_star_github(check_e2b_key_is_set, check_composio_key_s
result = ToolExecutionSandbox(composio_github_star_tool.name, {"owner": "letta-ai", "repo": "letta"}, user_id=test_user.id).run()
assert result.func_return["details"] == "Action executed successfully"
# Core memory integration tests
class TestCoreMemoryTools:
"""
Tests for core memory manipulation tools.
Tests run in both local sandbox and e2b environments.
"""
# Local sandbox tests
@pytest.mark.local_sandbox
def test_core_memory_replace_local(self, mock_e2b_api_key_none, core_memory_tools, test_user, agent_state):
"""Test successful replacement of content in core memory - local sandbox."""
new_name = "Charles"
args = {"label": "human", "old_content": "Chad", "new_content": new_name}
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user_id=test_user.id)
result = sandbox.run(agent_state=agent_state)
assert new_name in result.agent_state.memory.get_block("human").value
assert result.func_return is None
@pytest.mark.local_sandbox
def test_core_memory_append_local(self, mock_e2b_api_key_none, core_memory_tools, test_user, agent_state):
"""Test successful appending of content to core memory - local sandbox."""
append_text = "\nLikes coffee"
args = {"label": "human", "content": append_text}
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_append"].name, args, user_id=test_user.id)
result = sandbox.run(agent_state=agent_state)
assert append_text in result.agent_state.memory.get_block("human").value
assert result.func_return is None
@pytest.mark.local_sandbox
def test_core_memory_replace_error_local(self, mock_e2b_api_key_none, core_memory_tools, test_user, agent_state):
"""Test error handling when trying to replace non-existent content - local sandbox."""
nonexistent_name = "Alexander Wang"
args = {"label": "human", "old_content": nonexistent_name, "new_content": "Charles"}
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user_id=test_user.id)
result = sandbox.run(agent_state=agent_state)
assert len(result.stderr) != 0
assert f"ValueError: Old content '{nonexistent_name}' not found in memory block 'human'" in result.stderr[0]
# E2B sandbox tests
@pytest.mark.e2b_sandbox
def test_core_memory_replace_e2b(self, check_e2b_key_is_set, core_memory_tools, test_user, agent_state):
"""Test successful replacement of content in core memory - e2b sandbox."""
new_name = "Charles"
args = {"label": "human", "old_content": "Chad", "new_content": new_name}
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user_id=test_user.id)
result = sandbox.run(agent_state=agent_state)
assert new_name in result.agent_state.memory.get_block("human").value
assert result.func_return is None
@pytest.mark.e2b_sandbox
def test_core_memory_append_e2b(self, check_e2b_key_is_set, core_memory_tools, test_user, agent_state):
"""Test successful appending of content to core memory - e2b sandbox."""
append_text = "\nLikes coffee"
args = {"label": "human", "content": append_text}
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_append"].name, args, user_id=test_user.id)
result = sandbox.run(agent_state=agent_state)
assert append_text in result.agent_state.memory.get_block("human").value
assert result.func_return is None
@pytest.mark.e2b_sandbox
def test_core_memory_replace_error_e2b(self, check_e2b_key_is_set, core_memory_tools, test_user, agent_state):
"""Test error handling when trying to replace non-existent content - e2b sandbox."""
nonexistent_name = "Alexander Wang"
args = {"label": "human", "old_content": nonexistent_name, "new_content": "Charles"}
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user_id=test_user.id)
result = sandbox.run(agent_state=agent_state)
assert len(result.stderr) != 0
assert f"ValueError: Old content '{nonexistent_name}' not found in memory block 'human'" in result.stderr[0]

View File

@@ -1,28 +1,25 @@
import pytest
import letta.functions.function_sets.base as base_functions
from letta import create_client
from letta import LocalClient, create_client
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from .utils import wipe_config
# test_agent_id = "test_agent"
client = None
@pytest.fixture(scope="module")
def agent_obj():
"""Create a test agent that we can call functions on"""
wipe_config()
global client
def client():
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
@pytest.fixture(scope="module")
def agent_obj(client: LocalClient):
"""Create a test agent that we can call functions on"""
agent_state = client.create_agent()
global agent_obj
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
yield agent_obj
@@ -88,7 +85,15 @@ def test_archival(agent_obj):
pass
def test_recall(agent_obj):
base_functions.conversation_search(agent_obj, "banana")
base_functions.conversation_search(agent_obj, "banana", page=0)
base_functions.conversation_search_date(agent_obj, start_date="2022-01-01", end_date="2022-01-02")
def test_recall(client, agent_obj):
# keyword
keyword = "banana"
# Send messages to agent
client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="hello")
client.send_message(agent_id=agent_obj.agent_state.id, role="user", message=keyword)
client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="tell me a fun fact")
# Conversation search
result = base_functions.conversation_search(agent_obj, "banana")
assert keyword in result