feat: add pagination to list agents for blocks endpoint (#3476)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, or_, select
|
||||
@@ -17,6 +18,7 @@ from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.enums import ActorType
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -357,19 +359,87 @@ class BlockManager:
|
||||
block_id: str,
|
||||
actor: PydanticUser,
|
||||
include_relationships: Optional[List[str]] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
ascending: bool = True,
|
||||
) -> List[PydanticAgentState]:
|
||||
"""
|
||||
Retrieve all agents associated with a given block.
|
||||
Retrieve all agents associated with a given block with pagination support.
|
||||
|
||||
Args:
|
||||
block_id: ID of the block to get agents for
|
||||
actor: User performing the operation
|
||||
include_relationships: List of relationships to include in the response
|
||||
before: Cursor for pagination (get items before this ID)
|
||||
after: Cursor for pagination (get items after this ID)
|
||||
limit: Maximum number of items to return
|
||||
ascending: Sort order (True for ascending, False for descending)
|
||||
|
||||
Returns:
|
||||
List of agent states associated with the block
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Start with a basic query
|
||||
query = (
|
||||
select(AgentModel)
|
||||
.where(AgentModel.id.in_(select(BlocksAgents.agent_id).where(BlocksAgents.block_id == block_id)))
|
||||
.where(AgentModel.organization_id == actor.organization_id)
|
||||
)
|
||||
|
||||
# Apply pagination using cursor-based approach
|
||||
if after:
|
||||
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after))).first()
|
||||
if result:
|
||||
after_sort_value, after_id = result
|
||||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime):
|
||||
after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
if ascending:
|
||||
query = query.where(
|
||||
AgentModel.created_at > after_sort_value,
|
||||
or_(AgentModel.created_at == after_sort_value, AgentModel.id > after_id),
|
||||
)
|
||||
else:
|
||||
query = query.where(
|
||||
AgentModel.created_at < after_sort_value,
|
||||
or_(AgentModel.created_at == after_sort_value, AgentModel.id < after_id),
|
||||
)
|
||||
|
||||
if before:
|
||||
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before))).first()
|
||||
if result:
|
||||
before_sort_value, before_id = result
|
||||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime):
|
||||
before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
if ascending:
|
||||
query = query.where(
|
||||
AgentModel.created_at < before_sort_value,
|
||||
or_(AgentModel.created_at == before_sort_value, AgentModel.id < before_id),
|
||||
)
|
||||
else:
|
||||
query = query.where(
|
||||
AgentModel.created_at > before_sort_value,
|
||||
or_(AgentModel.created_at == before_sort_value, AgentModel.id > before_id),
|
||||
)
|
||||
|
||||
# Apply sorting
|
||||
if ascending:
|
||||
query = query.order_by(AgentModel.created_at.asc(), AgentModel.id.asc())
|
||||
else:
|
||||
query = query.order_by(AgentModel.created_at.desc(), AgentModel.id.desc())
|
||||
|
||||
# Apply limit
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# Execute the query
|
||||
result = await session.execute(query)
|
||||
agents_orm = result.scalars().all()
|
||||
|
||||
agents = await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents_orm])
|
||||
return agents
|
||||
|
||||
|
||||
Reference in New Issue
Block a user