feat: add support for approval request creation (#4313)

This commit is contained in:
cthomas
2025-08-29 15:23:02 -07:00
committed by GitHub
parent bfd6030f72
commit b8c2f42d33
6 changed files with 290 additions and 70 deletions

View File

@@ -48,7 +48,7 @@ from letta.schemas.step_metrics import StepMetrics
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
from letta.server.rest_api.utils import create_approval_request_message_from_llm_response, create_letta_messages_from_llm_response
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
@@ -1543,78 +1543,97 @@ class LettaAgent(BaseAgent):
request_heartbeat=request_heartbeat,
)
# 2. Execute the tool (or synthesize an error result if disallowed)
tool_rule_violated = tool_call_name not in valid_tool_names
if tool_rule_violated:
tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver)
else:
# Track tool execution time
tool_start_time = get_utc_timestamp_ns()
tool_execution_result = await self._execute_tool(
tool_name=tool_call_name,
tool_args=tool_args,
agent_state=agent_state,
agent_step_span=agent_step_span,
if tool_rules_solver.is_requires_approval_tool(tool_call_name):
approval_message = create_approval_request_message_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name=tool_call_name,
function_arguments=tool_args,
tool_call_id=tool_call_id,
actor=self.actor,
continue_stepping=request_heartbeat,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
step_id=step_id,
)
tool_end_time = get_utc_timestamp_ns()
messages_to_persist = (initial_messages or []) + [approval_message]
continue_stepping = False
stop_reason = LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
else:
# 2. Execute the tool (or synthesize an error result if disallowed)
tool_rule_violated = tool_call_name not in valid_tool_names
if tool_rule_violated:
tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver)
else:
# Track tool execution time
tool_start_time = get_utc_timestamp_ns()
tool_execution_result = await self._execute_tool(
tool_name=tool_call_name,
tool_args=tool_args,
agent_state=agent_state,
agent_step_span=agent_step_span,
step_id=step_id,
)
tool_end_time = get_utc_timestamp_ns()
# Store tool execution time in metrics
step_metrics.tool_execution_ns = tool_end_time - tool_start_time
# Store tool execution time in metrics
step_metrics.tool_execution_ns = tool_end_time - tool_start_time
log_telemetry(
self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id
)
log_telemetry(
self.logger,
"_handle_ai_response execute tool finish",
tool_execution_result=tool_execution_result,
tool_call_id=tool_call_id,
)
# 3. Prepare the function-response payload
truncate = tool_call_name not in {"conversation_search", "conversation_search_date", "archival_memory_search"}
return_char_limit = next(
(t.return_char_limit for t in agent_state.tools if t.name == tool_call_name),
None,
)
function_response_string = validate_function_response(
tool_execution_result.func_return,
return_char_limit=return_char_limit,
truncate=truncate,
)
self.last_function_response = package_function_response(
was_success=tool_execution_result.success_flag,
response_string=function_response_string,
timezone=agent_state.timezone,
)
# 3. Prepare the function-response payload
truncate = tool_call_name not in {"conversation_search", "conversation_search_date", "archival_memory_search"}
return_char_limit = next(
(t.return_char_limit for t in agent_state.tools if t.name == tool_call_name),
None,
)
function_response_string = validate_function_response(
tool_execution_result.func_return,
return_char_limit=return_char_limit,
truncate=truncate,
)
self.last_function_response = package_function_response(
was_success=tool_execution_result.success_flag,
response_string=function_response_string,
timezone=agent_state.timezone,
)
# 4. Decide whether to keep stepping (<<< focal section simplified)
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
agent_state=agent_state,
request_heartbeat=request_heartbeat,
tool_call_name=tool_call_name,
tool_rule_violated=tool_rule_violated,
tool_rules_solver=tool_rules_solver,
is_final_step=is_final_step,
)
# 4. Decide whether to keep stepping (focal section simplified)
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
agent_state=agent_state,
request_heartbeat=request_heartbeat,
tool_call_name=tool_call_name,
tool_rule_violated=tool_rule_violated,
tool_rules_solver=tool_rules_solver,
is_final_step=is_final_step,
)
# 5. Create messages (step was already created at the beginning)
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name=tool_call_name,
function_arguments=tool_args,
tool_execution_result=tool_execution_result,
tool_call_id=tool_call_id,
function_call_success=tool_execution_result.success_flag,
function_response=function_response_string,
timezone=agent_state.timezone,
actor=self.actor,
continue_stepping=continue_stepping,
heartbeat_reason=heartbeat_reason,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
step_id=step_id,
)
# 5. Create messages (step was already created at the beginning)
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name=tool_call_name,
function_arguments=tool_args,
tool_execution_result=tool_execution_result,
tool_call_id=tool_call_id,
function_call_success=tool_execution_result.success_flag,
function_response=function_response_string,
timezone=agent_state.timezone,
actor=self.actor,
continue_stepping=continue_stepping,
heartbeat_reason=heartbeat_reason,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
step_id=step_id,
)
messages_to_persist = (initial_messages or []) + tool_call_messages
persisted_messages = await self.message_manager.create_many_messages_async(
(initial_messages or []) + tool_call_messages, actor=self.actor
)
persisted_messages = await self.message_manager.create_many_messages_async(messages_to_persist, actor=self.actor)
if run_id:
await self.job_manager.add_messages_to_job_async(

View File

@@ -32,6 +32,7 @@ class MessageRole(str, Enum):
tool = "tool"
function = "function"
system = "system"
approval = "approval"
class OptionState(str, Enum):

View File

@@ -15,6 +15,7 @@ class StopReasonType(str, Enum):
no_tool_call = "no_tool_call"
tool_rule = "tool_rule"
cancelled = "cancelled"
requires_approval = "requires_approval"
@property
def run_status(self) -> JobStatus:
@@ -22,6 +23,7 @@ class StopReasonType(str, Enum):
StopReasonType.end_turn,
StopReasonType.max_steps,
StopReasonType.tool_rule,
StopReasonType.requires_approval,
):
return JobStatus.completed
elif self in (

View File

@@ -20,6 +20,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageRole
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import (
ApprovalRequestMessage,
AssistantMessage,
HiddenReasoningMessage,
LettaMessage,
@@ -204,7 +205,7 @@ class Message(BaseMessage):
@field_validator("role")
@classmethod
def validate_role(cls, v: str) -> str:
roles = ["system", "assistant", "user", "tool"]
roles = ["system", "assistant", "user", "tool", "approval"]
assert v in roles, f"Role must be one of {roles}"
return v
@@ -275,8 +276,8 @@ class Message(BaseMessage):
include_err: Optional[bool] = None,
) -> List[LettaMessage]:
"""Convert message object (in DB format) to the style used by the original Letta API"""
messages = []
if self.role == MessageRole.assistant:
messages = []
if self.content:
messages.extend(self._convert_reasoning_messages())
if self.tool_calls is not None:
@@ -289,11 +290,19 @@ class Message(BaseMessage):
),
)
elif self.role == MessageRole.tool:
messages = [self._convert_tool_return_message()]
messages.append(self._convert_tool_return_message())
elif self.role == MessageRole.user:
messages = [self._convert_user_message()]
messages.append(self._convert_user_message())
elif self.role == MessageRole.system:
messages = [self._convert_system_message()]
messages.append(self._convert_system_message())
elif self.role == MessageRole.approval:
if self.content:
messages.extend(self._convert_reasoning_messages())
if self.tool_calls is not None:
tool_calls = self._convert_tool_call_messages()
assert len(tool_calls) == 1
approval_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"}))
messages.append(approval_message)
else:
raise ValueError(f"Unknown role: {self.role}")

