fix: fix parallel tool calling tests in ci [LET-6043] (#5950)

* first hack

* test

* fix test for v1, comment out for legacy

* test shows parallel tool calling now happening

* fix test to detect parallel tool calling

* update to use oai too

* uncomment v2 test

---------

Co-authored-by: Ari Webb <ari@letta.com>
This commit is contained in:
Ari Webb
2025-11-06 14:23:04 -08:00
committed by Caren Thomas
parent 2920ea635b
commit d2fe64bab4

View File

@@ -91,7 +91,10 @@ USER_MESSAGE_ROLL_DICE: List[MessageCreateParam] = [
USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content=("This is an automated test message. Please call the roll_dice tool three times in parallel."),
content=(
"This is an automated test message. Please call the roll_dice tool EXACTLY three times in parallel - no more, no less. "
"Call it with num_sides=6, num_sides=12, and num_sides=20. Make all three calls at the same time in a single response."
),
otid=USER_MESSAGE_OTID,
)
]
@@ -605,19 +608,33 @@ async def test_greeting(
)
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
@pytest.mark.asyncio(loop_scope="function")
async def test_parallel_tool_call_anthropic(
async def test_parallel_tool_calls(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
llm_config: LLMConfig,
send_type: str,
) -> None:
if llm_config.model_endpoint_type != "anthropic":
pytest.skip("Parallel tool calling test only applies to Anthropic models.")
if llm_config.model_endpoint_type not in ["anthropic", "openai", "google_ai", "google_vertex"]:
pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, and Gemini models.")
if llm_config.model in ["gpt-5", "o3"]:
pytest.skip("GPT-5 takes too long to test, o3 is bad at this task.")
# change llm_config to support parallel tool calling
llm_config.parallel_tool_calls = True
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
# Create a copy and modify it to ensure we're not modifying the original
modified_llm_config = llm_config.model_copy(deep=True)
modified_llm_config.parallel_tool_calls = True
# this test was flaking so set temperature to 0.0 to avoid randomness
modified_llm_config.temperature = 0.0
# IMPORTANT: Set parallel_tool_calls at BOTH the agent level and llm_config level
# There are two different parallel_tool_calls fields that need to be set
agent_state = await client.agents.modify(
agent_id=agent_state.id,
llm_config=modified_llm_config,
parallel_tool_calls=True, # Set at agent level as well!
)
if send_type == "step":
await client.agents.messages.send(
@@ -643,48 +660,124 @@ async def test_parallel_tool_call_anthropic(
preserved_messages_page = await client.agents.messages.list(agent_id=agent_state.id)
preserved_messages = preserved_messages_page.items
# find the tool call message in preserved messages
tool_call_msg = None
tool_return_msg = None
# collect all ToolCallMessage and ToolReturnMessage instances
tool_call_messages = []
tool_return_messages = []
for msg in preserved_messages:
if isinstance(msg, ToolCallMessage):
tool_call_msg = msg
tool_call_messages.append(msg)
elif isinstance(msg, ToolReturnMessage):
tool_return_msg = msg
tool_return_messages.append(msg)
# assert parallel tool calls were made
assert tool_call_msg is not None, "ToolCallMessage not found in preserved messages"
assert hasattr(tool_call_msg, "tool_calls"), "tool_calls field not found in ToolCallMessage"
assert len(tool_call_msg.tool_calls) == 3, f"Expected 3 parallel tool calls, got {len(tool_call_msg.tool_calls)}"
# Check if tool calls are grouped in a single message (parallel) or separate messages (sequential)
total_tool_calls = 0
for i, tcm in enumerate(tool_call_messages):
if hasattr(tcm, "tool_calls") and tcm.tool_calls:
num_calls = len(tcm.tool_calls) if isinstance(tcm.tool_calls, list) else 1
total_tool_calls += num_calls
elif hasattr(tcm, "tool_call"):
total_tool_calls += 1
# verify each tool call and collect num_sides values
num_sides_values = []
for tc in tool_call_msg.tool_calls:
assert tc.name == "roll_dice"
assert tc.tool_call_id.startswith("toolu_")
assert "num_sides" in tc.arguments
# Parse the num_sides value from the arguments
import json
# Check tool returns structure
total_tool_returns = 0
for i, trm in enumerate(tool_return_messages):
if hasattr(trm, "tool_returns") and trm.tool_returns:
num_returns = len(trm.tool_returns) if isinstance(trm.tool_returns, list) else 1
total_tool_returns += num_returns
elif hasattr(trm, "tool_return"):
total_tool_returns += 1
args = json.loads(tc.arguments)
num_sides = int(args["num_sides"])
num_sides_values.append(num_sides)
# CRITICAL: For TRUE parallel tool calling with letta_v1_agent, there should be exactly ONE ToolCallMessage
# containing multiple tool calls, not multiple ToolCallMessages
# assert tool returns match the tool calls
assert tool_return_msg is not None, "ToolReturnMessage not found in preserved messages"
assert hasattr(tool_return_msg, "tool_returns"), "tool_returns field not found in ToolReturnMessage"
assert len(tool_return_msg.tool_returns) == 3, f"Expected 3 tool returns, got {len(tool_return_msg.tool_returns)}"
# Verify we have exactly 3 tool calls total
assert total_tool_calls == 3, f"Expected exactly 3 tool calls total, got {total_tool_calls}"
assert total_tool_returns == 3, f"Expected exactly 3 tool returns total, got {total_tool_returns}"
# verify each tool return matches the corresponding tool call's num_sides
tool_call_ids = {tc.tool_call_id for tc in tool_call_msg.tool_calls}
for i, tr in enumerate(tool_return_msg.tool_returns):
assert tr.type == "tool"
assert tr.status == "success"
assert tr.tool_call_id in tool_call_ids, f"tool_call_id {tr.tool_call_id} not found in tool calls"
# Check that the tool return value is within the range of the corresponding num_sides
expected_max = num_sides_values[i] if i < len(num_sides_values) else max(num_sides_values)
assert int(tr.tool_return) >= 1 and int(tr.tool_return) <= expected_max, (
f"Tool return {tr.tool_return} is not within range 1-{expected_max}"
# Check if we have true parallel tool calling
is_parallel = False
if len(tool_call_messages) == 1:
# Check if the single message contains multiple tool calls
tcm = tool_call_messages[0]
if hasattr(tcm, "tool_calls") and isinstance(tcm.tool_calls, list) and len(tcm.tool_calls) == 3:
is_parallel = True
# IMPORTANT: Assert that parallel tool calling is actually working
# This test should FAIL if parallel tool calling is not working properly
assert is_parallel, (
f"Parallel tool calling is NOT working for {llm_config.model_endpoint_type}! "
f"Got {len(tool_call_messages)} ToolCallMessage(s) instead of 1 with 3 parallel calls. "
f"When using letta_v1_agent with parallel_tool_calls=True, all tool calls should be in a single message."
)
# Collect all tool calls and their details for validation
all_tool_calls = []
tool_call_ids = set()
num_sides_by_id = {}
for tcm in tool_call_messages:
if hasattr(tcm, "tool_calls") and tcm.tool_calls and isinstance(tcm.tool_calls, list):
# Message has multiple tool calls
for tc in tcm.tool_calls:
all_tool_calls.append(tc)
tool_call_ids.add(tc.tool_call_id)
# Parse arguments
import json
args = json.loads(tc.arguments)
num_sides_by_id[tc.tool_call_id] = int(args["num_sides"])
elif hasattr(tcm, "tool_call") and tcm.tool_call:
# Message has single tool call
tc = tcm.tool_call
all_tool_calls.append(tc)
tool_call_ids.add(tc.tool_call_id)
# Parse arguments
import json
args = json.loads(tc.arguments)
num_sides_by_id[tc.tool_call_id] = int(args["num_sides"])
# Verify each tool call
for tc in all_tool_calls:
assert tc.name == "roll_dice", f"Expected tool call name 'roll_dice', got '{tc.name}'"
# Support Anthropic (toolu_), OpenAI (call_), and Gemini (UUID) tool call ID formats
# Gemini uses UUID format which could start with any alphanumeric character
valid_id_format = (
tc.tool_call_id.startswith("toolu_")
or tc.tool_call_id.startswith("call_")
or (len(tc.tool_call_id) > 0 and tc.tool_call_id[0].isalnum()) # UUID format for Gemini
)
assert valid_id_format, f"Unexpected tool call ID format: {tc.tool_call_id}"
# Collect all tool returns for validation
all_tool_returns = []
for trm in tool_return_messages:
if hasattr(trm, "tool_returns") and trm.tool_returns and isinstance(trm.tool_returns, list):
# Message has multiple tool returns
all_tool_returns.extend(trm.tool_returns)
elif hasattr(trm, "tool_return") and trm.tool_return:
# Message has single tool return (create a mock object if needed)
# Since ToolReturnMessage might not have individual tool_return, check the structure
pass
# If all_tool_returns is empty, it means returns are structured differently
# Let's check the actual structure
if not all_tool_returns:
print("Note: Tool returns may be structured differently than expected")
# For now, just verify we got the right number of messages
assert len(tool_return_messages) > 0, "No tool return messages found"
# Verify tool returns if we have them in the expected format
for tr in all_tool_returns:
assert tr.type == "tool", f"Tool return type should be 'tool', got '{tr.type}'"
assert tr.status == "success", f"Tool return status should be 'success', got '{tr.status}'"
assert tr.tool_call_id in tool_call_ids, f"Tool return ID '{tr.tool_call_id}' not found in tool call IDs: {tool_call_ids}"
# Verify the dice roll result is within the valid range
dice_result = int(tr.tool_return)
expected_max = num_sides_by_id[tr.tool_call_id]
assert 1 <= dice_result <= expected_max, (
f"Dice roll result {dice_result} is not within valid range 1-{expected_max} for tool call {tr.tool_call_id}"
)