feat: enable configuration of response_char_limit for tools (#2207)
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
"""add column to tools table to contain function return limit return_char_limit
|
||||
|
||||
Revision ID: a91994b9752f
|
||||
Revises: e1a625072dbf
|
||||
Create Date: 2024-12-09 18:27:25.650079
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a91994b9752f"
|
||||
down_revision: Union[str, None] = "e1a625072dbf"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("tools", sa.Column("return_char_limit", sa.Integer(), nullable=True))
|
||||
|
||||
# Populate `return_char_limit` column
|
||||
op.execute(
|
||||
f"""
|
||||
UPDATE tools
|
||||
SET return_char_limit = {FUNCTION_RETURN_CHAR_LIMIT}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("tools", "return_char_limit")
|
||||
# ### end Alembic commands ###
|
||||
@@ -801,7 +801,13 @@ class Agent(BaseAgent):
|
||||
# but by default, we add a truncation safeguard to prevent bad functions from
|
||||
# overflow the agent context window
|
||||
truncate = True
|
||||
function_response_string = validate_function_response(function_response, truncate=truncate)
|
||||
|
||||
# get the function response limit
|
||||
tool_obj = [tool for tool in self.agent_state.tools if tool.name == function_name][0]
|
||||
return_char_limit = tool_obj.return_char_limit
|
||||
function_response_string = validate_function_response(
|
||||
function_response, return_char_limit=return_char_limit, truncate=truncate
|
||||
)
|
||||
function_args.pop("self", None)
|
||||
function_response = package_function_response(True, function_response_string)
|
||||
function_failed = False
|
||||
|
||||
@@ -11,6 +11,7 @@ from letta.constants import (
|
||||
BASE_TOOLS,
|
||||
DEFAULT_HUMAN,
|
||||
DEFAULT_PERSONA,
|
||||
FUNCTION_RETURN_CHAR_LIMIT,
|
||||
)
|
||||
from letta.data_sources.connectors import DataConnector
|
||||
from letta.functions.functions import parse_source_code
|
||||
@@ -200,18 +201,12 @@ class AbstractClient(object):
|
||||
raise NotImplementedError
|
||||
|
||||
def create_tool(
|
||||
self,
|
||||
func,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
self, func, name: Optional[str] = None, tags: Optional[List[str]] = None, return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT
|
||||
) -> Tool:
|
||||
raise NotImplementedError
|
||||
|
||||
def create_or_update_tool(
|
||||
self,
|
||||
func,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
self, func, name: Optional[str] = None, tags: Optional[List[str]] = None, return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT
|
||||
) -> Tool:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -222,6 +217,7 @@ class AbstractClient(object):
|
||||
description: Optional[str] = None,
|
||||
func: Optional[Callable] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
|
||||
) -> Tool:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1465,6 +1461,7 @@ class RESTClient(AbstractClient):
|
||||
func: Callable,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
|
||||
) -> Tool:
|
||||
"""
|
||||
Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
|
||||
@@ -1473,6 +1470,7 @@ class RESTClient(AbstractClient):
|
||||
func (callable): The function to create a tool for.
|
||||
name: (str): Name of the tool (must be unique per-user.)
|
||||
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
|
||||
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
|
||||
|
||||
Returns:
|
||||
tool (Tool): The created tool.
|
||||
@@ -1481,7 +1479,9 @@ class RESTClient(AbstractClient):
|
||||
source_type = "python"
|
||||
|
||||
# call server function
|
||||
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags)
|
||||
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, return_char_limit=return_char_limit)
|
||||
if tags:
|
||||
request.tags = tags
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create tool: {response.text}")
|
||||
@@ -1492,6 +1492,7 @@ class RESTClient(AbstractClient):
|
||||
func: Callable,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
|
||||
) -> Tool:
|
||||
"""
|
||||
Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
|
||||
@@ -1500,6 +1501,7 @@ class RESTClient(AbstractClient):
|
||||
func (callable): The function to create a tool for.
|
||||
name: (str): Name of the tool (must be unique per-user.)
|
||||
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
|
||||
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
|
||||
|
||||
Returns:
|
||||
tool (Tool): The created tool.
|
||||
@@ -1508,7 +1510,9 @@ class RESTClient(AbstractClient):
|
||||
source_type = "python"
|
||||
|
||||
# call server function
|
||||
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags)
|
||||
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, return_char_limit=return_char_limit)
|
||||
if tags:
|
||||
request.tags = tags
|
||||
response = requests.put(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create tool: {response.text}")
|
||||
@@ -1521,6 +1525,7 @@ class RESTClient(AbstractClient):
|
||||
description: Optional[str] = None,
|
||||
func: Optional[Callable] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
|
||||
) -> Tool:
|
||||
"""
|
||||
Update a tool with provided parameters (name, func, tags)
|
||||
@@ -1530,6 +1535,7 @@ class RESTClient(AbstractClient):
|
||||
name (str): Name of the tool
|
||||
func (callable): Function to wrap in a tool
|
||||
tags (List[str]): Tags for the tool
|
||||
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
|
||||
|
||||
Returns:
|
||||
tool (Tool): Updated tool
|
||||
@@ -1541,7 +1547,14 @@ class RESTClient(AbstractClient):
|
||||
|
||||
source_type = "python"
|
||||
|
||||
request = ToolUpdate(description=description, source_type=source_type, source_code=source_code, tags=tags, name=name)
|
||||
request = ToolUpdate(
|
||||
description=description,
|
||||
source_type=source_type,
|
||||
source_code=source_code,
|
||||
tags=tags,
|
||||
name=name,
|
||||
return_char_limit=return_char_limit,
|
||||
)
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update tool: {response.text}")
|
||||
@@ -2726,6 +2739,7 @@ class LocalClient(AbstractClient):
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
description: Optional[str] = None,
|
||||
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
|
||||
) -> Tool:
|
||||
"""
|
||||
Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
|
||||
@@ -2735,6 +2749,7 @@ class LocalClient(AbstractClient):
|
||||
name: (str): Name of the tool (must be unique per-user.)
|
||||
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
|
||||
description (str, optional): The description.
|
||||
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
|
||||
|
||||
Returns:
|
||||
tool (Tool): The created tool.
|
||||
@@ -2755,6 +2770,7 @@ class LocalClient(AbstractClient):
|
||||
name=name,
|
||||
tags=tags,
|
||||
description=description,
|
||||
return_char_limit=return_char_limit,
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
@@ -2765,6 +2781,7 @@ class LocalClient(AbstractClient):
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
description: Optional[str] = None,
|
||||
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
|
||||
) -> Tool:
|
||||
"""
|
||||
Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
|
||||
@@ -2774,6 +2791,7 @@ class LocalClient(AbstractClient):
|
||||
name: (str): Name of the tool (must be unique per-user.)
|
||||
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
|
||||
description (str, optional): The description.
|
||||
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
|
||||
|
||||
Returns:
|
||||
tool (Tool): The created tool.
|
||||
@@ -2791,6 +2809,7 @@ class LocalClient(AbstractClient):
|
||||
name=name,
|
||||
tags=tags,
|
||||
description=description,
|
||||
return_char_limit=return_char_limit,
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
@@ -2802,6 +2821,7 @@ class LocalClient(AbstractClient):
|
||||
description: Optional[str] = None,
|
||||
func: Optional[callable] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
|
||||
) -> Tool:
|
||||
"""
|
||||
Update a tool with provided parameters (name, func, tags)
|
||||
@@ -2811,6 +2831,7 @@ class LocalClient(AbstractClient):
|
||||
name (str): Name of the tool
|
||||
func (callable): Function to wrap in a tool
|
||||
tags (List[str]): Tags for the tool
|
||||
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
|
||||
|
||||
Returns:
|
||||
tool (Tool): Updated tool
|
||||
@@ -2821,6 +2842,7 @@ class LocalClient(AbstractClient):
|
||||
"tags": tags,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"return_char_limit": return_char_limit,
|
||||
}
|
||||
|
||||
# Filter out any None values from the dictionary
|
||||
|
||||
@@ -30,6 +30,7 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
__table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),)
|
||||
|
||||
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
|
||||
return_char_limit: Mapped[int] = mapped_column(nullable=True, doc="The maximum number of characters the tool can return.")
|
||||
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
|
||||
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
|
||||
source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json)
|
||||
@@ -45,19 +46,16 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
|
||||
# Add event listener to update tool_name in ToolsAgents when Tool name changes
|
||||
@event.listens_for(Tool, 'before_update')
|
||||
@event.listens_for(Tool, "before_update")
|
||||
def update_tool_name_in_tools_agents(mapper, connection, target):
|
||||
"""Update tool_name in ToolsAgents when Tool name changes."""
|
||||
state = target._sa_instance_state
|
||||
history = state.get_history('name', passive=True)
|
||||
history = state.get_history("name", passive=True)
|
||||
if not history.has_changes():
|
||||
return
|
||||
|
||||
|
||||
# Get the new name and update all associated ToolsAgents records
|
||||
new_name = target.name
|
||||
from letta.orm.tools_agents import ToolsAgents
|
||||
connection.execute(
|
||||
ToolsAgents.__table__.update().where(
|
||||
ToolsAgents.tool_id == target.id
|
||||
).values(tool_name=new_name)
|
||||
)
|
||||
|
||||
connection.execute(ToolsAgents.__table__.update().where(ToolsAgents.tool_id == target.id).values(tool_name=new_name))
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
|
||||
from letta.functions.functions import derive_openai_json_schema
|
||||
from letta.functions.helpers import (
|
||||
generate_composio_tool_wrapper,
|
||||
@@ -41,6 +42,9 @@ class Tool(BaseTool):
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")
|
||||
|
||||
# tool configuration
|
||||
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
|
||||
|
||||
# metadata fields
|
||||
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
@@ -91,6 +95,7 @@ class ToolCreate(LettaBase):
|
||||
json_schema: Optional[Dict] = Field(
|
||||
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
|
||||
)
|
||||
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
|
||||
|
||||
@classmethod
|
||||
def from_composio(cls, action_name: str, api_key: Optional[str] = None) -> "ToolCreate":
|
||||
|
||||
@@ -28,7 +28,6 @@ from letta.constants import (
|
||||
CLI_WARNING_PREFIX,
|
||||
CORE_MEMORY_HUMAN_CHAR_LIMIT,
|
||||
CORE_MEMORY_PERSONA_CHAR_LIMIT,
|
||||
FUNCTION_RETURN_CHAR_LIMIT,
|
||||
LETTA_DIR,
|
||||
MAX_FILENAME_LENGTH,
|
||||
TOOL_CALL_ID_MAX_LEN,
|
||||
@@ -906,8 +905,8 @@ def parse_json(string) -> dict:
|
||||
raise e
|
||||
|
||||
|
||||
def validate_function_response(function_response_string: any, strict: bool = False, truncate: bool = True) -> str:
|
||||
"""Check to make sure that a function used by Letta returned a valid response
|
||||
def validate_function_response(function_response_string: any, return_char_limit: int, strict: bool = False, truncate: bool = True) -> str:
|
||||
"""Check to make sure that a function used by Letta returned a valid response. Truncates to return_char_limit if necessary.
|
||||
|
||||
Responses need to be strings (or None) that fall under a certain text count limit.
|
||||
"""
|
||||
@@ -943,11 +942,11 @@ def validate_function_response(function_response_string: any, strict: bool = Fal
|
||||
|
||||
# Now check the length and make sure it doesn't go over the limit
|
||||
# TODO we should change this to a max token limit that's variable based on tokens remaining (or context-window)
|
||||
if truncate and len(function_response_string) > FUNCTION_RETURN_CHAR_LIMIT:
|
||||
if truncate and len(function_response_string) > return_char_limit:
|
||||
print(
|
||||
f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT}) and was truncated"
|
||||
f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {return_char_limit}) and was truncated"
|
||||
)
|
||||
function_response_string = f"{function_response_string[:FUNCTION_RETURN_CHAR_LIMIT]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT})]"
|
||||
function_response_string = f"{function_response_string[:return_char_limit]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {return_char_limit})]"
|
||||
|
||||
return function_response_string
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
@@ -15,6 +16,7 @@ from letta.schemas.agent import AgentState
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.job import JobStatus
|
||||
from letta.schemas.letta_message import FunctionReturn
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType
|
||||
from letta.utils import create_random_username
|
||||
@@ -40,7 +42,8 @@ def run_server():
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[{"server": False}, {"server": True}], # whether to use REST API server
|
||||
# params=[{"server": False}, {"server": True}], # whether to use REST API server
|
||||
params=[{"server": False}], # whether to use REST API server
|
||||
scope="module",
|
||||
)
|
||||
def client(request):
|
||||
@@ -299,6 +302,40 @@ def test_send_system_message(client: Union[LocalClient, RESTClient], agent: Agen
|
||||
assert send_system_message_response, "Sending message failed"
|
||||
|
||||
|
||||
def test_function_return_limit(client: Union[LocalClient, RESTClient]):
|
||||
"""Test to see if the function return limit works"""
|
||||
|
||||
def big_return():
|
||||
"""
|
||||
Always call this tool.
|
||||
|
||||
Returns:
|
||||
important_data (str): Important data
|
||||
"""
|
||||
return "x" * 100000
|
||||
|
||||
padding = len("[NOTE: function output was truncated since it exceeded the character limit (100000 > 1000)]") + 50
|
||||
tool = client.create_or_update_tool(func=big_return, return_char_limit=1000)
|
||||
agent = client.create_agent(name="agent1", tools=[tool.name])
|
||||
# get function response
|
||||
response = client.send_message(agent_id=agent.id, message="call the big_return function", role="user")
|
||||
print(response.messages)
|
||||
|
||||
response_message = None
|
||||
for message in response.messages:
|
||||
if isinstance(message, FunctionReturn):
|
||||
response_message = message
|
||||
break
|
||||
|
||||
assert response_message, "FunctionReturn message not found in response"
|
||||
res = response_message.function_return
|
||||
assert "function output was truncated " in res
|
||||
res_json = json.loads(res)
|
||||
assert (
|
||||
len(res_json["message"]) <= 1000 + padding
|
||||
), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user