feat: enable configuration of response_char_limit for tools (#2207)

This commit is contained in:
Sarah Wooders
2024-12-09 18:55:18 -08:00
committed by GitHub
parent fc980ff654
commit d61b2f9545
7 changed files with 133 additions and 27 deletions

View File

@@ -0,0 +1,39 @@
"""add column to tools table to contain function return limit return_char_limit
Revision ID: a91994b9752f
Revises: e1a625072dbf
Create Date: 2024-12-09 18:27:25.650079
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
# revision identifiers, used by Alembic.
revision: str = "a91994b9752f"
down_revision: Union[str, None] = "e1a625072dbf"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("tools", sa.Column("return_char_limit", sa.Integer(), nullable=True))
# Populate `return_char_limit` column
op.execute(
f"""
UPDATE tools
SET return_char_limit = {FUNCTION_RETURN_CHAR_LIMIT}
"""
)
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("tools", "return_char_limit")
# ### end Alembic commands ###

View File

@@ -801,7 +801,13 @@ class Agent(BaseAgent):
# but by default, we add a truncation safeguard to prevent bad functions from
# overflow the agent context window
truncate = True
function_response_string = validate_function_response(function_response, truncate=truncate)
# get the function response limit
tool_obj = [tool for tool in self.agent_state.tools if tool.name == function_name][0]
return_char_limit = tool_obj.return_char_limit
function_response_string = validate_function_response(
function_response, return_char_limit=return_char_limit, truncate=truncate
)
function_args.pop("self", None)
function_response = package_function_response(True, function_response_string)
function_failed = False

View File

@@ -11,6 +11,7 @@ from letta.constants import (
BASE_TOOLS,
DEFAULT_HUMAN,
DEFAULT_PERSONA,
FUNCTION_RETURN_CHAR_LIMIT,
)
from letta.data_sources.connectors import DataConnector
from letta.functions.functions import parse_source_code
@@ -200,18 +201,12 @@ class AbstractClient(object):
raise NotImplementedError
def create_tool(
self,
func,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
self, func, name: Optional[str] = None, tags: Optional[List[str]] = None, return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT
) -> Tool:
raise NotImplementedError
def create_or_update_tool(
self,
func,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
self, func, name: Optional[str] = None, tags: Optional[List[str]] = None, return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT
) -> Tool:
raise NotImplementedError
@@ -222,6 +217,7 @@ class AbstractClient(object):
description: Optional[str] = None,
func: Optional[Callable] = None,
tags: Optional[List[str]] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
raise NotImplementedError
@@ -1465,6 +1461,7 @@ class RESTClient(AbstractClient):
func: Callable,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
"""
Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
@@ -1473,6 +1470,7 @@ class RESTClient(AbstractClient):
func (callable): The function to create a tool for.
name: (str): Name of the tool (must be unique per-user.)
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
Returns:
tool (Tool): The created tool.
@@ -1481,7 +1479,9 @@ class RESTClient(AbstractClient):
source_type = "python"
# call server function
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags)
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, return_char_limit=return_char_limit)
if tags:
request.tags = tags
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create tool: {response.text}")
@@ -1492,6 +1492,7 @@ class RESTClient(AbstractClient):
func: Callable,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
"""
Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
@@ -1500,6 +1501,7 @@ class RESTClient(AbstractClient):
func (callable): The function to create a tool for.
name: (str): Name of the tool (must be unique per-user.)
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
Returns:
tool (Tool): The created tool.
@@ -1508,7 +1510,9 @@ class RESTClient(AbstractClient):
source_type = "python"
# call server function
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags)
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, return_char_limit=return_char_limit)
if tags:
request.tags = tags
response = requests.put(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create tool: {response.text}")
@@ -1521,6 +1525,7 @@ class RESTClient(AbstractClient):
description: Optional[str] = None,
func: Optional[Callable] = None,
tags: Optional[List[str]] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
"""
Update a tool with provided parameters (name, func, tags)
@@ -1530,6 +1535,7 @@ class RESTClient(AbstractClient):
name (str): Name of the tool
func (callable): Function to wrap in a tool
tags (List[str]): Tags for the tool
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
Returns:
tool (Tool): Updated tool
@@ -1541,7 +1547,14 @@ class RESTClient(AbstractClient):
source_type = "python"
request = ToolUpdate(description=description, source_type=source_type, source_code=source_code, tags=tags, name=name)
request = ToolUpdate(
description=description,
source_type=source_type,
source_code=source_code,
tags=tags,
name=name,
return_char_limit=return_char_limit,
)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update tool: {response.text}")
@@ -2726,6 +2739,7 @@ class LocalClient(AbstractClient):
name: Optional[str] = None,
tags: Optional[List[str]] = None,
description: Optional[str] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
"""
Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
@@ -2735,6 +2749,7 @@ class LocalClient(AbstractClient):
name: (str): Name of the tool (must be unique per-user.)
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
description (str, optional): The description.
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
Returns:
tool (Tool): The created tool.
@@ -2755,6 +2770,7 @@ class LocalClient(AbstractClient):
name=name,
tags=tags,
description=description,
return_char_limit=return_char_limit,
),
actor=self.user,
)
@@ -2765,6 +2781,7 @@ class LocalClient(AbstractClient):
name: Optional[str] = None,
tags: Optional[List[str]] = None,
description: Optional[str] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
"""
Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
@@ -2774,6 +2791,7 @@ class LocalClient(AbstractClient):
name: (str): Name of the tool (must be unique per-user.)
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
description (str, optional): The description.
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
Returns:
tool (Tool): The created tool.
@@ -2791,6 +2809,7 @@ class LocalClient(AbstractClient):
name=name,
tags=tags,
description=description,
return_char_limit=return_char_limit,
),
actor=self.user,
)
@@ -2802,6 +2821,7 @@ class LocalClient(AbstractClient):
description: Optional[str] = None,
func: Optional[callable] = None,
tags: Optional[List[str]] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
"""
Update a tool with provided parameters (name, func, tags)
@@ -2811,6 +2831,7 @@ class LocalClient(AbstractClient):
name (str): Name of the tool
func (callable): Function to wrap in a tool
tags (List[str]): Tags for the tool
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
Returns:
tool (Tool): Updated tool
@@ -2821,6 +2842,7 @@ class LocalClient(AbstractClient):
"tags": tags,
"name": name,
"description": description,
"return_char_limit": return_char_limit,
}
# Filter out any None values from the dictionary

View File

@@ -30,6 +30,7 @@ class Tool(SqlalchemyBase, OrganizationMixin):
__table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),)
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
return_char_limit: Mapped[int] = mapped_column(nullable=True, doc="The maximum number of characters the tool can return.")
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json)
@@ -45,19 +46,16 @@ class Tool(SqlalchemyBase, OrganizationMixin):
# Add event listener to update tool_name in ToolsAgents when Tool name changes
@event.listens_for(Tool, 'before_update')
@event.listens_for(Tool, "before_update")
def update_tool_name_in_tools_agents(mapper, connection, target):
"""Update tool_name in ToolsAgents when Tool name changes."""
state = target._sa_instance_state
history = state.get_history('name', passive=True)
history = state.get_history("name", passive=True)
if not history.has_changes():
return
# Get the new name and update all associated ToolsAgents records
new_name = target.name
from letta.orm.tools_agents import ToolsAgents
connection.execute(
ToolsAgents.__table__.update().where(
ToolsAgents.tool_id == target.id
).values(tool_name=new_name)
)
connection.execute(ToolsAgents.__table__.update().where(ToolsAgents.tool_id == target.id).values(tool_name=new_name))

View File

@@ -2,6 +2,7 @@ from typing import Dict, List, Optional
from pydantic import Field, model_validator
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
from letta.functions.functions import derive_openai_json_schema
from letta.functions.helpers import (
generate_composio_tool_wrapper,
@@ -41,6 +42,9 @@ class Tool(BaseTool):
source_code: str = Field(..., description="The source code of the function.")
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")
# tool configuration
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
# metadata fields
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
@@ -91,6 +95,7 @@ class ToolCreate(LettaBase):
json_schema: Optional[Dict] = Field(
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
)
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
@classmethod
def from_composio(cls, action_name: str, api_key: Optional[str] = None) -> "ToolCreate":

View File

@@ -28,7 +28,6 @@ from letta.constants import (
CLI_WARNING_PREFIX,
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
FUNCTION_RETURN_CHAR_LIMIT,
LETTA_DIR,
MAX_FILENAME_LENGTH,
TOOL_CALL_ID_MAX_LEN,
@@ -906,8 +905,8 @@ def parse_json(string) -> dict:
raise e
def validate_function_response(function_response_string: any, strict: bool = False, truncate: bool = True) -> str:
"""Check to make sure that a function used by Letta returned a valid response
def validate_function_response(function_response_string: any, return_char_limit: int, strict: bool = False, truncate: bool = True) -> str:
"""Check to make sure that a function used by Letta returned a valid response. Truncates to return_char_limit if necessary.
Responses need to be strings (or None) that fall under a certain text count limit.
"""
@@ -943,11 +942,11 @@ def validate_function_response(function_response_string: any, strict: bool = Fal
# Now check the length and make sure it doesn't go over the limit
# TODO we should change this to a max token limit that's variable based on tokens remaining (or context-window)
if truncate and len(function_response_string) > FUNCTION_RETURN_CHAR_LIMIT:
if truncate and len(function_response_string) > return_char_limit:
print(
f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT}) and was truncated"
f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {return_char_limit}) and was truncated"
)
function_response_string = f"{function_response_string[:FUNCTION_RETURN_CHAR_LIMIT]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT})]"
function_response_string = f"{function_response_string[:return_char_limit]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {return_char_limit})]"
return function_response_string

