chore: patch update tools (#4090)

* patch update tools

* update tool patch

* fallback to generation for legacy tools

* avoid re-parsing source if json schema exists

* fix more tests

* remove asssert

* fix

* update

* update

* update

* Fix tests

---------

Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
Sarah Wooders
2025-08-28 11:34:36 -07:00
committed by GitHub
parent 2d6727673a
commit b4fc7012cc
6 changed files with 279 additions and 569 deletions

View File

@@ -60,6 +60,15 @@ class LettaToolNameConflictError(LettaError):
)
class LettaToolNameSchemaMismatchError(LettaToolCreateError):
"""Error raised when a tool name our source codedoes not match the name in the JSON schema."""
def __init__(self, tool_name: str, json_schema_name: str, source_code: str):
super().__init__(
message=f"Tool name '{tool_name}' does not match the name in the JSON schema '{json_schema_name}' or in the source code `{source_code}`",
)
class LettaConfigurationError(LettaError):
"""Error raised when there are configuration-related issues."""

View File

@@ -80,7 +80,8 @@ class Tool(BaseTool):
"""
from letta.functions.helpers import generate_model_from_args_json_schema
if self.tool_type == ToolType.CUSTOM:
if self.tool_type == ToolType.CUSTOM and not self.json_schema:
# attempt various fallbacks to get the JSON schema
if not self.source_code:
logger.error("Custom tool with id=%s is missing source_code field", self.id)
raise ValueError(f"Custom tool with id={self.id} is missing source_code field.")
@@ -157,7 +158,7 @@ class Tool(BaseTool):
class ToolCreate(LettaBase):
description: Optional[str] = Field(None, description="The description of the tool.")
tags: List[str] = Field([], description="Metadata tags.")
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
source_code: str = Field(..., description="The source code of the function.")
source_type: str = Field("python", description="The source type of the function.")
json_schema: Optional[Dict] = Field(

View File

@@ -130,7 +130,7 @@ async def create_tool(
"""
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
tool = Tool(**request.model_dump())
tool = Tool(**request.model_dump(exclude_unset=True))
return await server.tool_manager.create_tool_async(pydantic_tool=tool, actor=actor)
except UniqueConstraintViolationError as e:
# Log or print the full exception here for debugging
@@ -162,7 +162,9 @@ async def upsert_tool(
"""
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
tool = await server.tool_manager.create_or_update_tool_async(pydantic_tool=Tool(**request.model_dump()), actor=actor)
tool = await server.tool_manager.create_or_update_tool_async(
pydantic_tool=Tool(**request.model_dump(exclude_unset=True)), actor=actor
)
return tool
except UniqueConstraintViolationError as e:
# Log the error and raise a conflict exception
@@ -190,18 +192,17 @@ async def modify_tool(
"""
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.tool_manager.update_tool_by_id_async(tool_id=tool_id, tool_update=request, actor=actor)
tool = await server.tool_manager.update_tool_by_id_async(tool_id=tool_id, tool_update=request, actor=actor)
print("FINAL TOOL", tool)
return tool
except LettaToolNameConflictError as e:
# HTTP 409 == Conflict
print(f"Tool name conflict during update: {e}")
raise HTTPException(status_code=409, detail=str(e))
except LettaToolCreateError as e:
# HTTP 400 == Bad Request
print(f"Error occurred during tool update: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
# Catch other unexpected errors and raise an internal server error
print(f"Unexpected error occurred: {e}")
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")

View File

@@ -18,7 +18,7 @@ from letta.constants import (
LOCAL_ONLY_MULTI_AGENT_TOOLS,
MCP_TOOL_TAG_NAME_PREFIX,
)
from letta.errors import LettaToolNameConflictError
from letta.errors import LettaToolNameConflictError, LettaToolNameSchemaMismatchError
from letta.functions.functions import derive_openai_json_schema, load_function_set
from letta.log import get_logger
@@ -403,6 +403,7 @@ class ToolManager:
updated_tool_type: Optional[ToolType] = None,
bypass_name_check: bool = False,
) -> PydanticTool:
# TODO: remove this (legacy non-async)
"""
Update a tool with complex validation and schema derivation logic.
@@ -519,55 +520,36 @@ class ToolManager:
# Fetch current tool early to allow conditional logic based on tool type
current_tool = await self.get_tool_by_id_async(tool_id=tool_id, actor=actor)
# For MCP tools, do NOT derive schema from Python source. Trust provided JSON schema.
if current_tool.tool_type == ToolType.EXTERNAL_MCP:
# Prefer provided json_schema; fall back to current
if "json_schema" in update_data:
new_schema = update_data["json_schema"].copy()
new_name = new_schema.get("name", current_tool.name)
else:
new_schema = current_tool.json_schema
new_name = current_tool.name
# Ensure we don't trigger derive
update_data.pop("source_code", None)
# If name changes, enforce uniqueness
if new_name != current_tool.name:
name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
if name_exists:
raise LettaToolNameConflictError(tool_name=new_name)
# Do NOT derive schema from Python source. Trust provided JSON schema.
# Prefer provided json_schema; fall back to current
if "json_schema" in update_data:
new_schema = update_data["json_schema"].copy()
new_name = new_schema.get("name", current_tool.name)
else:
# For non-MCP tools, preserve existing behavior
# TODO: Consider this behavior...is this what we want?
# TODO: I feel like it's bad if json_schema strays from source code so
# if source code is provided, always derive the name from it
if "source_code" in update_data.keys() and not bypass_name_check:
# Check source type to use appropriate parser
source_type = update_data.get("source_type", current_tool.source_type)
if source_type == "typescript":
from letta.functions.typescript_parser import derive_typescript_json_schema
new_schema = current_tool.json_schema
new_name = current_tool.name
derived_schema = derive_typescript_json_schema(source_code=update_data["source_code"])
else:
# Default to Python for backwards compatibility
derived_schema = derive_openai_json_schema(source_code=update_data["source_code"])
new_name = derived_schema["name"]
# original tool may no have a JSON schema at all for legacy reasons
# in this case, fallback to dangerous schema generation
if new_schema is None:
if source_type == "typescript":
from letta.functions.typescript_parser import derive_typescript_json_schema
# if json_schema wasn't provided, use the derived schema
if "json_schema" not in update_data.keys():
new_schema = derived_schema
else:
# if json_schema was provided, update only its name to match the source code
new_schema = update_data["json_schema"].copy()
new_schema["name"] = new_name
# update the json_schema in update_data so it gets applied in the loop
update_data["json_schema"] = new_schema
new_schema = derive_typescript_json_schema(source_code=update_data["source_code"])
else:
new_schema = derive_openai_json_schema(source_code=update_data["source_code"])
# check if the name is changing and if so, verify it doesn't conflict
if new_name != current_tool.name:
# check if a tool with the new name already exists
name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
if name_exists:
raise LettaToolNameConflictError(tool_name=new_name)
# If name changes, enforce uniqueness
if new_name != current_tool.name:
name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
if name_exists:
raise LettaToolNameConflictError(tool_name=new_name)
# NOTE: EXTREMELEY HACKY, we need to stop making assumptions about the source_code
if "source_code" in update_data and f"def {new_name}" not in update_data.get("source_code", ""):
raise LettaToolNameSchemaMismatchError(
tool_name=new_name, json_schema_name=new_schema.get("name"), source_code=update_data.get("source_code")
)
# Now perform the update within the session
async with db_registry.async_session() as session:

View File

@@ -1,215 +0,0 @@
import asyncio
import json
import os
import threading
import time
import pytest
import requests
from dotenv import load_dotenv
from letta_client import AsyncLetta, MessageCreate, ReasoningMessage, ToolCallMessage
from letta_client.core import RequestOptions
from tests.helpers.utils import upload_test_agentfile_from_disk_async
REASONING_THROTTLE_MS = 100
TEST_USER_MESSAGE = "What products or services does 11x AI sell?"
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
@pytest.fixture(scope="function")
def client(server_url: str):
"""
Creates and returns an asynchronous Letta REST client for testing.
"""
async_client_instance = AsyncLetta(base_url=server_url)
yield async_client_instance
async def test_pinecone_tool(client: AsyncLetta, server_url: str) -> None:
"""
Test the Pinecone tool integration with the Letta client.
"""
response = await upload_test_agentfile_from_disk_async(client, "knowledge-base.af")
agent_id = response.agent_ids[0]
agent = await client.agents.modify(
agent_id=agent_id,
tool_exec_environment_variables={
"PINECONE_INDEX_HOST": os.getenv("PINECONE_INDEX_HOST"),
"PINECONE_API_KEY": os.getenv("PINECONE_API_KEY"),
"PINECONE_NAMESPACE": os.getenv("PINECONE_NAMESPACE"),
},
)
last_message = await client.agents.messages.list(
agent_id=agent.id,
limit=1,
)
curr_message_type = None
messages = []
reasoning_content = []
last_reasoning_update_ms = 0
tool_call_content = ""
tool_return_content = ""
summary = None
pinecone_results = None
queries = []
try:
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
MessageCreate(
role="user",
content=TEST_USER_MESSAGE,
),
],
stream_tokens=True,
request_options=RequestOptions(
timeout_in_seconds=1000,
),
)
async for chunk in response:
if chunk.message_type != curr_message_type:
messages.append(chunk)
curr_message_type = chunk.message_type
if curr_message_type == "reasoning_message":
reasoning_content = []
if curr_message_type == "tool_call_message":
tool_call_content = ""
if chunk.message_type == "reasoning_message":
now_ms = time.time_ns() // 1_000_000
if now_ms - last_reasoning_update_ms < REASONING_THROTTLE_MS:
await asyncio.sleep(REASONING_THROTTLE_MS / 1000)
last_reasoning_update_ms = now_ms
if len(reasoning_content) == 0:
reasoning_content = [chunk.reasoning]
else:
reasoning_content[-1] += chunk.reasoning
message_dict = messages[-1].model_dump()
message_dict["reasoning"] = "".join(reasoning_content).strip()
messages[-1] = ReasoningMessage(**message_dict)
if chunk.message_type == "tool_return_message":
tool_return_content += chunk.tool_return
if chunk.status == "success":
try:
if chunk.name == "summarize_pinecone_results":
json_response = json.loads(chunk.tool_return)
summary = json_response.get("summary", None)
pinecone_results = json_response.get("pinecone_results", None)
tool_return_content = ""
elif chunk.name == "craft_queries":
queries.append(chunk.tool_return)
tool_return_content = ""
except Exception as e:
print(f"Error parsing JSON response: {str(e)}. {chunk.tool_return}\n")
tool_return_content = ""
if chunk.message_type == "tool_call_message":
if chunk.tool_call.arguments is not None:
tool_call_content += chunk.tool_call.arguments
message_dict = messages[-1].model_dump()
message_dict["tool_call"]["arguments"] = tool_call_content
messages[-1] = ToolCallMessage(**message_dict)
except Exception as e:
print(f"Failed to fetch knowledge base response: {str(e)}\n")
print(tool_call_content)
raise e
assert len(messages) > 0, "No messages received from the agent."
assert len(reasoning_content) > 0, "No reasoning content received from the agent."
assert summary is not None, "No summary received from the agent."
assert pinecone_results is not None, "No Pinecone results received from the agent."
assert len(queries) > 0, "No queries received from the agent."
assert messages[-2].message_type == "stop_reason", "Penultimate message in stream must be stop reason."
assert messages[-1].message_type == "usage_statistics", "Last message in stream must be usage stats."
response_messages_from_stream = [m for m in messages if m.message_type not in ["stop_reason", "usage_statistics"]]
response_message_types_from_stream = [m.message_type for m in response_messages_from_stream]
messages_from_db = await client.agents.messages.list(
agent_id=agent.id,
after=last_message[0].id,
limit=100,
)
response_messages_from_db = [m for m in messages_from_db if m.message_type != "user_message"]
response_message_types_from_db = [m.message_type for m in response_messages_from_db]
assert len(response_messages_from_stream) == len(response_messages_from_db)
assert response_message_types_from_stream == response_message_types_from_db
for idx in range(len(response_messages_from_stream)):
stream_message = response_messages_from_stream[idx]
db_message = response_messages_from_db[idx]
assert stream_message.message_type == db_message.message_type
assert stream_message.id == db_message.id
assert stream_message.otid == db_message.otid
if stream_message.message_type == "reasoning_message":
assert stream_message.reasoning == db_message.reasoning
if stream_message.message_type == "tool_call_message":
assert stream_message.tool_call.tool_call_id == db_message.tool_call.tool_call_id
assert stream_message.tool_call.name == db_message.tool_call.name
if stream_message.tool_call.name == "craft_queries":
assert "queries" in stream_message.tool_call.arguments
assert "queries" in db_message.tool_call.arguments
if stream_message.tool_call.name == "search_and_store_pinecone_records":
assert "query_text" in stream_message.tool_call.arguments
assert "query_text" in db_message.tool_call.arguments
if stream_message.tool_call.name == "summarize_pinecone_results":
assert "summary" in stream_message.tool_call.arguments
assert "summary" in db_message.tool_call.arguments
assert "inner_thoughts" not in stream_message.tool_call.arguments
assert "inner_thoughts" not in db_message.tool_call.arguments
if stream_message.message_type == "tool_return_message":
assert stream_message.tool_return == db_message.tool_return
await client.agents.delete(agent_id=agent.id)

View File

@@ -11,7 +11,7 @@ import pytest
from dotenv import load_dotenv
from letta_client import CreateBlock
from letta_client import Letta as LettaSDKClient
from letta_client import LettaRequest, MessageCreate, TextContent
from letta_client import LettaRequest, MessageCreate, TerminalToolRule, TextContent
from letta_client.client import BaseTool
from letta_client.core import ApiError
from letta_client.types import AgentState, ToolReturnMessage
@@ -942,10 +942,22 @@ def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKCl
print(f"Updated inventory for {data.item.name} with a quantity change of {quantity_change}")
return True
# test creation
tool = client.tools.add(
tool=ManageInventoryTool(),
)
# test that upserting also works
new_description = "NEW"
class ManageInventoryToolModified(ManageInventoryTool):
description: str = new_description
tool = client.tools.add(
tool=ManageInventoryToolModified(),
)
assert tool.description == new_description
assert tool is not None
assert tool.name == "manage_inventory"
assert "inventory" in tool.tags
@@ -1005,7 +1017,7 @@ def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKCl
client.tools.delete(tool.id)
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
@pytest.mark.parametrize("e2b_sandbox_mode", [False], indirect=True)
def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
class Step(BaseModel):
@@ -1021,7 +1033,18 @@ def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
print(f"Created task plan with {len(steps)} steps: {explanation}")
return steps
tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"])
# test creation
client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"])
# test upsert
new_steps_description = "NEW"
class StepsListModified(BaseModel):
steps: List[Step] = Field(..., description=new_steps_description)
explanation: str = Field(..., description="Explanation for the list of steps.")
tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsListModified, description=new_steps_description)
assert tool.description == new_steps_description
assert tool is not None
assert tool.name == "create_task_plan"
@@ -1039,6 +1062,9 @@ def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
embedding="openai/text-embedding-3-small",
tool_ids=[tool.id],
include_base_tools=False,
tool_rules=[
TerminalToolRule(tool_name=tool.name),
],
)
response = client.agents.messages.create(
@@ -1062,6 +1088,7 @@ def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
assert first_tool_call.tool_call.name == "create_task_plan"
args = json.loads(first_tool_call.tool_call.arguments)
assert "steps" in args
assert "explanation" in args
assert isinstance(args["steps"], list)
@@ -1224,145 +1251,6 @@ def test_agent_tools_list(client: LettaSDKClient):
client.agents.delete(agent_id=agent_state.id)
def test_update_tool_source_code_changes_name(client: LettaSDKClient):
"""Test that updating a tool's source code correctly changes its name"""
import textwrap
# Create initial tool
def initial_tool(x: int) -> int:
"""
Multiply a number by 2
Args:
x: The input number
Returns:
The input multiplied by 2
"""
return x * 2
# Create the tool
tool = client.tools.upsert_from_function(func=initial_tool)
assert tool.name == "initial_tool"
try:
# Define new function source code with different name
new_source_code = textwrap.dedent(
"""
def updated_tool(x: int, y: int) -> int:
'''
Add two numbers together
Args:
x: First number
y: Second number
Returns:
Sum of x and y
'''
return x + y
"""
).strip()
# Update the tool's source code
updated = client.tools.modify(tool_id=tool.id, source_code=new_source_code)
# Verify the name changed
assert updated.name == "updated_tool"
assert updated.source_code == new_source_code
# Verify the schema was updated for the new parameters
assert updated.json_schema is not None
assert updated.json_schema["name"] == "updated_tool"
assert updated.json_schema["description"] == "Add two numbers together"
# Check parameters
params = updated.json_schema.get("parameters", {})
properties = params.get("properties", {})
assert "x" in properties
assert "y" in properties
assert properties["x"]["type"] == "integer"
assert properties["y"]["type"] == "integer"
assert properties["x"]["description"] == "First number"
assert properties["y"]["description"] == "Second number"
assert params["required"] == ["x", "y"]
finally:
# Clean up
client.tools.delete(tool_id=tool.id)
def test_update_tool_source_code_duplicate_name_error(client: LettaSDKClient):
"""Test that updating a tool's source code to have the same name as another existing tool raises an error"""
import textwrap
# Create first tool
def first_tool(x: int) -> int:
"""
Multiply a number by 2
Args:
x: The input number
Returns:
The input multiplied by 2
"""
return x * 2
# Create second tool
def second_tool(x: int) -> int:
"""
Multiply a number by 3
Args:
x: The input number
Returns:
The input multiplied by 3
"""
return x * 3
# Create both tools
tool1 = client.tools.upsert_from_function(func=first_tool)
tool2 = client.tools.upsert_from_function(func=second_tool)
assert tool1.name == "first_tool"
assert tool2.name == "second_tool"
try:
# Try to update second_tool to have the same name as first_tool
new_source_code = textwrap.dedent(
"""
def first_tool(x: int) -> int:
'''
Multiply a number by 4
Args:
x: The input number
Returns:
The input multiplied by 4
'''
return x * 4
"""
).strip()
# This should raise an error since first_tool already exists
with pytest.raises(Exception) as exc_info:
client.tools.modify(tool_id=tool2.id, source_code=new_source_code)
# Verify the error message indicates duplicate name
error_message = str(exc_info.value)
assert "already exists" in error_message.lower() or "duplicate" in error_message.lower() or "conflict" in error_message.lower()
# Verify that tool2 was not modified
tool2_check = client.tools.retrieve(tool_id=tool2.id)
assert tool2_check.name == "second_tool" # Name should remain unchanged
finally:
# Clean up both tools
client.tools.delete(tool_id=tool1.id)
client.tools.delete(tool_id=tool2.id)
def test_add_tool_with_multiple_functions_in_source_code(client: LettaSDKClient):
"""Test adding a tool with multiple functions in the source code"""
import textwrap
@@ -1445,143 +1333,144 @@ def test_add_tool_with_multiple_functions_in_source_code(client: LettaSDKClient)
client.tools.delete(tool_id=tool.id)
def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient):
"""Test that tool name auto-updates when source code changes with multiple functions"""
import textwrap
# Initial source code with multiple functions
initial_source_code = textwrap.dedent(
"""
def helper_function(x: int) -> int:
'''
Helper function that doubles the input
Args:
x: The input number
Returns:
The input multiplied by 2
'''
return x * 2
def another_helper(text: str) -> str:
'''
Another helper that uppercases text
Args:
text: The input text to uppercase
Returns:
The uppercased text
'''
return text.upper()
def main_function(x: int, y: int) -> int:
'''
Main function that uses the helper
Args:
x: First number
y: Second number
Returns:
Result of (x * 2) + y
'''
doubled_x = helper_function(x)
return doubled_x + y
"""
).strip()
# Create tool with initial source code
tool = client.tools.create(
source_code=initial_source_code,
)
try:
# Verify the tool was created with the last function's name
assert tool is not None
assert tool.name == "main_function"
assert tool.source_code == initial_source_code
# Now modify the source code with a different function order
new_source_code = textwrap.dedent(
"""
def process_data(data: str, count: int) -> str:
'''
Process data by repeating it
Args:
data: The input data
count: Number of times to repeat
Returns:
The processed data
'''
return data * count
def helper_utility(x: float) -> float:
'''
Helper utility function
Args:
x: Input value
Returns:
Squared value
'''
return x * x
"""
).strip()
# Modify the tool with new source code
modified_tool = client.tools.modify(tool_id=tool.id, source_code=new_source_code)
# Verify the name automatically updated to the last function
assert modified_tool.name == "helper_utility"
assert modified_tool.source_code == new_source_code
# Verify the JSON schema updated correctly
assert modified_tool.json_schema is not None
assert modified_tool.json_schema["name"] == "helper_utility"
assert modified_tool.json_schema["description"] == "Helper utility function"
# Check parameters updated correctly
params = modified_tool.json_schema.get("parameters", {})
properties = params.get("properties", {})
assert "x" in properties
assert properties["x"]["type"] == "number" # float maps to number
assert params["required"] == ["x"]
# Test one more modification with only one function
single_function_code = textwrap.dedent(
"""
def calculate_total(items: list, tax_rate: float) -> float:
'''
Calculate total with tax
Args:
items: List of item prices
tax_rate: Tax rate as decimal
Returns:
Total including tax
'''
subtotal = sum(items)
return subtotal * (1 + tax_rate)
"""
).strip()
# Modify again
final_tool = client.tools.modify(tool_id=tool.id, source_code=single_function_code)
# Verify name updated again
assert final_tool.name == "calculate_total"
assert final_tool.source_code == single_function_code
assert final_tool.json_schema["description"] == "Calculate total with tax"
finally:
# Clean up
client.tools.delete(tool_id=tool.id)
# TODO: add back once behavior is defined
# def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient):
# """Test that tool name auto-updates when source code changes with multiple functions"""
# import textwrap
#
# # Initial source code with multiple functions
# initial_source_code = textwrap.dedent(
# """
# def helper_function(x: int) -> int:
# '''
# Helper function that doubles the input
#
# Args:
# x: The input number
#
# Returns:
# The input multiplied by 2
# '''
# return x * 2
#
# def another_helper(text: str) -> str:
# '''
# Another helper that uppercases text
#
# Args:
# text: The input text to uppercase
#
# Returns:
# The uppercased text
# '''
# return text.upper()
#
# def main_function(x: int, y: int) -> int:
# '''
# Main function that uses the helper
#
# Args:
# x: First number
# y: Second number
#
# Returns:
# Result of (x * 2) + y
# '''
# doubled_x = helper_function(x)
# return doubled_x + y
# """
# ).strip()
#
# # Create tool with initial source code
# tool = client.tools.create(
# source_code=initial_source_code,
# )
#
# try:
# # Verify the tool was created with the last function's name
# assert tool is not None
# assert tool.name == "main_function"
# assert tool.source_code == initial_source_code
#
# # Now modify the source code with a different function order
# new_source_code = textwrap.dedent(
# """
# def process_data(data: str, count: int) -> str:
# '''
# Process data by repeating it
#
# Args:
# data: The input data
# count: Number of times to repeat
#
# Returns:
# The processed data
# '''
# return data * count
#
# def helper_utility(x: float) -> float:
# '''
# Helper utility function
#
# Args:
# x: Input value
#
# Returns:
# Squared value
# '''
# return x * x
# """
# ).strip()
#
# # Modify the tool with new source code
# modified_tool = client.tools.modify(name="helper_utility", tool_id=tool.id, source_code=new_source_code)
#
# # Verify the name automatically updated to the last function
# assert modified_tool.name == "helper_utility"
# assert modified_tool.source_code == new_source_code
#
# # Verify the JSON schema updated correctly
# assert modified_tool.json_schema is not None
# assert modified_tool.json_schema["name"] == "helper_utility"
# assert modified_tool.json_schema["description"] == "Helper utility function"
#
# # Check parameters updated correctly
# params = modified_tool.json_schema.get("parameters", {})
# properties = params.get("properties", {})
# assert "x" in properties
# assert properties["x"]["type"] == "number" # float maps to number
# assert params["required"] == ["x"]
#
# # Test one more modification with only one function
# single_function_code = textwrap.dedent(
# """
# def calculate_total(items: list, tax_rate: float) -> float:
# '''
# Calculate total with tax
#
# Args:
# items: List of item prices
# tax_rate: Tax rate as decimal
#
# Returns:
# Total including tax
# '''
# subtotal = sum(items)
# return subtotal * (1 + tax_rate)
# """
# ).strip()
#
# # Modify again
# final_tool = client.tools.modify(tool_id=tool.id, source_code=single_function_code)
#
# # Verify name updated again
# assert final_tool.name == "calculate_total"
# assert final_tool.source_code == single_function_code
# assert final_tool.json_schema["description"] == "Calculate total with tax"
#
# finally:
# # Clean up
# client.tools.delete(tool_id=tool.id)
def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
@@ -1637,28 +1526,16 @@ def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
},
}
# Modify the tool with both new source code AND JSON schema
modified_tool = client.tools.modify(tool_id=tool.id, source_code=new_source_code, json_schema=custom_json_schema)
# verify there is a 400 error when both source code and json schema are provided
with pytest.raises(Exception) as e:
client.tools.modify(tool_id=tool.id, source_code=new_source_code, json_schema=custom_json_schema)
assert e.value.status_code == 400
# Verify the name comes from the source code function name, not the JSON schema
assert modified_tool.name == "renamed_function"
assert modified_tool.source_code == new_source_code
# Verify the JSON schema was updated to match the function name from source code
assert modified_tool.json_schema is not None
assert modified_tool.json_schema["name"] == "renamed_function"
# The description should come from the source code docstring, not the JSON schema
assert modified_tool.json_schema["description"] == "Multiply a value by a multiplier"
# Verify parameters are from the source code, not the custom JSON schema
params = modified_tool.json_schema.get("parameters", {})
properties = params.get("properties", {})
assert "value" in properties
assert "multiplier" in properties
assert properties["value"]["type"] == "number"
assert properties["multiplier"]["type"] == "number"
assert params["required"] == ["value"]
# update with consistent name and schema
custom_json_schema["name"] = "renamed_function"
tool = client.tools.modify(tool_id=tool.id, json_schema=custom_json_schema)
assert tool.json_schema == custom_json_schema
assert tool.name == "renamed_function"
finally:
# Clean up
@@ -2006,3 +1883,58 @@ def test_import_agent_with_files_from_disk(client: LettaSDKClient):
# Clean up agents and sources
client.agents.delete(agent_id=imported_agent_id)
client.sources.delete(source_id=imported_source.id)
def test_upsert_tools(client: LettaSDKClient):
"""Test upserting tools with complex schemas."""
from typing import List
class WriteReasonOffer(BaseModel):
biltMerchantId: str = Field(..., description="The merchant ID (e.g. 'MERCHANT_NETWORK-123' or 'LYFT')")
campaignId: str = Field(
...,
description="The campaign ID (e.g. '550e8400-e29b-41d4-a716-446655440000' or '550e8400-e29b-41d4-a716-446655440000_123e4567-e89b-12d3-a456-426614174000')",
)
reason: str = Field(
...,
description="A detailed explanation of why this offer is relevant to the user. Refer to the category-specific reason_instructions_{category} block for all guidelines on creating personalized reasons.",
)
class WriteReasonArgs(BaseModel):
"""Arguments for the write_reason tool."""
offer_list: List[WriteReasonOffer] = Field(
...,
description="List of WriteReasonOffer objects with merchant and campaign information",
)
def write_reason(offer_list: List[WriteReasonOffer]):
"""
This tool is used to write detailed reasons for a list of offers.
It returns the essential information: biltMerchantId, campaignId, and reason.
IMPORTANT: When generating reasons, you MUST ONLY follow the guidelines in the
category-specific instruction block named "reason_instructions_{category}" where
{category} is the category of the offer (e.g., dining, travel, shopping).
These instruction blocks contain all the necessary guidelines for creating
personalized, detailed reasons for each category. Do not rely on any other
instructions outside of these blocks.
Args:
offer_list: List of WriteReasonOffer objects, each containing:
- biltMerchantId: The merchant ID (e.g. 'MERCHANT_NETWORK-123' or 'LYFT')
- campaignId: The campaign ID (e.g. '124', '28')
- reason: A detailed explanation generated according to the category-specific reason_instructions_{category} block
Returns:
None: This function prints the offer list but does not return a value.
"""
print(offer_list)
tool = client.tools.upsert_from_function(func=write_reason, args_schema=WriteReasonArgs)
assert tool is not None
assert tool.name == "write_reason"
# Clean up
client.tools.delete(tool.id)