test: Add basic parallel tool calling test to send_message v2 for anthropic [LET-5362] (#5355)

Add basic parallel tool calling test to send_message v2 for anthropic
This commit is contained in:
Matthew Zhou
2025-10-13 10:25:49 -07:00
committed by Caren Thomas
parent 681f4903fd
commit 10a3d86507

View File

@@ -99,6 +99,13 @@ USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreate] = [
MessageCreate(
role="user",
content=("This is an automated test message. Please call the roll_dice tool three times in parallel."),
otid=USER_MESSAGE_OTID,
)
]
def assert_greeting_response(
@@ -515,6 +522,73 @@ async def test_greeting(
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
@pytest.mark.asyncio(loop_scope="function")
async def test_parallel_tool_call_anthropic_streaming(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
if llm_config.model_endpoint_type != "anthropic":
pytest.skip("Parallel tool calling test only applies to Anthropic models.")
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
stream = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
stream_tokens=True,
background=False,
)
messages = await accumulate_chunks(stream)
run_id = messages[0].run_id if messages else None
# validate parallel tool call behavior in preserved messages
preserved_messages = await client.agents.messages.list(agent_id=agent_state.id)
# find the tool call message in preserved messages
tool_call_msg = None
tool_return_msg = None
for msg in preserved_messages:
if isinstance(msg, ToolCallMessage):
tool_call_msg = msg
elif isinstance(msg, ToolReturnMessage):
tool_return_msg = msg
# assert parallel tool calls were made
assert tool_call_msg is not None, "ToolCallMessage not found in preserved messages"
assert hasattr(tool_call_msg, "tool_calls"), "tool_calls field not found in ToolCallMessage"
assert len(tool_call_msg.tool_calls) == 3, f"Expected 3 parallel tool calls, got {len(tool_call_msg.tool_calls)}"
# verify each tool call
for tc in tool_call_msg.tool_calls:
assert tc["name"] == "roll_dice"
assert tc["tool_call_id"].startswith("toolu_")
assert "num_sides" in tc["arguments"]
# assert tool returns match the tool calls
assert tool_return_msg is not None, "ToolReturnMessage not found in preserved messages"
assert hasattr(tool_return_msg, "tool_returns"), "tool_returns field not found in ToolReturnMessage"
assert len(tool_return_msg.tool_returns) == 3, f"Expected 3 tool returns, got {len(tool_return_msg.tool_returns)}"
# verify each tool return
tool_call_ids = {tc["tool_call_id"] for tc in tool_call_msg.tool_calls}
for tr in tool_return_msg.tool_returns:
assert tr["type"] == "tool"
assert tr["status"] == "success"
assert tr["tool_call_id"] in tool_call_ids, f"tool_call_id {tr['tool_call_id']} not found in tool calls"
assert int(tr["tool_return"]) >= 1 and int(tr["tool_return"]) <= 6
if run_id:
await wait_for_run_completion(client, run_id)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
@pytest.mark.parametrize(
["send_type", "cancellation"],
list(