feat: Improve tool renaming based on json schema (#3676)
This commit is contained in:
@@ -49,6 +49,17 @@ class LettaToolCreateError(LettaError):
|
||||
super().__init__(message=message or self.default_error_message)
|
||||
|
||||
|
||||
class LettaToolNameConflictError(LettaError):
|
||||
"""Error raised when a tool name already exists."""
|
||||
|
||||
def __init__(self, tool_name: str):
|
||||
super().__init__(
|
||||
message=f"Tool with name '{tool_name}' already exists in your organization",
|
||||
code=ErrorCode.INVALID_ARGUMENT,
|
||||
details={"tool_name": tool_name},
|
||||
)
|
||||
|
||||
|
||||
class LettaConfigurationError(LettaError):
|
||||
"""Error raised when there are configuration-related issues."""
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from httpx import HTTPStatusError
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from letta.errors import LettaToolCreateError
|
||||
from letta.errors import LettaToolCreateError, LettaToolNameConflictError
|
||||
from letta.functions.functions import derive_openai_json_schema
|
||||
from letta.functions.mcp_client.exceptions import MCPTimeoutError
|
||||
from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
||||
@@ -191,6 +191,10 @@ async def modify_tool(
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.tool_manager.update_tool_by_id_async(tool_id=tool_id, tool_update=request, actor=actor)
|
||||
except LettaToolNameConflictError as e:
|
||||
# HTTP 409 == Conflict
|
||||
print(f"Tool name conflict during update: {e}")
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except LettaToolCreateError as e:
|
||||
# HTTP 400 == Bad Request
|
||||
print(f"Error occurred during tool update: {e}")
|
||||
|
||||
@@ -19,6 +19,7 @@ from letta.constants import (
|
||||
LOCAL_ONLY_MULTI_AGENT_TOOLS,
|
||||
MCP_TOOL_TAG_NAME_PREFIX,
|
||||
)
|
||||
from letta.errors import LettaToolNameConflictError
|
||||
from letta.functions.functions import derive_openai_json_schema, load_function_set
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
@@ -299,6 +300,16 @@ class ToolManager:
|
||||
count = result.scalar()
|
||||
return count > 0
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def tool_name_exists_async(self, tool_name: str, actor: PydanticUser) -> bool:
|
||||
"""Check if a tool with the given name exists in the user's organization (lightweight check)."""
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(func.count(ToolModel.id)).where(ToolModel.name == tool_name, ToolModel.organization_id == actor.organization_id)
|
||||
result = await session.execute(query)
|
||||
count = result.scalar()
|
||||
return count > 0
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_tools_async(
|
||||
@@ -379,22 +390,39 @@ class ToolManager:
|
||||
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
|
||||
) -> PydanticTool:
|
||||
"""Update a tool by its ID with the given ToolUpdate object."""
|
||||
# First, check if source code update would cause a name conflict
|
||||
update_data = tool_update.model_dump(to_orm=True, exclude_none=True)
|
||||
new_name = None
|
||||
new_schema = None
|
||||
|
||||
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
|
||||
# Derive the new schema and name from the source code
|
||||
new_schema = derive_openai_json_schema(source_code=update_data["source_code"])
|
||||
new_name = new_schema["name"]
|
||||
|
||||
# Get current tool to check if name is changing
|
||||
current_tool = self.get_tool_by_id(tool_id=tool_id, actor=actor)
|
||||
|
||||
# Check if the name is changing and if so, verify it doesn't conflict
|
||||
if new_name != current_tool.name:
|
||||
# Check if a tool with the new name already exists
|
||||
existing_tool = self.get_tool_by_name(tool_name=new_name, actor=actor)
|
||||
if existing_tool:
|
||||
raise LettaToolNameConflictError(tool_name=new_name)
|
||||
|
||||
# Now perform the update within the session
|
||||
with db_registry.session() as session:
|
||||
# Fetch the tool by ID
|
||||
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
|
||||
|
||||
# Update tool attributes with only the fields that were explicitly set
|
||||
update_data = tool_update.model_dump(to_orm=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(tool, key, value)
|
||||
|
||||
# If source code is changed and a new json_schema is not provided, we want to auto-refresh the schema
|
||||
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
|
||||
pydantic_tool = tool.to_pydantic()
|
||||
new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code)
|
||||
|
||||
# If we already computed the new schema, apply it
|
||||
if new_schema is not None:
|
||||
tool.json_schema = new_schema
|
||||
tool.name = new_schema["name"]
|
||||
tool.name = new_name
|
||||
|
||||
if updated_tool_type:
|
||||
tool.tool_type = updated_tool_type
|
||||
@@ -408,22 +436,39 @@ class ToolManager:
|
||||
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
|
||||
) -> PydanticTool:
|
||||
"""Update a tool by its ID with the given ToolUpdate object."""
|
||||
# First, check if source code update would cause a name conflict
|
||||
update_data = tool_update.model_dump(to_orm=True, exclude_none=True)
|
||||
new_name = None
|
||||
new_schema = None
|
||||
|
||||
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
|
||||
# Derive the new schema and name from the source code
|
||||
new_schema = derive_openai_json_schema(source_code=update_data["source_code"])
|
||||
new_name = new_schema["name"]
|
||||
|
||||
# Get current tool to check if name is changing
|
||||
current_tool = await self.get_tool_by_id_async(tool_id=tool_id, actor=actor)
|
||||
|
||||
# Check if the name is changing and if so, verify it doesn't conflict
|
||||
if new_name != current_tool.name:
|
||||
# Check if a tool with the new name already exists
|
||||
name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
|
||||
if name_exists:
|
||||
raise LettaToolNameConflictError(tool_name=new_name)
|
||||
|
||||
# Now perform the update within the session
|
||||
async with db_registry.async_session() as session:
|
||||
# Fetch the tool by ID
|
||||
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
|
||||
|
||||
# Update tool attributes with only the fields that were explicitly set
|
||||
update_data = tool_update.model_dump(to_orm=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(tool, key, value)
|
||||
|
||||
# If source code is changed and a new json_schema is not provided, we want to auto-refresh the schema
|
||||
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
|
||||
pydantic_tool = tool.to_pydantic()
|
||||
new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code)
|
||||
|
||||
# If we already computed the new schema, apply it
|
||||
if new_schema is not None:
|
||||
tool.json_schema = new_schema
|
||||
tool.name = new_schema["name"]
|
||||
tool.name = new_name
|
||||
|
||||
if updated_tool_type:
|
||||
tool.tool_type = updated_tool_type
|
||||
|
||||
@@ -1085,3 +1085,142 @@ def test_agent_tools_list(client: LettaSDKClient):
|
||||
finally:
|
||||
# Clean up
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
|
||||
def test_update_tool_source_code_changes_name(client: LettaSDKClient):
|
||||
"""Test that updating a tool's source code correctly changes its name"""
|
||||
import textwrap
|
||||
|
||||
# Create initial tool
|
||||
def initial_tool(x: int) -> int:
|
||||
"""
|
||||
Multiply a number by 2
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
Returns:
|
||||
The input multiplied by 2
|
||||
"""
|
||||
return x * 2
|
||||
|
||||
# Create the tool
|
||||
tool = client.tools.upsert_from_function(func=initial_tool)
|
||||
assert tool.name == "initial_tool"
|
||||
|
||||
try:
|
||||
# Define new function source code with different name
|
||||
new_source_code = textwrap.dedent(
|
||||
"""
|
||||
def updated_tool(x: int, y: int) -> int:
|
||||
'''
|
||||
Add two numbers together
|
||||
|
||||
Args:
|
||||
x: First number
|
||||
y: Second number
|
||||
Returns:
|
||||
Sum of x and y
|
||||
'''
|
||||
return x + y
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# Update the tool's source code
|
||||
updated = client.tools.modify(tool_id=tool.id, source_code=new_source_code)
|
||||
|
||||
# Verify the name changed
|
||||
assert updated.name == "updated_tool"
|
||||
assert updated.source_code == new_source_code
|
||||
|
||||
# Verify the schema was updated for the new parameters
|
||||
assert updated.json_schema is not None
|
||||
assert updated.json_schema["name"] == "updated_tool"
|
||||
assert updated.json_schema["description"] == "Add two numbers together"
|
||||
|
||||
# Check parameters
|
||||
params = updated.json_schema.get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
assert "x" in properties
|
||||
assert "y" in properties
|
||||
assert properties["x"]["type"] == "integer"
|
||||
assert properties["y"]["type"] == "integer"
|
||||
assert properties["x"]["description"] == "First number"
|
||||
assert properties["y"]["description"] == "Second number"
|
||||
assert params["required"] == ["x", "y"]
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.tools.delete(tool_id=tool.id)
|
||||
|
||||
|
||||
def test_update_tool_source_code_duplicate_name_error(client: LettaSDKClient):
|
||||
"""Test that updating a tool's source code to have the same name as another existing tool raises an error"""
|
||||
import textwrap
|
||||
|
||||
# Create first tool
|
||||
def first_tool(x: int) -> int:
|
||||
"""
|
||||
Multiply a number by 2
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
|
||||
Returns:
|
||||
The input multiplied by 2
|
||||
"""
|
||||
return x * 2
|
||||
|
||||
# Create second tool
|
||||
def second_tool(x: int) -> int:
|
||||
"""
|
||||
Multiply a number by 3
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
|
||||
Returns:
|
||||
The input multiplied by 3
|
||||
"""
|
||||
return x * 3
|
||||
|
||||
# Create both tools
|
||||
tool1 = client.tools.upsert_from_function(func=first_tool)
|
||||
tool2 = client.tools.upsert_from_function(func=second_tool)
|
||||
|
||||
assert tool1.name == "first_tool"
|
||||
assert tool2.name == "second_tool"
|
||||
|
||||
try:
|
||||
# Try to update second_tool to have the same name as first_tool
|
||||
new_source_code = textwrap.dedent(
|
||||
"""
|
||||
def first_tool(x: int) -> int:
|
||||
'''
|
||||
Multiply a number by 4
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
|
||||
Returns:
|
||||
The input multiplied by 4
|
||||
'''
|
||||
return x * 4
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# This should raise an error since first_tool already exists
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
client.tools.modify(tool_id=tool2.id, source_code=new_source_code)
|
||||
|
||||
# Verify the error message indicates duplicate name
|
||||
error_message = str(exc_info.value)
|
||||
assert "already exists" in error_message.lower() or "duplicate" in error_message.lower() or "conflict" in error_message.lower()
|
||||
|
||||
# Verify that tool2 was not modified
|
||||
tool2_check = client.tools.retrieve(tool_id=tool2.id)
|
||||
assert tool2_check.name == "second_tool" # Name should remain unchanged
|
||||
|
||||
finally:
|
||||
# Clean up both tools
|
||||
client.tools.delete(tool_id=tool1.id)
|
||||
client.tools.delete(tool_id=tool2.id)
|
||||
|
||||
Reference in New Issue
Block a user