fix: Allow ChildToolRule to work without support for structured outputs (#2270)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
@@ -604,6 +604,9 @@ class Agent(BaseAgent):
|
||||
and len(self.tool_rules_solver.init_tool_rules) > 0
|
||||
):
|
||||
force_tool_call = self.tool_rules_solver.init_tool_rules[0].tool_name
|
||||
# Force a tool call if exactly one tool is specified
|
||||
elif step_count is not None and step_count > 0 and len(allowed_tool_names) == 1:
|
||||
force_tool_call = allowed_tool_names[0]
|
||||
|
||||
for attempt in range(1, empty_response_retry_limit + 1):
|
||||
try:
|
||||
|
||||
@@ -262,10 +262,24 @@ def convert_anthropic_response_to_chatcompletion(
|
||||
),
|
||||
)
|
||||
]
|
||||
else:
|
||||
# Just inner mono
|
||||
content = strip_xml_tags(string=response_json["content"][0]["text"], tag=inner_thoughts_xml_tag)
|
||||
tool_calls = None
|
||||
elif len(response_json["content"]) == 1:
|
||||
if response_json["content"][0]["type"] == "tool_use":
|
||||
# function call only
|
||||
content = None
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=response_json["content"][0]["id"],
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=response_json["content"][0]["name"],
|
||||
arguments=json.dumps(response_json["content"][0]["input"], indent=2),
|
||||
),
|
||||
)
|
||||
]
|
||||
else:
|
||||
# inner mono only
|
||||
content = strip_xml_tags(string=response_json["content"][0]["text"], tag=inner_thoughts_xml_tag)
|
||||
tool_calls = None
|
||||
else:
|
||||
raise RuntimeError("Unexpected type for content in response_json.")
|
||||
|
||||
@@ -327,6 +341,14 @@ def anthropic_chat_completions_request(
|
||||
if anthropic_tools is not None:
|
||||
data["tools"] = anthropic_tools
|
||||
|
||||
# TODO: Add support for other tool_choice options like "auto", "any"
|
||||
if len(anthropic_tools) == 1:
|
||||
data["tool_choice"] = {
|
||||
"type": "tool", # Changed from "function" to "tool"
|
||||
"name": anthropic_tools[0]["name"], # Directly specify name without nested "function" object
|
||||
"disable_parallel_tool_use": True # Force single tool use
|
||||
}
|
||||
|
||||
# Move 'system' to the top level
|
||||
# 'messages: Unexpected role "system". The Messages API accepts a top-level `system` parameter, not "system" as an input message role.'
|
||||
assert data["messages"][0]["role"] == "system", f"Expected 'system' role in messages[0]:\n{data['messages'][0]}"
|
||||
@@ -362,7 +384,6 @@ def anthropic_chat_completions_request(
|
||||
data.pop("top_p", None)
|
||||
data.pop("presence_penalty", None)
|
||||
data.pop("user", None)
|
||||
data.pop("tool_choice", None)
|
||||
|
||||
response_json = make_post_request(url, headers, data)
|
||||
return convert_anthropic_response_to_chatcompletion(response_json=response_json, inner_thoughts_xml_tag=inner_thoughts_xml_tag)
|
||||
|
||||
@@ -64,6 +64,7 @@ def setup_agent(
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
tool_rules: Optional[List[BaseToolRule]] = None,
|
||||
agent_uuid: str = agent_uuid,
|
||||
include_base_tools: bool = True,
|
||||
) -> AgentState:
|
||||
config_data = json.load(open(filename, "r"))
|
||||
llm_config = LLMConfig(**config_data)
|
||||
@@ -77,7 +78,7 @@ def setup_agent(
|
||||
|
||||
memory = ChatMemory(human=memory_human_str, persona=memory_persona_str)
|
||||
agent_state = client.create_agent(
|
||||
name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tool_ids=tool_ids, tool_rules=tool_rules
|
||||
name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tool_ids=tool_ids, tool_rules=tool_rules, include_base_tools=include_base_tools,
|
||||
)
|
||||
|
||||
return agent_state
|
||||
|
||||
@@ -234,3 +234,51 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none):
|
||||
if i < 2:
|
||||
backoff_time = 10 * (2 ** i)
|
||||
time.sleep(backoff_time)
|
||||
|
||||
@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
|
||||
def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none):
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
send_message = client.server.tool_manager.get_tool_by_name(tool_name="send_message", actor=client.user)
|
||||
archival_memory_search = client.server.tool_manager.get_tool_by_name(tool_name="archival_memory_search", actor=client.user)
|
||||
archival_memory_insert = client.server.tool_manager.get_tool_by_name(tool_name="archival_memory_insert", actor=client.user)
|
||||
|
||||
# Make tool rules
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name="archival_memory_search"),
|
||||
ChildToolRule(tool_name="archival_memory_search", children=["archival_memory_insert"]),
|
||||
ChildToolRule(tool_name="archival_memory_insert", children=["send_message"]),
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
tools = [send_message, archival_memory_search, archival_memory_insert]
|
||||
|
||||
config_files = [
|
||||
"tests/configs/llm_model_configs/claude-3-sonnet-20240229.json",
|
||||
"tests/configs/llm_model_configs/openai-gpt-4o.json",
|
||||
]
|
||||
|
||||
for config in config_files:
|
||||
agent_state = setup_agent(client, config, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
response = client.user_message(agent_id=agent_state.id, message="hi. run archival memory search")
|
||||
|
||||
# Make checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Assert the tools were called
|
||||
assert_invoked_function_call(response.messages, "archival_memory_search")
|
||||
assert_invoked_function_call(response.messages, "archival_memory_insert")
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
# Check ordering of tool calls
|
||||
tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]]
|
||||
for m in response.messages:
|
||||
if isinstance(m, FunctionCallMessage):
|
||||
# Check that it's equal to the first one
|
||||
assert m.function_call.name == tool_names[0]
|
||||
|
||||
# Pop out first one
|
||||
tool_names = tool_names[1:]
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
Reference in New Issue
Block a user