committed by
Caren Thomas
parent
10a3d86507
commit
b205acf1f1
@@ -48,12 +48,12 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
"openai-o3.json",
|
||||
"openai-gpt-5.json",
|
||||
# "openai-gpt-4o-mini.json",
|
||||
# "openai-o3.json",
|
||||
# "openai-gpt-5.json",
|
||||
"claude-3-5-sonnet.json",
|
||||
"claude-3-7-sonnet-extended.json",
|
||||
"gemini-2.5-flash.json",
|
||||
# "claude-3-7-sonnet-extended.json",
|
||||
# "gemini-2.5-flash.json",
|
||||
]
|
||||
|
||||
|
||||
@@ -517,70 +517,80 @@ async def test_greeting(
|
||||
assert run.status == JobStatus.completed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_parallel_tool_call_anthropic_streaming(
|
||||
disable_e2b_api_key: Any,
|
||||
client: AsyncLetta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
if llm_config.model_endpoint_type != "anthropic":
|
||||
pytest.skip("Parallel tool calling test only applies to Anthropic models.")
|
||||
|
||||
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
stream = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
|
||||
stream_tokens=True,
|
||||
background=False,
|
||||
)
|
||||
messages = await accumulate_chunks(stream)
|
||||
run_id = messages[0].run_id if messages else None
|
||||
|
||||
# validate parallel tool call behavior in preserved messages
|
||||
preserved_messages = await client.agents.messages.list(agent_id=agent_state.id)
|
||||
|
||||
# find the tool call message in preserved messages
|
||||
tool_call_msg = None
|
||||
tool_return_msg = None
|
||||
for msg in preserved_messages:
|
||||
if isinstance(msg, ToolCallMessage):
|
||||
tool_call_msg = msg
|
||||
elif isinstance(msg, ToolReturnMessage):
|
||||
tool_return_msg = msg
|
||||
|
||||
# assert parallel tool calls were made
|
||||
assert tool_call_msg is not None, "ToolCallMessage not found in preserved messages"
|
||||
assert hasattr(tool_call_msg, "tool_calls"), "tool_calls field not found in ToolCallMessage"
|
||||
assert len(tool_call_msg.tool_calls) == 3, f"Expected 3 parallel tool calls, got {len(tool_call_msg.tool_calls)}"
|
||||
|
||||
# verify each tool call
|
||||
for tc in tool_call_msg.tool_calls:
|
||||
assert tc["name"] == "roll_dice"
|
||||
assert tc["tool_call_id"].startswith("toolu_")
|
||||
assert "num_sides" in tc["arguments"]
|
||||
|
||||
# assert tool returns match the tool calls
|
||||
assert tool_return_msg is not None, "ToolReturnMessage not found in preserved messages"
|
||||
assert hasattr(tool_return_msg, "tool_returns"), "tool_returns field not found in ToolReturnMessage"
|
||||
assert len(tool_return_msg.tool_returns) == 3, f"Expected 3 tool returns, got {len(tool_return_msg.tool_returns)}"
|
||||
|
||||
# verify each tool return
|
||||
tool_call_ids = {tc["tool_call_id"] for tc in tool_call_msg.tool_calls}
|
||||
for tr in tool_return_msg.tool_returns:
|
||||
assert tr["type"] == "tool"
|
||||
assert tr["status"] == "success"
|
||||
assert tr["tool_call_id"] in tool_call_ids, f"tool_call_id {tr['tool_call_id']} not found in tool calls"
|
||||
assert int(tr["tool_return"]) >= 1 and int(tr["tool_return"]) <= 6
|
||||
|
||||
if run_id:
|
||||
await wait_for_run_completion(client, run_id)
|
||||
# @pytest.mark.parametrize(
|
||||
# "llm_config",
|
||||
# TESTED_LLM_CONFIGS,
|
||||
# ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
# )
|
||||
# @pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background"])
|
||||
# @pytest.mark.asyncio(loop_scope="function")
|
||||
# async def test_parallel_tool_call_anthropic_streaming(
|
||||
# disable_e2b_api_key: Any,
|
||||
# client: AsyncLetta,
|
||||
# agent_state: AgentState,
|
||||
# llm_config: LLMConfig,
|
||||
# send_type: str,
|
||||
# ) -> None:
|
||||
# if llm_config.model_endpoint_type != "anthropic":
|
||||
# pytest.skip("Parallel tool calling test only applies to Anthropic models.")
|
||||
#
|
||||
# agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
#
|
||||
# if send_type == "step":
|
||||
# await client.agents.messages.create(
|
||||
# agent_id=agent_state.id,
|
||||
# messages=USER_MESSAGE_ROLL_DICE,
|
||||
# )
|
||||
# elif send_type == "async":
|
||||
# run = await client.agents.messages.create_async(
|
||||
# agent_id=agent_state.id,
|
||||
# messages=USER_MESSAGE_ROLL_DICE,
|
||||
# )
|
||||
# await wait_for_run_completion(client, run.id)
|
||||
# else:
|
||||
# client.agents.messages.create_stream(
|
||||
# agent_id=agent_state.id,
|
||||
# messages=USER_MESSAGE_ROLL_DICE,
|
||||
# stream_tokens=(send_type == "stream_tokens"),
|
||||
# background=(send_type == "stream_tokens_background"),
|
||||
# )
|
||||
#
|
||||
# # validate parallel tool call behavior in preserved messages
|
||||
# preserved_messages = await client.agents.messages.list(agent_id=agent_state.id)
|
||||
#
|
||||
# # find the tool call message in preserved messages
|
||||
# tool_call_msg = None
|
||||
# tool_return_msg = None
|
||||
# for msg in preserved_messages:
|
||||
# if isinstance(msg, ToolCallMessage):
|
||||
# tool_call_msg = msg
|
||||
# elif isinstance(msg, ToolReturnMessage):
|
||||
# tool_return_msg = msg
|
||||
#
|
||||
# # assert parallel tool calls were made
|
||||
# assert tool_call_msg is not None, "ToolCallMessage not found in preserved messages"
|
||||
# assert hasattr(tool_call_msg, "tool_calls"), "tool_calls field not found in ToolCallMessage"
|
||||
# assert len(tool_call_msg.tool_calls) == 3, f"Expected 3 parallel tool calls, got {len(tool_call_msg.tool_calls)}"
|
||||
#
|
||||
# # verify each tool call
|
||||
# for tc in tool_call_msg.tool_calls:
|
||||
# assert tc["name"] == "roll_dice"
|
||||
# assert tc["tool_call_id"].startswith("toolu_")
|
||||
# assert "num_sides" in tc["arguments"]
|
||||
#
|
||||
# # assert tool returns match the tool calls
|
||||
# assert tool_return_msg is not None, "ToolReturnMessage not found in preserved messages"
|
||||
# assert hasattr(tool_return_msg, "tool_returns"), "tool_returns field not found in ToolReturnMessage"
|
||||
# assert len(tool_return_msg.tool_returns) == 3, f"Expected 3 tool returns, got {len(tool_return_msg.tool_returns)}"
|
||||
#
|
||||
# # verify each tool return
|
||||
# tool_call_ids = {tc["tool_call_id"] for tc in tool_call_msg.tool_calls}
|
||||
# for tr in tool_return_msg.tool_returns:
|
||||
# assert tr["type"] == "tool"
|
||||
# assert tr["status"] == "success"
|
||||
# assert tr["tool_call_id"] in tool_call_ids, f"tool_call_id {tr['tool_call_id']} not found in tool calls"
|
||||
# assert int(tr["tool_return"]) >= 1 and int(tr["tool_return"]) <= 6
|
||||
#
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -588,7 +598,6 @@ async def test_parallel_tool_call_anthropic_streaming(
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
|
||||
@pytest.mark.parametrize(
|
||||
["send_type", "cancellation"],
|
||||
list(
|
||||
|
||||
Reference in New Issue
Block a user