feat: latest hitl + parallel tool call changes (#5565)

This commit is contained in:
cthomas
2025-10-18 22:27:51 -07:00
committed by Caren Thomas
parent de0896c547
commit 73dcc0d4b7
12 changed files with 215 additions and 142 deletions

View File

@@ -23210,7 +23210,7 @@
"description": "The tool call that has been requested by the llm to run",
"deprecated": true
},
"requested_tool_calls": {
"tool_calls": {
"anyOf": [
{
"items": {
@@ -23225,26 +23225,8 @@
"type": "null"
}
],
"title": "Requested Tool Calls",
"title": "Tool Calls",
"description": "The tool calls that have been requested by the llm to run, which are pending approval"
},
"allowed_tool_calls": {
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/ToolCall"
},
"type": "array"
},
{
"$ref": "#/components/schemas/ToolCallDelta"
},
{
"type": "null"
}
],
"title": "Allowed Tool Calls",
"description": "Any tool calls returned by the llm during the same turn that do not require approvals, which will execute once this approval request is handled regardless of approval or denial. Only used when parallel_tool_calls is enabled"
}
},
"type": "object",

View File

@@ -137,9 +137,7 @@ async def _prepare_in_context_messages_async(
def validate_approval_tool_call_ids(approval_request_message: Message, approval_response_message: ApprovalCreate):
approval_requests = approval_request_message.tool_calls
approval_request_tool_call_ids = [approval_request.id for approval_request in approval_requests if approval_request.requires_approval]
if not approval_request_tool_call_ids and len(approval_request_message.tool_calls) == 1:
approval_request_tool_call_ids = [approval_request_message.tool_calls[0].id]
approval_request_tool_call_ids = [approval_request.id for approval_request in approval_requests]
approval_responses = approval_response_message.approvals
approval_response_tool_call_ids = [approval_response.tool_call_id for approval_response in approval_responses]
@@ -418,3 +416,19 @@ def _maybe_get_approval_messages(messages: list[Message]) -> Tuple[Message | Non
if maybe_approval_request.role == "approval" and maybe_approval_response.role == "approval":
return maybe_approval_request, maybe_approval_response
return None, None
def _maybe_get_pending_tool_call_message(messages: list[Message]) -> Message | None:
"""
Only used in the case where hitl is invoked with parallel tool calling,
where agent calls some tools that require approval, and others that don't.
"""
if len(messages) >= 3:
maybe_tool_call_message = messages[-3]
if (
maybe_tool_call_message.role == "assistant"
and maybe_tool_call_message.tool_calls is not None
and len(maybe_tool_call_message.tool_calls) > 0
):
return maybe_tool_call_message
return None

View File

@@ -1700,15 +1700,17 @@ class LettaAgent(BaseAgent):
)
if not is_approval and tool_rules_solver.is_requires_approval_tool(tool_call_name):
tool_args[REQUEST_HEARTBEAT_PARAM] = request_heartbeat
approval_message = create_approval_request_message_from_llm_response(
approval_messages = create_approval_request_message_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
tool_calls=[ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args)))],
requested_tool_calls=[
ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args)))
],
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
step_id=step_id,
)
messages_to_persist = (initial_messages or []) + [approval_message]
messages_to_persist = (initial_messages or []) + approval_messages
continue_stepping = False
stop_reason = LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
else:

View File

@@ -935,16 +935,18 @@ class LettaAgentV2(BaseAgentV2):
if not is_approval and tool_rules_solver.is_requires_approval_tool(tool_call_name):
tool_args[REQUEST_HEARTBEAT_PARAM] = request_heartbeat
approval_message = create_approval_request_message_from_llm_response(
approval_messages = create_approval_request_message_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
tool_calls=[ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args)))],
requested_tool_calls=[
ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args)))
],
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
step_id=step_id,
run_id=run_id,
)
messages_to_persist = (initial_messages or []) + [approval_message]
messages_to_persist = (initial_messages or []) + approval_messages
continue_stepping = False
stop_reason = LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
else:

View File