View File

@@ -1,4 +1,5 @@
import asyncio
import json
import os
import threading
import time
@@ -15,6 +16,7 @@ from letta.schemas.agent import AgentState
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.job import JobStatus
from letta.schemas.letta_message import FunctionReturn
from letta.schemas.llm_config import LLMConfig
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType
from letta.utils import create_random_username
@@ -40,7 +42,8 @@ def run_server():
@pytest.fixture(
params=[{"server": False}, {"server": True}], # whether to use REST API server
# params=[{"server": False}, {"server": True}], # whether to use REST API server
params=[{"server": False}], # whether to use REST API server
scope="module",
)
def client(request):
@@ -299,6 +302,40 @@ def test_send_system_message(client: Union[LocalClient, RESTClient], agent: Agen
assert send_system_message_response, "Sending message failed"
def test_function_return_limit(client: Union[LocalClient, RESTClient]):
"""Test to see if the function return limit works"""
def big_return():
"""
Always call this tool.
Returns:
important_data (str): Important data
"""
return "x" * 100000
padding = len("[NOTE: function output was truncated since it exceeded the character limit (100000 > 1000)]") + 50
tool = client.create_or_update_tool(func=big_return, return_char_limit=1000)
agent = client.create_agent(name="agent1", tools=[tool.name])
# get function response
response = client.send_message(agent_id=agent.id, message="call the big_return function", role="user")
print(response.messages)
response_message = None
for message in response.messages:
if isinstance(message, FunctionReturn):
response_message = message
break
assert response_message, "FunctionReturn message not found in response"
res = response_message.function_return
assert "function output was truncated " in res
res_json = json.loads(res)
assert (
len(res_json["message"]) <= 1000 + padding
), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}"
@pytest.mark.asyncio
async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request):
"""