diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py
index d98d6555..96011147 100644
--- a/letta/agents/helpers.py
+++ b/letta/agents/helpers.py
@@ -1,12 +1,15 @@
+import json
import uuid
import xml.etree.ElementTree as ET
from typing import List, Optional, Tuple
+from letta.helpers import ToolRulesSolver
from letta.schemas.agent import AgentState
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message, MessageCreate
+from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import create_input_messages
@@ -205,3 +208,28 @@ def deserialize_message_history(xml_str: str) -> Tuple[List[str], str]:
def generate_step_id():
return f"step-{uuid.uuid4()}"
+
+
+def _safe_load_dict(raw: str) -> dict:
+ """Lenient JSON → dict with fallback to eval on assertion failure."""
+ if "}{" in raw: # strip accidental parallel calls
+ raw = raw.split("}{", 1)[0] + "}"
+ try:
+ data = json.loads(raw)
+ if not isinstance(data, dict):
+ raise AssertionError
+ return data
+ except (json.JSONDecodeError, AssertionError):
+ return json.loads(raw) if raw else {}
+
+
+def _pop_heartbeat(tool_args: dict) -> bool:
+ hb = tool_args.pop("request_heartbeat", False)
+ return str(hb).lower() == "true" if isinstance(hb, str) else bool(hb)
+
+
+def _build_rule_violation_result(tool_name: str, valid: list[str], solver: ToolRulesSolver) -> ToolExecutionResult:
+ hint_lines = solver.guess_rule_violation(tool_name)
+ hint_txt = ("\n** Hint: Possible rules that were violated:\n" + "\n".join(f"\t- {h}" for h in hint_lines)) if hint_lines else ""
+ msg = f"[ToolConstraintError] Cannot call {tool_name}, " f"valid tools include: {valid}.{hint_txt}"
+ return ToolExecutionResult(status="error", func_return=msg)
diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py
index d80dc436..d6d3ce0c 100644
--- a/letta/agents/letta_agent.py
+++ b/letta/agents/letta_agent.py
@@ -10,7 +10,14 @@ from opentelemetry.trace import Span
from letta.agents.base_agent import BaseAgent
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
-from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_no_persist_async, generate_step_id
+from letta.agents.helpers import (
+ _build_rule_violation_result,
+ _create_letta_response,
+ _pop_heartbeat,
+ _prepare_in_context_messages_no_persist_async,
+ _safe_load_dict,
+ generate_step_id,
+)
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX
from letta.errors import ContextWindowExceededError
from letta.helpers import ToolRulesSolver
@@ -931,45 +938,15 @@ class LettaAgent(BaseAgent):
run_id: Optional[str] = None,
) -> Tuple[List[Message], bool, Optional[LettaStopReason]]:
"""
- Now that streaming is done, handle the final AI response.
- This might yield additional SSE tokens if we do stalling.
- At the end, set self._continue_execution accordingly.
+ Handle the final AI response once streaming completes, execute / validate the
+ tool call, decide whether we should keep stepping, and persist state.
"""
- stop_reason = None
- # Check if the called tool is allowed by tool name:
- tool_call_name = tool_call.function.name
- tool_call_args_str = tool_call.function.arguments
-
- # Temp hack to gracefully handle parallel tool calling attempt, only take first one
- if "}{" in tool_call_args_str:
- tool_call_args_str = tool_call_args_str.split("}{", 1)[0] + "}"
-
- try:
- tool_args = json.loads(tool_call_args_str)
- assert isinstance(tool_args, dict), "tool_args must be a dict"
- except json.JSONDecodeError:
- tool_args = {}
- except AssertionError:
- tool_args = json.loads(tool_args)
-
- # Get request heartbeats and coerce to bool
- request_heartbeat = tool_args.pop("request_heartbeat", False)
- if is_final_step:
- stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
- self.logger.info("Agent has reached max steps.")
- request_heartbeat = False
- else:
- # Pre-emptively pop out inner_thoughts
- tool_args.pop(INNER_THOUGHTS_KWARG, "")
-
- # So this is necessary, because sometimes non-structured outputs makes mistakes
- if not isinstance(request_heartbeat, bool):
- if isinstance(request_heartbeat, str):
- request_heartbeat = request_heartbeat.lower() == "true"
- else:
- request_heartbeat = bool(request_heartbeat)
-
- tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
+ # 1. Parse and validate the tool-call envelope
+ tool_call_name: str = tool_call.function.name
+ tool_call_id: str = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
+ tool_args = _safe_load_dict(tool_call.function.arguments)
+ request_heartbeat: bool = _pop_heartbeat(tool_args)
+ tool_args.pop(INNER_THOUGHTS_KWARG, None)
log_telemetry(
self.logger,
@@ -979,16 +956,11 @@ class LettaAgent(BaseAgent):
tool_call_id=tool_call_id,
request_heartbeat=request_heartbeat,
)
- # Check if tool rule is violated - if so, we'll force continuation
- tool_rule_violated = tool_call_name not in valid_tool_names
+ # 2. Execute the tool (or synthesize an error result if disallowed)
+ tool_rule_violated = tool_call_name not in valid_tool_names
if tool_rule_violated:
- base_error_message = f"[ToolConstraintError] Cannot call {tool_call_name}, valid tools to call include: {valid_tool_names}."
- violated_rule_messages = tool_rules_solver.guess_rule_violation(tool_call_name)
- if violated_rule_messages:
- bullet_points = "\n".join(f"\t- {msg}" for msg in violated_rule_messages)
- base_error_message += f"\n** Hint: Possible rules that were violated:\n{bullet_points}"
- tool_execution_result = ToolExecutionResult(status="error", func_return=base_error_message)
+ tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver)
else:
tool_execution_result = await self._execute_tool(
tool_name=tool_call_name,
@@ -997,66 +969,38 @@ class LettaAgent(BaseAgent):
agent_step_span=agent_step_span,
step_id=step_id,
)
+
log_telemetry(
self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id
)
- if tool_call_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
- # with certain functions we rely on the paging mechanism to handle overflow
- truncate = False
- else:
- # but by default, we add a truncation safeguard to prevent bad functions from
- # overflow the agent context window
- truncate = True
-
- # get the function response limit
- target_tool = next((x for x in agent_state.tools if x.name == tool_call_name), None)
- return_char_limit = target_tool.return_char_limit if target_tool else None
- function_response_string = validate_function_response(
- tool_execution_result.func_return, return_char_limit=return_char_limit, truncate=truncate
+ # 3. Prepare the function-response payload
+ truncate = tool_call_name not in {"conversation_search", "conversation_search_date", "archival_memory_search"}
+ return_char_limit = next(
+ (t.return_char_limit for t in agent_state.tools if t.name == tool_call_name),
+ None,
)
- function_response = package_function_response(
+ function_response_string = validate_function_response(
+ tool_execution_result.func_return,
+ return_char_limit=return_char_limit,
+ truncate=truncate,
+ )
+ self.last_function_response = package_function_response(
was_success=tool_execution_result.success_flag,
response_string=function_response_string,
timezone=agent_state.timezone,
)
- # 4. Register tool call with tool rule solver
- # Resolve whether or not to continue stepping
- continue_stepping = request_heartbeat
+ # 4. Decide whether to keep stepping (<<< focal section simplified)
+ continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
+ request_heartbeat=request_heartbeat,
+ tool_call_name=tool_call_name,
+ tool_rule_violated=tool_rule_violated,
+ tool_rules_solver=tool_rules_solver,
+ is_final_step=is_final_step,
+ )
- # Force continuation if tool rule was violated to give the model another chance
- if tool_rule_violated:
- continue_stepping = True
- else:
- tool_rules_solver.register_tool_call(tool_name=tool_call_name)
- if tool_rules_solver.is_terminal_tool(tool_name=tool_call_name):
- if continue_stepping:
- stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
- continue_stepping = False
- elif tool_rules_solver.has_children_tools(tool_name=tool_call_name):
- continue_stepping = True
- elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name):
- continue_stepping = True
-
- # Check if required-before-exit tools have been called before allowing exit
- heartbeat_reason = None # Default
- uncalled_required_tools = tool_rules_solver.get_uncalled_required_tools()
- if not continue_stepping and uncalled_required_tools:
- continue_stepping = True
- heartbeat_reason = (
- f"{NON_USER_MSG_PREFIX}Cannot finish, still need to call the following required tools: {', '.join(uncalled_required_tools)}"
- )
-
- # TODO: @caren is this right?
- # reset stop reason since we ain't stopping!
- stop_reason = None
- self.logger.info(f"RequiredBeforeExitToolRule: Forcing agent continuation. Missing required tools: {uncalled_required_tools}")
-
- # 5a. Persist Steps to DB
- # Following agent loop to persist this before messages
- # TODO (cliandy): determine what should match old loop w/provider_id
- # TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same.
+ # 5. Persist step + messages and propagate to jobs
logged_step = await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
@@ -1071,7 +1015,6 @@ class LettaAgent(BaseAgent):
step_id=step_id,
)
- # 5b. Persist Messages to DB
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
@@ -1083,27 +1026,72 @@ class LettaAgent(BaseAgent):
function_response=function_response_string,
timezone=agent_state.timezone,
actor=self.actor,
- add_heartbeat_request_system_message=continue_stepping,
+ continue_stepping=continue_stepping,
heartbeat_reason=heartbeat_reason,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
- step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops
+ step_id=logged_step.id if logged_step else None,
)
persisted_messages = await self.message_manager.create_many_messages_async(
(initial_messages or []) + tool_call_messages, actor=self.actor
)
- self.last_function_response = function_response
if run_id:
await self.job_manager.add_messages_to_job_async(
job_id=run_id,
- message_ids=[message.id for message in persisted_messages if message.role != "user"],
+ message_ids=[m.id for m in persisted_messages if m.role != "user"],
actor=self.actor,
)
return persisted_messages, continue_stepping, stop_reason
+ def _decide_continuation(
+ self,
+ request_heartbeat: bool,
+ tool_call_name: str,
+ tool_rule_violated: bool,
+ tool_rules_solver: ToolRulesSolver,
+ is_final_step: bool | None,
+ ) -> tuple[bool, str | None, LettaStopReason | None]:
+
+ continue_stepping = request_heartbeat
+ heartbeat_reason: str | None = None
+ stop_reason: LettaStopReason | None = None
+
+ if tool_rule_violated:
+ continue_stepping = True
+ heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: tool rule violation."
+ else:
+ tool_rules_solver.register_tool_call(tool_call_name)
+
+ if tool_rules_solver.is_terminal_tool(tool_call_name):
+ if continue_stepping:
+ stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
+ continue_stepping = False
+
+ elif tool_rules_solver.has_children_tools(tool_call_name):
+ continue_stepping = True
+ heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: child tool rule."
+
+ elif tool_rules_solver.is_continue_tool(tool_call_name):
+ continue_stepping = True
+ heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing: continue tool rule."
+
+ # – hard stop overrides –
+ if is_final_step:
+ continue_stepping = False
+ stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
+ else:
+ uncalled = tool_rules_solver.get_uncalled_required_tools()
+ if not continue_stepping and uncalled:
+ continue_stepping = True
+ heartbeat_reason = f"{NON_USER_MSG_PREFIX}Missing required tools: " f"{', '.join(uncalled)}"
+
+ stop_reason = None # reset – we’re still going
+
+ return continue_stepping, heartbeat_reason, stop_reason
+
@trace_method
async def _execute_tool(
self,
diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py
index d3343a03..da94511f 100644
--- a/letta/agents/letta_agent_batch.py
+++ b/letta/agents/letta_agent_batch.py
@@ -550,7 +550,7 @@ class LettaAgentBatch(BaseAgent):
tool_execution_result=tool_exec_result_obj,
timezone=agent_state.timezone,
actor=self.actor,
- add_heartbeat_request_system_message=False,
+ continue_stepping=False,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=None,
llm_batch_item_id=llm_batch_item_id,
diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py
index 633f9925..edc18fd2 100644
--- a/letta/agents/voice_agent.py
+++ b/letta/agents/voice_agent.py
@@ -277,7 +277,7 @@ class VoiceAgent(BaseAgent):
tool_execution_result=tool_execution_result,
timezone=agent_state.timezone,
actor=self.actor,
- add_heartbeat_request_system_message=True,
+ continue_stepping=True,
)
letta_message_db_queue.extend(tool_call_messages)
diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py
index 6d78c810..30950d8d 100644
--- a/letta/schemas/tool_rule.py
+++ b/letta/schemas/tool_rule.py
@@ -52,7 +52,7 @@ class ChildToolRule(BaseToolRule):
type: Literal[ToolRuleType.constrain_child_tools] = ToolRuleType.constrain_child_tools
children: List[str] = Field(..., description="The children tools that can be invoked.")
prompt_template: Optional[str] = Field(
- default="\nAfter using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}\n",
+ default="\nAfter using {{ tool_name }}, you must use one of these tools: {{ children | join(', ') }}\n",
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
)
@@ -61,7 +61,7 @@ class ChildToolRule(BaseToolRule):
return set(self.children) if last_tool == self.tool_name else available_tools
def _get_default_template(self) -> Optional[str]:
- return "\nAfter using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}\n"
+ return "\nAfter using {{ tool_name }}, you must use one of these tools: {{ children | join(', ') }}\n"
class ParentToolRule(BaseToolRule):
diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py
index c6a4d902..394193e3 100644
--- a/letta/server/rest_api/utils.py
+++ b/letta/server/rest_api/utils.py
@@ -13,7 +13,13 @@ from openai.types.chat.chat_completion_message_tool_call import Function as Open
from openai.types.chat.completion_create_params import CompletionCreateParams
from pydantic import BaseModel
-from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
+from letta.constants import (
+ DEFAULT_MESSAGE_TOOL,
+ DEFAULT_MESSAGE_TOOL_KWARG,
+ FUNC_FAILED_HEARTBEAT_MESSAGE,
+ REQ_HEARTBEAT_MESSAGE,
+ REQUEST_HEARTBEAT_PARAM,
+)
from letta.errors import ContextWindowExceededError, RateLimitExceededError
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns, ns_to_ms
from letta.helpers.message_helper import convert_message_creates_to_messages
@@ -194,7 +200,7 @@ def create_letta_messages_from_llm_response(
function_response: Optional[str],
timezone: str,
actor: User,
- add_heartbeat_request_system_message: bool = False,
+ continue_stepping: bool = False,
heartbeat_reason: Optional[str] = None,
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
@@ -202,9 +208,9 @@ def create_letta_messages_from_llm_response(
step_id: str | None = None,
) -> List[Message]:
messages = []
-
# Construct the tool call with the assistant's message
- function_arguments["request_heartbeat"] = True
+ # Force set request_heartbeat in tool_args to calculated continue_stepping
+ function_arguments[REQUEST_HEARTBEAT_PARAM] = continue_stepping
tool_call = OpenAIToolCall(
id=tool_call_id,
function=OpenAIFunction(
@@ -254,7 +260,7 @@ def create_letta_messages_from_llm_response(
)
messages.append(tool_message)
- if add_heartbeat_request_system_message:
+ if continue_stepping:
heartbeat_system_message = create_heartbeat_system_message(
agent_id=agent_id,
model=model,
@@ -323,7 +329,7 @@ def create_assistant_messages_from_openai_response(
function_response=None,
timezone=timezone,
actor=actor,
- add_heartbeat_request_system_message=False,
+ continue_stepping=False,
)
diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py
index 5482bce7..57445d68 100644
--- a/tests/integration_test_agent_tool_graph.py
+++ b/tests/integration_test_agent_tool_graph.py
@@ -1,4 +1,5 @@
import asyncio
+import json
import uuid
import pytest
@@ -996,3 +997,53 @@ async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_
assert len(send_message_calls) == 1, "Should call send_message exactly once"
print(f"✓ Agent '{agent_name}' exited cleanly after calling required tool normally")
+
+
+@pytest.mark.timeout(60)
+@pytest.mark.asyncio
+async def test_terminal_tool_rule_send_message_request_heartbeat_false(server, disable_e2b_api_key, default_user):
+ """Test that when there's a terminal tool rule on send_message, the tool call has request_heartbeat=False."""
+ agent_name = "terminal_send_message_heartbeat_test"
+ config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
+
+ # Set up tool rules with terminal rule on send_message
+ tool_rules = [
+ TerminalToolRule(tool_name="send_message"),
+ ]
+
+ # Create agent
+ agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[], tool_rules=tool_rules)
+
+ # Send message that should trigger send_message tool call
+ response = await run_agent_step(
+ server=server,
+ agent_id=agent_state.id,
+ input_messages=[MessageCreate(role="user", content="Please send me a simple message.")],
+ actor=default_user,
+ )
+
+ # Assertions
+ assert_sanity_checks(response)
+ assert_invoked_function_call(response.messages, "send_message")
+
+ # Find the send_message tool call and check request_heartbeat is False
+ send_message_call = None
+ for message in response.messages:
+ if isinstance(message, ToolCallMessage) and message.tool_call.name == "send_message":
+ send_message_call = message
+ break
+
+ assert send_message_call is not None, "send_message tool call should be found"
+
+ # Parse the arguments and check request_heartbeat
+ try:
+ arguments = json.loads(send_message_call.tool_call.arguments)
+ except json.JSONDecodeError:
+ pytest.fail("Failed to parse tool call arguments as JSON")
+
+ assert "request_heartbeat" in arguments, "request_heartbeat should be present in send_message arguments"
+ assert arguments["request_heartbeat"] is False, "request_heartbeat should be False for terminal tool rule"
+
+ print(f"✓ Agent '{agent_name}' correctly set request_heartbeat=False for terminal send_message")
+
+ cleanup(server=server, agent_uuid=agent_name, actor=default_user)