From 280f2c8d2ad94c70ae44dfdf06eb9c327165f1ee Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Fri, 25 Apr 2025 15:26:13 -0700 Subject: [PATCH] feat: add query parameter to exclude builtin tools for tool count (#1898) fix: FastAPI router ordering for sources and tools --- letta/server/rest_api/routers/v1/sources.py | 22 ++++++------- letta/server/rest_api/routers/v1/tools.py | 35 +++++++++++---------- letta/services/tool_manager.py | 15 ++++----- 3 files changed, 36 insertions(+), 36 deletions(-) diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index ac91d69b..97a76eb3 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -20,6 +20,17 @@ from letta.utils import sanitize_filename router = APIRouter(prefix="/sources", tags=["sources"]) +@router.get("/count", response_model=int, operation_id="count_sources") +def count_sources( + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Count all data sources created by a user. + """ + return server.source_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) + + @router.get("/{source_id}", response_model=Source, operation_id="retrieve_source") def retrieve_source( source_id: str, @@ -67,17 +78,6 @@ def list_sources( return server.list_all_sources(actor=actor) -@router.get("/count", response_model=int, operation_id="count_sources") -def count_sources( - server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present -): - """ - Count all data sources created by a user. - """ - return server.source_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) - - @router.post("/", response_model=Source, operation_id="create_source") def create_source( source_create: SourceCreate, diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 06482175..a1c7591b 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -9,7 +9,7 @@ from composio.exceptions import ( EnumMetadataNotFound, EnumStringNotFound, ) -from fastapi import APIRouter, Body, Depends, Header, HTTPException +from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query from letta.errors import LettaToolCreateError from letta.functions.mcp_client.exceptions import MCPTimeoutError @@ -40,6 +40,24 @@ def delete_tool( server.tool_manager.delete_tool_by_id(tool_id=tool_id, actor=actor) +@router.get("/count", response_model=int, operation_id="count_tools") +def count_tools( + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), + include_base_tools: Optional[bool] = Query(False, description="Include built-in Letta tools in the count"), +): + """ + Get a count of all tools available to agents belonging to the org of the user. + """ + try: + return server.tool_manager.size( + actor=server.user_manager.get_user_or_default(user_id=actor_id), include_base_tools=include_base_tools + ) + except Exception as e: + print(f"Error occurred: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/{tool_id}", response_model=Tool, operation_id="retrieve_tool") def retrieve_tool( tool_id: str, @@ -80,21 +98,6 @@ def list_tools( raise HTTPException(status_code=500, detail=str(e)) -@router.get("/count", response_model=int, operation_id="count_tools") -def count_tools( - server: SyncServer = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), -): - """ - Get a count of all tools available to agents belonging to the org of the user - """ - try: - return server.tool_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) - except Exception as e: - print(f"Error occurred: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - @router.post("/", response_model=Tool, operation_id="create_tool") def create_tool( request: ToolCreate = Body(...), diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 571e67d4..d03a36f8 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -29,14 +29,6 @@ logger = get_logger(__name__) class ToolManager: """Manager class to handle business logic related to Tools.""" - BASE_TOOL_NAMES = [ - "send_message", - "conversation_search", - "archival_memory_insert", - "archival_memory_search", - ] - BASE_MEMORY_TOOL_NAMES = ["core_memory_append", "core_memory_replace"] - def __init__(self): # Fetching the db_context similarly as in OrganizationManager from letta.server.db import db_context @@ -149,12 +141,17 @@ class ToolManager: def size( self, actor: PydanticUser, + include_base_tools: bool, ) -> int: """ Get the total count of tools for the given user. + + If include_builtin is True, it will also count the built-in tools. """ with self.session_maker() as session: - return ToolModel.size(db_session=session, actor=actor) + if include_base_tools: + return ToolModel.size(db_session=session, actor=actor) + return ToolModel.size(db_session=session, actor=actor, name=LETTA_TOOL_SET) @enforce_types def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> PydanticTool: