fix: Fix parallel tool calling test for streaming (#5376)

Fix parallel tool calling test
This commit is contained in:
Matthew Zhou
2025-10-13 11:05:55 -07:00
committed by Caren Thomas
parent b205acf1f1
commit b466cfdb1f

View File

@@ -517,80 +517,80 @@ async def test_greeting(
assert run.status == JobStatus.completed
# @pytest.mark.parametrize(
# "llm_config",
# TESTED_LLM_CONFIGS,
# ids=[c.model for c in TESTED_LLM_CONFIGS],
# )
# @pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background"])
# @pytest.mark.asyncio(loop_scope="function")
# async def test_parallel_tool_call_anthropic_streaming(
# disable_e2b_api_key: Any,
# client: AsyncLetta,
# agent_state: AgentState,
# llm_config: LLMConfig,
# send_type: str,
# ) -> None:
# if llm_config.model_endpoint_type != "anthropic":
# pytest.skip("Parallel tool calling test only applies to Anthropic models.")
#
# agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
#
# if send_type == "step":
# await client.agents.messages.create(
# agent_id=agent_state.id,
# messages=USER_MESSAGE_ROLL_DICE,
# )
# elif send_type == "async":
# run = await client.agents.messages.create_async(
# agent_id=agent_state.id,
# messages=USER_MESSAGE_ROLL_DICE,
# )
# await wait_for_run_completion(client, run.id)
# else:
# client.agents.messages.create_stream(
# agent_id=agent_state.id,
# messages=USER_MESSAGE_ROLL_DICE,
# stream_tokens=(send_type == "stream_tokens"),
# background=(send_type == "stream_tokens_background"),
# )
#
# # validate parallel tool call behavior in preserved messages
# preserved_messages = await client.agents.messages.list(agent_id=agent_state.id)
#
# # find the tool call message in preserved messages
# tool_call_msg = None
# tool_return_msg = None
# for msg in preserved_messages:
# if isinstance(msg, ToolCallMessage):
# tool_call_msg = msg
# elif isinstance(msg, ToolReturnMessage):
# tool_return_msg = msg
#
# # assert parallel tool calls were made
# assert tool_call_msg is not None, "ToolCallMessage not found in preserved messages"
# assert hasattr(tool_call_msg, "tool_calls"), "tool_calls field not found in ToolCallMessage"
# assert len(tool_call_msg.tool_calls) == 3, f"Expected 3 parallel tool calls, got {len(tool_call_msg.tool_calls)}"
#
# # verify each tool call
# for tc in tool_call_msg.tool_calls:
# assert tc["name"] == "roll_dice"
# assert tc["tool_call_id"].startswith("toolu_")
# assert "num_sides" in tc["arguments"]
#
# # assert tool returns match the tool calls
# assert tool_return_msg is not None, "ToolReturnMessage not found in preserved messages"
# assert hasattr(tool_return_msg, "tool_returns"), "tool_returns field not found in ToolReturnMessage"
# assert len(tool_return_msg.tool_returns) == 3, f"Expected 3 tool returns, got {len(tool_return_msg.tool_returns)}"
#
# # verify each tool return
# tool_call_ids = {tc["tool_call_id"] for tc in tool_call_msg.tool_calls}
# for tr in tool_return_msg.tool_returns:
# assert tr["type"] == "tool"
# assert tr["status"] == "success"
# assert tr["tool_call_id"] in tool_call_ids, f"tool_call_id {tr['tool_call_id']} not found in tool calls"
# assert int(tr["tool_return"]) >= 1 and int(tr["tool_return"]) <= 6
#
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
@pytest.mark.parametrize("send_type", ["stream_tokens"]) # ["step", "stream_steps", "stream_tokens", "stream_tokens_background"])
@pytest.mark.asyncio(loop_scope="function")
async def test_parallel_tool_call_anthropic_streaming(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
llm_config: LLMConfig,
send_type: str,
) -> None:
if llm_config.model_endpoint_type != "anthropic":
pytest.skip("Parallel tool calling test only applies to Anthropic models.")
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
if send_type == "step":
await client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
)
elif send_type == "async":
run = await client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
)
await wait_for_run_completion(client, run.id)
else:
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
stream_tokens=(send_type == "stream_tokens"),
background=(send_type == "stream_tokens_background"),
)
await accumulate_chunks(response)
# validate parallel tool call behavior in preserved messages
preserved_messages = await client.agents.messages.list(agent_id=agent_state.id)
# find the tool call message in preserved messages
tool_call_msg = None
tool_return_msg = None
for msg in preserved_messages:
if isinstance(msg, ToolCallMessage):
tool_call_msg = msg
elif isinstance(msg, ToolReturnMessage):
tool_return_msg = msg
# assert parallel tool calls were made
assert tool_call_msg is not None, "ToolCallMessage not found in preserved messages"
assert hasattr(tool_call_msg, "tool_calls"), "tool_calls field not found in ToolCallMessage"
assert len(tool_call_msg.tool_calls) == 3, f"Expected 3 parallel tool calls, got {len(tool_call_msg.tool_calls)}"
# verify each tool call
for tc in tool_call_msg.tool_calls:
assert tc["name"] == "roll_dice"
assert tc["tool_call_id"].startswith("toolu_")
assert "num_sides" in tc["arguments"]
# assert tool returns match the tool calls
assert tool_return_msg is not None, "ToolReturnMessage not found in preserved messages"
assert hasattr(tool_return_msg, "tool_returns"), "tool_returns field not found in ToolReturnMessage"
assert len(tool_return_msg.tool_returns) == 3, f"Expected 3 tool returns, got {len(tool_return_msg.tool_returns)}"
# verify each tool return
tool_call_ids = {tc["tool_call_id"] for tc in tool_call_msg.tool_calls}
for tr in tool_return_msg.tool_returns:
assert tr["type"] == "tool"
assert tr["status"] == "success"
assert tr["tool_call_id"] in tool_call_ids, f"tool_call_id {tr['tool_call_id']} not found in tool calls"
assert int(tr["tool_return"]) >= 1 and int(tr["tool_return"]) <= 6
@pytest.mark.parametrize(