test: Add basic parallel tool calling test to send_message v2 for anthropic [LET-5362] (#5355)
Add basic parallel tool calling test to send_message v2 for anthropic
This commit is contained in:
committed by
Caren Thomas
parent
681f4903fd
commit
10a3d86507
@@ -99,6 +99,13 @@ USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=("This is an automated test message. Please call the roll_dice tool three times in parallel."),
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def assert_greeting_response(
|
||||
@@ -515,6 +522,73 @@ async def test_greeting(
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_parallel_tool_call_anthropic_streaming(
|
||||
disable_e2b_api_key: Any,
|
||||
client: AsyncLetta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
if llm_config.model_endpoint_type != "anthropic":
|
||||
pytest.skip("Parallel tool calling test only applies to Anthropic models.")
|
||||
|
||||
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
stream = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
|
||||
stream_tokens=True,
|
||||
background=False,
|
||||
)
|
||||
messages = await accumulate_chunks(stream)
|
||||
run_id = messages[0].run_id if messages else None
|
||||
|
||||
# validate parallel tool call behavior in preserved messages
|
||||
preserved_messages = await client.agents.messages.list(agent_id=agent_state.id)
|
||||
|
||||
# find the tool call message in preserved messages
|
||||
tool_call_msg = None
|
||||
tool_return_msg = None
|
||||
for msg in preserved_messages:
|
||||
if isinstance(msg, ToolCallMessage):
|
||||
tool_call_msg = msg
|
||||
elif isinstance(msg, ToolReturnMessage):
|
||||
tool_return_msg = msg
|
||||
|
||||
# assert parallel tool calls were made
|
||||
assert tool_call_msg is not None, "ToolCallMessage not found in preserved messages"
|
||||
assert hasattr(tool_call_msg, "tool_calls"), "tool_calls field not found in ToolCallMessage"
|
||||
assert len(tool_call_msg.tool_calls) == 3, f"Expected 3 parallel tool calls, got {len(tool_call_msg.tool_calls)}"
|
||||
|
||||
# verify each tool call
|
||||
for tc in tool_call_msg.tool_calls:
|
||||
assert tc["name"] == "roll_dice"
|
||||
assert tc["tool_call_id"].startswith("toolu_")
|
||||
assert "num_sides" in tc["arguments"]
|
||||
|
||||
# assert tool returns match the tool calls
|
||||
assert tool_return_msg is not None, "ToolReturnMessage not found in preserved messages"
|
||||
assert hasattr(tool_return_msg, "tool_returns"), "tool_returns field not found in ToolReturnMessage"
|
||||
assert len(tool_return_msg.tool_returns) == 3, f"Expected 3 tool returns, got {len(tool_return_msg.tool_returns)}"
|
||||
|
||||
# verify each tool return
|
||||
tool_call_ids = {tc["tool_call_id"] for tc in tool_call_msg.tool_calls}
|
||||
for tr in tool_return_msg.tool_returns:
|
||||
assert tr["type"] == "tool"
|
||||
assert tr["status"] == "success"
|
||||
assert tr["tool_call_id"] in tool_call_ids, f"tool_call_id {tr['tool_call_id']} not found in tool calls"
|
||||
assert int(tr["tool_return"]) >= 1 and int(tr["tool_return"]) <= 6
|
||||
|
||||
if run_id:
|
||||
await wait_for_run_completion(client, run_id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
|
||||
@pytest.mark.parametrize(
|
||||
["send_type", "cancellation"],
|
||||
list(
|
||||
|
||||
Reference in New Issue
Block a user