feat: separate out hitl cases (#5531)

This commit is contained in:
cthomas
2025-10-17 11:21:03 -07:00
committed by Caren Thomas
parent a03263aca2
commit 62d5ae1828

View File

@@ -46,6 +46,10 @@ from letta.system import package_function_response
from letta.utils import log_telemetry, validate_function_response
class ToolCallDenial(ToolCall):
reason: Optional[str] = None
class LettaAgentV3(LettaAgentV2):
"""
Similar to V2, but stripped down / simplified, while also generalized:
@@ -304,9 +308,29 @@ class LettaAgentV3(LettaAgentV2):
self._require_tool_call = require_tool_call
approval_request, approval_response = _maybe_get_approval_messages(messages)
tool_call_denials, tool_returns = [], []
if approval_request and approval_response:
tool_calls = approval_request.tool_calls
content = approval_request.content
# Get tool calls that are pending
backfill_tool_call_id = approval_request.tool_calls[0].id # legacy case
pending_tool_calls = {
backfill_tool_call_id if a.tool_call_id.startswith("message-") else a.tool_call_id: a
for a in approval_response.approvals
if a.type == "approval"
}
tool_calls = [t for t in approval_request.tool_calls if t.id in pending_tool_calls]
# Get tool calls that were denied
denies = {d.tool_call_id: d for d in approval_response.approvals if d.type == "approval" and not d.approve}
tool_call_denials = [
ToolCallDenial(**t.model_dump(), reason=denies.get(t.id).reason) for t in approval_request.tool_calls if t.id in denies
]
# Get tool calls that were executed client side
if approval_response.approvals:
tool_returns = [r for r in approval_response.approvals if r.type == "tool"]
step_id = approval_request.step_id
step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor)
else:
@@ -412,12 +436,6 @@ class LettaAgentV3(LettaAgentV2):
tool_calls = [llm_adapter.tool_call]
aggregated_persisted: list[Message] = []
tool_return_payload = (
approval_response.approvals[0]
if approval_response and approval_response.approvals and isinstance(approval_response.approvals[0], ToolReturn)
else None
)
persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response(
tool_calls=tool_calls,
valid_tool_names=[tool["name"] for tool in valid_tools],
@@ -436,10 +454,10 @@ class LettaAgentV3(LettaAgentV2):
is_final_step=(remaining_turns == 0),
run_id=run_id,
step_metrics=step_metrics,
is_approval=approval_response.approve if approval_response is not None else False,
is_denial=(approval_response.approve == False) if approval_response is not None else False,
denial_reason=approval_response.denial_reason if approval_response is not None else None,
tool_return=tool_return_payload,
is_approval=approval_response is not None,
is_denial=(len(tool_call_denials) > 0),
denial_reason=tool_call_denials[0].reason if tool_call_denials else None,
tool_return=tool_returns[0] if tool_returns else None,
)
aggregated_persisted.extend(persisted_messages)
# NOTE: there is an edge case where persisted_messages is empty (the LLM did a "no-op")
@@ -449,8 +467,8 @@ class LettaAgentV3(LettaAgentV2):
if llm_adapter.supports_token_streaming():
# Stream each tool return if tools were executed
tool_returns = [msg for msg in aggregated_persisted if msg.role == "tool"]
for tr in tool_returns:
response_tool_returns = [msg for msg in aggregated_persisted if msg.role == "tool"]
for tr in response_tool_returns:
# Skip streaming for aggregated parallel tool returns (no per-call tool_call_id)
if tr.tool_call_id is None and tr.tool_returns:
continue