diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 228a2f7c..4bee75ac 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -168,15 +168,19 @@ async def _prepare_in_context_messages_no_persist_async( # Check for approval-related message validation if len(input_messages) == 1 and input_messages[0].type == "approval": # User is trying to send an approval response + approval_message = input_messages[0].approvals[0] if current_in_context_messages[-1].role != "approval": raise ValueError( "Cannot process approval response: No tool call is currently awaiting approval. " "Please send a regular message to interact with the agent." ) - if input_messages[0].approval_request_id != current_in_context_messages[-1].id: + if ( + approval_message.tool_call_id != current_in_context_messages[-1].id + and approval_message.tool_call_id != current_in_context_messages[-1].tool_calls[0].id + ): raise ValueError( - f"Invalid approval request ID. Expected '{current_in_context_messages[-1].id}' " - f"but received '{input_messages[0].approval_request_id}'." + f"Invalid tool call ID. Expected '{current_in_context_messages[-1].tool_calls[0].id}' " + f"but received '{approval_message.tool_call_id}'." ) new_in_context_messages = create_approval_response_message_from_input(agent_state=agent_state, input_message=input_messages[0]) else: diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 1df13a1c..ba590731 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -19,16 +19,17 @@ from letta.agents.letta_agent_v2 import LettaAgentV2 from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM from letta.errors import ContextWindowExceededError, LLMError from letta.helpers import ToolRulesSolver -from letta.helpers.datetime_helpers import get_utc_timestamp_ns +from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns from letta.helpers.tool_execution_helper import enable_strict_mode from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState +from letta.schemas.enums import MessageRole from letta.schemas.letta_message import LettaMessage, MessageType from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType -from letta.schemas.message import Message, MessageCreate +from letta.schemas.message import Message, MessageCreate, ToolReturn from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics from letta.schemas.step import StepProgression from letta.schemas.step_metrics import StepMetrics @@ -397,6 +398,9 @@ class LettaAgentV3(LettaAgentV2): is_approval=approval_response.approve if approval_response is not None else False, is_denial=(approval_response.approve == False) if approval_response is not None else False, denial_reason=approval_response.denial_reason if approval_response is not None else None, + tool_return=approval_response.approvals[0] + if approval_response and approval_response.approvals and isinstance(approval_response.approvals[0], ToolReturn) + else None, ) # NOTE: there is an edge case where persisted_messages is empty (the LLM did a "no-op") @@ -542,6 +546,7 @@ class LettaAgentV3(LettaAgentV2): is_approval: bool | None = None, is_denial: bool | None = None, denial_reason: str | None = None, + tool_return: ToolReturn | None = None, ) -> tuple[list[Message], bool, LettaStopReason | None]: """ Handle the final AI response once streaming completes, execute / validate the @@ -553,29 +558,45 @@ class LettaAgentV3(LettaAgentV2): else: tool_call_id: str = tool_call.id or f"call_{uuid.uuid4().hex[:8]}" - if is_denial: + if is_denial or tool_return is not None: continue_stepping = True stop_reason = None - tool_call_messages = create_letta_messages_from_llm_response( - agent_id=agent_state.id, - model=agent_state.llm_config.model, - function_name=tool_call.function.name, - function_arguments={}, - tool_execution_result=ToolExecutionResult(status="error"), - tool_call_id=tool_call_id, - function_response=f"Error: request to call tool denied. User reason: {denial_reason}", - timezone=agent_state.timezone, - continue_stepping=continue_stepping, - # NOTE: we may need to change this to not have a "heartbeat" prefix for v3? - heartbeat_reason=f"{NON_USER_MSG_PREFIX}Continuing: user denied request to call tool.", - reasoning_content=None, - pre_computed_assistant_message_id=None, - step_id=step_id, - run_id=run_id, - is_approval_response=True, - force_set_request_heartbeat=False, - add_heartbeat_on_continue=False, - ) + if tool_return is not None: + tool_call_messages = [ + Message( + role=MessageRole.tool, + content=[TextContent(text=tool_return.func_response)], + agent_id=agent_state.id, + model=agent_state.llm_config.model, + tool_calls=[], + tool_call_id=tool_return.tool_call_id, + created_at=get_utc_time(), + tool_returns=[tool_return], + run_id=run_id, + step_id=step_id, + ) + ] + else: + tool_call_messages = create_letta_messages_from_llm_response( + agent_id=agent_state.id, + model=agent_state.llm_config.model, + function_name=tool_call.function.name, + function_arguments={}, + tool_execution_result=ToolExecutionResult(status="error"), + tool_call_id=tool_call_id, + function_response=f"Error: request to call tool denied. User reason: {denial_reason}", + timezone=agent_state.timezone, + continue_stepping=continue_stepping, + # NOTE: we may need to change this to not have a "heartbeat" prefix for v3? + heartbeat_reason=f"{NON_USER_MSG_PREFIX}Continuing: user denied request to call tool.", + reasoning_content=None, + pre_computed_assistant_message_id=None, + step_id=step_id, + run_id=run_id, + is_approval_response=True, + force_set_request_heartbeat=False, + add_heartbeat_on_continue=False, + ) messages_to_persist = (initial_messages or []) + tool_call_messages # Set run_id on all messages before persisting diff --git a/letta/schemas/message.py b/letta/schemas/message.py index f30bf1da..4bf22260 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -33,6 +33,7 @@ from letta.schemas.letta_message import ( SystemMessage, ToolCall, ToolCallMessage, + ToolReturn as LettaToolReturn, ToolReturnMessage, UserMessage, ) @@ -344,17 +345,29 @@ class Message(BaseMessage): messages.append(approval_request_message) else: if self.approvals: - first_approval = [a for a in self.approvals if isinstance(a, ApprovalReturn)][0] + first_approval = [a for a in self.approvals if isinstance(a, ApprovalReturn)] + + def maybe_convert_tool_return_message(maybe_tool_return: ToolReturn): + if isinstance(maybe_tool_return, LettaToolReturn): + return LettaToolReturn( + tool_call_id=maybe_tool_return.tool_call_id, + status=maybe_tool_return.status, + tool_return=maybe_tool_return.tool_return, + stdout=maybe_tool_return.stdout, + stderr=maybe_tool_return.stderr, + ) + return maybe_tool_return + approval_response_message = ApprovalResponseMessage( id=self.id, date=self.created_at, otid=self.otid, - approvals=self.approvals, + approvals=[maybe_convert_tool_return_message(approval) for approval in self.approvals], run_id=self.run_id, # TODO: temporary populate these fields for backwards compatibility - approve=first_approval.approve, - approval_request_id=first_approval.tool_call_id, - reason=first_approval.reason, + approve=first_approval[0].approve if first_approval else None, + approval_request_id=first_approval[0].tool_call_id if first_approval else None, + reason=first_approval[0].reason if first_approval else None, ) else: approval_response_message = ApprovalResponseMessage( diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index bb3b9028..1b7e5dcd 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -27,6 +27,7 @@ from letta.otel.metric_registry import MetricRegistry from letta.otel.tracing import tracer from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import ToolReturn as LettaToolReturn from letta.schemas.letta_message_content import ( OmittedReasoningContent, ReasoningContent, @@ -169,6 +170,20 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti def create_approval_response_message_from_input(agent_state: AgentState, input_message: ApprovalCreate) -> List[Message]: + def maybe_convert_tool_return_message(maybe_tool_return: LettaToolReturn): + if isinstance(maybe_tool_return, LettaToolReturn): + packaged_function_response = package_function_response( + maybe_tool_return.status == "success", maybe_tool_return.tool_return, agent_state.timezone + ) + return ToolReturn( + tool_call_id=maybe_tool_return.tool_call_id, + status=maybe_tool_return.status, + func_response=packaged_function_response, + stdout=maybe_tool_return.stdout, + stderr=maybe_tool_return.stderr, + ) + return maybe_tool_return + return [ Message( role=MessageRole.approval, @@ -177,6 +192,7 @@ def create_approval_response_message_from_input(agent_state: AgentState, input_m approval_request_id=input_message.approval_request_id, approve=input_message.approve, denial_reason=input_message.reason, + approvals=[maybe_convert_tool_return_message(approval) for approval in input_message.approvals], ) ] diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 701e6170..91fe7889 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -184,7 +184,7 @@ def test_send_approval_message_with_incorrect_request_id(client, agent): messages=USER_MESSAGE_TEST_APPROVAL, ) - with pytest.raises(ApiError, match="Invalid approval request ID"): + with pytest.raises(ApiError, match="Invalid tool call ID"): client.agents.messages.create( agent_id=agent.id, messages=[ApprovalCreate(approve=True, approval_request_id=FAKE_REQUEST_ID)],