@@ -12,6 +12,7 @@ from letta.agents.helpers import (
_build_rule_violation_result,
_load_last_function_response,
_maybe_get_approval_messages,
_maybe_get_pending_tool_call_message,
_prepare_in_context_messages_no_persist_async,
_safe_load_tool_call_str,
generate_step_id,
@@ -315,15 +316,15 @@ class LettaAgentV3(LettaAgentV2):
# Get tool calls that are pending
backfill_tool_call_id = approval_request.tool_calls[0].id # legacy case
approved_tool_call_ids = [
approved_tool_call_ids = {
backfill_tool_call_id if a.tool_call_id.startswith("message-") else a.tool_call_id
for a in approval_response.approvals
if isinstance(a, ApprovalReturn) and a.approve
]
pending_tool_call_ids = [
t.id for t in approval_request.tool_calls if not t.requires_approval and t.id not in approved_tool_call_ids
]
tool_calls = [t for t in approval_request.tool_calls if t.id in approved_tool_call_ids + pending_tool_call_ids]
}
tool_calls = [tool_call for tool_call in approval_request.tool_calls if tool_call.id in approved_tool_call_ids]
pending_tool_call_message = _maybe_get_pending_tool_call_message(messages)
if pending_tool_call_message:
tool_calls.extend(pending_tool_call_message.tool_calls)
# Get tool calls that were denied
denies = {d.tool_call_id: d for d in approval_response.approvals if isinstance(d, ApprovalReturn) and not d.approve}
@@ -681,21 +682,20 @@ class LettaAgentV3(LettaAgentV2):
# 2. Check whether tool call requires approval
if not is_approval_response:
requires_approval = False
for tool_call in tool_calls:
tool_call.requires_approval = tool_rules_solver.is_requires_approval_tool(tool_call.function.name)
requires_approval = requires_approval or tool_call.requires_approval
if requires_approval:
approval_message = create_approval_request_message_from_llm_response(
requested_tool_calls = [t for t in tool_calls if tool_rules_solver.is_requires_approval_tool(t.function.name)]
allowed_tool_calls = [t for t in tool_calls if not tool_rules_solver.is_requires_approval_tool(t.function.name)]
if requested_tool_calls:
approval_messages = create_approval_request_message_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
tool_calls=tool_calls,
requested_tool_calls=requested_tool_calls,
allowed_tool_calls=allowed_tool_calls,
reasoning_content=content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
step_id=step_id,
run_id=run_id,
)
messages_to_persist = (initial_messages or []) + [approval_message]
messages_to_persist = (initial_messages or []) + approval_messages
for message in messages_to_persist:
if message.run_id is None:

View File

@@ -181,6 +181,7 @@ def deserialize_tool_calls(data: Optional[List[Dict]]) -> List[OpenAIToolCall]:
calls = []
for item in data:
item.pop("requires_approval", None) # legacy field
func_data = item.pop("function", None)
tool_call_function = OpenAIFunction(**func_data)
calls.append(OpenAIToolCall(function=tool_call_function, **item))

View File

@@ -39,6 +39,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
from letta.server.rest_api.utils import increment_message_uuid
logger = get_logger(__name__)
@@ -282,14 +283,12 @@ class SimpleAnthropicStreamingInterface:
call_id = content.id
# Initialize arguments from the start event's input (often {}) to avoid undefined in UIs
if name in self.requires_approval_tools:
if prev_message_type and prev_message_type != "approval_request_message":
message_index += 1
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
id=increment_message_uuid(self.letta_message_id),
# Do not emit placeholder arguments here to avoid UI duplicates
tool_call=ToolCallDelta(name=name, tool_call_id=call_id),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
otid=Message.generate_otid_from_id(increment_message_uuid(self.letta_message_id), message_index),
run_id=self.run_id,
step_id=self.step_id,
)
@@ -306,7 +305,7 @@ class SimpleAnthropicStreamingInterface:
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = tool_call_msg.message_type
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
elif isinstance(content, BetaThinkingBlock):
@@ -382,13 +381,11 @@ class SimpleAnthropicStreamingInterface:
call_id = ctx.get("id")
if name in self.requires_approval_tools:
if prev_message_type and prev_message_type != "approval_request_message":
message_index += 1
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
id=increment_message_uuid(self.letta_message_id),
tool_call=ToolCallDelta(name=name, tool_call_id=call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
otid=Message.generate_otid_from_id(increment_message_uuid(self.letta_message_id), message_index),
run_id=self.run_id,
step_id=self.step_id,
)
@@ -404,7 +401,7 @@ class SimpleAnthropicStreamingInterface:
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
elif isinstance(delta, BetaThinkingDelta):

View File

