chore: patch update tools (#4090)
* patch update tools * update tool patch * fallback to generation for legacy tools * avoid re-parsing source if json schema exists * fix more tests * remove asssert * fix * update * update * update * Fix tests --------- Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -60,6 +60,15 @@ class LettaToolNameConflictError(LettaError):
|
||||
)
|
||||
|
||||
|
||||
class LettaToolNameSchemaMismatchError(LettaToolCreateError):
|
||||
"""Error raised when a tool name our source codedoes not match the name in the JSON schema."""
|
||||
|
||||
def __init__(self, tool_name: str, json_schema_name: str, source_code: str):
|
||||
super().__init__(
|
||||
message=f"Tool name '{tool_name}' does not match the name in the JSON schema '{json_schema_name}' or in the source code `{source_code}`",
|
||||
)
|
||||
|
||||
|
||||
class LettaConfigurationError(LettaError):
|
||||
"""Error raised when there are configuration-related issues."""
|
||||
|
||||
|
||||
@@ -80,7 +80,8 @@ class Tool(BaseTool):
|
||||
"""
|
||||
from letta.functions.helpers import generate_model_from_args_json_schema
|
||||
|
||||
if self.tool_type == ToolType.CUSTOM:
|
||||
if self.tool_type == ToolType.CUSTOM and not self.json_schema:
|
||||
# attempt various fallbacks to get the JSON schema
|
||||
if not self.source_code:
|
||||
logger.error("Custom tool with id=%s is missing source_code field", self.id)
|
||||
raise ValueError(f"Custom tool with id={self.id} is missing source_code field.")
|
||||
@@ -157,7 +158,7 @@ class Tool(BaseTool):
|
||||
|
||||
class ToolCreate(LettaBase):
|
||||
description: Optional[str] = Field(None, description="The description of the tool.")
|
||||
tags: List[str] = Field([], description="Metadata tags.")
|
||||
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
source_type: str = Field("python", description="The source type of the function.")
|
||||
json_schema: Optional[Dict] = Field(
|
||||
|
||||
@@ -130,7 +130,7 @@ async def create_tool(
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
tool = Tool(**request.model_dump())
|
||||
tool = Tool(**request.model_dump(exclude_unset=True))
|
||||
return await server.tool_manager.create_tool_async(pydantic_tool=tool, actor=actor)
|
||||
except UniqueConstraintViolationError as e:
|
||||
# Log or print the full exception here for debugging
|
||||
@@ -162,7 +162,9 @@ async def upsert_tool(
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
tool = await server.tool_manager.create_or_update_tool_async(pydantic_tool=Tool(**request.model_dump()), actor=actor)
|
||||
tool = await server.tool_manager.create_or_update_tool_async(
|
||||
pydantic_tool=Tool(**request.model_dump(exclude_unset=True)), actor=actor
|
||||
)
|
||||
return tool
|
||||
except UniqueConstraintViolationError as e:
|
||||
# Log the error and raise a conflict exception
|
||||
@@ -190,18 +192,17 @@ async def modify_tool(
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.tool_manager.update_tool_by_id_async(tool_id=tool_id, tool_update=request, actor=actor)
|
||||
tool = await server.tool_manager.update_tool_by_id_async(tool_id=tool_id, tool_update=request, actor=actor)
|
||||
print("FINAL TOOL", tool)
|
||||
return tool
|
||||
except LettaToolNameConflictError as e:
|
||||
# HTTP 409 == Conflict
|
||||
print(f"Tool name conflict during update: {e}")
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except LettaToolCreateError as e:
|
||||
# HTTP 400 == Bad Request
|
||||
print(f"Error occurred during tool update: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
# Catch other unexpected errors and raise an internal server error
|
||||
print(f"Unexpected error occurred: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from letta.constants import (
|
||||
LOCAL_ONLY_MULTI_AGENT_TOOLS,
|
||||
MCP_TOOL_TAG_NAME_PREFIX,
|
||||
)
|
||||
from letta.errors import LettaToolNameConflictError
|
||||
from letta.errors import LettaToolNameConflictError, LettaToolNameSchemaMismatchError
|
||||
from letta.functions.functions import derive_openai_json_schema, load_function_set
|
||||
from letta.log import get_logger
|
||||
|
||||
@@ -403,6 +403,7 @@ class ToolManager:
|
||||
updated_tool_type: Optional[ToolType] = None,
|
||||
bypass_name_check: bool = False,
|
||||
) -> PydanticTool:
|
||||
# TODO: remove this (legacy non-async)
|
||||
"""
|
||||
Update a tool with complex validation and schema derivation logic.
|
||||
|
||||
@@ -519,55 +520,36 @@ class ToolManager:
|
||||
# Fetch current tool early to allow conditional logic based on tool type
|
||||
current_tool = await self.get_tool_by_id_async(tool_id=tool_id, actor=actor)
|
||||
|
||||
# For MCP tools, do NOT derive schema from Python source. Trust provided JSON schema.
|
||||
if current_tool.tool_type == ToolType.EXTERNAL_MCP:
|
||||
# Prefer provided json_schema; fall back to current
|
||||
if "json_schema" in update_data:
|
||||
new_schema = update_data["json_schema"].copy()
|
||||
new_name = new_schema.get("name", current_tool.name)
|
||||
else:
|
||||
new_schema = current_tool.json_schema
|
||||
new_name = current_tool.name
|
||||
# Ensure we don't trigger derive
|
||||
update_data.pop("source_code", None)
|
||||
# If name changes, enforce uniqueness
|
||||
if new_name != current_tool.name:
|
||||
name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
|
||||
if name_exists:
|
||||
raise LettaToolNameConflictError(tool_name=new_name)
|
||||
# Do NOT derive schema from Python source. Trust provided JSON schema.
|
||||
# Prefer provided json_schema; fall back to current
|
||||
if "json_schema" in update_data:
|
||||
new_schema = update_data["json_schema"].copy()
|
||||
new_name = new_schema.get("name", current_tool.name)
|
||||
else:
|
||||
# For non-MCP tools, preserve existing behavior
|
||||
# TODO: Consider this behavior...is this what we want?
|
||||
# TODO: I feel like it's bad if json_schema strays from source code so
|
||||
# if source code is provided, always derive the name from it
|
||||
if "source_code" in update_data.keys() and not bypass_name_check:
|
||||
# Check source type to use appropriate parser
|
||||
source_type = update_data.get("source_type", current_tool.source_type)
|
||||
if source_type == "typescript":
|
||||
from letta.functions.typescript_parser import derive_typescript_json_schema
|
||||
new_schema = current_tool.json_schema
|
||||
new_name = current_tool.name
|
||||
|
||||
derived_schema = derive_typescript_json_schema(source_code=update_data["source_code"])
|
||||
else:
|
||||
# Default to Python for backwards compatibility
|
||||
derived_schema = derive_openai_json_schema(source_code=update_data["source_code"])
|
||||
new_name = derived_schema["name"]
|
||||
# original tool may no have a JSON schema at all for legacy reasons
|
||||
# in this case, fallback to dangerous schema generation
|
||||
if new_schema is None:
|
||||
if source_type == "typescript":
|
||||
from letta.functions.typescript_parser import derive_typescript_json_schema
|
||||
|
||||
# if json_schema wasn't provided, use the derived schema
|
||||
if "json_schema" not in update_data.keys():
|
||||
new_schema = derived_schema
|
||||
else:
|
||||
# if json_schema was provided, update only its name to match the source code
|
||||
new_schema = update_data["json_schema"].copy()
|
||||
new_schema["name"] = new_name
|
||||
# update the json_schema in update_data so it gets applied in the loop
|
||||
update_data["json_schema"] = new_schema
|
||||
new_schema = derive_typescript_json_schema(source_code=update_data["source_code"])
|
||||
else:
|
||||
new_schema = derive_openai_json_schema(source_code=update_data["source_code"])
|
||||
|
||||
# check if the name is changing and if so, verify it doesn't conflict
|
||||
if new_name != current_tool.name:
|
||||
# check if a tool with the new name already exists
|
||||
name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
|
||||
if name_exists:
|
||||
raise LettaToolNameConflictError(tool_name=new_name)
|
||||
# If name changes, enforce uniqueness
|
||||
if new_name != current_tool.name:
|
||||
name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
|
||||
if name_exists:
|
||||
raise LettaToolNameConflictError(tool_name=new_name)
|
||||
|
||||
# NOTE: EXTREMELEY HACKY, we need to stop making assumptions about the source_code
|
||||
if "source_code" in update_data and f"def {new_name}" not in update_data.get("source_code", ""):
|
||||
raise LettaToolNameSchemaMismatchError(
|
||||
tool_name=new_name, json_schema_name=new_schema.get("name"), source_code=update_data.get("source_code")
|
||||
)
|
||||
|
||||
# Now perform the update within the session
|
||||
async with db_registry.async_session() as session:
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta, MessageCreate, ReasoningMessage, ToolCallMessage
|
||||
from letta_client.core import RequestOptions
|
||||
|
||||
from tests.helpers.utils import upload_test_agentfile_from_disk_async
|
||||
|
||||
REASONING_THROTTLE_MS = 100
|
||||
TEST_USER_MESSAGE = "What products or services does 11x AI sell?"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url() -> str:
|
||||
"""
|
||||
Provides the URL for the Letta server.
|
||||
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||||
and polls until it's accepting connections.
|
||||
"""
|
||||
|
||||
def _run_server() -> None:
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Poll until the server is up (or timeout)
|
||||
timeout_seconds = 30
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url + "/v1/health")
|
||||
if resp.status_code < 500:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(server_url: str):
|
||||
"""
|
||||
Creates and returns an asynchronous Letta REST client for testing.
|
||||
"""
|
||||
async_client_instance = AsyncLetta(base_url=server_url)
|
||||
yield async_client_instance
|
||||
|
||||
|
||||
async def test_pinecone_tool(client: AsyncLetta, server_url: str) -> None:
|
||||
"""
|
||||
Test the Pinecone tool integration with the Letta client.
|
||||
"""
|
||||
response = await upload_test_agentfile_from_disk_async(client, "knowledge-base.af")
|
||||
|
||||
agent_id = response.agent_ids[0]
|
||||
|
||||
agent = await client.agents.modify(
|
||||
agent_id=agent_id,
|
||||
tool_exec_environment_variables={
|
||||
"PINECONE_INDEX_HOST": os.getenv("PINECONE_INDEX_HOST"),
|
||||
"PINECONE_API_KEY": os.getenv("PINECONE_API_KEY"),
|
||||
"PINECONE_NAMESPACE": os.getenv("PINECONE_NAMESPACE"),
|
||||
},
|
||||
)
|
||||
last_message = await client.agents.messages.list(
|
||||
agent_id=agent.id,
|
||||
limit=1,
|
||||
)
|
||||
|
||||
curr_message_type = None
|
||||
messages = []
|
||||
reasoning_content = []
|
||||
last_reasoning_update_ms = 0
|
||||
tool_call_content = ""
|
||||
tool_return_content = ""
|
||||
summary = None
|
||||
pinecone_results = None
|
||||
queries = []
|
||||
|
||||
try:
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=TEST_USER_MESSAGE,
|
||||
),
|
||||
],
|
||||
stream_tokens=True,
|
||||
request_options=RequestOptions(
|
||||
timeout_in_seconds=1000,
|
||||
),
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
if chunk.message_type != curr_message_type:
|
||||
messages.append(chunk)
|
||||
curr_message_type = chunk.message_type
|
||||
if curr_message_type == "reasoning_message":
|
||||
reasoning_content = []
|
||||
if curr_message_type == "tool_call_message":
|
||||
tool_call_content = ""
|
||||
|
||||
if chunk.message_type == "reasoning_message":
|
||||
now_ms = time.time_ns() // 1_000_000
|
||||
if now_ms - last_reasoning_update_ms < REASONING_THROTTLE_MS:
|
||||
await asyncio.sleep(REASONING_THROTTLE_MS / 1000)
|
||||
|
||||
last_reasoning_update_ms = now_ms
|
||||
if len(reasoning_content) == 0:
|
||||
reasoning_content = [chunk.reasoning]
|
||||
else:
|
||||
reasoning_content[-1] += chunk.reasoning
|
||||
|
||||
message_dict = messages[-1].model_dump()
|
||||
message_dict["reasoning"] = "".join(reasoning_content).strip()
|
||||
messages[-1] = ReasoningMessage(**message_dict)
|
||||
|
||||
if chunk.message_type == "tool_return_message":
|
||||
tool_return_content += chunk.tool_return
|
||||
|
||||
if chunk.status == "success":
|
||||
try:
|
||||
if chunk.name == "summarize_pinecone_results":
|
||||
json_response = json.loads(chunk.tool_return)
|
||||
summary = json_response.get("summary", None)
|
||||
pinecone_results = json_response.get("pinecone_results", None)
|
||||
tool_return_content = ""
|
||||
elif chunk.name == "craft_queries":
|
||||
queries.append(chunk.tool_return)
|
||||
tool_return_content = ""
|
||||
except Exception as e:
|
||||
print(f"Error parsing JSON response: {str(e)}. {chunk.tool_return}\n")
|
||||
tool_return_content = ""
|
||||
|
||||
if chunk.message_type == "tool_call_message":
|
||||
if chunk.tool_call.arguments is not None:
|
||||
tool_call_content += chunk.tool_call.arguments
|
||||
message_dict = messages[-1].model_dump()
|
||||
message_dict["tool_call"]["arguments"] = tool_call_content
|
||||
messages[-1] = ToolCallMessage(**message_dict)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch knowledge base response: {str(e)}\n")
|
||||
print(tool_call_content)
|
||||
raise e
|
||||
|
||||
assert len(messages) > 0, "No messages received from the agent."
|
||||
assert len(reasoning_content) > 0, "No reasoning content received from the agent."
|
||||
assert summary is not None, "No summary received from the agent."
|
||||
assert pinecone_results is not None, "No Pinecone results received from the agent."
|
||||
assert len(queries) > 0, "No queries received from the agent."
|
||||
|
||||
assert messages[-2].message_type == "stop_reason", "Penultimate message in stream must be stop reason."
|
||||
assert messages[-1].message_type == "usage_statistics", "Last message in stream must be usage stats."
|
||||
response_messages_from_stream = [m for m in messages if m.message_type not in ["stop_reason", "usage_statistics"]]
|
||||
response_message_types_from_stream = [m.message_type for m in response_messages_from_stream]
|
||||
|
||||
messages_from_db = await client.agents.messages.list(
|
||||
agent_id=agent.id,
|
||||
after=last_message[0].id,
|
||||
limit=100,
|
||||
)
|
||||
response_messages_from_db = [m for m in messages_from_db if m.message_type != "user_message"]
|
||||
response_message_types_from_db = [m.message_type for m in response_messages_from_db]
|
||||
|
||||
assert len(response_messages_from_stream) == len(response_messages_from_db)
|
||||
assert response_message_types_from_stream == response_message_types_from_db
|
||||
for idx in range(len(response_messages_from_stream)):
|
||||
stream_message = response_messages_from_stream[idx]
|
||||
db_message = response_messages_from_db[idx]
|
||||
assert stream_message.message_type == db_message.message_type
|
||||
assert stream_message.id == db_message.id
|
||||
assert stream_message.otid == db_message.otid
|
||||
|
||||
if stream_message.message_type == "reasoning_message":
|
||||
assert stream_message.reasoning == db_message.reasoning
|
||||
|
||||
if stream_message.message_type == "tool_call_message":
|
||||
assert stream_message.tool_call.tool_call_id == db_message.tool_call.tool_call_id
|
||||
assert stream_message.tool_call.name == db_message.tool_call.name
|
||||
|
||||
if stream_message.tool_call.name == "craft_queries":
|
||||
assert "queries" in stream_message.tool_call.arguments
|
||||
assert "queries" in db_message.tool_call.arguments
|
||||
if stream_message.tool_call.name == "search_and_store_pinecone_records":
|
||||
assert "query_text" in stream_message.tool_call.arguments
|
||||
assert "query_text" in db_message.tool_call.arguments
|
||||
if stream_message.tool_call.name == "summarize_pinecone_results":
|
||||
assert "summary" in stream_message.tool_call.arguments
|
||||
assert "summary" in db_message.tool_call.arguments
|
||||
|
||||
assert "inner_thoughts" not in stream_message.tool_call.arguments
|
||||
assert "inner_thoughts" not in db_message.tool_call.arguments
|
||||
|
||||
if stream_message.message_type == "tool_return_message":
|
||||
assert stream_message.tool_return == db_message.tool_return
|
||||
|
||||
await client.agents.delete(agent_id=agent.id)
|
||||
@@ -11,7 +11,7 @@ import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import CreateBlock
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client import LettaRequest, MessageCreate, TextContent
|
||||
from letta_client import LettaRequest, MessageCreate, TerminalToolRule, TextContent
|
||||
from letta_client.client import BaseTool
|
||||
from letta_client.core import ApiError
|
||||
from letta_client.types import AgentState, ToolReturnMessage
|
||||
@@ -942,10 +942,22 @@ def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKCl
|
||||
print(f"Updated inventory for {data.item.name} with a quantity change of {quantity_change}")
|
||||
return True
|
||||
|
||||
# test creation
|
||||
tool = client.tools.add(
|
||||
tool=ManageInventoryTool(),
|
||||
)
|
||||
|
||||
# test that upserting also works
|
||||
new_description = "NEW"
|
||||
|
||||
class ManageInventoryToolModified(ManageInventoryTool):
|
||||
description: str = new_description
|
||||
|
||||
tool = client.tools.add(
|
||||
tool=ManageInventoryToolModified(),
|
||||
)
|
||||
assert tool.description == new_description
|
||||
|
||||
assert tool is not None
|
||||
assert tool.name == "manage_inventory"
|
||||
assert "inventory" in tool.tags
|
||||
@@ -1005,7 +1017,7 @@ def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKCl
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
|
||||
@pytest.mark.parametrize("e2b_sandbox_mode", [False], indirect=True)
|
||||
def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
|
||||
class Step(BaseModel):
|
||||
@@ -1021,7 +1033,18 @@ def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
print(f"Created task plan with {len(steps)} steps: {explanation}")
|
||||
return steps
|
||||
|
||||
tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"])
|
||||
# test creation
|
||||
client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"])
|
||||
|
||||
# test upsert
|
||||
new_steps_description = "NEW"
|
||||
|
||||
class StepsListModified(BaseModel):
|
||||
steps: List[Step] = Field(..., description=new_steps_description)
|
||||
explanation: str = Field(..., description="Explanation for the list of steps.")
|
||||
|
||||
tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsListModified, description=new_steps_description)
|
||||
assert tool.description == new_steps_description
|
||||
|
||||
assert tool is not None
|
||||
assert tool.name == "create_task_plan"
|
||||
@@ -1039,6 +1062,9 @@ def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tool_ids=[tool.id],
|
||||
include_base_tools=False,
|
||||
tool_rules=[
|
||||
TerminalToolRule(tool_name=tool.name),
|
||||
],
|
||||
)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
@@ -1062,6 +1088,7 @@ def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
assert first_tool_call.tool_call.name == "create_task_plan"
|
||||
|
||||
args = json.loads(first_tool_call.tool_call.arguments)
|
||||
|
||||
assert "steps" in args
|
||||
assert "explanation" in args
|
||||
assert isinstance(args["steps"], list)
|
||||
@@ -1224,145 +1251,6 @@ def test_agent_tools_list(client: LettaSDKClient):
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
|
||||
def test_update_tool_source_code_changes_name(client: LettaSDKClient):
|
||||
"""Test that updating a tool's source code correctly changes its name"""
|
||||
import textwrap
|
||||
|
||||
# Create initial tool
|
||||
def initial_tool(x: int) -> int:
|
||||
"""
|
||||
Multiply a number by 2
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
Returns:
|
||||
The input multiplied by 2
|
||||
"""
|
||||
return x * 2
|
||||
|
||||
# Create the tool
|
||||
tool = client.tools.upsert_from_function(func=initial_tool)
|
||||
assert tool.name == "initial_tool"
|
||||
|
||||
try:
|
||||
# Define new function source code with different name
|
||||
new_source_code = textwrap.dedent(
|
||||
"""
|
||||
def updated_tool(x: int, y: int) -> int:
|
||||
'''
|
||||
Add two numbers together
|
||||
|
||||
Args:
|
||||
x: First number
|
||||
y: Second number
|
||||
Returns:
|
||||
Sum of x and y
|
||||
'''
|
||||
return x + y
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# Update the tool's source code
|
||||
updated = client.tools.modify(tool_id=tool.id, source_code=new_source_code)
|
||||
|
||||
# Verify the name changed
|
||||
assert updated.name == "updated_tool"
|
||||
assert updated.source_code == new_source_code
|
||||
|
||||
# Verify the schema was updated for the new parameters
|
||||
assert updated.json_schema is not None
|
||||
assert updated.json_schema["name"] == "updated_tool"
|
||||
assert updated.json_schema["description"] == "Add two numbers together"
|
||||
|
||||
# Check parameters
|
||||
params = updated.json_schema.get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
assert "x" in properties
|
||||
assert "y" in properties
|
||||
assert properties["x"]["type"] == "integer"
|
||||
assert properties["y"]["type"] == "integer"
|
||||
assert properties["x"]["description"] == "First number"
|
||||
assert properties["y"]["description"] == "Second number"
|
||||
assert params["required"] == ["x", "y"]
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.tools.delete(tool_id=tool.id)
|
||||
|
||||
|
||||
def test_update_tool_source_code_duplicate_name_error(client: LettaSDKClient):
|
||||
"""Test that updating a tool's source code to have the same name as another existing tool raises an error"""
|
||||
import textwrap
|
||||
|
||||
# Create first tool
|
||||
def first_tool(x: int) -> int:
|
||||
"""
|
||||
Multiply a number by 2
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
|
||||
Returns:
|
||||
The input multiplied by 2
|
||||
"""
|
||||
return x * 2
|
||||
|
||||
# Create second tool
|
||||
def second_tool(x: int) -> int:
|
||||
"""
|
||||
Multiply a number by 3
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
|
||||
Returns:
|
||||
The input multiplied by 3
|
||||
"""
|
||||
return x * 3
|
||||
|
||||
# Create both tools
|
||||
tool1 = client.tools.upsert_from_function(func=first_tool)
|
||||
tool2 = client.tools.upsert_from_function(func=second_tool)
|
||||
|
||||
assert tool1.name == "first_tool"
|
||||
assert tool2.name == "second_tool"
|
||||
|
||||
try:
|
||||
# Try to update second_tool to have the same name as first_tool
|
||||
new_source_code = textwrap.dedent(
|
||||
"""
|
||||
def first_tool(x: int) -> int:
|
||||
'''
|
||||
Multiply a number by 4
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
|
||||
Returns:
|
||||
The input multiplied by 4
|
||||
'''
|
||||
return x * 4
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# This should raise an error since first_tool already exists
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
client.tools.modify(tool_id=tool2.id, source_code=new_source_code)
|
||||
|
||||
# Verify the error message indicates duplicate name
|
||||
error_message = str(exc_info.value)
|
||||
assert "already exists" in error_message.lower() or "duplicate" in error_message.lower() or "conflict" in error_message.lower()
|
||||
|
||||
# Verify that tool2 was not modified
|
||||
tool2_check = client.tools.retrieve(tool_id=tool2.id)
|
||||
assert tool2_check.name == "second_tool" # Name should remain unchanged
|
||||
|
||||
finally:
|
||||
# Clean up both tools
|
||||
client.tools.delete(tool_id=tool1.id)
|
||||
client.tools.delete(tool_id=tool2.id)
|
||||
|
||||
|
||||
def test_add_tool_with_multiple_functions_in_source_code(client: LettaSDKClient):
|
||||
"""Test adding a tool with multiple functions in the source code"""
|
||||
import textwrap
|
||||
@@ -1445,143 +1333,144 @@ def test_add_tool_with_multiple_functions_in_source_code(client: LettaSDKClient)
|
||||
client.tools.delete(tool_id=tool.id)
|
||||
|
||||
|
||||
def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient):
|
||||
"""Test that tool name auto-updates when source code changes with multiple functions"""
|
||||
import textwrap
|
||||
|
||||
# Initial source code with multiple functions
|
||||
initial_source_code = textwrap.dedent(
|
||||
"""
|
||||
def helper_function(x: int) -> int:
|
||||
'''
|
||||
Helper function that doubles the input
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
|
||||
Returns:
|
||||
The input multiplied by 2
|
||||
'''
|
||||
return x * 2
|
||||
|
||||
def another_helper(text: str) -> str:
|
||||
'''
|
||||
Another helper that uppercases text
|
||||
|
||||
Args:
|
||||
text: The input text to uppercase
|
||||
|
||||
Returns:
|
||||
The uppercased text
|
||||
'''
|
||||
return text.upper()
|
||||
|
||||
def main_function(x: int, y: int) -> int:
|
||||
'''
|
||||
Main function that uses the helper
|
||||
|
||||
Args:
|
||||
x: First number
|
||||
y: Second number
|
||||
|
||||
Returns:
|
||||
Result of (x * 2) + y
|
||||
'''
|
||||
doubled_x = helper_function(x)
|
||||
return doubled_x + y
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# Create tool with initial source code
|
||||
tool = client.tools.create(
|
||||
source_code=initial_source_code,
|
||||
)
|
||||
|
||||
try:
|
||||
# Verify the tool was created with the last function's name
|
||||
assert tool is not None
|
||||
assert tool.name == "main_function"
|
||||
assert tool.source_code == initial_source_code
|
||||
|
||||
# Now modify the source code with a different function order
|
||||
new_source_code = textwrap.dedent(
|
||||
"""
|
||||
def process_data(data: str, count: int) -> str:
|
||||
'''
|
||||
Process data by repeating it
|
||||
|
||||
Args:
|
||||
data: The input data
|
||||
count: Number of times to repeat
|
||||
|
||||
Returns:
|
||||
The processed data
|
||||
'''
|
||||
return data * count
|
||||
|
||||
def helper_utility(x: float) -> float:
|
||||
'''
|
||||
Helper utility function
|
||||
|
||||
Args:
|
||||
x: Input value
|
||||
|
||||
Returns:
|
||||
Squared value
|
||||
'''
|
||||
return x * x
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# Modify the tool with new source code
|
||||
modified_tool = client.tools.modify(tool_id=tool.id, source_code=new_source_code)
|
||||
|
||||
# Verify the name automatically updated to the last function
|
||||
assert modified_tool.name == "helper_utility"
|
||||
assert modified_tool.source_code == new_source_code
|
||||
|
||||
# Verify the JSON schema updated correctly
|
||||
assert modified_tool.json_schema is not None
|
||||
assert modified_tool.json_schema["name"] == "helper_utility"
|
||||
assert modified_tool.json_schema["description"] == "Helper utility function"
|
||||
|
||||
# Check parameters updated correctly
|
||||
params = modified_tool.json_schema.get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
assert "x" in properties
|
||||
assert properties["x"]["type"] == "number" # float maps to number
|
||||
assert params["required"] == ["x"]
|
||||
|
||||
# Test one more modification with only one function
|
||||
single_function_code = textwrap.dedent(
|
||||
"""
|
||||
def calculate_total(items: list, tax_rate: float) -> float:
|
||||
'''
|
||||
Calculate total with tax
|
||||
|
||||
Args:
|
||||
items: List of item prices
|
||||
tax_rate: Tax rate as decimal
|
||||
|
||||
Returns:
|
||||
Total including tax
|
||||
'''
|
||||
subtotal = sum(items)
|
||||
return subtotal * (1 + tax_rate)
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# Modify again
|
||||
final_tool = client.tools.modify(tool_id=tool.id, source_code=single_function_code)
|
||||
|
||||
# Verify name updated again
|
||||
assert final_tool.name == "calculate_total"
|
||||
assert final_tool.source_code == single_function_code
|
||||
assert final_tool.json_schema["description"] == "Calculate total with tax"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.tools.delete(tool_id=tool.id)
|
||||
# TODO: add back once behavior is defined
|
||||
# def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient):
|
||||
# """Test that tool name auto-updates when source code changes with multiple functions"""
|
||||
# import textwrap
|
||||
#
|
||||
# # Initial source code with multiple functions
|
||||
# initial_source_code = textwrap.dedent(
|
||||
# """
|
||||
# def helper_function(x: int) -> int:
|
||||
# '''
|
||||
# Helper function that doubles the input
|
||||
#
|
||||
# Args:
|
||||
# x: The input number
|
||||
#
|
||||
# Returns:
|
||||
# The input multiplied by 2
|
||||
# '''
|
||||
# return x * 2
|
||||
#
|
||||
# def another_helper(text: str) -> str:
|
||||
# '''
|
||||
# Another helper that uppercases text
|
||||
#
|
||||
# Args:
|
||||
# text: The input text to uppercase
|
||||
#
|
||||
# Returns:
|
||||
# The uppercased text
|
||||
# '''
|
||||
# return text.upper()
|
||||
#
|
||||
# def main_function(x: int, y: int) -> int:
|
||||
# '''
|
||||
# Main function that uses the helper
|
||||
#
|
||||
# Args:
|
||||
# x: First number
|
||||
# y: Second number
|
||||
#
|
||||
# Returns:
|
||||
# Result of (x * 2) + y
|
||||
# '''
|
||||
# doubled_x = helper_function(x)
|
||||
# return doubled_x + y
|
||||
# """
|
||||
# ).strip()
|
||||
#
|
||||
# # Create tool with initial source code
|
||||
# tool = client.tools.create(
|
||||
# source_code=initial_source_code,
|
||||
# )
|
||||
#
|
||||
# try:
|
||||
# # Verify the tool was created with the last function's name
|
||||
# assert tool is not None
|
||||
# assert tool.name == "main_function"
|
||||
# assert tool.source_code == initial_source_code
|
||||
#
|
||||
# # Now modify the source code with a different function order
|
||||
# new_source_code = textwrap.dedent(
|
||||
# """
|
||||
# def process_data(data: str, count: int) -> str:
|
||||
# '''
|
||||
# Process data by repeating it
|
||||
#
|
||||
# Args:
|
||||
# data: The input data
|
||||
# count: Number of times to repeat
|
||||
#
|
||||
# Returns:
|
||||
# The processed data
|
||||
# '''
|
||||
# return data * count
|
||||
#
|
||||
# def helper_utility(x: float) -> float:
|
||||
# '''
|
||||
# Helper utility function
|
||||
#
|
||||
# Args:
|
||||
# x: Input value
|
||||
#
|
||||
# Returns:
|
||||
# Squared value
|
||||
# '''
|
||||
# return x * x
|
||||
# """
|
||||
# ).strip()
|
||||
#
|
||||
# # Modify the tool with new source code
|
||||
# modified_tool = client.tools.modify(name="helper_utility", tool_id=tool.id, source_code=new_source_code)
|
||||
#
|
||||
# # Verify the name automatically updated to the last function
|
||||
# assert modified_tool.name == "helper_utility"
|
||||
# assert modified_tool.source_code == new_source_code
|
||||
#
|
||||
# # Verify the JSON schema updated correctly
|
||||
# assert modified_tool.json_schema is not None
|
||||
# assert modified_tool.json_schema["name"] == "helper_utility"
|
||||
# assert modified_tool.json_schema["description"] == "Helper utility function"
|
||||
#
|
||||
# # Check parameters updated correctly
|
||||
# params = modified_tool.json_schema.get("parameters", {})
|
||||
# properties = params.get("properties", {})
|
||||
# assert "x" in properties
|
||||
# assert properties["x"]["type"] == "number" # float maps to number
|
||||
# assert params["required"] == ["x"]
|
||||
#
|
||||
# # Test one more modification with only one function
|
||||
# single_function_code = textwrap.dedent(
|
||||
# """
|
||||
# def calculate_total(items: list, tax_rate: float) -> float:
|
||||
# '''
|
||||
# Calculate total with tax
|
||||
#
|
||||
# Args:
|
||||
# items: List of item prices
|
||||
# tax_rate: Tax rate as decimal
|
||||
#
|
||||
# Returns:
|
||||
# Total including tax
|
||||
# '''
|
||||
# subtotal = sum(items)
|
||||
# return subtotal * (1 + tax_rate)
|
||||
# """
|
||||
# ).strip()
|
||||
#
|
||||
# # Modify again
|
||||
# final_tool = client.tools.modify(tool_id=tool.id, source_code=single_function_code)
|
||||
#
|
||||
# # Verify name updated again
|
||||
# assert final_tool.name == "calculate_total"
|
||||
# assert final_tool.source_code == single_function_code
|
||||
# assert final_tool.json_schema["description"] == "Calculate total with tax"
|
||||
#
|
||||
# finally:
|
||||
# # Clean up
|
||||
# client.tools.delete(tool_id=tool.id)
|
||||
|
||||
|
||||
def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
|
||||
@@ -1637,28 +1526,16 @@ def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
|
||||
},
|
||||
}
|
||||
|
||||
# Modify the tool with both new source code AND JSON schema
|
||||
modified_tool = client.tools.modify(tool_id=tool.id, source_code=new_source_code, json_schema=custom_json_schema)
|
||||
# verify there is a 400 error when both source code and json schema are provided
|
||||
with pytest.raises(Exception) as e:
|
||||
client.tools.modify(tool_id=tool.id, source_code=new_source_code, json_schema=custom_json_schema)
|
||||
assert e.value.status_code == 400
|
||||
|
||||
# Verify the name comes from the source code function name, not the JSON schema
|
||||
assert modified_tool.name == "renamed_function"
|
||||
assert modified_tool.source_code == new_source_code
|
||||
|
||||
# Verify the JSON schema was updated to match the function name from source code
|
||||
assert modified_tool.json_schema is not None
|
||||
assert modified_tool.json_schema["name"] == "renamed_function"
|
||||
|
||||
# The description should come from the source code docstring, not the JSON schema
|
||||
assert modified_tool.json_schema["description"] == "Multiply a value by a multiplier"
|
||||
|
||||
# Verify parameters are from the source code, not the custom JSON schema
|
||||
params = modified_tool.json_schema.get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
assert "value" in properties
|
||||
assert "multiplier" in properties
|
||||
assert properties["value"]["type"] == "number"
|
||||
assert properties["multiplier"]["type"] == "number"
|
||||
assert params["required"] == ["value"]
|
||||
# update with consistent name and schema
|
||||
custom_json_schema["name"] = "renamed_function"
|
||||
tool = client.tools.modify(tool_id=tool.id, json_schema=custom_json_schema)
|
||||
assert tool.json_schema == custom_json_schema
|
||||
assert tool.name == "renamed_function"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
@@ -2006,3 +1883,58 @@ def test_import_agent_with_files_from_disk(client: LettaSDKClient):
|
||||
# Clean up agents and sources
|
||||
client.agents.delete(agent_id=imported_agent_id)
|
||||
client.sources.delete(source_id=imported_source.id)
|
||||
|
||||
|
||||
def test_upsert_tools(client: LettaSDKClient):
|
||||
"""Test upserting tools with complex schemas."""
|
||||
from typing import List
|
||||
|
||||
class WriteReasonOffer(BaseModel):
|
||||
biltMerchantId: str = Field(..., description="The merchant ID (e.g. 'MERCHANT_NETWORK-123' or 'LYFT')")
|
||||
campaignId: str = Field(
|
||||
...,
|
||||
description="The campaign ID (e.g. '550e8400-e29b-41d4-a716-446655440000' or '550e8400-e29b-41d4-a716-446655440000_123e4567-e89b-12d3-a456-426614174000')",
|
||||
)
|
||||
reason: str = Field(
|
||||
...,
|
||||
description="A detailed explanation of why this offer is relevant to the user. Refer to the category-specific reason_instructions_{category} block for all guidelines on creating personalized reasons.",
|
||||
)
|
||||
|
||||
class WriteReasonArgs(BaseModel):
|
||||
"""Arguments for the write_reason tool."""
|
||||
|
||||
offer_list: List[WriteReasonOffer] = Field(
|
||||
...,
|
||||
description="List of WriteReasonOffer objects with merchant and campaign information",
|
||||
)
|
||||
|
||||
def write_reason(offer_list: List[WriteReasonOffer]):
|
||||
"""
|
||||
This tool is used to write detailed reasons for a list of offers.
|
||||
It returns the essential information: biltMerchantId, campaignId, and reason.
|
||||
|
||||
IMPORTANT: When generating reasons, you MUST ONLY follow the guidelines in the
|
||||
category-specific instruction block named "reason_instructions_{category}" where
|
||||
{category} is the category of the offer (e.g., dining, travel, shopping).
|
||||
|
||||
These instruction blocks contain all the necessary guidelines for creating
|
||||
personalized, detailed reasons for each category. Do not rely on any other
|
||||
instructions outside of these blocks.
|
||||
|
||||
Args:
|
||||
offer_list: List of WriteReasonOffer objects, each containing:
|
||||
- biltMerchantId: The merchant ID (e.g. 'MERCHANT_NETWORK-123' or 'LYFT')
|
||||
- campaignId: The campaign ID (e.g. '124', '28')
|
||||
- reason: A detailed explanation generated according to the category-specific reason_instructions_{category} block
|
||||
|
||||
Returns:
|
||||
None: This function prints the offer list but does not return a value.
|
||||
"""
|
||||
print(offer_list)
|
||||
|
||||
tool = client.tools.upsert_from_function(func=write_reason, args_schema=WriteReasonArgs)
|
||||
assert tool is not None
|
||||
assert tool.name == "write_reason"
|
||||
|
||||
# Clean up
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
Reference in New Issue
Block a user