feat: add support for approval request creation (#4313)
This commit is contained in:
@@ -48,7 +48,7 @@ from letta.schemas.step_metrics import StepMetrics
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
|
||||
from letta.server.rest_api.utils import create_approval_request_message_from_llm_response, create_letta_messages_from_llm_response
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
|
||||
@@ -1543,78 +1543,97 @@ class LettaAgent(BaseAgent):
|
||||
request_heartbeat=request_heartbeat,
|
||||
)
|
||||
|
||||
# 2. Execute the tool (or synthesize an error result if disallowed)
|
||||
tool_rule_violated = tool_call_name not in valid_tool_names
|
||||
if tool_rule_violated:
|
||||
tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver)
|
||||
else:
|
||||
# Track tool execution time
|
||||
tool_start_time = get_utc_timestamp_ns()
|
||||
tool_execution_result = await self._execute_tool(
|
||||
tool_name=tool_call_name,
|
||||
tool_args=tool_args,
|
||||
agent_state=agent_state,
|
||||
agent_step_span=agent_step_span,
|
||||
if tool_rules_solver.is_requires_approval_tool(tool_call_name):
|
||||
approval_message = create_approval_request_message_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
function_name=tool_call_name,
|
||||
function_arguments=tool_args,
|
||||
tool_call_id=tool_call_id,
|
||||
actor=self.actor,
|
||||
continue_stepping=request_heartbeat,
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
tool_end_time = get_utc_timestamp_ns()
|
||||
messages_to_persist = (initial_messages or []) + [approval_message]
|
||||
continue_stepping = False
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
|
||||
else:
|
||||
# 2. Execute the tool (or synthesize an error result if disallowed)
|
||||
tool_rule_violated = tool_call_name not in valid_tool_names
|
||||
if tool_rule_violated:
|
||||
tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver)
|
||||
else:
|
||||
# Track tool execution time
|
||||
tool_start_time = get_utc_timestamp_ns()
|
||||
tool_execution_result = await self._execute_tool(
|
||||
tool_name=tool_call_name,
|
||||
tool_args=tool_args,
|
||||
agent_state=agent_state,
|
||||
agent_step_span=agent_step_span,
|
||||
step_id=step_id,
|
||||
)
|
||||
tool_end_time = get_utc_timestamp_ns()
|
||||
|
||||
# Store tool execution time in metrics
|
||||
step_metrics.tool_execution_ns = tool_end_time - tool_start_time
|
||||
# Store tool execution time in metrics
|
||||
step_metrics.tool_execution_ns = tool_end_time - tool_start_time
|
||||
|
||||
log_telemetry(
|
||||
self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id
|
||||
)
|
||||
log_telemetry(
|
||||
self.logger,
|
||||
"_handle_ai_response execute tool finish",
|
||||
tool_execution_result=tool_execution_result,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
|
||||
# 3. Prepare the function-response payload
|
||||
truncate = tool_call_name not in {"conversation_search", "conversation_search_date", "archival_memory_search"}
|
||||
return_char_limit = next(
|
||||
(t.return_char_limit for t in agent_state.tools if t.name == tool_call_name),
|
||||
None,
|
||||
)
|
||||
function_response_string = validate_function_response(
|
||||
tool_execution_result.func_return,
|
||||
return_char_limit=return_char_limit,
|
||||
truncate=truncate,
|
||||
)
|
||||
self.last_function_response = package_function_response(
|
||||
was_success=tool_execution_result.success_flag,
|
||||
response_string=function_response_string,
|
||||
timezone=agent_state.timezone,
|
||||
)
|
||||
# 3. Prepare the function-response payload
|
||||
truncate = tool_call_name not in {"conversation_search", "conversation_search_date", "archival_memory_search"}
|
||||
return_char_limit = next(
|
||||
(t.return_char_limit for t in agent_state.tools if t.name == tool_call_name),
|
||||
None,
|
||||
)
|
||||
function_response_string = validate_function_response(
|
||||
tool_execution_result.func_return,
|
||||
return_char_limit=return_char_limit,
|
||||
truncate=truncate,
|
||||
)
|
||||
self.last_function_response = package_function_response(
|
||||
was_success=tool_execution_result.success_flag,
|
||||
response_string=function_response_string,
|
||||
timezone=agent_state.timezone,
|
||||
)
|
||||
|
||||
# 4. Decide whether to keep stepping (<<< focal section simplified)
|
||||
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
|
||||
agent_state=agent_state,
|
||||
request_heartbeat=request_heartbeat,
|
||||
tool_call_name=tool_call_name,
|
||||
tool_rule_violated=tool_rule_violated,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
is_final_step=is_final_step,
|
||||
)
|
||||
# 4. Decide whether to keep stepping (focal section simplified)
|
||||
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
|
||||
agent_state=agent_state,
|
||||
request_heartbeat=request_heartbeat,
|
||||
tool_call_name=tool_call_name,
|
||||
tool_rule_violated=tool_rule_violated,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
is_final_step=is_final_step,
|
||||
)
|
||||
|
||||
# 5. Create messages (step was already created at the beginning)
|
||||
tool_call_messages = create_letta_messages_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
function_name=tool_call_name,
|
||||
function_arguments=tool_args,
|
||||
tool_execution_result=tool_execution_result,
|
||||
tool_call_id=tool_call_id,
|
||||
function_call_success=tool_execution_result.success_flag,
|
||||
function_response=function_response_string,
|
||||
timezone=agent_state.timezone,
|
||||
actor=self.actor,
|
||||
continue_stepping=continue_stepping,
|
||||
heartbeat_reason=heartbeat_reason,
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
# 5. Create messages (step was already created at the beginning)
|
||||
tool_call_messages = create_letta_messages_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
function_name=tool_call_name,
|
||||
function_arguments=tool_args,
|
||||
tool_execution_result=tool_execution_result,
|
||||
tool_call_id=tool_call_id,
|
||||
function_call_success=tool_execution_result.success_flag,
|
||||
function_response=function_response_string,
|
||||
timezone=agent_state.timezone,
|
||||
actor=self.actor,
|
||||
continue_stepping=continue_stepping,
|
||||
heartbeat_reason=heartbeat_reason,
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
(initial_messages or []) + tool_call_messages, actor=self.actor
|
||||
)
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(messages_to_persist, actor=self.actor)
|
||||
|
||||
if run_id:
|
||||
await self.job_manager.add_messages_to_job_async(
|
||||
|
||||
@@ -32,6 +32,7 @@ class MessageRole(str, Enum):
|
||||
tool = "tool"
|
||||
function = "function"
|
||||
system = "system"
|
||||
approval = "approval"
|
||||
|
||||
|
||||
class OptionState(str, Enum):
|
||||
|
||||
@@ -15,6 +15,7 @@ class StopReasonType(str, Enum):
|
||||
no_tool_call = "no_tool_call"
|
||||
tool_rule = "tool_rule"
|
||||
cancelled = "cancelled"
|
||||
requires_approval = "requires_approval"
|
||||
|
||||
@property
|
||||
def run_status(self) -> JobStatus:
|
||||
@@ -22,6 +23,7 @@ class StopReasonType(str, Enum):
|
||||
StopReasonType.end_turn,
|
||||
StopReasonType.max_steps,
|
||||
StopReasonType.tool_rule,
|
||||
StopReasonType.requires_approval,
|
||||
):
|
||||
return JobStatus.completed
|
||||
elif self in (
|
||||
|
||||
@@ -20,6 +20,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import (
|
||||
ApprovalRequestMessage,
|
||||
AssistantMessage,
|
||||
HiddenReasoningMessage,
|
||||
LettaMessage,
|
||||
@@ -204,7 +205,7 @@ class Message(BaseMessage):
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, v: str) -> str:
|
||||
roles = ["system", "assistant", "user", "tool"]
|
||||
roles = ["system", "assistant", "user", "tool", "approval"]
|
||||
assert v in roles, f"Role must be one of {roles}"
|
||||
return v
|
||||
|
||||
@@ -275,8 +276,8 @@ class Message(BaseMessage):
|
||||
include_err: Optional[bool] = None,
|
||||
) -> List[LettaMessage]:
|
||||
"""Convert message object (in DB format) to the style used by the original Letta API"""
|
||||
messages = []
|
||||
if self.role == MessageRole.assistant:
|
||||
messages = []
|
||||
if self.content:
|
||||
messages.extend(self._convert_reasoning_messages())
|
||||
if self.tool_calls is not None:
|
||||
@@ -289,11 +290,19 @@ class Message(BaseMessage):
|
||||
),
|
||||
)
|
||||
elif self.role == MessageRole.tool:
|
||||
messages = [self._convert_tool_return_message()]
|
||||
messages.append(self._convert_tool_return_message())
|
||||
elif self.role == MessageRole.user:
|
||||
messages = [self._convert_user_message()]
|
||||
messages.append(self._convert_user_message())
|
||||
elif self.role == MessageRole.system:
|
||||
messages = [self._convert_system_message()]
|
||||
messages.append(self._convert_system_message())
|
||||
elif self.role == MessageRole.approval:
|
||||
if self.content:
|
||||
messages.extend(self._convert_reasoning_messages())
|
||||
if self.tool_calls is not None:
|
||||
tool_calls = self._convert_tool_call_messages()
|
||||
assert len(tool_calls) == 1
|
||||
approval_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"}))
|
||||
messages.append(approval_message)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {self.role}")
|
||||
|
||||
|
||||
@@ -176,6 +176,46 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti
|
||||
return messages
|
||||
|
||||
|
||||
def create_approval_request_message_from_llm_response(
|
||||
agent_id: str,
|
||||
model: str,
|
||||
function_name: str,
|
||||
function_arguments: Dict,
|
||||
tool_call_id: str,
|
||||
actor: User,
|
||||
continue_stepping: bool = False,
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
step_id: str | None = None,
|
||||
) -> Message:
|
||||
# Construct the tool call with the assistant's message
|
||||
# Force set request_heartbeat in tool_args to calculated continue_stepping
|
||||
function_arguments[REQUEST_HEARTBEAT_PARAM] = continue_stepping
|
||||
tool_call = OpenAIToolCall(
|
||||
id=tool_call_id,
|
||||
function=OpenAIFunction(
|
||||
name=function_name,
|
||||
arguments=json.dumps(function_arguments),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
# TODO: Use ToolCallContent instead of tool_calls
|
||||
# TODO: This helps preserve ordering
|
||||
approval_message = Message(
|
||||
role=MessageRole.approval,
|
||||
content=reasoning_content if reasoning_content else [],
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
tool_calls=[tool_call],
|
||||
tool_call_id=tool_call_id,
|
||||
created_at=get_utc_time(),
|
||||
step_id=step_id,
|
||||
)
|
||||
if pre_computed_assistant_message_id:
|
||||
approval_message.id = pre_computed_assistant_message_id
|
||||
return approval_message
|
||||
|
||||
|
||||
def create_letta_messages_from_llm_response(
|
||||
agent_id: str,
|
||||
model: str,
|
||||
|
||||
149
tests/integration_test_human_in_the_loop.py
Normal file
149
tests/integration_test_human_in_the_loop.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, List
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta, MessageCreate
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# ------------------------------
|
||||
# Helper Functions and Constants
|
||||
# ------------------------------
|
||||
|
||||
|
||||
def requires_approval_tool(input_text: str) -> str:
|
||||
"""
|
||||
A tool that requires approval before execution.
|
||||
Args:
|
||||
input_text (str): The input text to process.
|
||||
Returns:
|
||||
str: The processed text with 'APPROVED:' prefix.
|
||||
"""
|
||||
return f"APPROVED: {input_text}"
|
||||
|
||||
|
||||
USER_MESSAGE_OTID = str(uuid.uuid4())
|
||||
USER_MESSAGE_TEST_APPROVAL: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="This is an automated test message. Call the requires_approval_tool with the text 'test approval'.",
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
|
||||
# ------------------------------
|
||||
# Fixtures
|
||||
# ------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url() -> str:
|
||||
"""
|
||||
Provides the URL for the Letta server.
|
||||
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||||
and polls until it's accepting connections.
|
||||
"""
|
||||
|
||||
def _run_server() -> None:
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Poll until the server is up (or timeout)
|
||||
timeout_seconds = 30
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url + "/v1/health")
|
||||
if resp.status_code < 500:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server_url: str) -> Letta:
|
||||
"""
|
||||
Creates and returns a synchronous Letta REST client for testing.
|
||||
"""
|
||||
client_instance = Letta(base_url=server_url)
|
||||
yield client_instance
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def approval_tool_fixture(client: Letta):
|
||||
"""
|
||||
Creates and returns a tool that requires approval for testing.
|
||||
"""
|
||||
client.tools.upsert_base_tools()
|
||||
approval_tool = client.tools.upsert_from_function(
|
||||
func=requires_approval_tool,
|
||||
# default_requires_approval=True,
|
||||
)
|
||||
yield approval_tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent(client: Letta, approval_tool_fixture) -> AgentState:
|
||||
"""
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
The agent is configured with the requires_approval_tool.
|
||||
"""
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
agent_state = client.agents.create(
|
||||
name="approval_test_agent",
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id, approval_tool_fixture.id],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tags=["approval_test"],
|
||||
)
|
||||
yield agent_state
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Test Cases
|
||||
# ------------------------------
|
||||
|
||||
|
||||
def test_send_message_with_approval_tool(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message to an agent with a tool that requires approval.
|
||||
This test just verifies that the agent can send a message successfully.
|
||||
The actual approval logic testing will be filled out by the user.
|
||||
"""
|
||||
# Send a simple greeting message to test basic functionality
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
|
||||
# Basic assertion that we got a response with an approval request
|
||||
assert response.messages is not None
|
||||
assert len(response.messages) == 2
|
||||
assert response.messages[0].message_type == "reasoning_message"
|
||||
assert response.messages[1].message_type == "approval_request_message"
|
||||
Reference in New Issue
Block a user