diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index e6a207c1..72ca9847 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -205,22 +205,47 @@ class LLMBatchManager: return item.update(db_session=session, actor=actor).to_pydantic() - # TODO: Maybe make this paginated? @enforce_types def list_batch_items( self, batch_id: str, limit: Optional[int] = None, actor: Optional[PydanticUser] = None, + after: Optional[str] = None, + agent_id: Optional[str] = None, + request_status: Optional[JobStatus] = None, + step_status: Optional[AgentStepStatus] = None, ) -> List[PydanticLLMBatchItem]: - """List all batch items for a given batch_id, optionally filtered by organization and limited in count.""" + """ + List all batch items for a given batch_id, optionally filtered by additional criteria and limited in count. + + Optional filters: + - after: A cursor string. Only items with an `id` greater than this value are returned. + - agent_id: Restrict the result set to a specific agent. + - request_status: Filter items based on their request status (e.g., created, completed, expired). + - step_status: Filter items based on their step execution status. + + The results are ordered by their id in ascending order. + """ with self.session_maker() as session: query = session.query(LLMBatchItem).filter(LLMBatchItem.batch_id == batch_id) if actor is not None: query = query.filter(LLMBatchItem.organization_id == actor.organization_id) - if limit: + # Additional optional filters + if agent_id is not None: + query = query.filter(LLMBatchItem.agent_id == agent_id) + if request_status is not None: + query = query.filter(LLMBatchItem.request_status == request_status) + if step_status is not None: + query = query.filter(LLMBatchItem.step_status == step_status) + if after is not None: + query = query.filter(LLMBatchItem.id > after) + + query = query.order_by(LLMBatchItem.id.asc()) + + if limit is not None: query = query.limit(limit) results = query.all() diff --git a/tests/test_managers.py b/tests/test_managers.py index 16f73c29..2d868cd3 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4947,6 +4947,65 @@ def test_list_batch_items_limit_and_filter(server, default_user, sarah_agent, du assert len(limited_items) == 2 +def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): + # Create a batch job. + batch = server.batch_manager.create_batch_job( + llm_provider=ProviderType.anthropic, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + + # Create 10 batch items. + created_items = [] + for i in range(10): + item = server.batch_manager.create_batch_item( + batch_id=batch.id, + agent_id=sarah_agent.id, + llm_config=dummy_llm_config, + step_state=dummy_step_state, + actor=default_user, + ) + created_items.append(item) + + # Retrieve all items (without pagination). + all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user) + assert len(all_items) >= 10, f"Expected at least 10 items, got {len(all_items)}" + + # Verify the items are ordered ascending by id (based on our implementation). + sorted_ids = [item.id for item in sorted(all_items, key=lambda i: i.id)] + retrieved_ids = [item.id for item in all_items] + assert retrieved_ids == sorted_ids, "Batch items are not ordered in ascending order by id" + + # Choose a cursor: the id of the 5th item. + cursor = all_items[4].id + + # Retrieve items after the cursor. + paged_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=cursor) + + # All returned items should have an id greater than the cursor. + for item in paged_items: + assert item.id > cursor, f"Item id {item.id} is not greater than the cursor {cursor}" + + # Count expected remaining items. + # Find the index of the cursor in our sorted list. + cursor_index = sorted_ids.index(cursor) + expected_remaining = len(sorted_ids) - cursor_index - 1 + assert len(paged_items) == expected_remaining, f"Expected {expected_remaining} items after cursor, got {len(paged_items)}" + + # Test pagination with a limit. + limit = 3 + limited_page = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=cursor, limit=limit) + # If more than 'limit' items remain, we should only get exactly 'limit' items. + assert len(limited_page) == min( + limit, expected_remaining + ), f"Expected {min(limit, expected_remaining)} items with limit {limit}, got {len(limited_page)}" + + # Optional: Test with a cursor beyond the last item returns an empty list. + last_cursor = sorted_ids[-1] + empty_page = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=last_cursor) + assert empty_page == [], "Expected an empty list when cursor is after the last item" + + def test_bulk_update_batch_items_request_status_by_agent( server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state ):