feat: allow list tools by tool type [PRO-870] (#4036)

* feat: allow list tools by tool type

* chore: update list

* chore: respond to comments

* chore: refactor tools hella

* Add tests to managers

* chore: branch

---------

Co-authored-by: Shubham Naik <shub@memgpt.ai>
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
Shubham Naik
2025-09-02 15:53:36 -07:00
committed by GitHub
parent 75d444c335
commit 6716f5b5e8
3 changed files with 791 additions and 20 deletions

View File

@@ -27,7 +27,7 @@ from letta.log import get_logger
from letta.orm.errors import UniqueConstraintViolationError
from letta.orm.mcp_oauth import OAuthSessionStatus
from letta.prompts.gpt_system import get_system_text
from letta.schemas.enums import MessageRole
from letta.schemas.enums import MessageRole, ToolType
from letta.schemas.letta_message import ToolReturnMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
@@ -62,16 +62,94 @@ async def delete_tool(
@router.get("/count", response_model=int, operation_id="count_tools")
async def count_tools(
name: Optional[str] = None,
names: Optional[List[str]] = Query(None, description="Filter by specific tool names"),
tool_ids: Optional[List[str]] = Query(
None, description="Filter by specific tool IDs - accepts repeated params or comma-separated values"
),
search: Optional[str] = Query(None, description="Search tool names (case-insensitive partial match)"),
tool_types: Optional[List[str]] = Query(None, description="Filter by tool type(s) - accepts repeated params or comma-separated values"),
exclude_tool_types: Optional[List[str]] = Query(
None, description="Tool type(s) to exclude - accepts repeated params or comma-separated values"
),
return_only_letta_tools: Optional[bool] = Query(False, description="Count only tools with tool_type starting with 'letta_'"),
exclude_letta_tools: Optional[bool] = Query(False, description="Exclude built-in Letta tools from the count"),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
include_base_tools: Optional[bool] = Query(False, description="Include built-in Letta tools in the count"),
):
"""
Get a count of all tools available to agents belonging to the org of the user.
"""
try:
# Helper function to parse tool types - supports both repeated params and comma-separated values
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
if tool_types_input is None:
return None
# Flatten any comma-separated values and validate against ToolType enum
flattened_types = []
for item in tool_types_input:
# Split by comma in case user provided comma-separated values
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
flattened_types.extend(types_in_item)
# Validate each type against the ToolType enum
valid_types = []
valid_values = [tt.value for tt in ToolType]
for tool_type in flattened_types:
if tool_type not in valid_values:
raise HTTPException(
status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}"
)
valid_types.append(tool_type)
return valid_types if valid_types else None
# Parse and validate tool types (same logic as list_tools)
tool_types_str = parse_tool_types(tool_types)
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.tool_manager.size_async(actor=actor, include_base_tools=include_base_tools)
# Combine single name with names list for unified processing (same logic as list_tools)
combined_names = []
if name is not None:
combined_names.append(name)
if names is not None:
combined_names.extend(names)
# Use None if no names specified, otherwise use the combined list
final_names = combined_names if combined_names else None
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
if tool_ids_input is None:
return None
# Flatten any comma-separated values
flattened_ids = []
for item in tool_ids_input:
# Split by comma in case user provided comma-separated values
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
flattened_ids.extend(ids_in_item)
return flattened_ids if flattened_ids else None
# Parse tool IDs (same logic as list_tools)
final_tool_ids = parse_tool_ids(tool_ids)
# Get the count of tools using unified query
return await server.tool_manager.count_tools_async(
actor=actor,
tool_types=tool_types_str,
exclude_tool_types=exclude_tool_types_str,
names=final_names,
tool_ids=final_tool_ids,
search=search,
return_only_letta_tools=return_only_letta_tools,
exclude_letta_tools=exclude_letta_tools,
)
except Exception as e:
print(f"Error occurred: {e}")
raise HTTPException(status_code=500, detail=str(e))
@@ -99,6 +177,16 @@ async def list_tools(
after: Optional[str] = None,
limit: Optional[int] = 50,
name: Optional[str] = None,
names: Optional[List[str]] = Query(None, description="Filter by specific tool names"),
tool_ids: Optional[List[str]] = Query(
None, description="Filter by specific tool IDs - accepts repeated params or comma-separated values"
),
search: Optional[str] = Query(None, description="Search tool names (case-insensitive partial match)"),
tool_types: Optional[List[str]] = Query(None, description="Filter by tool type(s) - accepts repeated params or comma-separated values"),
exclude_tool_types: Optional[List[str]] = Query(
None, description="Tool type(s) to exclude - accepts repeated params or comma-separated values"
),
return_only_letta_tools: Optional[bool] = Query(False, description="Return only tools with tool_type starting with 'letta_'"),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
@@ -106,13 +194,76 @@ async def list_tools(
Get a list of all tools available to agents belonging to the org of the user
"""
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
if name is not None:
tool = await server.tool_manager.get_tool_by_name_async(tool_name=name, actor=actor)
return [tool] if tool else []
# Helper function to parse tool types - supports both repeated params and comma-separated values
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
if tool_types_input is None:
return None
# Get the list of tools
return await server.tool_manager.list_tools_async(actor=actor, after=after, limit=limit)
# Flatten any comma-separated values and validate against ToolType enum
flattened_types = []
for item in tool_types_input:
# Split by comma in case user provided comma-separated values
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
flattened_types.extend(types_in_item)
# Validate each type against the ToolType enum
valid_types = []
valid_values = [tt.value for tt in ToolType]
for tool_type in flattened_types:
if tool_type not in valid_values:
raise HTTPException(
status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}"
)
valid_types.append(tool_type)
return valid_types if valid_types else None
# Parse and validate tool types
tool_types_str = parse_tool_types(tool_types)
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# Combine single name with names list for unified processing
combined_names = []
if name is not None:
combined_names.append(name)
if names is not None:
combined_names.extend(names)
# Use None if no names specified, otherwise use the combined list
final_names = combined_names if combined_names else None
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
if tool_ids_input is None:
return None
# Flatten any comma-separated values
flattened_ids = []
for item in tool_ids_input:
# Split by comma in case user provided comma-separated values
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
flattened_ids.extend(ids_in_item)
return flattened_ids if flattened_ids else None
# Parse tool IDs
final_tool_ids = parse_tool_ids(tool_ids)
# Get the list of tools using unified query
return await server.tool_manager.list_tools_async(
actor=actor,
after=after,
limit=limit,
tool_types=tool_types_str,
exclude_tool_types=exclude_tool_types_str,
names=final_names,
tool_ids=final_tool_ids,
search=search,
return_only_letta_tools=return_only_letta_tools,
)
except Exception as e:
# Log or print the full exception here for debugging
print(f"Error occurred: {e}")

View File

@@ -2,7 +2,7 @@ import importlib
import warnings
from typing import List, Optional, Set, Union
from sqlalchemy import func, select
from sqlalchemy import and_, func, or_, select
from letta.constants import (
BASE_FUNCTION_RETURN_CHAR_LIMIT,
@@ -319,10 +319,30 @@ class ToolManager:
@enforce_types
@trace_method
async def list_tools_async(
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, upsert_base_tools: bool = True
self,
actor: PydanticUser,
after: Optional[str] = None,
limit: Optional[int] = 50,
upsert_base_tools: bool = True,
tool_types: Optional[List[str]] = None,
exclude_tool_types: Optional[List[str]] = None,
names: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
search: Optional[str] = None,
return_only_letta_tools: bool = False,
) -> List[PydanticTool]:
"""List all tools with optional pagination."""
tools = await self._list_tools_async(actor=actor, after=after, limit=limit)
tools = await self._list_tools_async(
actor=actor,
after=after,
limit=limit,
tool_types=tool_types,
exclude_tool_types=exclude_tool_types,
names=names,
tool_ids=tool_ids,
search=search,
return_only_letta_tools=return_only_letta_tools,
)
# Check if all base tools are present if we requested all the tools w/o cursor
# TODO: This is a temporary hack to resolve this issue
@@ -337,22 +357,86 @@ class ToolManager:
logger.info(f"Missing base tools detected: {missing_base_tools}. Upserting all base tools.")
await self.upsert_base_tools_async(actor=actor)
# Re-fetch the tools list after upserting base tools
tools = await self._list_tools_async(actor=actor, after=after, limit=limit)
tools = await self._list_tools_async(
actor=actor,
after=after,
limit=limit,
tool_types=tool_types,
exclude_tool_types=exclude_tool_types,
names=names,
tool_ids=tool_ids,
search=search,
return_only_letta_tools=return_only_letta_tools,
)
return tools
@enforce_types
@trace_method
async def _list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
async def _list_tools_async(
self,
actor: PydanticUser,
after: Optional[str] = None,
limit: Optional[int] = 50,
tool_types: Optional[List[str]] = None,
exclude_tool_types: Optional[List[str]] = None,
names: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
search: Optional[str] = None,
return_only_letta_tools: bool = False,
) -> List[PydanticTool]:
"""List all tools with optional pagination."""
tools_to_delete = []
async with db_registry.async_session() as session:
tools = await ToolModel.list_async(
db_session=session,
after=after,
limit=limit,
organization_id=actor.organization_id,
)
# Use SQLAlchemy directly for all cases - more control and consistency
# Start with base query
query = select(ToolModel).where(ToolModel.organization_id == actor.organization_id)
# Apply tool_types filter
if tool_types is not None:
query = query.where(ToolModel.tool_type.in_(tool_types))
# Apply names filter
if names is not None:
query = query.where(ToolModel.name.in_(names))
# Apply tool_ids filter
if tool_ids is not None:
query = query.where(ToolModel.id.in_(tool_ids))
# Apply search filter (ILIKE for case-insensitive partial match)
if search is not None:
query = query.where(ToolModel.name.ilike(f"%{search}%"))
# Apply exclude_tool_types filter at database level
if exclude_tool_types is not None:
query = query.where(~ToolModel.tool_type.in_(exclude_tool_types))
# Apply return_only_letta_tools filter at database level
if return_only_letta_tools:
query = query.where(ToolModel.tool_type.like("letta_%"))
# Apply pagination if specified
if after is not None:
after_tool = await session.get(ToolModel, after)
if after_tool:
query = query.where(
or_(
ToolModel.created_at < after_tool.created_at,
and_(ToolModel.created_at == after_tool.created_at, ToolModel.id < after_tool.id),
)
)
# Apply limit
if limit is not None:
query = query.limit(limit)
# Order by created_at and id for consistent pagination
query = query.order_by(ToolModel.created_at.desc(), ToolModel.id.desc())
# Execute query
result = await session.execute(query)
tools = list(result.scalars())
# Remove any malformed tools
results = []
@@ -375,6 +459,61 @@ class ToolManager:
return results
@enforce_types
@trace_method
async def count_tools_async(
self,
actor: PydanticUser,
tool_types: Optional[List[str]] = None,
exclude_tool_types: Optional[List[str]] = None,
names: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
search: Optional[str] = None,
return_only_letta_tools: bool = False,
exclude_letta_tools: bool = False,
) -> int:
"""Count tools with the same filtering logic as list_tools_async."""
async with db_registry.async_session() as session:
# Use SQLAlchemy directly with COUNT query - same filtering logic as list_tools_async
# Start with base query
query = select(func.count(ToolModel.id)).where(ToolModel.organization_id == actor.organization_id)
# Apply tool_types filter
if tool_types is not None:
query = query.where(ToolModel.tool_type.in_(tool_types))
# Apply names filter
if names is not None:
query = query.where(ToolModel.name.in_(names))
# Apply tool_ids filter
if tool_ids is not None:
query = query.where(ToolModel.id.in_(tool_ids))
# Apply search filter (ILIKE for case-insensitive partial match)
if search is not None:
query = query.where(ToolModel.name.ilike(f"%{search}%"))
# Apply exclude_tool_types filter at database level
if exclude_tool_types is not None:
query = query.where(~ToolModel.tool_type.in_(exclude_tool_types))
# Apply return_only_letta_tools filter at database level
if return_only_letta_tools:
query = query.where(ToolModel.tool_type.like("letta_%"))
# Handle exclude_letta_tools logic (if True, exclude Letta tools)
if exclude_letta_tools:
# Exclude tools that are in the LETTA_TOOL_SET
letta_tool_names = list(LETTA_TOOL_SET)
query = query.where(~ToolModel.name.in_(letta_tool_names))
# Execute count query
result = await session.execute(query)
count = result.scalar()
return count or 0
@enforce_types
@trace_method
async def size_async(

View File

@@ -4143,6 +4143,487 @@ async def test_list_tools(server: SyncServer, print_tool, default_user):
assert any(t.id == print_tool.id for t in tools)
@pytest.mark.asyncio
async def test_list_tools_with_tool_types(server: SyncServer, default_user):
"""Test filtering tools by tool_types parameter."""
# create tools with different types
def calculator_tool(a: int, b: int) -> int:
"""Add two numbers.
Args:
a: First number
b: Second number
Returns:
Sum of a and b
"""
return a + b
def weather_tool(city: str) -> str:
"""Get weather for a city.
Args:
city: Name of the city
Returns:
Weather information
"""
return f"Weather in {city}"
# create custom tools
custom_tool1 = PydanticTool(
name="calculator",
description="Math tool",
source_code=parse_source_code(calculator_tool),
source_type="python",
tool_type=ToolType.CUSTOM,
)
custom_tool1.json_schema = derive_openai_json_schema(source_code=custom_tool1.source_code, name=custom_tool1.name)
custom_tool1 = await server.tool_manager.create_or_update_tool_async(custom_tool1, actor=default_user)
custom_tool2 = PydanticTool(
name="weather",
description="Weather tool",
source_code=parse_source_code(weather_tool),
source_type="python",
tool_type=ToolType.CUSTOM,
)
custom_tool2.json_schema = derive_openai_json_schema(source_code=custom_tool2.source_code, name=custom_tool2.name)
custom_tool2 = await server.tool_manager.create_or_update_tool_async(custom_tool2, actor=default_user)
# test filtering by single tool type
tools = await server.tool_manager.list_tools_async(actor=default_user, tool_types=[ToolType.CUSTOM.value], upsert_base_tools=False)
assert len(tools) == 2
assert all(t.tool_type == ToolType.CUSTOM for t in tools)
# test filtering by multiple tool types (should get same result since we only have CUSTOM)
tools = await server.tool_manager.list_tools_async(
actor=default_user, tool_types=[ToolType.CUSTOM.value, ToolType.LETTA_CORE.value], upsert_base_tools=False
)
assert len(tools) == 2
# test filtering by non-existent tool type
tools = await server.tool_manager.list_tools_async(
actor=default_user, tool_types=[ToolType.EXTERNAL_MCP.value], upsert_base_tools=False
)
assert len(tools) == 0
@pytest.mark.asyncio
async def test_list_tools_with_exclude_tool_types(server: SyncServer, default_user, print_tool):
"""Test excluding tools by exclude_tool_types parameter."""
# we already have print_tool which is CUSTOM type
# create a tool with a different type (simulate by updating tool type directly)
def special_tool(msg: str) -> str:
"""Special tool.
Args:
msg: Message to return
Returns:
The message
"""
return msg
special = PydanticTool(
name="special",
description="Special tool",
source_code=parse_source_code(special_tool),
source_type="python",
tool_type=ToolType.CUSTOM,
)
special.json_schema = derive_openai_json_schema(source_code=special.source_code, name=special.name)
special = await server.tool_manager.create_or_update_tool_async(special, actor=default_user)
# test excluding EXTERNAL_MCP (should get all tools since none are MCP)
tools = await server.tool_manager.list_tools_async(
actor=default_user, exclude_tool_types=[ToolType.EXTERNAL_MCP.value], upsert_base_tools=False
)
assert len(tools) == 2 # print_tool and special
# test excluding CUSTOM (should get no tools)
tools = await server.tool_manager.list_tools_async(
actor=default_user, exclude_tool_types=[ToolType.CUSTOM.value], upsert_base_tools=False
)
assert len(tools) == 0
@pytest.mark.asyncio
async def test_list_tools_with_names(server: SyncServer, default_user):
"""Test filtering tools by names parameter."""
# create tools with specific names
def alpha_tool() -> str:
"""Alpha tool.
Returns:
Alpha string
"""
return "alpha"
def beta_tool() -> str:
"""Beta tool.
Returns:
Beta string
"""
return "beta"
def gamma_tool() -> str:
"""Gamma tool.
Returns:
Gamma string
"""
return "gamma"
alpha = PydanticTool(name="alpha_tool", description="Alpha", source_code=parse_source_code(alpha_tool), source_type="python")
alpha.json_schema = derive_openai_json_schema(source_code=alpha.source_code, name=alpha.name)
alpha = await server.tool_manager.create_or_update_tool_async(alpha, actor=default_user)
beta = PydanticTool(name="beta_tool", description="Beta", source_code=parse_source_code(beta_tool), source_type="python")
beta.json_schema = derive_openai_json_schema(source_code=beta.source_code, name=beta.name)
beta = await server.tool_manager.create_or_update_tool_async(beta, actor=default_user)
gamma = PydanticTool(name="gamma_tool", description="Gamma", source_code=parse_source_code(gamma_tool), source_type="python")
gamma.json_schema = derive_openai_json_schema(source_code=gamma.source_code, name=gamma.name)
gamma = await server.tool_manager.create_or_update_tool_async(gamma, actor=default_user)
# test filtering by single name
tools = await server.tool_manager.list_tools_async(actor=default_user, names=["alpha_tool"], upsert_base_tools=False)
assert len(tools) == 1
assert tools[0].name == "alpha_tool"
# test filtering by multiple names
tools = await server.tool_manager.list_tools_async(actor=default_user, names=["alpha_tool", "gamma_tool"], upsert_base_tools=False)
assert len(tools) == 2
assert set(t.name for t in tools) == {"alpha_tool", "gamma_tool"}
# test filtering by non-existent name
tools = await server.tool_manager.list_tools_async(actor=default_user, names=["non_existent_tool"], upsert_base_tools=False)
assert len(tools) == 0
@pytest.mark.asyncio
async def test_list_tools_with_tool_ids(server: SyncServer, default_user):
"""Test filtering tools by tool_ids parameter."""
# create multiple tools
def tool1() -> str:
"""Tool 1.
Returns:
String 1
"""
return "1"
def tool2() -> str:
"""Tool 2.
Returns:
String 2
"""
return "2"
def tool3() -> str:
"""Tool 3.
Returns:
String 3
"""
return "3"
t1 = PydanticTool(name="tool1", description="First", source_code=parse_source_code(tool1), source_type="python")
t1.json_schema = derive_openai_json_schema(source_code=t1.source_code, name=t1.name)
t1 = await server.tool_manager.create_or_update_tool_async(t1, actor=default_user)
t2 = PydanticTool(name="tool2", description="Second", source_code=parse_source_code(tool2), source_type="python")
t2.json_schema = derive_openai_json_schema(source_code=t2.source_code, name=t2.name)
t2 = await server.tool_manager.create_or_update_tool_async(t2, actor=default_user)
t3 = PydanticTool(name="tool3", description="Third", source_code=parse_source_code(tool3), source_type="python")
t3.json_schema = derive_openai_json_schema(source_code=t3.source_code, name=t3.name)
t3 = await server.tool_manager.create_or_update_tool_async(t3, actor=default_user)
# test filtering by single id
tools = await server.tool_manager.list_tools_async(actor=default_user, tool_ids=[t1.id], upsert_base_tools=False)
assert len(tools) == 1
assert tools[0].id == t1.id
# test filtering by multiple ids
tools = await server.tool_manager.list_tools_async(actor=default_user, tool_ids=[t1.id, t3.id], upsert_base_tools=False)
assert len(tools) == 2
assert set(t.id for t in tools) == {t1.id, t3.id}
# test filtering by non-existent id
tools = await server.tool_manager.list_tools_async(actor=default_user, tool_ids=["non-existent-id"], upsert_base_tools=False)
assert len(tools) == 0
@pytest.mark.asyncio
async def test_list_tools_with_search(server: SyncServer, default_user):
"""Test searching tools by partial name match."""
# create tools with searchable names
def calculator_add() -> str:
"""Calculator add.
Returns:
Add operation
"""
return "add"
def calculator_subtract() -> str:
"""Calculator subtract.
Returns:
Subtract operation
"""
return "subtract"
def weather_forecast() -> str:
"""Weather forecast.
Returns:
Forecast data
"""
return "forecast"
calc_add = PydanticTool(
name="calculator_add", description="Add numbers", source_code=parse_source_code(calculator_add), source_type="python"
)
calc_add.json_schema = derive_openai_json_schema(source_code=calc_add.source_code, name=calc_add.name)
calc_add = await server.tool_manager.create_or_update_tool_async(calc_add, actor=default_user)
calc_sub = PydanticTool(
name="calculator_subtract", description="Subtract numbers", source_code=parse_source_code(calculator_subtract), source_type="python"
)
calc_sub.json_schema = derive_openai_json_schema(source_code=calc_sub.source_code, name=calc_sub.name)
calc_sub = await server.tool_manager.create_or_update_tool_async(calc_sub, actor=default_user)
weather = PydanticTool(
name="weather_forecast", description="Weather", source_code=parse_source_code(weather_forecast), source_type="python"
)
weather.json_schema = derive_openai_json_schema(source_code=weather.source_code, name=weather.name)
weather = await server.tool_manager.create_or_update_tool_async(weather, actor=default_user)
# test searching for "calculator" (should find both calculator tools)
tools = await server.tool_manager.list_tools_async(actor=default_user, search="calculator", upsert_base_tools=False)
assert len(tools) == 2
assert all("calculator" in t.name for t in tools)
# test case-insensitive search
tools = await server.tool_manager.list_tools_async(actor=default_user, search="CALCULATOR", upsert_base_tools=False)
assert len(tools) == 2
# test partial match
tools = await server.tool_manager.list_tools_async(actor=default_user, search="calc", upsert_base_tools=False)
assert len(tools) == 2
# test search with no matches
tools = await server.tool_manager.list_tools_async(actor=default_user, search="nonexistent", upsert_base_tools=False)
assert len(tools) == 0
@pytest.mark.asyncio
async def test_list_tools_return_only_letta_tools(server: SyncServer, default_user):
"""Test filtering for only Letta tools."""
# first, upsert base tools to ensure we have Letta tools
await server.tool_manager.upsert_base_tools_async(actor=default_user)
# create a custom tool
def custom_tool() -> str:
"""Custom tool.
Returns:
Custom string
"""
return "custom"
custom = PydanticTool(
name="custom_tool",
description="Custom",
source_code=parse_source_code(custom_tool),
source_type="python",
tool_type=ToolType.CUSTOM,
)
custom.json_schema = derive_openai_json_schema(source_code=custom.source_code, name=custom.name)
custom = await server.tool_manager.create_or_update_tool_async(custom, actor=default_user)
# test without filter (should get custom tool + all letta tools)
tools = await server.tool_manager.list_tools_async(actor=default_user, return_only_letta_tools=False, upsert_base_tools=False)
# should have at least the custom tool and some letta tools
assert len(tools) > 1
assert any(t.name == "custom_tool" for t in tools)
# test with filter (should only get letta tools)
tools = await server.tool_manager.list_tools_async(actor=default_user, return_only_letta_tools=True, upsert_base_tools=False)
assert len(tools) > 0
# all tools should have tool_type starting with "letta_"
assert all(t.tool_type.value.startswith("letta_") for t in tools)
# custom tool should not be in the list
assert not any(t.name == "custom_tool" for t in tools)
@pytest.mark.asyncio
async def test_list_tools_combined_filters(server: SyncServer, default_user):
"""Test combining multiple filters."""
# create various tools
def calc_add() -> str:
"""Calculator add.
Returns:
Add result
"""
return "add"
def calc_multiply() -> str:
"""Calculator multiply.
Returns:
Multiply result
"""
return "multiply"
def weather_tool() -> str:
"""Weather tool.
Returns:
Weather data
"""
return "weather"
calc1 = PydanticTool(
name="calculator_add", description="Add", source_code=parse_source_code(calc_add), source_type="python", tool_type=ToolType.CUSTOM
)
calc1.json_schema = derive_openai_json_schema(source_code=calc1.source_code, name=calc1.name)
calc1 = await server.tool_manager.create_or_update_tool_async(calc1, actor=default_user)
calc2 = PydanticTool(
name="calculator_multiply",
description="Multiply",
source_code=parse_source_code(calc_multiply),
source_type="python",
tool_type=ToolType.CUSTOM,
)
calc2.json_schema = derive_openai_json_schema(source_code=calc2.source_code, name=calc2.name)
calc2 = await server.tool_manager.create_or_update_tool_async(calc2, actor=default_user)
weather = PydanticTool(
name="weather_current",
description="Weather",
source_code=parse_source_code(weather_tool),
source_type="python",
tool_type=ToolType.CUSTOM,
)
weather.json_schema = derive_openai_json_schema(source_code=weather.source_code, name=weather.name)
weather = await server.tool_manager.create_or_update_tool_async(weather, actor=default_user)
# combine search with tool_types
tools = await server.tool_manager.list_tools_async(
actor=default_user, search="calculator", tool_types=[ToolType.CUSTOM.value], upsert_base_tools=False
)
assert len(tools) == 2
assert all("calculator" in t.name and t.tool_type == ToolType.CUSTOM for t in tools)
# combine names with tool_ids
tools = await server.tool_manager.list_tools_async(
actor=default_user, names=["calculator_add"], tool_ids=[calc1.id], upsert_base_tools=False
)
assert len(tools) == 1
assert tools[0].id == calc1.id
# combine search with exclude_tool_types
tools = await server.tool_manager.list_tools_async(
actor=default_user, search="calculator", exclude_tool_types=[ToolType.EXTERNAL_MCP.value], upsert_base_tools=False
)
assert len(tools) == 2
@pytest.mark.asyncio
async def test_count_tools_async(server: SyncServer, default_user):
"""Test counting tools with various filters."""
# create multiple tools
def tool_a() -> str:
"""Tool A.
Returns:
String a
"""
return "a"
def tool_b() -> str:
"""Tool B.
Returns:
String b
"""
return "b"
def search_tool() -> str:
"""Search tool.
Returns:
Search result
"""
return "search"
ta = PydanticTool(
name="tool_a", description="A", source_code=parse_source_code(tool_a), source_type="python", tool_type=ToolType.CUSTOM
)
ta.json_schema = derive_openai_json_schema(source_code=ta.source_code, name=ta.name)
ta = await server.tool_manager.create_or_update_tool_async(ta, actor=default_user)
tb = PydanticTool(
name="tool_b", description="B", source_code=parse_source_code(tool_b), source_type="python", tool_type=ToolType.CUSTOM
)
tb.json_schema = derive_openai_json_schema(source_code=tb.source_code, name=tb.name)
tb = await server.tool_manager.create_or_update_tool_async(tb, actor=default_user)
# upsert base tools to ensure we have Letta tools for counting
await server.tool_manager.upsert_base_tools_async(actor=default_user)
# count all tools (should have 2 custom tools + letta tools)
count = await server.tool_manager.count_tools_async(actor=default_user)
assert count > 2 # at least our 2 custom tools + letta tools
# count with tool_types filter
count = await server.tool_manager.count_tools_async(actor=default_user, tool_types=[ToolType.CUSTOM.value])
assert count == 2 # only our custom tools
# count with search filter
count = await server.tool_manager.count_tools_async(actor=default_user, search="tool")
# should at least find our 2 tools (tool_a, tool_b)
assert count >= 2
# count with names filter
count = await server.tool_manager.count_tools_async(actor=default_user, names=["tool_a", "tool_b"])
assert count == 2
# count with return_only_letta_tools
count = await server.tool_manager.count_tools_async(actor=default_user, return_only_letta_tools=True)
assert count > 0 # should have letta tools
# count with exclude_tool_types (exclude all letta tool types)
count = await server.tool_manager.count_tools_async(
actor=default_user,
exclude_tool_types=[
ToolType.LETTA_CORE.value,
ToolType.LETTA_MEMORY_CORE.value,
ToolType.LETTA_MULTI_AGENT_CORE.value,
ToolType.LETTA_SLEEPTIME_CORE.value,
ToolType.LETTA_VOICE_SLEEPTIME_CORE.value,
ToolType.LETTA_BUILTIN.value,
ToolType.LETTA_FILES_CORE.value,
],
)
assert count == 2 # only our custom tools
def test_update_tool_by_id(server: SyncServer, print_tool, default_user):
updated_description = "updated_description"
return_char_limit = 10000