View File

@@ -176,6 +176,46 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti
return messages
def create_approval_request_message_from_llm_response(
agent_id: str,
model: str,
function_name: str,
function_arguments: Dict,
tool_call_id: str,
actor: User,
continue_stepping: bool = False,
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
step_id: str | None = None,
) -> Message:
# Construct the tool call with the assistant's message
# Force set request_heartbeat in tool_args to calculated continue_stepping
function_arguments[REQUEST_HEARTBEAT_PARAM] = continue_stepping
tool_call = OpenAIToolCall(
id=tool_call_id,
function=OpenAIFunction(
name=function_name,
arguments=json.dumps(function_arguments),
),
type="function",
)
# TODO: Use ToolCallContent instead of tool_calls
# TODO: This helps preserve ordering
approval_message = Message(
role=MessageRole.approval,
content=reasoning_content if reasoning_content else [],
agent_id=agent_id,
model=model,
tool_calls=[tool_call],
tool_call_id=tool_call_id,
created_at=get_utc_time(),
step_id=step_id,
)
if pre_computed_assistant_message_id:
approval_message.id = pre_computed_assistant_message_id
return approval_message
def create_letta_messages_from_llm_response(
agent_id: str,
model: str,

View File

@@ -0,0 +1,149 @@
import os
import threading
import time
import uuid
from typing import Any, List
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta, MessageCreate
from letta.log import get_logger
from letta.schemas.agent import AgentState
logger = get_logger(__name__)
# ------------------------------
# Helper Functions and Constants
# ------------------------------
def requires_approval_tool(input_text: str) -> str:
"""
A tool that requires approval before execution.
Args:
input_text (str): The input text to process.
Returns:
str: The processed text with 'APPROVED:' prefix.
"""
return f"APPROVED: {input_text}"
USER_MESSAGE_OTID = str(uuid.uuid4())
USER_MESSAGE_TEST_APPROVAL: List[MessageCreate] = [
MessageCreate(
role="user",
content="This is an automated test message. Call the requires_approval_tool with the text 'test approval'.",
otid=USER_MESSAGE_OTID,
)
]
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="function")
def approval_tool_fixture(client: Letta):
"""
Creates and returns a tool that requires approval for testing.
"""
client.tools.upsert_base_tools()
approval_tool = client.tools.upsert_from_function(
func=requires_approval_tool,
# default_requires_approval=True,
)
yield approval_tool
@pytest.fixture(scope="function")
def agent(client: Letta, approval_tool_fixture) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
The agent is configured with the requires_approval_tool.
"""
send_message_tool = client.tools.list(name="send_message")[0]
agent_state = client.agents.create(
name="approval_test_agent",
include_base_tools=False,
tool_ids=[send_message_tool.id, approval_tool_fixture.id],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
tags=["approval_test"],
)
yield agent_state
# ------------------------------
# Test Cases
# ------------------------------
def test_send_message_with_approval_tool(
disable_e2b_api_key: Any,
client: Letta,
agent: AgentState,
) -> None:
"""
Tests sending a message to an agent with a tool that requires approval.
This test just verifies that the agent can send a message successfully.
The actual approval logic testing will be filled out by the user.
"""
# Send a simple greeting message to test basic functionality
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
# Basic assertion that we got a response with an approval request
assert response.messages is not None
assert len(response.messages) == 2
assert response.messages[0].message_type == "reasoning_message"
assert response.messages[1].message_type == "approval_request_message"