fix: Add comprehensive testing for tool creation paths (#3255)
This commit is contained in:
@@ -129,7 +129,8 @@ def get_function_name_and_docstring(source_code: str, name: Optional[str] = None
|
||||
raise LettaToolCreateError("Could not determine function name")
|
||||
|
||||
if not docstring:
|
||||
raise LettaToolCreateError("Docstring is missing")
|
||||
# For tools with args_json_schema, the docstring is optional
|
||||
docstring = f"The {function_name} tool"
|
||||
|
||||
return function_name, docstring
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ast
|
||||
import base64
|
||||
import pickle
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM, SEND_MESSAGE_TOOL_NAME
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -9,7 +9,7 @@ from letta.schemas.response_format import ResponseFormatType, ResponseFormatUnio
|
||||
from letta.types import JsonDict, JsonValue
|
||||
|
||||
|
||||
def parse_stdout_best_effort(text: str | bytes) -> tuple[Any, AgentState | None]:
|
||||
def parse_stdout_best_effort(text: Union[str, bytes]) -> tuple[Any, AgentState | None]:
|
||||
"""
|
||||
Decode and unpickle the result from the function execution if possible.
|
||||
Returns (function_return_value, agent_state).
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic.config import JsonDict
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
|
||||
@@ -23,6 +24,8 @@ from letta.services.tool_sandbox.base import AsyncToolSandboxBase
|
||||
from letta.settings import tool_settings
|
||||
from letta.utils import get_friendly_error_msg, parse_stderr_error_msg
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
METADATA_CONFIG_STATE_KEY = "config_state"
|
||||
@@ -240,9 +243,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
if isinstance(e, TimeoutError):
|
||||
raise e
|
||||
|
||||
print(f"Subprocess execution for tool {self.tool_name} encountered an error: {e}")
|
||||
print(e.__class__.__name__)
|
||||
print(e.__traceback__)
|
||||
logger.error(f"Subprocess execution for tool {self.tool_name} encountered an error: {e}")
|
||||
logger.error(e.__class__.__name__)
|
||||
logger.error(e.__traceback__)
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name,
|
||||
exception_name=type(e).__name__,
|
||||
|
||||
@@ -24,8 +24,32 @@ agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickl
|
||||
{{ tool_source_code }}
|
||||
|
||||
{# Invoke the function and store the result in a global variable #}
|
||||
_function_result = {{ invoke_function_call }}
|
||||
|
||||
{# Use a temporary Pydantic wrapper to recursively serialize any nested Pydantic objects #}
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
class _TempResultWrapper(BaseModel):
|
||||
result: Any
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
_wrapped = _TempResultWrapper(result=_function_result)
|
||||
_serialized_result = _wrapped.model_dump()['result']
|
||||
except ImportError:
|
||||
# Pydantic not available in sandbox, fall back to string conversion
|
||||
print("Pydantic not available in sandbox environment, falling back to string conversion")
|
||||
_serialized_result = str(_function_result)
|
||||
except Exception as e:
|
||||
# If wrapping fails, print the error and stringify the result
|
||||
print(f"Failed to serialize result with Pydantic wrapper: {e}")
|
||||
_serialized_result = str(_function_result)
|
||||
|
||||
{{ local_sandbox_result_var_name }} = {
|
||||
"results": {{ invoke_function_call }},
|
||||
"results": _serialized_result,
|
||||
"agent_state": agent_state
|
||||
}
|
||||
|
||||
|
||||
@@ -26,9 +26,32 @@ agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickl
|
||||
|
||||
{# Async wrapper to handle the function call and store the result #}
|
||||
async def _async_wrapper():
|
||||
result = await {{ invoke_function_call }}
|
||||
_function_result = await {{ invoke_function_call }}
|
||||
|
||||
{# Use a temporary Pydantic wrapper to recursively serialize any nested Pydantic objects #}
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
class _TempResultWrapper(BaseModel):
|
||||
result: Any
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
_wrapped = _TempResultWrapper(result=_function_result)
|
||||
_serialized_result = _wrapped.model_dump()['result']
|
||||
except ImportError:
|
||||
# Pydantic not available in sandbox, fall back to string conversion
|
||||
print("Pydantic not available in sandbox environment, falling back to string conversion")
|
||||
_serialized_result = str(_function_result)
|
||||
except Exception as e:
|
||||
# If wrapping fails, print the error and stringify the result
|
||||
print(f"Failed to serialize result with Pydantic wrapper: {e}")
|
||||
_serialized_result = str(_function_result)
|
||||
|
||||
return {
|
||||
"results": result,
|
||||
"results": _serialized_result,
|
||||
"agent_state": agent_state
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,32 @@ def disable_e2b_api_key() -> Generator[None, None, None]:
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2b_sandbox_mode(request) -> Generator[None, None, None]:
|
||||
"""
|
||||
Parametrizable fixture to enable/disable E2B sandbox mode.
|
||||
|
||||
Usage:
|
||||
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
|
||||
def test_function(e2b_sandbox_mode, ...):
|
||||
# Test runs twice - once with E2B enabled, once disabled
|
||||
"""
|
||||
from letta.settings import tool_settings
|
||||
|
||||
enable_e2b = request.param
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
|
||||
if not enable_e2b:
|
||||
# Disable E2B by setting API key to None
|
||||
tool_settings.e2b_api_key = None
|
||||
# If enable_e2b is True, leave the original API key unchanged
|
||||
|
||||
yield
|
||||
|
||||
# Restore original API key
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_pinecone() -> Generator[None, None, None]:
|
||||
"""
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Type
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import CreateBlock
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client import MessageCreate
|
||||
from letta_client.client import BaseTool
|
||||
from letta_client.core import ApiError
|
||||
from letta_client.types import AgentState, ToolReturnMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Constants
|
||||
SERVER_PORT = 8283
|
||||
@@ -762,3 +766,204 @@ def test_base_tools_upsert_on_list(client: LettaSDKClient):
|
||||
final_tool_names = {tool.name for tool in final_tools}
|
||||
for deleted_tool in tools_to_delete:
|
||||
assert deleted_tool.name in final_tool_names, f"Deleted tool {deleted_tool.name} was not properly restored"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
|
||||
def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
class InventoryItem(BaseModel):
|
||||
sku: str
|
||||
name: str
|
||||
price: float
|
||||
category: str
|
||||
|
||||
class InventoryEntry(BaseModel):
|
||||
timestamp: int
|
||||
item: InventoryItem
|
||||
transaction_id: str
|
||||
|
||||
class InventoryEntryData(BaseModel):
|
||||
data: InventoryEntry
|
||||
quantity_change: int
|
||||
|
||||
class ManageInventoryTool(BaseTool):
|
||||
name: str = "manage_inventory"
|
||||
args_schema: Type[BaseModel] = InventoryEntryData
|
||||
description: str = "Update inventory catalogue with a new data entry"
|
||||
tags: List[str] = ["inventory", "shop"]
|
||||
|
||||
def run(self, data: InventoryEntry, quantity_change: int) -> bool:
|
||||
print(f"Updated inventory for {data.item.name} with a quantity change of {quantity_change}")
|
||||
return True
|
||||
|
||||
tool = client.tools.add(
|
||||
tool=ManageInventoryTool(),
|
||||
)
|
||||
|
||||
assert tool is not None
|
||||
assert tool.name == "manage_inventory"
|
||||
assert "inventory" in tool.tags
|
||||
assert "shop" in tool.tags
|
||||
|
||||
temp_agent = client.agents.create(
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a helpful inventory management assistant.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tool_ids=[tool.id],
|
||||
include_base_tools=False,
|
||||
)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=temp_agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="Update the inventory for product 'iPhone 15' with SKU 'IPH15-001', price $999.99, category 'Electronics', transaction ID 'TXN-12345', timestamp 1640995200, with a quantity change of +10",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
|
||||
tool_call_messages = [msg for msg in response.messages if msg.message_type == "tool_call_message"]
|
||||
assert len(tool_call_messages) > 0, "Expected at least one tool call message"
|
||||
|
||||
first_tool_call = tool_call_messages[0]
|
||||
assert first_tool_call.tool_call.name == "manage_inventory"
|
||||
|
||||
args = json.loads(first_tool_call.tool_call.arguments)
|
||||
assert "data" in args
|
||||
assert "quantity_change" in args
|
||||
assert "item" in args["data"]
|
||||
assert "name" in args["data"]["item"]
|
||||
assert "sku" in args["data"]["item"]
|
||||
assert "price" in args["data"]["item"]
|
||||
assert "category" in args["data"]["item"]
|
||||
assert "transaction_id" in args["data"]
|
||||
assert "timestamp" in args["data"]
|
||||
|
||||
tool_return_messages = [msg for msg in response.messages if msg.message_type == "tool_return_message"]
|
||||
assert len(tool_return_messages) > 0, "Expected at least one tool return message"
|
||||
|
||||
first_tool_return = tool_return_messages[0]
|
||||
assert first_tool_return.status == "success"
|
||||
assert first_tool_return.tool_return == "True"
|
||||
assert "Updated inventory for iPhone 15 with a quantity change of 10" in "\n".join(first_tool_return.stdout)
|
||||
|
||||
client.agents.delete(temp_agent.id)
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
|
||||
def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
|
||||
class Step(BaseModel):
|
||||
name: str = Field(..., description="Name of the step.")
|
||||
description: str = Field(..., description="An exhaustive description of what this step is trying to achieve.")
|
||||
|
||||
class StepsList(BaseModel):
|
||||
steps: List[Step] = Field(..., description="List of steps to add to the task plan.")
|
||||
explanation: str = Field(..., description="Explanation for the list of steps.")
|
||||
|
||||
def create_task_plan(steps, explanation):
|
||||
"""Creates a task plan for the current task."""
|
||||
print(f"Created task plan with {len(steps)} steps: {explanation}")
|
||||
return steps
|
||||
|
||||
tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"])
|
||||
|
||||
assert tool is not None
|
||||
assert tool.name == "create_task_plan"
|
||||
assert "planning" in tool.tags
|
||||
assert "task" in tool.tags
|
||||
|
||||
temp_agent = client.agents.create(
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a helpful task planning assistant.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tool_ids=[tool.id],
|
||||
include_base_tools=False,
|
||||
)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=temp_agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="Create a task plan for organizing a team meeting with 3 steps: 1) Schedule meeting (find available time slots), 2) Send invitations (notify all team members), 3) Prepare agenda (outline discussion topics). Explanation: This plan ensures a well-organized team meeting.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "messages")
|
||||
assert len(response.messages) > 0
|
||||
|
||||
tool_call_messages = [msg for msg in response.messages if msg.message_type == "tool_call_message"]
|
||||
assert len(tool_call_messages) > 0, "Expected at least one tool call message"
|
||||
|
||||
first_tool_call = tool_call_messages[0]
|
||||
assert first_tool_call.tool_call.name == "create_task_plan"
|
||||
|
||||
args = json.loads(first_tool_call.tool_call.arguments)
|
||||
assert "steps" in args
|
||||
assert "explanation" in args
|
||||
assert isinstance(args["steps"], list)
|
||||
assert len(args["steps"]) > 0
|
||||
|
||||
for step in args["steps"]:
|
||||
assert "name" in step
|
||||
assert "description" in step
|
||||
|
||||
tool_return_messages = [msg for msg in response.messages if msg.message_type == "tool_return_message"]
|
||||
assert len(tool_return_messages) > 0, "Expected at least one tool return message"
|
||||
|
||||
first_tool_return = tool_return_messages[0]
|
||||
assert first_tool_return.status == "success"
|
||||
|
||||
client.agents.delete(temp_agent.id)
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
|
||||
def test_create_tool_from_function_with_docstring(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
"""Test creating a tool from a function with a docstring using create_from_function"""
|
||||
|
||||
def roll_dice() -> str:
|
||||
"""
|
||||
Simulate the roll of a 20-sided die (d20).
|
||||
|
||||
This function generates a random integer between 1 and 20, inclusive,
|
||||
which represents the outcome of a single roll of a d20.
|
||||
|
||||
Returns:
|
||||
str: The result of the die roll.
|
||||
"""
|
||||
import random
|
||||
|
||||
dice_role_outcome = random.randint(1, 20)
|
||||
output_string = f"You rolled a {dice_role_outcome}"
|
||||
return output_string
|
||||
|
||||
tool = client.tools.create_from_function(func=roll_dice)
|
||||
|
||||
assert tool is not None
|
||||
assert tool.name == "roll_dice"
|
||||
assert "Simulate the roll of a 20-sided die" in tool.description
|
||||
assert tool.source_code is not None
|
||||
assert "random.randint(1, 20)" in tool.source_code
|
||||
|
||||
all_tools = client.tools.list()
|
||||
tool_names = [t.name for t in all_tools]
|
||||
assert "roll_dice" in tool_names
|
||||
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
Reference in New Issue
Block a user