diff --git a/letta/server/rest_api/routers/v1/tags.py b/letta/server/rest_api/routers/v1/tags.py index 4ffae32e..9b160bec 100644 --- a/letta/server/rest_api/routers/v1/tags.py +++ b/letta/server/rest_api/routers/v1/tags.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Literal, Optional from fastapi import APIRouter, Depends, Header, Query @@ -13,15 +13,26 @@ router = APIRouter(prefix="/tags", tags=["tag", "admin"]) @router.get("/", tags=["admin"], response_model=List[str], operation_id="list_tags") async def list_tags( - after: Optional[str] = Query(None), - limit: Optional[int] = Query(50), + before: Optional[str] = Query( + None, description="Tag cursor for pagination. Returns tags that come before this tag in the specified sort order" + ), + after: Optional[str] = Query( + None, description="Tag cursor for pagination. Returns tags that come after this tag in the specified sort order" + ), + limit: Optional[int] = Query(50, description="Maximum number of tags to return"), + order: Literal["asc", "desc"] = Query( + "asc", description="Sort order for tags. 'asc' for alphabetical order, 'desc' for reverse alphabetical order" + ), + order_by: Literal["name"] = Query("name", description="Field to sort by"), + query_text: Optional[str] = Query(None, description="Filter tags by text search"), server: "SyncServer" = Depends(get_letta_server), - query_text: Optional[str] = Query(None), actor_id: Optional[str] = Header(None, alias="user_id"), ): """ - Get a list of all tags in the database + Get a list of all agent tags in the database. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - tags = await server.agent_manager.list_tags_async(actor=actor, after=after, limit=limit, query_text=query_text) + tags = await server.agent_manager.list_tags_async( + actor=actor, before=before, after=after, limit=limit, query_text=query_text, ascending=(order == "asc") + ) return tags diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 417a010c..ec35270e 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -3542,19 +3542,27 @@ class AgentManager: @enforce_types @trace_method async def list_tags_async( - self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None + self, + actor: PydanticUser, + before: Optional[str] = None, + after: Optional[str] = None, + limit: Optional[int] = 50, + query_text: Optional[str] = None, + ascending: bool = True, ) -> List[str]: """ Get all tags a user has created, ordered alphabetically. Args: actor: User performing the action. - after: Cursor for forward pagination. - limit: Maximum number of tags to return. - query text to filter tags by. + before: Cursor for backward pagination (tags before this tag). + after: Cursor for forward pagination (tags after this tag). + limit: Maximum number of tags to return (default: 50). + query_text: Filter tags by text search. + ascending: Sort order - True for alphabetical, False for reverse (default: True). Returns: - List[str]: List of all tags. + List[str]: List of all tags matching the criteria. """ async with db_registry.async_session() as session: # Build the query using select() for async SQLAlchemy @@ -3573,10 +3581,26 @@ class AgentManager: # SQLite: Use LIKE with LOWER for case-insensitive search query = query.where(func.lower(AgentsTags.tag).like(func.lower(f"%{query_text}%"))) + # Handle pagination cursors if after: - query = query.where(AgentsTags.tag > after) + if ascending: + query = query.where(AgentsTags.tag > after) + else: + query = query.where(AgentsTags.tag < after) - query = query.order_by(AgentsTags.tag).limit(limit) + if before: + if ascending: + query = query.where(AgentsTags.tag < before) + else: + query = query.where(AgentsTags.tag > before) + + # Apply ordering based on ascending parameter + if ascending: + query = query.order_by(AgentsTags.tag.asc()) + else: + query = query.order_by(AgentsTags.tag.desc()) + + query = query.limit(limit) # Execute the query asynchronously result = await session.execute(query)