From 39bc9edb9ced8db6e1c75141fc4ff22f1246dafa Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 24 Jun 2025 19:50:00 -0700 Subject: [PATCH] feat: Override heartbeat request when system forces step exit (#3015) --- letta/agents/helpers.py | 28 +++ letta/agents/letta_agent.py | 190 ++++++++++----------- letta/agents/letta_agent_batch.py | 2 +- letta/agents/voice_agent.py | 2 +- letta/schemas/tool_rule.py | 4 +- letta/server/rest_api/utils.py | 18 +- tests/integration_test_agent_tool_graph.py | 51 ++++++ 7 files changed, 184 insertions(+), 111 deletions(-) diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index d98d6555..96011147 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -1,12 +1,15 @@ +import json import uuid import xml.etree.ElementTree as ET from typing import List, Optional, Tuple +from letta.helpers import ToolRulesSolver from letta.schemas.agent import AgentState from letta.schemas.letta_message import MessageType from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message, MessageCreate +from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.server.rest_api.utils import create_input_messages @@ -205,3 +208,28 @@ def deserialize_message_history(xml_str: str) -> Tuple[List[str], str]: def generate_step_id(): return f"step-{uuid.uuid4()}" + + +def _safe_load_dict(raw: str) -> dict: + """Lenient JSON → dict with fallback to eval on assertion failure.""" + if "}{" in raw: # strip accidental parallel calls + raw = raw.split("}{", 1)[0] + "}" + try: + data = json.loads(raw) + if not isinstance(data, dict): + raise AssertionError + return data + except (json.JSONDecodeError, AssertionError): + return json.loads(raw) if raw else {} + + +def _pop_heartbeat(tool_args: dict) -> bool: + hb = tool_args.pop("request_heartbeat", False) + return str(hb).lower() == "true" if isinstance(hb, str) else bool(hb) + + +def _build_rule_violation_result(tool_name: str, valid: list[str], solver: ToolRulesSolver) -> ToolExecutionResult: + hint_lines = solver.guess_rule_violation(tool_name) + hint_txt = ("\n** Hint: Possible rules that were violated:\n" + "\n".join(f"\t- {h}" for h in hint_lines)) if hint_lines else "" + msg = f"[ToolConstraintError] Cannot call {tool_name}, " f"valid tools include: {valid}.{hint_txt}" + return ToolExecutionResult(status="error", func_return=msg) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index d80dc436..d6d3ce0c 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -10,7 +10,14 @@ from opentelemetry.trace import Span from letta.agents.base_agent import BaseAgent from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent -from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_no_persist_async, generate_step_id +from letta.agents.helpers import ( + _build_rule_violation_result, + _create_letta_response, + _pop_heartbeat, + _prepare_in_context_messages_no_persist_async, + _safe_load_dict, + generate_step_id, +) from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX from letta.errors import ContextWindowExceededError from letta.helpers import ToolRulesSolver @@ -931,45 +938,15 @@ class LettaAgent(BaseAgent): run_id: Optional[str] = None, ) -> Tuple[List[Message], bool, Optional[LettaStopReason]]: """ - Now that streaming is done, handle the final AI response. - This might yield additional SSE tokens if we do stalling. - At the end, set self._continue_execution accordingly. + Handle the final AI response once streaming completes, execute / validate the + tool call, decide whether we should keep stepping, and persist state. """ - stop_reason = None - # Check if the called tool is allowed by tool name: - tool_call_name = tool_call.function.name - tool_call_args_str = tool_call.function.arguments - - # Temp hack to gracefully handle parallel tool calling attempt, only take first one - if "}{" in tool_call_args_str: - tool_call_args_str = tool_call_args_str.split("}{", 1)[0] + "}" - - try: - tool_args = json.loads(tool_call_args_str) - assert isinstance(tool_args, dict), "tool_args must be a dict" - except json.JSONDecodeError: - tool_args = {} - except AssertionError: - tool_args = json.loads(tool_args) - - # Get request heartbeats and coerce to bool - request_heartbeat = tool_args.pop("request_heartbeat", False) - if is_final_step: - stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value) - self.logger.info("Agent has reached max steps.") - request_heartbeat = False - else: - # Pre-emptively pop out inner_thoughts - tool_args.pop(INNER_THOUGHTS_KWARG, "") - - # So this is necessary, because sometimes non-structured outputs makes mistakes - if not isinstance(request_heartbeat, bool): - if isinstance(request_heartbeat, str): - request_heartbeat = request_heartbeat.lower() == "true" - else: - request_heartbeat = bool(request_heartbeat) - - tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}" + # 1. Parse and validate the tool-call envelope + tool_call_name: str = tool_call.function.name + tool_call_id: str = tool_call.id or f"call_{uuid.uuid4().hex[:8]}" + tool_args = _safe_load_dict(tool_call.function.arguments) + request_heartbeat: bool = _pop_heartbeat(tool_args) + tool_args.pop(INNER_THOUGHTS_KWARG, None) log_telemetry( self.logger, @@ -979,16 +956,11 @@ class LettaAgent(BaseAgent): tool_call_id=tool_call_id, request_heartbeat=request_heartbeat, ) - # Check if tool rule is violated - if so, we'll force continuation - tool_rule_violated = tool_call_name not in valid_tool_names + # 2. Execute the tool (or synthesize an error result if disallowed) + tool_rule_violated = tool_call_name not in valid_tool_names if tool_rule_violated: - base_error_message = f"[ToolConstraintError] Cannot call {tool_call_name}, valid tools to call include: {valid_tool_names}." - violated_rule_messages = tool_rules_solver.guess_rule_violation(tool_call_name) - if violated_rule_messages: - bullet_points = "\n".join(f"\t- {msg}" for msg in violated_rule_messages) - base_error_message += f"\n** Hint: Possible rules that were violated:\n{bullet_points}" - tool_execution_result = ToolExecutionResult(status="error", func_return=base_error_message) + tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver) else: tool_execution_result = await self._execute_tool( tool_name=tool_call_name, @@ -997,66 +969,38 @@ class LettaAgent(BaseAgent): agent_step_span=agent_step_span, step_id=step_id, ) + log_telemetry( self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id ) - if tool_call_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]: - # with certain functions we rely on the paging mechanism to handle overflow - truncate = False - else: - # but by default, we add a truncation safeguard to prevent bad functions from - # overflow the agent context window - truncate = True - - # get the function response limit - target_tool = next((x for x in agent_state.tools if x.name == tool_call_name), None) - return_char_limit = target_tool.return_char_limit if target_tool else None - function_response_string = validate_function_response( - tool_execution_result.func_return, return_char_limit=return_char_limit, truncate=truncate + # 3. Prepare the function-response payload + truncate = tool_call_name not in {"conversation_search", "conversation_search_date", "archival_memory_search"} + return_char_limit = next( + (t.return_char_limit for t in agent_state.tools if t.name == tool_call_name), + None, ) - function_response = package_function_response( + function_response_string = validate_function_response( + tool_execution_result.func_return, + return_char_limit=return_char_limit, + truncate=truncate, + ) + self.last_function_response = package_function_response( was_success=tool_execution_result.success_flag, response_string=function_response_string, timezone=agent_state.timezone, ) - # 4. Register tool call with tool rule solver - # Resolve whether or not to continue stepping - continue_stepping = request_heartbeat + # 4. Decide whether to keep stepping (<<< focal section simplified) + continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation( + request_heartbeat=request_heartbeat, + tool_call_name=tool_call_name, + tool_rule_violated=tool_rule_violated, + tool_rules_solver=tool_rules_solver, + is_final_step=is_final_step, + ) - # Force continuation if tool rule was violated to give the model another chance - if tool_rule_violated: - continue_stepping = True - else: - tool_rules_solver.register_tool_call(tool_name=tool_call_name) - if tool_rules_solver.is_terminal_tool(tool_name=tool_call_name): - if continue_stepping: - stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value) - continue_stepping = False - elif tool_rules_solver.has_children_tools(tool_name=tool_call_name): - continue_stepping = True - elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name): - continue_stepping = True - - # Check if required-before-exit tools have been called before allowing exit - heartbeat_reason = None # Default - uncalled_required_tools = tool_rules_solver.get_uncalled_required_tools() - if not continue_stepping and uncalled_required_tools: - continue_stepping = True - heartbeat_reason = ( - f"{NON_USER_MSG_PREFIX}Cannot finish, still need to call the following required tools: {', '.join(uncalled_required_tools)}" - ) - - # TODO: @caren is this right? - # reset stop reason since we ain't stopping! - stop_reason = None - self.logger.info(f"RequiredBeforeExitToolRule: Forcing agent continuation. Missing required tools: {uncalled_required_tools}") - - # 5a. Persist Steps to DB - # Following agent loop to persist this before messages - # TODO (cliandy): determine what should match old loop w/provider_id - # TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same. + # 5. Persist step + messages and propagate to jobs logged_step = await self.step_manager.log_step_async( actor=self.actor, agent_id=agent_state.id, @@ -1071,7 +1015,6 @@ class LettaAgent(BaseAgent): step_id=step_id, ) - # 5b. Persist Messages to DB tool_call_messages = create_letta_messages_from_llm_response( agent_id=agent_state.id, model=agent_state.llm_config.model, @@ -1083,27 +1026,72 @@ class LettaAgent(BaseAgent): function_response=function_response_string, timezone=agent_state.timezone, actor=self.actor, - add_heartbeat_request_system_message=continue_stepping, + continue_stepping=continue_stepping, heartbeat_reason=heartbeat_reason, reasoning_content=reasoning_content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, - step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops + step_id=logged_step.id if logged_step else None, ) persisted_messages = await self.message_manager.create_many_messages_async( (initial_messages or []) + tool_call_messages, actor=self.actor ) - self.last_function_response = function_response if run_id: await self.job_manager.add_messages_to_job_async( job_id=run_id, - message_ids=[message.id for message in persisted_messages if message.role != "user"], + message_ids=[m.id for m in persisted_messages if m.role != "user"], actor=self.actor, ) return persisted_messages, continue_stepping, stop_reason + def _decide_continuation( + self, + request_heartbeat: bool, + tool_call_name: str, + tool_rule_violated: bool, + tool_rules_solver: ToolRulesSolver, + is_final_step: bool | None, + ) -> tuple[bool, str | None, LettaStopReason | None]: + + continue_stepping = request_heartbeat + heartbeat_reason: str | None = None + stop_reason: LettaStopReason | None = None + + if tool_rule_violated: + continue_stepping = True + heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: tool rule violation." + else: + tool_rules_solver.register_tool_call(tool_call_name) + + if tool_rules_solver.is_terminal_tool(tool_call_name): + if continue_stepping: + stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value) + continue_stepping = False + + elif tool_rules_solver.has_children_tools(tool_call_name): + continue_stepping = True + heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: child tool rule." + + elif tool_rules_solver.is_continue_tool(tool_call_name): + continue_stepping = True + heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: continue tool rule." + + # – hard stop overrides – + if is_final_step: + continue_stepping = False + stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value) + else: + uncalled = tool_rules_solver.get_uncalled_required_tools() + if not continue_stepping and uncalled: + continue_stepping = True + heartbeat_reason = f"{NON_USER_MSG_PREFIX}Missing required tools: " f"{', '.join(uncalled)}" + + stop_reason = None # reset – we’re still going + + return continue_stepping, heartbeat_reason, stop_reason + @trace_method async def _execute_tool( self, diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index d3343a03..da94511f 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -550,7 +550,7 @@ class LettaAgentBatch(BaseAgent): tool_execution_result=tool_exec_result_obj, timezone=agent_state.timezone, actor=self.actor, - add_heartbeat_request_system_message=False, + continue_stepping=False, reasoning_content=reasoning_content, pre_computed_assistant_message_id=None, llm_batch_item_id=llm_batch_item_id, diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index 633f9925..edc18fd2 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -277,7 +277,7 @@ class VoiceAgent(BaseAgent): tool_execution_result=tool_execution_result, timezone=agent_state.timezone, actor=self.actor, - add_heartbeat_request_system_message=True, + continue_stepping=True, ) letta_message_db_queue.extend(tool_call_messages) diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index 6d78c810..30950d8d 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -52,7 +52,7 @@ class ChildToolRule(BaseToolRule): type: Literal[ToolRuleType.constrain_child_tools] = ToolRuleType.constrain_child_tools children: List[str] = Field(..., description="The children tools that can be invoked.") prompt_template: Optional[str] = Field( - default="\nAfter using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}\n", + default="\nAfter using {{ tool_name }}, you must use one of these tools: {{ children | join(', ') }}\n", description="Optional Jinja2 template for generating agent prompt about this tool rule.", ) @@ -61,7 +61,7 @@ class ChildToolRule(BaseToolRule): return set(self.children) if last_tool == self.tool_name else available_tools def _get_default_template(self) -> Optional[str]: - return "\nAfter using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}\n" + return "\nAfter using {{ tool_name }}, you must use one of these tools: {{ children | join(', ') }}\n" class ParentToolRule(BaseToolRule): diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index c6a4d902..394193e3 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -13,7 +13,13 @@ from openai.types.chat.chat_completion_message_tool_call import Function as Open from openai.types.chat.completion_create_params import CompletionCreateParams from pydantic import BaseModel -from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE +from letta.constants import ( + DEFAULT_MESSAGE_TOOL, + DEFAULT_MESSAGE_TOOL_KWARG, + FUNC_FAILED_HEARTBEAT_MESSAGE, + REQ_HEARTBEAT_MESSAGE, + REQUEST_HEARTBEAT_PARAM, +) from letta.errors import ContextWindowExceededError, RateLimitExceededError from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns, ns_to_ms from letta.helpers.message_helper import convert_message_creates_to_messages @@ -194,7 +200,7 @@ def create_letta_messages_from_llm_response( function_response: Optional[str], timezone: str, actor: User, - add_heartbeat_request_system_message: bool = False, + continue_stepping: bool = False, heartbeat_reason: Optional[str] = None, reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, pre_computed_assistant_message_id: Optional[str] = None, @@ -202,9 +208,9 @@ def create_letta_messages_from_llm_response( step_id: str | None = None, ) -> List[Message]: messages = [] - # Construct the tool call with the assistant's message - function_arguments["request_heartbeat"] = True + # Force set request_heartbeat in tool_args to calculated continue_stepping + function_arguments[REQUEST_HEARTBEAT_PARAM] = continue_stepping tool_call = OpenAIToolCall( id=tool_call_id, function=OpenAIFunction( @@ -254,7 +260,7 @@ def create_letta_messages_from_llm_response( ) messages.append(tool_message) - if add_heartbeat_request_system_message: + if continue_stepping: heartbeat_system_message = create_heartbeat_system_message( agent_id=agent_id, model=model, @@ -323,7 +329,7 @@ def create_assistant_messages_from_openai_response( function_response=None, timezone=timezone, actor=actor, - add_heartbeat_request_system_message=False, + continue_stepping=False, ) diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 5482bce7..57445d68 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -1,4 +1,5 @@ import asyncio +import json import uuid import pytest @@ -996,3 +997,53 @@ async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_ assert len(send_message_calls) == 1, "Should call send_message exactly once" print(f"✓ Agent '{agent_name}' exited cleanly after calling required tool normally") + + +@pytest.mark.timeout(60) +@pytest.mark.asyncio +async def test_terminal_tool_rule_send_message_request_heartbeat_false(server, disable_e2b_api_key, default_user): + """Test that when there's a terminal tool rule on send_message, the tool call has request_heartbeat=False.""" + agent_name = "terminal_send_message_heartbeat_test" + config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json" + + # Set up tool rules with terminal rule on send_message + tool_rules = [ + TerminalToolRule(tool_name="send_message"), + ] + + # Create agent + agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[], tool_rules=tool_rules) + + # Send message that should trigger send_message tool call + response = await run_agent_step( + server=server, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="Please send me a simple message.")], + actor=default_user, + ) + + # Assertions + assert_sanity_checks(response) + assert_invoked_function_call(response.messages, "send_message") + + # Find the send_message tool call and check request_heartbeat is False + send_message_call = None + for message in response.messages: + if isinstance(message, ToolCallMessage) and message.tool_call.name == "send_message": + send_message_call = message + break + + assert send_message_call is not None, "send_message tool call should be found" + + # Parse the arguments and check request_heartbeat + try: + arguments = json.loads(send_message_call.tool_call.arguments) + except json.JSONDecodeError: + pytest.fail("Failed to parse tool call arguments as JSON") + + assert "request_heartbeat" in arguments, "request_heartbeat should be present in send_message arguments" + assert arguments["request_heartbeat"] is False, "request_heartbeat should be False for terminal tool rule" + + print(f"✓ Agent '{agent_name}' correctly set request_heartbeat=False for terminal send_message") + + cleanup(server=server, agent_uuid=agent_name, actor=default_user)