feat: Override heartbeat request when system forces step exit (#3015)

This commit is contained in:
Matthew Zhou
2025-06-24 19:50:00 -07:00
committed by GitHub
parent aa02da3bb3
commit a31826d7a5
7 changed files with 184 additions and 111 deletions

View File

@@ -1,12 +1,15 @@
import json
import uuid
import xml.etree.ElementTree as ET
from typing import List, Optional, Tuple
from letta.helpers import ToolRulesSolver
from letta.schemas.agent import AgentState
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message, MessageCreate
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import create_input_messages
@@ -205,3 +208,28 @@ def deserialize_message_history(xml_str: str) -> Tuple[List[str], str]:
def generate_step_id():
return f"step-{uuid.uuid4()}"
def _safe_load_dict(raw: str) -> dict:
"""Lenient JSON → dict with fallback to eval on assertion failure."""
if "}{" in raw: # strip accidental parallel calls
raw = raw.split("}{", 1)[0] + "}"
try:
data = json.loads(raw)
if not isinstance(data, dict):
raise AssertionError
return data
except (json.JSONDecodeError, AssertionError):
return json.loads(raw) if raw else {}
def _pop_heartbeat(tool_args: dict) -> bool:
hb = tool_args.pop("request_heartbeat", False)
return str(hb).lower() == "true" if isinstance(hb, str) else bool(hb)
def _build_rule_violation_result(tool_name: str, valid: list[str], solver: ToolRulesSolver) -> ToolExecutionResult:
hint_lines = solver.guess_rule_violation(tool_name)
hint_txt = ("\n** Hint: Possible rules that were violated:\n" + "\n".join(f"\t- {h}" for h in hint_lines)) if hint_lines else ""
msg = f"[ToolConstraintError] Cannot call {tool_name}, " f"valid tools include: {valid}.{hint_txt}"
return ToolExecutionResult(status="error", func_return=msg)

View File

@@ -10,7 +10,14 @@ from opentelemetry.trace import Span
from letta.agents.base_agent import BaseAgent
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_no_persist_async, generate_step_id
from letta.agents.helpers import (
_build_rule_violation_result,
_create_letta_response,
_pop_heartbeat,
_prepare_in_context_messages_no_persist_async,
_safe_load_dict,
generate_step_id,
)
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX
from letta.errors import ContextWindowExceededError
from letta.helpers import ToolRulesSolver
@@ -931,45 +938,15 @@ class LettaAgent(BaseAgent):
run_id: Optional[str] = None,
) -> Tuple[List[Message], bool, Optional[LettaStopReason]]:
"""
Now that streaming is done, handle the final AI response.
This might yield additional SSE tokens if we do stalling.
At the end, set self._continue_execution accordingly.
Handle the final AI response once streaming completes, execute / validate the
tool call, decide whether we should keep stepping, and persist state.
"""
stop_reason = None
# Check if the called tool is allowed by tool name:
tool_call_name = tool_call.function.name
tool_call_args_str = tool_call.function.arguments
# Temp hack to gracefully handle parallel tool calling attempt, only take first one
if "}{" in tool_call_args_str:
tool_call_args_str = tool_call_args_str.split("}{", 1)[0] + "}"
try:
tool_args = json.loads(tool_call_args_str)
assert isinstance(tool_args, dict), "tool_args must be a dict"
except json.JSONDecodeError:
tool_args = {}
except AssertionError:
tool_args = json.loads(tool_args)
# Get request heartbeats and coerce to bool
request_heartbeat = tool_args.pop("request_heartbeat", False)
if is_final_step:
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
self.logger.info("Agent has reached max steps.")
request_heartbeat = False
else:
# Pre-emptively pop out inner_thoughts
tool_args.pop(INNER_THOUGHTS_KWARG, "")
# So this is necessary, because sometimes non-structured outputs makes mistakes
if not isinstance(request_heartbeat, bool):
if isinstance(request_heartbeat, str):
request_heartbeat = request_heartbeat.lower() == "true"
else:
request_heartbeat = bool(request_heartbeat)
tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
# 1. Parse and validate the tool-call envelope
tool_call_name: str = tool_call.function.name
tool_call_id: str = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
tool_args = _safe_load_dict(tool_call.function.arguments)
request_heartbeat: bool = _pop_heartbeat(tool_args)
tool_args.pop(INNER_THOUGHTS_KWARG, None)
log_telemetry(
self.logger,
@@ -979,16 +956,11 @@ class LettaAgent(BaseAgent):
tool_call_id=tool_call_id,
request_heartbeat=request_heartbeat,
)
# Check if tool rule is violated - if so, we'll force continuation
tool_rule_violated = tool_call_name not in valid_tool_names
# 2. Execute the tool (or synthesize an error result if disallowed)
tool_rule_violated = tool_call_name not in valid_tool_names
if tool_rule_violated:
base_error_message = f"[ToolConstraintError] Cannot call {tool_call_name}, valid tools to call include: {valid_tool_names}."
violated_rule_messages = tool_rules_solver.guess_rule_violation(tool_call_name)
if violated_rule_messages:
bullet_points = "\n".join(f"\t- {msg}" for msg in violated_rule_messages)
base_error_message += f"\n** Hint: Possible rules that were violated:\n{bullet_points}"
tool_execution_result = ToolExecutionResult(status="error", func_return=base_error_message)
tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver)
else:
tool_execution_result = await self._execute_tool(
tool_name=tool_call_name,
@@ -997,66 +969,38 @@ class LettaAgent(BaseAgent):
agent_step_span=agent_step_span,
step_id=step_id,
)
log_telemetry(
self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id
)
if tool_call_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
# with certain functions we rely on the paging mechanism to handle overflow
truncate = False
else:
# but by default, we add a truncation safeguard to prevent bad functions from
# overflow the agent context window
truncate = True
# get the function response limit
target_tool = next((x for x in agent_state.tools if x.name == tool_call_name), None)
return_char_limit = target_tool.return_char_limit if target_tool else None
function_response_string = validate_function_response(
tool_execution_result.func_return, return_char_limit=return_char_limit, truncate=truncate
# 3. Prepare the function-response payload
truncate = tool_call_name not in {"conversation_search", "conversation_search_date", "archival_memory_search"}
return_char_limit = next(
(t.return_char_limit for t in agent_state.tools if t.name == tool_call_name),
None,
)
function_response = package_function_response(
function_response_string = validate_function_response(
tool_execution_result.func_return,
return_char_limit=return_char_limit,
truncate=truncate,
)
self.last_function_response = package_function_response(
was_success=tool_execution_result.success_flag,
response_string=function_response_string,
timezone=agent_state.timezone,
)
# 4. Register tool call with tool rule solver
# Resolve whether or not to continue stepping
continue_stepping = request_heartbeat
# 4. Decide whether to keep stepping (<<< focal section simplified)
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
request_heartbeat=request_heartbeat,
tool_call_name=tool_call_name,
tool_rule_violated=tool_rule_violated,
tool_rules_solver=tool_rules_solver,
is_final_step=is_final_step,
)
# Force continuation if tool rule was violated to give the model another chance
if tool_rule_violated:
continue_stepping = True
else:
tool_rules_solver.register_tool_call(tool_name=tool_call_name)
if tool_rules_solver.is_terminal_tool(tool_name=tool_call_name):
if continue_stepping:
stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
continue_stepping = False
elif tool_rules_solver.has_children_tools(tool_name=tool_call_name):
continue_stepping = True
elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name):
continue_stepping = True
# Check if required-before-exit tools have been called before allowing exit
heartbeat_reason = None # Default
uncalled_required_tools = tool_rules_solver.get_uncalled_required_tools()
if not continue_stepping and uncalled_required_tools:
continue_stepping = True
heartbeat_reason = (
f"{NON_USER_MSG_PREFIX}Cannot finish, still need to call the following required tools: {', '.join(uncalled_required_tools)}"
)
# TODO: @caren is this right?
# reset stop reason since we ain't stopping!
stop_reason = None
self.logger.info(f"RequiredBeforeExitToolRule: Forcing agent continuation. Missing required tools: {uncalled_required_tools}")
# 5a. Persist Steps to DB
# Following agent loop to persist this before messages
# TODO (cliandy): determine what should match old loop w/provider_id
# TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same.
# 5. Persist step + messages and propagate to jobs
logged_step = await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
@@ -1071,7 +1015,6 @@ class LettaAgent(BaseAgent):
step_id=step_id,
)
# 5b. Persist Messages to DB
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
@@ -1083,27 +1026,72 @@ class LettaAgent(BaseAgent):
function_response=function_response_string,
timezone=agent_state.timezone,
actor=self.actor,
add_heartbeat_request_system_message=continue_stepping,
continue_stepping=continue_stepping,
heartbeat_reason=heartbeat_reason,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops
step_id=logged_step.id if logged_step else None,
)
persisted_messages = await self.message_manager.create_many_messages_async(
(initial_messages or []) + tool_call_messages, actor=self.actor
)
self.last_function_response = function_response
if run_id:
await self.job_manager.add_messages_to_job_async(
job_id=run_id,
message_ids=[message.id for message in persisted_messages if message.role != "user"],
message_ids=[m.id for m in persisted_messages if m.role != "user"],
actor=self.actor,
)
return persisted_messages, continue_stepping, stop_reason
def _decide_continuation(
self,
request_heartbeat: bool,
tool_call_name: str,
tool_rule_violated: bool,
tool_rules_solver: ToolRulesSolver,
is_final_step: bool | None,
) -> tuple[bool, str | None, LettaStopReason | None]:
continue_stepping = request_heartbeat
heartbeat_reason: str | None = None
stop_reason: LettaStopReason | None = None
if tool_rule_violated:
continue_stepping = True
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: tool rule violation."
else:
tool_rules_solver.register_tool_call(tool_call_name)
if tool_rules_solver.is_terminal_tool(tool_call_name):
if continue_stepping:
stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
continue_stepping = False
elif tool_rules_solver.has_children_tools(tool_call_name):
continue_stepping = True
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: child tool rule."
elif tool_rules_solver.is_continue_tool(tool_call_name):
continue_stepping = True
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: continue tool rule."
# hard stop overrides
if is_final_step:
continue_stepping = False
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
else:
uncalled = tool_rules_solver.get_uncalled_required_tools()
if not continue_stepping and uncalled:
continue_stepping = True
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Missing required tools: " f"{', '.join(uncalled)}"
stop_reason = None # reset were still going
return continue_stepping, heartbeat_reason, stop_reason
@trace_method
async def _execute_tool(
self,

View File

@@ -550,7 +550,7 @@ class LettaAgentBatch(BaseAgent):
tool_execution_result=tool_exec_result_obj,
timezone=agent_state.timezone,
actor=self.actor,
add_heartbeat_request_system_message=False,
continue_stepping=False,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=None,
llm_batch_item_id=llm_batch_item_id,

View File

@@ -277,7 +277,7 @@ class VoiceAgent(BaseAgent):
tool_execution_result=tool_execution_result,
timezone=agent_state.timezone,
actor=self.actor,
add_heartbeat_request_system_message=True,
continue_stepping=True,
)
letta_message_db_queue.extend(tool_call_messages)

