diff --git a/tests/integration_test_send_message_v2.py b/tests/integration_test_send_message_v2.py index 67755ac4..c83a9f5e 100644 --- a/tests/integration_test_send_message_v2.py +++ b/tests/integration_test_send_message_v2.py @@ -517,80 +517,80 @@ async def test_greeting( assert run.status == JobStatus.completed -# @pytest.mark.parametrize( -# "llm_config", -# TESTED_LLM_CONFIGS, -# ids=[c.model for c in TESTED_LLM_CONFIGS], -# ) -# @pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background"]) -# @pytest.mark.asyncio(loop_scope="function") -# async def test_parallel_tool_call_anthropic_streaming( -# disable_e2b_api_key: Any, -# client: AsyncLetta, -# agent_state: AgentState, -# llm_config: LLMConfig, -# send_type: str, -# ) -> None: -# if llm_config.model_endpoint_type != "anthropic": -# pytest.skip("Parallel tool calling test only applies to Anthropic models.") -# -# agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) -# -# if send_type == "step": -# await client.agents.messages.create( -# agent_id=agent_state.id, -# messages=USER_MESSAGE_ROLL_DICE, -# ) -# elif send_type == "async": -# run = await client.agents.messages.create_async( -# agent_id=agent_state.id, -# messages=USER_MESSAGE_ROLL_DICE, -# ) -# await wait_for_run_completion(client, run.id) -# else: -# client.agents.messages.create_stream( -# agent_id=agent_state.id, -# messages=USER_MESSAGE_ROLL_DICE, -# stream_tokens=(send_type == "stream_tokens"), -# background=(send_type == "stream_tokens_background"), -# ) -# -# # validate parallel tool call behavior in preserved messages -# preserved_messages = await client.agents.messages.list(agent_id=agent_state.id) -# -# # find the tool call message in preserved messages -# tool_call_msg = None -# tool_return_msg = None -# for msg in preserved_messages: -# if isinstance(msg, ToolCallMessage): -# tool_call_msg = msg -# elif isinstance(msg, ToolReturnMessage): -# tool_return_msg = msg -# -# # assert parallel tool calls were made -# assert tool_call_msg is not None, "ToolCallMessage not found in preserved messages" -# assert hasattr(tool_call_msg, "tool_calls"), "tool_calls field not found in ToolCallMessage" -# assert len(tool_call_msg.tool_calls) == 3, f"Expected 3 parallel tool calls, got {len(tool_call_msg.tool_calls)}" -# -# # verify each tool call -# for tc in tool_call_msg.tool_calls: -# assert tc["name"] == "roll_dice" -# assert tc["tool_call_id"].startswith("toolu_") -# assert "num_sides" in tc["arguments"] -# -# # assert tool returns match the tool calls -# assert tool_return_msg is not None, "ToolReturnMessage not found in preserved messages" -# assert hasattr(tool_return_msg, "tool_returns"), "tool_returns field not found in ToolReturnMessage" -# assert len(tool_return_msg.tool_returns) == 3, f"Expected 3 tool returns, got {len(tool_return_msg.tool_returns)}" -# -# # verify each tool return -# tool_call_ids = {tc["tool_call_id"] for tc in tool_call_msg.tool_calls} -# for tr in tool_return_msg.tool_returns: -# assert tr["type"] == "tool" -# assert tr["status"] == "success" -# assert tr["tool_call_id"] in tool_call_ids, f"tool_call_id {tr['tool_call_id']} not found in tool calls" -# assert int(tr["tool_return"]) >= 1 and int(tr["tool_return"]) <= 6 -# +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +@pytest.mark.parametrize("send_type", ["stream_tokens"]) # ["step", "stream_steps", "stream_tokens", "stream_tokens_background"]) +@pytest.mark.asyncio(loop_scope="function") +async def test_parallel_tool_call_anthropic_streaming( + disable_e2b_api_key: Any, + client: AsyncLetta, + agent_state: AgentState, + llm_config: LLMConfig, + send_type: str, +) -> None: + if llm_config.model_endpoint_type != "anthropic": + pytest.skip("Parallel tool calling test only applies to Anthropic models.") + + agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + if send_type == "step": + await client.agents.messages.create( + agent_id=agent_state.id, + messages=USER_MESSAGE_PARALLEL_TOOL_CALL, + ) + elif send_type == "async": + run = await client.agents.messages.create_async( + agent_id=agent_state.id, + messages=USER_MESSAGE_PARALLEL_TOOL_CALL, + ) + await wait_for_run_completion(client, run.id) + else: + response = client.agents.messages.create_stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_PARALLEL_TOOL_CALL, + stream_tokens=(send_type == "stream_tokens"), + background=(send_type == "stream_tokens_background"), + ) + await accumulate_chunks(response) + + # validate parallel tool call behavior in preserved messages + preserved_messages = await client.agents.messages.list(agent_id=agent_state.id) + + # find the tool call message in preserved messages + tool_call_msg = None + tool_return_msg = None + for msg in preserved_messages: + if isinstance(msg, ToolCallMessage): + tool_call_msg = msg + elif isinstance(msg, ToolReturnMessage): + tool_return_msg = msg + + # assert parallel tool calls were made + assert tool_call_msg is not None, "ToolCallMessage not found in preserved messages" + assert hasattr(tool_call_msg, "tool_calls"), "tool_calls field not found in ToolCallMessage" + assert len(tool_call_msg.tool_calls) == 3, f"Expected 3 parallel tool calls, got {len(tool_call_msg.tool_calls)}" + + # verify each tool call + for tc in tool_call_msg.tool_calls: + assert tc["name"] == "roll_dice" + assert tc["tool_call_id"].startswith("toolu_") + assert "num_sides" in tc["arguments"] + + # assert tool returns match the tool calls + assert tool_return_msg is not None, "ToolReturnMessage not found in preserved messages" + assert hasattr(tool_return_msg, "tool_returns"), "tool_returns field not found in ToolReturnMessage" + assert len(tool_return_msg.tool_returns) == 3, f"Expected 3 tool returns, got {len(tool_return_msg.tool_returns)}" + + # verify each tool return + tool_call_ids = {tc["tool_call_id"] for tc in tool_call_msg.tool_calls} + for tr in tool_return_msg.tool_returns: + assert tr["type"] == "tool" + assert tr["status"] == "success" + assert tr["tool_call_id"] in tool_call_ids, f"tool_call_id {tr['tool_call_id']} not found in tool calls" + assert int(tr["tool_return"]) >= 1 and int(tr["tool_return"]) <= 6 @pytest.mark.parametrize(