From 62d5ae1828b489e44e058330103f90d288fc9468 Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 17 Oct 2025 11:21:03 -0700 Subject: [PATCH] feat: separate out hitl cases (#5531) --- letta/agents/letta_agent_v3.py | 44 ++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index bb2b8dfc..28516ef6 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -46,6 +46,10 @@ from letta.system import package_function_response from letta.utils import log_telemetry, validate_function_response +class ToolCallDenial(ToolCall): + reason: Optional[str] = None + + class LettaAgentV3(LettaAgentV2): """ Similar to V2, but stripped down / simplified, while also generalized: @@ -304,9 +308,29 @@ class LettaAgentV3(LettaAgentV2): self._require_tool_call = require_tool_call approval_request, approval_response = _maybe_get_approval_messages(messages) + tool_call_denials, tool_returns = [], [] if approval_request and approval_response: - tool_calls = approval_request.tool_calls content = approval_request.content + + # Get tool calls that are pending + backfill_tool_call_id = approval_request.tool_calls[0].id # legacy case + pending_tool_calls = { + backfill_tool_call_id if a.tool_call_id.startswith("message-") else a.tool_call_id: a + for a in approval_response.approvals + if a.type == "approval" + } + tool_calls = [t for t in approval_request.tool_calls if t.id in pending_tool_calls] + + # Get tool calls that were denied + denies = {d.tool_call_id: d for d in approval_response.approvals if d.type == "approval" and not d.approve} + tool_call_denials = [ + ToolCallDenial(**t.model_dump(), reason=denies.get(t.id).reason) for t in approval_request.tool_calls if t.id in denies + ] + + # Get tool calls that were executed client side + if approval_response.approvals: + tool_returns = [r for r in approval_response.approvals if r.type == "tool"] + step_id = approval_request.step_id step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor) else: @@ -412,12 +436,6 @@ class LettaAgentV3(LettaAgentV2): tool_calls = [llm_adapter.tool_call] aggregated_persisted: list[Message] = [] - tool_return_payload = ( - approval_response.approvals[0] - if approval_response and approval_response.approvals and isinstance(approval_response.approvals[0], ToolReturn) - else None - ) - persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response( tool_calls=tool_calls, valid_tool_names=[tool["name"] for tool in valid_tools], @@ -436,10 +454,10 @@ class LettaAgentV3(LettaAgentV2): is_final_step=(remaining_turns == 0), run_id=run_id, step_metrics=step_metrics, - is_approval=approval_response.approve if approval_response is not None else False, - is_denial=(approval_response.approve == False) if approval_response is not None else False, - denial_reason=approval_response.denial_reason if approval_response is not None else None, - tool_return=tool_return_payload, + is_approval=approval_response is not None, + is_denial=(len(tool_call_denials) > 0), + denial_reason=tool_call_denials[0].reason if tool_call_denials else None, + tool_return=tool_returns[0] if tool_returns else None, ) aggregated_persisted.extend(persisted_messages) # NOTE: there is an edge case where persisted_messages is empty (the LLM did a "no-op") @@ -449,8 +467,8 @@ class LettaAgentV3(LettaAgentV2): if llm_adapter.supports_token_streaming(): # Stream each tool return if tools were executed - tool_returns = [msg for msg in aggregated_persisted if msg.role == "tool"] - for tr in tool_returns: + response_tool_returns = [msg for msg in aggregated_persisted if msg.role == "tool"] + for tr in response_tool_returns: # Skip streaming for aggregated parallel tool returns (no per-call tool_call_id) if tr.tool_call_id is None and tr.tool_returns: continue