feat: separate out hitl cases (#5531)
This commit is contained in:
@@ -46,6 +46,10 @@ from letta.system import package_function_response
|
||||
from letta.utils import log_telemetry, validate_function_response
|
||||
|
||||
|
||||
class ToolCallDenial(ToolCall):
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class LettaAgentV3(LettaAgentV2):
|
||||
"""
|
||||
Similar to V2, but stripped down / simplified, while also generalized:
|
||||
@@ -304,9 +308,29 @@ class LettaAgentV3(LettaAgentV2):
|
||||
self._require_tool_call = require_tool_call
|
||||
|
||||
approval_request, approval_response = _maybe_get_approval_messages(messages)
|
||||
tool_call_denials, tool_returns = [], []
|
||||
if approval_request and approval_response:
|
||||
tool_calls = approval_request.tool_calls
|
||||
content = approval_request.content
|
||||
|
||||
# Get tool calls that are pending
|
||||
backfill_tool_call_id = approval_request.tool_calls[0].id # legacy case
|
||||
pending_tool_calls = {
|
||||
backfill_tool_call_id if a.tool_call_id.startswith("message-") else a.tool_call_id: a
|
||||
for a in approval_response.approvals
|
||||
if a.type == "approval"
|
||||
}
|
||||
tool_calls = [t for t in approval_request.tool_calls if t.id in pending_tool_calls]
|
||||
|
||||
# Get tool calls that were denied
|
||||
denies = {d.tool_call_id: d for d in approval_response.approvals if d.type == "approval" and not d.approve}
|
||||
tool_call_denials = [
|
||||
ToolCallDenial(**t.model_dump(), reason=denies.get(t.id).reason) for t in approval_request.tool_calls if t.id in denies
|
||||
]
|
||||
|
||||
# Get tool calls that were executed client side
|
||||
if approval_response.approvals:
|
||||
tool_returns = [r for r in approval_response.approvals if r.type == "tool"]
|
||||
|
||||
step_id = approval_request.step_id
|
||||
step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor)
|
||||
else:
|
||||
@@ -412,12 +436,6 @@ class LettaAgentV3(LettaAgentV2):
|
||||
tool_calls = [llm_adapter.tool_call]
|
||||
|
||||
aggregated_persisted: list[Message] = []
|
||||
tool_return_payload = (
|
||||
approval_response.approvals[0]
|
||||
if approval_response and approval_response.approvals and isinstance(approval_response.approvals[0], ToolReturn)
|
||||
else None
|
||||
)
|
||||
|
||||
persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response(
|
||||
tool_calls=tool_calls,
|
||||
valid_tool_names=[tool["name"] for tool in valid_tools],
|
||||
@@ -436,10 +454,10 @@ class LettaAgentV3(LettaAgentV2):
|
||||
is_final_step=(remaining_turns == 0),
|
||||
run_id=run_id,
|
||||
step_metrics=step_metrics,
|
||||
is_approval=approval_response.approve if approval_response is not None else False,
|
||||
is_denial=(approval_response.approve == False) if approval_response is not None else False,
|
||||
denial_reason=approval_response.denial_reason if approval_response is not None else None,
|
||||
tool_return=tool_return_payload,
|
||||
is_approval=approval_response is not None,
|
||||
is_denial=(len(tool_call_denials) > 0),
|
||||
denial_reason=tool_call_denials[0].reason if tool_call_denials else None,
|
||||
tool_return=tool_returns[0] if tool_returns else None,
|
||||
)
|
||||
aggregated_persisted.extend(persisted_messages)
|
||||
# NOTE: there is an edge case where persisted_messages is empty (the LLM did a "no-op")
|
||||
@@ -449,8 +467,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
if llm_adapter.supports_token_streaming():
|
||||
# Stream each tool return if tools were executed
|
||||
tool_returns = [msg for msg in aggregated_persisted if msg.role == "tool"]
|
||||
for tr in tool_returns:
|
||||
response_tool_returns = [msg for msg in aggregated_persisted if msg.role == "tool"]
|
||||
for tr in response_tool_returns:
|
||||
# Skip streaming for aggregated parallel tool returns (no per-call tool_call_id)
|
||||
if tr.tool_call_id is None and tr.tool_returns:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user