From 1edd4ab4ffeeb2d5541e66cd67b1c6920aca4357 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Mon, 2 Dec 2024 18:57:04 -0800 Subject: [PATCH] feat: add POST route for testing tool execution via `tool_id` (#2139) --- letta/schemas/tool.py | 12 +++ letta/server/rest_api/routers/v1/tools.py | 38 ++++++- letta/server/server.py | 110 +++++++++++++++++++- letta/services/tool_execution_sandbox.py | 23 +++-- tests/test_server.py | 116 ++++++++++++++++++++++ 5 files changed, 287 insertions(+), 12 deletions(-) diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 6391d2bf..ff39ed7c 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -201,3 +201,15 @@ class ToolUpdate(LettaBase): class Config: extra = "ignore" # Allows extra fields without validation errors # TODO: Remove this, and clean usage of ToolUpdate everywhere else + + +class ToolRun(LettaBase): + id: str = Field(..., description="The ID of the tool to run.") + args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).") + + +class ToolRunFromSource(LettaBase): + args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).") + name: Optional[str] = Field(..., description="The name of the tool to run.") + source_code: str = Field(None, description="The source code of the function.") + source_type: Optional[str] = Field(None, description="The type of the source code.") diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 41daae83..d9e24473 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -5,7 +5,8 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException from letta.errors import LettaToolCreateError from letta.orm.errors import UniqueConstraintViolationError -from letta.schemas.tool import Tool, ToolCreate, ToolUpdate +from letta.schemas.letta_message import FunctionReturn +from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer @@ -159,6 +160,41 @@ def add_base_tools( return server.tool_manager.add_base_tools(actor=actor) +# NOTE: can re-enable if needed +# @router.post("/{tool_id}/run", response_model=FunctionReturn, operation_id="run_tool") +# def run_tool( +# server: SyncServer = Depends(get_letta_server), +# request: ToolRun = Body(...), +# user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +# ): +# """ +# Run an existing tool on provided arguments +# """ +# actor = server.get_user_or_default(user_id=user_id) + +# return server.run_tool(tool_id=request.tool_id, tool_args=request.tool_args, user_id=actor.id) + + +@router.post("/run", response_model=FunctionReturn, operation_id="run_tool_from_source") +def run_tool_from_source( + server: SyncServer = Depends(get_letta_server), + request: ToolRunFromSource = Body(...), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Attempt to build a tool from source, then run it on the provided arguments + """ + actor = server.get_user_or_default(user_id=user_id) + + return server.run_tool_from_source( + tool_source=request.source_code, + tool_source_type=request.source_type, + tool_args=request.args, + tool_name=request.name, + user_id=actor.id, + ) + + # Specific routes for Composio diff --git a/letta/server/server.py b/letta/server/server.py index 71befe05..acfb904d 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1,4 +1,5 @@ # inspecting tools +import json import os import traceback import warnings @@ -56,7 +57,7 @@ from letta.schemas.embedding_config import EmbeddingConfig # openai schemas from letta.schemas.enums import JobStatus from letta.schemas.job import Job -from letta.schemas.letta_message import LettaMessage +from letta.schemas.letta_message import FunctionReturn, LettaMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ( ArchivalMemorySummary, @@ -78,9 +79,10 @@ from letta.services.organization_manager import OrganizationManager from letta.services.per_agent_lock_manager import PerAgentLockManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager +from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager -from letta.utils import create_random_username, json_dumps, json_loads +from letta.utils import create_random_username, get_utc_time, json_dumps, json_loads logger = get_logger(__name__) @@ -1764,6 +1766,110 @@ class SyncServer(Server): return block return None + # def run_tool(self, tool_id: str, tool_args: str, user_id: str) -> FunctionReturn: + # """Run a tool using the sandbox and return the result""" + + # try: + # tool_args_dict = json.loads(tool_args) + # except json.JSONDecodeError: + # raise ValueError("Invalid JSON string for tool_args") + + # # Get the tool by ID + # user = self.user_manager.get_user_by_id(user_id=user_id) + # tool = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=user) + # if tool.name is None: + # raise ValueError(f"Tool with id {tool_id} does not have a name") + + # # TODO eventually allow using agent state in tools + # agent_state = None + + # try: + # sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id).run(agent_state=agent_state) + # if sandbox_run_result is None: + # raise ValueError(f"Tool with id {tool_id} returned execution with None") + # function_response = str(sandbox_run_result.func_return) + + # return FunctionReturn( + # id="null", + # function_call_id="null", + # date=get_utc_time(), + # status="success", + # function_return=function_response, + # ) + # except Exception as e: + # # same as agent.py + # from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT + + # error_msg = f"Error executing tool {tool.name}: {e}" + # if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT: + # error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT] + + # return FunctionReturn( + # id="null", + # function_call_id="null", + # date=get_utc_time(), + # status="error", + # function_return=error_msg, + # ) + + def run_tool_from_source( + self, + user_id: str, + tool_args: str, + tool_source: str, + tool_source_type: Optional[str] = None, + tool_name: Optional[str] = None, + ) -> FunctionReturn: + """Run a tool from source code""" + + try: + tool_args_dict = json.loads(tool_args) + except json.JSONDecodeError: + raise ValueError("Invalid JSON string for tool_args") + + if tool_source_type is not None and tool_source_type != "python": + raise ValueError("Only Python source code is supported at this time") + + # NOTE: we're creating a floating Tool object and NOT persiting to DB + tool = Tool( + name=tool_name, + source_code=tool_source, + ) + assert tool.name is not None, "Failed to create tool object" + + # TODO eventually allow using agent state in tools + agent_state = None + + # Next, attempt to run the tool with the sandbox + try: + sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id, tool_object=tool).run(agent_state=agent_state) + if sandbox_run_result is None: + raise ValueError(f"Tool with id {tool.id} returned execution with None") + function_response = str(sandbox_run_result.func_return) + + return FunctionReturn( + id="null", + function_call_id="null", + date=get_utc_time(), + status="success", + function_return=function_response, + ) + except Exception as e: + # same as agent.py + from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT + + error_msg = f"Error executing tool {tool.name}: {e}" + if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT: + error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT] + + return FunctionReturn( + id="null", + function_call_id="null", + date=get_utc_time(), + status="error", + function_return=error_msg, + ) + # Composio wrappers def get_composio_apps(self) -> List["AppModel"]: """Get a list of all Composio apps with actions""" diff --git a/letta/services/tool_execution_sandbox.py b/letta/services/tool_execution_sandbox.py index c1c48979..14068715 100644 --- a/letta/services/tool_execution_sandbox.py +++ b/letta/services/tool_execution_sandbox.py @@ -11,6 +11,7 @@ from typing import Any, Optional from letta.log import get_logger from letta.schemas.agent import AgentState from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType +from letta.schemas.tool import Tool from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager @@ -27,7 +28,7 @@ class ToolExecutionSandbox: # We make this a long random string to avoid collisions with any variables in the user's code LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt" - def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=False): + def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=False, tool_object: Optional[Tool] = None): self.tool_name = tool_name self.args = args @@ -36,14 +37,18 @@ class ToolExecutionSandbox: # agent_state is the state of the agent that invoked this run self.user = UserManager().get_user_by_id(user_id=user_id) - # Get the tool - # TODO: So in theory, it's possible this retrieves a tool not provisioned to the agent - # TODO: That would probably imply that agent_state is incorrectly configured - self.tool = ToolManager().get_tool_by_name(tool_name=tool_name, actor=self.user) - if not self.tool: - raise ValueError( - f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}" - ) + # If a tool object is provided, we use it directly, otherwise pull via name + if tool_object is not None: + self.tool = tool_object + else: + # Get the tool via name + # TODO: So in theory, it's possible this retrieves a tool not provisioned to the agent + # TODO: That would probably imply that agent_state is incorrectly configured + self.tool = ToolManager().get_tool_by_name(tool_name=tool_name, actor=self.user) + if not self.tool: + raise ValueError( + f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}" + ) self.sandbox_config_manager = SandboxConfigManager(tool_settings) self.force_recreate = force_recreate diff --git a/tests/test_server.py b/tests/test_server.py index ab36f555..43443b23 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -541,6 +541,122 @@ def test_get_messages_letta_format(server, user_id, agent_id): _test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse) +EXAMPLE_TOOL_SOURCE = ''' +def ingest(message: str): + """ + Ingest a message into the system. + + Args: + message (str): The message to ingest into the system. + + Returns: + str: The result of ingesting the message. + """ + return f"Ingested message {message}" + +''' + + +EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR = ''' +def util_do_nothing(): + """ + A util function that does nothing. + + Returns: + str: Dummy output. + """ + print("I'm a distractor") + +def ingest(message: str): + """ + Ingest a message into the system. + + Args: + message (str): The message to ingest into the system. + + Returns: + str: The result of ingesting the message. + """ + util_do_nothing() + return f"Ingested message {message}" + +''' + + +def test_tool_run(server, user_id, agent_id): + """Test that the server can run tools""" + + result = server.run_tool_from_source( + user_id=user_id, + tool_source=EXAMPLE_TOOL_SOURCE, + tool_source_type="python", + tool_args=json.dumps({"message": "Hello, world!"}), + # tool_name="ingest", + ) + print(result) + assert result.status == "success" + assert result.function_return == "Ingested message Hello, world!", result.function_return + + result = server.run_tool_from_source( + user_id=user_id, + tool_source=EXAMPLE_TOOL_SOURCE, + tool_source_type="python", + tool_args=json.dumps({"message": "Well well well"}), + # tool_name="ingest", + ) + print(result) + assert result.status == "success" + assert result.function_return == "Ingested message Well well well", result.function_return + + result = server.run_tool_from_source( + user_id=user_id, + tool_source=EXAMPLE_TOOL_SOURCE, + tool_source_type="python", + tool_args=json.dumps({"bad_arg": "oh no"}), + # tool_name="ingest", + ) + print(result) + assert result.status == "error" + assert "Error" in result.function_return, result.function_return + assert "missing 1 required positional argument" in result.function_return, result.function_return + + # Test that we can still pull the tool out by default (pulls that last tool in the source) + result = server.run_tool_from_source( + user_id=user_id, + tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR, + tool_source_type="python", + tool_args=json.dumps({"message": "Well well well"}), + # tool_name="ingest", + ) + print(result) + assert result.status == "success" + assert result.function_return == "Ingested message Well well well", result.function_return + + # Test that we can pull the tool out by name + result = server.run_tool_from_source( + user_id=user_id, + tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR, + tool_source_type="python", + tool_args=json.dumps({"message": "Well well well"}), + tool_name="ingest", + ) + print(result) + assert result.status == "success" + assert result.function_return == "Ingested message Well well well", result.function_return + + # Test that we can pull a different tool out by name + result = server.run_tool_from_source( + user_id=user_id, + tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR, + tool_source_type="python", + tool_args=json.dumps({}), + tool_name="util_do_nothing", + ) + print(result) + assert result.status == "success" + assert result.function_return == str(None), result.function_return + + def test_composio_client_simple(server): apps = server.get_composio_apps() # Assert there's some amount of apps returned