chore: patch update tools (#4090)
* patch update tools * update tool patch * fallback to generation for legacy tools * avoid re-parsing source if json schema exists * fix more tests * remove asssert * fix * update * update * update * Fix tests --------- Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -11,7 +11,7 @@ import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import CreateBlock
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client import LettaRequest, MessageCreate, TextContent
|
||||
from letta_client import LettaRequest, MessageCreate, TerminalToolRule, TextContent
|
||||
from letta_client.client import BaseTool
|
||||
from letta_client.core import ApiError
|
||||
from letta_client.types import AgentState, ToolReturnMessage
|
||||
@@ -942,10 +942,22 @@ def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKCl
|
||||
print(f"Updated inventory for {data.item.name} with a quantity change of {quantity_change}")
|
||||
return True
|
||||
|
||||
# test creation
|
||||
tool = client.tools.add(
|
||||
tool=ManageInventoryTool(),
|
||||
)
|
||||
|
||||
# test that upserting also works
|
||||
new_description = "NEW"
|
||||
|
||||
class ManageInventoryToolModified(ManageInventoryTool):
|
||||
description: str = new_description
|
||||
|
||||
tool = client.tools.add(
|
||||
tool=ManageInventoryToolModified(),
|
||||
)
|
||||
assert tool.description == new_description
|
||||
|
||||
assert tool is not None
|
||||
assert tool.name == "manage_inventory"
|
||||
assert "inventory" in tool.tags
|
||||
@@ -1005,7 +1017,7 @@ def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKCl
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
|
||||
@pytest.mark.parametrize("e2b_sandbox_mode", [False], indirect=True)
|
||||
def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
|
||||
class Step(BaseModel):
|
||||
@@ -1021,7 +1033,18 @@ def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
print(f"Created task plan with {len(steps)} steps: {explanation}")
|
||||
return steps
|
||||
|
||||
tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"])
|
||||
# test creation
|
||||
client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"])
|
||||
|
||||
# test upsert
|
||||
new_steps_description = "NEW"
|
||||
|
||||
class StepsListModified(BaseModel):
|
||||
steps: List[Step] = Field(..., description=new_steps_description)
|
||||
explanation: str = Field(..., description="Explanation for the list of steps.")
|
||||
|
||||
tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsListModified, description=new_steps_description)
|
||||
assert tool.description == new_steps_description
|
||||
|
||||
assert tool is not None
|
||||
assert tool.name == "create_task_plan"
|
||||
@@ -1039,6 +1062,9 @@ def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tool_ids=[tool.id],
|
||||
include_base_tools=False,
|
||||
tool_rules=[
|
||||
TerminalToolRule(tool_name=tool.name),
|
||||
],
|
||||
)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
@@ -1062,6 +1088,7 @@ def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
assert first_tool_call.tool_call.name == "create_task_plan"
|
||||
|
||||
args = json.loads(first_tool_call.tool_call.arguments)
|
||||
|
||||
assert "steps" in args
|
||||
assert "explanation" in args
|
||||
assert isinstance(args["steps"], list)
|
||||
@@ -1224,145 +1251,6 @@ def test_agent_tools_list(client: LettaSDKClient):
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
|
||||
def test_update_tool_source_code_changes_name(client: LettaSDKClient):
|
||||
"""Test that updating a tool's source code correctly changes its name"""
|
||||
import textwrap
|
||||
|
||||
# Create initial tool
|
||||
def initial_tool(x: int) -> int:
|
||||
"""
|
||||
Multiply a number by 2
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
Returns:
|
||||
The input multiplied by 2
|
||||
"""
|
||||
return x * 2
|
||||
|
||||
# Create the tool
|
||||
tool = client.tools.upsert_from_function(func=initial_tool)
|
||||
assert tool.name == "initial_tool"
|
||||
|
||||
try:
|
||||
# Define new function source code with different name
|
||||
new_source_code = textwrap.dedent(
|
||||
"""
|
||||
def updated_tool(x: int, y: int) -> int:
|
||||
'''
|
||||
Add two numbers together
|
||||
|
||||
Args:
|
||||
x: First number
|
||||
y: Second number
|
||||
Returns:
|
||||
Sum of x and y
|
||||
'''
|
||||
return x + y
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# Update the tool's source code
|
||||
updated = client.tools.modify(tool_id=tool.id, source_code=new_source_code)
|
||||
|
||||
# Verify the name changed
|
||||
assert updated.name == "updated_tool"
|
||||
assert updated.source_code == new_source_code
|
||||
|
||||
# Verify the schema was updated for the new parameters
|
||||
assert updated.json_schema is not None
|
||||
assert updated.json_schema["name"] == "updated_tool"
|
||||
assert updated.json_schema["description"] == "Add two numbers together"
|
||||
|
||||
# Check parameters
|
||||
params = updated.json_schema.get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
assert "x" in properties
|
||||
assert "y" in properties
|
||||
assert properties["x"]["type"] == "integer"
|
||||
assert properties["y"]["type"] == "integer"
|
||||
assert properties["x"]["description"] == "First number"
|
||||
assert properties["y"]["description"] == "Second number"
|
||||
assert params["required"] == ["x", "y"]
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.tools.delete(tool_id=tool.id)
|
||||
|
||||
|
||||
def test_update_tool_source_code_duplicate_name_error(client: LettaSDKClient):
|
||||
"""Test that updating a tool's source code to have the same name as another existing tool raises an error"""
|
||||
import textwrap
|
||||
|
||||
# Create first tool
|
||||
def first_tool(x: int) -> int:
|
||||
"""
|
||||
Multiply a number by 2
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
|
||||
Returns:
|
||||
The input multiplied by 2
|
||||
"""
|
||||
return x * 2
|
||||
|
||||
# Create second tool
|
||||
def second_tool(x: int) -> int:
|
||||
"""
|
||||
Multiply a number by 3
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
|
||||
Returns:
|
||||
The input multiplied by 3
|
||||
"""
|
||||
return x * 3
|
||||
|
||||
# Create both tools
|
||||
tool1 = client.tools.upsert_from_function(func=first_tool)
|
||||
tool2 = client.tools.upsert_from_function(func=second_tool)
|
||||
|
||||
assert tool1.name == "first_tool"
|
||||
assert tool2.name == "second_tool"
|
||||
|
||||
try:
|
||||
# Try to update second_tool to have the same name as first_tool
|
||||
new_source_code = textwrap.dedent(
|
||||
"""
|
||||
def first_tool(x: int) -> int:
|
||||
'''
|
||||
Multiply a number by 4
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
|
||||
Returns:
|
||||
The input multiplied by 4
|
||||
'''
|
||||
return x * 4
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# This should raise an error since first_tool already exists
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
client.tools.modify(tool_id=tool2.id, source_code=new_source_code)
|
||||
|
||||
# Verify the error message indicates duplicate name
|
||||
error_message = str(exc_info.value)
|
||||
assert "already exists" in error_message.lower() or "duplicate" in error_message.lower() or "conflict" in error_message.lower()
|
||||
|
||||
# Verify that tool2 was not modified
|
||||
tool2_check = client.tools.retrieve(tool_id=tool2.id)
|
||||
assert tool2_check.name == "second_tool" # Name should remain unchanged
|
||||
|
||||
finally:
|
||||
# Clean up both tools
|
||||
client.tools.delete(tool_id=tool1.id)
|
||||
client.tools.delete(tool_id=tool2.id)
|
||||
|
||||
|
||||
def test_add_tool_with_multiple_functions_in_source_code(client: LettaSDKClient):
|
||||
"""Test adding a tool with multiple functions in the source code"""
|
||||
import textwrap
|
||||
@@ -1445,143 +1333,144 @@ def test_add_tool_with_multiple_functions_in_source_code(client: LettaSDKClient)
|
||||
client.tools.delete(tool_id=tool.id)
|
||||
|
||||
|
||||
def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient):
|
||||
"""Test that tool name auto-updates when source code changes with multiple functions"""
|
||||
import textwrap
|
||||
|
||||
# Initial source code with multiple functions
|
||||
initial_source_code = textwrap.dedent(
|
||||
"""
|
||||
def helper_function(x: int) -> int:
|
||||
'''
|
||||
Helper function that doubles the input
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
|
||||
Returns:
|
||||
The input multiplied by 2
|
||||
'''
|
||||
return x * 2
|
||||
|
||||
def another_helper(text: str) -> str:
|
||||
'''
|
||||
Another helper that uppercases text
|
||||
|
||||
Args:
|
||||
text: The input text to uppercase
|
||||
|
||||
Returns:
|
||||
The uppercased text
|
||||
'''
|
||||
return text.upper()
|
||||
|
||||
def main_function(x: int, y: int) -> int:
|
||||
'''
|
||||
Main function that uses the helper
|
||||
|
||||
Args:
|
||||
x: First number
|
||||
y: Second number
|
||||
|
||||
Returns:
|
||||
Result of (x * 2) + y
|
||||
'''
|
||||
doubled_x = helper_function(x)
|
||||
return doubled_x + y
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# Create tool with initial source code
|
||||
tool = client.tools.create(
|
||||
source_code=initial_source_code,
|
||||
)
|
||||
|
||||
try:
|
||||
# Verify the tool was created with the last function's name
|
||||
assert tool is not None
|
||||
assert tool.name == "main_function"
|
||||
assert tool.source_code == initial_source_code
|
||||
|
||||
# Now modify the source code with a different function order
|
||||
new_source_code = textwrap.dedent(
|
||||
"""
|
||||
def process_data(data: str, count: int) -> str:
|
||||
'''
|
||||
Process data by repeating it
|
||||
|
||||
Args:
|
||||
data: The input data
|
||||
count: Number of times to repeat
|
||||
|
||||
Returns:
|
||||
The processed data
|
||||
'''
|
||||
return data * count
|
||||
|
||||
def helper_utility(x: float) -> float:
|
||||
'''
|
||||
Helper utility function
|
||||
|
||||
Args:
|
||||
x: Input value
|
||||
|
||||
Returns:
|
||||
Squared value
|
||||
'''
|
||||
return x * x
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# Modify the tool with new source code
|
||||
modified_tool = client.tools.modify(tool_id=tool.id, source_code=new_source_code)
|
||||
|
||||
# Verify the name automatically updated to the last function
|
||||
assert modified_tool.name == "helper_utility"
|
||||
assert modified_tool.source_code == new_source_code
|
||||
|
||||
# Verify the JSON schema updated correctly
|
||||
assert modified_tool.json_schema is not None
|
||||
assert modified_tool.json_schema["name"] == "helper_utility"
|
||||
assert modified_tool.json_schema["description"] == "Helper utility function"
|
||||
|
||||
# Check parameters updated correctly
|
||||
params = modified_tool.json_schema.get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
assert "x" in properties
|
||||
assert properties["x"]["type"] == "number" # float maps to number
|
||||
assert params["required"] == ["x"]
|
||||
|
||||
# Test one more modification with only one function
|
||||
single_function_code = textwrap.dedent(
|
||||
"""
|
||||
def calculate_total(items: list, tax_rate: float) -> float:
|
||||
'''
|
||||
Calculate total with tax
|
||||
|
||||
Args:
|
||||
items: List of item prices
|
||||
tax_rate: Tax rate as decimal
|
||||
|
||||
Returns:
|
||||
Total including tax
|
||||
'''
|
||||
subtotal = sum(items)
|
||||
return subtotal * (1 + tax_rate)
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# Modify again
|
||||
final_tool = client.tools.modify(tool_id=tool.id, source_code=single_function_code)
|
||||
|
||||
# Verify name updated again
|
||||
assert final_tool.name == "calculate_total"
|
||||
assert final_tool.source_code == single_function_code
|
||||
assert final_tool.json_schema["description"] == "Calculate total with tax"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.tools.delete(tool_id=tool.id)
|
||||
# TODO: add back once behavior is defined
|
||||
# def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient):
|
||||
# """Test that tool name auto-updates when source code changes with multiple functions"""
|
||||
# import textwrap
|
||||
#
|
||||
# # Initial source code with multiple functions
|
||||
# initial_source_code = textwrap.dedent(
|
||||
# """
|
||||
# def helper_function(x: int) -> int:
|
||||
# '''
|
||||
# Helper function that doubles the input
|
||||
#
|
||||
# Args:
|
||||
# x: The input number
|
||||
#
|
||||
# Returns:
|
||||
# The input multiplied by 2
|
||||
# '''
|
||||
# return x * 2
|
||||
#
|
||||
# def another_helper(text: str) -> str:
|
||||
# '''
|
||||
# Another helper that uppercases text
|
||||
#
|
||||
# Args:
|
||||
# text: The input text to uppercase
|
||||
#
|
||||
# Returns:
|
||||
# The uppercased text
|
||||
# '''
|
||||
# return text.upper()
|
||||
#
|
||||
# def main_function(x: int, y: int) -> int:
|
||||
# '''
|
||||
# Main function that uses the helper
|
||||
#
|
||||
# Args:
|
||||
# x: First number
|
||||
# y: Second number
|
||||
#
|
||||
# Returns:
|
||||
# Result of (x * 2) + y
|
||||
# '''
|
||||
# doubled_x = helper_function(x)
|
||||
# return doubled_x + y
|
||||
# """
|
||||
# ).strip()
|
||||
#
|
||||
# # Create tool with initial source code
|
||||
# tool = client.tools.create(
|
||||
# source_code=initial_source_code,
|
||||
# )
|
||||
#
|
||||
# try:
|
||||
# # Verify the tool was created with the last function's name
|
||||
# assert tool is not None
|
||||
# assert tool.name == "main_function"
|
||||
# assert tool.source_code == initial_source_code
|
||||
#
|
||||
# # Now modify the source code with a different function order
|
||||
# new_source_code = textwrap.dedent(
|
||||
# """
|
||||
# def process_data(data: str, count: int) -> str:
|
||||
# '''
|
||||
# Process data by repeating it
|
||||
#
|
||||
# Args:
|
||||
# data: The input data
|
||||
# count: Number of times to repeat
|
||||
#
|
||||
# Returns:
|
||||
# The processed data
|
||||
# '''
|
||||
# return data * count
|
||||
#
|
||||
# def helper_utility(x: float) -> float:
|
||||
# '''
|
||||
# Helper utility function
|
||||
#
|
||||
# Args:
|
||||
# x: Input value
|
||||
#
|
||||
# Returns:
|
||||
# Squared value
|
||||
# '''
|
||||
# return x * x
|
||||
# """
|
||||
# ).strip()
|
||||
#
|
||||
# # Modify the tool with new source code
|
||||
# modified_tool = client.tools.modify(name="helper_utility", tool_id=tool.id, source_code=new_source_code)
|
||||
#
|
||||
# # Verify the name automatically updated to the last function
|
||||
# assert modified_tool.name == "helper_utility"
|
||||
# assert modified_tool.source_code == new_source_code
|
||||
#
|
||||
# # Verify the JSON schema updated correctly
|
||||
# assert modified_tool.json_schema is not None
|
||||
# assert modified_tool.json_schema["name"] == "helper_utility"
|
||||
# assert modified_tool.json_schema["description"] == "Helper utility function"
|
||||
#
|
||||
# # Check parameters updated correctly
|
||||
# params = modified_tool.json_schema.get("parameters", {})
|
||||
# properties = params.get("properties", {})
|
||||
# assert "x" in properties
|
||||
# assert properties["x"]["type"] == "number" # float maps to number
|
||||
# assert params["required"] == ["x"]
|
||||
#
|
||||
# # Test one more modification with only one function
|
||||
# single_function_code = textwrap.dedent(
|
||||
# """
|
||||
# def calculate_total(items: list, tax_rate: float) -> float:
|
||||
# '''
|
||||
# Calculate total with tax
|
||||
#
|
||||
# Args:
|
||||
# items: List of item prices
|
||||
# tax_rate: Tax rate as decimal
|
||||
#
|
||||
# Returns:
|
||||
# Total including tax
|
||||
# '''
|
||||
# subtotal = sum(items)
|
||||
# return subtotal * (1 + tax_rate)
|
||||
# """
|
||||
# ).strip()
|
||||
#
|
||||
# # Modify again
|
||||
# final_tool = client.tools.modify(tool_id=tool.id, source_code=single_function_code)
|
||||
#
|
||||
# # Verify name updated again
|
||||
# assert final_tool.name == "calculate_total"
|
||||
# assert final_tool.source_code == single_function_code
|
||||
# assert final_tool.json_schema["description"] == "Calculate total with tax"
|
||||
#
|
||||
# finally:
|
||||
# # Clean up
|
||||
# client.tools.delete(tool_id=tool.id)
|
||||
|
||||
|
||||
def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
|
||||
@@ -1637,28 +1526,16 @@ def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
|
||||
},
|
||||
}
|
||||
|
||||
# Modify the tool with both new source code AND JSON schema
|
||||
modified_tool = client.tools.modify(tool_id=tool.id, source_code=new_source_code, json_schema=custom_json_schema)
|
||||
# verify there is a 400 error when both source code and json schema are provided
|
||||
with pytest.raises(Exception) as e:
|
||||
client.tools.modify(tool_id=tool.id, source_code=new_source_code, json_schema=custom_json_schema)
|
||||
assert e.value.status_code == 400
|
||||
|
||||
# Verify the name comes from the source code function name, not the JSON schema
|
||||
assert modified_tool.name == "renamed_function"
|
||||
assert modified_tool.source_code == new_source_code
|
||||
|
||||
# Verify the JSON schema was updated to match the function name from source code
|
||||
assert modified_tool.json_schema is not None
|
||||
assert modified_tool.json_schema["name"] == "renamed_function"
|
||||
|
||||
# The description should come from the source code docstring, not the JSON schema
|
||||
assert modified_tool.json_schema["description"] == "Multiply a value by a multiplier"
|
||||
|
||||
# Verify parameters are from the source code, not the custom JSON schema
|
||||
params = modified_tool.json_schema.get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
assert "value" in properties
|
||||
assert "multiplier" in properties
|
||||
assert properties["value"]["type"] == "number"
|
||||
assert properties["multiplier"]["type"] == "number"
|
||||
assert params["required"] == ["value"]
|
||||
# update with consistent name and schema
|
||||
custom_json_schema["name"] = "renamed_function"
|
||||
tool = client.tools.modify(tool_id=tool.id, json_schema=custom_json_schema)
|
||||
assert tool.json_schema == custom_json_schema
|
||||
assert tool.name == "renamed_function"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
@@ -2006,3 +1883,58 @@ def test_import_agent_with_files_from_disk(client: LettaSDKClient):
|
||||
# Clean up agents and sources
|
||||
client.agents.delete(agent_id=imported_agent_id)
|
||||
client.sources.delete(source_id=imported_source.id)
|
||||
|
||||
|
||||
def test_upsert_tools(client: LettaSDKClient):
|
||||
"""Test upserting tools with complex schemas."""
|
||||
from typing import List
|
||||
|
||||
class WriteReasonOffer(BaseModel):
|
||||
biltMerchantId: str = Field(..., description="The merchant ID (e.g. 'MERCHANT_NETWORK-123' or 'LYFT')")
|
||||
campaignId: str = Field(
|
||||
...,
|
||||
description="The campaign ID (e.g. '550e8400-e29b-41d4-a716-446655440000' or '550e8400-e29b-41d4-a716-446655440000_123e4567-e89b-12d3-a456-426614174000')",
|
||||
)
|
||||
reason: str = Field(
|
||||
...,
|
||||
description="A detailed explanation of why this offer is relevant to the user. Refer to the category-specific reason_instructions_{category} block for all guidelines on creating personalized reasons.",
|
||||
)
|
||||
|
||||
class WriteReasonArgs(BaseModel):
|
||||
"""Arguments for the write_reason tool."""
|
||||
|
||||
offer_list: List[WriteReasonOffer] = Field(
|
||||
...,
|
||||
description="List of WriteReasonOffer objects with merchant and campaign information",
|
||||
)
|
||||
|
||||
def write_reason(offer_list: List[WriteReasonOffer]):
|
||||
"""
|
||||
This tool is used to write detailed reasons for a list of offers.
|
||||
It returns the essential information: biltMerchantId, campaignId, and reason.
|
||||
|
||||
IMPORTANT: When generating reasons, you MUST ONLY follow the guidelines in the
|
||||
category-specific instruction block named "reason_instructions_{category}" where
|
||||
{category} is the category of the offer (e.g., dining, travel, shopping).
|
||||
|
||||
These instruction blocks contain all the necessary guidelines for creating
|
||||
personalized, detailed reasons for each category. Do not rely on any other
|
||||
instructions outside of these blocks.
|
||||
|
||||
Args:
|
||||
offer_list: List of WriteReasonOffer objects, each containing:
|
||||
- biltMerchantId: The merchant ID (e.g. 'MERCHANT_NETWORK-123' or 'LYFT')
|
||||
- campaignId: The campaign ID (e.g. '124', '28')
|
||||
- reason: A detailed explanation generated according to the category-specific reason_instructions_{category} block
|
||||
|
||||
Returns:
|
||||
None: This function prints the offer list but does not return a value.
|
||||
"""
|
||||
print(offer_list)
|
||||
|
||||
tool = client.tools.upsert_from_function(func=write_reason, args_schema=WriteReasonArgs)
|
||||
assert tool is not None
|
||||
assert tool.name == "write_reason"
|
||||
|
||||
# Clean up
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
Reference in New Issue
Block a user