fix: allow for send_message to be unterminated if the user requests it (#4169)

This commit is contained in:
Charles Packer
2025-08-25 21:11:14 -07:00
committed by GitHub
parent 9871bffdf2
commit bb57f9cca4
3 changed files with 15 additions and 5 deletions

View File

@@ -331,8 +331,13 @@ class Agent(BaseAgent):
return None
allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
# Extract terminal tool names from tool rules
terminal_tool_names = {rule.tool_name for rule in self.tool_rules_solver.terminal_tool_rules}
allowed_functions = runtime_override_tool_json_schema(
tool_list=allowed_functions, response_format=self.agent_state.response_format, request_heartbeat=True
tool_list=allowed_functions,
response_format=self.agent_state.response_format,
request_heartbeat=True,
terminal_tools=terminal_tool_names,
)
# For the first message, force the initial tool if one is specified

View File

@@ -1459,8 +1459,10 @@ class LettaAgent(BaseAgent):
force_tool_call = valid_tool_names[0]
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
# Extract terminal tool names from tool rules
terminal_tool_names = {rule.tool_name for rule in tool_rules_solver.terminal_tool_rules}
allowed_tools = runtime_override_tool_json_schema(
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True, terminal_tools=terminal_tool_names
)
return (

View File

@@ -70,13 +70,16 @@ def runtime_override_tool_json_schema(
tool_list: list[JsonDict],
response_format: ResponseFormatUnion | None,
request_heartbeat: bool = True,
terminal_tools: set[str] | None = None,
) -> list[JsonDict]:
"""Override the tool JSON schemas at runtime if certain conditions are met.
Cases:
1. We will inject `send_message` tool calls with `response_format` if provided
2. Tools will have an additional `request_heartbeat` parameter added.
2. Tools will have an additional `request_heartbeat` parameter added (except for terminal tools).
"""
if terminal_tools is None:
terminal_tools = set()
for tool_json in tool_list:
if tool_json["name"] == SEND_MESSAGE_TOOL_NAME and response_format and response_format.type != ResponseFormatType.text:
if response_format.type == ResponseFormatType.json_schema:
@@ -89,8 +92,8 @@ def runtime_override_tool_json_schema(
"properties": {},
}
if request_heartbeat:
# TODO (cliandy): see support for tool control loop parameters
if tool_json["name"] != SEND_MESSAGE_TOOL_NAME:
# Only add request_heartbeat to non-terminal tools
if tool_json["name"] not in terminal_tools:
tool_json["parameters"]["properties"][REQUEST_HEARTBEAT_PARAM] = {
"type": "boolean",
"description": REQUEST_HEARTBEAT_DESCRIPTION,