feat: allow list tools by tool type [PRO-870] (#4036)
* feat: allow list tools by tool type * chore: update list * chore: respond to comments * chore: refactor tools hella * Add tests to managers * chore: branch --------- Co-authored-by: Shubham Naik <shub@memgpt.ai> Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -27,7 +27,7 @@ from letta.log import get_logger
|
||||
from letta.orm.errors import UniqueConstraintViolationError
|
||||
from letta.orm.mcp_oauth import OAuthSessionStatus
|
||||
from letta.prompts.gpt_system import get_system_text
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.enums import MessageRole, ToolType
|
||||
from letta.schemas.letta_message import ToolReturnMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
|
||||
@@ -62,16 +62,94 @@ async def delete_tool(
|
||||
|
||||
@router.get("/count", response_model=int, operation_id="count_tools")
|
||||
async def count_tools(
|
||||
name: Optional[str] = None,
|
||||
names: Optional[List[str]] = Query(None, description="Filter by specific tool names"),
|
||||
tool_ids: Optional[List[str]] = Query(
|
||||
None, description="Filter by specific tool IDs - accepts repeated params or comma-separated values"
|
||||
),
|
||||
search: Optional[str] = Query(None, description="Search tool names (case-insensitive partial match)"),
|
||||
tool_types: Optional[List[str]] = Query(None, description="Filter by tool type(s) - accepts repeated params or comma-separated values"),
|
||||
exclude_tool_types: Optional[List[str]] = Query(
|
||||
None, description="Tool type(s) to exclude - accepts repeated params or comma-separated values"
|
||||
),
|
||||
return_only_letta_tools: Optional[bool] = Query(False, description="Count only tools with tool_type starting with 'letta_'"),
|
||||
exclude_letta_tools: Optional[bool] = Query(False, description="Exclude built-in Letta tools from the count"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
include_base_tools: Optional[bool] = Query(False, description="Include built-in Letta tools in the count"),
|
||||
):
|
||||
"""
|
||||
Get a count of all tools available to agents belonging to the org of the user.
|
||||
"""
|
||||
try:
|
||||
# Helper function to parse tool types - supports both repeated params and comma-separated values
|
||||
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if tool_types_input is None:
|
||||
return None
|
||||
|
||||
# Flatten any comma-separated values and validate against ToolType enum
|
||||
flattened_types = []
|
||||
for item in tool_types_input:
|
||||
# Split by comma in case user provided comma-separated values
|
||||
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
|
||||
flattened_types.extend(types_in_item)
|
||||
|
||||
# Validate each type against the ToolType enum
|
||||
valid_types = []
|
||||
valid_values = [tt.value for tt in ToolType]
|
||||
|
||||
for tool_type in flattened_types:
|
||||
if tool_type not in valid_values:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}"
|
||||
)
|
||||
valid_types.append(tool_type)
|
||||
|
||||
return valid_types if valid_types else None
|
||||
|
||||
# Parse and validate tool types (same logic as list_tools)
|
||||
tool_types_str = parse_tool_types(tool_types)
|
||||
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.tool_manager.size_async(actor=actor, include_base_tools=include_base_tools)
|
||||
|
||||
# Combine single name with names list for unified processing (same logic as list_tools)
|
||||
combined_names = []
|
||||
if name is not None:
|
||||
combined_names.append(name)
|
||||
if names is not None:
|
||||
combined_names.extend(names)
|
||||
|
||||
# Use None if no names specified, otherwise use the combined list
|
||||
final_names = combined_names if combined_names else None
|
||||
|
||||
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
|
||||
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if tool_ids_input is None:
|
||||
return None
|
||||
|
||||
# Flatten any comma-separated values
|
||||
flattened_ids = []
|
||||
for item in tool_ids_input:
|
||||
# Split by comma in case user provided comma-separated values
|
||||
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
|
||||
flattened_ids.extend(ids_in_item)
|
||||
|
||||
return flattened_ids if flattened_ids else None
|
||||
|
||||
# Parse tool IDs (same logic as list_tools)
|
||||
final_tool_ids = parse_tool_ids(tool_ids)
|
||||
|
||||
# Get the count of tools using unified query
|
||||
return await server.tool_manager.count_tools_async(
|
||||
actor=actor,
|
||||
tool_types=tool_types_str,
|
||||
exclude_tool_types=exclude_tool_types_str,
|
||||
names=final_names,
|
||||
tool_ids=final_tool_ids,
|
||||
search=search,
|
||||
return_only_letta_tools=return_only_letta_tools,
|
||||
exclude_letta_tools=exclude_letta_tools,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -99,6 +177,16 @@ async def list_tools(
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
name: Optional[str] = None,
|
||||
names: Optional[List[str]] = Query(None, description="Filter by specific tool names"),
|
||||
tool_ids: Optional[List[str]] = Query(
|
||||
None, description="Filter by specific tool IDs - accepts repeated params or comma-separated values"
|
||||
),
|
||||
search: Optional[str] = Query(None, description="Search tool names (case-insensitive partial match)"),
|
||||
tool_types: Optional[List[str]] = Query(None, description="Filter by tool type(s) - accepts repeated params or comma-separated values"),
|
||||
exclude_tool_types: Optional[List[str]] = Query(
|
||||
None, description="Tool type(s) to exclude - accepts repeated params or comma-separated values"
|
||||
),
|
||||
return_only_letta_tools: Optional[bool] = Query(False, description="Return only tools with tool_type starting with 'letta_'"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
@@ -106,13 +194,76 @@ async def list_tools(
|
||||
Get a list of all tools available to agents belonging to the org of the user
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
if name is not None:
|
||||
tool = await server.tool_manager.get_tool_by_name_async(tool_name=name, actor=actor)
|
||||
return [tool] if tool else []
|
||||
# Helper function to parse tool types - supports both repeated params and comma-separated values
|
||||
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if tool_types_input is None:
|
||||
return None
|
||||
|
||||
# Get the list of tools
|
||||
return await server.tool_manager.list_tools_async(actor=actor, after=after, limit=limit)
|
||||
# Flatten any comma-separated values and validate against ToolType enum
|
||||
flattened_types = []
|
||||
for item in tool_types_input:
|
||||
# Split by comma in case user provided comma-separated values
|
||||
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
|
||||
flattened_types.extend(types_in_item)
|
||||
|
||||
# Validate each type against the ToolType enum
|
||||
valid_types = []
|
||||
valid_values = [tt.value for tt in ToolType]
|
||||
|
||||
for tool_type in flattened_types:
|
||||
if tool_type not in valid_values:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}"
|
||||
)
|
||||
valid_types.append(tool_type)
|
||||
|
||||
return valid_types if valid_types else None
|
||||
|
||||
# Parse and validate tool types
|
||||
tool_types_str = parse_tool_types(tool_types)
|
||||
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# Combine single name with names list for unified processing
|
||||
combined_names = []
|
||||
if name is not None:
|
||||
combined_names.append(name)
|
||||
if names is not None:
|
||||
combined_names.extend(names)
|
||||
|
||||
# Use None if no names specified, otherwise use the combined list
|
||||
final_names = combined_names if combined_names else None
|
||||
|
||||
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
|
||||
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if tool_ids_input is None:
|
||||
return None
|
||||
|
||||
# Flatten any comma-separated values
|
||||
flattened_ids = []
|
||||
for item in tool_ids_input:
|
||||
# Split by comma in case user provided comma-separated values
|
||||
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
|
||||
flattened_ids.extend(ids_in_item)
|
||||
|
||||
return flattened_ids if flattened_ids else None
|
||||
|
||||
# Parse tool IDs
|
||||
final_tool_ids = parse_tool_ids(tool_ids)
|
||||
|
||||
# Get the list of tools using unified query
|
||||
return await server.tool_manager.list_tools_async(
|
||||
actor=actor,
|
||||
after=after,
|
||||
limit=limit,
|
||||
tool_types=tool_types_str,
|
||||
exclude_tool_types=exclude_tool_types_str,
|
||||
names=final_names,
|
||||
tool_ids=final_tool_ids,
|
||||
search=search,
|
||||
return_only_letta_tools=return_only_letta_tools,
|
||||
)
|
||||
except Exception as e:
|
||||
# Log or print the full exception here for debugging
|
||||
print(f"Error occurred: {e}")
|
||||
|
||||
@@ -2,7 +2,7 @@ import importlib
|
||||
import warnings
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
|
||||
from letta.constants import (
|
||||
BASE_FUNCTION_RETURN_CHAR_LIMIT,
|
||||
@@ -319,10 +319,30 @@ class ToolManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_tools_async(
|
||||
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, upsert_base_tools: bool = True
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
upsert_base_tools: bool = True,
|
||||
tool_types: Optional[List[str]] = None,
|
||||
exclude_tool_types: Optional[List[str]] = None,
|
||||
names: Optional[List[str]] = None,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
search: Optional[str] = None,
|
||||
return_only_letta_tools: bool = False,
|
||||
) -> List[PydanticTool]:
|
||||
"""List all tools with optional pagination."""
|
||||
tools = await self._list_tools_async(actor=actor, after=after, limit=limit)
|
||||
tools = await self._list_tools_async(
|
||||
actor=actor,
|
||||
after=after,
|
||||
limit=limit,
|
||||
tool_types=tool_types,
|
||||
exclude_tool_types=exclude_tool_types,
|
||||
names=names,
|
||||
tool_ids=tool_ids,
|
||||
search=search,
|
||||
return_only_letta_tools=return_only_letta_tools,
|
||||
)
|
||||
|
||||
# Check if all base tools are present if we requested all the tools w/o cursor
|
||||
# TODO: This is a temporary hack to resolve this issue
|
||||
@@ -337,22 +357,86 @@ class ToolManager:
|
||||
logger.info(f"Missing base tools detected: {missing_base_tools}. Upserting all base tools.")
|
||||
await self.upsert_base_tools_async(actor=actor)
|
||||
# Re-fetch the tools list after upserting base tools
|
||||
tools = await self._list_tools_async(actor=actor, after=after, limit=limit)
|
||||
tools = await self._list_tools_async(
|
||||
actor=actor,
|
||||
after=after,
|
||||
limit=limit,
|
||||
tool_types=tool_types,
|
||||
exclude_tool_types=exclude_tool_types,
|
||||
names=names,
|
||||
tool_ids=tool_ids,
|
||||
search=search,
|
||||
return_only_letta_tools=return_only_letta_tools,
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def _list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
|
||||
async def _list_tools_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
tool_types: Optional[List[str]] = None,
|
||||
exclude_tool_types: Optional[List[str]] = None,
|
||||
names: Optional[List[str]] = None,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
search: Optional[str] = None,
|
||||
return_only_letta_tools: bool = False,
|
||||
) -> List[PydanticTool]:
|
||||
"""List all tools with optional pagination."""
|
||||
tools_to_delete = []
|
||||
async with db_registry.async_session() as session:
|
||||
tools = await ToolModel.list_async(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
# Use SQLAlchemy directly for all cases - more control and consistency
|
||||
# Start with base query
|
||||
query = select(ToolModel).where(ToolModel.organization_id == actor.organization_id)
|
||||
|
||||
# Apply tool_types filter
|
||||
if tool_types is not None:
|
||||
query = query.where(ToolModel.tool_type.in_(tool_types))
|
||||
|
||||
# Apply names filter
|
||||
if names is not None:
|
||||
query = query.where(ToolModel.name.in_(names))
|
||||
|
||||
# Apply tool_ids filter
|
||||
if tool_ids is not None:
|
||||
query = query.where(ToolModel.id.in_(tool_ids))
|
||||
|
||||
# Apply search filter (ILIKE for case-insensitive partial match)
|
||||
if search is not None:
|
||||
query = query.where(ToolModel.name.ilike(f"%{search}%"))
|
||||
|
||||
# Apply exclude_tool_types filter at database level
|
||||
if exclude_tool_types is not None:
|
||||
query = query.where(~ToolModel.tool_type.in_(exclude_tool_types))
|
||||
|
||||
# Apply return_only_letta_tools filter at database level
|
||||
if return_only_letta_tools:
|
||||
query = query.where(ToolModel.tool_type.like("letta_%"))
|
||||
|
||||
# Apply pagination if specified
|
||||
if after is not None:
|
||||
after_tool = await session.get(ToolModel, after)
|
||||
if after_tool:
|
||||
query = query.where(
|
||||
or_(
|
||||
ToolModel.created_at < after_tool.created_at,
|
||||
and_(ToolModel.created_at == after_tool.created_at, ToolModel.id < after_tool.id),
|
||||
)
|
||||
)
|
||||
|
||||
# Apply limit
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
# Order by created_at and id for consistent pagination
|
||||
query = query.order_by(ToolModel.created_at.desc(), ToolModel.id.desc())
|
||||
|
||||
# Execute query
|
||||
result = await session.execute(query)
|
||||
tools = list(result.scalars())
|
||||
|
||||
# Remove any malformed tools
|
||||
results = []
|
||||
@@ -375,6 +459,61 @@ class ToolManager:
|
||||
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def count_tools_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
tool_types: Optional[List[str]] = None,
|
||||
exclude_tool_types: Optional[List[str]] = None,
|
||||
names: Optional[List[str]] = None,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
search: Optional[str] = None,
|
||||
return_only_letta_tools: bool = False,
|
||||
exclude_letta_tools: bool = False,
|
||||
) -> int:
|
||||
"""Count tools with the same filtering logic as list_tools_async."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Use SQLAlchemy directly with COUNT query - same filtering logic as list_tools_async
|
||||
# Start with base query
|
||||
query = select(func.count(ToolModel.id)).where(ToolModel.organization_id == actor.organization_id)
|
||||
|
||||
# Apply tool_types filter
|
||||
if tool_types is not None:
|
||||
query = query.where(ToolModel.tool_type.in_(tool_types))
|
||||
|
||||
# Apply names filter
|
||||
if names is not None:
|
||||
query = query.where(ToolModel.name.in_(names))
|
||||
|
||||
# Apply tool_ids filter
|
||||
if tool_ids is not None:
|
||||
query = query.where(ToolModel.id.in_(tool_ids))
|
||||
|
||||
# Apply search filter (ILIKE for case-insensitive partial match)
|
||||
if search is not None:
|
||||
query = query.where(ToolModel.name.ilike(f"%{search}%"))
|
||||
|
||||
# Apply exclude_tool_types filter at database level
|
||||
if exclude_tool_types is not None:
|
||||
query = query.where(~ToolModel.tool_type.in_(exclude_tool_types))
|
||||
|
||||
# Apply return_only_letta_tools filter at database level
|
||||
if return_only_letta_tools:
|
||||
query = query.where(ToolModel.tool_type.like("letta_%"))
|
||||
|
||||
# Handle exclude_letta_tools logic (if True, exclude Letta tools)
|
||||
if exclude_letta_tools:
|
||||
# Exclude tools that are in the LETTA_TOOL_SET
|
||||
letta_tool_names = list(LETTA_TOOL_SET)
|
||||
query = query.where(~ToolModel.name.in_(letta_tool_names))
|
||||
|
||||
# Execute count query
|
||||
result = await session.execute(query)
|
||||
count = result.scalar()
|
||||
|
||||
return count or 0
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def size_async(
|
||||
|
||||
@@ -4143,6 +4143,487 @@ async def test_list_tools(server: SyncServer, print_tool, default_user):
|
||||
assert any(t.id == print_tool.id for t in tools)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_with_tool_types(server: SyncServer, default_user):
|
||||
"""Test filtering tools by tool_types parameter."""
|
||||
|
||||
# create tools with different types
|
||||
def calculator_tool(a: int, b: int) -> int:
|
||||
"""Add two numbers.
|
||||
|
||||
Args:
|
||||
a: First number
|
||||
b: Second number
|
||||
|
||||
Returns:
|
||||
Sum of a and b
|
||||
"""
|
||||
return a + b
|
||||
|
||||
def weather_tool(city: str) -> str:
|
||||
"""Get weather for a city.
|
||||
|
||||
Args:
|
||||
city: Name of the city
|
||||
|
||||
Returns:
|
||||
Weather information
|
||||
"""
|
||||
return f"Weather in {city}"
|
||||
|
||||
# create custom tools
|
||||
custom_tool1 = PydanticTool(
|
||||
name="calculator",
|
||||
description="Math tool",
|
||||
source_code=parse_source_code(calculator_tool),
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
custom_tool1.json_schema = derive_openai_json_schema(source_code=custom_tool1.source_code, name=custom_tool1.name)
|
||||
custom_tool1 = await server.tool_manager.create_or_update_tool_async(custom_tool1, actor=default_user)
|
||||
|
||||
custom_tool2 = PydanticTool(
|
||||
name="weather",
|
||||
description="Weather tool",
|
||||
source_code=parse_source_code(weather_tool),
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
custom_tool2.json_schema = derive_openai_json_schema(source_code=custom_tool2.source_code, name=custom_tool2.name)
|
||||
custom_tool2 = await server.tool_manager.create_or_update_tool_async(custom_tool2, actor=default_user)
|
||||
|
||||
# test filtering by single tool type
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, tool_types=[ToolType.CUSTOM.value], upsert_base_tools=False)
|
||||
assert len(tools) == 2
|
||||
assert all(t.tool_type == ToolType.CUSTOM for t in tools)
|
||||
|
||||
# test filtering by multiple tool types (should get same result since we only have CUSTOM)
|
||||
tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, tool_types=[ToolType.CUSTOM.value, ToolType.LETTA_CORE.value], upsert_base_tools=False
|
||||
)
|
||||
assert len(tools) == 2
|
||||
|
||||
# test filtering by non-existent tool type
|
||||
tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, tool_types=[ToolType.EXTERNAL_MCP.value], upsert_base_tools=False
|
||||
)
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_with_exclude_tool_types(server: SyncServer, default_user, print_tool):
|
||||
"""Test excluding tools by exclude_tool_types parameter."""
|
||||
# we already have print_tool which is CUSTOM type
|
||||
|
||||
# create a tool with a different type (simulate by updating tool type directly)
|
||||
def special_tool(msg: str) -> str:
|
||||
"""Special tool.
|
||||
|
||||
Args:
|
||||
msg: Message to return
|
||||
|
||||
Returns:
|
||||
The message
|
||||
"""
|
||||
return msg
|
||||
|
||||
special = PydanticTool(
|
||||
name="special",
|
||||
description="Special tool",
|
||||
source_code=parse_source_code(special_tool),
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
special.json_schema = derive_openai_json_schema(source_code=special.source_code, name=special.name)
|
||||
special = await server.tool_manager.create_or_update_tool_async(special, actor=default_user)
|
||||
|
||||
# test excluding EXTERNAL_MCP (should get all tools since none are MCP)
|
||||
tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, exclude_tool_types=[ToolType.EXTERNAL_MCP.value], upsert_base_tools=False
|
||||
)
|
||||
assert len(tools) == 2 # print_tool and special
|
||||
|
||||
# test excluding CUSTOM (should get no tools)
|
||||
tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, exclude_tool_types=[ToolType.CUSTOM.value], upsert_base_tools=False
|
||||
)
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_with_names(server: SyncServer, default_user):
|
||||
"""Test filtering tools by names parameter."""
|
||||
|
||||
# create tools with specific names
|
||||
def alpha_tool() -> str:
|
||||
"""Alpha tool.
|
||||
|
||||
Returns:
|
||||
Alpha string
|
||||
"""
|
||||
return "alpha"
|
||||
|
||||
def beta_tool() -> str:
|
||||
"""Beta tool.
|
||||
|
||||
Returns:
|
||||
Beta string
|
||||
"""
|
||||
return "beta"
|
||||
|
||||
def gamma_tool() -> str:
|
||||
"""Gamma tool.
|
||||
|
||||
Returns:
|
||||
Gamma string
|
||||
"""
|
||||
return "gamma"
|
||||
|
||||
alpha = PydanticTool(name="alpha_tool", description="Alpha", source_code=parse_source_code(alpha_tool), source_type="python")
|
||||
alpha.json_schema = derive_openai_json_schema(source_code=alpha.source_code, name=alpha.name)
|
||||
alpha = await server.tool_manager.create_or_update_tool_async(alpha, actor=default_user)
|
||||
|
||||
beta = PydanticTool(name="beta_tool", description="Beta", source_code=parse_source_code(beta_tool), source_type="python")
|
||||
beta.json_schema = derive_openai_json_schema(source_code=beta.source_code, name=beta.name)
|
||||
beta = await server.tool_manager.create_or_update_tool_async(beta, actor=default_user)
|
||||
|
||||
gamma = PydanticTool(name="gamma_tool", description="Gamma", source_code=parse_source_code(gamma_tool), source_type="python")
|
||||
gamma.json_schema = derive_openai_json_schema(source_code=gamma.source_code, name=gamma.name)
|
||||
gamma = await server.tool_manager.create_or_update_tool_async(gamma, actor=default_user)
|
||||
|
||||
# test filtering by single name
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, names=["alpha_tool"], upsert_base_tools=False)
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "alpha_tool"
|
||||
|
||||
# test filtering by multiple names
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, names=["alpha_tool", "gamma_tool"], upsert_base_tools=False)
|
||||
assert len(tools) == 2
|
||||
assert set(t.name for t in tools) == {"alpha_tool", "gamma_tool"}
|
||||
|
||||
# test filtering by non-existent name
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, names=["non_existent_tool"], upsert_base_tools=False)
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_with_tool_ids(server: SyncServer, default_user):
|
||||
"""Test filtering tools by tool_ids parameter."""
|
||||
|
||||
# create multiple tools
|
||||
def tool1() -> str:
|
||||
"""Tool 1.
|
||||
|
||||
Returns:
|
||||
String 1
|
||||
"""
|
||||
return "1"
|
||||
|
||||
def tool2() -> str:
|
||||
"""Tool 2.
|
||||
|
||||
Returns:
|
||||
String 2
|
||||
"""
|
||||
return "2"
|
||||
|
||||
def tool3() -> str:
|
||||
"""Tool 3.
|
||||
|
||||
Returns:
|
||||
String 3
|
||||
"""
|
||||
return "3"
|
||||
|
||||
t1 = PydanticTool(name="tool1", description="First", source_code=parse_source_code(tool1), source_type="python")
|
||||
t1.json_schema = derive_openai_json_schema(source_code=t1.source_code, name=t1.name)
|
||||
t1 = await server.tool_manager.create_or_update_tool_async(t1, actor=default_user)
|
||||
|
||||
t2 = PydanticTool(name="tool2", description="Second", source_code=parse_source_code(tool2), source_type="python")
|
||||
t2.json_schema = derive_openai_json_schema(source_code=t2.source_code, name=t2.name)
|
||||
t2 = await server.tool_manager.create_or_update_tool_async(t2, actor=default_user)
|
||||
|
||||
t3 = PydanticTool(name="tool3", description="Third", source_code=parse_source_code(tool3), source_type="python")
|
||||
t3.json_schema = derive_openai_json_schema(source_code=t3.source_code, name=t3.name)
|
||||
t3 = await server.tool_manager.create_or_update_tool_async(t3, actor=default_user)
|
||||
|
||||
# test filtering by single id
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, tool_ids=[t1.id], upsert_base_tools=False)
|
||||
assert len(tools) == 1
|
||||
assert tools[0].id == t1.id
|
||||
|
||||
# test filtering by multiple ids
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, tool_ids=[t1.id, t3.id], upsert_base_tools=False)
|
||||
assert len(tools) == 2
|
||||
assert set(t.id for t in tools) == {t1.id, t3.id}
|
||||
|
||||
# test filtering by non-existent id
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, tool_ids=["non-existent-id"], upsert_base_tools=False)
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_with_search(server: SyncServer, default_user):
|
||||
"""Test searching tools by partial name match."""
|
||||
|
||||
# create tools with searchable names
|
||||
def calculator_add() -> str:
|
||||
"""Calculator add.
|
||||
|
||||
Returns:
|
||||
Add operation
|
||||
"""
|
||||
return "add"
|
||||
|
||||
def calculator_subtract() -> str:
|
||||
"""Calculator subtract.
|
||||
|
||||
Returns:
|
||||
Subtract operation
|
||||
"""
|
||||
return "subtract"
|
||||
|
||||
def weather_forecast() -> str:
|
||||
"""Weather forecast.
|
||||
|
||||
Returns:
|
||||
Forecast data
|
||||
"""
|
||||
return "forecast"
|
||||
|
||||
calc_add = PydanticTool(
|
||||
name="calculator_add", description="Add numbers", source_code=parse_source_code(calculator_add), source_type="python"
|
||||
)
|
||||
calc_add.json_schema = derive_openai_json_schema(source_code=calc_add.source_code, name=calc_add.name)
|
||||
calc_add = await server.tool_manager.create_or_update_tool_async(calc_add, actor=default_user)
|
||||
|
||||
calc_sub = PydanticTool(
|
||||
name="calculator_subtract", description="Subtract numbers", source_code=parse_source_code(calculator_subtract), source_type="python"
|
||||
)
|
||||
calc_sub.json_schema = derive_openai_json_schema(source_code=calc_sub.source_code, name=calc_sub.name)
|
||||
calc_sub = await server.tool_manager.create_or_update_tool_async(calc_sub, actor=default_user)
|
||||
|
||||
weather = PydanticTool(
|
||||
name="weather_forecast", description="Weather", source_code=parse_source_code(weather_forecast), source_type="python"
|
||||
)
|
||||
weather.json_schema = derive_openai_json_schema(source_code=weather.source_code, name=weather.name)
|
||||
weather = await server.tool_manager.create_or_update_tool_async(weather, actor=default_user)
|
||||
|
||||
# test searching for "calculator" (should find both calculator tools)
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, search="calculator", upsert_base_tools=False)
|
||||
assert len(tools) == 2
|
||||
assert all("calculator" in t.name for t in tools)
|
||||
|
||||
# test case-insensitive search
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, search="CALCULATOR", upsert_base_tools=False)
|
||||
assert len(tools) == 2
|
||||
|
||||
# test partial match
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, search="calc", upsert_base_tools=False)
|
||||
assert len(tools) == 2
|
||||
|
||||
# test search with no matches
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, search="nonexistent", upsert_base_tools=False)
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_return_only_letta_tools(server: SyncServer, default_user):
|
||||
"""Test filtering for only Letta tools."""
|
||||
# first, upsert base tools to ensure we have Letta tools
|
||||
await server.tool_manager.upsert_base_tools_async(actor=default_user)
|
||||
|
||||
# create a custom tool
|
||||
def custom_tool() -> str:
|
||||
"""Custom tool.
|
||||
|
||||
Returns:
|
||||
Custom string
|
||||
"""
|
||||
return "custom"
|
||||
|
||||
custom = PydanticTool(
|
||||
name="custom_tool",
|
||||
description="Custom",
|
||||
source_code=parse_source_code(custom_tool),
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
custom.json_schema = derive_openai_json_schema(source_code=custom.source_code, name=custom.name)
|
||||
custom = await server.tool_manager.create_or_update_tool_async(custom, actor=default_user)
|
||||
|
||||
# test without filter (should get custom tool + all letta tools)
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, return_only_letta_tools=False, upsert_base_tools=False)
|
||||
# should have at least the custom tool and some letta tools
|
||||
assert len(tools) > 1
|
||||
assert any(t.name == "custom_tool" for t in tools)
|
||||
|
||||
# test with filter (should only get letta tools)
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, return_only_letta_tools=True, upsert_base_tools=False)
|
||||
assert len(tools) > 0
|
||||
# all tools should have tool_type starting with "letta_"
|
||||
assert all(t.tool_type.value.startswith("letta_") for t in tools)
|
||||
# custom tool should not be in the list
|
||||
assert not any(t.name == "custom_tool" for t in tools)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_combined_filters(server: SyncServer, default_user):
|
||||
"""Test combining multiple filters."""
|
||||
|
||||
# create various tools
|
||||
def calc_add() -> str:
|
||||
"""Calculator add.
|
||||
|
||||
Returns:
|
||||
Add result
|
||||
"""
|
||||
return "add"
|
||||
|
||||
def calc_multiply() -> str:
|
||||
"""Calculator multiply.
|
||||
|
||||
Returns:
|
||||
Multiply result
|
||||
"""
|
||||
return "multiply"
|
||||
|
||||
def weather_tool() -> str:
|
||||
"""Weather tool.
|
||||
|
||||
Returns:
|
||||
Weather data
|
||||
"""
|
||||
return "weather"
|
||||
|
||||
calc1 = PydanticTool(
|
||||
name="calculator_add", description="Add", source_code=parse_source_code(calc_add), source_type="python", tool_type=ToolType.CUSTOM
|
||||
)
|
||||
calc1.json_schema = derive_openai_json_schema(source_code=calc1.source_code, name=calc1.name)
|
||||
calc1 = await server.tool_manager.create_or_update_tool_async(calc1, actor=default_user)
|
||||
|
||||
calc2 = PydanticTool(
|
||||
name="calculator_multiply",
|
||||
description="Multiply",
|
||||
source_code=parse_source_code(calc_multiply),
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
calc2.json_schema = derive_openai_json_schema(source_code=calc2.source_code, name=calc2.name)
|
||||
calc2 = await server.tool_manager.create_or_update_tool_async(calc2, actor=default_user)
|
||||
|
||||
weather = PydanticTool(
|
||||
name="weather_current",
|
||||
description="Weather",
|
||||
source_code=parse_source_code(weather_tool),
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
weather.json_schema = derive_openai_json_schema(source_code=weather.source_code, name=weather.name)
|
||||
weather = await server.tool_manager.create_or_update_tool_async(weather, actor=default_user)
|
||||
|
||||
# combine search with tool_types
|
||||
tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, search="calculator", tool_types=[ToolType.CUSTOM.value], upsert_base_tools=False
|
||||
)
|
||||
assert len(tools) == 2
|
||||
assert all("calculator" in t.name and t.tool_type == ToolType.CUSTOM for t in tools)
|
||||
|
||||
# combine names with tool_ids
|
||||
tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, names=["calculator_add"], tool_ids=[calc1.id], upsert_base_tools=False
|
||||
)
|
||||
assert len(tools) == 1
|
||||
assert tools[0].id == calc1.id
|
||||
|
||||
# combine search with exclude_tool_types
|
||||
tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, search="calculator", exclude_tool_types=[ToolType.EXTERNAL_MCP.value], upsert_base_tools=False
|
||||
)
|
||||
assert len(tools) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tools_async(server: SyncServer, default_user):
|
||||
"""Test counting tools with various filters."""
|
||||
|
||||
# create multiple tools
|
||||
def tool_a() -> str:
|
||||
"""Tool A.
|
||||
|
||||
Returns:
|
||||
String a
|
||||
"""
|
||||
return "a"
|
||||
|
||||
def tool_b() -> str:
|
||||
"""Tool B.
|
||||
|
||||
Returns:
|
||||
String b
|
||||
"""
|
||||
return "b"
|
||||
|
||||
def search_tool() -> str:
|
||||
"""Search tool.
|
||||
|
||||
Returns:
|
||||
Search result
|
||||
"""
|
||||
return "search"
|
||||
|
||||
ta = PydanticTool(
|
||||
name="tool_a", description="A", source_code=parse_source_code(tool_a), source_type="python", tool_type=ToolType.CUSTOM
|
||||
)
|
||||
ta.json_schema = derive_openai_json_schema(source_code=ta.source_code, name=ta.name)
|
||||
ta = await server.tool_manager.create_or_update_tool_async(ta, actor=default_user)
|
||||
|
||||
tb = PydanticTool(
|
||||
name="tool_b", description="B", source_code=parse_source_code(tool_b), source_type="python", tool_type=ToolType.CUSTOM
|
||||
)
|
||||
tb.json_schema = derive_openai_json_schema(source_code=tb.source_code, name=tb.name)
|
||||
tb = await server.tool_manager.create_or_update_tool_async(tb, actor=default_user)
|
||||
|
||||
# upsert base tools to ensure we have Letta tools for counting
|
||||
await server.tool_manager.upsert_base_tools_async(actor=default_user)
|
||||
|
||||
# count all tools (should have 2 custom tools + letta tools)
|
||||
count = await server.tool_manager.count_tools_async(actor=default_user)
|
||||
assert count > 2 # at least our 2 custom tools + letta tools
|
||||
|
||||
# count with tool_types filter
|
||||
count = await server.tool_manager.count_tools_async(actor=default_user, tool_types=[ToolType.CUSTOM.value])
|
||||
assert count == 2 # only our custom tools
|
||||
|
||||
# count with search filter
|
||||
count = await server.tool_manager.count_tools_async(actor=default_user, search="tool")
|
||||
# should at least find our 2 tools (tool_a, tool_b)
|
||||
assert count >= 2
|
||||
|
||||
# count with names filter
|
||||
count = await server.tool_manager.count_tools_async(actor=default_user, names=["tool_a", "tool_b"])
|
||||
assert count == 2
|
||||
|
||||
# count with return_only_letta_tools
|
||||
count = await server.tool_manager.count_tools_async(actor=default_user, return_only_letta_tools=True)
|
||||
assert count > 0 # should have letta tools
|
||||
|
||||
# count with exclude_tool_types (exclude all letta tool types)
|
||||
count = await server.tool_manager.count_tools_async(
|
||||
actor=default_user,
|
||||
exclude_tool_types=[
|
||||
ToolType.LETTA_CORE.value,
|
||||
ToolType.LETTA_MEMORY_CORE.value,
|
||||
ToolType.LETTA_MULTI_AGENT_CORE.value,
|
||||
ToolType.LETTA_SLEEPTIME_CORE.value,
|
||||
ToolType.LETTA_VOICE_SLEEPTIME_CORE.value,
|
||||
ToolType.LETTA_BUILTIN.value,
|
||||
ToolType.LETTA_FILES_CORE.value,
|
||||
],
|
||||
)
|
||||
assert count == 2 # only our custom tools
|
||||
|
||||
|
||||
def test_update_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
updated_description = "updated_description"
|
||||
return_char_limit = 10000
|
||||
|
||||
Reference in New Issue
Block a user