fix: Add comprehensive testing for tool creation paths (#3255)

This commit is contained in:
Matthew Zhou
2025-07-09 15:51:50 -07:00
committed by GitHub
parent fcb894a4e3
commit fc07b2b2c2
7 changed files with 291 additions and 9 deletions

View File

@@ -129,7 +129,8 @@ def get_function_name_and_docstring(source_code: str, name: Optional[str] = None
raise LettaToolCreateError("Could not determine function name")
if not docstring:
raise LettaToolCreateError("Docstring is missing")
# For tools with args_json_schema, the docstring is optional
docstring = f"The {function_name} tool"
return function_name, docstring

View File

@@ -1,7 +1,7 @@
import ast
import base64
import pickle
from typing import Any
from typing import Any, Union
from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM, SEND_MESSAGE_TOOL_NAME
from letta.schemas.agent import AgentState
@@ -9,7 +9,7 @@ from letta.schemas.response_format import ResponseFormatType, ResponseFormatUnio
from letta.types import JsonDict, JsonValue
def parse_stdout_best_effort(text: str | bytes) -> tuple[Any, AgentState | None]:
def parse_stdout_best_effort(text: Union[str, bytes]) -> tuple[Any, AgentState | None]:
"""
Decode and unpickle the result from the function execution if possible.
Returns (function_return_value, agent_state).

View File

@@ -8,6 +8,7 @@ from typing import Any, Dict, Optional
from pydantic.config import JsonDict
from letta.log import get_logger
from letta.otel.tracing import log_event, trace_method
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
@@ -23,6 +24,8 @@ from letta.services.tool_sandbox.base import AsyncToolSandboxBase
from letta.settings import tool_settings
from letta.utils import get_friendly_error_msg, parse_stderr_error_msg
logger = get_logger(__name__)
class AsyncToolSandboxLocal(AsyncToolSandboxBase):
METADATA_CONFIG_STATE_KEY = "config_state"
@@ -240,9 +243,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
if isinstance(e, TimeoutError):
raise e
print(f"Subprocess execution for tool {self.tool_name} encountered an error: {e}")
print(e.__class__.__name__)
print(e.__traceback__)
logger.error(f"Subprocess execution for tool {self.tool_name} encountered an error: {e}")
logger.error(e.__class__.__name__)
logger.error(e.__traceback__)
func_return = get_friendly_error_msg(
function_name=self.tool_name,
exception_name=type(e).__name__,

View File

@@ -24,8 +24,32 @@ agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickl
{{ tool_source_code }}
{# Invoke the function and store the result in a global variable #}
_function_result = {{ invoke_function_call }}
{# Use a temporary Pydantic wrapper to recursively serialize any nested Pydantic objects #}
try:
from pydantic import BaseModel
from typing import Any
class _TempResultWrapper(BaseModel):
result: Any
class Config:
arbitrary_types_allowed = True
_wrapped = _TempResultWrapper(result=_function_result)
_serialized_result = _wrapped.model_dump()['result']
except ImportError:
# Pydantic not available in sandbox, fall back to string conversion
print("Pydantic not available in sandbox environment, falling back to string conversion")
_serialized_result = str(_function_result)
except Exception as e:
# If wrapping fails, print the error and stringify the result
print(f"Failed to serialize result with Pydantic wrapper: {e}")
_serialized_result = str(_function_result)
{{ local_sandbox_result_var_name }} = {
"results": {{ invoke_function_call }},
"results": _serialized_result,
"agent_state": agent_state
}

View File

@@ -26,9 +26,32 @@ agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickl
{# Async wrapper to handle the function call and store the result #}
async def _async_wrapper():
result = await {{ invoke_function_call }}
_function_result = await {{ invoke_function_call }}
{# Use a temporary Pydantic wrapper to recursively serialize any nested Pydantic objects #}
try:
from pydantic import BaseModel
from typing import Any
class _TempResultWrapper(BaseModel):
result: Any
class Config:
arbitrary_types_allowed = True
_wrapped = _TempResultWrapper(result=_function_result)
_serialized_result = _wrapped.model_dump()['result']
except ImportError:
# Pydantic not available in sandbox, fall back to string conversion
print("Pydantic not available in sandbox environment, falling back to string conversion")
_serialized_result = str(_function_result)
except Exception as e:
# If wrapping fails, print the error and stringify the result
print(f"Failed to serialize result with Pydantic wrapper: {e}")
_serialized_result = str(_function_result)
return {
"results": result,
"results": _serialized_result,
"agent_state": agent_state
}

View File

@@ -28,6 +28,32 @@ def disable_e2b_api_key() -> Generator[None, None, None]:
tool_settings.e2b_api_key = original_api_key
@pytest.fixture
def e2b_sandbox_mode(request) -> Generator[None, None, None]:
"""
Parametrizable fixture to enable/disable E2B sandbox mode.
Usage:
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
def test_function(e2b_sandbox_mode, ...):
# Test runs twice - once with E2B enabled, once disabled
"""
from letta.settings import tool_settings
enable_e2b = request.param
original_api_key = tool_settings.e2b_api_key
if not enable_e2b:
# Disable E2B by setting API key to None
tool_settings.e2b_api_key = None
# If enable_e2b is True, leave the original API key unchanged
yield
# Restore original API key
tool_settings.e2b_api_key = original_api_key
@pytest.fixture
def disable_pinecone() -> Generator[None, None, None]:
"""

View File

@@ -1,15 +1,19 @@
import json
import os
import threading
import time
import uuid
from typing import List, Type
import pytest
from dotenv import load_dotenv
from letta_client import CreateBlock
from letta_client import Letta as LettaSDKClient
from letta_client import MessageCreate
from letta_client.client import BaseTool
from letta_client.core import ApiError
from letta_client.types import AgentState, ToolReturnMessage
from pydantic import BaseModel, Field
# Constants
SERVER_PORT = 8283
@@ -762,3 +766,204 @@ def test_base_tools_upsert_on_list(client: LettaSDKClient):
final_tool_names = {tool.name for tool in final_tools}
for deleted_tool in tools_to_delete:
assert deleted_tool.name in final_tool_names, f"Deleted tool {deleted_tool.name} was not properly restored"
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKClient):
class InventoryItem(BaseModel):
sku: str
name: str
price: float
category: str
class InventoryEntry(BaseModel):
timestamp: int
item: InventoryItem
transaction_id: str
class InventoryEntryData(BaseModel):
data: InventoryEntry
quantity_change: int
class ManageInventoryTool(BaseTool):
name: str = "manage_inventory"
args_schema: Type[BaseModel] = InventoryEntryData
description: str = "Update inventory catalogue with a new data entry"
tags: List[str] = ["inventory", "shop"]
def run(self, data: InventoryEntry, quantity_change: int) -> bool:
print(f"Updated inventory for {data.item.name} with a quantity change of {quantity_change}")
return True
tool = client.tools.add(
tool=ManageInventoryTool(),
)
assert tool is not None
assert tool.name == "manage_inventory"
assert "inventory" in tool.tags
assert "shop" in tool.tags
temp_agent = client.agents.create(
memory_blocks=[
CreateBlock(
label="persona",
value="You are a helpful inventory management assistant.",
),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
tool_ids=[tool.id],
include_base_tools=False,
)
response = client.agents.messages.create(
agent_id=temp_agent.id,
messages=[
MessageCreate(
role="user",
content="Update the inventory for product 'iPhone 15' with SKU 'IPH15-001', price $999.99, category 'Electronics', transaction ID 'TXN-12345', timestamp 1640995200, with a quantity change of +10",
),
],
)
assert response is not None
tool_call_messages = [msg for msg in response.messages if msg.message_type == "tool_call_message"]
assert len(tool_call_messages) > 0, "Expected at least one tool call message"
first_tool_call = tool_call_messages[0]
assert first_tool_call.tool_call.name == "manage_inventory"
args = json.loads(first_tool_call.tool_call.arguments)
assert "data" in args
assert "quantity_change" in args
assert "item" in args["data"]
assert "name" in args["data"]["item"]
assert "sku" in args["data"]["item"]
assert "price" in args["data"]["item"]
assert "category" in args["data"]["item"]
assert "transaction_id" in args["data"]
assert "timestamp" in args["data"]
tool_return_messages = [msg for msg in response.messages if msg.message_type == "tool_return_message"]
assert len(tool_return_messages) > 0, "Expected at least one tool return message"
first_tool_return = tool_return_messages[0]
assert first_tool_return.status == "success"
assert first_tool_return.tool_return == "True"
assert "Updated inventory for iPhone 15 with a quantity change of 10" in "\n".join(first_tool_return.stdout)
client.agents.delete(temp_agent.id)
client.tools.delete(tool.id)
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
class Step(BaseModel):
name: str = Field(..., description="Name of the step.")
description: str = Field(..., description="An exhaustive description of what this step is trying to achieve.")
class StepsList(BaseModel):
steps: List[Step] = Field(..., description="List of steps to add to the task plan.")
explanation: str = Field(..., description="Explanation for the list of steps.")
def create_task_plan(steps, explanation):
"""Creates a task plan for the current task."""
print(f"Created task plan with {len(steps)} steps: {explanation}")
return steps
tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"])
assert tool is not None
assert tool.name == "create_task_plan"
assert "planning" in tool.tags
assert "task" in tool.tags
temp_agent = client.agents.create(
memory_blocks=[
CreateBlock(
label="persona",
value="You are a helpful task planning assistant.",
),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
tool_ids=[tool.id],
include_base_tools=False,
)
response = client.agents.messages.create(
agent_id=temp_agent.id,
messages=[
MessageCreate(
role="user",
content="Create a task plan for organizing a team meeting with 3 steps: 1) Schedule meeting (find available time slots), 2) Send invitations (notify all team members), 3) Prepare agenda (outline discussion topics). Explanation: This plan ensures a well-organized team meeting.",
),
],
)
assert response is not None
assert hasattr(response, "messages")
assert len(response.messages) > 0
tool_call_messages = [msg for msg in response.messages if msg.message_type == "tool_call_message"]
assert len(tool_call_messages) > 0, "Expected at least one tool call message"
first_tool_call = tool_call_messages[0]
assert first_tool_call.tool_call.name == "create_task_plan"
args = json.loads(first_tool_call.tool_call.arguments)
assert "steps" in args
assert "explanation" in args
assert isinstance(args["steps"], list)
assert len(args["steps"]) > 0
for step in args["steps"]:
assert "name" in step
assert "description" in step
tool_return_messages = [msg for msg in response.messages if msg.message_type == "tool_return_message"]
assert len(tool_return_messages) > 0, "Expected at least one tool return message"
first_tool_return = tool_return_messages[0]
assert first_tool_return.status == "success"
client.agents.delete(temp_agent.id)
client.tools.delete(tool.id)
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
def test_create_tool_from_function_with_docstring(e2b_sandbox_mode, client: LettaSDKClient):
"""Test creating a tool from a function with a docstring using create_from_function"""
def roll_dice() -> str:
"""
Simulate the roll of a 20-sided die (d20).
This function generates a random integer between 1 and 20, inclusive,
which represents the outcome of a single roll of a d20.
Returns:
str: The result of the die roll.
"""
import random
dice_role_outcome = random.randint(1, 20)
output_string = f"You rolled a {dice_role_outcome}"
return output_string
tool = client.tools.create_from_function(func=roll_dice)
assert tool is not None
assert tool.name == "roll_dice"
assert "Simulate the roll of a 20-sided die" in tool.description
assert tool.source_code is not None
assert "random.randint(1, 20)" in tool.source_code
all_tools = client.tools.list()
tool_names = [t.name for t in all_tools]
assert "roll_dice" in tool_names
client.tools.delete(tool.id)