feat: Add pagination for list tools (#1907)

Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
Matthew Zhou
2024-10-18 15:19:45 -07:00
committed by GitHub
parent 11b8371953
commit aa55a3d10e
7 changed files with 59 additions and 21 deletions

View File

@@ -200,7 +200,7 @@ class AbstractClient(object):
) -> Tool:
raise NotImplementedError
def list_tools(self) -> List[Tool]:
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
raise NotImplementedError
def get_tool(self, id: str) -> Tool:
@@ -1382,14 +1382,19 @@ class RESTClient(AbstractClient):
# raise ValueError(f"Failed to create tool: {response.text}")
# return ToolModel(**response.json())
def list_tools(self) -> List[Tool]:
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
"""
List available tools for the user.
Returns:
tools (List[Tool]): List of tools
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", headers=self.headers)
params = {}
if cursor:
params["cursor"] = str(cursor)
if limit:
params["limit"] = limit
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", params=params, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list tools: {response.text}")
return [Tool(**tool) for tool in response.json()]
@@ -2281,15 +2286,14 @@ class LocalClient(AbstractClient):
ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name), self.user_id
)
def list_tools(self):
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
"""
List available tools for the user.
Returns:
tools (List[Tool]): List of tools
"""
tools = self.server.list_tools(user_id=self.user_id)
return tools
return self.server.list_tools(cursor=cursor, limit=limit, user_id=self.user_id)
def get_tool(self, id: str) -> Optional[Tool]:
"""

View File

@@ -14,7 +14,9 @@ from sqlalchemy import (
Integer,
String,
TypeDecorator,
asc,
desc,
or_,
)
from sqlalchemy.sql import func
@@ -707,12 +709,19 @@ class MetadataStore:
session.commit()
@enforce_types
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]:
with self.session_maker() as session:
results = session.query(ToolModel).filter(ToolModel.user_id == None).all()
if user_id:
results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all()
# Query for public tools or user-specific tools
query = session.query(ToolModel).filter(or_(ToolModel.user_id == None, ToolModel.user_id == user_id))
# Apply cursor if provided (assuming cursor is an ID)
if cursor:
query = query.filter(ToolModel.id > cursor)
# Order by ID and apply limit
results = query.order_by(asc(ToolModel.id)).limit(limit).all()
# Convert to records
res = [r.to_record() for r in results]
return res

View File

@@ -70,6 +70,7 @@ def create_application() -> "FastAPI":
title="Letta",
summary="Create LLM agents with long-term memory and custom tools 📚🦙",
version="1.0.0", # TODO wire this up to the version in the package
debug=True,
)
if "--ade" in sys.argv:

View File

@@ -59,18 +59,21 @@ def get_tool_id(
@router.get("/", response_model=List[Tool], operation_id="list_tools")
def list_all_tools(
cursor: Optional[str] = None,
limit: Optional[int] = 50,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a list of all tools available to agents created by a user
"""
actor = server.get_user_or_default(user_id=user_id)
actor.id
# TODO: add back when user-specific
return server.list_tools(user_id=actor.id)
# return server.ms.list_tools(user_id=None)
try:
actor = server.get_user_or_default(user_id=user_id)
return server.list_tools(cursor=cursor, limit=limit, user_id=actor.id)
except Exception as e:
# Log or print the full exception here for debugging
print(f"Error occurred: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/", response_model=Tool, operation_id="create_tool")

View File

@@ -1981,9 +1981,9 @@ class SyncServer(Server):
"""Delete a tool"""
self.ms.delete_tool(tool_id)
def list_tools(self, user_id: str) -> List[Tool]:
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[Tool]:
"""List tools available to user_id"""
tools = self.ms.list_tools(user_id)
tools = self.ms.list_tools(cursor=cursor, limit=limit, user_id=user_id)
return tools
def add_default_tools(self, module_name="base", user_id: Optional[str] = None):

View File

@@ -299,6 +299,28 @@ def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
# print("CONFIG", config_response)
def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
tools = client.list_tools()
visited_ids = {t.id: False for t in tools}
cursor = None
# Choose 3 for uneven buckets (only 7 default tools)
num_tools = 3
# Construct a complete pagination test to see if we can return all the tools eventually
for _ in range(0, len(tools), num_tools):
curr_tools = client.list_tools(cursor, num_tools)
assert len(curr_tools) <= num_tools
for curr_tool in curr_tools:
assert curr_tool.id in visited_ids
visited_ids[curr_tool.id] = True
cursor = curr_tools[-1].id
# Assert that everything has been visited
assert all(visited_ids.values())
def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
# clear sources
for source in client.list_sources():

View File

@@ -36,8 +36,6 @@ def agent(client):
def test_agent(client: Union[LocalClient, RESTClient]):
tools = client.list_tools()
# create agent
agent_state_test = client.create_agent(
name="test_agent2",
@@ -51,6 +49,7 @@ def test_agent(client: Union[LocalClient, RESTClient]):
assert agent_state_test.id in [a.id for a in agents]
# get agent
tools = client.list_tools()
print("TOOLS", [t.name for t in tools])
agent_state = client.get_agent(agent_state_test.id)
assert agent_state.name == "test_agent2"