@@ -305,13 +305,9 @@ class ApprovalRequestMessage(LettaMessage):
tool_call: Union[ToolCall, ToolCallDelta] = Field(
..., description="The tool call that has been requested by the llm to run", deprecated=True
)
requested_tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = Field(
tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = Field(
None, description="The tool calls that have been requested by the llm to run, which are pending approval"
)
allowed_tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = Field(
None,
description="Any tool calls returned by the llm during the same turn that do not require approvals, which will execute once this approval request is handled regardless of approval or denial. Only used when parallel_tool_calls is enabled",
)
class ApprovalResponseMessage(LettaMessage):

View File

@@ -10,6 +10,7 @@ from datetime import datetime, timezone
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from letta_client import LettaMessageUnion
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
from openai.types.responses import ResponseReasoningItem
from pydantic import BaseModel, Field, field_validator, model_validator
@@ -813,14 +814,6 @@ class Message(BaseMessage):
def _convert_approval_request_message(self) -> ApprovalRequestMessage:
"""Convert approval request message to ApprovalRequestMessage"""
allowed_tool_calls = []
if len(self.tool_calls) == 0 and self.tool_call:
requested_tool_calls = [self.tool_call]
if len(self.tool_calls) == 1:
requested_tool_calls = self.tool_calls
else:
requested_tool_calls = [t for t in self.tool_calls if t.requires_approval]
allowed_tool_calls = [t for t in self.tool_calls if not t.requires_approval]
def _convert_tool_call(tool_call):
return ToolCall(
@@ -836,9 +829,8 @@ class Message(BaseMessage):
sender_id=self.sender_id,
step_id=self.step_id,
run_id=self.run_id,
tool_call=_convert_tool_call(requested_tool_calls[0]), # backwards compatibility
requested_tool_calls=[_convert_tool_call(tc) for tc in requested_tool_calls],
allowed_tool_calls=[_convert_tool_call(tc) for tc in allowed_tool_calls],
tool_call=_convert_tool_call(self.tool_calls[0]), # backwards compatibility
tool_calls=[_convert_tool_call(tc) for tc in self.tool_calls],
name=self.name,
)
@@ -1819,7 +1811,40 @@ class Message(BaseMessage):
# Filter last message if it is a lone approval request without a response - this only occurs for token counting
if messages[-1].role == "approval" and messages[-1].tool_calls is not None and len(messages[-1].tool_calls) > 0:
messages.remove(messages[-1])
# Also filter pending tool call message if this turn invoked parallel tool calling
if messages and messages[-1].role == "assistant" and messages[-1].tool_calls is not None and len(messages[-1].tool_calls) > 0:
messages.remove(messages[-1])
# Filter last message if it is a lone reasoning message without assistant message or tool call
if (
messages[-1].role == "assistant"
and messages[-1].tool_calls is None
and (not messages[-1].content or all(not isinstance(content_part, TextContent) for content_part in messages[-1].content))
):
messages.remove(messages[-1])
# Collapse adjacent tool call and approval messages
messages = Message.collapse_tool_call_messages_for_llm_api(messages)
return messages
@staticmethod
def collapse_tool_call_messages_for_llm_api(
messages: List[Message],
) -> List[Message]:
adjacent_tool_call_approval_messages = []
for i in range(len(messages) - 1):
if (
messages[i].role == MessageRole.assistant
and messages[i].tool_calls is not None
and messages[i + 1].role == MessageRole.approval
and messages[i + 1].tool_calls is not None
):
adjacent_tool_call_approval_messages.append(i)
for i in reversed(adjacent_tool_call_approval_messages):
messages[i].content = messages[i].content + messages[i + 1].content
messages[i].tool_calls = messages[i].tool_calls + messages[i + 1].tool_calls
messages.remove(messages[i + 1])
return messages
@staticmethod

View File

@@ -19,7 +19,6 @@ class ToolCall(BaseModel):
type: Literal["function"] = "function"
# function: ToolCallFunction
function: FunctionCall
requires_approval: Optional[bool] = None
class LogProbToken(BaseModel):

View File

@@ -203,12 +203,40 @@ def create_approval_response_message_from_input(
def create_approval_request_message_from_llm_response(
agent_id: str,
model: str,
tool_calls: List[OpenAIToolCall],
requested_tool_calls: List[OpenAIToolCall],
allowed_tool_calls: List[OpenAIToolCall] = [],
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
step_id: str | None = None,
run_id: str = None,
) -> Message:
messages = []
if allowed_tool_calls:
oai_tool_calls = [
OpenAIToolCall(
id=tool_call.id,
function=OpenAIFunction(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
type="function",
)
for tool_call in allowed_tool_calls
]
tool_message = Message(
role=MessageRole.assistant,
content=reasoning_content if reasoning_content else [],
agent_id=agent_id,
model=model,
tool_calls=oai_tool_calls,
tool_call_id=allowed_tool_calls[0].id,
created_at=get_utc_time(),
step_id=step_id,
run_id=run_id,
)
if pre_computed_assistant_message_id:
tool_message.id = pre_computed_assistant_message_id
messages.append(tool_message)
# Construct the tool call with the assistant's message
oai_tool_calls = [
OpenAIToolCall(
@@ -218,15 +246,14 @@ def create_approval_request_message_from_llm_response(
arguments=tool_call.function.arguments,
),
type="function",
requires_approval=tool_call.requires_approval,
)
for tool_call in tool_calls
for tool_call in requested_tool_calls
]
# TODO: Use ToolCallContent instead of tool_calls
# TODO: This helps preserve ordering
approval_message = Message(
role=MessageRole.approval,
content=reasoning_content if reasoning_content else [],
content=reasoning_content if reasoning_content and not allowed_tool_calls else [],
agent_id=agent_id,
model=model,
tool_calls=oai_tool_calls,
@@ -236,8 +263,17 @@ def create_approval_request_message_from_llm_response(
run_id=run_id,
)
if pre_computed_assistant_message_id:
approval_message.id = pre_computed_assistant_message_id
return approval_message
approval_message.id = increment_message_uuid(pre_computed_assistant_message_id)
messages.append(approval_message)
return messages
def increment_message_uuid(message_id: str):
message_uuid = uuid.UUID(message_id.split("-", maxsplit=1)[1])
uuid_as_int = message_uuid.int
incremented_int = uuid_as_int + 1
incremented_uuid = uuid.UUID(int=incremented_int)
return "message-" + str(incremented_uuid)
def create_letta_messages_from_llm_response(

View File

@@ -87,6 +87,25 @@ def accumulate_chunks(stream):
return messages
def approve_tool_call(client: Letta, agent_id: str, tool_call_id: str):
client.agents.messages.create(
agent_id=agent_id,
messages=[
ApprovalCreate(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
{
"type": "approval",
"approve": True,
"tool_call_id": tool_call_id,
},
],
),
],
)
# ------------------------------
# Fixtures
# ------------------------------
@@ -216,7 +235,7 @@ def test_send_approval_without_pending_request(client, agent):
def test_send_user_message_with_pending_request(client, agent):
client.agents.messages.create(
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
@@ -227,9 +246,11 @@ def test_send_user_message_with_pending_request(client, agent):
messages=[MessageCreate(role="user", content="hi")],
)
approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id)
def test_send_approval_message_with_incorrect_request_id(client, agent):
client.agents.messages.create(
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
@@ -252,6 +273,8 @@ def test_send_approval_message_with_incorrect_request_id(client, agent):
],
)
approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id)
# ------------------------------
# Request Test Cases
@@ -276,9 +299,9 @@ def test_invoke_approval_request(
assert messages[2].message_type == "approval_request_message"
assert messages[2].tool_call is not None
assert messages[2].tool_call.name == "get_secret_code_tool"
assert messages[2].requested_tool_calls is not None
assert len(messages[2].requested_tool_calls) == 1
assert messages[2].requested_tool_calls[0]["name"] == "get_secret_code_tool"
assert messages[2].tool_calls is not None
assert len(messages[2].tool_calls) == 1
assert messages[2].tool_calls[0]["name"] == "get_secret_code_tool"
# v3/v1 path: approval request tool args must not include request_heartbeat
import json as _json
@@ -286,6 +309,10 @@ def test_invoke_approval_request(
_args = _json.loads(messages[2].tool_call.arguments)
assert "request_heartbeat" not in _args
client.agents.context.retrieve(agent_id=agent.id)
approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id)
def test_invoke_approval_request_stream(
client: Letta,
@@ -309,43 +336,9 @@ def test_invoke_approval_request_stream(
assert messages[3].message_type == "stop_reason"
assert messages[4].message_type == "usage_statistics"
client.agents.context.retrieve(agent_id=agent.id)
def test_invoke_approval_request_with_context_check(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
{
"type": "approval",
"approve": True,
"tool_call_id": tool_call_id,
},
],
),
],
stream_tokens=True,
)
messages = accumulate_chunks(response)
try:
client.agents.context.retrieve(agent_id=agent.id)
except Exception as e:
if len(messages) > 4:
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
raise e
approve_tool_call(client, agent.id, messages[2].tool_call.tool_call_id)
def test_invoke_tool_after_turning_off_requires_approval(
@@ -357,7 +350,6 @@ def test_invoke_tool_after_turning_off_requires_approval(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
approval_request_id = response.messages[0].id
tool_call_id = response.messages[2].tool_call.tool_call_id
response = client.agents.messages.create_stream(
@@ -432,13 +424,7 @@ def test_approve_tool_call_request(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
approval_request_id = response.messages[0].id
tool_call_id = response.messages[2].tool_call.tool_call_id
# Ensure no request_heartbeat on approval request
# import json as _json
# _args = _json.loads(response.messages[0].tool_call.arguments)
# assert "request_heartbeat" not in _args
response = client.agents.messages.create_stream(
agent_id=agent.id,
@@ -1175,6 +1161,7 @@ def test_parallel_tool_calling(
client: Letta,
agent: AgentState,
) -> None:
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
@@ -1183,28 +1170,32 @@ def test_parallel_tool_calling(
messages = response.messages
assert messages is not None
assert len(messages) == 3
assert len(messages) == 4
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "approval_request_message"
assert messages[2].tool_call is not None
assert messages[2].tool_call.name == "get_secret_code_tool" or messages[2].tool_call.name == "roll_dice_tool"
assert messages[2].message_type == "tool_call_message"
assert len(messages[2].tool_calls) == 1
assert messages[2].tool_calls[0]["name"] == "roll_dice_tool"
assert "6" in messages[2].tool_calls[0]["arguments"]
dice_tool_call_id = messages[2].tool_calls[0]["tool_call_id"]
assert len(messages[2].requested_tool_calls) == 3
assert messages[2].requested_tool_calls[0]["name"] == "get_secret_code_tool"
assert "hello world" in messages[2].requested_tool_calls[0]["arguments"]
approve_tool_call_id = messages[2].requested_tool_calls[0]["tool_call_id"]
assert messages[2].requested_tool_calls[1]["name"] == "get_secret_code_tool"
assert "hello letta" in messages[2].requested_tool_calls[1]["arguments"]
deny_tool_call_id = messages[2].requested_tool_calls[1]["tool_call_id"]
assert messages[2].requested_tool_calls[2]["name"] == "get_secret_code_tool"
assert "hello test" in messages[2].requested_tool_calls[2]["arguments"]
client_side_tool_call_id = messages[2].requested_tool_calls[2]["tool_call_id"]
assert messages[3].message_type == "approval_request_message"
assert messages[3].tool_call is not None
assert messages[3].tool_call.name == "get_secret_code_tool"
assert len(messages[2].allowed_tool_calls) == 1
assert messages[2].allowed_tool_calls[0]["name"] == "roll_dice_tool"
assert "6" in messages[2].allowed_tool_calls[0]["arguments"]
dice_tool_call_id = messages[2].allowed_tool_calls[0]["tool_call_id"]
assert len(messages[3].tool_calls) == 3
assert messages[3].tool_calls[0]["name"] == "get_secret_code_tool"
assert "hello world" in messages[3].tool_calls[0]["arguments"]
approve_tool_call_id = messages[3].tool_calls[0]["tool_call_id"]
assert messages[3].tool_calls[1]["name"] == "get_secret_code_tool"
assert "hello letta" in messages[3].tool_calls[1]["arguments"]
deny_tool_call_id = messages[3].tool_calls[1]["tool_call_id"]
assert messages[3].tool_calls[2]["name"] == "get_secret_code_tool"
assert "hello test" in messages[3].tool_calls[2]["arguments"]
client_side_tool_call_id = messages[3].tool_calls[2]["tool_call_id"]
# ensure context is not bricked
client.agents.context.retrieve(agent_id=agent.id)
response = client.agents.messages.create(
agent_id=agent.id,
@@ -1258,3 +1249,31 @@ def test_parallel_tool_calling(
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "tool_call_message"
assert messages[3].message_type == "tool_return_message"
# ensure context is not bricked
client.agents.context.retrieve(agent_id=agent.id)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
assert len(messages) > 6
assert messages[0].message_type == "user_message"
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "assistant_message"
assert messages[3].message_type == "tool_call_message"
assert messages[4].message_type == "approval_request_message"
assert messages[5].message_type == "approval_response_message"
assert messages[6].message_type == "tool_return_message"
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,
)
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 4
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "stop_reason"
assert messages[3].message_type == "usage_statistics"