feat: support programmatic tool calling for custom tools [LET-6316] (#6369)
This commit is contained in:
committed by
Caren Thomas
parent
3e02f12dfd
commit
e349ba3bdd
@@ -10,8 +10,10 @@ import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
from letta_client.types import AgentState, MessageCreateParam, ToolReturnMessage
|
||||
from letta_client.types.agents import ToolCallMessage
|
||||
|
||||
from letta.services.tool_executor.builtin_tool_executor import LettaBuiltinToolExecutor
|
||||
from letta.settings import tool_settings
|
||||
|
||||
# ------------------------------
|
||||
# Fixtures
|
||||
@@ -72,9 +74,9 @@ def agent_state(client: Letta) -> AgentState:
|
||||
"""
|
||||
client.tools.upsert_base_tools()
|
||||
|
||||
send_message_tool = list(client.tools.list(name="send_message"))[0]
|
||||
run_code_tool = list(client.tools.list(name="run_code"))[0]
|
||||
web_search_tool = list(client.tools.list(name="web_search"))[0]
|
||||
send_message_tool = client.tools.list(name="send_message").items[0]
|
||||
run_code_tool = client.tools.list(name="run_code").items[0]
|
||||
web_search_tool = client.tools.list(name="web_search").items[0]
|
||||
agent_state_instance = client.agents.create(
|
||||
name="test_builtin_tools_agent",
|
||||
include_base_tools=False,
|
||||
@@ -311,3 +313,178 @@ async def test_web_search_uses_exa():
|
||||
assert "results" in response_json
|
||||
assert response_json["query"] == "test query"
|
||||
assert len(response_json["results"]) == 1
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Programmatic Tool Calling Tests
|
||||
# ------------------------------
|
||||
|
||||
|
||||
ADD_TOOL_SOURCE = """
|
||||
def add(a: int, b: int) -> int:
|
||||
\"\"\"Add two numbers together.
|
||||
|
||||
Args:
|
||||
a (int): The first number.
|
||||
b (int): The second number.
|
||||
|
||||
Returns:
|
||||
int: The sum of a and b.
|
||||
\"\"\"
|
||||
return a + b
|
||||
"""
|
||||
|
||||
MULTIPLY_TOOL_SOURCE = """
|
||||
def multiply(a: int, b: int) -> int:
|
||||
\"\"\"Multiply two numbers together.
|
||||
|
||||
Args:
|
||||
a (int): The first number.
|
||||
b (int): The second number.
|
||||
|
||||
Returns:
|
||||
int: The product of a and b.
|
||||
\"\"\"
|
||||
return a * b
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent_with_custom_tools(client: Letta) -> AgentState:
|
||||
"""
|
||||
Creates an agent with custom add/multiply tools and run_code tool
|
||||
to test programmatic tool calling.
|
||||
"""
|
||||
client.tools.upsert_base_tools()
|
||||
|
||||
# Create custom tools
|
||||
add_tool = client.tools.create(source_code=ADD_TOOL_SOURCE)
|
||||
multiply_tool = client.tools.create(source_code=MULTIPLY_TOOL_SOURCE)
|
||||
|
||||
# Get the run_code tool
|
||||
run_code_tool = client.tools.list(name="run_code").items[0]
|
||||
send_message_tool = client.tools.list(name="send_message").items[0]
|
||||
|
||||
agent_state_instance = client.agents.create(
|
||||
name="test_programmatic_tool_calling_agent",
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id, run_code_tool.id, add_tool.id, multiply_tool.id],
|
||||
model="openai/gpt-4o",
|
||||
embedding="letta/letta-free",
|
||||
tags=["test_programmatic_tool_calling"],
|
||||
)
|
||||
yield agent_state_instance
|
||||
|
||||
# Cleanup
|
||||
client.agents.delete(agent_state_instance.id)
|
||||
client.tools.delete(add_tool.id)
|
||||
client.tools.delete(multiply_tool.id)
|
||||
|
||||
|
||||
def test_programmatic_tool_calling_compose_tools(
|
||||
client: Letta,
|
||||
agent_with_custom_tools: AgentState,
|
||||
) -> None:
|
||||
"""
|
||||
Tests that run_code can compose agent tools programmatically in a SINGLE call.
|
||||
This validates that:
|
||||
1. Tool source code is injected into the sandbox
|
||||
2. Claude composes tools in one run_code call, not multiple separate tool calls
|
||||
3. The result is computed correctly: add(multiply(4, 5), 6) = 26
|
||||
"""
|
||||
# Expected result: multiply(4, 5) = 20, add(20, 6) = 26
|
||||
expected = "26"
|
||||
|
||||
user_message = MessageCreateParam(
|
||||
role="user",
|
||||
content=(
|
||||
"Use the run_code tool to execute Python code that composes the add and multiply tools. "
|
||||
"Calculate add(multiply(4, 5), 6) and return the result. "
|
||||
"The add and multiply functions are already available in the code execution environment. "
|
||||
"Do this in a SINGLE run_code call - do NOT call add or multiply as separate tools."
|
||||
),
|
||||
otid=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_with_custom_tools.id,
|
||||
messages=[user_message],
|
||||
)
|
||||
|
||||
# Extract all tool calls
|
||||
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
|
||||
assert tool_calls, "No ToolCallMessage found for programmatic tool calling test"
|
||||
|
||||
# Verify the agent used run_code to compose tools, not direct add/multiply calls
|
||||
tool_names = [m.tool_call.name for m in tool_calls]
|
||||
run_code_calls = [name for name in tool_names if name == "run_code"]
|
||||
direct_add_calls = [name for name in tool_names if name == "add"]
|
||||
direct_multiply_calls = [name for name in tool_names if name == "multiply"]
|
||||
|
||||
# The key assertion: tools should be composed via run_code, not called directly
|
||||
assert len(run_code_calls) >= 1, f"Expected at least one run_code call, but got tool calls: {tool_names}"
|
||||
assert len(direct_add_calls) == 0, (
|
||||
f"Expected no direct 'add' tool calls (should be called via run_code), but found {len(direct_add_calls)}"
|
||||
)
|
||||
assert len(direct_multiply_calls) == 0, (
|
||||
f"Expected no direct 'multiply' tool calls (should be called via run_code), but found {len(direct_multiply_calls)}"
|
||||
)
|
||||
|
||||
# Verify the result is correct
|
||||
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
|
||||
returns = [m.tool_return for m in tool_returns]
|
||||
assert any(expected in ret for ret in returns), f"Expected to find '{expected}' in tool_return, but got {returns!r}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="function")
|
||||
async def test_run_code_injects_tool_source_code() -> None:
|
||||
"""
|
||||
Unit test that verifies run_code injects agent tool source code into the sandbox.
|
||||
This test directly calls run_code with a mocked agent_state containing tools.
|
||||
"""
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
# Create mock agent state with tools that have source code
|
||||
mock_agent_state = MagicMock()
|
||||
mock_agent_state.tools = [
|
||||
Tool(
|
||||
id="tool-00000001",
|
||||
name="add",
|
||||
source_code=ADD_TOOL_SOURCE.strip(),
|
||||
),
|
||||
Tool(
|
||||
id="tool-00000002",
|
||||
name="multiply",
|
||||
source_code=MULTIPLY_TOOL_SOURCE.strip(),
|
||||
),
|
||||
]
|
||||
|
||||
# Skip if E2B_API_KEY is not set
|
||||
if not tool_settings.e2b_api_key:
|
||||
pytest.skip("E2B_API_KEY not set, skipping run_code test")
|
||||
|
||||
# Create executor with mock dependencies
|
||||
executor = LettaBuiltinToolExecutor(
|
||||
message_manager=MagicMock(),
|
||||
agent_manager=MagicMock(),
|
||||
block_manager=MagicMock(),
|
||||
run_manager=MagicMock(),
|
||||
passage_manager=MagicMock(),
|
||||
actor=MagicMock(),
|
||||
)
|
||||
|
||||
# Execute code that composes the tools
|
||||
# Note: We don't define add/multiply in the code - they should be injected from tool source
|
||||
result = await executor.run_code(
|
||||
agent_state=mock_agent_state,
|
||||
code="print(add(multiply(4, 5), 6))",
|
||||
language="python",
|
||||
)
|
||||
|
||||
response_json = json.loads(result)
|
||||
|
||||
# Verify execution succeeded and returned correct result
|
||||
assert "error" not in response_json or response_json.get("error") is None, f"Code execution failed: {response_json}"
|
||||
assert "26" in str(response_json["results"]) or "26" in str(response_json["logs"]["stdout"]), (
|
||||
f"Expected '26' in results, got: {response_json}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user