diff --git a/fern/openapi.json b/fern/openapi.json index c1861dd9..e0ab9604 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -23210,7 +23210,7 @@ "description": "The tool call that has been requested by the llm to run", "deprecated": true }, - "requested_tool_calls": { + "tool_calls": { "anyOf": [ { "items": { @@ -23225,26 +23225,8 @@ "type": "null" } ], - "title": "Requested Tool Calls", + "title": "Tool Calls", "description": "The tool calls that have been requested by the llm to run, which are pending approval" - }, - "allowed_tool_calls": { - "anyOf": [ - { - "items": { - "$ref": "#/components/schemas/ToolCall" - }, - "type": "array" - }, - { - "$ref": "#/components/schemas/ToolCallDelta" - }, - { - "type": "null" - } - ], - "title": "Allowed Tool Calls", - "description": "Any tool calls returned by the llm during the same turn that do not require approvals, which will execute once this approval request is handled regardless of approval or denial. Only used when parallel_tool_calls is enabled" } }, "type": "object", diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index fa6e22cc..c5050b25 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -137,9 +137,7 @@ async def _prepare_in_context_messages_async( def validate_approval_tool_call_ids(approval_request_message: Message, approval_response_message: ApprovalCreate): approval_requests = approval_request_message.tool_calls - approval_request_tool_call_ids = [approval_request.id for approval_request in approval_requests if approval_request.requires_approval] - if not approval_request_tool_call_ids and len(approval_request_message.tool_calls) == 1: - approval_request_tool_call_ids = [approval_request_message.tool_calls[0].id] + approval_request_tool_call_ids = [approval_request.id for approval_request in approval_requests] approval_responses = approval_response_message.approvals approval_response_tool_call_ids = [approval_response.tool_call_id for approval_response in approval_responses] @@ -418,3 +416,19 @@ def _maybe_get_approval_messages(messages: list[Message]) -> Tuple[Message | Non if maybe_approval_request.role == "approval" and maybe_approval_response.role == "approval": return maybe_approval_request, maybe_approval_response return None, None + + +def _maybe_get_pending_tool_call_message(messages: list[Message]) -> Message | None: + """ + Only used in the case where hitl is invoked with parallel tool calling, + where agent calls some tools that require approval, and others that don't. + """ + if len(messages) >= 3: + maybe_tool_call_message = messages[-3] + if ( + maybe_tool_call_message.role == "assistant" + and maybe_tool_call_message.tool_calls is not None + and len(maybe_tool_call_message.tool_calls) > 0 + ): + return maybe_tool_call_message + return None diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 81f38488..0860a9dc 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -1700,15 +1700,17 @@ class LettaAgent(BaseAgent): ) if not is_approval and tool_rules_solver.is_requires_approval_tool(tool_call_name): tool_args[REQUEST_HEARTBEAT_PARAM] = request_heartbeat - approval_message = create_approval_request_message_from_llm_response( + approval_messages = create_approval_request_message_from_llm_response( agent_id=agent_state.id, model=agent_state.llm_config.model, - tool_calls=[ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args)))], + requested_tool_calls=[ + ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args))) + ], reasoning_content=reasoning_content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=step_id, ) - messages_to_persist = (initial_messages or []) + [approval_message] + messages_to_persist = (initial_messages or []) + approval_messages continue_stepping = False stop_reason = LettaStopReason(stop_reason=StopReasonType.requires_approval.value) else: diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index fea764f3..3b9281d1 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -935,16 +935,18 @@ class LettaAgentV2(BaseAgentV2): if not is_approval and tool_rules_solver.is_requires_approval_tool(tool_call_name): tool_args[REQUEST_HEARTBEAT_PARAM] = request_heartbeat - approval_message = create_approval_request_message_from_llm_response( + approval_messages = create_approval_request_message_from_llm_response( agent_id=agent_state.id, model=agent_state.llm_config.model, - tool_calls=[ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args)))], + requested_tool_calls=[ + ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args))) + ], reasoning_content=reasoning_content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=step_id, run_id=run_id, ) - messages_to_persist = (initial_messages or []) + [approval_message] + messages_to_persist = (initial_messages or []) + approval_messages continue_stepping = False stop_reason = LettaStopReason(stop_reason=StopReasonType.requires_approval.value) else: diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 81a54e46..7dc28402 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -12,6 +12,7 @@ from letta.agents.helpers import ( _build_rule_violation_result, _load_last_function_response, _maybe_get_approval_messages, + _maybe_get_pending_tool_call_message, _prepare_in_context_messages_no_persist_async, _safe_load_tool_call_str, generate_step_id, @@ -315,15 +316,15 @@ class LettaAgentV3(LettaAgentV2): # Get tool calls that are pending backfill_tool_call_id = approval_request.tool_calls[0].id # legacy case - approved_tool_call_ids = [ + approved_tool_call_ids = { backfill_tool_call_id if a.tool_call_id.startswith("message-") else a.tool_call_id for a in approval_response.approvals if isinstance(a, ApprovalReturn) and a.approve - ] - pending_tool_call_ids = [ - t.id for t in approval_request.tool_calls if not t.requires_approval and t.id not in approved_tool_call_ids - ] - tool_calls = [t for t in approval_request.tool_calls if t.id in approved_tool_call_ids + pending_tool_call_ids] + } + tool_calls = [tool_call for tool_call in approval_request.tool_calls if tool_call.id in approved_tool_call_ids] + pending_tool_call_message = _maybe_get_pending_tool_call_message(messages) + if pending_tool_call_message: + tool_calls.extend(pending_tool_call_message.tool_calls) # Get tool calls that were denied denies = {d.tool_call_id: d for d in approval_response.approvals if isinstance(d, ApprovalReturn) and not d.approve} @@ -681,21 +682,20 @@ class LettaAgentV3(LettaAgentV2): # 2. Check whether tool call requires approval if not is_approval_response: - requires_approval = False - for tool_call in tool_calls: - tool_call.requires_approval = tool_rules_solver.is_requires_approval_tool(tool_call.function.name) - requires_approval = requires_approval or tool_call.requires_approval - if requires_approval: - approval_message = create_approval_request_message_from_llm_response( + requested_tool_calls = [t for t in tool_calls if tool_rules_solver.is_requires_approval_tool(t.function.name)] + allowed_tool_calls = [t for t in tool_calls if not tool_rules_solver.is_requires_approval_tool(t.function.name)] + if requested_tool_calls: + approval_messages = create_approval_request_message_from_llm_response( agent_id=agent_state.id, model=agent_state.llm_config.model, - tool_calls=tool_calls, + requested_tool_calls=requested_tool_calls, + allowed_tool_calls=allowed_tool_calls, reasoning_content=content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=step_id, run_id=run_id, ) - messages_to_persist = (initial_messages or []) + [approval_message] + messages_to_persist = (initial_messages or []) + approval_messages for message in messages_to_persist: if message.run_id is None: diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index c067eff9..9ca236fc 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -181,6 +181,7 @@ def deserialize_tool_calls(data: Optional[List[Dict]]) -> List[OpenAIToolCall]: calls = [] for item in data: + item.pop("requires_approval", None) # legacy field func_data = item.pop("function", None) tool_call_function = OpenAIFunction(**func_data) calls.append(OpenAIToolCall(function=tool_call_function, **item)) diff --git a/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py b/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py index c18f8e5e..12bcec02 100644 --- a/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py +++ b/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py @@ -39,6 +39,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser +from letta.server.rest_api.utils import increment_message_uuid logger = get_logger(__name__) @@ -282,14 +283,12 @@ class SimpleAnthropicStreamingInterface: call_id = content.id # Initialize arguments from the start event's input (often {}) to avoid undefined in UIs if name in self.requires_approval_tools: - if prev_message_type and prev_message_type != "approval_request_message": - message_index += 1 tool_call_msg = ApprovalRequestMessage( - id=self.letta_message_id, + id=increment_message_uuid(self.letta_message_id), # Do not emit placeholder arguments here to avoid UI duplicates tool_call=ToolCallDelta(name=name, tool_call_id=call_id), date=datetime.now(timezone.utc).isoformat(), - otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + otid=Message.generate_otid_from_id(increment_message_uuid(self.letta_message_id), message_index), run_id=self.run_id, step_id=self.step_id, ) @@ -306,7 +305,7 @@ class SimpleAnthropicStreamingInterface: run_id=self.run_id, step_id=self.step_id, ) - prev_message_type = tool_call_msg.message_type + prev_message_type = tool_call_msg.message_type yield tool_call_msg elif isinstance(content, BetaThinkingBlock): @@ -382,13 +381,11 @@ class SimpleAnthropicStreamingInterface: call_id = ctx.get("id") if name in self.requires_approval_tools: - if prev_message_type and prev_message_type != "approval_request_message": - message_index += 1 tool_call_msg = ApprovalRequestMessage( - id=self.letta_message_id, + id=increment_message_uuid(self.letta_message_id), tool_call=ToolCallDelta(name=name, tool_call_id=call_id, arguments=delta.partial_json), date=datetime.now(timezone.utc).isoformat(), - otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + otid=Message.generate_otid_from_id(increment_message_uuid(self.letta_message_id), message_index), run_id=self.run_id, step_id=self.step_id, ) @@ -404,7 +401,7 @@ class SimpleAnthropicStreamingInterface: run_id=self.run_id, step_id=self.step_id, ) - + prev_message_type = tool_call_msg.message_type yield tool_call_msg elif isinstance(delta, BetaThinkingDelta): diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index 6429843b..1f5b5155 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -305,13 +305,9 @@ class ApprovalRequestMessage(LettaMessage): tool_call: Union[ToolCall, ToolCallDelta] = Field( ..., description="The tool call that has been requested by the llm to run", deprecated=True ) - requested_tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = Field( + tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = Field( None, description="The tool calls that have been requested by the llm to run, which are pending approval" ) - allowed_tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = Field( - None, - description="Any tool calls returned by the llm during the same turn that do not require approvals, which will execute once this approval request is handled regardless of approval or denial. Only used when parallel_tool_calls is enabled", - ) class ApprovalResponseMessage(LettaMessage): diff --git a/letta/schemas/message.py b/letta/schemas/message.py index d881d1cb..937593d0 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -10,6 +10,7 @@ from datetime import datetime, timezone from enum import Enum from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from letta_client import LettaMessageUnion from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction from openai.types.responses import ResponseReasoningItem from pydantic import BaseModel, Field, field_validator, model_validator @@ -813,14 +814,6 @@ class Message(BaseMessage): def _convert_approval_request_message(self) -> ApprovalRequestMessage: """Convert approval request message to ApprovalRequestMessage""" - allowed_tool_calls = [] - if len(self.tool_calls) == 0 and self.tool_call: - requested_tool_calls = [self.tool_call] - if len(self.tool_calls) == 1: - requested_tool_calls = self.tool_calls - else: - requested_tool_calls = [t for t in self.tool_calls if t.requires_approval] - allowed_tool_calls = [t for t in self.tool_calls if not t.requires_approval] def _convert_tool_call(tool_call): return ToolCall( @@ -836,9 +829,8 @@ class Message(BaseMessage): sender_id=self.sender_id, step_id=self.step_id, run_id=self.run_id, - tool_call=_convert_tool_call(requested_tool_calls[0]), # backwards compatibility - requested_tool_calls=[_convert_tool_call(tc) for tc in requested_tool_calls], - allowed_tool_calls=[_convert_tool_call(tc) for tc in allowed_tool_calls], + tool_call=_convert_tool_call(self.tool_calls[0]), # backwards compatibility + tool_calls=[_convert_tool_call(tc) for tc in self.tool_calls], name=self.name, ) @@ -1819,7 +1811,40 @@ class Message(BaseMessage): # Filter last message if it is a lone approval request without a response - this only occurs for token counting if messages[-1].role == "approval" and messages[-1].tool_calls is not None and len(messages[-1].tool_calls) > 0: messages.remove(messages[-1]) + # Also filter pending tool call message if this turn invoked parallel tool calling + if messages and messages[-1].role == "assistant" and messages[-1].tool_calls is not None and len(messages[-1].tool_calls) > 0: + messages.remove(messages[-1]) + # Filter last message if it is a lone reasoning message without assistant message or tool call + if ( + messages[-1].role == "assistant" + and messages[-1].tool_calls is None + and (not messages[-1].content or all(not isinstance(content_part, TextContent) for content_part in messages[-1].content)) + ): + messages.remove(messages[-1]) + + # Collapse adjacent tool call and approval messages + messages = Message.collapse_tool_call_messages_for_llm_api(messages) + + return messages + + @staticmethod + def collapse_tool_call_messages_for_llm_api( + messages: List[Message], + ) -> List[Message]: + adjacent_tool_call_approval_messages = [] + for i in range(len(messages) - 1): + if ( + messages[i].role == MessageRole.assistant + and messages[i].tool_calls is not None + and messages[i + 1].role == MessageRole.approval + and messages[i + 1].tool_calls is not None + ): + adjacent_tool_call_approval_messages.append(i) + for i in reversed(adjacent_tool_call_approval_messages): + messages[i].content = messages[i].content + messages[i + 1].content + messages[i].tool_calls = messages[i].tool_calls + messages[i + 1].tool_calls + messages.remove(messages[i + 1]) return messages @staticmethod diff --git a/letta/schemas/openai/chat_completion_response.py b/letta/schemas/openai/chat_completion_response.py index 4be96243..63224cc4 100644 --- a/letta/schemas/openai/chat_completion_response.py +++ b/letta/schemas/openai/chat_completion_response.py @@ -19,7 +19,6 @@ class ToolCall(BaseModel): type: Literal["function"] = "function" # function: ToolCallFunction function: FunctionCall - requires_approval: Optional[bool] = None class LogProbToken(BaseModel): diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 1f7fb82f..048dd3ae 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -203,12 +203,40 @@ def create_approval_response_message_from_input( def create_approval_request_message_from_llm_response( agent_id: str, model: str, - tool_calls: List[OpenAIToolCall], + requested_tool_calls: List[OpenAIToolCall], + allowed_tool_calls: List[OpenAIToolCall] = [], reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, pre_computed_assistant_message_id: Optional[str] = None, step_id: str | None = None, run_id: str = None, ) -> Message: + messages = [] + if allowed_tool_calls: + oai_tool_calls = [ + OpenAIToolCall( + id=tool_call.id, + function=OpenAIFunction( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + type="function", + ) + for tool_call in allowed_tool_calls + ] + tool_message = Message( + role=MessageRole.assistant, + content=reasoning_content if reasoning_content else [], + agent_id=agent_id, + model=model, + tool_calls=oai_tool_calls, + tool_call_id=allowed_tool_calls[0].id, + created_at=get_utc_time(), + step_id=step_id, + run_id=run_id, + ) + if pre_computed_assistant_message_id: + tool_message.id = pre_computed_assistant_message_id + messages.append(tool_message) # Construct the tool call with the assistant's message oai_tool_calls = [ OpenAIToolCall( @@ -218,15 +246,14 @@ def create_approval_request_message_from_llm_response( arguments=tool_call.function.arguments, ), type="function", - requires_approval=tool_call.requires_approval, ) - for tool_call in tool_calls + for tool_call in requested_tool_calls ] # TODO: Use ToolCallContent instead of tool_calls # TODO: This helps preserve ordering approval_message = Message( role=MessageRole.approval, - content=reasoning_content if reasoning_content else [], + content=reasoning_content if reasoning_content and not allowed_tool_calls else [], agent_id=agent_id, model=model, tool_calls=oai_tool_calls, @@ -236,8 +263,17 @@ def create_approval_request_message_from_llm_response( run_id=run_id, ) if pre_computed_assistant_message_id: - approval_message.id = pre_computed_assistant_message_id - return approval_message + approval_message.id = increment_message_uuid(pre_computed_assistant_message_id) + messages.append(approval_message) + return messages + + +def increment_message_uuid(message_id: str): + message_uuid = uuid.UUID(message_id.split("-", maxsplit=1)[1]) + uuid_as_int = message_uuid.int + incremented_int = uuid_as_int + 1 + incremented_uuid = uuid.UUID(int=incremented_int) + return "message-" + str(incremented_uuid) def create_letta_messages_from_llm_response( diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 8ab5fdf8..3fed4676 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -87,6 +87,25 @@ def accumulate_chunks(stream): return messages +def approve_tool_call(client: Letta, agent_id: str, tool_call_id: str): + client.agents.messages.create( + agent_id=agent_id, + messages=[ + ApprovalCreate( + approve=False, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + approvals=[ + { + "type": "approval", + "approve": True, + "tool_call_id": tool_call_id, + }, + ], + ), + ], + ) + + # ------------------------------ # Fixtures # ------------------------------ @@ -216,7 +235,7 @@ def test_send_approval_without_pending_request(client, agent): def test_send_user_message_with_pending_request(client, agent): - client.agents.messages.create( + response = client.agents.messages.create( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) @@ -227,9 +246,11 @@ def test_send_user_message_with_pending_request(client, agent): messages=[MessageCreate(role="user", content="hi")], ) + approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id) + def test_send_approval_message_with_incorrect_request_id(client, agent): - client.agents.messages.create( + response = client.agents.messages.create( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) @@ -252,6 +273,8 @@ def test_send_approval_message_with_incorrect_request_id(client, agent): ], ) + approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id) + # ------------------------------ # Request Test Cases @@ -276,9 +299,9 @@ def test_invoke_approval_request( assert messages[2].message_type == "approval_request_message" assert messages[2].tool_call is not None assert messages[2].tool_call.name == "get_secret_code_tool" - assert messages[2].requested_tool_calls is not None - assert len(messages[2].requested_tool_calls) == 1 - assert messages[2].requested_tool_calls[0]["name"] == "get_secret_code_tool" + assert messages[2].tool_calls is not None + assert len(messages[2].tool_calls) == 1 + assert messages[2].tool_calls[0]["name"] == "get_secret_code_tool" # v3/v1 path: approval request tool args must not include request_heartbeat import json as _json @@ -286,6 +309,10 @@ def test_invoke_approval_request( _args = _json.loads(messages[2].tool_call.arguments) assert "request_heartbeat" not in _args + client.agents.context.retrieve(agent_id=agent.id) + + approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id) + def test_invoke_approval_request_stream( client: Letta, @@ -309,43 +336,9 @@ def test_invoke_approval_request_stream( assert messages[3].message_type == "stop_reason" assert messages[4].message_type == "usage_statistics" + client.agents.context.retrieve(agent_id=agent.id) -def test_invoke_approval_request_with_context_check( - client: Letta, - agent: AgentState, -) -> None: - response = client.agents.messages.create( - agent_id=agent.id, - messages=USER_MESSAGE_TEST_APPROVAL, - ) - tool_call_id = response.messages[2].tool_call.tool_call_id - - response = client.agents.messages.create_stream( - agent_id=agent.id, - messages=[ - ApprovalCreate( - approve=False, # legacy (passing incorrect value to ensure it is overridden) - approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) - approvals=[ - { - "type": "approval", - "approve": True, - "tool_call_id": tool_call_id, - }, - ], - ), - ], - stream_tokens=True, - ) - - messages = accumulate_chunks(response) - - try: - client.agents.context.retrieve(agent_id=agent.id) - except Exception as e: - if len(messages) > 4: - raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.") - raise e + approve_tool_call(client, agent.id, messages[2].tool_call.tool_call_id) def test_invoke_tool_after_turning_off_requires_approval( @@ -357,7 +350,6 @@ def test_invoke_tool_after_turning_off_requires_approval( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - approval_request_id = response.messages[0].id tool_call_id = response.messages[2].tool_call.tool_call_id response = client.agents.messages.create_stream( @@ -432,13 +424,7 @@ def test_approve_tool_call_request( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - approval_request_id = response.messages[0].id tool_call_id = response.messages[2].tool_call.tool_call_id - # Ensure no request_heartbeat on approval request - # import json as _json - - # _args = _json.loads(response.messages[0].tool_call.arguments) - # assert "request_heartbeat" not in _args response = client.agents.messages.create_stream( agent_id=agent.id, @@ -1175,6 +1161,7 @@ def test_parallel_tool_calling( client: Letta, agent: AgentState, ) -> None: + last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id response = client.agents.messages.create( agent_id=agent.id, messages=USER_MESSAGE_PARALLEL_TOOL_CALL, @@ -1183,28 +1170,32 @@ def test_parallel_tool_calling( messages = response.messages assert messages is not None - assert len(messages) == 3 + assert len(messages) == 4 assert messages[0].message_type == "reasoning_message" assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "approval_request_message" - assert messages[2].tool_call is not None - assert messages[2].tool_call.name == "get_secret_code_tool" or messages[2].tool_call.name == "roll_dice_tool" + assert messages[2].message_type == "tool_call_message" + assert len(messages[2].tool_calls) == 1 + assert messages[2].tool_calls[0]["name"] == "roll_dice_tool" + assert "6" in messages[2].tool_calls[0]["arguments"] + dice_tool_call_id = messages[2].tool_calls[0]["tool_call_id"] - assert len(messages[2].requested_tool_calls) == 3 - assert messages[2].requested_tool_calls[0]["name"] == "get_secret_code_tool" - assert "hello world" in messages[2].requested_tool_calls[0]["arguments"] - approve_tool_call_id = messages[2].requested_tool_calls[0]["tool_call_id"] - assert messages[2].requested_tool_calls[1]["name"] == "get_secret_code_tool" - assert "hello letta" in messages[2].requested_tool_calls[1]["arguments"] - deny_tool_call_id = messages[2].requested_tool_calls[1]["tool_call_id"] - assert messages[2].requested_tool_calls[2]["name"] == "get_secret_code_tool" - assert "hello test" in messages[2].requested_tool_calls[2]["arguments"] - client_side_tool_call_id = messages[2].requested_tool_calls[2]["tool_call_id"] + assert messages[3].message_type == "approval_request_message" + assert messages[3].tool_call is not None + assert messages[3].tool_call.name == "get_secret_code_tool" - assert len(messages[2].allowed_tool_calls) == 1 - assert messages[2].allowed_tool_calls[0]["name"] == "roll_dice_tool" - assert "6" in messages[2].allowed_tool_calls[0]["arguments"] - dice_tool_call_id = messages[2].allowed_tool_calls[0]["tool_call_id"] + assert len(messages[3].tool_calls) == 3 + assert messages[3].tool_calls[0]["name"] == "get_secret_code_tool" + assert "hello world" in messages[3].tool_calls[0]["arguments"] + approve_tool_call_id = messages[3].tool_calls[0]["tool_call_id"] + assert messages[3].tool_calls[1]["name"] == "get_secret_code_tool" + assert "hello letta" in messages[3].tool_calls[1]["arguments"] + deny_tool_call_id = messages[3].tool_calls[1]["tool_call_id"] + assert messages[3].tool_calls[2]["name"] == "get_secret_code_tool" + assert "hello test" in messages[3].tool_calls[2]["arguments"] + client_side_tool_call_id = messages[3].tool_calls[2]["tool_call_id"] + + # ensure context is not bricked + client.agents.context.retrieve(agent_id=agent.id) response = client.agents.messages.create( agent_id=agent.id, @@ -1258,3 +1249,31 @@ def test_parallel_tool_calling( assert messages[1].message_type == "reasoning_message" assert messages[2].message_type == "tool_call_message" assert messages[3].message_type == "tool_return_message" + + # ensure context is not bricked + client.agents.context.retrieve(agent_id=agent.id) + + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) + assert len(messages) > 6 + assert messages[0].message_type == "user_message" + assert messages[1].message_type == "reasoning_message" + assert messages[2].message_type == "assistant_message" + assert messages[3].message_type == "tool_call_message" + assert messages[4].message_type == "approval_request_message" + assert messages[5].message_type == "approval_response_message" + assert messages[6].message_type == "tool_return_message" + + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=USER_MESSAGE_FOLLOW_UP, + stream_tokens=True, + ) + + messages = accumulate_chunks(response) + + assert messages is not None + assert len(messages) == 4 + assert messages[0].message_type == "reasoning_message" + assert messages[1].message_type == "assistant_message" + assert messages[2].message_type == "stop_reason" + assert messages[3].message_type == "usage_statistics"