From aa55a3d10ec4802cd18815ed88d1a299776aa229 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 18 Oct 2024 15:19:45 -0700 Subject: [PATCH] feat: Add pagination for list tools (#1907) Co-authored-by: Matt Zhou --- letta/client/client.py | 16 ++++++++++------ letta/metadata.py | 19 ++++++++++++++----- letta/server/rest_api/app.py | 1 + letta/server/rest_api/routers/v1/tools.py | 15 +++++++++------ letta/server/server.py | 4 ++-- tests/test_client.py | 22 ++++++++++++++++++++++ tests/test_new_client.py | 3 +-- 7 files changed, 59 insertions(+), 21 deletions(-) diff --git a/letta/client/client.py b/letta/client/client.py index bea579a3..79e6da4c 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -200,7 +200,7 @@ class AbstractClient(object): ) -> Tool: raise NotImplementedError - def list_tools(self) -> List[Tool]: + def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]: raise NotImplementedError def get_tool(self, id: str) -> Tool: @@ -1382,14 +1382,19 @@ class RESTClient(AbstractClient): # raise ValueError(f"Failed to create tool: {response.text}") # return ToolModel(**response.json()) - def list_tools(self) -> List[Tool]: + def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]: """ List available tools for the user. Returns: tools (List[Tool]): List of tools """ - response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", headers=self.headers) + params = {} + if cursor: + params["cursor"] = str(cursor) + if limit: + params["limit"] = limit + response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", params=params, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to list tools: {response.text}") return [Tool(**tool) for tool in response.json()] @@ -2281,15 +2286,14 @@ class LocalClient(AbstractClient): ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name), self.user_id ) - def list_tools(self): + def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]: """ List available tools for the user. Returns: tools (List[Tool]): List of tools """ - tools = self.server.list_tools(user_id=self.user_id) - return tools + return self.server.list_tools(cursor=cursor, limit=limit, user_id=self.user_id) def get_tool(self, id: str) -> Optional[Tool]: """ diff --git a/letta/metadata.py b/letta/metadata.py index ed2b4202..1d36d216 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -14,7 +14,9 @@ from sqlalchemy import ( Integer, String, TypeDecorator, + asc, desc, + or_, ) from sqlalchemy.sql import func @@ -707,12 +709,19 @@ class MetadataStore: session.commit() @enforce_types - # def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools - def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]: + def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]: with self.session_maker() as session: - results = session.query(ToolModel).filter(ToolModel.user_id == None).all() - if user_id: - results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all() + # Query for public tools or user-specific tools + query = session.query(ToolModel).filter(or_(ToolModel.user_id == None, ToolModel.user_id == user_id)) + + # Apply cursor if provided (assuming cursor is an ID) + if cursor: + query = query.filter(ToolModel.id > cursor) + + # Order by ID and apply limit + results = query.order_by(asc(ToolModel.id)).limit(limit).all() + + # Convert to records res = [r.to_record() for r in results] return res diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 7b73674e..303ed0ad 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -70,6 +70,7 @@ def create_application() -> "FastAPI": title="Letta", summary="Create LLM agents with long-term memory and custom tools 📚🦙", version="1.0.0", # TODO wire this up to the version in the package + debug=True, ) if "--ade" in sys.argv: diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 0defac11..e8782b89 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -59,18 +59,21 @@ def get_tool_id( @router.get("/", response_model=List[Tool], operation_id="list_tools") def list_all_tools( + cursor: Optional[str] = None, + limit: Optional[int] = 50, server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get a list of all tools available to agents created by a user """ - actor = server.get_user_or_default(user_id=user_id) - actor.id - - # TODO: add back when user-specific - return server.list_tools(user_id=actor.id) - # return server.ms.list_tools(user_id=None) + try: + actor = server.get_user_or_default(user_id=user_id) + return server.list_tools(cursor=cursor, limit=limit, user_id=actor.id) + except Exception as e: + # Log or print the full exception here for debugging + print(f"Error occurred: {e}") + raise HTTPException(status_code=500, detail=str(e)) @router.post("/", response_model=Tool, operation_id="create_tool") diff --git a/letta/server/server.py b/letta/server/server.py index bf1ac91e..283f55db 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1981,9 +1981,9 @@ class SyncServer(Server): """Delete a tool""" self.ms.delete_tool(tool_id) - def list_tools(self, user_id: str) -> List[Tool]: + def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[Tool]: """List tools available to user_id""" - tools = self.ms.list_tools(user_id) + tools = self.ms.list_tools(cursor=cursor, limit=limit, user_id=user_id) return tools def add_default_tools(self, module_name="base", user_id: Optional[str] = None): diff --git a/tests/test_client.py b/tests/test_client.py index 1a91b7e3..0a5e5620 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -299,6 +299,28 @@ def test_config(client: Union[LocalClient, RESTClient], agent: AgentState): # print("CONFIG", config_response) +def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: AgentState): + tools = client.list_tools() + visited_ids = {t.id: False for t in tools} + + cursor = None + # Choose 3 for uneven buckets (only 7 default tools) + num_tools = 3 + # Construct a complete pagination test to see if we can return all the tools eventually + for _ in range(0, len(tools), num_tools): + curr_tools = client.list_tools(cursor, num_tools) + assert len(curr_tools) <= num_tools + + for curr_tool in curr_tools: + assert curr_tool.id in visited_ids + visited_ids[curr_tool.id] = True + + cursor = curr_tools[-1].id + + # Assert that everything has been visited + assert all(visited_ids.values()) + + def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState): # clear sources for source in client.list_sources(): diff --git a/tests/test_new_client.py b/tests/test_new_client.py index 3ddfc1ea..fd502ab8 100644 --- a/tests/test_new_client.py +++ b/tests/test_new_client.py @@ -36,8 +36,6 @@ def agent(client): def test_agent(client: Union[LocalClient, RESTClient]): - tools = client.list_tools() - # create agent agent_state_test = client.create_agent( name="test_agent2", @@ -51,6 +49,7 @@ def test_agent(client: Union[LocalClient, RESTClient]): assert agent_state_test.id in [a.id for a in agents] # get agent + tools = client.list_tools() print("TOOLS", [t.name for t in tools]) agent_state = client.get_agent(agent_state_test.id) assert agent_state.name == "test_agent2"