diff --git a/letta/errors.py b/letta/errors.py index 52db7c00..f3188a96 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -60,6 +60,15 @@ class LettaToolNameConflictError(LettaError): ) +class LettaToolNameSchemaMismatchError(LettaToolCreateError): + """Error raised when a tool name our source codedoes not match the name in the JSON schema.""" + + def __init__(self, tool_name: str, json_schema_name: str, source_code: str): + super().__init__( + message=f"Tool name '{tool_name}' does not match the name in the JSON schema '{json_schema_name}' or in the source code `{source_code}`", + ) + + class LettaConfigurationError(LettaError): """Error raised when there are configuration-related issues.""" diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 19fc6161..fb71392a 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -80,7 +80,8 @@ class Tool(BaseTool): """ from letta.functions.helpers import generate_model_from_args_json_schema - if self.tool_type == ToolType.CUSTOM: + if self.tool_type == ToolType.CUSTOM and not self.json_schema: + # attempt various fallbacks to get the JSON schema if not self.source_code: logger.error("Custom tool with id=%s is missing source_code field", self.id) raise ValueError(f"Custom tool with id={self.id} is missing source_code field.") @@ -157,7 +158,7 @@ class Tool(BaseTool): class ToolCreate(LettaBase): description: Optional[str] = Field(None, description="The description of the tool.") - tags: List[str] = Field([], description="Metadata tags.") + tags: Optional[List[str]] = Field(None, description="Metadata tags.") source_code: str = Field(..., description="The source code of the function.") source_type: str = Field("python", description="The source type of the function.") json_schema: Optional[Dict] = Field( diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index a1a0230f..9251c04f 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -130,7 +130,7 @@ async def create_tool( """ try: actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - tool = Tool(**request.model_dump()) + tool = Tool(**request.model_dump(exclude_unset=True)) return await server.tool_manager.create_tool_async(pydantic_tool=tool, actor=actor) except UniqueConstraintViolationError as e: # Log or print the full exception here for debugging @@ -162,7 +162,9 @@ async def upsert_tool( """ try: actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - tool = await server.tool_manager.create_or_update_tool_async(pydantic_tool=Tool(**request.model_dump()), actor=actor) + tool = await server.tool_manager.create_or_update_tool_async( + pydantic_tool=Tool(**request.model_dump(exclude_unset=True)), actor=actor + ) return tool except UniqueConstraintViolationError as e: # Log the error and raise a conflict exception @@ -190,18 +192,17 @@ async def modify_tool( """ try: actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return await server.tool_manager.update_tool_by_id_async(tool_id=tool_id, tool_update=request, actor=actor) + tool = await server.tool_manager.update_tool_by_id_async(tool_id=tool_id, tool_update=request, actor=actor) + print("FINAL TOOL", tool) + return tool except LettaToolNameConflictError as e: # HTTP 409 == Conflict - print(f"Tool name conflict during update: {e}") raise HTTPException(status_code=409, detail=str(e)) except LettaToolCreateError as e: # HTTP 400 == Bad Request - print(f"Error occurred during tool update: {e}") raise HTTPException(status_code=400, detail=str(e)) except Exception as e: # Catch other unexpected errors and raise an internal server error - print(f"Unexpected error occurred: {e}") raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index d85e5c5c..25596445 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -18,7 +18,7 @@ from letta.constants import ( LOCAL_ONLY_MULTI_AGENT_TOOLS, MCP_TOOL_TAG_NAME_PREFIX, ) -from letta.errors import LettaToolNameConflictError +from letta.errors import LettaToolNameConflictError, LettaToolNameSchemaMismatchError from letta.functions.functions import derive_openai_json_schema, load_function_set from letta.log import get_logger @@ -403,6 +403,7 @@ class ToolManager: updated_tool_type: Optional[ToolType] = None, bypass_name_check: bool = False, ) -> PydanticTool: + # TODO: remove this (legacy non-async) """ Update a tool with complex validation and schema derivation logic. @@ -519,55 +520,36 @@ class ToolManager: # Fetch current tool early to allow conditional logic based on tool type current_tool = await self.get_tool_by_id_async(tool_id=tool_id, actor=actor) - # For MCP tools, do NOT derive schema from Python source. Trust provided JSON schema. - if current_tool.tool_type == ToolType.EXTERNAL_MCP: - # Prefer provided json_schema; fall back to current - if "json_schema" in update_data: - new_schema = update_data["json_schema"].copy() - new_name = new_schema.get("name", current_tool.name) - else: - new_schema = current_tool.json_schema - new_name = current_tool.name - # Ensure we don't trigger derive - update_data.pop("source_code", None) - # If name changes, enforce uniqueness - if new_name != current_tool.name: - name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor) - if name_exists: - raise LettaToolNameConflictError(tool_name=new_name) + # Do NOT derive schema from Python source. Trust provided JSON schema. + # Prefer provided json_schema; fall back to current + if "json_schema" in update_data: + new_schema = update_data["json_schema"].copy() + new_name = new_schema.get("name", current_tool.name) else: - # For non-MCP tools, preserve existing behavior - # TODO: Consider this behavior...is this what we want? - # TODO: I feel like it's bad if json_schema strays from source code so - # if source code is provided, always derive the name from it - if "source_code" in update_data.keys() and not bypass_name_check: - # Check source type to use appropriate parser - source_type = update_data.get("source_type", current_tool.source_type) - if source_type == "typescript": - from letta.functions.typescript_parser import derive_typescript_json_schema + new_schema = current_tool.json_schema + new_name = current_tool.name - derived_schema = derive_typescript_json_schema(source_code=update_data["source_code"]) - else: - # Default to Python for backwards compatibility - derived_schema = derive_openai_json_schema(source_code=update_data["source_code"]) - new_name = derived_schema["name"] + # original tool may no have a JSON schema at all for legacy reasons + # in this case, fallback to dangerous schema generation + if new_schema is None: + if source_type == "typescript": + from letta.functions.typescript_parser import derive_typescript_json_schema - # if json_schema wasn't provided, use the derived schema - if "json_schema" not in update_data.keys(): - new_schema = derived_schema - else: - # if json_schema was provided, update only its name to match the source code - new_schema = update_data["json_schema"].copy() - new_schema["name"] = new_name - # update the json_schema in update_data so it gets applied in the loop - update_data["json_schema"] = new_schema + new_schema = derive_typescript_json_schema(source_code=update_data["source_code"]) + else: + new_schema = derive_openai_json_schema(source_code=update_data["source_code"]) - # check if the name is changing and if so, verify it doesn't conflict - if new_name != current_tool.name: - # check if a tool with the new name already exists - name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor) - if name_exists: - raise LettaToolNameConflictError(tool_name=new_name) + # If name changes, enforce uniqueness + if new_name != current_tool.name: + name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor) + if name_exists: + raise LettaToolNameConflictError(tool_name=new_name) + + # NOTE: EXTREMELEY HACKY, we need to stop making assumptions about the source_code + if "source_code" in update_data and f"def {new_name}" not in update_data.get("source_code", ""): + raise LettaToolNameSchemaMismatchError( + tool_name=new_name, json_schema_name=new_schema.get("name"), source_code=update_data.get("source_code") + ) # Now perform the update within the session async with db_registry.async_session() as session: diff --git a/tests/integration_test_pinecone_tool.py b/tests/integration_test_pinecone_tool.py deleted file mode 100644 index 20d9d1ee..00000000 --- a/tests/integration_test_pinecone_tool.py +++ /dev/null @@ -1,215 +0,0 @@ -import asyncio -import json -import os -import threading -import time - -import pytest -import requests -from dotenv import load_dotenv -from letta_client import AsyncLetta, MessageCreate, ReasoningMessage, ToolCallMessage -from letta_client.core import RequestOptions - -from tests.helpers.utils import upload_test_agentfile_from_disk_async - -REASONING_THROTTLE_MS = 100 -TEST_USER_MESSAGE = "What products or services does 11x AI sell?" - - -@pytest.fixture(scope="module") -def server_url() -> str: - """ - Provides the URL for the Letta server. - If LETTA_SERVER_URL is not set, starts the server in a background thread - and polls until it's accepting connections. - """ - - def _run_server() -> None: - load_dotenv() - from letta.server.rest_api.app import start_server - - start_server(debug=True) - - url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - - if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=_run_server, daemon=True) - thread.start() - - # Poll until the server is up (or timeout) - timeout_seconds = 30 - deadline = time.time() + timeout_seconds - while time.time() < deadline: - try: - resp = requests.get(url + "/v1/health") - if resp.status_code < 500: - break - except requests.exceptions.RequestException: - pass - time.sleep(0.1) - else: - raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") - - return url - - -@pytest.fixture(scope="function") -def client(server_url: str): - """ - Creates and returns an asynchronous Letta REST client for testing. - """ - async_client_instance = AsyncLetta(base_url=server_url) - yield async_client_instance - - -async def test_pinecone_tool(client: AsyncLetta, server_url: str) -> None: - """ - Test the Pinecone tool integration with the Letta client. - """ - response = await upload_test_agentfile_from_disk_async(client, "knowledge-base.af") - - agent_id = response.agent_ids[0] - - agent = await client.agents.modify( - agent_id=agent_id, - tool_exec_environment_variables={ - "PINECONE_INDEX_HOST": os.getenv("PINECONE_INDEX_HOST"), - "PINECONE_API_KEY": os.getenv("PINECONE_API_KEY"), - "PINECONE_NAMESPACE": os.getenv("PINECONE_NAMESPACE"), - }, - ) - last_message = await client.agents.messages.list( - agent_id=agent.id, - limit=1, - ) - - curr_message_type = None - messages = [] - reasoning_content = [] - last_reasoning_update_ms = 0 - tool_call_content = "" - tool_return_content = "" - summary = None - pinecone_results = None - queries = [] - - try: - response = client.agents.messages.create_stream( - agent_id=agent.id, - messages=[ - MessageCreate( - role="user", - content=TEST_USER_MESSAGE, - ), - ], - stream_tokens=True, - request_options=RequestOptions( - timeout_in_seconds=1000, - ), - ) - - async for chunk in response: - if chunk.message_type != curr_message_type: - messages.append(chunk) - curr_message_type = chunk.message_type - if curr_message_type == "reasoning_message": - reasoning_content = [] - if curr_message_type == "tool_call_message": - tool_call_content = "" - - if chunk.message_type == "reasoning_message": - now_ms = time.time_ns() // 1_000_000 - if now_ms - last_reasoning_update_ms < REASONING_THROTTLE_MS: - await asyncio.sleep(REASONING_THROTTLE_MS / 1000) - - last_reasoning_update_ms = now_ms - if len(reasoning_content) == 0: - reasoning_content = [chunk.reasoning] - else: - reasoning_content[-1] += chunk.reasoning - - message_dict = messages[-1].model_dump() - message_dict["reasoning"] = "".join(reasoning_content).strip() - messages[-1] = ReasoningMessage(**message_dict) - - if chunk.message_type == "tool_return_message": - tool_return_content += chunk.tool_return - - if chunk.status == "success": - try: - if chunk.name == "summarize_pinecone_results": - json_response = json.loads(chunk.tool_return) - summary = json_response.get("summary", None) - pinecone_results = json_response.get("pinecone_results", None) - tool_return_content = "" - elif chunk.name == "craft_queries": - queries.append(chunk.tool_return) - tool_return_content = "" - except Exception as e: - print(f"Error parsing JSON response: {str(e)}. {chunk.tool_return}\n") - tool_return_content = "" - - if chunk.message_type == "tool_call_message": - if chunk.tool_call.arguments is not None: - tool_call_content += chunk.tool_call.arguments - message_dict = messages[-1].model_dump() - message_dict["tool_call"]["arguments"] = tool_call_content - messages[-1] = ToolCallMessage(**message_dict) - - except Exception as e: - print(f"Failed to fetch knowledge base response: {str(e)}\n") - print(tool_call_content) - raise e - - assert len(messages) > 0, "No messages received from the agent." - assert len(reasoning_content) > 0, "No reasoning content received from the agent." - assert summary is not None, "No summary received from the agent." - assert pinecone_results is not None, "No Pinecone results received from the agent." - assert len(queries) > 0, "No queries received from the agent." - - assert messages[-2].message_type == "stop_reason", "Penultimate message in stream must be stop reason." - assert messages[-1].message_type == "usage_statistics", "Last message in stream must be usage stats." - response_messages_from_stream = [m for m in messages if m.message_type not in ["stop_reason", "usage_statistics"]] - response_message_types_from_stream = [m.message_type for m in response_messages_from_stream] - - messages_from_db = await client.agents.messages.list( - agent_id=agent.id, - after=last_message[0].id, - limit=100, - ) - response_messages_from_db = [m for m in messages_from_db if m.message_type != "user_message"] - response_message_types_from_db = [m.message_type for m in response_messages_from_db] - - assert len(response_messages_from_stream) == len(response_messages_from_db) - assert response_message_types_from_stream == response_message_types_from_db - for idx in range(len(response_messages_from_stream)): - stream_message = response_messages_from_stream[idx] - db_message = response_messages_from_db[idx] - assert stream_message.message_type == db_message.message_type - assert stream_message.id == db_message.id - assert stream_message.otid == db_message.otid - - if stream_message.message_type == "reasoning_message": - assert stream_message.reasoning == db_message.reasoning - - if stream_message.message_type == "tool_call_message": - assert stream_message.tool_call.tool_call_id == db_message.tool_call.tool_call_id - assert stream_message.tool_call.name == db_message.tool_call.name - - if stream_message.tool_call.name == "craft_queries": - assert "queries" in stream_message.tool_call.arguments - assert "queries" in db_message.tool_call.arguments - if stream_message.tool_call.name == "search_and_store_pinecone_records": - assert "query_text" in stream_message.tool_call.arguments - assert "query_text" in db_message.tool_call.arguments - if stream_message.tool_call.name == "summarize_pinecone_results": - assert "summary" in stream_message.tool_call.arguments - assert "summary" in db_message.tool_call.arguments - - assert "inner_thoughts" not in stream_message.tool_call.arguments - assert "inner_thoughts" not in db_message.tool_call.arguments - - if stream_message.message_type == "tool_return_message": - assert stream_message.tool_return == db_message.tool_return - - await client.agents.delete(agent_id=agent.id) diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 5360adc2..4ed51b35 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -11,7 +11,7 @@ import pytest from dotenv import load_dotenv from letta_client import CreateBlock from letta_client import Letta as LettaSDKClient -from letta_client import LettaRequest, MessageCreate, TextContent +from letta_client import LettaRequest, MessageCreate, TerminalToolRule, TextContent from letta_client.client import BaseTool from letta_client.core import ApiError from letta_client.types import AgentState, ToolReturnMessage @@ -942,10 +942,22 @@ def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKCl print(f"Updated inventory for {data.item.name} with a quantity change of {quantity_change}") return True + # test creation tool = client.tools.add( tool=ManageInventoryTool(), ) + # test that upserting also works + new_description = "NEW" + + class ManageInventoryToolModified(ManageInventoryTool): + description: str = new_description + + tool = client.tools.add( + tool=ManageInventoryToolModified(), + ) + assert tool.description == new_description + assert tool is not None assert tool.name == "manage_inventory" assert "inventory" in tool.tags @@ -1005,7 +1017,7 @@ def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKCl client.tools.delete(tool.id) -@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True) +@pytest.mark.parametrize("e2b_sandbox_mode", [False], indirect=True) def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient): class Step(BaseModel): @@ -1021,7 +1033,18 @@ def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient): print(f"Created task plan with {len(steps)} steps: {explanation}") return steps - tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"]) + # test creation + client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"]) + + # test upsert + new_steps_description = "NEW" + + class StepsListModified(BaseModel): + steps: List[Step] = Field(..., description=new_steps_description) + explanation: str = Field(..., description="Explanation for the list of steps.") + + tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsListModified, description=new_steps_description) + assert tool.description == new_steps_description assert tool is not None assert tool.name == "create_task_plan" @@ -1039,6 +1062,9 @@ def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient): embedding="openai/text-embedding-3-small", tool_ids=[tool.id], include_base_tools=False, + tool_rules=[ + TerminalToolRule(tool_name=tool.name), + ], ) response = client.agents.messages.create( @@ -1062,6 +1088,7 @@ def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient): assert first_tool_call.tool_call.name == "create_task_plan" args = json.loads(first_tool_call.tool_call.arguments) + assert "steps" in args assert "explanation" in args assert isinstance(args["steps"], list) @@ -1224,145 +1251,6 @@ def test_agent_tools_list(client: LettaSDKClient): client.agents.delete(agent_id=agent_state.id) -def test_update_tool_source_code_changes_name(client: LettaSDKClient): - """Test that updating a tool's source code correctly changes its name""" - import textwrap - - # Create initial tool - def initial_tool(x: int) -> int: - """ - Multiply a number by 2 - - Args: - x: The input number - Returns: - The input multiplied by 2 - """ - return x * 2 - - # Create the tool - tool = client.tools.upsert_from_function(func=initial_tool) - assert tool.name == "initial_tool" - - try: - # Define new function source code with different name - new_source_code = textwrap.dedent( - """ - def updated_tool(x: int, y: int) -> int: - ''' - Add two numbers together - - Args: - x: First number - y: Second number - Returns: - Sum of x and y - ''' - return x + y - """ - ).strip() - - # Update the tool's source code - updated = client.tools.modify(tool_id=tool.id, source_code=new_source_code) - - # Verify the name changed - assert updated.name == "updated_tool" - assert updated.source_code == new_source_code - - # Verify the schema was updated for the new parameters - assert updated.json_schema is not None - assert updated.json_schema["name"] == "updated_tool" - assert updated.json_schema["description"] == "Add two numbers together" - - # Check parameters - params = updated.json_schema.get("parameters", {}) - properties = params.get("properties", {}) - assert "x" in properties - assert "y" in properties - assert properties["x"]["type"] == "integer" - assert properties["y"]["type"] == "integer" - assert properties["x"]["description"] == "First number" - assert properties["y"]["description"] == "Second number" - assert params["required"] == ["x", "y"] - - finally: - # Clean up - client.tools.delete(tool_id=tool.id) - - -def test_update_tool_source_code_duplicate_name_error(client: LettaSDKClient): - """Test that updating a tool's source code to have the same name as another existing tool raises an error""" - import textwrap - - # Create first tool - def first_tool(x: int) -> int: - """ - Multiply a number by 2 - - Args: - x: The input number - - Returns: - The input multiplied by 2 - """ - return x * 2 - - # Create second tool - def second_tool(x: int) -> int: - """ - Multiply a number by 3 - - Args: - x: The input number - - Returns: - The input multiplied by 3 - """ - return x * 3 - - # Create both tools - tool1 = client.tools.upsert_from_function(func=first_tool) - tool2 = client.tools.upsert_from_function(func=second_tool) - - assert tool1.name == "first_tool" - assert tool2.name == "second_tool" - - try: - # Try to update second_tool to have the same name as first_tool - new_source_code = textwrap.dedent( - """ - def first_tool(x: int) -> int: - ''' - Multiply a number by 4 - - Args: - x: The input number - - Returns: - The input multiplied by 4 - ''' - return x * 4 - """ - ).strip() - - # This should raise an error since first_tool already exists - with pytest.raises(Exception) as exc_info: - client.tools.modify(tool_id=tool2.id, source_code=new_source_code) - - # Verify the error message indicates duplicate name - error_message = str(exc_info.value) - assert "already exists" in error_message.lower() or "duplicate" in error_message.lower() or "conflict" in error_message.lower() - - # Verify that tool2 was not modified - tool2_check = client.tools.retrieve(tool_id=tool2.id) - assert tool2_check.name == "second_tool" # Name should remain unchanged - - finally: - # Clean up both tools - client.tools.delete(tool_id=tool1.id) - client.tools.delete(tool_id=tool2.id) - - def test_add_tool_with_multiple_functions_in_source_code(client: LettaSDKClient): """Test adding a tool with multiple functions in the source code""" import textwrap @@ -1445,143 +1333,144 @@ def test_add_tool_with_multiple_functions_in_source_code(client: LettaSDKClient) client.tools.delete(tool_id=tool.id) -def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient): - """Test that tool name auto-updates when source code changes with multiple functions""" - import textwrap - - # Initial source code with multiple functions - initial_source_code = textwrap.dedent( - """ - def helper_function(x: int) -> int: - ''' - Helper function that doubles the input - - Args: - x: The input number - - Returns: - The input multiplied by 2 - ''' - return x * 2 - - def another_helper(text: str) -> str: - ''' - Another helper that uppercases text - - Args: - text: The input text to uppercase - - Returns: - The uppercased text - ''' - return text.upper() - - def main_function(x: int, y: int) -> int: - ''' - Main function that uses the helper - - Args: - x: First number - y: Second number - - Returns: - Result of (x * 2) + y - ''' - doubled_x = helper_function(x) - return doubled_x + y - """ - ).strip() - - # Create tool with initial source code - tool = client.tools.create( - source_code=initial_source_code, - ) - - try: - # Verify the tool was created with the last function's name - assert tool is not None - assert tool.name == "main_function" - assert tool.source_code == initial_source_code - - # Now modify the source code with a different function order - new_source_code = textwrap.dedent( - """ - def process_data(data: str, count: int) -> str: - ''' - Process data by repeating it - - Args: - data: The input data - count: Number of times to repeat - - Returns: - The processed data - ''' - return data * count - - def helper_utility(x: float) -> float: - ''' - Helper utility function - - Args: - x: Input value - - Returns: - Squared value - ''' - return x * x - """ - ).strip() - - # Modify the tool with new source code - modified_tool = client.tools.modify(tool_id=tool.id, source_code=new_source_code) - - # Verify the name automatically updated to the last function - assert modified_tool.name == "helper_utility" - assert modified_tool.source_code == new_source_code - - # Verify the JSON schema updated correctly - assert modified_tool.json_schema is not None - assert modified_tool.json_schema["name"] == "helper_utility" - assert modified_tool.json_schema["description"] == "Helper utility function" - - # Check parameters updated correctly - params = modified_tool.json_schema.get("parameters", {}) - properties = params.get("properties", {}) - assert "x" in properties - assert properties["x"]["type"] == "number" # float maps to number - assert params["required"] == ["x"] - - # Test one more modification with only one function - single_function_code = textwrap.dedent( - """ - def calculate_total(items: list, tax_rate: float) -> float: - ''' - Calculate total with tax - - Args: - items: List of item prices - tax_rate: Tax rate as decimal - - Returns: - Total including tax - ''' - subtotal = sum(items) - return subtotal * (1 + tax_rate) - """ - ).strip() - - # Modify again - final_tool = client.tools.modify(tool_id=tool.id, source_code=single_function_code) - - # Verify name updated again - assert final_tool.name == "calculate_total" - assert final_tool.source_code == single_function_code - assert final_tool.json_schema["description"] == "Calculate total with tax" - - finally: - # Clean up - client.tools.delete(tool_id=tool.id) +# TODO: add back once behavior is defined +# def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient): +# """Test that tool name auto-updates when source code changes with multiple functions""" +# import textwrap +# +# # Initial source code with multiple functions +# initial_source_code = textwrap.dedent( +# """ +# def helper_function(x: int) -> int: +# ''' +# Helper function that doubles the input +# +# Args: +# x: The input number +# +# Returns: +# The input multiplied by 2 +# ''' +# return x * 2 +# +# def another_helper(text: str) -> str: +# ''' +# Another helper that uppercases text +# +# Args: +# text: The input text to uppercase +# +# Returns: +# The uppercased text +# ''' +# return text.upper() +# +# def main_function(x: int, y: int) -> int: +# ''' +# Main function that uses the helper +# +# Args: +# x: First number +# y: Second number +# +# Returns: +# Result of (x * 2) + y +# ''' +# doubled_x = helper_function(x) +# return doubled_x + y +# """ +# ).strip() +# +# # Create tool with initial source code +# tool = client.tools.create( +# source_code=initial_source_code, +# ) +# +# try: +# # Verify the tool was created with the last function's name +# assert tool is not None +# assert tool.name == "main_function" +# assert tool.source_code == initial_source_code +# +# # Now modify the source code with a different function order +# new_source_code = textwrap.dedent( +# """ +# def process_data(data: str, count: int) -> str: +# ''' +# Process data by repeating it +# +# Args: +# data: The input data +# count: Number of times to repeat +# +# Returns: +# The processed data +# ''' +# return data * count +# +# def helper_utility(x: float) -> float: +# ''' +# Helper utility function +# +# Args: +# x: Input value +# +# Returns: +# Squared value +# ''' +# return x * x +# """ +# ).strip() +# +# # Modify the tool with new source code +# modified_tool = client.tools.modify(name="helper_utility", tool_id=tool.id, source_code=new_source_code) +# +# # Verify the name automatically updated to the last function +# assert modified_tool.name == "helper_utility" +# assert modified_tool.source_code == new_source_code +# +# # Verify the JSON schema updated correctly +# assert modified_tool.json_schema is not None +# assert modified_tool.json_schema["name"] == "helper_utility" +# assert modified_tool.json_schema["description"] == "Helper utility function" +# +# # Check parameters updated correctly +# params = modified_tool.json_schema.get("parameters", {}) +# properties = params.get("properties", {}) +# assert "x" in properties +# assert properties["x"]["type"] == "number" # float maps to number +# assert params["required"] == ["x"] +# +# # Test one more modification with only one function +# single_function_code = textwrap.dedent( +# """ +# def calculate_total(items: list, tax_rate: float) -> float: +# ''' +# Calculate total with tax +# +# Args: +# items: List of item prices +# tax_rate: Tax rate as decimal +# +# Returns: +# Total including tax +# ''' +# subtotal = sum(items) +# return subtotal * (1 + tax_rate) +# """ +# ).strip() +# +# # Modify again +# final_tool = client.tools.modify(tool_id=tool.id, source_code=single_function_code) +# +# # Verify name updated again +# assert final_tool.name == "calculate_total" +# assert final_tool.source_code == single_function_code +# assert final_tool.json_schema["description"] == "Calculate total with tax" +# +# finally: +# # Clean up +# client.tools.delete(tool_id=tool.id) def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient): @@ -1637,28 +1526,16 @@ def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient): }, } - # Modify the tool with both new source code AND JSON schema - modified_tool = client.tools.modify(tool_id=tool.id, source_code=new_source_code, json_schema=custom_json_schema) + # verify there is a 400 error when both source code and json schema are provided + with pytest.raises(Exception) as e: + client.tools.modify(tool_id=tool.id, source_code=new_source_code, json_schema=custom_json_schema) + assert e.value.status_code == 400 - # Verify the name comes from the source code function name, not the JSON schema - assert modified_tool.name == "renamed_function" - assert modified_tool.source_code == new_source_code - - # Verify the JSON schema was updated to match the function name from source code - assert modified_tool.json_schema is not None - assert modified_tool.json_schema["name"] == "renamed_function" - - # The description should come from the source code docstring, not the JSON schema - assert modified_tool.json_schema["description"] == "Multiply a value by a multiplier" - - # Verify parameters are from the source code, not the custom JSON schema - params = modified_tool.json_schema.get("parameters", {}) - properties = params.get("properties", {}) - assert "value" in properties - assert "multiplier" in properties - assert properties["value"]["type"] == "number" - assert properties["multiplier"]["type"] == "number" - assert params["required"] == ["value"] + # update with consistent name and schema + custom_json_schema["name"] = "renamed_function" + tool = client.tools.modify(tool_id=tool.id, json_schema=custom_json_schema) + assert tool.json_schema == custom_json_schema + assert tool.name == "renamed_function" finally: # Clean up @@ -2006,3 +1883,58 @@ def test_import_agent_with_files_from_disk(client: LettaSDKClient): # Clean up agents and sources client.agents.delete(agent_id=imported_agent_id) client.sources.delete(source_id=imported_source.id) + + +def test_upsert_tools(client: LettaSDKClient): + """Test upserting tools with complex schemas.""" + from typing import List + + class WriteReasonOffer(BaseModel): + biltMerchantId: str = Field(..., description="The merchant ID (e.g. 'MERCHANT_NETWORK-123' or 'LYFT')") + campaignId: str = Field( + ..., + description="The campaign ID (e.g. '550e8400-e29b-41d4-a716-446655440000' or '550e8400-e29b-41d4-a716-446655440000_123e4567-e89b-12d3-a456-426614174000')", + ) + reason: str = Field( + ..., + description="A detailed explanation of why this offer is relevant to the user. Refer to the category-specific reason_instructions_{category} block for all guidelines on creating personalized reasons.", + ) + + class WriteReasonArgs(BaseModel): + """Arguments for the write_reason tool.""" + + offer_list: List[WriteReasonOffer] = Field( + ..., + description="List of WriteReasonOffer objects with merchant and campaign information", + ) + + def write_reason(offer_list: List[WriteReasonOffer]): + """ + This tool is used to write detailed reasons for a list of offers. + It returns the essential information: biltMerchantId, campaignId, and reason. + + IMPORTANT: When generating reasons, you MUST ONLY follow the guidelines in the + category-specific instruction block named "reason_instructions_{category}" where + {category} is the category of the offer (e.g., dining, travel, shopping). + + These instruction blocks contain all the necessary guidelines for creating + personalized, detailed reasons for each category. Do not rely on any other + instructions outside of these blocks. + + Args: + offer_list: List of WriteReasonOffer objects, each containing: + - biltMerchantId: The merchant ID (e.g. 'MERCHANT_NETWORK-123' or 'LYFT') + - campaignId: The campaign ID (e.g. '124', '28') + - reason: A detailed explanation generated according to the category-specific reason_instructions_{category} block + + Returns: + None: This function prints the offer list but does not return a value. + """ + print(offer_list) + + tool = client.tools.upsert_from_function(func=write_reason, args_schema=WriteReasonArgs) + assert tool is not None + assert tool.name == "write_reason" + + # Clean up + client.tools.delete(tool.id)