feat: Add pagination for list tools (#1907)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
@@ -200,7 +200,7 @@ class AbstractClient(object):
|
||||
) -> Tool:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_tools(self) -> List[Tool]:
|
||||
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_tool(self, id: str) -> Tool:
|
||||
@@ -1382,14 +1382,19 @@ class RESTClient(AbstractClient):
|
||||
# raise ValueError(f"Failed to create tool: {response.text}")
|
||||
# return ToolModel(**response.json())
|
||||
|
||||
def list_tools(self) -> List[Tool]:
|
||||
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
|
||||
"""
|
||||
List available tools for the user.
|
||||
|
||||
Returns:
|
||||
tools (List[Tool]): List of tools
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", headers=self.headers)
|
||||
params = {}
|
||||
if cursor:
|
||||
params["cursor"] = str(cursor)
|
||||
if limit:
|
||||
params["limit"] = limit
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list tools: {response.text}")
|
||||
return [Tool(**tool) for tool in response.json()]
|
||||
@@ -2281,15 +2286,14 @@ class LocalClient(AbstractClient):
|
||||
ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name), self.user_id
|
||||
)
|
||||
|
||||
def list_tools(self):
|
||||
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
|
||||
"""
|
||||
List available tools for the user.
|
||||
|
||||
Returns:
|
||||
tools (List[Tool]): List of tools
|
||||
"""
|
||||
tools = self.server.list_tools(user_id=self.user_id)
|
||||
return tools
|
||||
return self.server.list_tools(cursor=cursor, limit=limit, user_id=self.user_id)
|
||||
|
||||
def get_tool(self, id: str) -> Optional[Tool]:
|
||||
"""
|
||||
|
||||
@@ -14,7 +14,9 @@ from sqlalchemy import (
|
||||
Integer,
|
||||
String,
|
||||
TypeDecorator,
|
||||
asc,
|
||||
desc,
|
||||
or_,
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
@@ -707,12 +709,19 @@ class MetadataStore:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
|
||||
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
|
||||
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(ToolModel).filter(ToolModel.user_id == None).all()
|
||||
if user_id:
|
||||
results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all()
|
||||
# Query for public tools or user-specific tools
|
||||
query = session.query(ToolModel).filter(or_(ToolModel.user_id == None, ToolModel.user_id == user_id))
|
||||
|
||||
# Apply cursor if provided (assuming cursor is an ID)
|
||||
if cursor:
|
||||
query = query.filter(ToolModel.id > cursor)
|
||||
|
||||
# Order by ID and apply limit
|
||||
results = query.order_by(asc(ToolModel.id)).limit(limit).all()
|
||||
|
||||
# Convert to records
|
||||
res = [r.to_record() for r in results]
|
||||
return res
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ def create_application() -> "FastAPI":
|
||||
title="Letta",
|
||||
summary="Create LLM agents with long-term memory and custom tools 📚🦙",
|
||||
version="1.0.0", # TODO wire this up to the version in the package
|
||||
debug=True,
|
||||
)
|
||||
|
||||
if "--ade" in sys.argv:
|
||||
|
||||
@@ -59,18 +59,21 @@ def get_tool_id(
|
||||
|
||||
@router.get("/", response_model=List[Tool], operation_id="list_tools")
|
||||
def list_all_tools(
|
||||
cursor: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a list of all tools available to agents created by a user
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor.id
|
||||
|
||||
# TODO: add back when user-specific
|
||||
return server.list_tools(user_id=actor.id)
|
||||
# return server.ms.list_tools(user_id=None)
|
||||
try:
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
return server.list_tools(cursor=cursor, limit=limit, user_id=actor.id)
|
||||
except Exception as e:
|
||||
# Log or print the full exception here for debugging
|
||||
print(f"Error occurred: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/", response_model=Tool, operation_id="create_tool")
|
||||
|
||||
@@ -1981,9 +1981,9 @@ class SyncServer(Server):
|
||||
"""Delete a tool"""
|
||||
self.ms.delete_tool(tool_id)
|
||||
|
||||
def list_tools(self, user_id: str) -> List[Tool]:
|
||||
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[Tool]:
|
||||
"""List tools available to user_id"""
|
||||
tools = self.ms.list_tools(user_id)
|
||||
tools = self.ms.list_tools(cursor=cursor, limit=limit, user_id=user_id)
|
||||
return tools
|
||||
|
||||
def add_default_tools(self, module_name="base", user_id: Optional[str] = None):
|
||||
|
||||
@@ -299,6 +299,28 @@ def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# print("CONFIG", config_response)
|
||||
|
||||
|
||||
def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
tools = client.list_tools()
|
||||
visited_ids = {t.id: False for t in tools}
|
||||
|
||||
cursor = None
|
||||
# Choose 3 for uneven buckets (only 7 default tools)
|
||||
num_tools = 3
|
||||
# Construct a complete pagination test to see if we can return all the tools eventually
|
||||
for _ in range(0, len(tools), num_tools):
|
||||
curr_tools = client.list_tools(cursor, num_tools)
|
||||
assert len(curr_tools) <= num_tools
|
||||
|
||||
for curr_tool in curr_tools:
|
||||
assert curr_tool.id in visited_ids
|
||||
visited_ids[curr_tool.id] = True
|
||||
|
||||
cursor = curr_tools[-1].id
|
||||
|
||||
# Assert that everything has been visited
|
||||
assert all(visited_ids.values())
|
||||
|
||||
|
||||
def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# clear sources
|
||||
for source in client.list_sources():
|
||||
|
||||
@@ -36,8 +36,6 @@ def agent(client):
|
||||
|
||||
|
||||
def test_agent(client: Union[LocalClient, RESTClient]):
|
||||
tools = client.list_tools()
|
||||
|
||||
# create agent
|
||||
agent_state_test = client.create_agent(
|
||||
name="test_agent2",
|
||||
@@ -51,6 +49,7 @@ def test_agent(client: Union[LocalClient, RESTClient]):
|
||||
assert agent_state_test.id in [a.id for a in agents]
|
||||
|
||||
# get agent
|
||||
tools = client.list_tools()
|
||||
print("TOOLS", [t.name for t in tools])
|
||||
agent_state = client.get_agent(agent_state_test.id)
|
||||
assert agent_state.name == "test_agent2"
|
||||
|
||||
Reference in New Issue
Block a user