From d61b2f9545e4421ed44dffccac4d8728cbcc678d Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Mon, 9 Dec 2024 18:55:18 -0800 Subject: [PATCH] feat: enable configuration of `response_char_limit` for tools (#2207) --- ...f_add_column_to_tools_table_to_contain_.py | 39 ++++++++++++++++ letta/agent.py | 8 +++- letta/client/client.py | 44 ++++++++++++++----- letta/orm/tool.py | 14 +++--- letta/schemas/tool.py | 5 +++ letta/utils.py | 11 +++-- tests/test_client.py | 39 +++++++++++++++- 7 files changed, 133 insertions(+), 27 deletions(-) create mode 100644 alembic/versions/a91994b9752f_add_column_to_tools_table_to_contain_.py diff --git a/alembic/versions/a91994b9752f_add_column_to_tools_table_to_contain_.py b/alembic/versions/a91994b9752f_add_column_to_tools_table_to_contain_.py new file mode 100644 index 00000000..f8da3856 --- /dev/null +++ b/alembic/versions/a91994b9752f_add_column_to_tools_table_to_contain_.py @@ -0,0 +1,39 @@ +"""add column to tools table to contain function return limit return_char_limit + +Revision ID: a91994b9752f +Revises: e1a625072dbf +Create Date: 2024-12-09 18:27:25.650079 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op +from letta.constants import FUNCTION_RETURN_CHAR_LIMIT + +# revision identifiers, used by Alembic. +revision: str = "a91994b9752f" +down_revision: Union[str, None] = "e1a625072dbf" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("tools", sa.Column("return_char_limit", sa.Integer(), nullable=True)) + + # Populate `return_char_limit` column + op.execute( + f""" + UPDATE tools + SET return_char_limit = {FUNCTION_RETURN_CHAR_LIMIT} + """ + ) + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("tools", "return_char_limit") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index a09fd04c..81924c2e 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -801,7 +801,13 @@ class Agent(BaseAgent): # but by default, we add a truncation safeguard to prevent bad functions from # overflow the agent context window truncate = True - function_response_string = validate_function_response(function_response, truncate=truncate) + + # get the function response limit + tool_obj = [tool for tool in self.agent_state.tools if tool.name == function_name][0] + return_char_limit = tool_obj.return_char_limit + function_response_string = validate_function_response( + function_response, return_char_limit=return_char_limit, truncate=truncate + ) function_args.pop("self", None) function_response = package_function_response(True, function_response_string) function_failed = False diff --git a/letta/client/client.py b/letta/client/client.py index 97f36cd9..6456aa3f 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -11,6 +11,7 @@ from letta.constants import ( BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, + FUNCTION_RETURN_CHAR_LIMIT, ) from letta.data_sources.connectors import DataConnector from letta.functions.functions import parse_source_code @@ -200,18 +201,12 @@ class AbstractClient(object): raise NotImplementedError def create_tool( - self, - func, - name: Optional[str] = None, - tags: Optional[List[str]] = None, + self, func, name: Optional[str] = None, tags: Optional[List[str]] = None, return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT ) -> Tool: raise NotImplementedError def create_or_update_tool( - self, - func, - name: Optional[str] = None, - tags: Optional[List[str]] = None, + self, func, name: Optional[str] = None, tags: Optional[List[str]] = None, return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT ) -> Tool: raise NotImplementedError @@ -222,6 +217,7 @@ class AbstractClient(object): description: Optional[str] = None, func: Optional[Callable] = None, tags: Optional[List[str]] = None, + return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, ) -> Tool: raise NotImplementedError @@ -1465,6 +1461,7 @@ class RESTClient(AbstractClient): func: Callable, name: Optional[str] = None, tags: Optional[List[str]] = None, + return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, ) -> Tool: """ Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent. @@ -1473,6 +1470,7 @@ class RESTClient(AbstractClient): func (callable): The function to create a tool for. name: (str): Name of the tool (must be unique per-user.) tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. + return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT. Returns: tool (Tool): The created tool. @@ -1481,7 +1479,9 @@ class RESTClient(AbstractClient): source_type = "python" # call server function - request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags) + request = ToolCreate(source_type=source_type, source_code=source_code, name=name, return_char_limit=return_char_limit) + if tags: + request.tags = tags response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create tool: {response.text}") @@ -1492,6 +1492,7 @@ class RESTClient(AbstractClient): func: Callable, name: Optional[str] = None, tags: Optional[List[str]] = None, + return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, ) -> Tool: """ Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent. @@ -1500,6 +1501,7 @@ class RESTClient(AbstractClient): func (callable): The function to create a tool for. name: (str): Name of the tool (must be unique per-user.) tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. + return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT. Returns: tool (Tool): The created tool. @@ -1508,7 +1510,9 @@ class RESTClient(AbstractClient): source_type = "python" # call server function - request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags) + request = ToolCreate(source_type=source_type, source_code=source_code, name=name, return_char_limit=return_char_limit) + if tags: + request.tags = tags response = requests.put(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create tool: {response.text}") @@ -1521,6 +1525,7 @@ class RESTClient(AbstractClient): description: Optional[str] = None, func: Optional[Callable] = None, tags: Optional[List[str]] = None, + return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, ) -> Tool: """ Update a tool with provided parameters (name, func, tags) @@ -1530,6 +1535,7 @@ class RESTClient(AbstractClient): name (str): Name of the tool func (callable): Function to wrap in a tool tags (List[str]): Tags for the tool + return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT. Returns: tool (Tool): Updated tool @@ -1541,7 +1547,14 @@ class RESTClient(AbstractClient): source_type = "python" - request = ToolUpdate(description=description, source_type=source_type, source_code=source_code, tags=tags, name=name) + request = ToolUpdate( + description=description, + source_type=source_type, + source_code=source_code, + tags=tags, + name=name, + return_char_limit=return_char_limit, + ) response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to update tool: {response.text}") @@ -2726,6 +2739,7 @@ class LocalClient(AbstractClient): name: Optional[str] = None, tags: Optional[List[str]] = None, description: Optional[str] = None, + return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, ) -> Tool: """ Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent. @@ -2735,6 +2749,7 @@ class LocalClient(AbstractClient): name: (str): Name of the tool (must be unique per-user.) tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. description (str, optional): The description. + return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT. Returns: tool (Tool): The created tool. @@ -2755,6 +2770,7 @@ class LocalClient(AbstractClient): name=name, tags=tags, description=description, + return_char_limit=return_char_limit, ), actor=self.user, ) @@ -2765,6 +2781,7 @@ class LocalClient(AbstractClient): name: Optional[str] = None, tags: Optional[List[str]] = None, description: Optional[str] = None, + return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, ) -> Tool: """ Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent. @@ -2774,6 +2791,7 @@ class LocalClient(AbstractClient): name: (str): Name of the tool (must be unique per-user.) tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. description (str, optional): The description. + return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT. Returns: tool (Tool): The created tool. @@ -2791,6 +2809,7 @@ class LocalClient(AbstractClient): name=name, tags=tags, description=description, + return_char_limit=return_char_limit, ), actor=self.user, ) @@ -2802,6 +2821,7 @@ class LocalClient(AbstractClient): description: Optional[str] = None, func: Optional[callable] = None, tags: Optional[List[str]] = None, + return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, ) -> Tool: """ Update a tool with provided parameters (name, func, tags) @@ -2811,6 +2831,7 @@ class LocalClient(AbstractClient): name (str): Name of the tool func (callable): Function to wrap in a tool tags (List[str]): Tags for the tool + return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT. Returns: tool (Tool): Updated tool @@ -2821,6 +2842,7 @@ class LocalClient(AbstractClient): "tags": tags, "name": name, "description": description, + "return_char_limit": return_char_limit, } # Filter out any None values from the dictionary diff --git a/letta/orm/tool.py b/letta/orm/tool.py index 00038fe0..8f1ac46a 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -30,6 +30,7 @@ class Tool(SqlalchemyBase, OrganizationMixin): __table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),) name: Mapped[str] = mapped_column(doc="The display name of the tool.") + return_char_limit: Mapped[int] = mapped_column(nullable=True, doc="The maximum number of characters the tool can return.") description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.") tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.") source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json) @@ -45,19 +46,16 @@ class Tool(SqlalchemyBase, OrganizationMixin): # Add event listener to update tool_name in ToolsAgents when Tool name changes -@event.listens_for(Tool, 'before_update') +@event.listens_for(Tool, "before_update") def update_tool_name_in_tools_agents(mapper, connection, target): """Update tool_name in ToolsAgents when Tool name changes.""" state = target._sa_instance_state - history = state.get_history('name', passive=True) + history = state.get_history("name", passive=True) if not history.has_changes(): return - + # Get the new name and update all associated ToolsAgents records new_name = target.name from letta.orm.tools_agents import ToolsAgents - connection.execute( - ToolsAgents.__table__.update().where( - ToolsAgents.tool_id == target.id - ).values(tool_name=new_name) - ) + + connection.execute(ToolsAgents.__table__.update().where(ToolsAgents.tool_id == target.id).values(tool_name=new_name)) diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index ed31f9d6..997965ab 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional from pydantic import Field, model_validator +from letta.constants import FUNCTION_RETURN_CHAR_LIMIT from letta.functions.functions import derive_openai_json_schema from letta.functions.helpers import ( generate_composio_tool_wrapper, @@ -41,6 +42,9 @@ class Tool(BaseTool): source_code: str = Field(..., description="The source code of the function.") json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.") + # tool configuration + return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.") + # metadata fields created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") @@ -91,6 +95,7 @@ class ToolCreate(LettaBase): json_schema: Optional[Dict] = Field( None, description="The JSON schema of the function (auto-generated from source_code if not provided)" ) + return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.") @classmethod def from_composio(cls, action_name: str, api_key: Optional[str] = None) -> "ToolCreate": diff --git a/letta/utils.py b/letta/utils.py index 71915420..ad666885 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -28,7 +28,6 @@ from letta.constants import ( CLI_WARNING_PREFIX, CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT, - FUNCTION_RETURN_CHAR_LIMIT, LETTA_DIR, MAX_FILENAME_LENGTH, TOOL_CALL_ID_MAX_LEN, @@ -906,8 +905,8 @@ def parse_json(string) -> dict: raise e -def validate_function_response(function_response_string: any, strict: bool = False, truncate: bool = True) -> str: - """Check to make sure that a function used by Letta returned a valid response +def validate_function_response(function_response_string: any, return_char_limit: int, strict: bool = False, truncate: bool = True) -> str: + """Check to make sure that a function used by Letta returned a valid response. Truncates to return_char_limit if necessary. Responses need to be strings (or None) that fall under a certain text count limit. """ @@ -943,11 +942,11 @@ def validate_function_response(function_response_string: any, strict: bool = Fal # Now check the length and make sure it doesn't go over the limit # TODO we should change this to a max token limit that's variable based on tokens remaining (or context-window) - if truncate and len(function_response_string) > FUNCTION_RETURN_CHAR_LIMIT: + if truncate and len(function_response_string) > return_char_limit: print( - f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT}) and was truncated" + f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {return_char_limit}) and was truncated" ) - function_response_string = f"{function_response_string[:FUNCTION_RETURN_CHAR_LIMIT]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT})]" + function_response_string = f"{function_response_string[:return_char_limit]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {return_char_limit})]" return function_response_string diff --git a/tests/test_client.py b/tests/test_client.py index 61a95f3a..2c92ef95 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,5 @@ import asyncio +import json import os import threading import time @@ -15,6 +16,7 @@ from letta.schemas.agent import AgentState from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.job import JobStatus +from letta.schemas.letta_message import FunctionReturn from letta.schemas.llm_config import LLMConfig from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType from letta.utils import create_random_username @@ -40,7 +42,8 @@ def run_server(): @pytest.fixture( - params=[{"server": False}, {"server": True}], # whether to use REST API server + # params=[{"server": False}, {"server": True}], # whether to use REST API server + params=[{"server": False}], # whether to use REST API server scope="module", ) def client(request): @@ -299,6 +302,40 @@ def test_send_system_message(client: Union[LocalClient, RESTClient], agent: Agen assert send_system_message_response, "Sending message failed" +def test_function_return_limit(client: Union[LocalClient, RESTClient]): + """Test to see if the function return limit works""" + + def big_return(): + """ + Always call this tool. + + Returns: + important_data (str): Important data + """ + return "x" * 100000 + + padding = len("[NOTE: function output was truncated since it exceeded the character limit (100000 > 1000)]") + 50 + tool = client.create_or_update_tool(func=big_return, return_char_limit=1000) + agent = client.create_agent(name="agent1", tools=[tool.name]) + # get function response + response = client.send_message(agent_id=agent.id, message="call the big_return function", role="user") + print(response.messages) + + response_message = None + for message in response.messages: + if isinstance(message, FunctionReturn): + response_message = message + break + + assert response_message, "FunctionReturn message not found in response" + res = response_message.function_return + assert "function output was truncated " in res + res_json = json.loads(res) + assert ( + len(res_json["message"]) <= 1000 + padding + ), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}" + + @pytest.mark.asyncio async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request): """