diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index f4c0e530..97c52054 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -1,7 +1,7 @@ import asyncio from typing import Dict, List, Optional -from sqlalchemy import delete, select +from sqlalchemy import delete, or_, select from sqlalchemy.orm import Session from letta.log import get_logger @@ -176,7 +176,10 @@ class BlockManager: template_name: Optional[str] = None, identity_id: Optional[str] = None, identifier_keys: Optional[List[str]] = None, + before: Optional[str] = None, + after: Optional[str] = None, limit: Optional[int] = 50, + ascending: bool = True, ) -> List[PydanticBlock]: """Async version of get_blocks method. Retrieve blocks based on various optional filters.""" from sqlalchemy import select @@ -205,19 +208,67 @@ class BlockManager: if template_name: query = query.where(BlockModel.template_name == template_name) + needs_distinct = False + if identifier_keys: - query = ( - query.join(BlockModel.identities) - .filter(BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys)) - .distinct(BlockModel.id) + query = query.join(BlockModel.identities).filter( + BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys) ) + needs_distinct = True if identity_id: - query = ( - query.join(BlockModel.identities) - .filter(BlockModel.identities.property.mapper.class_.id == identity_id) - .distinct(BlockModel.id) - ) + query = query.join(BlockModel.identities).filter(BlockModel.identities.property.mapper.class_.id == identity_id) + needs_distinct = True + + if after: + result = (await session.execute(select(BlockModel.created_at, BlockModel.id).where(BlockModel.id == after))).first() + if result: + after_sort_value, after_id = result + # SQLite does not support as granular timestamping, so we need to round the timestamp + if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime): + after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S") + + if ascending: + query = query.where( + BlockModel.created_at > after_sort_value, + or_(BlockModel.created_at == after_sort_value, BlockModel.id > after_id), + ) + else: + query = query.where( + BlockModel.created_at < after_sort_value, + or_(BlockModel.created_at == after_sort_value, BlockModel.id < after_id), + ) + + if before: + result = (await session.execute(select(BlockModel.created_at, BlockModel.id).where(BlockModel.id == before))).first() + if result: + before_sort_value, before_id = result + # SQLite does not support as granular timestamping, so we need to round the timestamp + if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime): + before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S") + + if ascending: + query = query.where( + BlockModel.created_at < before_sort_value, + or_(BlockModel.created_at == before_sort_value, BlockModel.id < before_id), + ) + else: + query = query.where( + BlockModel.created_at > before_sort_value, + or_(BlockModel.created_at == before_sort_value, BlockModel.id > before_id), + ) + + # Apply ordering and handle distinct if needed + if needs_distinct: + if ascending: + query = query.distinct(BlockModel.id).order_by(BlockModel.id.asc(), BlockModel.created_at.asc()) + else: + query = query.distinct(BlockModel.id).order_by(BlockModel.id.desc(), BlockModel.created_at.desc()) + else: + if ascending: + query = query.order_by(BlockModel.created_at.asc(), BlockModel.id.asc()) + else: + query = query.order_by(BlockModel.created_at.desc(), BlockModel.id.desc()) # Add limit if limit: