feat: Override heartbeat request when system forces step exit (#3015)
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
import json
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import create_input_messages
|
||||
@@ -205,3 +208,28 @@ def deserialize_message_history(xml_str: str) -> Tuple[List[str], str]:
|
||||
|
||||
def generate_step_id():
|
||||
return f"step-{uuid.uuid4()}"
|
||||
|
||||
|
||||
def _safe_load_dict(raw: str) -> dict:
|
||||
"""Lenient JSON → dict with fallback to eval on assertion failure."""
|
||||
if "}{" in raw: # strip accidental parallel calls
|
||||
raw = raw.split("}{", 1)[0] + "}"
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
if not isinstance(data, dict):
|
||||
raise AssertionError
|
||||
return data
|
||||
except (json.JSONDecodeError, AssertionError):
|
||||
return json.loads(raw) if raw else {}
|
||||
|
||||
|
||||
def _pop_heartbeat(tool_args: dict) -> bool:
|
||||
hb = tool_args.pop("request_heartbeat", False)
|
||||
return str(hb).lower() == "true" if isinstance(hb, str) else bool(hb)
|
||||
|
||||
|
||||
def _build_rule_violation_result(tool_name: str, valid: list[str], solver: ToolRulesSolver) -> ToolExecutionResult:
|
||||
hint_lines = solver.guess_rule_violation(tool_name)
|
||||
hint_txt = ("\n** Hint: Possible rules that were violated:\n" + "\n".join(f"\t- {h}" for h in hint_lines)) if hint_lines else ""
|
||||
msg = f"[ToolConstraintError] Cannot call {tool_name}, " f"valid tools include: {valid}.{hint_txt}"
|
||||
return ToolExecutionResult(status="error", func_return=msg)
|
||||
|
||||
@@ -10,7 +10,14 @@ from opentelemetry.trace import Span
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
|
||||
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_no_persist_async, generate_step_id
|
||||
from letta.agents.helpers import (
|
||||
_build_rule_violation_result,
|
||||
_create_letta_response,
|
||||
_pop_heartbeat,
|
||||
_prepare_in_context_messages_no_persist_async,
|
||||
_safe_load_dict,
|
||||
generate_step_id,
|
||||
)
|
||||
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX
|
||||
from letta.errors import ContextWindowExceededError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
@@ -931,45 +938,15 @@ class LettaAgent(BaseAgent):
|
||||
run_id: Optional[str] = None,
|
||||
) -> Tuple[List[Message], bool, Optional[LettaStopReason]]:
|
||||
"""
|
||||
Now that streaming is done, handle the final AI response.
|
||||
This might yield additional SSE tokens if we do stalling.
|
||||
At the end, set self._continue_execution accordingly.
|
||||
Handle the final AI response once streaming completes, execute / validate the
|
||||
tool call, decide whether we should keep stepping, and persist state.
|
||||
"""
|
||||
stop_reason = None
|
||||
# Check if the called tool is allowed by tool name:
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_args_str = tool_call.function.arguments
|
||||
|
||||
# Temp hack to gracefully handle parallel tool calling attempt, only take first one
|
||||
if "}{" in tool_call_args_str:
|
||||
tool_call_args_str = tool_call_args_str.split("}{", 1)[0] + "}"
|
||||
|
||||
try:
|
||||
tool_args = json.loads(tool_call_args_str)
|
||||
assert isinstance(tool_args, dict), "tool_args must be a dict"
|
||||
except json.JSONDecodeError:
|
||||
tool_args = {}
|
||||
except AssertionError:
|
||||
tool_args = json.loads(tool_args)
|
||||
|
||||
# Get request heartbeats and coerce to bool
|
||||
request_heartbeat = tool_args.pop("request_heartbeat", False)
|
||||
if is_final_step:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
|
||||
self.logger.info("Agent has reached max steps.")
|
||||
request_heartbeat = False
|
||||
else:
|
||||
# Pre-emptively pop out inner_thoughts
|
||||
tool_args.pop(INNER_THOUGHTS_KWARG, "")
|
||||
|
||||
# So this is necessary, because sometimes non-structured outputs makes mistakes
|
||||
if not isinstance(request_heartbeat, bool):
|
||||
if isinstance(request_heartbeat, str):
|
||||
request_heartbeat = request_heartbeat.lower() == "true"
|
||||
else:
|
||||
request_heartbeat = bool(request_heartbeat)
|
||||
|
||||
tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
|
||||
# 1. Parse and validate the tool-call envelope
|
||||
tool_call_name: str = tool_call.function.name
|
||||
tool_call_id: str = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
|
||||
tool_args = _safe_load_dict(tool_call.function.arguments)
|
||||
request_heartbeat: bool = _pop_heartbeat(tool_args)
|
||||
tool_args.pop(INNER_THOUGHTS_KWARG, None)
|
||||
|
||||
log_telemetry(
|
||||
self.logger,
|
||||
@@ -979,16 +956,11 @@ class LettaAgent(BaseAgent):
|
||||
tool_call_id=tool_call_id,
|
||||
request_heartbeat=request_heartbeat,
|
||||
)
|
||||
# Check if tool rule is violated - if so, we'll force continuation
|
||||
tool_rule_violated = tool_call_name not in valid_tool_names
|
||||
|
||||
# 2. Execute the tool (or synthesize an error result if disallowed)
|
||||
tool_rule_violated = tool_call_name not in valid_tool_names
|
||||
if tool_rule_violated:
|
||||
base_error_message = f"[ToolConstraintError] Cannot call {tool_call_name}, valid tools to call include: {valid_tool_names}."
|
||||
violated_rule_messages = tool_rules_solver.guess_rule_violation(tool_call_name)
|
||||
if violated_rule_messages:
|
||||
bullet_points = "\n".join(f"\t- {msg}" for msg in violated_rule_messages)
|
||||
base_error_message += f"\n** Hint: Possible rules that were violated:\n{bullet_points}"
|
||||
tool_execution_result = ToolExecutionResult(status="error", func_return=base_error_message)
|
||||
tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver)
|
||||
else:
|
||||
tool_execution_result = await self._execute_tool(
|
||||
tool_name=tool_call_name,
|
||||
@@ -997,66 +969,38 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span=agent_step_span,
|
||||
step_id=step_id,
|
||||
)
|
||||
|
||||
log_telemetry(
|
||||
self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id
|
||||
)
|
||||
|
||||
if tool_call_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
|
||||
# with certain functions we rely on the paging mechanism to handle overflow
|
||||
truncate = False
|
||||
else:
|
||||
# but by default, we add a truncation safeguard to prevent bad functions from
|
||||
# overflow the agent context window
|
||||
truncate = True
|
||||
|
||||
# get the function response limit
|
||||
target_tool = next((x for x in agent_state.tools if x.name == tool_call_name), None)
|
||||
return_char_limit = target_tool.return_char_limit if target_tool else None
|
||||
function_response_string = validate_function_response(
|
||||
tool_execution_result.func_return, return_char_limit=return_char_limit, truncate=truncate
|
||||
# 3. Prepare the function-response payload
|
||||
truncate = tool_call_name not in {"conversation_search", "conversation_search_date", "archival_memory_search"}
|
||||
return_char_limit = next(
|
||||
(t.return_char_limit for t in agent_state.tools if t.name == tool_call_name),
|
||||
None,
|
||||
)
|
||||
function_response = package_function_response(
|
||||
function_response_string = validate_function_response(
|
||||
tool_execution_result.func_return,
|
||||
return_char_limit=return_char_limit,
|
||||
truncate=truncate,
|
||||
)
|
||||
self.last_function_response = package_function_response(
|
||||
was_success=tool_execution_result.success_flag,
|
||||
response_string=function_response_string,
|
||||
timezone=agent_state.timezone,
|
||||
)
|
||||
|
||||
# 4. Register tool call with tool rule solver
|
||||
# Resolve whether or not to continue stepping
|
||||
continue_stepping = request_heartbeat
|
||||
# 4. Decide whether to keep stepping (<<< focal section simplified)
|
||||
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
|
||||
request_heartbeat=request_heartbeat,
|
||||
tool_call_name=tool_call_name,
|
||||
tool_rule_violated=tool_rule_violated,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
is_final_step=is_final_step,
|
||||
)
|
||||
|
||||
# Force continuation if tool rule was violated to give the model another chance
|
||||
if tool_rule_violated:
|
||||
continue_stepping = True
|
||||
else:
|
||||
tool_rules_solver.register_tool_call(tool_name=tool_call_name)
|
||||
if tool_rules_solver.is_terminal_tool(tool_name=tool_call_name):
|
||||
if continue_stepping:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
|
||||
continue_stepping = False
|
||||
elif tool_rules_solver.has_children_tools(tool_name=tool_call_name):
|
||||
continue_stepping = True
|
||||
elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name):
|
||||
continue_stepping = True
|
||||
|
||||
# Check if required-before-exit tools have been called before allowing exit
|
||||
heartbeat_reason = None # Default
|
||||
uncalled_required_tools = tool_rules_solver.get_uncalled_required_tools()
|
||||
if not continue_stepping and uncalled_required_tools:
|
||||
continue_stepping = True
|
||||
heartbeat_reason = (
|
||||
f"{NON_USER_MSG_PREFIX}Cannot finish, still need to call the following required tools: {', '.join(uncalled_required_tools)}"
|
||||
)
|
||||
|
||||
# TODO: @caren is this right?
|
||||
# reset stop reason since we ain't stopping!
|
||||
stop_reason = None
|
||||
self.logger.info(f"RequiredBeforeExitToolRule: Forcing agent continuation. Missing required tools: {uncalled_required_tools}")
|
||||
|
||||
# 5a. Persist Steps to DB
|
||||
# Following agent loop to persist this before messages
|
||||
# TODO (cliandy): determine what should match old loop w/provider_id
|
||||
# TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same.
|
||||
# 5. Persist step + messages and propagate to jobs
|
||||
logged_step = await self.step_manager.log_step_async(
|
||||
actor=self.actor,
|
||||
agent_id=agent_state.id,
|
||||
@@ -1071,7 +1015,6 @@ class LettaAgent(BaseAgent):
|
||||
step_id=step_id,
|
||||
)
|
||||
|
||||
# 5b. Persist Messages to DB
|
||||
tool_call_messages = create_letta_messages_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
@@ -1083,27 +1026,72 @@ class LettaAgent(BaseAgent):
|
||||
function_response=function_response_string,
|
||||
timezone=agent_state.timezone,
|
||||
actor=self.actor,
|
||||
add_heartbeat_request_system_message=continue_stepping,
|
||||
continue_stepping=continue_stepping,
|
||||
heartbeat_reason=heartbeat_reason,
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||||
step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops
|
||||
step_id=logged_step.id if logged_step else None,
|
||||
)
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
(initial_messages or []) + tool_call_messages, actor=self.actor
|
||||
)
|
||||
self.last_function_response = function_response
|
||||
|
||||
if run_id:
|
||||
await self.job_manager.add_messages_to_job_async(
|
||||
job_id=run_id,
|
||||
message_ids=[message.id for message in persisted_messages if message.role != "user"],
|
||||
message_ids=[m.id for m in persisted_messages if m.role != "user"],
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
return persisted_messages, continue_stepping, stop_reason
|
||||
|
||||
def _decide_continuation(
|
||||
self,
|
||||
request_heartbeat: bool,
|
||||
tool_call_name: str,
|
||||
tool_rule_violated: bool,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
is_final_step: bool | None,
|
||||
) -> tuple[bool, str | None, LettaStopReason | None]:
|
||||
|
||||
continue_stepping = request_heartbeat
|
||||
heartbeat_reason: str | None = None
|
||||
stop_reason: LettaStopReason | None = None
|
||||
|
||||
if tool_rule_violated:
|
||||
continue_stepping = True
|
||||
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: tool rule violation."
|
||||
else:
|
||||
tool_rules_solver.register_tool_call(tool_call_name)
|
||||
|
||||
if tool_rules_solver.is_terminal_tool(tool_call_name):
|
||||
if continue_stepping:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
|
||||
continue_stepping = False
|
||||
|
||||
elif tool_rules_solver.has_children_tools(tool_call_name):
|
||||
continue_stepping = True
|
||||
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: child tool rule."
|
||||
|
||||
elif tool_rules_solver.is_continue_tool(tool_call_name):
|
||||
continue_stepping = True
|
||||
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: continue tool rule."
|
||||
|
||||
# – hard stop overrides –
|
||||
if is_final_step:
|
||||
continue_stepping = False
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
|
||||
else:
|
||||
uncalled = tool_rules_solver.get_uncalled_required_tools()
|
||||
if not continue_stepping and uncalled:
|
||||
continue_stepping = True
|
||||
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Missing required tools: " f"{', '.join(uncalled)}"
|
||||
|
||||
stop_reason = None # reset – we’re still going
|
||||
|
||||
return continue_stepping, heartbeat_reason, stop_reason
|
||||
|
||||
@trace_method
|
||||
async def _execute_tool(
|
||||
self,
|
||||
|
||||
@@ -550,7 +550,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
tool_execution_result=tool_exec_result_obj,
|
||||
timezone=agent_state.timezone,
|
||||
actor=self.actor,
|
||||
add_heartbeat_request_system_message=False,
|
||||
continue_stepping=False,
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=None,
|
||||
llm_batch_item_id=llm_batch_item_id,
|
||||
|
||||
@@ -277,7 +277,7 @@ class VoiceAgent(BaseAgent):
|
||||
tool_execution_result=tool_execution_result,
|
||||
timezone=agent_state.timezone,
|
||||
actor=self.actor,
|
||||
add_heartbeat_request_system_message=True,
|
||||
continue_stepping=True,
|
||||
)
|
||||
letta_message_db_queue.extend(tool_call_messages)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class ChildToolRule(BaseToolRule):
|
||||
type: Literal[ToolRuleType.constrain_child_tools] = ToolRuleType.constrain_child_tools
|
||||
children: List[str] = Field(..., description="The children tools that can be invoked.")
|
||||
prompt_template: Optional[str] = Field(
|
||||
default="<tool_rule>\nAfter using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}\n</tool_rule>",
|
||||
default="<tool_rule>\nAfter using {{ tool_name }}, you must use one of these tools: {{ children | join(', ') }}\n</tool_rule>",
|
||||
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
|
||||
)
|
||||
|
||||
@@ -61,7 +61,7 @@ class ChildToolRule(BaseToolRule):
|
||||
return set(self.children) if last_tool == self.tool_name else available_tools
|
||||
|
||||
def _get_default_template(self) -> Optional[str]:
|
||||
return "<tool_rule>\nAfter using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}\n</tool_rule>"
|
||||
return "<tool_rule>\nAfter using {{ tool_name }}, you must use one of these tools: {{ children | join(', ') }}\n</tool_rule>"
|
||||
|
||||
|
||||
class ParentToolRule(BaseToolRule):
|
||||
|
||||
@@ -13,7 +13,13 @@ from openai.types.chat.chat_completion_message_tool_call import Function as Open
|
||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
|
||||
from letta.constants import (
|
||||
DEFAULT_MESSAGE_TOOL,
|
||||
DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE,
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
REQUEST_HEARTBEAT_PARAM,
|
||||
)
|
||||
from letta.errors import ContextWindowExceededError, RateLimitExceededError
|
||||
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns, ns_to_ms
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
@@ -194,7 +200,7 @@ def create_letta_messages_from_llm_response(
|
||||
function_response: Optional[str],
|
||||
timezone: str,
|
||||
actor: User,
|
||||
add_heartbeat_request_system_message: bool = False,
|
||||
continue_stepping: bool = False,
|
||||
heartbeat_reason: Optional[str] = None,
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
@@ -202,9 +208,9 @@ def create_letta_messages_from_llm_response(
|
||||
step_id: str | None = None,
|
||||
) -> List[Message]:
|
||||
messages = []
|
||||
|
||||
# Construct the tool call with the assistant's message
|
||||
function_arguments["request_heartbeat"] = True
|
||||
# Force set request_heartbeat in tool_args to calculated continue_stepping
|
||||
function_arguments[REQUEST_HEARTBEAT_PARAM] = continue_stepping
|
||||
tool_call = OpenAIToolCall(
|
||||
id=tool_call_id,
|
||||
function=OpenAIFunction(
|
||||
@@ -254,7 +260,7 @@ def create_letta_messages_from_llm_response(
|
||||
)
|
||||
messages.append(tool_message)
|
||||
|
||||
if add_heartbeat_request_system_message:
|
||||
if continue_stepping:
|
||||
heartbeat_system_message = create_heartbeat_system_message(
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
@@ -323,7 +329,7 @@ def create_assistant_messages_from_openai_response(
|
||||
function_response=None,
|
||||
timezone=timezone,
|
||||
actor=actor,
|
||||
add_heartbeat_request_system_message=False,
|
||||
continue_stepping=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
@@ -996,3 +997,53 @@ async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_
|
||||
assert len(send_message_calls) == 1, "Should call send_message exactly once"
|
||||
|
||||
print(f"✓ Agent '{agent_name}' exited cleanly after calling required tool normally")
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminal_tool_rule_send_message_request_heartbeat_false(server, disable_e2b_api_key, default_user):
|
||||
"""Test that when there's a terminal tool rule on send_message, the tool call has request_heartbeat=False."""
|
||||
agent_name = "terminal_send_message_heartbeat_test"
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
# Set up tool rules with terminal rule on send_message
|
||||
tool_rules = [
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
|
||||
# Create agent
|
||||
agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[], tool_rules=tool_rules)
|
||||
|
||||
# Send message that should trigger send_message tool call
|
||||
response = await run_agent_step(
|
||||
server=server,
|
||||
agent_id=agent_state.id,
|
||||
input_messages=[MessageCreate(role="user", content="Please send me a simple message.")],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert_sanity_checks(response)
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
# Find the send_message tool call and check request_heartbeat is False
|
||||
send_message_call = None
|
||||
for message in response.messages:
|
||||
if isinstance(message, ToolCallMessage) and message.tool_call.name == "send_message":
|
||||
send_message_call = message
|
||||
break
|
||||
|
||||
assert send_message_call is not None, "send_message tool call should be found"
|
||||
|
||||
# Parse the arguments and check request_heartbeat
|
||||
try:
|
||||
arguments = json.loads(send_message_call.tool_call.arguments)
|
||||
except json.JSONDecodeError:
|
||||
pytest.fail("Failed to parse tool call arguments as JSON")
|
||||
|
||||
assert "request_heartbeat" in arguments, "request_heartbeat should be present in send_message arguments"
|
||||
assert arguments["request_heartbeat"] is False, "request_heartbeat should be False for terminal tool rule"
|
||||
|
||||
print(f"✓ Agent '{agent_name}' correctly set request_heartbeat=False for terminal send_message")
|
||||
|
||||
cleanup(server=server, agent_uuid=agent_name, actor=default_user)
|
||||
|
||||
Reference in New Issue
Block a user