From a274261b649c4ca439f34a40ce240e2c580c62e2 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 6 Nov 2024 19:14:13 -0800 Subject: [PATCH] chore: Add tool rules example (#1998) Co-authored-by: Sarah Wooders --- examples/docs/agent_advanced.py | 11 ++- examples/docs/agent_basic.py | 15 +--- examples/docs/tools.py | 33 ++++---- examples/tool_rule_usage.py | 132 ++++++++++++++++++++++++++++++++ letta/agent.py | 11 ++- letta/services/tool_manager.py | 8 +- tests/test_managers.py | 2 +- 7 files changed, 173 insertions(+), 39 deletions(-) create mode 100644 examples/tool_rule_usage.py diff --git a/examples/docs/agent_advanced.py b/examples/docs/agent_advanced.py index a42f5722..311aa92a 100644 --- a/examples/docs/agent_advanced.py +++ b/examples/docs/agent_advanced.py @@ -5,14 +5,18 @@ client = create_client() # create a new agent agent_state = client.create_agent( + # agent's name (unique per-user, autogenerated if not provided) name="agent_name", + # in-context memory representation with human/persona blocks memory=ChatMemory(human="Name: Sarah", persona="You are a helpful assistant that loves emojis"), + # LLM model & endpoint configuration llm_config=LLMConfig( model="gpt-4", model_endpoint_type="openai", model_endpoint="https://api.openai.com/v1", - context_window=8000, + context_window=8000, # set to <= max context window ), + # embedding model & endpoint configuration (cannot be changed) embedding_config=EmbeddingConfig( embedding_endpoint_type="openai", embedding_endpoint="https://api.openai.com/v1", @@ -20,9 +24,12 @@ agent_state = client.create_agent( embedding_dim=1536, embedding_chunk_size=300, ), + # system instructions for the agent (defaults to `memgpt_chat`) system=gpt_system.get_system_text("memgpt_chat"), - tools=[], + # whether to include base letta tools (default: True) include_base_tools=True, + # list of additional tools (by name) to add to the agent + tools=[], ) print(f"Created agent with name {agent_state.name} and unique ID {agent_state.id}") diff --git a/examples/docs/agent_basic.py b/examples/docs/agent_basic.py index 6f2195e7..d472f39d 100644 --- a/examples/docs/agent_basic.py +++ b/examples/docs/agent_basic.py @@ -3,19 +3,8 @@ from letta import EmbeddingConfig, LLMConfig, create_client client = create_client() # set automatic defaults for LLM/embedding config -client.set_default_llm_config( - LLMConfig(model="gpt-4o-mini", model_endpoint_type="openai", model_endpoint="https://api.openai.com/v1", context_window=128000) -) -client.set_default_embedding_config( - EmbeddingConfig( - embedding_endpoint_type="openai", - embedding_endpoint="https://api.openai.com/v1", - embedding_model="text-embedding-ada-002", - embedding_dim=1536, - embedding_chunk_size=300, - ) -) - +client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4")) +client.set_default_embedding_config(EmbeddingConfig.default_config(model_name="text-embedding-ada-002")) # create a new agent agent_state = client.create_agent() diff --git a/examples/docs/tools.py b/examples/docs/tools.py index 14f3a0e0..382e4520 100644 --- a/examples/docs/tools.py +++ b/examples/docs/tools.py @@ -1,23 +1,14 @@ from letta import EmbeddingConfig, LLMConfig, create_client +from letta.schemas.tool_rule import TerminalToolRule client = create_client() # set automatic defaults for LLM/embedding config -client.set_default_llm_config( - LLMConfig(model="gpt-4", model_endpoint_type="openai", model_endpoint="https://api.openai.com/v1", context_window=8000) -) -client.set_default_embedding_config( - EmbeddingConfig( - embedding_endpoint_type="openai", - embedding_endpoint="https://api.openai.com/v1", - embedding_model="text-embedding-ada-002", - embedding_dim=1536, - embedding_chunk_size=300, - ) -) +client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4")) +client.set_default_embedding_config(EmbeddingConfig.default_config(model_name="text-embedding-ada-002")) # define a function with a docstring -def roll_d20() -> str: +def roll_d20(self) -> str: """ Simulate the roll of a 20-sided die (d20). @@ -38,10 +29,22 @@ def roll_d20() -> str: return output_string -tool = client.create_tool(roll_d20, name="roll_dice") +# create a tool from the function +tool = client.create_tool(roll_d20) +print(f"Created tool with name {tool.name}") # create a new agent -agent_state = client.create_agent(tools=[tool.name]) +agent_state = client.create_agent( + # create the agent with an additional tool + tools=[tool.name], + # add tool rules that terminate execution after specific tools + tool_rules=[ + # exit after roll_d20 is called + TerminalToolRule(tool_name=tool.name), + # exit after send_message is called (default behavior) + TerminalToolRule(tool_name="send_message"), + ], +) print(f"Created agent with name {agent_state.name} with tools {agent_state.tools}") # Message an agent diff --git a/examples/tool_rule_usage.py b/examples/tool_rule_usage.py new file mode 100644 index 00000000..b408b1d0 --- /dev/null +++ b/examples/tool_rule_usage.py @@ -0,0 +1,132 @@ +import os +import uuid + +from letta import create_client +from letta.schemas.letta_message import FunctionCallMessage +from letta.schemas.tool_rule import InitToolRule, TerminalToolRule, ToolRule +from tests.helpers.endpoints_helper import ( + assert_invoked_send_message_with_keyword, + setup_agent, +) +from tests.helpers.utils import cleanup +from tests.test_endpoints import llm_config_dir + +""" +This example shows how you can constrain tool calls in your agent. + +Please note that this currently only works reliably for models with Structured Outputs (e.g. gpt-4o). + +Start by downloading the dependencies. +``` +poetry install --all-extras +``` +""" + +# Tools for this example +# Generate uuid for agent name for this example +namespace = uuid.NAMESPACE_DNS +agent_uuid = str(uuid.uuid5(namespace, "agent_tool_graph")) +config_file = os.path.join(llm_config_dir, "openai-gpt-4o.json") + +"""Contrived tools for this test case""" + + +def first_secret_word(self: "Agent"): + """ + Call this to retrieve the first secret word, which you will need for the second_secret_word function. + """ + return "v0iq020i0g" + + +def second_secret_word(self: "Agent", prev_secret_word: str): + """ + Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error. + + Args: + prev_secret_word (str): The secret word retrieved from calling first_secret_word. + """ + if prev_secret_word != "v0iq020i0g": + raise RuntimeError(f"Expected secret {"v0iq020i0g"}, got {prev_secret_word}") + + return "4rwp2b4gxq" + + +def third_secret_word(self: "Agent", prev_secret_word: str): + """ + Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error. + + Args: + prev_secret_word (str): The secret word retrieved from calling second_secret_word. + """ + if prev_secret_word != "4rwp2b4gxq": + raise RuntimeError(f"Expected secret {"4rwp2b4gxq"}, got {prev_secret_word}") + + return "hj2hwibbqm" + + +def fourth_secret_word(self: "Agent", prev_secret_word: str): + """ + Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error. + + Args: + prev_secret_word (str): The secret word retrieved from calling third_secret_word. + """ + if prev_secret_word != "hj2hwibbqm": + raise RuntimeError(f"Expected secret {"hj2hwibbqm"}, got {prev_secret_word}") + + return "banana" + + +def auto_error(self: "Agent"): + """ + If you call this function, it will throw an error automatically. + """ + raise RuntimeError("This should never be called.") + + +def main(): + # 1. Set up the client + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + # 2. Add all the tools to the client + functions = [first_secret_word, second_secret_word, third_secret_word, fourth_secret_word, auto_error] + tools = [] + for func in functions: + tool = client.create_tool(func) + tools.append(tool) + tool_names = [t.name for t in tools[:-1]] + + # 3. Create the tool rules. It must be called in this order, or there will be an error thrown. + tool_rules = [ + InitToolRule(tool_name="first_secret_word"), + ToolRule(tool_name="first_secret_word", children=["second_secret_word"]), + ToolRule(tool_name="second_secret_word", children=["third_secret_word"]), + ToolRule(tool_name="third_secret_word", children=["fourth_secret_word"]), + ToolRule(tool_name="fourth_secret_word", children=["send_message"]), + TerminalToolRule(tool_name="send_message"), + ] + + # 4. Create the agent + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tools=[t.name for t in tools], tool_rules=tool_rules) + + # 5. Ask for the final secret word + response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?") + + # 6. Here, we thoroughly check the correctness of the response + tool_names += ["send_message"] # Add send message because we expect this to be called at the end + for m in response.messages: + if isinstance(m, FunctionCallMessage): + # Check that it's equal to the first one + assert m.function_call.name == tool_names[0] + # Pop out first one + tool_names = tool_names[1:] + + # Check final send message contains "banana" + assert_invoked_send_message_with_keyword(response.messages, "banana") + print(f"Got successful response from client: \n\n{response}") + cleanup(client=client, agent_uuid=agent_uuid) + + +if __name__ == "__main__": + main() diff --git a/letta/agent.py b/letta/agent.py index 85daaa51..f49351d7 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -253,7 +253,14 @@ class Agent(BaseAgent): if agent_state.tool_rules is None: agent_state.tool_rules = [] - agent_state.tool_rules.append(TerminalToolRule(tool_name="send_message")) + # Define the rule to add + send_message_terminal_rule = TerminalToolRule(tool_name="send_message") + # Check if an equivalent rule is already present + if not any( + isinstance(rule, TerminalToolRule) and rule.tool_name == send_message_terminal_rule.tool_name for rule in agent_state.tool_rules + ): + agent_state.tool_rules.append(send_message_terminal_rule) + self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules) # gpt-4, gpt-3.5-turbo, ... @@ -395,7 +402,6 @@ class Agent(BaseAgent): exec(tool.module, env) else: exec(tool.source_code, env) - self.functions_python[tool.json_schema["name"]] = env[tool.json_schema["name"]] self.functions.append(tool.json_schema) except Exception as e: @@ -787,7 +793,6 @@ class Agent(BaseAgent): # Update ToolRulesSolver state with last called function self.tool_rules_solver.update_tool_usage(function_name) - # Update heartbeat request according to provided tool rules if self.tool_rules_solver.has_children_tools(function_name): heartbeat_request = True diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 778e5307..c60b8ee1 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -35,9 +35,7 @@ class ToolManager: def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" # Derive json_schema - derived_json_schema = pydantic_tool.json_schema or derive_openai_json_schema( - source_code=pydantic_tool.source_code, name=pydantic_tool.name - ) + derived_json_schema = pydantic_tool.json_schema or derive_openai_json_schema(source_code=pydantic_tool.source_code) derived_name = pydantic_tool.name or derived_json_schema["name"] try: @@ -120,8 +118,8 @@ class ToolManager: if "source_code" in update_data.keys() and "json_schema" not in update_data.keys(): pydantic_tool = tool.to_pydantic() - name = update_data["name"] if "name" in update_data.keys() else None - new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code, name=name) + update_data["name"] if "name" in update_data.keys() else None + new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code) tool.json_schema = new_schema diff --git a/tests/test_managers.py b/tests/test_managers.py index 6bfbf548..64c6b3be 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -367,7 +367,7 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_ og_json_schema = tool_fixture["tool_create"].json_schema source_code = parse_source_code(counter_tool) - name = "test_function_name_explicit" + name = "counter_tool" # Create a ToolUpdate object to modify the tool's source_code tool_update = ToolUpdate(name=name, source_code=source_code)