diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index c5ec3533..539f34c1 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -847,3 +847,29 @@ async def send_batch_messages( ) return await batch_runner.step_until_request(batch_requests=batch_requests) + + +@router.get( + "/messages/batches/{batch_id}", + response_model=LettaBatchResponse, + operation_id="retrieve_batch_message_request", +) +async def retrieve_batch_message_request( + batch_id: str, + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Retrieve the result or current status of a previously submitted batch message request. + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + batch_job = server.batch_manager.get_batch_job_by_id(batch_id=batch_id, actor=actor) + agent_count = server.batch_manager.count_batch_items(batch_id=batch_id) + + return LettaBatchResponse( + batch_id=batch_id, + status=batch_job.status, + agent_count=agent_count, + last_polled_at=batch_job.last_polled_at, + created_at=batch_job.created_at, + ) diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index a3ee9611..e6a207c1 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -2,7 +2,7 @@ import datetime from typing import Any, Dict, List, Optional, Tuple from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse -from sqlalchemy import tuple_ +from sqlalchemy import func, tuple_ from letta.jobs.types import BatchPollingResult, ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo from letta.log import get_logger @@ -312,3 +312,18 @@ class LLMBatchManager: with self.session_maker() as session: item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) item.hard_delete(db_session=session, actor=actor) + + @enforce_types + def count_batch_items(self, batch_id: str) -> int: + """ + Efficiently count the number of batch items for a given batch_id. + + Args: + batch_id (str): The batch identifier to count items for. + + Returns: + int: The total number of batch items associated with the given batch_id. + """ + with self.session_maker() as session: + count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.batch_id == batch_id).scalar() + return count or 0 diff --git a/tests/test_managers.py b/tests/test_managers.py index 5bca8587..16f73c29 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -5058,3 +5058,30 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m for item_id in created_ids: fetched = server.batch_manager.get_batch_item_by_id(item_id, actor=default_user) assert fetched.id in created_ids + + +def test_count_batch_items(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): + # Create a batch job first. + batch = server.batch_manager.create_batch_job( + llm_provider=ProviderType.anthropic, + status=JobStatus.created, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + + # Create a specific number of batch items for this batch. + num_items = 5 + for _ in range(num_items): + server.batch_manager.create_batch_item( + batch_id=batch.id, + agent_id=sarah_agent.id, + llm_config=dummy_llm_config, + step_state=dummy_step_state, + actor=default_user, + ) + + # Use the count_batch_items method to count the items. + count = server.batch_manager.count_batch_items(batch_id=batch.id) + + # Assert that the count matches the expected number. + assert count == num_items, f"Expected {num_items} items, got {count}"