feat: Add ConditionalToolRules (#2279)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
@@ -371,6 +372,9 @@ class Agent(BaseAgent):
|
||||
self._append_to_messages(added_messages=init_messages_objs)
|
||||
self._validate_message_buffer_is_utc()
|
||||
|
||||
# Load last function response from message history
|
||||
self.last_function_response = self.load_last_function_response()
|
||||
|
||||
# Keep track of the total number of messages throughout all time
|
||||
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
|
||||
self.messages_total_init = len(self._messages) - 1
|
||||
@@ -389,6 +393,19 @@ class Agent(BaseAgent):
|
||||
else:
|
||||
self.supports_structured_output = True
|
||||
|
||||
def load_last_function_response(self):
|
||||
"""Load the last function response from message history"""
|
||||
for i in range(len(self._messages) - 1, -1, -1):
|
||||
msg = self._messages[i]
|
||||
if msg.role == MessageRole.tool and msg.text:
|
||||
try:
|
||||
response_json = json.loads(msg.text)
|
||||
if response_json.get("message"):
|
||||
return response_json["message"]
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
raise ValueError(f"Invalid JSON format in message: {msg.text}")
|
||||
return None
|
||||
|
||||
def update_memory_if_change(self, new_memory: Memory) -> bool:
|
||||
"""
|
||||
Update internal memory object and system prompt if there have been modifications.
|
||||
@@ -586,7 +603,7 @@ class Agent(BaseAgent):
|
||||
) -> ChatCompletionResponse:
|
||||
"""Get response from LLM API with robust retry mechanism."""
|
||||
|
||||
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names()
|
||||
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(last_function_response=self.last_function_response)
|
||||
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
|
||||
|
||||
allowed_functions = (
|
||||
@@ -826,6 +843,7 @@ class Agent(BaseAgent):
|
||||
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
|
||||
printd(error_msg_user)
|
||||
function_response = package_function_response(False, error_msg)
|
||||
self.last_function_response = function_response
|
||||
# TODO: truncate error message somehow
|
||||
messages.append(
|
||||
Message.dict_to_message(
|
||||
@@ -861,6 +879,7 @@ class Agent(BaseAgent):
|
||||
) # extend conversation with function response
|
||||
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
|
||||
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1])
|
||||
self.last_function_response = function_response
|
||||
|
||||
else:
|
||||
# Standard non-function reply
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Dict, List, Optional, Set
|
||||
import json
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -6,6 +7,7 @@ from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.tool_rule import (
|
||||
BaseToolRule,
|
||||
ChildToolRule,
|
||||
ConditionalToolRule,
|
||||
InitToolRule,
|
||||
TerminalToolRule,
|
||||
)
|
||||
@@ -22,7 +24,7 @@ class ToolRulesSolver(BaseModel):
|
||||
init_tool_rules: List[InitToolRule] = Field(
|
||||
default_factory=list, description="Initial tool rules to be used at the start of tool execution."
|
||||
)
|
||||
tool_rules: List[ChildToolRule] = Field(
|
||||
tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field(
|
||||
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
|
||||
)
|
||||
terminal_tool_rules: List[TerminalToolRule] = Field(
|
||||
@@ -35,21 +37,25 @@ class ToolRulesSolver(BaseModel):
|
||||
# Separate the provided tool rules into init, standard, and terminal categories
|
||||
for rule in tool_rules:
|
||||
if rule.type == ToolRuleType.run_first:
|
||||
assert isinstance(rule, InitToolRule)
|
||||
self.init_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.constrain_child_tools:
|
||||
assert isinstance(rule, ChildToolRule)
|
||||
self.tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.conditional:
|
||||
assert isinstance(rule, ConditionalToolRule)
|
||||
self.validate_conditional_tool(rule)
|
||||
self.tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.exit_loop:
|
||||
assert isinstance(rule, TerminalToolRule)
|
||||
self.terminal_tool_rules.append(rule)
|
||||
|
||||
# Validate the tool rules to ensure they form a DAG
|
||||
if not self.validate_tool_rules():
|
||||
raise ToolRuleValidationError("Tool rules contain cycles, which are not allowed in a valid configuration.")
|
||||
|
||||
def update_tool_usage(self, tool_name: str):
|
||||
"""Update the internal state to track the last tool called."""
|
||||
self.last_tool_name = tool_name
|
||||
|
||||
def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]:
|
||||
def get_allowed_tool_names(self, error_on_empty: bool = False, last_function_response: Optional[str] = None) -> List[str]:
|
||||
"""Get a list of tool names allowed based on the last tool called."""
|
||||
if self.last_tool_name is None:
|
||||
# Use initial tool rules if no tool has been called yet
|
||||
@@ -58,18 +64,21 @@ class ToolRulesSolver(BaseModel):
|
||||
# Find a matching ToolRule for the last tool used
|
||||
current_rule = next((rule for rule in self.tool_rules if rule.tool_name == self.last_tool_name), None)
|
||||
|
||||
# Return children which must exist on ToolRule
|
||||
if current_rule:
|
||||
return current_rule.children
|
||||
|
||||
# Default to empty if no rule matches
|
||||
message = "User provided tool rules and execution state resolved to no more possible tool calls."
|
||||
if error_on_empty:
|
||||
raise RuntimeError(message)
|
||||
else:
|
||||
# warnings.warn(message)
|
||||
if current_rule is None:
|
||||
if error_on_empty:
|
||||
raise ValueError(f"No tool rule found for {self.last_tool_name}")
|
||||
return []
|
||||
|
||||
# If the current rule is a conditional tool rule, use the LLM response to
|
||||
# determine which child tool to use
|
||||
if isinstance(current_rule, ConditionalToolRule):
|
||||
if not last_function_response:
|
||||
raise ValueError("Conditional tool rule requires an LLM response to determine which child tool to use")
|
||||
next_tool = self.evaluate_conditional_tool(current_rule, last_function_response)
|
||||
return [next_tool] if next_tool else []
|
||||
|
||||
return current_rule.children if current_rule.children else []
|
||||
|
||||
def is_terminal_tool(self, tool_name: str) -> bool:
|
||||
"""Check if the tool is defined as a terminal tool in the terminal tool rules."""
|
||||
return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
|
||||
@@ -78,38 +87,60 @@ class ToolRulesSolver(BaseModel):
|
||||
"""Check if the tool has children tools"""
|
||||
return any(rule.tool_name == tool_name for rule in self.tool_rules)
|
||||
|
||||
def validate_tool_rules(self) -> bool:
|
||||
"""
|
||||
Validate that the tool rules define a directed acyclic graph (DAG).
|
||||
Returns True if valid (no cycles), otherwise False.
|
||||
"""
|
||||
# Build adjacency list for the tool graph
|
||||
adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules}
|
||||
def validate_conditional_tool(self, rule: ConditionalToolRule):
|
||||
'''
|
||||
Validate a conditional tool rule
|
||||
|
||||
# Track visited nodes
|
||||
visited: Set[str] = set()
|
||||
path_stack: Set[str] = set()
|
||||
Args:
|
||||
rule (ConditionalToolRule): The conditional tool rule to validate
|
||||
|
||||
# Define DFS helper function
|
||||
def dfs(tool_name: str) -> bool:
|
||||
if tool_name in path_stack:
|
||||
return False # Cycle detected
|
||||
if tool_name in visited:
|
||||
return True # Already validated
|
||||
Raises:
|
||||
ToolRuleValidationError: If the rule is invalid
|
||||
'''
|
||||
if len(rule.child_output_mapping) == 0:
|
||||
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
|
||||
return True
|
||||
|
||||
# Mark the node as visited in the current path
|
||||
path_stack.add(tool_name)
|
||||
for child in adjacency_list.get(tool_name, []):
|
||||
if not dfs(child):
|
||||
return False # Cycle detected in DFS
|
||||
path_stack.remove(tool_name) # Remove from current path
|
||||
visited.add(tool_name)
|
||||
return True
|
||||
def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
|
||||
'''
|
||||
Parse function response to determine which child tool to use based on the mapping
|
||||
|
||||
# Run DFS from each tool in `tool_rules`
|
||||
for rule in self.tool_rules:
|
||||
if rule.tool_name not in visited:
|
||||
if not dfs(rule.tool_name):
|
||||
return False # Cycle found, invalid tool rules
|
||||
Args:
|
||||
tool (ConditionalToolRule): The conditional tool rule
|
||||
last_function_response (str): The function response in JSON format
|
||||
|
||||
return True # No cycles, valid DAG
|
||||
Returns:
|
||||
str: The name of the child tool to use next
|
||||
'''
|
||||
json_response = json.loads(last_function_response)
|
||||
function_output = json_response["message"]
|
||||
|
||||
# Try to match the function output with a mapping key
|
||||
for key in tool.child_output_mapping:
|
||||
|
||||
# Convert function output to match key type for comparison
|
||||
if isinstance(key, bool):
|
||||
typed_output = function_output.lower() == "true"
|
||||
elif isinstance(key, int):
|
||||
try:
|
||||
typed_output = int(function_output)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif isinstance(key, float):
|
||||
try:
|
||||
typed_output = float(function_output)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
else: # string
|
||||
if function_output == "True" or function_output == "False":
|
||||
typed_output = function_output.lower()
|
||||
elif function_output == "None":
|
||||
typed_output = None
|
||||
else:
|
||||
typed_output = function_output
|
||||
|
||||
if typed_output == key:
|
||||
return tool.child_output_mapping[key]
|
||||
|
||||
# If no match found, use default
|
||||
return tool.default_child
|
||||
|
||||
@@ -9,7 +9,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
|
||||
|
||||
|
||||
class EmbeddingConfigColumn(TypeDecorator):
|
||||
@@ -80,7 +80,7 @@ class ToolRulesColumn(TypeDecorator):
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
|
||||
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
|
||||
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
|
||||
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
|
||||
if rule_type == ToolRuleType.run_first:
|
||||
@@ -90,6 +90,9 @@ class ToolRulesColumn(TypeDecorator):
|
||||
elif rule_type == ToolRuleType.constrain_child_tools:
|
||||
rule = ChildToolRule(**data)
|
||||
return rule
|
||||
elif rule_type == ToolRuleType.conditional:
|
||||
rule = ConditionalToolRule(**data)
|
||||
return rule
|
||||
else:
|
||||
raise ValueError(f"Unknown tool rule type: {rule_type}")
|
||||
|
||||
|
||||
@@ -45,5 +45,6 @@ class ToolRuleType(str, Enum):
|
||||
run_first = "InitToolRule"
|
||||
exit_loop = "TerminalToolRule" # reasoning loop should exit
|
||||
continue_loop = "continue_loop" # reasoning loop should continue
|
||||
conditional = "conditional"
|
||||
constrain_child_tools = "ToolRule"
|
||||
require_parent_tools = "require_parent_tools"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -21,6 +21,16 @@ class ChildToolRule(BaseToolRule):
|
||||
children: List[str] = Field(..., description="The children tools that can be invoked.")
|
||||
|
||||
|
||||
class ConditionalToolRule(BaseToolRule):
|
||||
"""
|
||||
A ToolRule that conditionally maps to different child tools based on the output.
|
||||
"""
|
||||
type: ToolRuleType = ToolRuleType.conditional
|
||||
default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.")
|
||||
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")
|
||||
require_output_mapping: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case")
|
||||
|
||||
|
||||
class InitToolRule(BaseToolRule):
|
||||
"""
|
||||
Represents the initial tool rule configuration.
|
||||
@@ -37,4 +47,4 @@ class TerminalToolRule(BaseToolRule):
|
||||
type: ToolRuleType = ToolRuleType.exit_loop
|
||||
|
||||
|
||||
ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule]
|
||||
ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, List, Optional, Sequence, Union
|
||||
|
||||
from letta.llm_api.helpers import unpack_inner_thoughts_from_kwargs
|
||||
from letta.schemas.tool_rule import BaseToolRule
|
||||
@@ -373,7 +373,7 @@ def assert_sanity_checks(response: LettaResponse):
|
||||
assert len(response.messages) > 0, response
|
||||
|
||||
|
||||
def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keyword: str, case_sensitive: bool = False) -> None:
|
||||
def assert_invoked_send_message_with_keyword(messages: Sequence[LettaMessage], keyword: str, case_sensitive: bool = False) -> None:
|
||||
# Find first instance of send_message
|
||||
target_message = None
|
||||
for message in messages:
|
||||
@@ -406,7 +406,7 @@ def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keywo
|
||||
raise InvalidToolCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}")
|
||||
|
||||
|
||||
def assert_invoked_function_call(messages: List[LettaMessage], function_name: str) -> None:
|
||||
def assert_invoked_function_call(messages: Sequence[LettaMessage], function_name: str) -> None:
|
||||
for message in messages:
|
||||
if isinstance(message, ToolCallMessage) and message.tool_call.name == function_name:
|
||||
# Found it, do nothing
|
||||
|
||||
@@ -4,7 +4,12 @@ import uuid
|
||||
import pytest
|
||||
from letta import create_client
|
||||
from letta.schemas.letta_message import ToolCallMessage
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import (
|
||||
ChildToolRule,
|
||||
ConditionalToolRule,
|
||||
InitToolRule,
|
||||
TerminalToolRule,
|
||||
)
|
||||
from tests.helpers.endpoints_helper import (
|
||||
assert_invoked_function_call,
|
||||
assert_invoked_send_message_with_keyword,
|
||||
@@ -68,6 +73,57 @@ def fourth_secret_word(prev_secret_word: str):
|
||||
return "banana"
|
||||
|
||||
|
||||
def flip_coin():
|
||||
"""
|
||||
Call this to retrieve the password to the secret word, which you will need to output in a send_message later.
|
||||
If it returns an empty string, try flipping again!
|
||||
|
||||
Returns:
|
||||
str: The password or an empty string
|
||||
"""
|
||||
import random
|
||||
|
||||
# Flip a coin with 50% chance
|
||||
if random.random() < 0.5:
|
||||
return ""
|
||||
return "hj2hwibbqm"
|
||||
|
||||
|
||||
def flip_coin_hard():
|
||||
"""
|
||||
Call this to retrieve the password to the secret word, which you will need to output in a send_message later.
|
||||
If it returns an empty string, try flipping again!
|
||||
|
||||
Returns:
|
||||
str: The password or an empty string
|
||||
"""
|
||||
import random
|
||||
|
||||
# Flip a coin with 50% chance
|
||||
result = random.random()
|
||||
if result < 0.5:
|
||||
return ""
|
||||
if result < 0.75:
|
||||
return "START_OVER"
|
||||
return "hj2hwibbqm"
|
||||
|
||||
|
||||
def can_play_game():
|
||||
"""
|
||||
Call this to start the tool chain.
|
||||
"""
|
||||
import random
|
||||
|
||||
return random.random() < 0.5
|
||||
|
||||
|
||||
def return_none():
|
||||
"""
|
||||
Really simple function
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
def auto_error():
|
||||
"""
|
||||
If you call this function, it will throw an error automatically.
|
||||
@@ -201,6 +257,7 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none):
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=t1_name),
|
||||
ChildToolRule(tool_name=t1_name, children=[t2_name]),
|
||||
TerminalToolRule(tool_name=t2_name)
|
||||
]
|
||||
tools = [t1, t2]
|
||||
|
||||
@@ -259,26 +316,331 @@ def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none):
|
||||
]
|
||||
|
||||
for config in config_files:
|
||||
agent_state = setup_agent(client, config, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
response = client.user_message(agent_id=agent_state.id, message="hi. run archival memory search")
|
||||
max_retries = 3
|
||||
last_error = None
|
||||
|
||||
# Make checks
|
||||
assert_sanity_checks(response)
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
agent_state = setup_agent(client, config, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
response = client.user_message(agent_id=agent_state.id, message="hi. run archival memory search")
|
||||
|
||||
# Assert the tools were called
|
||||
assert_invoked_function_call(response.messages, "archival_memory_search")
|
||||
assert_invoked_function_call(response.messages, "archival_memory_insert")
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
# Make checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Check ordering of tool calls
|
||||
tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]]
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolCallMessage):
|
||||
# Check that it's equal to the first one
|
||||
assert m.tool_call.name == tool_names[0]
|
||||
# Assert the tools were called
|
||||
assert_invoked_function_call(response.messages, "archival_memory_search")
|
||||
assert_invoked_function_call(response.messages, "archival_memory_insert")
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
# Pop out first one
|
||||
tool_names = tool_names[1:]
|
||||
# Check ordering of tool calls
|
||||
tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]]
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolCallMessage):
|
||||
# Check that it's equal to the first one
|
||||
assert m.tool_call.name == tool_names[0]
|
||||
|
||||
# Pop out first one
|
||||
tool_names = tool_names[1:]
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
break # Test passed, exit retry loop
|
||||
|
||||
except AssertionError as e:
|
||||
last_error = e
|
||||
print(f"Attempt {attempt + 1} failed, retrying..." if attempt < max_retries - 1 else f"All {max_retries} attempts failed")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
continue
|
||||
|
||||
if last_error and attempt == max_retries - 1:
|
||||
raise last_error # Re-raise the last error if all retries failed
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
|
||||
def test_agent_conditional_tool_easy(mock_e2b_api_key_none):
|
||||
"""
|
||||
Test the agent with a conditional tool that has a child tool.
|
||||
|
||||
Tool Flow:
|
||||
|
||||
-------
|
||||
| |
|
||||
| v
|
||||
-- flip_coin
|
||||
|
|
||||
v
|
||||
reveal_secret_word
|
||||
"""
|
||||
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
coin_flip_name = "flip_coin"
|
||||
secret_word_tool = "fourth_secret_word"
|
||||
flip_coin_tool = client.create_or_update_tool(flip_coin, name=coin_flip_name)
|
||||
reveal_secret = client.create_or_update_tool(fourth_secret_word, name=secret_word_tool)
|
||||
|
||||
# Make tool rules
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=coin_flip_name),
|
||||
ConditionalToolRule(
|
||||
tool_name=coin_flip_name,
|
||||
default_child=coin_flip_name,
|
||||
child_output_mapping={
|
||||
"hj2hwibbqm": secret_word_tool,
|
||||
}
|
||||
),
|
||||
TerminalToolRule(tool_name=secret_word_tool),
|
||||
]
|
||||
tools = [flip_coin_tool, reveal_secret]
|
||||
|
||||
config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word")
|
||||
|
||||
# Make checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Assert the tools were called
|
||||
assert_invoked_function_call(response.messages, "flip_coin")
|
||||
assert_invoked_function_call(response.messages, "fourth_secret_word")
|
||||
|
||||
# Check ordering of tool calls
|
||||
found_secret_word = False
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolCallMessage):
|
||||
if m.tool_call.name == secret_word_tool:
|
||||
# Should be the last tool call
|
||||
found_secret_word = True
|
||||
else:
|
||||
# Before finding secret_word, only flip_coin should be called
|
||||
assert m.tool_call.name == coin_flip_name
|
||||
assert not found_secret_word
|
||||
|
||||
# Ensure we found the secret word exactly once
|
||||
assert found_secret_word
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.timeout(90) # Longer timeout since this test has more steps
|
||||
def test_agent_conditional_tool_hard(mock_e2b_api_key_none):
|
||||
"""
|
||||
Test the agent with a complex conditional tool graph
|
||||
|
||||
Tool Flow:
|
||||
|
||||
can_play_game <---+
|
||||
| |
|
||||
v |
|
||||
flip_coin -----+
|
||||
|
|
||||
v
|
||||
fourth_secret_word
|
||||
"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
# Create tools
|
||||
play_game = "can_play_game"
|
||||
coin_flip_name = "flip_coin_hard"
|
||||
final_tool = "fourth_secret_word"
|
||||
play_game_tool = client.create_or_update_tool(can_play_game, name=play_game)
|
||||
flip_coin_tool = client.create_or_update_tool(flip_coin_hard, name=coin_flip_name)
|
||||
reveal_secret = client.create_or_update_tool(fourth_secret_word, name=final_tool)
|
||||
|
||||
# Make tool rules - chain them together with conditional rules
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=play_game),
|
||||
ConditionalToolRule(
|
||||
tool_name=play_game,
|
||||
default_child=play_game, # Keep trying if we can't play
|
||||
child_output_mapping={
|
||||
True: coin_flip_name # Only allow access when can_play_game returns True
|
||||
}
|
||||
),
|
||||
ConditionalToolRule(
|
||||
tool_name=coin_flip_name,
|
||||
default_child=coin_flip_name,
|
||||
child_output_mapping={
|
||||
"hj2hwibbqm": final_tool, "START_OVER": play_game
|
||||
}
|
||||
),
|
||||
TerminalToolRule(tool_name=final_tool),
|
||||
]
|
||||
|
||||
# Setup agent with all tools
|
||||
tools = [play_game_tool, flip_coin_tool, reveal_secret]
|
||||
config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
|
||||
agent_state = setup_agent(
|
||||
client,
|
||||
config_file,
|
||||
agent_uuid=agent_uuid,
|
||||
tool_ids=[t.id for t in tools],
|
||||
tool_rules=tool_rules
|
||||
)
|
||||
|
||||
# Ask agent to try to get all secret words
|
||||
response = client.user_message(agent_id=agent_state.id, message="hi")
|
||||
|
||||
# Make checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Assert all tools were called
|
||||
assert_invoked_function_call(response.messages, play_game)
|
||||
assert_invoked_function_call(response.messages, final_tool)
|
||||
|
||||
# Check ordering of tool calls
|
||||
found_words = []
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolCallMessage):
|
||||
name = m.tool_call.name
|
||||
if name in [play_game, coin_flip_name]:
|
||||
# Before finding secret_word, only can_play_game and flip_coin should be called
|
||||
assert name in [play_game, coin_flip_name]
|
||||
else:
|
||||
# Should find secret words in order
|
||||
expected_word = final_tool
|
||||
assert name == expected_word, f"Found {name} but expected {expected_word}"
|
||||
found_words.append(name)
|
||||
|
||||
# Ensure we found all secret words in order
|
||||
assert found_words == [final_tool]
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_agent_conditional_tool_without_default_child(mock_e2b_api_key_none):
|
||||
"""
|
||||
Test the agent with a conditional tool that allows any child tool to be called if a function returns None.
|
||||
|
||||
Tool Flow:
|
||||
|
||||
return_none
|
||||
|
|
||||
v
|
||||
any tool... <-- When output doesn't match mapping, agent can call any tool
|
||||
"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
# Create tools - we'll make several available to the agent
|
||||
tool_name = "return_none"
|
||||
|
||||
tool = client.create_or_update_tool(return_none, name=tool_name)
|
||||
secret_word = client.create_or_update_tool(first_secret_word, name="first_secret_word")
|
||||
|
||||
# Make tool rules - only map one output, let others be free choice
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=tool_name),
|
||||
ConditionalToolRule(
|
||||
tool_name=tool_name,
|
||||
default_child=None, # Allow any tool to be called if output doesn't match
|
||||
child_output_mapping={
|
||||
"anything but none": "first_secret_word"
|
||||
}
|
||||
)
|
||||
]
|
||||
tools = [tool, secret_word]
|
||||
|
||||
# Setup agent with all tools
|
||||
agent_state = setup_agent(
|
||||
client,
|
||||
config_file,
|
||||
agent_uuid=agent_uuid,
|
||||
tool_ids=[t.id for t in tools],
|
||||
tool_rules=tool_rules
|
||||
)
|
||||
|
||||
# Ask agent to try different tools based on the game output
|
||||
response = client.user_message(
|
||||
agent_id=agent_state.id,
|
||||
message="call a function, any function. then call send_message"
|
||||
)
|
||||
|
||||
# Make checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Assert return_none was called
|
||||
assert_invoked_function_call(response.messages, tool_name)
|
||||
|
||||
# Assert any base function called afterward
|
||||
found_any_tool = False
|
||||
found_return_none = False
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolCallMessage):
|
||||
if m.tool_call.name == tool_name:
|
||||
found_return_none = True
|
||||
elif found_return_none and m.tool_call.name:
|
||||
found_any_tool = True
|
||||
break
|
||||
|
||||
assert found_any_tool, "Should have called any tool after return_none"
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_agent_reload_remembers_function_response(mock_e2b_api_key_none):
|
||||
"""
|
||||
Test that when an agent is reloaded, it remembers the last function response for conditional tool chaining.
|
||||
|
||||
Tool Flow:
|
||||
|
||||
flip_coin
|
||||
|
|
||||
v
|
||||
fourth_secret_word <-- Should remember coin flip result after reload
|
||||
"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
# Create tools
|
||||
flip_coin_name = "flip_coin"
|
||||
secret_word = "fourth_secret_word"
|
||||
flip_coin_tool = client.create_or_update_tool(flip_coin, name=flip_coin_name)
|
||||
secret_word_tool = client.create_or_update_tool(fourth_secret_word, name=secret_word)
|
||||
|
||||
# Make tool rules - map coin flip to fourth_secret_word
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=flip_coin_name),
|
||||
ConditionalToolRule(
|
||||
tool_name=flip_coin_name,
|
||||
default_child=flip_coin_name, # Allow any tool to be called if output doesn't match
|
||||
child_output_mapping={
|
||||
"hj2hwibbqm": secret_word
|
||||
}
|
||||
),
|
||||
TerminalToolRule(tool_name=secret_word)
|
||||
]
|
||||
tools = [flip_coin_tool, secret_word_tool]
|
||||
|
||||
# Setup initial agent
|
||||
agent_state = setup_agent(
|
||||
client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules
|
||||
)
|
||||
|
||||
# Call flip_coin first
|
||||
response = client.user_message(agent_id=agent_state.id, message="flip a coin")
|
||||
assert_invoked_function_call(response.messages, flip_coin_name)
|
||||
assert_invoked_function_call(response.messages, secret_word)
|
||||
found_fourth_secret = False
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolCallMessage) and m.tool_call.name == secret_word:
|
||||
found_fourth_secret = True
|
||||
break
|
||||
|
||||
assert found_fourth_secret, "Reloaded agent should remember coin flip result and call fourth_secret_word if True"
|
||||
|
||||
# Reload the agent
|
||||
reloaded_agent = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
assert reloaded_agent.last_function_response is not None
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
@@ -74,8 +74,8 @@ def test_ripple_edit(client, mock_e2b_api_key_none):
|
||||
|
||||
assert set(conversation_agent.memory.list_block_labels()) == {"persona", "human", "fact_block", "rethink_memory_block"}
|
||||
|
||||
rethink_memory_tool = client.create_tool(rethink_memory)
|
||||
finish_rethinking_memory_tool = client.create_tool(finish_rethinking_memory)
|
||||
rethink_memory_tool = client.create_or_update_tool(rethink_memory)
|
||||
finish_rethinking_memory_tool = client.create_or_update_tool(finish_rethinking_memory)
|
||||
offline_memory_agent = client.create_agent(
|
||||
name="offline_memory_agent",
|
||||
agent_type=AgentType.offline_memory_agent,
|
||||
|
||||
@@ -529,6 +529,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
|
||||
def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test that we can update the details of a message"""
|
||||
import json
|
||||
|
||||
# create a message
|
||||
message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
|
||||
@@ -537,7 +538,7 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat
|
||||
assert isinstance(message_response.messages[-1], ToolReturnMessage)
|
||||
message = message_response.messages[-1]
|
||||
|
||||
new_text = "This exact string would never show up in the message???"
|
||||
new_text = json.dumps({"message": "This exact string would never show up in the message???"})
|
||||
new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id)
|
||||
assert new_message.text == new_text
|
||||
|
||||
|
||||
@@ -2,7 +2,12 @@ import pytest
|
||||
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.tool_rule_solver import ToolRuleValidationError
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import (
|
||||
ChildToolRule,
|
||||
ConditionalToolRule,
|
||||
InitToolRule,
|
||||
TerminalToolRule
|
||||
)
|
||||
|
||||
# Constants for tool names used in the tests
|
||||
START_TOOL = "start_tool"
|
||||
@@ -60,7 +65,7 @@ def test_get_allowed_tool_names_no_matching_rule_warning():
|
||||
# Action: Set last tool to an unrecognized tool and check warnings
|
||||
solver.update_tool_usage(UNRECOGNIZED_TOOL)
|
||||
|
||||
# NOTE: removed for now since this warning is getting triggered on every LLM call
|
||||
# # NOTE: removed for now since this warning is getting triggered on every LLM call
|
||||
# with warnings.catch_warnings(record=True) as w:
|
||||
# allowed_tools = solver.get_allowed_tool_names()
|
||||
|
||||
@@ -75,9 +80,9 @@ def test_get_allowed_tool_names_no_matching_rule_error():
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[])
|
||||
|
||||
# Action & Assert: Set last tool to an unrecognized tool and expect RuntimeError when error_on_empty=True
|
||||
# Action & Assert: Set last tool to an unrecognized tool and expect ValueError
|
||||
solver.update_tool_usage(UNRECOGNIZED_TOOL)
|
||||
with pytest.raises(RuntimeError, match="resolved to no more possible tool calls"):
|
||||
with pytest.raises(ValueError, match=f"No tool rule found for {UNRECOGNIZED_TOOL}"):
|
||||
solver.get_allowed_tool_names(error_on_empty=True)
|
||||
|
||||
|
||||
@@ -104,7 +109,46 @@ def test_update_tool_usage_and_get_allowed_tool_names_combined():
|
||||
assert solver.is_terminal_tool(FINAL_TOOL) is True, "Should recognize 'final_tool' as terminal"
|
||||
|
||||
|
||||
def test_tool_rules_with_cycle_detection():
|
||||
def test_conditional_tool_rule():
|
||||
# Setup: Define a conditional tool rule
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||
rule = ConditionalToolRule(
|
||||
tool_name=START_TOOL,
|
||||
default_child=None,
|
||||
child_output_mapping={True: END_TOOL, False: START_TOOL}
|
||||
)
|
||||
solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule])
|
||||
|
||||
# Action & Assert: Verify the rule properties
|
||||
# Step 1: Initially allowed tools
|
||||
assert solver.get_allowed_tool_names() == [START_TOOL], "Initial allowed tool should be 'start_tool'"
|
||||
|
||||
# Step 2: After using 'start_tool'
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [END_TOOL], "After 'start_tool' returns true, should allow 'end_tool'"
|
||||
assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [START_TOOL], "After 'start_tool' returns false, should allow 'start_tool'"
|
||||
|
||||
# Step 3: After using 'end_tool'
|
||||
assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal"
|
||||
|
||||
|
||||
def test_invalid_conditional_tool_rule():
|
||||
# Setup: Define an invalid conditional tool rule
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||
invalid_rule_1 = ConditionalToolRule(
|
||||
tool_name=START_TOOL,
|
||||
default_child=END_TOOL,
|
||||
child_output_mapping={}
|
||||
)
|
||||
|
||||
# Test 1: Missing child output mapping
|
||||
with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."):
|
||||
ToolRulesSolver(tool_rules=[init_rule, invalid_rule_1, terminal_rule])
|
||||
|
||||
|
||||
def test_tool_rules_with_invalid_path():
|
||||
# Setup: Define tool rules with both connected, disconnected nodes and a cycle
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL])
|
||||
@@ -113,15 +157,12 @@ def test_tool_rules_with_cycle_detection():
|
||||
rule_4 = ChildToolRule(tool_name=FINAL_TOOL, children=[END_TOOL]) # Disconnected rule, no cycle here
|
||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||
|
||||
# Action & Assert: Attempt to create the ToolRulesSolver with a cycle should raise ValidationError
|
||||
with pytest.raises(ToolRuleValidationError, match="Tool rules contain cycles"):
|
||||
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule])
|
||||
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule])
|
||||
|
||||
# Extra setup: Define tool rules without a cycle but with hanging nodes
|
||||
rule_5 = ChildToolRule(tool_name=PREP_TOOL, children=[FINAL_TOOL]) # Hanging node with no connection to start_tool
|
||||
|
||||
# Assert that a configuration without cycles does not raise an error
|
||||
try:
|
||||
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_4, rule_5, terminal_rule])
|
||||
except ToolRuleValidationError:
|
||||
pytest.fail("ToolRulesSolver raised ValidationError unexpectedly on a valid DAG with hanging nodes")
|
||||
# Now: add a path from the start tool to the final tool
|
||||
rule_5 = ConditionalToolRule(
|
||||
tool_name=HELPER_TOOL,
|
||||
default_child=FINAL_TOOL,
|
||||
child_output_mapping={True: START_TOOL, False: FINAL_TOOL},
|
||||
)
|
||||
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, rule_5, terminal_rule])
|
||||
|
||||
Reference in New Issue
Block a user