From fc07b2b2c2c8b82cd918f8355a4eddb881148abc Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 9 Jul 2025 15:51:50 -0700 Subject: [PATCH] fix: Add comprehensive testing for tool creation paths (#3255) --- letta/functions/ast_parsers.py | 3 +- letta/services/helpers/tool_parser_helper.py | 4 +- letta/services/tool_sandbox/local_sandbox.py | 9 +- letta/templates/sandbox_code_file.py.j2 | 26 ++- letta/templates/sandbox_code_file_async.py.j2 | 27 ++- tests/conftest.py | 26 +++ tests/test_sdk_client.py | 205 ++++++++++++++++++ 7 files changed, 291 insertions(+), 9 deletions(-) diff --git a/letta/functions/ast_parsers.py b/letta/functions/ast_parsers.py index 57785b46..627b7fdb 100644 --- a/letta/functions/ast_parsers.py +++ b/letta/functions/ast_parsers.py @@ -129,7 +129,8 @@ def get_function_name_and_docstring(source_code: str, name: Optional[str] = None raise LettaToolCreateError("Could not determine function name") if not docstring: - raise LettaToolCreateError("Docstring is missing") + # For tools with args_json_schema, the docstring is optional + docstring = f"The {function_name} tool" return function_name, docstring diff --git a/letta/services/helpers/tool_parser_helper.py b/letta/services/helpers/tool_parser_helper.py index f38de929..8bc5333b 100644 --- a/letta/services/helpers/tool_parser_helper.py +++ b/letta/services/helpers/tool_parser_helper.py @@ -1,7 +1,7 @@ import ast import base64 import pickle -from typing import Any +from typing import Any, Union from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM, SEND_MESSAGE_TOOL_NAME from letta.schemas.agent import AgentState @@ -9,7 +9,7 @@ from letta.schemas.response_format import ResponseFormatType, ResponseFormatUnio from letta.types import JsonDict, JsonValue -def parse_stdout_best_effort(text: str | bytes) -> tuple[Any, AgentState | None]: +def parse_stdout_best_effort(text: Union[str, bytes]) -> tuple[Any, AgentState | None]: """ Decode and unpickle the result from the function execution if possible. Returns (function_return_value, agent_state). diff --git a/letta/services/tool_sandbox/local_sandbox.py b/letta/services/tool_sandbox/local_sandbox.py index 3f24fca1..5056adde 100644 --- a/letta/services/tool_sandbox/local_sandbox.py +++ b/letta/services/tool_sandbox/local_sandbox.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Optional from pydantic.config import JsonDict +from letta.log import get_logger from letta.otel.tracing import log_event, trace_method from letta.schemas.agent import AgentState from letta.schemas.sandbox_config import SandboxConfig, SandboxType @@ -23,6 +24,8 @@ from letta.services.tool_sandbox.base import AsyncToolSandboxBase from letta.settings import tool_settings from letta.utils import get_friendly_error_msg, parse_stderr_error_msg +logger = get_logger(__name__) + class AsyncToolSandboxLocal(AsyncToolSandboxBase): METADATA_CONFIG_STATE_KEY = "config_state" @@ -240,9 +243,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): if isinstance(e, TimeoutError): raise e - print(f"Subprocess execution for tool {self.tool_name} encountered an error: {e}") - print(e.__class__.__name__) - print(e.__traceback__) + logger.error(f"Subprocess execution for tool {self.tool_name} encountered an error: {e}") + logger.error(e.__class__.__name__) + logger.error(e.__traceback__) func_return = get_friendly_error_msg( function_name=self.tool_name, exception_name=type(e).__name__, diff --git a/letta/templates/sandbox_code_file.py.j2 b/letta/templates/sandbox_code_file.py.j2 index 953b8ae8..3f4c4517 100644 --- a/letta/templates/sandbox_code_file.py.j2 +++ b/letta/templates/sandbox_code_file.py.j2 @@ -24,8 +24,32 @@ agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickl {{ tool_source_code }} {# Invoke the function and store the result in a global variable #} +_function_result = {{ invoke_function_call }} + +{# Use a temporary Pydantic wrapper to recursively serialize any nested Pydantic objects #} +try: + from pydantic import BaseModel + from typing import Any + + class _TempResultWrapper(BaseModel): + result: Any + + class Config: + arbitrary_types_allowed = True + + _wrapped = _TempResultWrapper(result=_function_result) + _serialized_result = _wrapped.model_dump()['result'] +except ImportError: + # Pydantic not available in sandbox, fall back to string conversion + print("Pydantic not available in sandbox environment, falling back to string conversion") + _serialized_result = str(_function_result) +except Exception as e: + # If wrapping fails, print the error and stringify the result + print(f"Failed to serialize result with Pydantic wrapper: {e}") + _serialized_result = str(_function_result) + {{ local_sandbox_result_var_name }} = { - "results": {{ invoke_function_call }}, + "results": _serialized_result, "agent_state": agent_state } diff --git a/letta/templates/sandbox_code_file_async.py.j2 b/letta/templates/sandbox_code_file_async.py.j2 index 6ed9cdbe..33c8971d 100644 --- a/letta/templates/sandbox_code_file_async.py.j2 +++ b/letta/templates/sandbox_code_file_async.py.j2 @@ -26,9 +26,32 @@ agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickl {# Async wrapper to handle the function call and store the result #} async def _async_wrapper(): - result = await {{ invoke_function_call }} + _function_result = await {{ invoke_function_call }} + + {# Use a temporary Pydantic wrapper to recursively serialize any nested Pydantic objects #} + try: + from pydantic import BaseModel + from typing import Any + + class _TempResultWrapper(BaseModel): + result: Any + + class Config: + arbitrary_types_allowed = True + + _wrapped = _TempResultWrapper(result=_function_result) + _serialized_result = _wrapped.model_dump()['result'] + except ImportError: + # Pydantic not available in sandbox, fall back to string conversion + print("Pydantic not available in sandbox environment, falling back to string conversion") + _serialized_result = str(_function_result) + except Exception as e: + # If wrapping fails, print the error and stringify the result + print(f"Failed to serialize result with Pydantic wrapper: {e}") + _serialized_result = str(_function_result) + return { - "results": result, + "results": _serialized_result, "agent_state": agent_state } diff --git a/tests/conftest.py b/tests/conftest.py index 0abe389d..2d5eb88e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,32 @@ def disable_e2b_api_key() -> Generator[None, None, None]: tool_settings.e2b_api_key = original_api_key +@pytest.fixture +def e2b_sandbox_mode(request) -> Generator[None, None, None]: + """ + Parametrizable fixture to enable/disable E2B sandbox mode. + + Usage: + @pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True) + def test_function(e2b_sandbox_mode, ...): + # Test runs twice - once with E2B enabled, once disabled + """ + from letta.settings import tool_settings + + enable_e2b = request.param + original_api_key = tool_settings.e2b_api_key + + if not enable_e2b: + # Disable E2B by setting API key to None + tool_settings.e2b_api_key = None + # If enable_e2b is True, leave the original API key unchanged + + yield + + # Restore original API key + tool_settings.e2b_api_key = original_api_key + + @pytest.fixture def disable_pinecone() -> Generator[None, None, None]: """ diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 46ed45e0..0bf94b31 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -1,15 +1,19 @@ +import json import os import threading import time import uuid +from typing import List, Type import pytest from dotenv import load_dotenv from letta_client import CreateBlock from letta_client import Letta as LettaSDKClient from letta_client import MessageCreate +from letta_client.client import BaseTool from letta_client.core import ApiError from letta_client.types import AgentState, ToolReturnMessage +from pydantic import BaseModel, Field # Constants SERVER_PORT = 8283 @@ -762,3 +766,204 @@ def test_base_tools_upsert_on_list(client: LettaSDKClient): final_tool_names = {tool.name for tool in final_tools} for deleted_tool in tools_to_delete: assert deleted_tool.name in final_tool_names, f"Deleted tool {deleted_tool.name} was not properly restored" + + +@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True) +def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKClient): + class InventoryItem(BaseModel): + sku: str + name: str + price: float + category: str + + class InventoryEntry(BaseModel): + timestamp: int + item: InventoryItem + transaction_id: str + + class InventoryEntryData(BaseModel): + data: InventoryEntry + quantity_change: int + + class ManageInventoryTool(BaseTool): + name: str = "manage_inventory" + args_schema: Type[BaseModel] = InventoryEntryData + description: str = "Update inventory catalogue with a new data entry" + tags: List[str] = ["inventory", "shop"] + + def run(self, data: InventoryEntry, quantity_change: int) -> bool: + print(f"Updated inventory for {data.item.name} with a quantity change of {quantity_change}") + return True + + tool = client.tools.add( + tool=ManageInventoryTool(), + ) + + assert tool is not None + assert tool.name == "manage_inventory" + assert "inventory" in tool.tags + assert "shop" in tool.tags + + temp_agent = client.agents.create( + memory_blocks=[ + CreateBlock( + label="persona", + value="You are a helpful inventory management assistant.", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + tool_ids=[tool.id], + include_base_tools=False, + ) + + response = client.agents.messages.create( + agent_id=temp_agent.id, + messages=[ + MessageCreate( + role="user", + content="Update the inventory for product 'iPhone 15' with SKU 'IPH15-001', price $999.99, category 'Electronics', transaction ID 'TXN-12345', timestamp 1640995200, with a quantity change of +10", + ), + ], + ) + + assert response is not None + + tool_call_messages = [msg for msg in response.messages if msg.message_type == "tool_call_message"] + assert len(tool_call_messages) > 0, "Expected at least one tool call message" + + first_tool_call = tool_call_messages[0] + assert first_tool_call.tool_call.name == "manage_inventory" + + args = json.loads(first_tool_call.tool_call.arguments) + assert "data" in args + assert "quantity_change" in args + assert "item" in args["data"] + assert "name" in args["data"]["item"] + assert "sku" in args["data"]["item"] + assert "price" in args["data"]["item"] + assert "category" in args["data"]["item"] + assert "transaction_id" in args["data"] + assert "timestamp" in args["data"] + + tool_return_messages = [msg for msg in response.messages if msg.message_type == "tool_return_message"] + assert len(tool_return_messages) > 0, "Expected at least one tool return message" + + first_tool_return = tool_return_messages[0] + assert first_tool_return.status == "success" + assert first_tool_return.tool_return == "True" + assert "Updated inventory for iPhone 15 with a quantity change of 10" in "\n".join(first_tool_return.stdout) + + client.agents.delete(temp_agent.id) + client.tools.delete(tool.id) + + +@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True) +def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient): + + class Step(BaseModel): + name: str = Field(..., description="Name of the step.") + description: str = Field(..., description="An exhaustive description of what this step is trying to achieve.") + + class StepsList(BaseModel): + steps: List[Step] = Field(..., description="List of steps to add to the task plan.") + explanation: str = Field(..., description="Explanation for the list of steps.") + + def create_task_plan(steps, explanation): + """Creates a task plan for the current task.""" + print(f"Created task plan with {len(steps)} steps: {explanation}") + return steps + + tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"]) + + assert tool is not None + assert tool.name == "create_task_plan" + assert "planning" in tool.tags + assert "task" in tool.tags + + temp_agent = client.agents.create( + memory_blocks=[ + CreateBlock( + label="persona", + value="You are a helpful task planning assistant.", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + tool_ids=[tool.id], + include_base_tools=False, + ) + + response = client.agents.messages.create( + agent_id=temp_agent.id, + messages=[ + MessageCreate( + role="user", + content="Create a task plan for organizing a team meeting with 3 steps: 1) Schedule meeting (find available time slots), 2) Send invitations (notify all team members), 3) Prepare agenda (outline discussion topics). Explanation: This plan ensures a well-organized team meeting.", + ), + ], + ) + + assert response is not None + assert hasattr(response, "messages") + assert len(response.messages) > 0 + + tool_call_messages = [msg for msg in response.messages if msg.message_type == "tool_call_message"] + assert len(tool_call_messages) > 0, "Expected at least one tool call message" + + first_tool_call = tool_call_messages[0] + assert first_tool_call.tool_call.name == "create_task_plan" + + args = json.loads(first_tool_call.tool_call.arguments) + assert "steps" in args + assert "explanation" in args + assert isinstance(args["steps"], list) + assert len(args["steps"]) > 0 + + for step in args["steps"]: + assert "name" in step + assert "description" in step + + tool_return_messages = [msg for msg in response.messages if msg.message_type == "tool_return_message"] + assert len(tool_return_messages) > 0, "Expected at least one tool return message" + + first_tool_return = tool_return_messages[0] + assert first_tool_return.status == "success" + + client.agents.delete(temp_agent.id) + client.tools.delete(tool.id) + + +@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True) +def test_create_tool_from_function_with_docstring(e2b_sandbox_mode, client: LettaSDKClient): + """Test creating a tool from a function with a docstring using create_from_function""" + + def roll_dice() -> str: + """ + Simulate the roll of a 20-sided die (d20). + + This function generates a random integer between 1 and 20, inclusive, + which represents the outcome of a single roll of a d20. + + Returns: + str: The result of the die roll. + """ + import random + + dice_role_outcome = random.randint(1, 20) + output_string = f"You rolled a {dice_role_outcome}" + return output_string + + tool = client.tools.create_from_function(func=roll_dice) + + assert tool is not None + assert tool.name == "roll_dice" + assert "Simulate the roll of a 20-sided die" in tool.description + assert tool.source_code is not None + assert "random.randint(1, 20)" in tool.source_code + + all_tools = client.tools.list() + tool_names = [t.name for t in all_tools] + assert "roll_dice" in tool_names + + client.tools.delete(tool.id)