From 640d4c0a2287b389715abd60106e2aa425e9205a Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Wed, 6 Nov 2024 14:06:10 -0800 Subject: [PATCH] fix: fix issue with linking tools and adding new tools (#1988) Co-authored-by: Matt Zhou --- letta/agent.py | 3 +-- letta/client/client.py | 3 +++ letta/functions/functions.py | 2 +- letta/schemas/tool.py | 6 ++++- letta/server/rest_api/routers/v1/tools.py | 2 +- letta/server/server.py | 28 +++++++++++++++-------- letta/services/tool_manager.py | 15 ++++-------- tests/test_managers.py | 3 +-- 8 files changed, 35 insertions(+), 27 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 03056ce2..85daaa51 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -3,7 +3,6 @@ import inspect import traceback import warnings from abc import ABC, abstractmethod -from lib2to3.fixer_util import is_list from typing import List, Literal, Optional, Tuple, Union from tqdm import tqdm @@ -252,7 +251,7 @@ class Agent(BaseAgent): warnings.warn("Tool rules only work reliably for the latest OpenAI models that support structured outputs.") # add default rule for having send_message be a terminal tool - if not is_list(agent_state.tool_rules): + if agent_state.tool_rules is None: agent_state.tool_rules = [] agent_state.tool_rules.append(TerminalToolRule(tool_name="send_message")) self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules) diff --git a/letta/client/client.py b/letta/client/client.py index 128ab4de..840c85b9 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -77,6 +77,7 @@ class AbstractClient(object): memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)), system: Optional[str] = None, tools: Optional[List[str]] = None, + tool_rules: Optional[List[BaseToolRule]] = None, include_base_tools: Optional[bool] = True, metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, description: Optional[str] = None, @@ -372,6 +373,7 @@ class RESTClient(AbstractClient): system: Optional[str] = None, # tools tools: Optional[List[str]] = None, + tool_rules: Optional[List[BaseToolRule]] = None, include_base_tools: Optional[bool] = True, # metadata metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, @@ -425,6 +427,7 @@ class RESTClient(AbstractClient): metadata_=metadata, memory=memory, tools=tool_names, + tool_rules=tool_rules, system=system, agent_type=agent_type, llm_config=llm_config if llm_config else self._default_llm_config, diff --git a/letta/functions/functions.py b/letta/functions/functions.py index cc2fc358..0d4ba82a 100644 --- a/letta/functions/functions.py +++ b/letta/functions/functions.py @@ -9,7 +9,7 @@ from letta.constants import CLI_WARNING_PREFIX from letta.functions.schema_generator import generate_schema -def derive_openai_json_schema(source_code: str, name: Optional[str]) -> dict: +def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> dict: # auto-generate openai schema try: # Define a custom environment with necessary imports diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 34733a18..f98ef813 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -68,7 +68,7 @@ class ToolCreate(LettaBase): tags: List[str] = Field([], description="Metadata tags.") module: Optional[str] = Field(None, description="The source code of the function.") source_code: str = Field(..., description="The source code of the function.") - source_type: str = Field(..., description="The source type of the function.") + source_type: str = Field("python", description="The source type of the function.") json_schema: Optional[Dict] = Field( None, description="The JSON schema of the function (auto-generated from source_code if not provided)" ) @@ -216,3 +216,7 @@ class ToolUpdate(LettaBase): json_schema: Optional[Dict] = Field( None, description="The JSON schema of the function (auto-generated from source_code if not provided)" ) + + class Config: + extra = "ignore" # Allows extra fields without validation errors + # TODO: Remove this, and clean usage of ToolUpdate everywhere else diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 22f3dc03..117ce38c 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -104,7 +104,7 @@ def update_tool( Update an existing tool """ actor = server.get_user_or_default(user_id=user_id) - return server.tool_manager.update_tool_by_id(tool_id, actor.id, request) + return server.tool_manager.update_tool_by_id(tool_id=tool_id, tool_update=request, actor=actor) @router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools") diff --git a/letta/server/server.py b/letta/server/server.py index 55b46135..03272a77 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -894,6 +894,7 @@ class SyncServer(Server): assert isinstance(agent.agent_state.memory, Memory), f"Invalid memory type: {type(agent_state.memory)}" # return AgentState + return agent.agent_state def update_agent( @@ -935,17 +936,26 @@ class SyncServer(Server): # Replace tools and also re-link # (1) get tools + make sure they exist - tool_objs = [] - for tool_name in request.tools: - tool_obj = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) - assert tool_obj, f"Tool {tool_name} does not exist" - tool_objs.append(tool_obj) + # Current and target tools as sets of tool names + current_tools = set(letta_agent.agent_state.tools) + target_tools = set(request.tools) - # (2) replace the list of tool names ("ids") inside the agent state - letta_agent.agent_state.tools = request.tools + # Calculate tools to add and remove + tools_to_add = target_tools - current_tools + tools_to_remove = current_tools - target_tools - # (3) then attempt to link the tools modules - letta_agent.link_tools(tool_objs) + # Fetch tool objects for those to add and remove + tools_to_add = [self.tool_manager.get_tool_by_name(tool_name=tool, actor=actor) for tool in tools_to_add] + tools_to_remove = [self.tool_manager.get_tool_by_name(tool_name=tool, actor=actor) for tool in tools_to_remove] + + # update agent tool list + for tool in tools_to_remove: + self.remove_tool_from_agent(agent_id=request.id, tool_id=tool.id, user_id=actor.id) + for tool in tools_to_add: + self.add_tool_to_agent(agent_id=request.id, tool_id=tool.id, user_id=actor.id) + + # reload agent + letta_agent = self._get_or_load_agent(agent_id=request.id) # configs if request.llm_config: diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 1b85e316..94c73188 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -105,7 +105,7 @@ class ToolManager: return [tool.to_pydantic() for tool in tools] @enforce_types - def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> None: + def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> PydanticTool: """Update a tool by its ID with the given ToolUpdate object.""" with self.session_maker() as session: # Fetch the tool by ID @@ -116,24 +116,17 @@ class ToolManager: for key, value in update_data.items(): setattr(tool, key, value) - # If source code is changed and a new json_schema is not provided, we want to auto-refresh the name and schema - # CAUTION: This will override any name/schema values the user passed in + # If source code is changed and a new json_schema is not provided, we want to auto-refresh the schema if "source_code" in update_data.keys() and "json_schema" not in update_data.keys(): pydantic_tool = tool.to_pydantic() - # Decide whether or not to reset name - # If name was not explicitly passed in as part of the update, then we auto-generate a new name based on source code - name = None - if "name" in update_data.keys(): - name = update_data["name"] + name = update_data["name"] if "name" in update_data.keys() else None new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code, name=name) - # The name will either be set (if explicit) or autogenerated from the source code - tool.name = new_schema["name"] tool.json_schema = new_schema # Save the updated tool to the database - tool.update(db_session=session, actor=actor) + return tool.update(db_session=session, actor=actor) @enforce_types def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None: diff --git a/tests/test_managers.py b/tests/test_managers.py index 557c00c9..82c256d1 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -285,9 +285,8 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t assert updated_tool.source_code == source_code assert updated_tool.json_schema != og_json_schema - new_schema = derive_openai_json_schema(source_code=updated_tool.source_code, name=updated_tool.name) + new_schema = derive_openai_json_schema(source_code=updated_tool.source_code) assert updated_tool.json_schema == new_schema - assert updated_tool.name == new_schema["name"] def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture):