From 7bb0f53c23cce88b382649a5fb891cb9068da915 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 30 Oct 2024 16:53:41 -0700 Subject: [PATCH] feat: Implement tool calling rules for agents (#1954) --- .github/workflows/tests.yml | 147 ++++++++---------- letta/agent.py | 43 ++++- letta/client/client.py | 18 ++- letta/helpers/__init__.py | 1 + letta/helpers/tool_rule_solver.py | 115 ++++++++++++++ letta/llm_api/helpers.py | 4 +- letta/llm_api/llm_api_tools.py | 3 +- letta/llm_api/openai.py | 5 + letta/metadata.py | 44 +++++- letta/schemas/agent.py | 10 +- letta/schemas/tool_rule.py | 25 +++ letta/server/server.py | 1 + .../{gpt-4.json => openai-gpt-4o.json} | 5 +- tests/helpers/endpoints_helper.py | 7 +- tests/test_agent_function_update.py | 44 ------ tests/test_agent_tool_graph.py | 130 ++++++++++++++++ tests/test_endpoints.py | 14 +- tests/test_tool_rule_solver.py | 128 +++++++++++++++ tests/test_tools.py | 2 +- 19 files changed, 588 insertions(+), 158 deletions(-) create mode 100644 letta/helpers/__init__.py create mode 100644 letta/helpers/tool_rule_solver.py create mode 100644 letta/schemas/tool_rule.py rename tests/configs/llm_model_configs/{gpt-4.json => openai-gpt-4o.json} (56%) delete mode 100644 tests/test_agent_function_update.py create mode 100644 tests/test_agent_tool_graph.py create mode 100644 tests/test_tool_rule_solver.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index de9a4792..13f83686 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,10 +14,21 @@ on: branches: [ main ] jobs: - unit-tests: + run-core-unit-tests: runs-on: ubuntu-latest timeout-minutes: 15 - + strategy: + fail-fast: false + matrix: + test_suite: + - "test_local_client.py" + - "test_client.py" + - "test_server.py" + - "test_managers.py" + - "test_tools.py" + - "test_o1_agent.py" + - "test_tool_rule_solver.py" + - "test_agent_tool_graph.py" services: qdrant: image: qdrant/qdrant @@ -25,93 +36,61 @@ jobs: - 6333:6333 steps: - - name: Checkout - uses: actions/checkout@v4 + - name: Checkout + uses: actions/checkout@v4 - - name: Build and run container - run: bash db/run_postgres.sh + - name: Build and run container + run: bash db/run_postgres.sh - - name: "Setup Python, Poetry and Dependencies" - uses: packetcoders/action-setup-cache-python-poetry@main - with: - python-version: "3.12" - poetry-version: "1.8.2" - install-args: "-E dev -E postgres -E milvus -E external-tools -E tests" + - name: Setup Python, Poetry, and Dependencies + uses: packetcoders/action-setup-cache-python-poetry@main + with: + python-version: "3.12" + poetry-version: "1.8.2" + install-args: "-E dev -E postgres -E milvus -E external-tools -E tests" - - name: Run LocalClient tests - env: - LETTA_PG_PORT: 8888 - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_DB: letta - LETTA_PG_HOST: localhost - LETTA_SERVER_PASS: test_server_token - run: | - poetry run pytest -s -vv tests/test_local_client.py + - name: Run core unit tests + env: + LETTA_PG_PORT: 8888 + LETTA_PG_USER: letta + LETTA_PG_PASSWORD: letta + LETTA_PG_DB: letta + LETTA_PG_HOST: localhost + LETTA_SERVER_PASS: test_server_token + run: | + poetry run pytest -s -vv tests/${{ matrix.test_suite }} - - name: Run RESTClient tests - env: - LETTA_PG_PORT: 8888 - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_DB: letta - LETTA_PG_HOST: localhost - LETTA_SERVER_PASS: test_server_token - run: | - poetry run pytest -s -vv tests/test_client.py + misc-unit-tests: + runs-on: ubuntu-latest + needs: run-core-unit-tests + services: + qdrant: + image: qdrant/qdrant + ports: + - 6333:6333 - - name: Run server tests - env: - LETTA_PG_PORT: 8888 - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_DB: letta - LETTA_PG_HOST: localhost - LETTA_SERVER_PASS: test_server_token - run: | - poetry run pytest -s -vv tests/test_server.py + steps: + - name: Checkout + uses: actions/checkout@v4 - - name: Run server manager tests - env: - LETTA_PG_PORT: 8888 - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_DB: letta - LETTA_PG_HOST: localhost - LETTA_SERVER_PASS: test_server_token - run: | - poetry run pytest -s -vv tests/test_managers.py + - name: Build and run container + run: bash db/run_postgres.sh - - name: Run tools tests - env: - LETTA_PG_PORT: 8888 - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_DB: letta - LETTA_PG_HOST: localhost - LETTA_SERVER_PASS: test_server_token - run: | - poetry run pytest -s -vv tests/test_tools.py + - name: Setup Python, Poetry, and Dependencies + uses: packetcoders/action-setup-cache-python-poetry@main + with: + python-version: "3.12" + poetry-version: "1.8.2" + install-args: "-E dev -E postgres -E milvus -E external-tools -E tests" - - name: Run o1 agent tests - env: - LETTA_PG_PORT: 8888 - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_DB: letta - LETTA_PG_HOST: localhost - LETTA_SERVER_PASS: test_server_token - run: | - poetry run pytest -s -vv tests/test_o1_agent.py - - - name: Run tests with pytest - env: - LETTA_PG_PORT: 8888 - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_HOST: localhost - LETTA_PG_DB: letta - LETTA_SERVER_PASS: test_server_token - PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} - run: | - poetry run pytest -s -vv -k "not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests + - name: Run misc unit tests + env: + LETTA_PG_PORT: 8888 + LETTA_PG_USER: letta + LETTA_PG_PASSWORD: letta + LETTA_PG_DB: letta + LETTA_PG_HOST: localhost + LETTA_SERVER_PASS: test_server_token + PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} + run: | + poetry run pytest -s -vv -k "not test_single_path_agent_tool_call_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests diff --git a/letta/agent.py b/letta/agent.py index e18e989a..a5297723 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -20,6 +20,7 @@ from letta.constants import ( REQ_HEARTBEAT_MESSAGE, ) from letta.errors import LLMError +from letta.helpers import ToolRulesSolver from letta.interface import AgentInterface from letta.llm_api.helpers import is_context_overflow_error from letta.llm_api.llm_api_tools import create @@ -43,6 +44,7 @@ from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.passage import Passage from letta.schemas.tool import Tool +from letta.schemas.tool_rule import TerminalToolRule from letta.schemas.usage import LettaUsageStatistics from letta.system import ( get_heartbeat, @@ -242,6 +244,14 @@ class Agent(BaseAgent): # link tools self.link_tools(tools) + # initialize a tool rules solver + if agent_state.tool_rules: + # if there are tool rules, print out a warning + warnings.warn("Tool rules only work reliably for the latest OpenAI models that support structured outputs.") + # add default rule for having send_message be a terminal tool + agent_state.tool_rules.append(TerminalToolRule(tool_name="send_message")) + self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules) + # gpt-4, gpt-3.5-turbo, ... self.model = self.agent_state.llm_config.model @@ -465,15 +475,26 @@ class Agent(BaseAgent): function_call: str = "auto", first_message: bool = False, # hint stream: bool = False, # TODO move to config? + fail_on_empty_response: bool = False, + empty_response_retry_limit: int = 3, ) -> ChatCompletionResponse: """Get response from LLM API""" + # Get the allowed tools based on the ToolRulesSolver state + allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names() + + if not allowed_tool_names: + # if it's empty, any available tools are fair game + allowed_functions = self.functions + else: + allowed_functions = [func for func in self.functions if func["name"] in allowed_tool_names] + try: response = create( # agent_state=self.agent_state, llm_config=self.agent_state.llm_config, messages=message_sequence, user_id=self.agent_state.user_id, - functions=self.functions, + functions=allowed_functions, functions_python=self.functions_python, function_call=function_call, # hint @@ -484,7 +505,15 @@ class Agent(BaseAgent): ) if len(response.choices) == 0 or response.choices[0] is None: - raise Exception(f"API call didn't return a message: {response}") + empty_api_err_message = f"API call didn't return a message: {response}" + if fail_on_empty_response or empty_response_retry_limit == 0: + raise Exception(empty_api_err_message) + else: + # Decrement retry limit and try again + warnings.warn(empty_api_err_message) + return self._get_ai_reply( + message_sequence, function_call, first_message, stream, fail_on_empty_response, empty_response_retry_limit - 1 + ) # special case for 'length' if response.choices[0].finish_reason == "length": @@ -515,6 +544,7 @@ class Agent(BaseAgent): assert response_message_id.startswith("message-"), response_message_id messages = [] # append these to the history when done + function_name = None # Step 2: check if LLM wanted to call a function if response_message.function_call or (response_message.tool_calls is not None and len(response_message.tool_calls) > 0): @@ -724,6 +754,15 @@ class Agent(BaseAgent): # TODO: @charles please check this self.rebuild_memory() + # Update ToolRulesSolver state with last called function + self.tool_rules_solver.update_tool_usage(function_name) + + # Update heartbeat request according to provided tool rules + if self.tool_rules_solver.has_children_tools(function_name): + heartbeat_request = True + elif self.tool_rules_solver.is_terminal_tool(function_name): + heartbeat_request = False + return messages, heartbeat_request, function_failed def step( diff --git a/letta/client/client.py b/letta/client/client.py index 30b53f2d..f8005432 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -43,6 +43,7 @@ from letta.schemas.organization import Organization from letta.schemas.passage import Passage from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.schemas.tool import Tool, ToolCreate, ToolUpdate +from letta.schemas.tool_rule import BaseToolRule from letta.server.rest_api.interface import QueuingInterface from letta.server.server import SyncServer from letta.utils import get_human_text, get_persona_text @@ -140,6 +141,8 @@ class AbstractClient(object): agent_id: Optional[str] = None, name: Optional[str] = None, stream: Optional[bool] = False, + stream_steps: bool = False, + stream_tokens: bool = False, include_full_message: Optional[bool] = False, ) -> LettaResponse: raise NotImplementedError @@ -196,7 +199,6 @@ class AbstractClient(object): self, func, name: Optional[str] = None, - update: Optional[bool] = True, tags: Optional[List[str]] = None, ) -> Tool: raise NotImplementedError @@ -405,7 +407,7 @@ class RESTClient(AbstractClient): # add memory tools memory_functions = get_memory_functions(memory) for func_name, func in memory_functions.items(): - tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"], update=True) + tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"]) tool_names.append(tool.name) # check if default configs are provided @@ -1268,7 +1270,6 @@ class RESTClient(AbstractClient): self, func: Callable, name: Optional[str] = None, - update: Optional[bool] = True, # TODO: actually use this tags: Optional[List[str]] = None, ) -> Tool: """ @@ -1278,7 +1279,6 @@ class RESTClient(AbstractClient): func (callable): The function to create a tool for. name: (str): Name of the tool (must be unique per-user.) tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. - update (bool, optional): Update the tool if it already exists. Defaults to True. Returns: tool (Tool): The created tool. @@ -1628,6 +1628,7 @@ class LocalClient(AbstractClient): system: Optional[str] = None, # tools tools: Optional[List[str]] = None, + tool_rules: Optional[List[BaseToolRule]] = None, include_base_tools: Optional[bool] = True, # metadata metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, @@ -1642,6 +1643,7 @@ class LocalClient(AbstractClient): memory (Memory): Memory configuration system (str): System configuration tools (List[str]): List of tools + tool_rules (Optional[List[BaseToolRule]]): List of tool rules include_base_tools (bool): Include base tools metadata (Dict): Metadata description (str): Description @@ -1663,7 +1665,7 @@ class LocalClient(AbstractClient): # add memory tools memory_functions = get_memory_functions(memory) for func_name, func in memory_functions.items(): - tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"], update=True) + tool = self.create_tool(func, name=func_name, tags=["memory", "letta-base"]) tool_names.append(tool.name) self.interface.clear() @@ -1680,6 +1682,7 @@ class LocalClient(AbstractClient): metadata_=metadata, memory=memory, tools=tool_names, + tool_rules=tool_rules, system=system, agent_type=agent_type, llm_config=llm_config if llm_config else self._default_llm_config, @@ -2255,8 +2258,8 @@ class LocalClient(AbstractClient): self, func, name: Optional[str] = None, - update: Optional[bool] = True, # TODO: actually use this tags: Optional[List[str]] = None, + description: Optional[str] = None, ) -> Tool: """ Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent. @@ -2265,7 +2268,7 @@ class LocalClient(AbstractClient): func (callable): The function to create a tool for. name: (str): Name of the tool (must be unique per-user.) tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. - update (bool, optional): Update the tool if it already exists. Defaults to True. + description (str, optional): The description. Returns: tool (Tool): The created tool. @@ -2285,6 +2288,7 @@ class LocalClient(AbstractClient): source_code=source_code, name=name, tags=tags, + description=description, ), actor=self.user, ) diff --git a/letta/helpers/__init__.py b/letta/helpers/__init__.py new file mode 100644 index 00000000..62e8d709 --- /dev/null +++ b/letta/helpers/__init__.py @@ -0,0 +1 @@ +from letta.helpers.tool_rule_solver import ToolRulesSolver diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py new file mode 100644 index 00000000..aafaa745 --- /dev/null +++ b/letta/helpers/tool_rule_solver.py @@ -0,0 +1,115 @@ +import warnings +from typing import Dict, List, Optional, Set + +from pydantic import BaseModel, Field + +from letta.schemas.tool_rule import ( + BaseToolRule, + InitToolRule, + TerminalToolRule, + ToolRule, +) + + +class ToolRuleValidationError(Exception): + """Custom exception for tool rule validation errors in ToolRulesSolver.""" + + def __init__(self, message: str): + super().__init__(f"ToolRuleValidationError: {message}") + + +class ToolRulesSolver(BaseModel): + init_tool_rules: List[InitToolRule] = Field( + default_factory=list, description="Initial tool rules to be used at the start of tool execution." + ) + tool_rules: List[ToolRule] = Field( + default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions." + ) + terminal_tool_rules: List[TerminalToolRule] = Field( + default_factory=list, description="Terminal tool rules that end the agent loop if called." + ) + last_tool_name: Optional[str] = Field(None, description="The most recent tool used, updated with each tool call.") + + def __init__(self, tool_rules: List[BaseToolRule], **kwargs): + super().__init__(**kwargs) + # Separate the provided tool rules into init, standard, and terminal categories + for rule in tool_rules: + if isinstance(rule, InitToolRule): + self.init_tool_rules.append(rule) + elif isinstance(rule, ToolRule): + self.tool_rules.append(rule) + elif isinstance(rule, TerminalToolRule): + self.terminal_tool_rules.append(rule) + + # Validate the tool rules to ensure they form a DAG + if not self.validate_tool_rules(): + raise ToolRuleValidationError("Tool rules contain cycles, which are not allowed in a valid configuration.") + + def update_tool_usage(self, tool_name: str): + """Update the internal state to track the last tool called.""" + self.last_tool_name = tool_name + + def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]: + """Get a list of tool names allowed based on the last tool called.""" + if self.last_tool_name is None: + # Use initial tool rules if no tool has been called yet + return [rule.tool_name for rule in self.init_tool_rules] + else: + # Find a matching ToolRule for the last tool used + current_rule = next((rule for rule in self.tool_rules if rule.tool_name == self.last_tool_name), None) + + # Return children which must exist on ToolRule + if current_rule: + return current_rule.children + + # Default to empty if no rule matches + message = "User provided tool rules and execution state resolved to no more possible tool calls." + if error_on_empty: + raise RuntimeError(message) + else: + warnings.warn(message) + return [] + + def is_terminal_tool(self, tool_name: str) -> bool: + """Check if the tool is defined as a terminal tool in the terminal tool rules.""" + return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules) + + def has_children_tools(self, tool_name): + """Check if the tool has children tools""" + return any(rule.tool_name == tool_name for rule in self.tool_rules) + + def validate_tool_rules(self) -> bool: + """ + Validate that the tool rules define a directed acyclic graph (DAG). + Returns True if valid (no cycles), otherwise False. + """ + # Build adjacency list for the tool graph + adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules} + + # Track visited nodes + visited: Set[str] = set() + path_stack: Set[str] = set() + + # Define DFS helper function + def dfs(tool_name: str) -> bool: + if tool_name in path_stack: + return False # Cycle detected + if tool_name in visited: + return True # Already validated + + # Mark the node as visited in the current path + path_stack.add(tool_name) + for child in adjacency_list.get(tool_name, []): + if not dfs(child): + return False # Cycle detected in DFS + path_stack.remove(tool_name) # Remove from current path + visited.add(tool_name) + return True + + # Run DFS from each tool in `tool_rules` + for rule in self.tool_rules: + if rule.tool_name not in visited: + if not dfs(rule.tool_name): + return False # Cycle found, invalid tool rules + + return True # No cycles, valid DAG diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py index 2ebc7ae1..5048af74 100644 --- a/letta/llm_api/helpers.py +++ b/letta/llm_api/helpers.py @@ -16,9 +16,11 @@ def convert_to_structured_output(openai_function: dict) -> dict: See: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas """ + description = openai_function["description"] if "description" in openai_function else "" + structured_output = { "name": openai_function["name"], - "description": openai_function["description"], + "description": description, "strict": True, "parameters": {"type": "object", "properties": {}, "additionalProperties": False, "required": []}, } diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 7af35721..95f0e5ac 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -106,7 +106,7 @@ def create( messages: List[Message], user_id: Optional[str] = None, # option UUID to associate request with functions: Optional[list] = None, - functions_python: Optional[list] = None, + functions_python: Optional[dict] = None, function_call: str = "auto", # hint first_message: bool = False, @@ -140,7 +140,6 @@ def create( raise ValueError(f"OpenAI key is missing from letta config file") data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens) - if stream: # Client requested token streaming data.stream = True assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance( diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index bb32662a..231cde29 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -530,7 +530,12 @@ def openai_chat_completions_request( data.pop("tools") data.pop("tool_choice", None) # extra safe, should exist always (default="auto") + if "tools" in data: + for tool in data["tools"]: + tool["function"] = convert_to_structured_output(tool["function"]) + response_json = make_post_request(url, headers, data) + return ChatCompletionResponse(**response_json) diff --git a/letta/metadata.py b/letta/metadata.py index 9c2761d2..d1b9a025 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -30,6 +30,12 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction from letta.schemas.source import Source +from letta.schemas.tool_rule import ( + BaseToolRule, + InitToolRule, + TerminalToolRule, + ToolRule, +) from letta.schemas.user import User from letta.settings import settings from letta.utils import enforce_types, get_utc_time, printd @@ -196,6 +202,41 @@ def generate_api_key(prefix="sk-", length=51) -> str: return new_key +class ToolRulesColumn(TypeDecorator): + """Custom type for storing a list of ToolRules as JSON""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value: List[BaseToolRule], dialect): + """Convert a list of ToolRules to JSON-serializable format.""" + if value: + return [rule.model_dump() for rule in value] + return value + + def process_result_value(self, value, dialect) -> List[BaseToolRule]: + """Convert JSON back to a list of ToolRules.""" + if value: + return [self.deserialize_tool_rule(rule_data) for rule_data in value] + return value + + @staticmethod + def deserialize_tool_rule(data: dict) -> BaseToolRule: + """Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'.""" + rule_type = data.get("type") # Remove 'type' field if it exists since it is a class var + if rule_type == "InitToolRule": + return InitToolRule(**data) + elif rule_type == "TerminalToolRule": + return TerminalToolRule(**data) + elif rule_type == "ToolRule": + return ToolRule(**data) + else: + raise ValueError(f"Unknown tool rule type: {rule_type}") + + class AgentModel(Base): """Defines data model for storing Passages (consisting of text, embedding)""" @@ -212,7 +253,6 @@ class AgentModel(Base): message_ids = Column(JSON) memory = Column(JSON) system = Column(String) - tools = Column(JSON) # configs agent_type = Column(String) @@ -224,6 +264,7 @@ class AgentModel(Base): # tools tools = Column(JSON) + tool_rules = Column(ToolRulesColumn) Index(__tablename__ + "_idx_user", user_id), @@ -241,6 +282,7 @@ class AgentModel(Base): memory=Memory.load(self.memory), # load dictionary system=self.system, tools=self.tools, + tool_rules=self.tool_rules, agent_type=self.agent_type, llm_config=self.llm_config, embedding_config=self.embedding_config, diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 243c894d..2fa93939 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -11,6 +11,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import UsageStatistics +from letta.schemas.tool_rule import BaseToolRule class BaseAgent(LettaBase, validate_assignment=True): @@ -61,6 +62,9 @@ class AgentState(BaseAgent, validate_assignment=True): # tools tools: List[str] = Field(..., description="The tools used by the agent.") + # tool rules + tool_rules: List[BaseToolRule] = Field(..., description="The list of tool rules.") + # system prompt system: str = Field(..., description="The system prompt used by the agent.") @@ -104,6 +108,7 @@ class CreateAgent(BaseAgent): message_ids: Optional[List[uuid.UUID]] = Field(None, description="The ids of the messages in the agent's in-context memory.") memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.") tools: Optional[List[str]] = Field(None, description="The tools used by the agent.") + tool_rules: Optional[List[BaseToolRule]] = Field(None, description="The tool rules governing the agent.") system: Optional[str] = Field(None, description="The system prompt used by the agent.") agent_type: Optional[AgentType] = Field(None, description="The type of agent.") llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.") @@ -156,8 +161,3 @@ class AgentStepResponse(BaseModel): ..., description="Whether the agent step ended because the in-context memory is near its limit." ) usage: UsageStatistics = Field(..., description="Usage statistics of the LLM call during the agent's step.") - - -class RemoveToolsFromAgent(BaseModel): - agent_id: str = Field(..., description="The id of the agent.") - tool_ids: Optional[List[str]] = Field(None, description="The tools to be removed from the agent.") diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py new file mode 100644 index 00000000..9fa20dd8 --- /dev/null +++ b/letta/schemas/tool_rule.py @@ -0,0 +1,25 @@ +from typing import List + +from pydantic import Field + +from letta.schemas.letta_base import LettaBase + + +class BaseToolRule(LettaBase): + __id_prefix__ = "tool_rule" + tool_name: str = Field(..., description="The name of the tool. Must exist in the database for the user's organization.") + + +class ToolRule(BaseToolRule): + type: str = Field("ToolRule") + children: List[str] = Field(..., description="The children tools that can be invoked.") + + +class InitToolRule(BaseToolRule): + type: str = Field("InitToolRule") + """Represents the initial tool rule configuration.""" + + +class TerminalToolRule(BaseToolRule): + type: str = Field("TerminalToolRule") + """Represents a terminal tool rule configuration where if this tool gets called, it must end the agent loop.""" diff --git a/letta/server/server.py b/letta/server/server.py index 83ba2a96..52660064 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -842,6 +842,7 @@ class SyncServer(Server): name=request.name, user_id=user_id, tools=request.tools if request.tools else [], + tool_rules=request.tool_rules if request.tool_rules else [], agent_type=request.agent_type or AgentType.memgpt_agent, llm_config=llm_config, embedding_config=embedding_config, diff --git a/tests/configs/llm_model_configs/gpt-4.json b/tests/configs/llm_model_configs/openai-gpt-4o.json similarity index 56% rename from tests/configs/llm_model_configs/gpt-4.json rename to tests/configs/llm_model_configs/openai-gpt-4o.json index dedc8cec..8e2cd44a 100644 --- a/tests/configs/llm_model_configs/gpt-4.json +++ b/tests/configs/llm_model_configs/openai-gpt-4o.json @@ -1,8 +1,7 @@ { "context_window": 8192, - "model": "gpt-4", + "model": "gpt-4o", "model_endpoint_type": "openai", "model_endpoint": "https://api.openai.com/v1", - "model_wrapper": null, - "put_inner_thoughts_in_kwargs": false + "model_wrapper": null } diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 23f27e9b..49bb0dc1 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -4,6 +4,7 @@ import uuid from typing import Callable, List, Optional, Union from letta.llm_api.helpers import unpack_inner_thoughts_from_kwargs +from letta.schemas.tool_rule import BaseToolRule logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -61,6 +62,8 @@ def setup_agent( memory_human_str: str = get_human_text(DEFAULT_HUMAN), memory_persona_str: str = get_persona_text(DEFAULT_PERSONA), tools: Optional[List[str]] = None, + tool_rules: Optional[List[BaseToolRule]] = None, + agent_uuid: str = agent_uuid, ) -> AgentState: config_data = json.load(open(filename, "r")) llm_config = LLMConfig(**config_data) @@ -73,7 +76,9 @@ def setup_agent( config.save() memory = ChatMemory(human=memory_human_str, persona=memory_persona_str) - agent_state = client.create_agent(name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tools=tools) + agent_state = client.create_agent( + name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tools=tools, tool_rules=tool_rules + ) return agent_state diff --git a/tests/test_agent_function_update.py b/tests/test_agent_function_update.py deleted file mode 100644 index 5adcab25..00000000 --- a/tests/test_agent_function_update.py +++ /dev/null @@ -1,44 +0,0 @@ -import pytest - -from letta import create_client -from letta.schemas.message import Message -from letta.utils import assistant_function_to_tool, json_dumps -from tests.utils import wipe_config - - -def hello_world(self) -> str: - """Test function for agent to gain access to - - Returns: - str: A message for the world - """ - return "hello, world!" - - -@pytest.fixture(scope="module") -def agent(): - """Create a test agent that we can call functions on""" - wipe_config() - global client - # create letta client - client = create_client() - - agent_state = client.create_agent() - - return client.server._get_or_load_agent(agent_id=agent_state.id) - - -@pytest.fixture(scope="module") -def ai_function_call(): - return Message( - **assistant_function_to_tool( - { - "role": "assistant", - "text": "I will now call hello world", # TODO: change to `content` once `Message` is updated - "function_call": { - "name": "hello_world", - "arguments": json_dumps({}), - }, - } - ) - ) diff --git a/tests/test_agent_tool_graph.py b/tests/test_agent_tool_graph.py new file mode 100644 index 00000000..2988949d --- /dev/null +++ b/tests/test_agent_tool_graph.py @@ -0,0 +1,130 @@ +import os +import uuid + +import pytest + +from letta import create_client +from letta.schemas.letta_message import FunctionCallMessage +from letta.schemas.tool_rule import InitToolRule, TerminalToolRule, ToolRule +from tests.helpers.endpoints_helper import ( + assert_invoked_function_call, + assert_invoked_send_message_with_keyword, + assert_sanity_checks, + setup_agent, +) +from tests.helpers.utils import cleanup +from tests.test_endpoints import llm_config_dir + +# Generate uuid for agent name for this example +namespace = uuid.NAMESPACE_DNS +agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph")) +config_file = os.path.join(llm_config_dir, "openai-gpt-4o.json") + +"""Contrived tools for this test case""" + + +def first_secret_word(self: "Agent"): + """ + Call this to retrieve the first secret word, which you will need for the second_secret_word function. + """ + return "v0iq020i0g" + + +def second_secret_word(self: "Agent", prev_secret_word: str): + """ + Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error. + + Args: + prev_secret_word (str): The secret word retrieved from calling first_secret_word. + """ + if prev_secret_word != "v0iq020i0g": + raise RuntimeError(f"Expected secret {"v0iq020i0g"}, got {prev_secret_word}") + + return "4rwp2b4gxq" + + +def third_secret_word(self: "Agent", prev_secret_word: str): + """ + Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error. + + Args: + prev_secret_word (str): The secret word retrieved from calling second_secret_word. + """ + if prev_secret_word != "4rwp2b4gxq": + raise RuntimeError(f"Expected secret {"4rwp2b4gxq"}, got {prev_secret_word}") + + return "hj2hwibbqm" + + +def fourth_secret_word(self: "Agent", prev_secret_word: str): + """ + Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error. + + Args: + prev_secret_word (str): The secret word retrieved from calling third_secret_word. + """ + if prev_secret_word != "hj2hwibbqm": + raise RuntimeError(f"Expected secret {"hj2hwibbqm"}, got {prev_secret_word}") + + return "banana" + + +def auto_error(self: "Agent"): + """ + If you call this function, it will throw an error automatically. + """ + raise RuntimeError("This should never be called.") + + +@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely +def test_single_path_agent_tool_call_graph(): + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + # Add tools + t1 = client.create_tool(first_secret_word) + t2 = client.create_tool(second_secret_word) + t3 = client.create_tool(third_secret_word) + t4 = client.create_tool(fourth_secret_word) + t_err = client.create_tool(auto_error) + tools = [t1, t2, t3, t4, t_err] + + # Make tool rules + tool_rules = [ + InitToolRule(tool_name="first_secret_word"), + ToolRule(tool_name="first_secret_word", children=["second_secret_word"]), + ToolRule(tool_name="second_secret_word", children=["third_secret_word"]), + ToolRule(tool_name="third_secret_word", children=["fourth_secret_word"]), + ToolRule(tool_name="fourth_secret_word", children=["send_message"]), + TerminalToolRule(tool_name="send_message"), + ] + + # Make agent state + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tools=[t.name for t in tools], tool_rules=tool_rules) + response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?") + + # Make checks + assert_sanity_checks(response) + + # Assert the tools were called + assert_invoked_function_call(response.messages, "first_secret_word") + assert_invoked_function_call(response.messages, "second_secret_word") + assert_invoked_function_call(response.messages, "third_secret_word") + assert_invoked_function_call(response.messages, "fourth_secret_word") + + # Check ordering of tool calls + tool_names = [t.name for t in [t1, t2, t3, t4]] + tool_names += ["send_message"] + for m in response.messages: + if isinstance(m, FunctionCallMessage): + # Check that it's equal to the first one + assert m.function_call.name == tool_names[0] + + # Pop out first one + tool_names = tool_names[1:] + + # Check final send message contains "done" + assert_invoked_send_message_with_keyword(response.messages, "banana") + + print(f"Got successful response from client: \n\n{response}") + cleanup(client=client, agent_uuid=agent_uuid) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 08812311..76494c09 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -59,7 +59,7 @@ def retry_until_threshold(threshold=0.5, max_attempts=10, sleep_time_seconds=4): # OPENAI TESTS # ====================================================================================================================== def test_openai_gpt_4_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "gpt-4.json") + filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_first_response_is_valid_for_llm_endpoint(filename) # Log out successful response print(f"Got successful response from client: \n\n{response}") @@ -67,42 +67,42 @@ def test_openai_gpt_4_returns_valid_first_message(): def test_openai_gpt_4_returns_keyword(): keyword = "banana" - filename = os.path.join(llm_config_dir, "gpt-4.json") + filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_response_contains_keyword(filename, keyword=keyword) # Log out successful response print(f"Got successful response from client: \n\n{response}") def test_openai_gpt_4_uses_external_tool(): - filename = os.path.join(llm_config_dir, "gpt-4.json") + filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_agent_uses_external_tool(filename) # Log out successful response print(f"Got successful response from client: \n\n{response}") def test_openai_gpt_4_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "gpt-4.json") + filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_agent_recall_chat_memory(filename) # Log out successful response print(f"Got successful response from client: \n\n{response}") def test_openai_gpt_4_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "gpt-4.json") + filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_agent_archival_memory_retrieval(filename) # Log out successful response print(f"Got successful response from client: \n\n{response}") def test_openai_gpt_4_archival_memory_insert(): - filename = os.path.join(llm_config_dir, "gpt-4.json") + filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_agent_archival_memory_insert(filename) # Log out successful response print(f"Got successful response from client: \n\n{response}") def test_openai_gpt_4_edit_core_memory(): - filename = os.path.join(llm_config_dir, "gpt-4.json") + filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_agent_edit_core_memory(filename) # Log out successful response print(f"Got successful response from client: \n\n{response}") diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py new file mode 100644 index 00000000..c570ab6d --- /dev/null +++ b/tests/test_tool_rule_solver.py @@ -0,0 +1,128 @@ +import warnings + +import pytest + +from letta.helpers import ToolRulesSolver +from letta.helpers.tool_rule_solver import ToolRuleValidationError +from letta.schemas.tool_rule import InitToolRule, TerminalToolRule, ToolRule + +# Constants for tool names used in the tests +START_TOOL = "start_tool" +PREP_TOOL = "prep_tool" +NEXT_TOOL = "next_tool" +HELPER_TOOL = "helper_tool" +FINAL_TOOL = "final_tool" +END_TOOL = "end_tool" +UNRECOGNIZED_TOOL = "unrecognized_tool" + + +def test_get_allowed_tool_names_with_init_rules(): + # Setup: Initial tool rule configuration + init_rule_1 = InitToolRule(tool_name=START_TOOL) + init_rule_2 = InitToolRule(tool_name=PREP_TOOL) + solver = ToolRulesSolver(init_tool_rules=[init_rule_1, init_rule_2], tool_rules=[], terminal_tool_rules=[]) + + # Action: Get allowed tool names when no tool has been called + allowed_tools = solver.get_allowed_tool_names() + + # Assert: Both init tools should be allowed initially + assert allowed_tools == [START_TOOL, PREP_TOOL], "Should allow only InitToolRule tools at the start" + + +def test_get_allowed_tool_names_with_subsequent_rule(): + # Setup: Tool rule sequence + init_rule = InitToolRule(tool_name=START_TOOL) + rule_1 = ToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[]) + + # Action: Update usage and get allowed tools + solver.update_tool_usage(START_TOOL) + allowed_tools = solver.get_allowed_tool_names() + + # Assert: Only children of "start_tool" should be allowed + assert allowed_tools == [NEXT_TOOL, HELPER_TOOL], "Should allow only children of the last tool used" + + +def test_is_terminal_tool(): + # Setup: Terminal tool rule configuration + init_rule = InitToolRule(tool_name=START_TOOL) + terminal_rule = TerminalToolRule(tool_name=END_TOOL) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[terminal_rule]) + + # Action & Assert: Verify terminal and non-terminal tools + assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as a terminal tool" + assert solver.is_terminal_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a terminal tool" + + +def test_get_allowed_tool_names_no_matching_rule_warning(): + # Setup: Tool rules with no matching rule for the last tool + init_rule = InitToolRule(tool_name=START_TOOL) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[]) + + # Action: Set last tool to an unrecognized tool and check warnings + solver.update_tool_usage(UNRECOGNIZED_TOOL) + + with warnings.catch_warnings(record=True) as w: + allowed_tools = solver.get_allowed_tool_names() + + # Assert: Expecting a warning and an empty list of allowed tools + assert len(w) == 1, "Expected a warning for no matching rule" + assert "resolved to no more possible tool calls" in str(w[-1].message) + assert allowed_tools == [], "Should return an empty list if no matching rule" + + +def test_get_allowed_tool_names_no_matching_rule_error(): + # Setup: Tool rules with no matching rule for the last tool + init_rule = InitToolRule(tool_name=START_TOOL) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[]) + + # Action & Assert: Set last tool to an unrecognized tool and expect RuntimeError when error_on_empty=True + solver.update_tool_usage(UNRECOGNIZED_TOOL) + with pytest.raises(RuntimeError, match="resolved to no more possible tool calls"): + solver.get_allowed_tool_names(error_on_empty=True) + + +def test_update_tool_usage_and_get_allowed_tool_names_combined(): + # Setup: More complex rule chaining + init_rule = InitToolRule(tool_name=START_TOOL) + rule_1 = ToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) + rule_2 = ToolRule(tool_name=NEXT_TOOL, children=[FINAL_TOOL]) + terminal_rule = TerminalToolRule(tool_name=FINAL_TOOL) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1, rule_2], terminal_tool_rules=[terminal_rule]) + + # Step 1: Initially allowed tools + assert solver.get_allowed_tool_names() == [START_TOOL], "Initial allowed tool should be 'start_tool'" + + # Step 2: After using 'start_tool' + solver.update_tool_usage(START_TOOL) + assert solver.get_allowed_tool_names() == [NEXT_TOOL], "After 'start_tool', should allow 'next_tool'" + + # Step 3: After using 'next_tool' + solver.update_tool_usage(NEXT_TOOL) + assert solver.get_allowed_tool_names() == [FINAL_TOOL], "After 'next_tool', should allow 'final_tool'" + + # Step 4: 'final_tool' should be terminal + assert solver.is_terminal_tool(FINAL_TOOL) is True, "Should recognize 'final_tool' as terminal" + + +def test_tool_rules_with_cycle_detection(): + # Setup: Define tool rules with both connected, disconnected nodes and a cycle + init_rule = InitToolRule(tool_name=START_TOOL) + rule_1 = ToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) + rule_2 = ToolRule(tool_name=NEXT_TOOL, children=[HELPER_TOOL]) + rule_3 = ToolRule(tool_name=HELPER_TOOL, children=[START_TOOL]) # This creates a cycle: start -> next -> helper -> start + rule_4 = ToolRule(tool_name=FINAL_TOOL, children=[END_TOOL]) # Disconnected rule, no cycle here + terminal_rule = TerminalToolRule(tool_name=END_TOOL) + + # Action & Assert: Attempt to create the ToolRulesSolver with a cycle should raise ValidationError + with pytest.raises(ToolRuleValidationError, match="Tool rules contain cycles"): + ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule]) + + # Extra setup: Define tool rules without a cycle but with hanging nodes + rule_5 = ToolRule(tool_name=PREP_TOOL, children=[FINAL_TOOL]) # Hanging node with no connection to start_tool + + # Assert that a configuration without cycles does not raise an error + try: + ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_4, rule_5, terminal_rule]) + except ToolRuleValidationError: + pytest.fail("ToolRulesSolver raised ValidationError unexpectedly on a valid DAG with hanging nodes") diff --git a/tests/test_tools.py b/tests/test_tools.py index c5a6d2ec..f987afca 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -140,7 +140,7 @@ def test_create_agent_tool(client): return None # TODO: test attaching and using function on agent - tool = client.create_tool(core_memory_clear, tags=["extras"], update=True) + tool = client.create_tool(core_memory_clear, tags=["extras"]) print(f"Created tool", tool.name) # create agent with tool