diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 6343764f..90fbf67e 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -48,7 +48,7 @@ from letta.schemas.step_metrics import StepMetrics from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User -from letta.server.rest_api.utils import create_letta_messages_from_llm_response +from letta.server.rest_api.utils import create_approval_request_message_from_llm_response, create_letta_messages_from_llm_response from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema @@ -1543,78 +1543,97 @@ class LettaAgent(BaseAgent): request_heartbeat=request_heartbeat, ) - # 2. Execute the tool (or synthesize an error result if disallowed) - tool_rule_violated = tool_call_name not in valid_tool_names - if tool_rule_violated: - tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver) - else: - # Track tool execution time - tool_start_time = get_utc_timestamp_ns() - tool_execution_result = await self._execute_tool( - tool_name=tool_call_name, - tool_args=tool_args, - agent_state=agent_state, - agent_step_span=agent_step_span, + if tool_rules_solver.is_requires_approval_tool(tool_call_name): + approval_message = create_approval_request_message_from_llm_response( + agent_id=agent_state.id, + model=agent_state.llm_config.model, + function_name=tool_call_name, + function_arguments=tool_args, + tool_call_id=tool_call_id, + actor=self.actor, + continue_stepping=request_heartbeat, + reasoning_content=reasoning_content, + pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=step_id, ) - tool_end_time = get_utc_timestamp_ns() + messages_to_persist = (initial_messages or []) + [approval_message] + continue_stepping = False + stop_reason = LettaStopReason(stop_reason=StopReasonType.requires_approval.value) + else: + # 2. Execute the tool (or synthesize an error result if disallowed) + tool_rule_violated = tool_call_name not in valid_tool_names + if tool_rule_violated: + tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver) + else: + # Track tool execution time + tool_start_time = get_utc_timestamp_ns() + tool_execution_result = await self._execute_tool( + tool_name=tool_call_name, + tool_args=tool_args, + agent_state=agent_state, + agent_step_span=agent_step_span, + step_id=step_id, + ) + tool_end_time = get_utc_timestamp_ns() - # Store tool execution time in metrics - step_metrics.tool_execution_ns = tool_end_time - tool_start_time + # Store tool execution time in metrics + step_metrics.tool_execution_ns = tool_end_time - tool_start_time - log_telemetry( - self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id - ) + log_telemetry( + self.logger, + "_handle_ai_response execute tool finish", + tool_execution_result=tool_execution_result, + tool_call_id=tool_call_id, + ) - # 3. Prepare the function-response payload - truncate = tool_call_name not in {"conversation_search", "conversation_search_date", "archival_memory_search"} - return_char_limit = next( - (t.return_char_limit for t in agent_state.tools if t.name == tool_call_name), - None, - ) - function_response_string = validate_function_response( - tool_execution_result.func_return, - return_char_limit=return_char_limit, - truncate=truncate, - ) - self.last_function_response = package_function_response( - was_success=tool_execution_result.success_flag, - response_string=function_response_string, - timezone=agent_state.timezone, - ) + # 3. Prepare the function-response payload + truncate = tool_call_name not in {"conversation_search", "conversation_search_date", "archival_memory_search"} + return_char_limit = next( + (t.return_char_limit for t in agent_state.tools if t.name == tool_call_name), + None, + ) + function_response_string = validate_function_response( + tool_execution_result.func_return, + return_char_limit=return_char_limit, + truncate=truncate, + ) + self.last_function_response = package_function_response( + was_success=tool_execution_result.success_flag, + response_string=function_response_string, + timezone=agent_state.timezone, + ) - # 4. Decide whether to keep stepping (<<< focal section simplified) - continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation( - agent_state=agent_state, - request_heartbeat=request_heartbeat, - tool_call_name=tool_call_name, - tool_rule_violated=tool_rule_violated, - tool_rules_solver=tool_rules_solver, - is_final_step=is_final_step, - ) + # 4. Decide whether to keep stepping (focal section simplified) + continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation( + agent_state=agent_state, + request_heartbeat=request_heartbeat, + tool_call_name=tool_call_name, + tool_rule_violated=tool_rule_violated, + tool_rules_solver=tool_rules_solver, + is_final_step=is_final_step, + ) - # 5. Create messages (step was already created at the beginning) - tool_call_messages = create_letta_messages_from_llm_response( - agent_id=agent_state.id, - model=agent_state.llm_config.model, - function_name=tool_call_name, - function_arguments=tool_args, - tool_execution_result=tool_execution_result, - tool_call_id=tool_call_id, - function_call_success=tool_execution_result.success_flag, - function_response=function_response_string, - timezone=agent_state.timezone, - actor=self.actor, - continue_stepping=continue_stepping, - heartbeat_reason=heartbeat_reason, - reasoning_content=reasoning_content, - pre_computed_assistant_message_id=pre_computed_assistant_message_id, - step_id=step_id, - ) + # 5. Create messages (step was already created at the beginning) + tool_call_messages = create_letta_messages_from_llm_response( + agent_id=agent_state.id, + model=agent_state.llm_config.model, + function_name=tool_call_name, + function_arguments=tool_args, + tool_execution_result=tool_execution_result, + tool_call_id=tool_call_id, + function_call_success=tool_execution_result.success_flag, + function_response=function_response_string, + timezone=agent_state.timezone, + actor=self.actor, + continue_stepping=continue_stepping, + heartbeat_reason=heartbeat_reason, + reasoning_content=reasoning_content, + pre_computed_assistant_message_id=pre_computed_assistant_message_id, + step_id=step_id, + ) + messages_to_persist = (initial_messages or []) + tool_call_messages - persisted_messages = await self.message_manager.create_many_messages_async( - (initial_messages or []) + tool_call_messages, actor=self.actor - ) + persisted_messages = await self.message_manager.create_many_messages_async(messages_to_persist, actor=self.actor) if run_id: await self.job_manager.add_messages_to_job_async( diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index da8182bb..394ee3c0 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -32,6 +32,7 @@ class MessageRole(str, Enum): tool = "tool" function = "function" system = "system" + approval = "approval" class OptionState(str, Enum): diff --git a/letta/schemas/letta_stop_reason.py b/letta/schemas/letta_stop_reason.py index 60365ef4..c197c19a 100644 --- a/letta/schemas/letta_stop_reason.py +++ b/letta/schemas/letta_stop_reason.py @@ -15,6 +15,7 @@ class StopReasonType(str, Enum): no_tool_call = "no_tool_call" tool_rule = "tool_rule" cancelled = "cancelled" + requires_approval = "requires_approval" @property def run_status(self) -> JobStatus: @@ -22,6 +23,7 @@ class StopReasonType(str, Enum): StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule, + StopReasonType.requires_approval, ): return JobStatus.completed elif self in ( diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 41c692c5..192fa3c6 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -20,6 +20,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG from letta.schemas.enums import MessageRole from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.letta_message import ( + ApprovalRequestMessage, AssistantMessage, HiddenReasoningMessage, LettaMessage, @@ -204,7 +205,7 @@ class Message(BaseMessage): @field_validator("role") @classmethod def validate_role(cls, v: str) -> str: - roles = ["system", "assistant", "user", "tool"] + roles = ["system", "assistant", "user", "tool", "approval"] assert v in roles, f"Role must be one of {roles}" return v @@ -275,8 +276,8 @@ class Message(BaseMessage): include_err: Optional[bool] = None, ) -> List[LettaMessage]: """Convert message object (in DB format) to the style used by the original Letta API""" + messages = [] if self.role == MessageRole.assistant: - messages = [] if self.content: messages.extend(self._convert_reasoning_messages()) if self.tool_calls is not None: @@ -289,11 +290,19 @@ class Message(BaseMessage): ), ) elif self.role == MessageRole.tool: - messages = [self._convert_tool_return_message()] + messages.append(self._convert_tool_return_message()) elif self.role == MessageRole.user: - messages = [self._convert_user_message()] + messages.append(self._convert_user_message()) elif self.role == MessageRole.system: - messages = [self._convert_system_message()] + messages.append(self._convert_system_message()) + elif self.role == MessageRole.approval: + if self.content: + messages.extend(self._convert_reasoning_messages()) + if self.tool_calls is not None: + tool_calls = self._convert_tool_call_messages() + assert len(tool_calls) == 1 + approval_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"})) + messages.append(approval_message) else: raise ValueError(f"Unknown role: {self.role}") diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index b5426d99..b20c9db3 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -176,6 +176,46 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti return messages +def create_approval_request_message_from_llm_response( + agent_id: str, + model: str, + function_name: str, + function_arguments: Dict, + tool_call_id: str, + actor: User, + continue_stepping: bool = False, + reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, + pre_computed_assistant_message_id: Optional[str] = None, + step_id: str | None = None, +) -> Message: + # Construct the tool call with the assistant's message + # Force set request_heartbeat in tool_args to calculated continue_stepping + function_arguments[REQUEST_HEARTBEAT_PARAM] = continue_stepping + tool_call = OpenAIToolCall( + id=tool_call_id, + function=OpenAIFunction( + name=function_name, + arguments=json.dumps(function_arguments), + ), + type="function", + ) + # TODO: Use ToolCallContent instead of tool_calls + # TODO: This helps preserve ordering + approval_message = Message( + role=MessageRole.approval, + content=reasoning_content if reasoning_content else [], + agent_id=agent_id, + model=model, + tool_calls=[tool_call], + tool_call_id=tool_call_id, + created_at=get_utc_time(), + step_id=step_id, + ) + if pre_computed_assistant_message_id: + approval_message.id = pre_computed_assistant_message_id + return approval_message + + def create_letta_messages_from_llm_response( agent_id: str, model: str, diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py new file mode 100644 index 00000000..a46e574b --- /dev/null +++ b/tests/integration_test_human_in_the_loop.py @@ -0,0 +1,149 @@ +import os +import threading +import time +import uuid +from typing import Any, List + +import pytest +import requests +from dotenv import load_dotenv +from letta_client import Letta, MessageCreate + +from letta.log import get_logger +from letta.schemas.agent import AgentState + +logger = get_logger(__name__) + +# ------------------------------ +# Helper Functions and Constants +# ------------------------------ + + +def requires_approval_tool(input_text: str) -> str: + """ + A tool that requires approval before execution. + Args: + input_text (str): The input text to process. + Returns: + str: The processed text with 'APPROVED:' prefix. + """ + return f"APPROVED: {input_text}" + + +USER_MESSAGE_OTID = str(uuid.uuid4()) +USER_MESSAGE_TEST_APPROVAL: List[MessageCreate] = [ + MessageCreate( + role="user", + content="This is an automated test message. Call the requires_approval_tool with the text 'test approval'.", + otid=USER_MESSAGE_OTID, + ) +] + +# ------------------------------ +# Fixtures +# ------------------------------ + + +@pytest.fixture(scope="module") +def server_url() -> str: + """ + Provides the URL for the Letta server. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it's accepting connections. + """ + + def _run_server() -> None: + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") + + return url + + +@pytest.fixture(scope="module") +def client(server_url: str) -> Letta: + """ + Creates and returns a synchronous Letta REST client for testing. + """ + client_instance = Letta(base_url=server_url) + yield client_instance + + +@pytest.fixture(scope="function") +def approval_tool_fixture(client: Letta): + """ + Creates and returns a tool that requires approval for testing. + """ + client.tools.upsert_base_tools() + approval_tool = client.tools.upsert_from_function( + func=requires_approval_tool, + # default_requires_approval=True, + ) + yield approval_tool + + +@pytest.fixture(scope="function") +def agent(client: Letta, approval_tool_fixture) -> AgentState: + """ + Creates and returns an agent state for testing with a pre-configured agent. + The agent is configured with the requires_approval_tool. + """ + send_message_tool = client.tools.list(name="send_message")[0] + agent_state = client.agents.create( + name="approval_test_agent", + include_base_tools=False, + tool_ids=[send_message_tool.id, approval_tool_fixture.id], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + tags=["approval_test"], + ) + yield agent_state + + +# ------------------------------ +# Test Cases +# ------------------------------ + + +def test_send_message_with_approval_tool( + disable_e2b_api_key: Any, + client: Letta, + agent: AgentState, +) -> None: + """ + Tests sending a message to an agent with a tool that requires approval. + This test just verifies that the agent can send a message successfully. + The actual approval logic testing will be filled out by the user. + """ + # Send a simple greeting message to test basic functionality + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + + # Basic assertion that we got a response with an approval request + assert response.messages is not None + assert len(response.messages) == 2 + assert response.messages[0].message_type == "reasoning_message" + assert response.messages[1].message_type == "approval_request_message"