feat: latest hitl + parallel tool call changes (#5565)
This commit is contained in:
@@ -23210,7 +23210,7 @@
|
||||
"description": "The tool call that has been requested by the llm to run",
|
||||
"deprecated": true
|
||||
},
|
||||
"requested_tool_calls": {
|
||||
"tool_calls": {
|
||||
"anyOf": [
|
||||
{
|
||||
"items": {
|
||||
@@ -23225,26 +23225,8 @@
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Requested Tool Calls",
|
||||
"title": "Tool Calls",
|
||||
"description": "The tool calls that have been requested by the llm to run, which are pending approval"
|
||||
},
|
||||
"allowed_tool_calls": {
|
||||
"anyOf": [
|
||||
{
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ToolCall"
|
||||
},
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ToolCallDelta"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Allowed Tool Calls",
|
||||
"description": "Any tool calls returned by the llm during the same turn that do not require approvals, which will execute once this approval request is handled regardless of approval or denial. Only used when parallel_tool_calls is enabled"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
|
||||
@@ -137,9 +137,7 @@ async def _prepare_in_context_messages_async(
|
||||
|
||||
def validate_approval_tool_call_ids(approval_request_message: Message, approval_response_message: ApprovalCreate):
|
||||
approval_requests = approval_request_message.tool_calls
|
||||
approval_request_tool_call_ids = [approval_request.id for approval_request in approval_requests if approval_request.requires_approval]
|
||||
if not approval_request_tool_call_ids and len(approval_request_message.tool_calls) == 1:
|
||||
approval_request_tool_call_ids = [approval_request_message.tool_calls[0].id]
|
||||
approval_request_tool_call_ids = [approval_request.id for approval_request in approval_requests]
|
||||
|
||||
approval_responses = approval_response_message.approvals
|
||||
approval_response_tool_call_ids = [approval_response.tool_call_id for approval_response in approval_responses]
|
||||
@@ -418,3 +416,19 @@ def _maybe_get_approval_messages(messages: list[Message]) -> Tuple[Message | Non
|
||||
if maybe_approval_request.role == "approval" and maybe_approval_response.role == "approval":
|
||||
return maybe_approval_request, maybe_approval_response
|
||||
return None, None
|
||||
|
||||
|
||||
def _maybe_get_pending_tool_call_message(messages: list[Message]) -> Message | None:
|
||||
"""
|
||||
Only used in the case where hitl is invoked with parallel tool calling,
|
||||
where agent calls some tools that require approval, and others that don't.
|
||||
"""
|
||||
if len(messages) >= 3:
|
||||
maybe_tool_call_message = messages[-3]
|
||||
if (
|
||||
maybe_tool_call_message.role == "assistant"
|
||||
and maybe_tool_call_message.tool_calls is not None
|
||||
and len(maybe_tool_call_message.tool_calls) > 0
|
||||
):
|
||||
return maybe_tool_call_message
|
||||
return None
|
||||
|
||||
@@ -1700,15 +1700,17 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
if not is_approval and tool_rules_solver.is_requires_approval_tool(tool_call_name):
|
||||
tool_args[REQUEST_HEARTBEAT_PARAM] = request_heartbeat
|
||||
approval_message = create_approval_request_message_from_llm_response(
|
||||
approval_messages = create_approval_request_message_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
tool_calls=[ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args)))],
|
||||
requested_tool_calls=[
|
||||
ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args)))
|
||||
],
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + [approval_message]
|
||||
messages_to_persist = (initial_messages or []) + approval_messages
|
||||
continue_stepping = False
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
|
||||
else:
|
||||
|
||||
@@ -935,16 +935,18 @@ class LettaAgentV2(BaseAgentV2):
|
||||
|
||||
if not is_approval and tool_rules_solver.is_requires_approval_tool(tool_call_name):
|
||||
tool_args[REQUEST_HEARTBEAT_PARAM] = request_heartbeat
|
||||
approval_message = create_approval_request_message_from_llm_response(
|
||||
approval_messages = create_approval_request_message_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
tool_calls=[ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args)))],
|
||||
requested_tool_calls=[
|
||||
ToolCall(id=tool_call_id, function=FunctionCall(name=tool_call_name, arguments=json.dumps(tool_args)))
|
||||
],
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||||
step_id=step_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + [approval_message]
|
||||
messages_to_persist = (initial_messages or []) + approval_messages
|
||||
continue_stepping = False
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
|
||||
else:
|
||||
|
||||
@@ -12,6 +12,7 @@ from letta.agents.helpers import (
|
||||
_build_rule_violation_result,
|
||||
_load_last_function_response,
|
||||
_maybe_get_approval_messages,
|
||||
_maybe_get_pending_tool_call_message,
|
||||
_prepare_in_context_messages_no_persist_async,
|
||||
_safe_load_tool_call_str,
|
||||
generate_step_id,
|
||||
@@ -315,15 +316,15 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
# Get tool calls that are pending
|
||||
backfill_tool_call_id = approval_request.tool_calls[0].id # legacy case
|
||||
approved_tool_call_ids = [
|
||||
approved_tool_call_ids = {
|
||||
backfill_tool_call_id if a.tool_call_id.startswith("message-") else a.tool_call_id
|
||||
for a in approval_response.approvals
|
||||
if isinstance(a, ApprovalReturn) and a.approve
|
||||
]
|
||||
pending_tool_call_ids = [
|
||||
t.id for t in approval_request.tool_calls if not t.requires_approval and t.id not in approved_tool_call_ids
|
||||
]
|
||||
tool_calls = [t for t in approval_request.tool_calls if t.id in approved_tool_call_ids + pending_tool_call_ids]
|
||||
}
|
||||
tool_calls = [tool_call for tool_call in approval_request.tool_calls if tool_call.id in approved_tool_call_ids]
|
||||
pending_tool_call_message = _maybe_get_pending_tool_call_message(messages)
|
||||
if pending_tool_call_message:
|
||||
tool_calls.extend(pending_tool_call_message.tool_calls)
|
||||
|
||||
# Get tool calls that were denied
|
||||
denies = {d.tool_call_id: d for d in approval_response.approvals if isinstance(d, ApprovalReturn) and not d.approve}
|
||||
@@ -681,21 +682,20 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
# 2. Check whether tool call requires approval
|
||||
if not is_approval_response:
|
||||
requires_approval = False
|
||||
for tool_call in tool_calls:
|
||||
tool_call.requires_approval = tool_rules_solver.is_requires_approval_tool(tool_call.function.name)
|
||||
requires_approval = requires_approval or tool_call.requires_approval
|
||||
if requires_approval:
|
||||
approval_message = create_approval_request_message_from_llm_response(
|
||||
requested_tool_calls = [t for t in tool_calls if tool_rules_solver.is_requires_approval_tool(t.function.name)]
|
||||
allowed_tool_calls = [t for t in tool_calls if not tool_rules_solver.is_requires_approval_tool(t.function.name)]
|
||||
if requested_tool_calls:
|
||||
approval_messages = create_approval_request_message_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
tool_calls=tool_calls,
|
||||
requested_tool_calls=requested_tool_calls,
|
||||
allowed_tool_calls=allowed_tool_calls,
|
||||
reasoning_content=content,
|
||||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||||
step_id=step_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + [approval_message]
|
||||
messages_to_persist = (initial_messages or []) + approval_messages
|
||||
|
||||
for message in messages_to_persist:
|
||||
if message.run_id is None:
|
||||
|
||||
@@ -181,6 +181,7 @@ def deserialize_tool_calls(data: Optional[List[Dict]]) -> List[OpenAIToolCall]:
|
||||
|
||||
calls = []
|
||||
for item in data:
|
||||
item.pop("requires_approval", None) # legacy field
|
||||
func_data = item.pop("function", None)
|
||||
tool_call_function = OpenAIFunction(**func_data)
|
||||
calls.append(OpenAIToolCall(function=tool_call_function, **item))
|
||||
|
||||
@@ -39,6 +39,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
|
||||
from letta.server.rest_api.utils import increment_message_uuid
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -282,14 +283,12 @@ class SimpleAnthropicStreamingInterface:
|
||||
call_id = content.id
|
||||
# Initialize arguments from the start event's input (often {}) to avoid undefined in UIs
|
||||
if name in self.requires_approval_tools:
|
||||
if prev_message_type and prev_message_type != "approval_request_message":
|
||||
message_index += 1
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
id=increment_message_uuid(self.letta_message_id),
|
||||
# Do not emit placeholder arguments here to avoid UI duplicates
|
||||
tool_call=ToolCallDelta(name=name, tool_call_id=call_id),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
otid=Message.generate_otid_from_id(increment_message_uuid(self.letta_message_id), message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
@@ -306,7 +305,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
|
||||
elif isinstance(content, BetaThinkingBlock):
|
||||
@@ -382,13 +381,11 @@ class SimpleAnthropicStreamingInterface:
|
||||
call_id = ctx.get("id")
|
||||
|
||||
if name in self.requires_approval_tools:
|
||||
if prev_message_type and prev_message_type != "approval_request_message":
|
||||
message_index += 1
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
id=increment_message_uuid(self.letta_message_id),
|
||||
tool_call=ToolCallDelta(name=name, tool_call_id=call_id, arguments=delta.partial_json),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
otid=Message.generate_otid_from_id(increment_message_uuid(self.letta_message_id), message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
@@ -404,7 +401,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
|
||||
elif isinstance(delta, BetaThinkingDelta):
|
||||
|
||||
@@ -305,13 +305,9 @@ class ApprovalRequestMessage(LettaMessage):
|
||||
tool_call: Union[ToolCall, ToolCallDelta] = Field(
|
||||
..., description="The tool call that has been requested by the llm to run", deprecated=True
|
||||
)
|
||||
requested_tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = Field(
|
||||
tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = Field(
|
||||
None, description="The tool calls that have been requested by the llm to run, which are pending approval"
|
||||
)
|
||||
allowed_tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = Field(
|
||||
None,
|
||||
description="Any tool calls returned by the llm during the same turn that do not require approvals, which will execute once this approval request is handled regardless of approval or denial. Only used when parallel_tool_calls is enabled",
|
||||
)
|
||||
|
||||
|
||||
class ApprovalResponseMessage(LettaMessage):
|
||||
|
||||
@@ -10,6 +10,7 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from letta_client import LettaMessageUnion
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
|
||||
from openai.types.responses import ResponseReasoningItem
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
@@ -813,14 +814,6 @@ class Message(BaseMessage):
|
||||
|
||||
def _convert_approval_request_message(self) -> ApprovalRequestMessage:
|
||||
"""Convert approval request message to ApprovalRequestMessage"""
|
||||
allowed_tool_calls = []
|
||||
if len(self.tool_calls) == 0 and self.tool_call:
|
||||
requested_tool_calls = [self.tool_call]
|
||||
if len(self.tool_calls) == 1:
|
||||
requested_tool_calls = self.tool_calls
|
||||
else:
|
||||
requested_tool_calls = [t for t in self.tool_calls if t.requires_approval]
|
||||
allowed_tool_calls = [t for t in self.tool_calls if not t.requires_approval]
|
||||
|
||||
def _convert_tool_call(tool_call):
|
||||
return ToolCall(
|
||||
@@ -836,9 +829,8 @@ class Message(BaseMessage):
|
||||
sender_id=self.sender_id,
|
||||
step_id=self.step_id,
|
||||
run_id=self.run_id,
|
||||
tool_call=_convert_tool_call(requested_tool_calls[0]), # backwards compatibility
|
||||
requested_tool_calls=[_convert_tool_call(tc) for tc in requested_tool_calls],
|
||||
allowed_tool_calls=[_convert_tool_call(tc) for tc in allowed_tool_calls],
|
||||
tool_call=_convert_tool_call(self.tool_calls[0]), # backwards compatibility
|
||||
tool_calls=[_convert_tool_call(tc) for tc in self.tool_calls],
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
@@ -1819,7 +1811,40 @@ class Message(BaseMessage):
|
||||
# Filter last message if it is a lone approval request without a response - this only occurs for token counting
|
||||
if messages[-1].role == "approval" and messages[-1].tool_calls is not None and len(messages[-1].tool_calls) > 0:
|
||||
messages.remove(messages[-1])
|
||||
# Also filter pending tool call message if this turn invoked parallel tool calling
|
||||
if messages and messages[-1].role == "assistant" and messages[-1].tool_calls is not None and len(messages[-1].tool_calls) > 0:
|
||||
messages.remove(messages[-1])
|
||||
|
||||
# Filter last message if it is a lone reasoning message without assistant message or tool call
|
||||
if (
|
||||
messages[-1].role == "assistant"
|
||||
and messages[-1].tool_calls is None
|
||||
and (not messages[-1].content or all(not isinstance(content_part, TextContent) for content_part in messages[-1].content))
|
||||
):
|
||||
messages.remove(messages[-1])
|
||||
|
||||
# Collapse adjacent tool call and approval messages
|
||||
messages = Message.collapse_tool_call_messages_for_llm_api(messages)
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def collapse_tool_call_messages_for_llm_api(
|
||||
messages: List[Message],
|
||||
) -> List[Message]:
|
||||
adjacent_tool_call_approval_messages = []
|
||||
for i in range(len(messages) - 1):
|
||||
if (
|
||||
messages[i].role == MessageRole.assistant
|
||||
and messages[i].tool_calls is not None
|
||||
and messages[i + 1].role == MessageRole.approval
|
||||
and messages[i + 1].tool_calls is not None
|
||||
):
|
||||
adjacent_tool_call_approval_messages.append(i)
|
||||
for i in reversed(adjacent_tool_call_approval_messages):
|
||||
messages[i].content = messages[i].content + messages[i + 1].content
|
||||
messages[i].tool_calls = messages[i].tool_calls + messages[i + 1].tool_calls
|
||||
messages.remove(messages[i + 1])
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -19,7 +19,6 @@ class ToolCall(BaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
# function: ToolCallFunction
|
||||
function: FunctionCall
|
||||
requires_approval: Optional[bool] = None
|
||||
|
||||
|
||||
class LogProbToken(BaseModel):
|
||||
|
||||
@@ -203,12 +203,40 @@ def create_approval_response_message_from_input(
|
||||
def create_approval_request_message_from_llm_response(
|
||||
agent_id: str,
|
||||
model: str,
|
||||
tool_calls: List[OpenAIToolCall],
|
||||
requested_tool_calls: List[OpenAIToolCall],
|
||||
allowed_tool_calls: List[OpenAIToolCall] = [],
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
step_id: str | None = None,
|
||||
run_id: str = None,
|
||||
) -> Message:
|
||||
messages = []
|
||||
if allowed_tool_calls:
|
||||
oai_tool_calls = [
|
||||
OpenAIToolCall(
|
||||
id=tool_call.id,
|
||||
function=OpenAIFunction(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
for tool_call in allowed_tool_calls
|
||||
]
|
||||
tool_message = Message(
|
||||
role=MessageRole.assistant,
|
||||
content=reasoning_content if reasoning_content else [],
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
tool_calls=oai_tool_calls,
|
||||
tool_call_id=allowed_tool_calls[0].id,
|
||||
created_at=get_utc_time(),
|
||||
step_id=step_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
if pre_computed_assistant_message_id:
|
||||
tool_message.id = pre_computed_assistant_message_id
|
||||
messages.append(tool_message)
|
||||
# Construct the tool call with the assistant's message
|
||||
oai_tool_calls = [
|
||||
OpenAIToolCall(
|
||||
@@ -218,15 +246,14 @@ def create_approval_request_message_from_llm_response(
|
||||
arguments=tool_call.function.arguments,
|
||||
),
|
||||
type="function",
|
||||
requires_approval=tool_call.requires_approval,
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
for tool_call in requested_tool_calls
|
||||
]
|
||||
# TODO: Use ToolCallContent instead of tool_calls
|
||||
# TODO: This helps preserve ordering
|
||||
approval_message = Message(
|
||||
role=MessageRole.approval,
|
||||
content=reasoning_content if reasoning_content else [],
|
||||
content=reasoning_content if reasoning_content and not allowed_tool_calls else [],
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
tool_calls=oai_tool_calls,
|
||||
@@ -236,8 +263,17 @@ def create_approval_request_message_from_llm_response(
|
||||
run_id=run_id,
|
||||
)
|
||||
if pre_computed_assistant_message_id:
|
||||
approval_message.id = pre_computed_assistant_message_id
|
||||
return approval_message
|
||||
approval_message.id = increment_message_uuid(pre_computed_assistant_message_id)
|
||||
messages.append(approval_message)
|
||||
return messages
|
||||
|
||||
|
||||
def increment_message_uuid(message_id: str):
|
||||
message_uuid = uuid.UUID(message_id.split("-", maxsplit=1)[1])
|
||||
uuid_as_int = message_uuid.int
|
||||
incremented_int = uuid_as_int + 1
|
||||
incremented_uuid = uuid.UUID(int=incremented_int)
|
||||
return "message-" + str(incremented_uuid)
|
||||
|
||||
|
||||
def create_letta_messages_from_llm_response(
|
||||
|
||||
@@ -87,6 +87,25 @@ def accumulate_chunks(stream):
|
||||
return messages
|
||||
|
||||
|
||||
def approve_tool_call(client: Letta, agent_id: str, tool_call_id: str):
|
||||
client.agents.messages.create(
|
||||
agent_id=agent_id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=False, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approvals=[
|
||||
{
|
||||
"type": "approval",
|
||||
"approve": True,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Fixtures
|
||||
# ------------------------------
|
||||
@@ -216,7 +235,7 @@ def test_send_approval_without_pending_request(client, agent):
|
||||
|
||||
|
||||
def test_send_user_message_with_pending_request(client, agent):
|
||||
client.agents.messages.create(
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
@@ -227,9 +246,11 @@ def test_send_user_message_with_pending_request(client, agent):
|
||||
messages=[MessageCreate(role="user", content="hi")],
|
||||
)
|
||||
|
||||
approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id)
|
||||
|
||||
|
||||
def test_send_approval_message_with_incorrect_request_id(client, agent):
|
||||
client.agents.messages.create(
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
@@ -252,6 +273,8 @@ def test_send_approval_message_with_incorrect_request_id(client, agent):
|
||||
],
|
||||
)
|
||||
|
||||
approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Request Test Cases
|
||||
@@ -276,9 +299,9 @@ def test_invoke_approval_request(
|
||||
assert messages[2].message_type == "approval_request_message"
|
||||
assert messages[2].tool_call is not None
|
||||
assert messages[2].tool_call.name == "get_secret_code_tool"
|
||||
assert messages[2].requested_tool_calls is not None
|
||||
assert len(messages[2].requested_tool_calls) == 1
|
||||
assert messages[2].requested_tool_calls[0]["name"] == "get_secret_code_tool"
|
||||
assert messages[2].tool_calls is not None
|
||||
assert len(messages[2].tool_calls) == 1
|
||||
assert messages[2].tool_calls[0]["name"] == "get_secret_code_tool"
|
||||
|
||||
# v3/v1 path: approval request tool args must not include request_heartbeat
|
||||
import json as _json
|
||||
@@ -286,6 +309,10 @@ def test_invoke_approval_request(
|
||||
_args = _json.loads(messages[2].tool_call.arguments)
|
||||
assert "request_heartbeat" not in _args
|
||||
|
||||
client.agents.context.retrieve(agent_id=agent.id)
|
||||
|
||||
approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id)
|
||||
|
||||
|
||||
def test_invoke_approval_request_stream(
|
||||
client: Letta,
|
||||
@@ -309,43 +336,9 @@ def test_invoke_approval_request_stream(
|
||||
assert messages[3].message_type == "stop_reason"
|
||||
assert messages[4].message_type == "usage_statistics"
|
||||
|
||||
client.agents.context.retrieve(agent_id=agent.id)
|
||||
|
||||
def test_invoke_approval_request_with_context_check(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=False, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approvals=[
|
||||
{
|
||||
"type": "approval",
|
||||
"approve": True,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
try:
|
||||
client.agents.context.retrieve(agent_id=agent.id)
|
||||
except Exception as e:
|
||||
if len(messages) > 4:
|
||||
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
|
||||
raise e
|
||||
approve_tool_call(client, agent.id, messages[2].tool_call.tool_call_id)
|
||||
|
||||
|
||||
def test_invoke_tool_after_turning_off_requires_approval(
|
||||
@@ -357,7 +350,6 @@ def test_invoke_tool_after_turning_off_requires_approval(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
approval_request_id = response.messages[0].id
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
@@ -432,13 +424,7 @@ def test_approve_tool_call_request(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
approval_request_id = response.messages[0].id
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
# Ensure no request_heartbeat on approval request
|
||||
# import json as _json
|
||||
|
||||
# _args = _json.loads(response.messages[0].tool_call.arguments)
|
||||
# assert "request_heartbeat" not in _args
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
@@ -1175,6 +1161,7 @@ def test_parallel_tool_calling(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
|
||||
@@ -1183,28 +1170,32 @@ def test_parallel_tool_calling(
|
||||
messages = response.messages
|
||||
|
||||
assert messages is not None
|
||||
assert len(messages) == 3
|
||||
assert len(messages) == 4
|
||||
assert messages[0].message_type == "reasoning_message"
|
||||
assert messages[1].message_type == "assistant_message"
|
||||
assert messages[2].message_type == "approval_request_message"
|
||||
assert messages[2].tool_call is not None
|
||||
assert messages[2].tool_call.name == "get_secret_code_tool" or messages[2].tool_call.name == "roll_dice_tool"
|
||||
assert messages[2].message_type == "tool_call_message"
|
||||
assert len(messages[2].tool_calls) == 1
|
||||
assert messages[2].tool_calls[0]["name"] == "roll_dice_tool"
|
||||
assert "6" in messages[2].tool_calls[0]["arguments"]
|
||||
dice_tool_call_id = messages[2].tool_calls[0]["tool_call_id"]
|
||||
|
||||
assert len(messages[2].requested_tool_calls) == 3
|
||||
assert messages[2].requested_tool_calls[0]["name"] == "get_secret_code_tool"
|
||||
assert "hello world" in messages[2].requested_tool_calls[0]["arguments"]
|
||||
approve_tool_call_id = messages[2].requested_tool_calls[0]["tool_call_id"]
|
||||
assert messages[2].requested_tool_calls[1]["name"] == "get_secret_code_tool"
|
||||
assert "hello letta" in messages[2].requested_tool_calls[1]["arguments"]
|
||||
deny_tool_call_id = messages[2].requested_tool_calls[1]["tool_call_id"]
|
||||
assert messages[2].requested_tool_calls[2]["name"] == "get_secret_code_tool"
|
||||
assert "hello test" in messages[2].requested_tool_calls[2]["arguments"]
|
||||
client_side_tool_call_id = messages[2].requested_tool_calls[2]["tool_call_id"]
|
||||
assert messages[3].message_type == "approval_request_message"
|
||||
assert messages[3].tool_call is not None
|
||||
assert messages[3].tool_call.name == "get_secret_code_tool"
|
||||
|
||||
assert len(messages[2].allowed_tool_calls) == 1
|
||||
assert messages[2].allowed_tool_calls[0]["name"] == "roll_dice_tool"
|
||||
assert "6" in messages[2].allowed_tool_calls[0]["arguments"]
|
||||
dice_tool_call_id = messages[2].allowed_tool_calls[0]["tool_call_id"]
|
||||
assert len(messages[3].tool_calls) == 3
|
||||
assert messages[3].tool_calls[0]["name"] == "get_secret_code_tool"
|
||||
assert "hello world" in messages[3].tool_calls[0]["arguments"]
|
||||
approve_tool_call_id = messages[3].tool_calls[0]["tool_call_id"]
|
||||
assert messages[3].tool_calls[1]["name"] == "get_secret_code_tool"
|
||||
assert "hello letta" in messages[3].tool_calls[1]["arguments"]
|
||||
deny_tool_call_id = messages[3].tool_calls[1]["tool_call_id"]
|
||||
assert messages[3].tool_calls[2]["name"] == "get_secret_code_tool"
|
||||
assert "hello test" in messages[3].tool_calls[2]["arguments"]
|
||||
client_side_tool_call_id = messages[3].tool_calls[2]["tool_call_id"]
|
||||
|
||||
# ensure context is not bricked
|
||||
client.agents.context.retrieve(agent_id=agent.id)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
@@ -1258,3 +1249,31 @@ def test_parallel_tool_calling(
|
||||
assert messages[1].message_type == "reasoning_message"
|
||||
assert messages[2].message_type == "tool_call_message"
|
||||
assert messages[3].message_type == "tool_return_message"
|
||||
|
||||
# ensure context is not bricked
|
||||
client.agents.context.retrieve(agent_id=agent.id)
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
|
||||
assert len(messages) > 6
|
||||
assert messages[0].message_type == "user_message"
|
||||
assert messages[1].message_type == "reasoning_message"
|
||||
assert messages[2].message_type == "assistant_message"
|
||||
assert messages[3].message_type == "tool_call_message"
|
||||
assert messages[4].message_type == "approval_request_message"
|
||||
assert messages[5].message_type == "approval_response_message"
|
||||
assert messages[6].message_type == "tool_return_message"
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_FOLLOW_UP,
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
assert messages is not None
|
||||
assert len(messages) == 4
|
||||
assert messages[0].message_type == "reasoning_message"
|
||||
assert messages[1].message_type == "assistant_message"
|
||||
assert messages[2].message_type == "stop_reason"
|
||||
assert messages[3].message_type == "usage_statistics"
|
||||
|
||||
Reference in New Issue
Block a user