View File

@@ -52,7 +52,7 @@ class ChildToolRule(BaseToolRule):
type: Literal[ToolRuleType.constrain_child_tools] = ToolRuleType.constrain_child_tools
children: List[str] = Field(..., description="The children tools that can be invoked.")
prompt_template: Optional[str] = Field(
default="<tool_rule>\nAfter using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}\n</tool_rule>",
default="<tool_rule>\nAfter using {{ tool_name }}, you must use one of these tools: {{ children | join(', ') }}\n</tool_rule>",
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
)
@@ -61,7 +61,7 @@ class ChildToolRule(BaseToolRule):
return set(self.children) if last_tool == self.tool_name else available_tools
def _get_default_template(self) -> Optional[str]:
return "<tool_rule>\nAfter using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}\n</tool_rule>"
return "<tool_rule>\nAfter using {{ tool_name }}, you must use one of these tools: {{ children | join(', ') }}\n</tool_rule>"
class ParentToolRule(BaseToolRule):

View File

@@ -13,7 +13,13 @@ from openai.types.chat.chat_completion_message_tool_call import Function as Open
from openai.types.chat.completion_create_params import CompletionCreateParams
from pydantic import BaseModel
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
from letta.constants import (
DEFAULT_MESSAGE_TOOL,
DEFAULT_MESSAGE_TOOL_KWARG,
FUNC_FAILED_HEARTBEAT_MESSAGE,
REQ_HEARTBEAT_MESSAGE,
REQUEST_HEARTBEAT_PARAM,
)
from letta.errors import ContextWindowExceededError, RateLimitExceededError
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns, ns_to_ms
from letta.helpers.message_helper import convert_message_creates_to_messages
@@ -194,7 +200,7 @@ def create_letta_messages_from_llm_response(
function_response: Optional[str],
timezone: str,
actor: User,
add_heartbeat_request_system_message: bool = False,
continue_stepping: bool = False,
heartbeat_reason: Optional[str] = None,
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
@@ -202,9 +208,9 @@ def create_letta_messages_from_llm_response(
step_id: str | None = None,
) -> List[Message]:
messages = []
# Construct the tool call with the assistant's message
function_arguments["request_heartbeat"] = True
# Force set request_heartbeat in tool_args to calculated continue_stepping
function_arguments[REQUEST_HEARTBEAT_PARAM] = continue_stepping
tool_call = OpenAIToolCall(
id=tool_call_id,
function=OpenAIFunction(
@@ -254,7 +260,7 @@ def create_letta_messages_from_llm_response(
)
messages.append(tool_message)
if add_heartbeat_request_system_message:
if continue_stepping:
heartbeat_system_message = create_heartbeat_system_message(
agent_id=agent_id,
model=model,
@@ -323,7 +329,7 @@ def create_assistant_messages_from_openai_response(
function_response=None,
timezone=timezone,
actor=actor,
add_heartbeat_request_system_message=False,
continue_stepping=False,
)

View File

@@ -1,4 +1,5 @@
import asyncio
import json
import uuid
import pytest
@@ -996,3 +997,53 @@ async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_
assert len(send_message_calls) == 1, "Should call send_message exactly once"
print(f"✓ Agent '{agent_name}' exited cleanly after calling required tool normally")
@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_terminal_tool_rule_send_message_request_heartbeat_false(server, disable_e2b_api_key, default_user):
"""Test that when there's a terminal tool rule on send_message, the tool call has request_heartbeat=False."""
agent_name = "terminal_send_message_heartbeat_test"
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
# Set up tool rules with terminal rule on send_message
tool_rules = [
TerminalToolRule(tool_name="send_message"),
]
# Create agent
agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[], tool_rules=tool_rules)
# Send message that should trigger send_message tool call
response = await run_agent_step(
server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="Please send me a simple message.")],
actor=default_user,
)
# Assertions
assert_sanity_checks(response)
assert_invoked_function_call(response.messages, "send_message")
# Find the send_message tool call and check request_heartbeat is False
send_message_call = None
for message in response.messages:
if isinstance(message, ToolCallMessage) and message.tool_call.name == "send_message":
send_message_call = message
break
assert send_message_call is not None, "send_message tool call should be found"
# Parse the arguments and check request_heartbeat
try:
arguments = json.loads(send_message_call.tool_call.arguments)
except json.JSONDecodeError:
pytest.fail("Failed to parse tool call arguments as JSON")
assert "request_heartbeat" in arguments, "request_heartbeat should be present in send_message arguments"
assert arguments["request_heartbeat"] is False, "request_heartbeat should be False for terminal tool rule"
print(f"✓ Agent '{agent_name}' correctly set request_heartbeat=False for terminal send_message")
cleanup(server=server, agent_uuid=agent_name, actor=default_user)