feat: add POST route for testing tool execution via tool_id (#2139)

This commit is contained in:
Charles Packer
2024-12-02 18:57:04 -08:00
committed by GitHub
parent 3b1f579aba
commit 1edd4ab4ff
5 changed files with 287 additions and 12 deletions

View File

@@ -201,3 +201,15 @@ class ToolUpdate(LettaBase):
class Config:
extra = "ignore" # Allows extra fields without validation errors
# TODO: Remove this, and clean usage of ToolUpdate everywhere else
class ToolRun(LettaBase):
id: str = Field(..., description="The ID of the tool to run.")
args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).")
class ToolRunFromSource(LettaBase):
args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).")
name: Optional[str] = Field(..., description="The name of the tool to run.")
source_code: str = Field(None, description="The source code of the function.")
source_type: Optional[str] = Field(None, description="The type of the source code.")

View File

@@ -5,7 +5,8 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException
from letta.errors import LettaToolCreateError
from letta.orm.errors import UniqueConstraintViolationError
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
from letta.schemas.letta_message import FunctionReturn
from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
@@ -159,6 +160,41 @@ def add_base_tools(
return server.tool_manager.add_base_tools(actor=actor)
# NOTE: can re-enable if needed
# @router.post("/{tool_id}/run", response_model=FunctionReturn, operation_id="run_tool")
# def run_tool(
# server: SyncServer = Depends(get_letta_server),
# request: ToolRun = Body(...),
# user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
# ):
# """
# Run an existing tool on provided arguments
# """
# actor = server.get_user_or_default(user_id=user_id)
# return server.run_tool(tool_id=request.tool_id, tool_args=request.tool_args, user_id=actor.id)
@router.post("/run", response_model=FunctionReturn, operation_id="run_tool_from_source")
def run_tool_from_source(
server: SyncServer = Depends(get_letta_server),
request: ToolRunFromSource = Body(...),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Attempt to build a tool from source, then run it on the provided arguments
"""
actor = server.get_user_or_default(user_id=user_id)
return server.run_tool_from_source(
tool_source=request.source_code,
tool_source_type=request.source_type,
tool_args=request.args,
tool_name=request.name,
user_id=actor.id,
)
# Specific routes for Composio

View File

@@ -1,4 +1,5 @@
# inspecting tools
import json
import os
import traceback
import warnings
@@ -56,7 +57,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
# openai schemas
from letta.schemas.enums import JobStatus
from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message import FunctionReturn, LettaMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import (
ArchivalMemorySummary,
@@ -78,9 +79,10 @@ from letta.services.organization_manager import OrganizationManager
from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.utils import create_random_username, json_dumps, json_loads
from letta.utils import create_random_username, get_utc_time, json_dumps, json_loads
logger = get_logger(__name__)
@@ -1764,6 +1766,110 @@ class SyncServer(Server):
return block
return None
# def run_tool(self, tool_id: str, tool_args: str, user_id: str) -> FunctionReturn:
# """Run a tool using the sandbox and return the result"""
# try:
# tool_args_dict = json.loads(tool_args)
# except json.JSONDecodeError:
# raise ValueError("Invalid JSON string for tool_args")
# # Get the tool by ID
# user = self.user_manager.get_user_by_id(user_id=user_id)
# tool = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=user)
# if tool.name is None:
# raise ValueError(f"Tool with id {tool_id} does not have a name")
# # TODO eventually allow using agent state in tools
# agent_state = None
# try:
# sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id).run(agent_state=agent_state)
# if sandbox_run_result is None:
# raise ValueError(f"Tool with id {tool_id} returned execution with None")
# function_response = str(sandbox_run_result.func_return)
# return FunctionReturn(
# id="null",
# function_call_id="null",
# date=get_utc_time(),
# status="success",
# function_return=function_response,
# )
# except Exception as e:
# # same as agent.py
# from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
# error_msg = f"Error executing tool {tool.name}: {e}"
# if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
# error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
# return FunctionReturn(
# id="null",
# function_call_id="null",
# date=get_utc_time(),
# status="error",
# function_return=error_msg,
# )
def run_tool_from_source(
self,
user_id: str,
tool_args: str,
tool_source: str,
tool_source_type: Optional[str] = None,
tool_name: Optional[str] = None,
) -> FunctionReturn:
"""Run a tool from source code"""
try:
tool_args_dict = json.loads(tool_args)
except json.JSONDecodeError:
raise ValueError("Invalid JSON string for tool_args")
if tool_source_type is not None and tool_source_type != "python":
raise ValueError("Only Python source code is supported at this time")
# NOTE: we're creating a floating Tool object and NOT persiting to DB
tool = Tool(
name=tool_name,
source_code=tool_source,
)
assert tool.name is not None, "Failed to create tool object"
# TODO eventually allow using agent state in tools
agent_state = None
# Next, attempt to run the tool with the sandbox
try:
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id, tool_object=tool).run(agent_state=agent_state)
if sandbox_run_result is None:
raise ValueError(f"Tool with id {tool.id} returned execution with None")
function_response = str(sandbox_run_result.func_return)
return FunctionReturn(
id="null",
function_call_id="null",
date=get_utc_time(),
status="success",
function_return=function_response,
)
except Exception as e:
# same as agent.py
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
error_msg = f"Error executing tool {tool.name}: {e}"
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
return FunctionReturn(
id="null",
function_call_id="null",
date=get_utc_time(),
status="error",
function_return=error_msg,
)
# Composio wrappers
def get_composio_apps(self) -> List["AppModel"]:
"""Get a list of all Composio apps with actions"""

View File

@@ -11,6 +11,7 @@ from typing import Any, Optional
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.tool import Tool
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
@@ -27,7 +28,7 @@ class ToolExecutionSandbox:
# We make this a long random string to avoid collisions with any variables in the user's code
LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt"
def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=False):
def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=False, tool_object: Optional[Tool] = None):
self.tool_name = tool_name
self.args = args
@@ -36,14 +37,18 @@ class ToolExecutionSandbox:
# agent_state is the state of the agent that invoked this run
self.user = UserManager().get_user_by_id(user_id=user_id)
# Get the tool
# TODO: So in theory, it's possible this retrieves a tool not provisioned to the agent
# TODO: That would probably imply that agent_state is incorrectly configured
self.tool = ToolManager().get_tool_by_name(tool_name=tool_name, actor=self.user)
if not self.tool:
raise ValueError(
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
)
# If a tool object is provided, we use it directly, otherwise pull via name
if tool_object is not None:
self.tool = tool_object
else:
# Get the tool via name
# TODO: So in theory, it's possible this retrieves a tool not provisioned to the agent
# TODO: That would probably imply that agent_state is incorrectly configured
self.tool = ToolManager().get_tool_by_name(tool_name=tool_name, actor=self.user)
if not self.tool:
raise ValueError(
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
)
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
self.force_recreate = force_recreate

View File

@@ -541,6 +541,122 @@ def test_get_messages_letta_format(server, user_id, agent_id):
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)
EXAMPLE_TOOL_SOURCE = '''
def ingest(message: str):
"""
Ingest a message into the system.
Args:
message (str): The message to ingest into the system.
Returns:
str: The result of ingesting the message.
"""
return f"Ingested message {message}"
'''
EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR = '''
def util_do_nothing():
"""
A util function that does nothing.
Returns:
str: Dummy output.
"""
print("I'm a distractor")
def ingest(message: str):
"""
Ingest a message into the system.
Args:
message (str): The message to ingest into the system.
Returns:
str: The result of ingesting the message.
"""
util_do_nothing()
return f"Ingested message {message}"
'''
def test_tool_run(server, user_id, agent_id):
"""Test that the server can run tools"""
result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"message": "Hello, world!"}),
# tool_name="ingest",
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Hello, world!", result.function_return
result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
# tool_name="ingest",
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Well well well", result.function_return
result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"bad_arg": "oh no"}),
# tool_name="ingest",
)
print(result)
assert result.status == "error"
assert "Error" in result.function_return, result.function_return
assert "missing 1 required positional argument" in result.function_return, result.function_return
# Test that we can still pull the tool out by default (pulls that last tool in the source)
result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
# tool_name="ingest",
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Well well well", result.function_return
# Test that we can pull the tool out by name
result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
tool_name="ingest",
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Well well well", result.function_return
# Test that we can pull a different tool out by name
result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({}),
tool_name="util_do_nothing",
)
print(result)
assert result.status == "success"
assert result.function_return == str(None), result.function_return
def test_composio_client_simple(server):
apps = server.get_composio_apps()
# Assert there's some amount of apps returned