fix: allow for send_message to be unterminated if the user requests it (#4169)
This commit is contained in:
@@ -331,8 +331,13 @@ class Agent(BaseAgent):
|
||||
return None
|
||||
|
||||
allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
|
||||
# Extract terminal tool names from tool rules
|
||||
terminal_tool_names = {rule.tool_name for rule in self.tool_rules_solver.terminal_tool_rules}
|
||||
allowed_functions = runtime_override_tool_json_schema(
|
||||
tool_list=allowed_functions, response_format=self.agent_state.response_format, request_heartbeat=True
|
||||
tool_list=allowed_functions,
|
||||
response_format=self.agent_state.response_format,
|
||||
request_heartbeat=True,
|
||||
terminal_tools=terminal_tool_names,
|
||||
)
|
||||
|
||||
# For the first message, force the initial tool if one is specified
|
||||
|
||||
@@ -1459,8 +1459,10 @@ class LettaAgent(BaseAgent):
|
||||
force_tool_call = valid_tool_names[0]
|
||||
|
||||
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
|
||||
# Extract terminal tool names from tool rules
|
||||
terminal_tool_names = {rule.tool_name for rule in tool_rules_solver.terminal_tool_rules}
|
||||
allowed_tools = runtime_override_tool_json_schema(
|
||||
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True
|
||||
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True, terminal_tools=terminal_tool_names
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
@@ -70,13 +70,16 @@ def runtime_override_tool_json_schema(
|
||||
tool_list: list[JsonDict],
|
||||
response_format: ResponseFormatUnion | None,
|
||||
request_heartbeat: bool = True,
|
||||
terminal_tools: set[str] | None = None,
|
||||
) -> list[JsonDict]:
|
||||
"""Override the tool JSON schemas at runtime if certain conditions are met.
|
||||
|
||||
Cases:
|
||||
1. We will inject `send_message` tool calls with `response_format` if provided
|
||||
2. Tools will have an additional `request_heartbeat` parameter added.
|
||||
2. Tools will have an additional `request_heartbeat` parameter added (except for terminal tools).
|
||||
"""
|
||||
if terminal_tools is None:
|
||||
terminal_tools = set()
|
||||
for tool_json in tool_list:
|
||||
if tool_json["name"] == SEND_MESSAGE_TOOL_NAME and response_format and response_format.type != ResponseFormatType.text:
|
||||
if response_format.type == ResponseFormatType.json_schema:
|
||||
@@ -89,8 +92,8 @@ def runtime_override_tool_json_schema(
|
||||
"properties": {},
|
||||
}
|
||||
if request_heartbeat:
|
||||
# TODO (cliandy): see support for tool control loop parameters
|
||||
if tool_json["name"] != SEND_MESSAGE_TOOL_NAME:
|
||||
# Only add request_heartbeat to non-terminal tools
|
||||
if tool_json["name"] not in terminal_tools:
|
||||
tool_json["parameters"]["properties"][REQUEST_HEARTBEAT_PARAM] = {
|
||||
"type": "boolean",
|
||||
"description": REQUEST_HEARTBEAT_DESCRIPTION,
|
||||
|
||||
Reference in New Issue
Block a user