diff --git a/letta/client/client.py b/letta/client/client.py index f39a52fa..77468f0d 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -96,6 +96,12 @@ class AbstractClient(object): ): raise NotImplementedError + def add_tool_to_agent(self, agent_id: str, tool_id: str): + raise NotImplementedError + + def remove_tool_from_agent(self, agent_id: str, tool_id: str): + raise NotImplementedError + def rename_agent(self, agent_id: str, new_name: str): raise NotImplementedError @@ -474,6 +480,39 @@ class RESTClient(AbstractClient): raise ValueError(f"Failed to update agent: {response.text}") return AgentState(**response.json()) + def add_tool_to_agent(self, agent_id: str, tool_id: str): + """ + Add tool to an existing agent + + Args: + agent_id (str): ID of the agent + tool_id (str): A tool id + + Returns: + agent_state (AgentState): State of the updated agent + """ + response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/add-tool/{tool_id}", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to update agent: {response.text}") + return AgentState(**response.json()) + + def remove_tool_from_agent(self, agent_id: str, tool_id: str): + """ + Removes tools from an existing agent + + Args: + agent_id (str): ID of the agent + tool_id (str): The tool id + + Returns: + agent_state (AgentState): State of the updated agent + """ + + response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/remove-tool/{tool_id}", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to update agent: {response.text}") + return AgentState(**response.json()) + def rename_agent(self, agent_id: str, new_name: str): """ Rename an agent @@ -1653,6 +1692,36 @@ class LocalClient(AbstractClient): ) return agent_state + def add_tool_to_agent(self, agent_id: str, tool_id: str): + """ + Add tool to an existing agent + + Args: + agent_id (str): ID of the agent + tool_id (str): A tool id + + Returns: + agent_state (AgentState): State of the updated agent + """ + self.interface.clear() + agent_state = self.server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id) + return agent_state + + def remove_tool_from_agent(self, agent_id: str, tool_id: str): + """ + Removes tools from an existing agent + + Args: + agent_id (str): ID of the agent + tool_id (str): The tool id + + Returns: + agent_state (AgentState): State of the updated agent + """ + self.interface.clear() + agent_state = self.server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id) + return agent_state + def rename_agent(self, agent_id: str, new_name: str): """ Rename an agent @@ -2081,30 +2150,37 @@ class LocalClient(AbstractClient): Returns: None """ - existing_tool_id = self.get_tool_id(tool.name) - if existing_tool_id: + if self.tool_with_name_and_user_id_exists(tool): if update: - self.server.update_tool( + return self.server.update_tool( ToolUpdate( - id=existing_tool_id, + id=tool.id, + description=tool.description, source_type=tool.source_type, source_code=tool.source_code, tags=tool.tags, json_schema=tool.json_schema, name=tool.name, - ) + ), + self.user_id, ) else: - raise ValueError(f"Tool with name {tool.name} already exists") - - # call server function - return self.server.create_tool( - ToolCreate( - source_type=tool.source_type, source_code=tool.source_code, name=tool.name, json_schema=tool.json_schema, tags=tool.tags - ), - user_id=self.user_id, - update=update, - ) + raise ValueError(f"Tool with id={tool.id} and name={tool.name}already exists") + else: + # call server function + return self.server.create_tool( + ToolCreate( + id=tool.id, + description=tool.description, + source_type=tool.source_type, + source_code=tool.source_code, + name=tool.name, + json_schema=tool.json_schema, + tags=tool.tags, + ), + user_id=self.user_id, + update=update, + ) # TODO: Use the above function `add_tool` here as there is duplicate logic def create_tool( @@ -2170,7 +2246,9 @@ class LocalClient(AbstractClient): source_type = "python" - return self.server.update_tool(ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name)) + return self.server.update_tool( + ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name), self.user_id + ) def list_tools(self): """ @@ -2215,7 +2293,17 @@ class LocalClient(AbstractClient): """ return self.server.get_tool_id(name, self.user_id) - # data sources + def tool_with_name_and_user_id_exists(self, tool: Tool) -> bool: + """ + Check if the tool with name and user_id exists + + Args: + tool (Tool): the tool + + Returns: + (bool): True if the id exists, False otherwise. + """ + return self.server.tool_with_name_and_user_id_exists(tool, self.user_id) def load_data(self, connector: DataConnector, source_name: str): """ diff --git a/letta/client/utils.py b/letta/client/utils.py index bcec534c..1ff28f8c 100644 --- a/letta/client/utils.py +++ b/letta/client/utils.py @@ -1,6 +1,9 @@ +import re from datetime import datetime +from typing import Optional from IPython.display import HTML, display +from sqlalchemy.testing.plugin.plugin_base import warnings from letta.local_llm.constants import ( ASSISTANT_MESSAGE_CLI_SYMBOL, @@ -64,3 +67,15 @@ def pprint(messages): html_content += "" display(HTML(html_content)) + + +def derive_function_name_regex(function_string: str) -> Optional[str]: + # Regular expression to match the function name + match = re.search(r"def\s+([a-zA-Z_]\w*)\s*\(", function_string) + + if match: + function_name = match.group(1) + return function_name + else: + warnings.warn("No function name found.") + return None diff --git a/letta/metadata.py b/letta/metadata.py index 18960bbd..b4eab3ed 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -577,7 +577,7 @@ class MetadataStore: @enforce_types def create_tool(self, tool: Tool): with self.session_maker() as session: - if self.get_tool(tool_name=tool.name, user_id=tool.user_id) is not None: + if self.get_tool(tool_id=tool.id, tool_name=tool.name, user_id=tool.user_id) is not None: raise ValueError(f"Tool with name {tool.name} already exists") session.add(ToolModel(**vars(tool))) session.commit() @@ -620,9 +620,9 @@ class MetadataStore: session.commit() @enforce_types - def update_tool(self, tool: Tool): + def update_tool(self, tool_id: str, tool: Tool): with self.session_maker() as session: - session.query(ToolModel).filter(ToolModel.id == tool.id).update(vars(tool)) + session.query(ToolModel).filter(ToolModel.id == tool_id).update(vars(tool)) session.commit() @enforce_types @@ -815,6 +815,15 @@ class MetadataStore: results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all() if user_id: results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all() + if len(results) == 0: + return None + # assert len(results) == 1, f"Expected 1 result, got {len(results)}" + return results[0].to_record() + + @enforce_types + def get_tool_with_name_and_user_id(self, tool_name: Optional[str] = None, user_id: Optional[str] = None) -> Optional[ToolModel]: + with self.session_maker() as session: + results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all() if len(results) == 0: return None assert len(results) == 1, f"Expected 1 result, got {len(results)}" diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 367765fd..40f7c833 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -128,3 +128,8 @@ class AgentStepResponse(BaseModel): ..., description="Whether the agent step ended because the in-context memory is near its limit." ) usage: UsageStatistics = Field(..., description="Usage statistics of the LLM call during the agent's step.") + + +class RemoveToolsFromAgent(BaseModel): + agent_id: str = Field(..., description="The id of the agent.") + tool_ids: Optional[List[str]] = Field(None, description="The tools to be removed from the agent.") diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 10faec4c..ee9cf156 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -176,7 +176,9 @@ class Tool(BaseTool): class ToolCreate(BaseTool): + id: Optional[str] = Field(None, description="The unique identifier of the tool. If this is not provided, it will be autogenerated.") name: Optional[str] = Field(None, description="The name of the function (auto-generated from source_code if not provided).") + description: Optional[str] = Field(None, description="The description of the tool.") tags: List[str] = Field([], description="Metadata tags.") source_code: str = Field(..., description="The source code of the function.") json_schema: Optional[Dict] = Field( @@ -187,6 +189,7 @@ class ToolCreate(BaseTool): class ToolUpdate(ToolCreate): id: str = Field(..., description="The unique identifier of the tool.") + description: Optional[str] = Field(None, description="The description of the tool.") name: Optional[str] = Field(None, description="The name of the function.") tags: Optional[List[str]] = Field(None, description="Metadata tags.") source_code: Optional[str] = Field(None, description="The source code of the function.") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index b7a2c9bf..f3f8a79d 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -100,6 +100,34 @@ def update_agent( return server.update_agent(update_agent, user_id=actor.id) +@router.patch("/{agent_id}/add-tool/{tool_id}", response_model=AgentState, operation_id="add_tool_to_agent") +def add_tool_to_agent( + agent_id: str, + tool_id: str, + server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """Add tools to an exsiting agent""" + actor = server.get_user_or_default(user_id=user_id) + + update_agent.id = agent_id + return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id) + + +@router.patch("/{agent_id}/remove-tool/{tool_id}", response_model=AgentState, operation_id="remove_tool_from_agent") +def remove_tool_from_agent( + agent_id: str, + tool_id: str, + server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """Add tools to an exsiting agent""" + actor = server.get_user_or_default(user_id=user_id) + + update_agent.id = agent_id + return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id) + + @router.get("/{agent_id}", response_model=AgentState, operation_id="get_agent") def get_agent_state( agent_id: str, diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 404fabfd..0defac11 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -105,4 +105,4 @@ def update_tool( """ assert tool_id == request.id, "Tool ID in path must match tool ID in request body" # actor = server.get_user_or_default(user_id=user_id) - return server.update_tool(request) + return server.update_tool(request, user_id) diff --git a/letta/server/server.py b/letta/server/server.py index 90b68911..389bea2e 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -16,6 +16,7 @@ import letta.system as system from letta.agent import Agent, save_agent from letta.agent_store.db import attach_base from letta.agent_store.storage import StorageConnector, TableType +from letta.client.utils import derive_function_name_regex from letta.credentials import LettaCredentials from letta.data_sources.connectors import DataConnector, load_data @@ -965,6 +966,80 @@ class SyncServer(Server): # TODO: probably reload the agent somehow? return letta_agent.agent_state + def add_tool_to_agent( + self, + agent_id: str, + tool_id: str, + user_id: str, + ): + """Update the agents core memory block, return the new state""" + if self.ms.get_user(user_id=user_id) is None: + raise ValueError(f"User user_id={user_id} does not exist") + if self.ms.get_agent(agent_id=agent_id) is None: + raise ValueError(f"Agent agent_id={agent_id} does not exist") + + # Get the agent object (loaded in memory) + letta_agent = self._get_or_load_agent(agent_id=agent_id) + + # Get all the tool objects from the request + tool_objs = [] + tool_obj = self.ms.get_tool(tool_id=tool_id, user_id=user_id) + assert tool_obj, f"Tool with id={tool_id} does not exist" + tool_objs.append(tool_obj) + + for tool in letta_agent.tools: + tool_obj = self.ms.get_tool(tool_id=tool.id, user_id=user_id) + assert tool_obj, f"Tool with id={tool.id} does not exist" + + # If it's not the already added tool + if tool_obj.id != tool_id: + tool_objs.append(tool_obj) + + # replace the list of tool names ("ids") inside the agent state + letta_agent.agent_state.tools = [tool.name for tool in tool_objs] + + # then attempt to link the tools modules + letta_agent.link_tools(tool_objs) + + # save the agent + save_agent(letta_agent, self.ms) + return letta_agent.agent_state + + def remove_tool_from_agent( + self, + agent_id: str, + tool_id: str, + user_id: str, + ): + """Update the agents core memory block, return the new state""" + if self.ms.get_user(user_id=user_id) is None: + raise ValueError(f"User user_id={user_id} does not exist") + if self.ms.get_agent(agent_id=agent_id) is None: + raise ValueError(f"Agent agent_id={agent_id} does not exist") + + # Get the agent object (loaded in memory) + letta_agent = self._get_or_load_agent(agent_id=agent_id) + + # Get all the tool_objs + tool_objs = [] + for tool in letta_agent.tools: + tool_obj = self.ms.get_tool(tool_id=tool.id, user_id=user_id) + assert tool_obj, f"Tool with id={tool.id} does not exist" + + # If it's not the tool we want to remove + if tool_obj.id != tool_id: + tool_objs.append(tool_obj) + + # replace the list of tool names ("ids") inside the agent state + letta_agent.agent_state.tools = [tool.name for tool in tool_objs] + + # then attempt to link the tools modules + letta_agent.link_tools(tool_objs) + + # save the agent + save_agent(letta_agent, self.ms) + return letta_agent.agent_state + def _agent_state_to_config(self, agent_state: AgentState) -> dict: """Convert AgentState to a dict for a JSON response""" assert agent_state is not None @@ -1751,6 +1826,15 @@ class SyncServer(Server): """Get tool by ID.""" return self.ms.get_tool(tool_id=tool_id) + def tool_with_name_and_user_id_exists(self, tool: Tool, user_id: Optional[str] = None) -> bool: + """Check if tool exists""" + tool = self.ms.get_tool_with_name_and_user_id(tool_name=tool.name, user_id=user_id) + + if tool is None: + return False + else: + return True + def get_tool_id(self, name: str, user_id: str) -> Optional[str]: """Get tool ID from name and user_id.""" tool = self.ms.get_tool(tool_name=name, user_id=user_id) @@ -1758,16 +1842,27 @@ class SyncServer(Server): return None return tool.id - def update_tool( - self, - request: ToolUpdate, - ) -> Tool: + def update_tool(self, request: ToolUpdate, user_id: Optional[str] = None) -> Tool: """Update an existing tool""" - existing_tool = self.ms.get_tool(tool_id=request.id) - if not existing_tool: - raise ValueError(f"Tool does not exist") + if request.name: + existing_tool = self.ms.get_tool_with_name_and_user_id(tool_name=request.name, user_id=user_id) + if existing_tool is None: + raise ValueError(f"Tool with name={request.name}, user_id={user_id} does not exist") + else: + existing_tool = self.ms.get_tool(tool_id=request.id) + if existing_tool is None: + raise ValueError(f"Tool with id={request.id} does not exist") + + # Preserve the original tool id + # As we can override the tool id as well + # This is probably bad design if this is exposed to users... + original_id = existing_tool.id # override updated fields + if request.id: + existing_tool.id = request.id + if request.description: + existing_tool.description = request.description if request.source_code: existing_tool.source_code = request.source_code if request.source_type: @@ -1776,10 +1871,15 @@ class SyncServer(Server): existing_tool.tags = request.tags if request.json_schema: existing_tool.json_schema = request.json_schema + + # If name is explicitly provided here, overide the tool name if request.name: existing_tool.name = request.name + # Otherwise, if there's no name, and there's source code, we try to derive the name + elif request.source_code: + existing_tool.name = derive_function_name_regex(request.source_code) - self.ms.update_tool(existing_tool) + self.ms.update_tool(original_id, existing_tool) return self.ms.get_tool(tool_id=request.id) def create_tool(self, request: ToolCreate, user_id: Optional[str] = None, update: bool = True) -> Tool: # TODO: add other fields @@ -1817,15 +1917,23 @@ class SyncServer(Server): assert request.name, f"Tool name must be provided in json_schema {json_schema}. This should never happen." # check if already exists: - existing_tool = self.ms.get_tool(tool_name=request.name, user_id=user_id) + existing_tool = self.ms.get_tool(tool_id=request.id, tool_name=request.name, user_id=user_id) if existing_tool: if update: - updated_tool = self.update_tool(ToolUpdate(id=existing_tool.id, **vars(request))) + # id is an optional field, so we will fill it with the existing tool id + if not request.id: + request.id = existing_tool.id + updated_tool = self.update_tool(ToolUpdate(**vars(request)), user_id) assert updated_tool is not None, f"Failed to update tool {request.name}" return updated_tool else: raise ValueError(f"Tool {request.name} already exists and update=False") + # check for description + description = None + if request.description: + description = request.description + tool = Tool( name=request.name, source_code=request.source_code, @@ -1833,9 +1941,14 @@ class SyncServer(Server): tags=request.tags, json_schema=json_schema, user_id=user_id, + description=description, ) + + if request.id: + tool.id = request.id + self.ms.create_tool(tool) - created_tool = self.ms.get_tool(tool_name=request.name, user_id=user_id) + created_tool = self.ms.get_tool(tool_id=tool.id, user_id=user_id) return created_tool def delete_tool(self, tool_id: str): diff --git a/tests/test_new_client.py b/tests/test_new_client.py index f4bb1ed6..76e414df 100644 --- a/tests/test_new_client.py +++ b/tests/test_new_client.py @@ -1,3 +1,4 @@ +import uuid from typing import Union import pytest @@ -9,6 +10,7 @@ from letta.schemas.block import Block from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory +from letta.schemas.tool import Tool @pytest.fixture(scope="module") @@ -16,12 +18,17 @@ def client(): client = create_client() client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) + yield client @pytest.fixture(scope="module") def agent(client): - agent_state = client.create_agent(name="test_agent") + # Generate uuid for agent name for this example + namespace = uuid.NAMESPACE_DNS + agent_uuid = str(uuid.uuid5(namespace, "test_new_client_test_agent")) + + agent_state = client.create_agent(name=agent_uuid) yield agent_state client.delete_agent(agent_state.id) @@ -114,6 +121,52 @@ def test_agent(client: Union[LocalClient, RESTClient]): client.delete_agent(agent_state_test.id) +def test_agent_add_remove_tools(client: Union[LocalClient, RESTClient], agent): + # Create and add two tools to the client + # tool 1 + from composio_langchain import Action + + github_tool = Tool.get_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) + client.add_tool(github_tool) + # tool 2 + from crewai_tools import ScrapeWebsiteTool + + scrape_website_tool = Tool.from_crewai(ScrapeWebsiteTool(website_url="https://www.example.com")) + client.add_tool(scrape_website_tool) + + # assert both got added + tools = client.list_tools() + assert github_tool.id in [t.id for t in tools] + assert scrape_website_tool.id in [t.id for t in tools] + + # Assert that all combinations of tool_names, tool_user_ids are unique + combinations = [(t.name, t.user_id) for t in tools] + assert len(combinations) == len(set(combinations)) + + # create agent + agent_state = agent + curr_num_tools = len(agent_state.tools) + + # add both tools to agent in steps + agent_state = client.add_tool_to_agent(agent_id=agent_state.id, tool_id=github_tool.id) + agent_state = client.add_tool_to_agent(agent_id=agent_state.id, tool_id=scrape_website_tool.id) + + # confirm that both tools are in the agent state + curr_tools = agent_state.tools + assert len(curr_tools) == curr_num_tools + 2 + assert github_tool.name in curr_tools + assert scrape_website_tool.name in curr_tools + + # remove only the github tool + agent_state = client.remove_tool_from_agent(agent_id=agent_state.id, tool_id=github_tool.id) + + # confirm that only one tool left + curr_tools = agent_state.tools + assert len(curr_tools) == curr_num_tools + 1 + assert github_tool.name not in curr_tools + assert scrape_website_tool.name in curr_tools + + def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]): persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id) human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id) @@ -242,8 +295,7 @@ def test_tools(client: Union[LocalClient, RESTClient]): print(msg) # create tool - len(client.list_tools()) - tool = client.create_tool(print_tool, tags=["extras"]) + tool = client.create_tool(func=print_tool, tags=["extras"]) # list tools tools = client.list_tools() @@ -258,19 +310,13 @@ def test_tools(client: Union[LocalClient, RESTClient]): assert client.get_tool(tool.id).tags == extras2 # update tool: source code - client.update_tool(tool.id, name="print_tool2", func=print_tool2) + client.update_tool(tool.id, func=print_tool2) assert client.get_tool(tool.id).name == "print_tool2" - ## delete tool - # client.delete_tool(tool.id) - # assert len(client.list_tools()) == orig_tool_length - def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]): from composio_langchain import Action - from letta.schemas.tool import Tool - # Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example) client = create_client() @@ -292,8 +338,6 @@ def test_tools_from_crewai(client: Union[LocalClient, RESTClient]): from crewai_tools import ScrapeWebsiteTool - from letta.schemas.tool import Tool - crewai_tool = ScrapeWebsiteTool() # Translate to memGPT Tool @@ -329,8 +373,6 @@ def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]): from crewai_tools import ScrapeWebsiteTool - from letta.schemas.tool import Tool - crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com") # Translate to memGPT Tool @@ -363,8 +405,6 @@ def test_tools_from_langchain(client: Union[LocalClient, RESTClient]): from langchain_community.tools import WikipediaQueryRun from langchain_community.utilities import WikipediaAPIWrapper - from letta.schemas.tool import Tool - api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100) langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper) @@ -397,8 +437,6 @@ def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, REST from langchain_community.tools import WikipediaQueryRun from langchain_community.utilities import WikipediaAPIWrapper - from letta.schemas.tool import Tool - api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100) langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)