diff --git a/letta/errors.py b/letta/errors.py index c3b218fc..d2e643e3 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -49,6 +49,17 @@ class LettaToolCreateError(LettaError): super().__init__(message=message or self.default_error_message) +class LettaToolNameConflictError(LettaError): + """Error raised when a tool name already exists.""" + + def __init__(self, tool_name: str): + super().__init__( + message=f"Tool with name '{tool_name}' already exists in your organization", + code=ErrorCode.INVALID_ARGUMENT, + details={"tool_name": tool_name}, + ) + + class LettaConfigurationError(LettaError): """Error raised when there are configuration-related issues.""" diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 9e9467f7..6d33aaa6 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -16,7 +16,7 @@ from httpx import HTTPStatusError from pydantic import BaseModel, Field from starlette.responses import StreamingResponse -from letta.errors import LettaToolCreateError +from letta.errors import LettaToolCreateError, LettaToolNameConflictError from letta.functions.functions import derive_openai_json_schema from letta.functions.mcp_client.exceptions import MCPTimeoutError from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig @@ -191,6 +191,10 @@ async def modify_tool( try: actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) return await server.tool_manager.update_tool_by_id_async(tool_id=tool_id, tool_update=request, actor=actor) + except LettaToolNameConflictError as e: + # HTTP 409 == Conflict + print(f"Tool name conflict during update: {e}") + raise HTTPException(status_code=409, detail=str(e)) except LettaToolCreateError as e: # HTTP 400 == Bad Request print(f"Error occurred during tool update: {e}") diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 5071f79c..43bf244b 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -19,6 +19,7 @@ from letta.constants import ( LOCAL_ONLY_MULTI_AGENT_TOOLS, MCP_TOOL_TAG_NAME_PREFIX, ) +from letta.errors import LettaToolNameConflictError from letta.functions.functions import derive_openai_json_schema, load_function_set from letta.log import get_logger from letta.orm.enums import ToolType @@ -299,6 +300,16 @@ class ToolManager: count = result.scalar() return count > 0 + @enforce_types + @trace_method + async def tool_name_exists_async(self, tool_name: str, actor: PydanticUser) -> bool: + """Check if a tool with the given name exists in the user's organization (lightweight check).""" + async with db_registry.async_session() as session: + query = select(func.count(ToolModel.id)).where(ToolModel.name == tool_name, ToolModel.organization_id == actor.organization_id) + result = await session.execute(query) + count = result.scalar() + return count > 0 + @enforce_types @trace_method async def list_tools_async( @@ -379,22 +390,39 @@ class ToolManager: self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None ) -> PydanticTool: """Update a tool by its ID with the given ToolUpdate object.""" + # First, check if source code update would cause a name conflict + update_data = tool_update.model_dump(to_orm=True, exclude_none=True) + new_name = None + new_schema = None + + if "source_code" in update_data.keys() and "json_schema" not in update_data.keys(): + # Derive the new schema and name from the source code + new_schema = derive_openai_json_schema(source_code=update_data["source_code"]) + new_name = new_schema["name"] + + # Get current tool to check if name is changing + current_tool = self.get_tool_by_id(tool_id=tool_id, actor=actor) + + # Check if the name is changing and if so, verify it doesn't conflict + if new_name != current_tool.name: + # Check if a tool with the new name already exists + existing_tool = self.get_tool_by_name(tool_name=new_name, actor=actor) + if existing_tool: + raise LettaToolNameConflictError(tool_name=new_name) + + # Now perform the update within the session with db_registry.session() as session: # Fetch the tool by ID tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor) # Update tool attributes with only the fields that were explicitly set - update_data = tool_update.model_dump(to_orm=True, exclude_none=True) for key, value in update_data.items(): setattr(tool, key, value) - # If source code is changed and a new json_schema is not provided, we want to auto-refresh the schema - if "source_code" in update_data.keys() and "json_schema" not in update_data.keys(): - pydantic_tool = tool.to_pydantic() - new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code) - + # If we already computed the new schema, apply it + if new_schema is not None: tool.json_schema = new_schema - tool.name = new_schema["name"] + tool.name = new_name if updated_tool_type: tool.tool_type = updated_tool_type @@ -408,22 +436,39 @@ class ToolManager: self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None ) -> PydanticTool: """Update a tool by its ID with the given ToolUpdate object.""" + # First, check if source code update would cause a name conflict + update_data = tool_update.model_dump(to_orm=True, exclude_none=True) + new_name = None + new_schema = None + + if "source_code" in update_data.keys() and "json_schema" not in update_data.keys(): + # Derive the new schema and name from the source code + new_schema = derive_openai_json_schema(source_code=update_data["source_code"]) + new_name = new_schema["name"] + + # Get current tool to check if name is changing + current_tool = await self.get_tool_by_id_async(tool_id=tool_id, actor=actor) + + # Check if the name is changing and if so, verify it doesn't conflict + if new_name != current_tool.name: + # Check if a tool with the new name already exists + name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor) + if name_exists: + raise LettaToolNameConflictError(tool_name=new_name) + + # Now perform the update within the session async with db_registry.async_session() as session: # Fetch the tool by ID tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor) # Update tool attributes with only the fields that were explicitly set - update_data = tool_update.model_dump(to_orm=True, exclude_none=True) for key, value in update_data.items(): setattr(tool, key, value) - # If source code is changed and a new json_schema is not provided, we want to auto-refresh the schema - if "source_code" in update_data.keys() and "json_schema" not in update_data.keys(): - pydantic_tool = tool.to_pydantic() - new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code) - + # If we already computed the new schema, apply it + if new_schema is not None: tool.json_schema = new_schema - tool.name = new_schema["name"] + tool.name = new_name if updated_tool_type: tool.tool_type = updated_tool_type diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index d4bb7376..86111b25 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -1085,3 +1085,142 @@ def test_agent_tools_list(client: LettaSDKClient): finally: # Clean up client.agents.delete(agent_id=agent_state.id) + + +def test_update_tool_source_code_changes_name(client: LettaSDKClient): + """Test that updating a tool's source code correctly changes its name""" + import textwrap + + # Create initial tool + def initial_tool(x: int) -> int: + """ + Multiply a number by 2 + + Args: + x: The input number + Returns: + The input multiplied by 2 + """ + return x * 2 + + # Create the tool + tool = client.tools.upsert_from_function(func=initial_tool) + assert tool.name == "initial_tool" + + try: + # Define new function source code with different name + new_source_code = textwrap.dedent( + """ + def updated_tool(x: int, y: int) -> int: + ''' + Add two numbers together + + Args: + x: First number + y: Second number + Returns: + Sum of x and y + ''' + return x + y + """ + ).strip() + + # Update the tool's source code + updated = client.tools.modify(tool_id=tool.id, source_code=new_source_code) + + # Verify the name changed + assert updated.name == "updated_tool" + assert updated.source_code == new_source_code + + # Verify the schema was updated for the new parameters + assert updated.json_schema is not None + assert updated.json_schema["name"] == "updated_tool" + assert updated.json_schema["description"] == "Add two numbers together" + + # Check parameters + params = updated.json_schema.get("parameters", {}) + properties = params.get("properties", {}) + assert "x" in properties + assert "y" in properties + assert properties["x"]["type"] == "integer" + assert properties["y"]["type"] == "integer" + assert properties["x"]["description"] == "First number" + assert properties["y"]["description"] == "Second number" + assert params["required"] == ["x", "y"] + + finally: + # Clean up + client.tools.delete(tool_id=tool.id) + + +def test_update_tool_source_code_duplicate_name_error(client: LettaSDKClient): + """Test that updating a tool's source code to have the same name as another existing tool raises an error""" + import textwrap + + # Create first tool + def first_tool(x: int) -> int: + """ + Multiply a number by 2 + + Args: + x: The input number + + Returns: + The input multiplied by 2 + """ + return x * 2 + + # Create second tool + def second_tool(x: int) -> int: + """ + Multiply a number by 3 + + Args: + x: The input number + + Returns: + The input multiplied by 3 + """ + return x * 3 + + # Create both tools + tool1 = client.tools.upsert_from_function(func=first_tool) + tool2 = client.tools.upsert_from_function(func=second_tool) + + assert tool1.name == "first_tool" + assert tool2.name == "second_tool" + + try: + # Try to update second_tool to have the same name as first_tool + new_source_code = textwrap.dedent( + """ + def first_tool(x: int) -> int: + ''' + Multiply a number by 4 + + Args: + x: The input number + + Returns: + The input multiplied by 4 + ''' + return x * 4 + """ + ).strip() + + # This should raise an error since first_tool already exists + with pytest.raises(Exception) as exc_info: + client.tools.modify(tool_id=tool2.id, source_code=new_source_code) + + # Verify the error message indicates duplicate name + error_message = str(exc_info.value) + assert "already exists" in error_message.lower() or "duplicate" in error_message.lower() or "conflict" in error_message.lower() + + # Verify that tool2 was not modified + tool2_check = client.tools.retrieve(tool_id=tool2.id) + assert tool2_check.name == "second_tool" # Name should remain unchanged + + finally: + # Clean up both tools + client.tools.delete(tool_id=tool1.id) + client.tools.delete(tool_id=tool2.id)