From 6716f5b5e8bd65c81261aa98eb0fcfdca3afc4f9 Mon Sep 17 00:00:00 2001 From: Shubham Naik Date: Tue, 2 Sep 2025 15:53:36 -0700 Subject: [PATCH] feat: allow list tools by tool type [PRO-870] (#4036) * feat: allow list tools by tool type * chore: update list * chore: respond to comments * chore: refactor tools hella * Add tests to managers * chore: branch --------- Co-authored-by: Shubham Naik Co-authored-by: Matt Zhou --- letta/server/rest_api/routers/v1/tools.py | 169 +++++++- letta/services/tool_manager.py | 161 +++++++- tests/test_managers.py | 481 ++++++++++++++++++++++ 3 files changed, 791 insertions(+), 20 deletions(-) diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 5b9cc0ea..efd03b0b 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -27,7 +27,7 @@ from letta.log import get_logger from letta.orm.errors import UniqueConstraintViolationError from letta.orm.mcp_oauth import OAuthSessionStatus from letta.prompts.gpt_system import get_system_text -from letta.schemas.enums import MessageRole +from letta.schemas.enums import MessageRole, ToolType from letta.schemas.letta_message import ToolReturnMessage from letta.schemas.letta_message_content import TextContent from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer @@ -62,16 +62,94 @@ async def delete_tool( @router.get("/count", response_model=int, operation_id="count_tools") async def count_tools( + name: Optional[str] = None, + names: Optional[List[str]] = Query(None, description="Filter by specific tool names"), + tool_ids: Optional[List[str]] = Query( + None, description="Filter by specific tool IDs - accepts repeated params or comma-separated values" + ), + search: Optional[str] = Query(None, description="Search tool names (case-insensitive partial match)"), + tool_types: Optional[List[str]] = Query(None, description="Filter by tool type(s) - accepts repeated params or comma-separated values"), + exclude_tool_types: Optional[List[str]] = Query( + None, description="Tool type(s) to exclude - accepts repeated params or comma-separated values" + ), + return_only_letta_tools: Optional[bool] = Query(False, description="Count only tools with tool_type starting with 'letta_'"), + exclude_letta_tools: Optional[bool] = Query(False, description="Exclude built-in Letta tools from the count"), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), - include_base_tools: Optional[bool] = Query(False, description="Include built-in Letta tools in the count"), ): """ Get a count of all tools available to agents belonging to the org of the user. """ try: + # Helper function to parse tool types - supports both repeated params and comma-separated values + def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]: + if tool_types_input is None: + return None + + # Flatten any comma-separated values and validate against ToolType enum + flattened_types = [] + for item in tool_types_input: + # Split by comma in case user provided comma-separated values + types_in_item = [t.strip() for t in item.split(",") if t.strip()] + flattened_types.extend(types_in_item) + + # Validate each type against the ToolType enum + valid_types = [] + valid_values = [tt.value for tt in ToolType] + + for tool_type in flattened_types: + if tool_type not in valid_values: + raise HTTPException( + status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}" + ) + valid_types.append(tool_type) + + return valid_types if valid_types else None + + # Parse and validate tool types (same logic as list_tools) + tool_types_str = parse_tool_types(tool_types) + exclude_tool_types_str = parse_tool_types(exclude_tool_types) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return await server.tool_manager.size_async(actor=actor, include_base_tools=include_base_tools) + + # Combine single name with names list for unified processing (same logic as list_tools) + combined_names = [] + if name is not None: + combined_names.append(name) + if names is not None: + combined_names.extend(names) + + # Use None if no names specified, otherwise use the combined list + final_names = combined_names if combined_names else None + + # Helper function to parse tool IDs - supports both repeated params and comma-separated values + def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]: + if tool_ids_input is None: + return None + + # Flatten any comma-separated values + flattened_ids = [] + for item in tool_ids_input: + # Split by comma in case user provided comma-separated values + ids_in_item = [id.strip() for id in item.split(",") if id.strip()] + flattened_ids.extend(ids_in_item) + + return flattened_ids if flattened_ids else None + + # Parse tool IDs (same logic as list_tools) + final_tool_ids = parse_tool_ids(tool_ids) + + # Get the count of tools using unified query + return await server.tool_manager.count_tools_async( + actor=actor, + tool_types=tool_types_str, + exclude_tool_types=exclude_tool_types_str, + names=final_names, + tool_ids=final_tool_ids, + search=search, + return_only_letta_tools=return_only_letta_tools, + exclude_letta_tools=exclude_letta_tools, + ) except Exception as e: print(f"Error occurred: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -99,6 +177,16 @@ async def list_tools( after: Optional[str] = None, limit: Optional[int] = 50, name: Optional[str] = None, + names: Optional[List[str]] = Query(None, description="Filter by specific tool names"), + tool_ids: Optional[List[str]] = Query( + None, description="Filter by specific tool IDs - accepts repeated params or comma-separated values" + ), + search: Optional[str] = Query(None, description="Search tool names (case-insensitive partial match)"), + tool_types: Optional[List[str]] = Query(None, description="Filter by tool type(s) - accepts repeated params or comma-separated values"), + exclude_tool_types: Optional[List[str]] = Query( + None, description="Tool type(s) to exclude - accepts repeated params or comma-separated values" + ), + return_only_letta_tools: Optional[bool] = Query(False, description="Return only tools with tool_type starting with 'letta_'"), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): @@ -106,13 +194,76 @@ async def list_tools( Get a list of all tools available to agents belonging to the org of the user """ try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - if name is not None: - tool = await server.tool_manager.get_tool_by_name_async(tool_name=name, actor=actor) - return [tool] if tool else [] + # Helper function to parse tool types - supports both repeated params and comma-separated values + def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]: + if tool_types_input is None: + return None - # Get the list of tools - return await server.tool_manager.list_tools_async(actor=actor, after=after, limit=limit) + # Flatten any comma-separated values and validate against ToolType enum + flattened_types = [] + for item in tool_types_input: + # Split by comma in case user provided comma-separated values + types_in_item = [t.strip() for t in item.split(",") if t.strip()] + flattened_types.extend(types_in_item) + + # Validate each type against the ToolType enum + valid_types = [] + valid_values = [tt.value for tt in ToolType] + + for tool_type in flattened_types: + if tool_type not in valid_values: + raise HTTPException( + status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}" + ) + valid_types.append(tool_type) + + return valid_types if valid_types else None + + # Parse and validate tool types + tool_types_str = parse_tool_types(tool_types) + exclude_tool_types_str = parse_tool_types(exclude_tool_types) + + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + + # Combine single name with names list for unified processing + combined_names = [] + if name is not None: + combined_names.append(name) + if names is not None: + combined_names.extend(names) + + # Use None if no names specified, otherwise use the combined list + final_names = combined_names if combined_names else None + + # Helper function to parse tool IDs - supports both repeated params and comma-separated values + def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]: + if tool_ids_input is None: + return None + + # Flatten any comma-separated values + flattened_ids = [] + for item in tool_ids_input: + # Split by comma in case user provided comma-separated values + ids_in_item = [id.strip() for id in item.split(",") if id.strip()] + flattened_ids.extend(ids_in_item) + + return flattened_ids if flattened_ids else None + + # Parse tool IDs + final_tool_ids = parse_tool_ids(tool_ids) + + # Get the list of tools using unified query + return await server.tool_manager.list_tools_async( + actor=actor, + after=after, + limit=limit, + tool_types=tool_types_str, + exclude_tool_types=exclude_tool_types_str, + names=final_names, + tool_ids=final_tool_ids, + search=search, + return_only_letta_tools=return_only_letta_tools, + ) except Exception as e: # Log or print the full exception here for debugging print(f"Error occurred: {e}") diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 97aa6db9..011dd4a8 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -2,7 +2,7 @@ import importlib import warnings from typing import List, Optional, Set, Union -from sqlalchemy import func, select +from sqlalchemy import and_, func, or_, select from letta.constants import ( BASE_FUNCTION_RETURN_CHAR_LIMIT, @@ -319,10 +319,30 @@ class ToolManager: @enforce_types @trace_method async def list_tools_async( - self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, upsert_base_tools: bool = True + self, + actor: PydanticUser, + after: Optional[str] = None, + limit: Optional[int] = 50, + upsert_base_tools: bool = True, + tool_types: Optional[List[str]] = None, + exclude_tool_types: Optional[List[str]] = None, + names: Optional[List[str]] = None, + tool_ids: Optional[List[str]] = None, + search: Optional[str] = None, + return_only_letta_tools: bool = False, ) -> List[PydanticTool]: """List all tools with optional pagination.""" - tools = await self._list_tools_async(actor=actor, after=after, limit=limit) + tools = await self._list_tools_async( + actor=actor, + after=after, + limit=limit, + tool_types=tool_types, + exclude_tool_types=exclude_tool_types, + names=names, + tool_ids=tool_ids, + search=search, + return_only_letta_tools=return_only_letta_tools, + ) # Check if all base tools are present if we requested all the tools w/o cursor # TODO: This is a temporary hack to resolve this issue @@ -337,22 +357,86 @@ class ToolManager: logger.info(f"Missing base tools detected: {missing_base_tools}. Upserting all base tools.") await self.upsert_base_tools_async(actor=actor) # Re-fetch the tools list after upserting base tools - tools = await self._list_tools_async(actor=actor, after=after, limit=limit) + tools = await self._list_tools_async( + actor=actor, + after=after, + limit=limit, + tool_types=tool_types, + exclude_tool_types=exclude_tool_types, + names=names, + tool_ids=tool_ids, + search=search, + return_only_letta_tools=return_only_letta_tools, + ) return tools @enforce_types @trace_method - async def _list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: + async def _list_tools_async( + self, + actor: PydanticUser, + after: Optional[str] = None, + limit: Optional[int] = 50, + tool_types: Optional[List[str]] = None, + exclude_tool_types: Optional[List[str]] = None, + names: Optional[List[str]] = None, + tool_ids: Optional[List[str]] = None, + search: Optional[str] = None, + return_only_letta_tools: bool = False, + ) -> List[PydanticTool]: """List all tools with optional pagination.""" tools_to_delete = [] async with db_registry.async_session() as session: - tools = await ToolModel.list_async( - db_session=session, - after=after, - limit=limit, - organization_id=actor.organization_id, - ) + # Use SQLAlchemy directly for all cases - more control and consistency + # Start with base query + query = select(ToolModel).where(ToolModel.organization_id == actor.organization_id) + + # Apply tool_types filter + if tool_types is not None: + query = query.where(ToolModel.tool_type.in_(tool_types)) + + # Apply names filter + if names is not None: + query = query.where(ToolModel.name.in_(names)) + + # Apply tool_ids filter + if tool_ids is not None: + query = query.where(ToolModel.id.in_(tool_ids)) + + # Apply search filter (ILIKE for case-insensitive partial match) + if search is not None: + query = query.where(ToolModel.name.ilike(f"%{search}%")) + + # Apply exclude_tool_types filter at database level + if exclude_tool_types is not None: + query = query.where(~ToolModel.tool_type.in_(exclude_tool_types)) + + # Apply return_only_letta_tools filter at database level + if return_only_letta_tools: + query = query.where(ToolModel.tool_type.like("letta_%")) + + # Apply pagination if specified + if after is not None: + after_tool = await session.get(ToolModel, after) + if after_tool: + query = query.where( + or_( + ToolModel.created_at < after_tool.created_at, + and_(ToolModel.created_at == after_tool.created_at, ToolModel.id < after_tool.id), + ) + ) + + # Apply limit + if limit is not None: + query = query.limit(limit) + + # Order by created_at and id for consistent pagination + query = query.order_by(ToolModel.created_at.desc(), ToolModel.id.desc()) + + # Execute query + result = await session.execute(query) + tools = list(result.scalars()) # Remove any malformed tools results = [] @@ -375,6 +459,61 @@ class ToolManager: return results + @enforce_types + @trace_method + async def count_tools_async( + self, + actor: PydanticUser, + tool_types: Optional[List[str]] = None, + exclude_tool_types: Optional[List[str]] = None, + names: Optional[List[str]] = None, + tool_ids: Optional[List[str]] = None, + search: Optional[str] = None, + return_only_letta_tools: bool = False, + exclude_letta_tools: bool = False, + ) -> int: + """Count tools with the same filtering logic as list_tools_async.""" + async with db_registry.async_session() as session: + # Use SQLAlchemy directly with COUNT query - same filtering logic as list_tools_async + # Start with base query + query = select(func.count(ToolModel.id)).where(ToolModel.organization_id == actor.organization_id) + + # Apply tool_types filter + if tool_types is not None: + query = query.where(ToolModel.tool_type.in_(tool_types)) + + # Apply names filter + if names is not None: + query = query.where(ToolModel.name.in_(names)) + + # Apply tool_ids filter + if tool_ids is not None: + query = query.where(ToolModel.id.in_(tool_ids)) + + # Apply search filter (ILIKE for case-insensitive partial match) + if search is not None: + query = query.where(ToolModel.name.ilike(f"%{search}%")) + + # Apply exclude_tool_types filter at database level + if exclude_tool_types is not None: + query = query.where(~ToolModel.tool_type.in_(exclude_tool_types)) + + # Apply return_only_letta_tools filter at database level + if return_only_letta_tools: + query = query.where(ToolModel.tool_type.like("letta_%")) + + # Handle exclude_letta_tools logic (if True, exclude Letta tools) + if exclude_letta_tools: + # Exclude tools that are in the LETTA_TOOL_SET + letta_tool_names = list(LETTA_TOOL_SET) + query = query.where(~ToolModel.name.in_(letta_tool_names)) + + # Execute count query + result = await session.execute(query) + count = result.scalar() + + return count or 0 + @enforce_types @trace_method async def size_async( diff --git a/tests/test_managers.py b/tests/test_managers.py index 6700e370..b1ebdfdf 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4143,6 +4143,487 @@ async def test_list_tools(server: SyncServer, print_tool, default_user): assert any(t.id == print_tool.id for t in tools) +@pytest.mark.asyncio +async def test_list_tools_with_tool_types(server: SyncServer, default_user): + """Test filtering tools by tool_types parameter.""" + + # create tools with different types + def calculator_tool(a: int, b: int) -> int: + """Add two numbers. + + Args: + a: First number + b: Second number + + Returns: + Sum of a and b + """ + return a + b + + def weather_tool(city: str) -> str: + """Get weather for a city. + + Args: + city: Name of the city + + Returns: + Weather information + """ + return f"Weather in {city}" + + # create custom tools + custom_tool1 = PydanticTool( + name="calculator", + description="Math tool", + source_code=parse_source_code(calculator_tool), + source_type="python", + tool_type=ToolType.CUSTOM, + ) + custom_tool1.json_schema = derive_openai_json_schema(source_code=custom_tool1.source_code, name=custom_tool1.name) + custom_tool1 = await server.tool_manager.create_or_update_tool_async(custom_tool1, actor=default_user) + + custom_tool2 = PydanticTool( + name="weather", + description="Weather tool", + source_code=parse_source_code(weather_tool), + source_type="python", + tool_type=ToolType.CUSTOM, + ) + custom_tool2.json_schema = derive_openai_json_schema(source_code=custom_tool2.source_code, name=custom_tool2.name) + custom_tool2 = await server.tool_manager.create_or_update_tool_async(custom_tool2, actor=default_user) + + # test filtering by single tool type + tools = await server.tool_manager.list_tools_async(actor=default_user, tool_types=[ToolType.CUSTOM.value], upsert_base_tools=False) + assert len(tools) == 2 + assert all(t.tool_type == ToolType.CUSTOM for t in tools) + + # test filtering by multiple tool types (should get same result since we only have CUSTOM) + tools = await server.tool_manager.list_tools_async( + actor=default_user, tool_types=[ToolType.CUSTOM.value, ToolType.LETTA_CORE.value], upsert_base_tools=False + ) + assert len(tools) == 2 + + # test filtering by non-existent tool type + tools = await server.tool_manager.list_tools_async( + actor=default_user, tool_types=[ToolType.EXTERNAL_MCP.value], upsert_base_tools=False + ) + assert len(tools) == 0 + + +@pytest.mark.asyncio +async def test_list_tools_with_exclude_tool_types(server: SyncServer, default_user, print_tool): + """Test excluding tools by exclude_tool_types parameter.""" + # we already have print_tool which is CUSTOM type + + # create a tool with a different type (simulate by updating tool type directly) + def special_tool(msg: str) -> str: + """Special tool. + + Args: + msg: Message to return + + Returns: + The message + """ + return msg + + special = PydanticTool( + name="special", + description="Special tool", + source_code=parse_source_code(special_tool), + source_type="python", + tool_type=ToolType.CUSTOM, + ) + special.json_schema = derive_openai_json_schema(source_code=special.source_code, name=special.name) + special = await server.tool_manager.create_or_update_tool_async(special, actor=default_user) + + # test excluding EXTERNAL_MCP (should get all tools since none are MCP) + tools = await server.tool_manager.list_tools_async( + actor=default_user, exclude_tool_types=[ToolType.EXTERNAL_MCP.value], upsert_base_tools=False + ) + assert len(tools) == 2 # print_tool and special + + # test excluding CUSTOM (should get no tools) + tools = await server.tool_manager.list_tools_async( + actor=default_user, exclude_tool_types=[ToolType.CUSTOM.value], upsert_base_tools=False + ) + assert len(tools) == 0 + + +@pytest.mark.asyncio +async def test_list_tools_with_names(server: SyncServer, default_user): + """Test filtering tools by names parameter.""" + + # create tools with specific names + def alpha_tool() -> str: + """Alpha tool. + + Returns: + Alpha string + """ + return "alpha" + + def beta_tool() -> str: + """Beta tool. + + Returns: + Beta string + """ + return "beta" + + def gamma_tool() -> str: + """Gamma tool. + + Returns: + Gamma string + """ + return "gamma" + + alpha = PydanticTool(name="alpha_tool", description="Alpha", source_code=parse_source_code(alpha_tool), source_type="python") + alpha.json_schema = derive_openai_json_schema(source_code=alpha.source_code, name=alpha.name) + alpha = await server.tool_manager.create_or_update_tool_async(alpha, actor=default_user) + + beta = PydanticTool(name="beta_tool", description="Beta", source_code=parse_source_code(beta_tool), source_type="python") + beta.json_schema = derive_openai_json_schema(source_code=beta.source_code, name=beta.name) + beta = await server.tool_manager.create_or_update_tool_async(beta, actor=default_user) + + gamma = PydanticTool(name="gamma_tool", description="Gamma", source_code=parse_source_code(gamma_tool), source_type="python") + gamma.json_schema = derive_openai_json_schema(source_code=gamma.source_code, name=gamma.name) + gamma = await server.tool_manager.create_or_update_tool_async(gamma, actor=default_user) + + # test filtering by single name + tools = await server.tool_manager.list_tools_async(actor=default_user, names=["alpha_tool"], upsert_base_tools=False) + assert len(tools) == 1 + assert tools[0].name == "alpha_tool" + + # test filtering by multiple names + tools = await server.tool_manager.list_tools_async(actor=default_user, names=["alpha_tool", "gamma_tool"], upsert_base_tools=False) + assert len(tools) == 2 + assert set(t.name for t in tools) == {"alpha_tool", "gamma_tool"} + + # test filtering by non-existent name + tools = await server.tool_manager.list_tools_async(actor=default_user, names=["non_existent_tool"], upsert_base_tools=False) + assert len(tools) == 0 + + +@pytest.mark.asyncio +async def test_list_tools_with_tool_ids(server: SyncServer, default_user): + """Test filtering tools by tool_ids parameter.""" + + # create multiple tools + def tool1() -> str: + """Tool 1. + + Returns: + String 1 + """ + return "1" + + def tool2() -> str: + """Tool 2. + + Returns: + String 2 + """ + return "2" + + def tool3() -> str: + """Tool 3. + + Returns: + String 3 + """ + return "3" + + t1 = PydanticTool(name="tool1", description="First", source_code=parse_source_code(tool1), source_type="python") + t1.json_schema = derive_openai_json_schema(source_code=t1.source_code, name=t1.name) + t1 = await server.tool_manager.create_or_update_tool_async(t1, actor=default_user) + + t2 = PydanticTool(name="tool2", description="Second", source_code=parse_source_code(tool2), source_type="python") + t2.json_schema = derive_openai_json_schema(source_code=t2.source_code, name=t2.name) + t2 = await server.tool_manager.create_or_update_tool_async(t2, actor=default_user) + + t3 = PydanticTool(name="tool3", description="Third", source_code=parse_source_code(tool3), source_type="python") + t3.json_schema = derive_openai_json_schema(source_code=t3.source_code, name=t3.name) + t3 = await server.tool_manager.create_or_update_tool_async(t3, actor=default_user) + + # test filtering by single id + tools = await server.tool_manager.list_tools_async(actor=default_user, tool_ids=[t1.id], upsert_base_tools=False) + assert len(tools) == 1 + assert tools[0].id == t1.id + + # test filtering by multiple ids + tools = await server.tool_manager.list_tools_async(actor=default_user, tool_ids=[t1.id, t3.id], upsert_base_tools=False) + assert len(tools) == 2 + assert set(t.id for t in tools) == {t1.id, t3.id} + + # test filtering by non-existent id + tools = await server.tool_manager.list_tools_async(actor=default_user, tool_ids=["non-existent-id"], upsert_base_tools=False) + assert len(tools) == 0 + + +@pytest.mark.asyncio +async def test_list_tools_with_search(server: SyncServer, default_user): + """Test searching tools by partial name match.""" + + # create tools with searchable names + def calculator_add() -> str: + """Calculator add. + + Returns: + Add operation + """ + return "add" + + def calculator_subtract() -> str: + """Calculator subtract. + + Returns: + Subtract operation + """ + return "subtract" + + def weather_forecast() -> str: + """Weather forecast. + + Returns: + Forecast data + """ + return "forecast" + + calc_add = PydanticTool( + name="calculator_add", description="Add numbers", source_code=parse_source_code(calculator_add), source_type="python" + ) + calc_add.json_schema = derive_openai_json_schema(source_code=calc_add.source_code, name=calc_add.name) + calc_add = await server.tool_manager.create_or_update_tool_async(calc_add, actor=default_user) + + calc_sub = PydanticTool( + name="calculator_subtract", description="Subtract numbers", source_code=parse_source_code(calculator_subtract), source_type="python" + ) + calc_sub.json_schema = derive_openai_json_schema(source_code=calc_sub.source_code, name=calc_sub.name) + calc_sub = await server.tool_manager.create_or_update_tool_async(calc_sub, actor=default_user) + + weather = PydanticTool( + name="weather_forecast", description="Weather", source_code=parse_source_code(weather_forecast), source_type="python" + ) + weather.json_schema = derive_openai_json_schema(source_code=weather.source_code, name=weather.name) + weather = await server.tool_manager.create_or_update_tool_async(weather, actor=default_user) + + # test searching for "calculator" (should find both calculator tools) + tools = await server.tool_manager.list_tools_async(actor=default_user, search="calculator", upsert_base_tools=False) + assert len(tools) == 2 + assert all("calculator" in t.name for t in tools) + + # test case-insensitive search + tools = await server.tool_manager.list_tools_async(actor=default_user, search="CALCULATOR", upsert_base_tools=False) + assert len(tools) == 2 + + # test partial match + tools = await server.tool_manager.list_tools_async(actor=default_user, search="calc", upsert_base_tools=False) + assert len(tools) == 2 + + # test search with no matches + tools = await server.tool_manager.list_tools_async(actor=default_user, search="nonexistent", upsert_base_tools=False) + assert len(tools) == 0 + + +@pytest.mark.asyncio +async def test_list_tools_return_only_letta_tools(server: SyncServer, default_user): + """Test filtering for only Letta tools.""" + # first, upsert base tools to ensure we have Letta tools + await server.tool_manager.upsert_base_tools_async(actor=default_user) + + # create a custom tool + def custom_tool() -> str: + """Custom tool. + + Returns: + Custom string + """ + return "custom" + + custom = PydanticTool( + name="custom_tool", + description="Custom", + source_code=parse_source_code(custom_tool), + source_type="python", + tool_type=ToolType.CUSTOM, + ) + custom.json_schema = derive_openai_json_schema(source_code=custom.source_code, name=custom.name) + custom = await server.tool_manager.create_or_update_tool_async(custom, actor=default_user) + + # test without filter (should get custom tool + all letta tools) + tools = await server.tool_manager.list_tools_async(actor=default_user, return_only_letta_tools=False, upsert_base_tools=False) + # should have at least the custom tool and some letta tools + assert len(tools) > 1 + assert any(t.name == "custom_tool" for t in tools) + + # test with filter (should only get letta tools) + tools = await server.tool_manager.list_tools_async(actor=default_user, return_only_letta_tools=True, upsert_base_tools=False) + assert len(tools) > 0 + # all tools should have tool_type starting with "letta_" + assert all(t.tool_type.value.startswith("letta_") for t in tools) + # custom tool should not be in the list + assert not any(t.name == "custom_tool" for t in tools) + + +@pytest.mark.asyncio +async def test_list_tools_combined_filters(server: SyncServer, default_user): + """Test combining multiple filters.""" + + # create various tools + def calc_add() -> str: + """Calculator add. + + Returns: + Add result + """ + return "add" + + def calc_multiply() -> str: + """Calculator multiply. + + Returns: + Multiply result + """ + return "multiply" + + def weather_tool() -> str: + """Weather tool. + + Returns: + Weather data + """ + return "weather" + + calc1 = PydanticTool( + name="calculator_add", description="Add", source_code=parse_source_code(calc_add), source_type="python", tool_type=ToolType.CUSTOM + ) + calc1.json_schema = derive_openai_json_schema(source_code=calc1.source_code, name=calc1.name) + calc1 = await server.tool_manager.create_or_update_tool_async(calc1, actor=default_user) + + calc2 = PydanticTool( + name="calculator_multiply", + description="Multiply", + source_code=parse_source_code(calc_multiply), + source_type="python", + tool_type=ToolType.CUSTOM, + ) + calc2.json_schema = derive_openai_json_schema(source_code=calc2.source_code, name=calc2.name) + calc2 = await server.tool_manager.create_or_update_tool_async(calc2, actor=default_user) + + weather = PydanticTool( + name="weather_current", + description="Weather", + source_code=parse_source_code(weather_tool), + source_type="python", + tool_type=ToolType.CUSTOM, + ) + weather.json_schema = derive_openai_json_schema(source_code=weather.source_code, name=weather.name) + weather = await server.tool_manager.create_or_update_tool_async(weather, actor=default_user) + + # combine search with tool_types + tools = await server.tool_manager.list_tools_async( + actor=default_user, search="calculator", tool_types=[ToolType.CUSTOM.value], upsert_base_tools=False + ) + assert len(tools) == 2 + assert all("calculator" in t.name and t.tool_type == ToolType.CUSTOM for t in tools) + + # combine names with tool_ids + tools = await server.tool_manager.list_tools_async( + actor=default_user, names=["calculator_add"], tool_ids=[calc1.id], upsert_base_tools=False + ) + assert len(tools) == 1 + assert tools[0].id == calc1.id + + # combine search with exclude_tool_types + tools = await server.tool_manager.list_tools_async( + actor=default_user, search="calculator", exclude_tool_types=[ToolType.EXTERNAL_MCP.value], upsert_base_tools=False + ) + assert len(tools) == 2 + + +@pytest.mark.asyncio +async def test_count_tools_async(server: SyncServer, default_user): + """Test counting tools with various filters.""" + + # create multiple tools + def tool_a() -> str: + """Tool A. + + Returns: + String a + """ + return "a" + + def tool_b() -> str: + """Tool B. + + Returns: + String b + """ + return "b" + + def search_tool() -> str: + """Search tool. + + Returns: + Search result + """ + return "search" + + ta = PydanticTool( + name="tool_a", description="A", source_code=parse_source_code(tool_a), source_type="python", tool_type=ToolType.CUSTOM + ) + ta.json_schema = derive_openai_json_schema(source_code=ta.source_code, name=ta.name) + ta = await server.tool_manager.create_or_update_tool_async(ta, actor=default_user) + + tb = PydanticTool( + name="tool_b", description="B", source_code=parse_source_code(tool_b), source_type="python", tool_type=ToolType.CUSTOM + ) + tb.json_schema = derive_openai_json_schema(source_code=tb.source_code, name=tb.name) + tb = await server.tool_manager.create_or_update_tool_async(tb, actor=default_user) + + # upsert base tools to ensure we have Letta tools for counting + await server.tool_manager.upsert_base_tools_async(actor=default_user) + + # count all tools (should have 2 custom tools + letta tools) + count = await server.tool_manager.count_tools_async(actor=default_user) + assert count > 2 # at least our 2 custom tools + letta tools + + # count with tool_types filter + count = await server.tool_manager.count_tools_async(actor=default_user, tool_types=[ToolType.CUSTOM.value]) + assert count == 2 # only our custom tools + + # count with search filter + count = await server.tool_manager.count_tools_async(actor=default_user, search="tool") + # should at least find our 2 tools (tool_a, tool_b) + assert count >= 2 + + # count with names filter + count = await server.tool_manager.count_tools_async(actor=default_user, names=["tool_a", "tool_b"]) + assert count == 2 + + # count with return_only_letta_tools + count = await server.tool_manager.count_tools_async(actor=default_user, return_only_letta_tools=True) + assert count > 0 # should have letta tools + + # count with exclude_tool_types (exclude all letta tool types) + count = await server.tool_manager.count_tools_async( + actor=default_user, + exclude_tool_types=[ + ToolType.LETTA_CORE.value, + ToolType.LETTA_MEMORY_CORE.value, + ToolType.LETTA_MULTI_AGENT_CORE.value, + ToolType.LETTA_SLEEPTIME_CORE.value, + ToolType.LETTA_VOICE_SLEEPTIME_CORE.value, + ToolType.LETTA_BUILTIN.value, + ToolType.LETTA_FILES_CORE.value, + ], + ) + assert count == 2 # only our custom tools + + def test_update_tool_by_id(server: SyncServer, print_tool, default_user): updated_description = "updated_description" return_char_limit = 10000