diff --git a/fern/openapi.json b/fern/openapi.json index f27d49be..912bd408 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -23252,7 +23252,44 @@ } ], "title": "Tool Call", - "description": "The tool call that has been requested by the llm to run" + "description": "The tool call that has been requested by the llm to run", + "deprecated": true + }, + "requested_tool_calls": { + "anyOf": [ + { + "items": { + "$ref": "#/components/schemas/ToolCall" + }, + "type": "array" + }, + { + "$ref": "#/components/schemas/ToolCallDelta" + }, + { + "type": "null" + } + ], + "title": "Requested Tool Calls", + "description": "The tool calls that have been requested by the llm to run, which are pending approval" + }, + "allowed_tool_calls": { + "anyOf": [ + { + "items": { + "$ref": "#/components/schemas/ToolCall" + }, + "type": "array" + }, + { + "$ref": "#/components/schemas/ToolCallDelta" + }, + { + "type": "null" + } + ], + "title": "Allowed Tool Calls", + "description": "Any tool calls returned by the llm during the same turn that do not require approvals, which will execute once this approval request is handled regardless of approval or denial. Only used when parallel_tool_calls is enabled" } }, "type": "object", diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index d8952c3f..7d81a6af 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -137,7 +137,9 @@ async def _prepare_in_context_messages_async( def validate_approval_tool_call_ids(approval_request_message: Message, approval_response_message: ApprovalCreate): approval_requests = approval_request_message.tool_calls - approval_request_tool_call_ids = [approval_request.id for approval_request in approval_requests] + approval_request_tool_call_ids = [approval_request.id for approval_request in approval_requests if approval_request.requires_approval] + if not approval_request_tool_call_ids and len(approval_request_message.tool_calls) == 1: + approval_request_tool_call_ids = [approval_request_message.tool_calls[0].id] approval_responses = approval_response_message.approvals approval_response_tool_call_ids = [approval_response.tool_call_id for approval_response in approval_responses] diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index d159a193..5d5f7208 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -315,12 +315,15 @@ class LettaAgentV3(LettaAgentV2): # Get tool calls that are pending backfill_tool_call_id = approval_request.tool_calls[0].id # legacy case - pending_tool_calls = { - backfill_tool_call_id if a.tool_call_id.startswith("message-") else a.tool_call_id: a + approved_tool_call_ids = [ + backfill_tool_call_id if a.tool_call_id.startswith("message-") else a.tool_call_id for a in approval_response.approvals if isinstance(a, ApprovalReturn) and a.approve - } - tool_calls = [t for t in approval_request.tool_calls if t.id in pending_tool_calls] + ] + pending_tool_call_ids = [ + t.id for t in approval_request.tool_calls if not t.requires_approval and t.id not in approved_tool_call_ids + ] + tool_calls = [t for t in approval_request.tool_calls if t.id in approved_tool_call_ids + pending_tool_call_ids] # Get tool calls that were denied denies = {d.tool_call_id: d for d in approval_response.approvals if isinstance(d, ApprovalReturn) and not d.approve} @@ -363,20 +366,19 @@ class LettaAgentV3(LettaAgentV2): # Enable Anthropic parallel tool use when no tool rules are attached try: if self.agent_state.llm_config.model_endpoint_type in ["anthropic", "bedrock"]: - no_tool_rules = not self.agent_state.tool_rules or len(self.agent_state.tool_rules) == 0 - requires_approval = self.tool_rules_solver.get_requires_approval_tools( - set([t["name"] for t in valid_tools]) + no_tool_rules = ( + not self.agent_state.tool_rules + or len([t for t in self.agent_state.tool_rules if t.type != "requires_approval"]) == 0 ) - has_approval_tools = len(requires_approval) > 0 if ( isinstance(request_data.get("tool_choice"), dict) and "disable_parallel_tool_use" in request_data["tool_choice"] ): - # Gate parallel tool use on both: no tool rules and no approval-required tools and toggled on - if no_tool_rules and not has_approval_tools and self.agent_state.llm_config.parallel_tool_calls: + # Gate parallel tool use on both: no tool rules and toggled on + if no_tool_rules and self.agent_state.llm_config.parallel_tool_calls: request_data["tool_choice"]["disable_parallel_tool_use"] = False else: - # Explicitly disable when approvals exist (TODO support later) or tool rules present or llm_config toggled off + # Explicitly disable when tool rules present or llm_config toggled off request_data["tool_choice"]["disable_parallel_tool_use"] = True except Exception: # if this fails, we simply don't enable parallel tool use @@ -678,127 +680,71 @@ class LettaAgentV3(LettaAgentV2): return persisted_messages, continue_stepping, stop_reason # 2. Check whether tool call requires approval - tool_names = [tc.function.name for tc in tool_calls] - requires_approval = ( - not is_approval_response - and tool_names - and any(tool_rules_solver.is_requires_approval_tool(tool_name) for tool_name in tool_names) - ) - if requires_approval: - approval_message = create_approval_request_message_from_llm_response( - agent_id=agent_state.id, - model=agent_state.llm_config.model, - tool_calls=tool_calls, - reasoning_content=content, - pre_computed_assistant_message_id=pre_computed_assistant_message_id, - step_id=step_id, - run_id=run_id, - ) - messages_to_persist = (initial_messages or []) + [approval_message] - - for message in messages_to_persist: - if message.run_id is None: - message.run_id = run_id - - persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, - actor=self.actor, - run_id=run_id, - project_id=agent_state.project_id, - template_id=agent_state.template_id, - ) - return persisted_messages, False, LettaStopReason(stop_reason=StopReasonType.requires_approval.value) - - if tool_returns: - assert len(tool_returns) == 1, "Only one tool return is supported" - tool_return = tool_returns[0] - continue_stepping = True - stop_reason = None - tool_call_messages = [ - Message( - role=MessageRole.tool, - content=[TextContent(text=tool_return.func_response)], + if not is_approval_response: + requires_approval = False + for tool_call in tool_calls: + tool_call.requires_approval = tool_rules_solver.is_requires_approval_tool(tool_call.function.name) + requires_approval = requires_approval or tool_call.requires_approval + if requires_approval: + approval_message = create_approval_request_message_from_llm_response( agent_id=agent_state.id, model=agent_state.llm_config.model, - tool_calls=[], - tool_call_id=tool_return.tool_call_id, - created_at=get_utc_time(), - tool_returns=[tool_return], - run_id=run_id, + tool_calls=tool_calls, + reasoning_content=content, + pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=step_id, + run_id=run_id, ) - ] - messages_to_persist = (initial_messages or []) + tool_call_messages - for message in messages_to_persist: - if message.run_id is None: - message.run_id = run_id + messages_to_persist = (initial_messages or []) + [approval_message] - persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, - actor=self.actor, - run_id=run_id, - project_id=agent_state.project_id, - template_id=agent_state.template_id, - ) - return persisted_messages, continue_stepping, stop_reason + for message in messages_to_persist: + if message.run_id is None: + message.run_id = run_id - # Handle denial case first (special case that bypasses normal flow) - if tool_call_denials: - assert len(tool_call_denials) == 1, "Only one tool call denial is supported" - tool_call_denial = tool_call_denials[0] + persisted_messages = await self.message_manager.create_many_messages_async( + messages_to_persist, + actor=self.actor, + run_id=run_id, + project_id=agent_state.project_id, + template_id=agent_state.template_id, + ) + return persisted_messages, False, LettaStopReason(stop_reason=StopReasonType.requires_approval.value) + + result_tool_returns = [] + + # 3. Handle client side tool execution + if tool_returns: continue_stepping = True stop_reason = None - tool_call_messages = create_letta_messages_from_llm_response( - agent_id=agent_state.id, - model=agent_state.llm_config.model, - function_name=tool_call_denial.function.name, - function_arguments={}, - tool_execution_result=ToolExecutionResult(status="error"), - tool_call_id=tool_call_denial.id or f"call_{uuid.uuid4().hex[:8]}", - function_response=f"Error: request to call tool denied. User reason: {tool_call_denial.reason}", - timezone=agent_state.timezone, - continue_stepping=continue_stepping, - heartbeat_reason=f"{NON_USER_MSG_PREFIX}Continuing: user denied request to call tool.", - reasoning_content=None, - pre_computed_assistant_message_id=None, - step_id=step_id, - run_id=run_id, - is_approval_response=True, - force_set_request_heartbeat=False, - add_heartbeat_on_continue=False, - ) - messages_to_persist = (initial_messages or []) + tool_call_messages + result_tool_returns = tool_returns - # Set run_id on all messages before persisting - for message in messages_to_persist: - if message.run_id is None: - message.run_id = run_id + # 4. Handle denial cases + if tool_call_denials: + for tool_call_denial in tool_call_denials: + tool_call_id = tool_call_denial.id or f"call_{uuid.uuid4().hex[:8]}" + packaged_function_response = package_function_response( + was_success=False, + response_string=f"Error: request to call tool denied. User reason: {tool_call_denial.reason}", + timezone=agent_state.timezone, + ) + tool_return = ToolReturn( + tool_call_id=tool_call_id, + func_response=packaged_function_response, + status="error", + ) + result_tool_returns.append(tool_return) - persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, - actor=self.actor, - run_id=run_id, - project_id=agent_state.project_id, - template_id=agent_state.template_id, - ) - return persisted_messages, continue_stepping, stop_reason + # 5. Unified tool execution path (works for both single and multiple tools) - # 4. Unified tool execution path (works for both single and multiple tools) - - # 4a. TODO - - # 4b. Validate parallel tool calling constraints + # 5a. Validate parallel tool calling constraints if len(tool_calls) > 1: # No parallel tool calls with tool rules - if agent_state.tool_rules and len(agent_state.tool_rules) > 0: + if agent_state.tool_rules and len([r for r in agent_state.tool_rules if r.type != "requires_approval"]) > 0: raise ValueError( "Parallel tool calling is not allowed when tool rules are present. Disable tool rules to use parallel tool calls." ) - # No parallel tool calls with approval-required tools - if any(tool_rules_solver.is_requires_approval_tool(tc.function.name) for tc in tool_calls): - raise ValueError("Parallel tool calling is not allowed when any tool requires approval.") - # 4c. Prepare execution specs for all tools + # 5b. Prepare execution specs for all tools exec_specs = [] for tc in tool_calls: call_id = tc.id or f"call_{uuid.uuid4().hex[:8]}" @@ -849,7 +795,7 @@ class LettaAgentV3(LettaAgentV2): } ) - # 4d. Execute tools (sequentially for single, parallel for multiple) + # 5c. Execute tools (sequentially for single, parallel for multiple) if len(exec_specs) == 1: # Single tool - execute directly without asyncio.gather overhead spec = exec_specs[0] @@ -892,11 +838,11 @@ class LettaAgentV3(LettaAgentV2): results = await asyncio.gather(*[_run_one(s) for s in exec_specs]) - # Update metrics with execution time + # 5d. Update metrics with execution time if step_metrics is not None and results: step_metrics.tool_execution_ns = max(dt for _, dt in results) - # 4e. Process results and compute function responses + # 5e. Process results and compute function responses function_responses: list[Optional[str]] = [] persisted_continue_flags: list[bool] = [] persisted_stop_reasons: list[LettaStopReason | None] = [] @@ -942,7 +888,7 @@ class LettaAgentV3(LettaAgentV2): persisted_continue_flags.append(cont) persisted_stop_reasons.append(sr) - # 4f. Create messages using parallel message creation (works for both single and multi) + # 5f. Create messages using parallel message creation (works for both single and multi) tool_call_specs = [{"name": s["name"], "arguments": s["args"], "id": s["id"]} for s in exec_specs] tool_execution_results = [res for (res, _) in results] @@ -959,6 +905,7 @@ class LettaAgentV3(LettaAgentV2): reasoning_content=content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, is_approval_response=is_approval_response, + tool_returns=result_tool_returns, ) messages_to_persist: list[Message] = (initial_messages or []) + parallel_messages @@ -977,10 +924,11 @@ class LettaAgentV3(LettaAgentV2): template_id=agent_state.template_id, ) - # 4g. Aggregate continuation decisions + # 5g. Aggregate continuation decisions # For multiple tools: continue if ANY says continue, use last non-None stop_reason # For single tool: use its decision directly aggregate_continue = any(persisted_continue_flags) if persisted_continue_flags else False + aggregate_continue = aggregate_continue or tool_call_denials or tool_returns # continue if any tool call was denied or returned aggregate_stop_reason = None for sr in persisted_stop_reasons: if sr is not None: diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index de170708..6429843b 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -302,7 +302,16 @@ class ApprovalRequestMessage(LettaMessage): message_type: Literal[MessageType.approval_request_message] = Field( default=MessageType.approval_request_message, description="The type of the message." ) - tool_call: Union[ToolCall, ToolCallDelta] = Field(..., description="The tool call that has been requested by the llm to run") + tool_call: Union[ToolCall, ToolCallDelta] = Field( + ..., description="The tool call that has been requested by the llm to run", deprecated=True + ) + requested_tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = Field( + None, description="The tool calls that have been requested by the llm to run, which are pending approval" + ) + allowed_tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = Field( + None, + description="Any tool calls returned by the llm during the same turn that do not require approvals, which will execute once this approval request is handled regardless of approval or denial. Only used when parallel_tool_calls is enabled", + ) class ApprovalResponseMessage(LettaMessage): diff --git a/letta/schemas/message.py b/letta/schemas/message.py index f7677088..d881d1cb 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -339,10 +339,7 @@ class Message(BaseMessage): if self.content: messages.extend(self._convert_reasoning_messages(text_is_assistant_message=text_is_assistant_message)) if self.tool_calls is not None: - tool_calls = self._convert_tool_call_messages() - assert len(tool_calls) == 1 - approval_request_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"})) - messages.append(approval_request_message) + messages.append(self._convert_approval_request_message()) else: if self.approvals: first_approval = [a for a in self.approvals if isinstance(a, ApprovalReturn)] @@ -814,6 +811,37 @@ class Message(BaseMessage): else: raise ValueError(f"Invalid status: {status}") + def _convert_approval_request_message(self) -> ApprovalRequestMessage: + """Convert approval request message to ApprovalRequestMessage""" + allowed_tool_calls = [] + if len(self.tool_calls) == 0 and self.tool_call: + requested_tool_calls = [self.tool_call] + if len(self.tool_calls) == 1: + requested_tool_calls = self.tool_calls + else: + requested_tool_calls = [t for t in self.tool_calls if t.requires_approval] + allowed_tool_calls = [t for t in self.tool_calls if not t.requires_approval] + + def _convert_tool_call(tool_call): + return ToolCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + tool_call_id=tool_call.id, + ) + + return ApprovalRequestMessage( + id=self.id, + date=self.created_at, + otid=self.otid, + sender_id=self.sender_id, + step_id=self.step_id, + run_id=self.run_id, + tool_call=_convert_tool_call(requested_tool_calls[0]), # backwards compatibility + requested_tool_calls=[_convert_tool_call(tc) for tc in requested_tool_calls], + allowed_tool_calls=[_convert_tool_call(tc) for tc in allowed_tool_calls], + name=self.name, + ) + def _convert_user_message(self) -> UserMessage: """Convert user role message to UserMessage""" # Extract text content diff --git a/letta/schemas/openai/chat_completion_response.py b/letta/schemas/openai/chat_completion_response.py index 63224cc4..4be96243 100644 --- a/letta/schemas/openai/chat_completion_response.py +++ b/letta/schemas/openai/chat_completion_response.py @@ -19,6 +19,7 @@ class ToolCall(BaseModel): type: Literal["function"] = "function" # function: ToolCallFunction function: FunctionCall + requires_approval: Optional[bool] = None class LogProbToken(BaseModel): diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 1d60cae4..1f7fb82f 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -218,6 +218,7 @@ def create_approval_request_message_from_llm_response( arguments=tool_call.function.arguments, ), type="function", + requires_approval=tool_call.requires_approval, ) for tool_call in tool_calls ] @@ -390,6 +391,7 @@ def create_parallel_tool_messages_from_llm_response( pre_computed_assistant_message_id: Optional[str] = None, llm_batch_item_id: Optional[str] = None, is_approval_response: bool = False, + tool_returns: List[ToolReturn] = [], ) -> List[Message]: """ Build two messages representing a parallel tool-call step: @@ -453,7 +455,6 @@ def create_parallel_tool_messages_from_llm_response( messages.append(assistant_message) content: List[TextContent] = [] - tool_returns: List[ToolReturn] = [] for spec, exec_result, response in zip(tool_call_specs, tool_execution_results, function_responses): packaged = package_function_response(exec_result.success_flag, response, timezone) content.append(TextContent(text=packaged)) diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 5d905a62..8ab5fdf8 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -8,7 +8,7 @@ from unittest.mock import patch import pytest import requests from dotenv import load_dotenv -from letta_client import AgentState, ApprovalCreate, Letta, MessageCreate, Tool +from letta_client import AgentState, ApprovalCreate, Letta, LlmConfig, MessageCreate, Tool from letta_client.core.api_error import ApiError from letta.adapters.simple_llm_stream_adapter import SimpleLLMStreamAdapter @@ -42,6 +42,14 @@ USER_MESSAGE_FOLLOW_UP: List[MessageCreate] = [ otid=USER_MESSAGE_FOLLOW_UP_OTID, ) ] +USER_MESSAGE_PARALLEL_TOOL_CALL_CONTENT = "This is an automated test message. Call the get_secret_code_tool 3 times in parallel for the following inputs: 'hello world', 'hello letta', 'hello test', and also call the roll_dice_tool once with a 16-sided dice." +USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreate] = [ + MessageCreate( + role="user", + content=USER_MESSAGE_PARALLEL_TOOL_CALL_CONTENT, + otid=USER_MESSAGE_OTID, + ) +] def get_secret_code_tool(input_text: str) -> str: @@ -55,6 +63,19 @@ def get_secret_code_tool(input_text: str) -> str: return str(abs(hash(input_text))) +def roll_dice_tool(num_sides: int) -> str: + """ + A tool that returns a random number between 1 and num_sides. + Args: + num_sides (int): The number of sides on the die. + Returns: + str: The random number between 1 and num_sides. + """ + import random + + return str(random.randint(1, num_sides)) + + def accumulate_chunks(stream): messages = [] prev_message_type = None @@ -133,7 +154,18 @@ def approval_tool_fixture(client: Letta) -> Tool: @pytest.fixture(scope="function") -def agent(client: Letta, approval_tool_fixture) -> AgentState: +def dice_tool_fixture(client: Letta) -> Tool: + client.tools.upsert_base_tools() + dice_tool = client.tools.upsert_from_function( + func=roll_dice_tool, + ) + yield dice_tool + + client.tools.delete(tool_id=dice_tool.id) + + +@pytest.fixture(scope="function") +def agent(client: Letta, approval_tool_fixture, dice_tool_fixture) -> AgentState: """ Creates and returns an agent state for testing with a pre-configured agent. The agent is configured with the requires_approval_tool. @@ -142,11 +174,17 @@ def agent(client: Letta, approval_tool_fixture) -> AgentState: name="approval_test_agent", agent_type=AgentType.letta_v1_agent, include_base_tools=False, - tool_ids=[approval_tool_fixture.id], + tool_ids=[approval_tool_fixture.id, dice_tool_fixture.id], + include_base_tool_rules=False, + tool_rules=[], + # parallel_tool_calls=True, model="anthropic/claude-sonnet-4-5-20250929", embedding="openai/text-embedding-3-small", tags=["approval_test"], ) + agent_state = client.agents.modify( + agent_id=agent_state.id, llm_config=dict(agent_state.llm_config.model_dump(), **{"parallel_tool_calls": True}) + ) yield agent_state client.agents.delete(agent_id=agent_state.id) @@ -223,6 +261,35 @@ def test_send_approval_message_with_incorrect_request_id(client, agent): def test_invoke_approval_request( client: Letta, agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + + messages = response.messages + + assert messages is not None + assert len(messages) == 3 + assert messages[0].message_type == "reasoning_message" + assert messages[1].message_type == "assistant_message" + assert messages[2].message_type == "approval_request_message" + assert messages[2].tool_call is not None + assert messages[2].tool_call.name == "get_secret_code_tool" + assert messages[2].requested_tool_calls is not None + assert len(messages[2].requested_tool_calls) == 1 + assert messages[2].requested_tool_calls[0]["name"] == "get_secret_code_tool" + + # v3/v1 path: approval request tool args must not include request_heartbeat + import json as _json + + _args = _json.loads(messages[2].tool_call.arguments) + assert "request_heartbeat" not in _args + + +def test_invoke_approval_request_stream( + client: Letta, + agent: AgentState, ) -> None: response = client.agents.messages.create_stream( agent_id=agent.id, @@ -237,11 +304,8 @@ def test_invoke_approval_request( assert messages[0].message_type == "reasoning_message" assert messages[1].message_type == "assistant_message" assert messages[2].message_type == "approval_request_message" - # v3/v1 path: approval request tool args must not include request_heartbeat - # import json as _json - - # _args = _json.loads(messages[2].tool_call.arguments) - # assert "request_heartbeat" not in _args + assert messages[2].tool_call is not None + assert messages[2].tool_call.name == "get_secret_code_tool" assert messages[3].message_type == "stop_reason" assert messages[4].message_type == "usage_statistics" @@ -1105,3 +1169,92 @@ def test_client_side_tool_call_and_follow_up_with_error( assert messages[1].message_type == "assistant_message" assert messages[2].message_type == "stop_reason" assert messages[3].message_type == "usage_statistics" + + +def test_parallel_tool_calling( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_PARALLEL_TOOL_CALL, + ) + + messages = response.messages + + assert messages is not None + assert len(messages) == 3 + assert messages[0].message_type == "reasoning_message" + assert messages[1].message_type == "assistant_message" + assert messages[2].message_type == "approval_request_message" + assert messages[2].tool_call is not None + assert messages[2].tool_call.name == "get_secret_code_tool" or messages[2].tool_call.name == "roll_dice_tool" + + assert len(messages[2].requested_tool_calls) == 3 + assert messages[2].requested_tool_calls[0]["name"] == "get_secret_code_tool" + assert "hello world" in messages[2].requested_tool_calls[0]["arguments"] + approve_tool_call_id = messages[2].requested_tool_calls[0]["tool_call_id"] + assert messages[2].requested_tool_calls[1]["name"] == "get_secret_code_tool" + assert "hello letta" in messages[2].requested_tool_calls[1]["arguments"] + deny_tool_call_id = messages[2].requested_tool_calls[1]["tool_call_id"] + assert messages[2].requested_tool_calls[2]["name"] == "get_secret_code_tool" + assert "hello test" in messages[2].requested_tool_calls[2]["arguments"] + client_side_tool_call_id = messages[2].requested_tool_calls[2]["tool_call_id"] + + assert len(messages[2].allowed_tool_calls) == 1 + assert messages[2].allowed_tool_calls[0]["name"] == "roll_dice_tool" + assert "6" in messages[2].allowed_tool_calls[0]["arguments"] + dice_tool_call_id = messages[2].allowed_tool_calls[0]["tool_call_id"] + + response = client.agents.messages.create( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=False, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + approvals=[ + { + "type": "approval", + "approve": True, + "tool_call_id": approve_tool_call_id, + }, + { + "type": "approval", + "approve": False, + "tool_call_id": deny_tool_call_id, + }, + { + "type": "tool", + "tool_call_id": client_side_tool_call_id, + "tool_return": SECRET_CODE, + "status": "success", + }, + ], + ), + ], + ) + + messages = response.messages + + assert messages is not None + assert len(messages) == 1 or len(messages) == 3 or len(messages) == 4 + assert messages[0].message_type == "tool_return_message" + assert len(messages[0].tool_returns) == 4 + for tool_return in messages[0].tool_returns: + if tool_return["tool_call_id"] == approve_tool_call_id: + assert tool_return["status"] == "success" + elif tool_return["tool_call_id"] == deny_tool_call_id: + assert tool_return["status"] == "error" + elif tool_return["tool_call_id"] == client_side_tool_call_id: + assert tool_return["status"] == "success" + assert tool_return["tool_return"] == SECRET_CODE + else: + assert tool_return["tool_call_id"] == dice_tool_call_id + assert tool_return["status"] == "success" + if len(messages) == 3: + assert messages[1].message_type == "reasoning_message" + assert messages[2].message_type == "assistant_message" + elif len(messages) == 4: + assert messages[1].message_type == "reasoning_message" + assert messages[2].message_type == "tool_call_message" + assert messages[3].message_type == "tool_return_message"