From bb57f9cca40ffdbaadab3caa5c9fb8565a85a907 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Mon, 25 Aug 2025 21:11:14 -0700 Subject: [PATCH] fix: allow for send_message to be unterminated if the user requests it (#4169) --- letta/agent.py | 7 ++++++- letta/agents/letta_agent.py | 4 +++- letta/services/helpers/tool_parser_helper.py | 9 ++++++--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 4a9127c8..0dde8fef 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -331,8 +331,13 @@ class Agent(BaseAgent): return None allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names] + # Extract terminal tool names from tool rules + terminal_tool_names = {rule.tool_name for rule in self.tool_rules_solver.terminal_tool_rules} allowed_functions = runtime_override_tool_json_schema( - tool_list=allowed_functions, response_format=self.agent_state.response_format, request_heartbeat=True + tool_list=allowed_functions, + response_format=self.agent_state.response_format, + request_heartbeat=True, + terminal_tools=terminal_tool_names, ) # For the first message, force the initial tool if one is specified diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 1b57b2c2..f896be14 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -1459,8 +1459,10 @@ class LettaAgent(BaseAgent): force_tool_call = valid_tool_names[0] allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)] + # Extract terminal tool names from tool rules + terminal_tool_names = {rule.tool_name for rule in tool_rules_solver.terminal_tool_rules} allowed_tools = runtime_override_tool_json_schema( - tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True + tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True, terminal_tools=terminal_tool_names ) return ( diff --git a/letta/services/helpers/tool_parser_helper.py b/letta/services/helpers/tool_parser_helper.py index 8bc5333b..4633f91c 100644 --- a/letta/services/helpers/tool_parser_helper.py +++ b/letta/services/helpers/tool_parser_helper.py @@ -70,13 +70,16 @@ def runtime_override_tool_json_schema( tool_list: list[JsonDict], response_format: ResponseFormatUnion | None, request_heartbeat: bool = True, + terminal_tools: set[str] | None = None, ) -> list[JsonDict]: """Override the tool JSON schemas at runtime if certain conditions are met. Cases: 1. We will inject `send_message` tool calls with `response_format` if provided - 2. Tools will have an additional `request_heartbeat` parameter added. + 2. Tools will have an additional `request_heartbeat` parameter added (except for terminal tools). """ + if terminal_tools is None: + terminal_tools = set() for tool_json in tool_list: if tool_json["name"] == SEND_MESSAGE_TOOL_NAME and response_format and response_format.type != ResponseFormatType.text: if response_format.type == ResponseFormatType.json_schema: @@ -89,8 +92,8 @@ def runtime_override_tool_json_schema( "properties": {}, } if request_heartbeat: - # TODO (cliandy): see support for tool control loop parameters - if tool_json["name"] != SEND_MESSAGE_TOOL_NAME: + # Only add request_heartbeat to non-terminal tools + if tool_json["name"] not in terminal_tools: tool_json["parameters"]["properties"][REQUEST_HEARTBEAT_PARAM] = { "type": "boolean", "description": REQUEST_HEARTBEAT_DESCRIPTION,