feat: Create polling endpoint for batch (#1723)

This commit is contained in:
Matthew Zhou
2025-04-15 14:43:23 -07:00
committed by GitHub
parent f6cbaa04c9
commit a7fefea13c
3 changed files with 69 additions and 1 deletions

View File

@@ -847,3 +847,29 @@ async def send_batch_messages(
)
return await batch_runner.step_until_request(batch_requests=batch_requests)
@router.get(
"/messages/batches/{batch_id}",
response_model=LettaBatchResponse,
operation_id="retrieve_batch_message_request",
)
async def retrieve_batch_message_request(
batch_id: str,
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Retrieve the result or current status of a previously submitted batch message request.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
batch_job = server.batch_manager.get_batch_job_by_id(batch_id=batch_id, actor=actor)
agent_count = server.batch_manager.count_batch_items(batch_id=batch_id)
return LettaBatchResponse(
batch_id=batch_id,
status=batch_job.status,
agent_count=agent_count,
last_polled_at=batch_job.last_polled_at,
created_at=batch_job.created_at,
)

View File

@@ -2,7 +2,7 @@ import datetime
from typing import Any, Dict, List, Optional, Tuple
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
from sqlalchemy import tuple_
from sqlalchemy import func, tuple_
from letta.jobs.types import BatchPollingResult, ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo
from letta.log import get_logger
@@ -312,3 +312,18 @@ class LLMBatchManager:
with self.session_maker() as session:
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
item.hard_delete(db_session=session, actor=actor)
@enforce_types
def count_batch_items(self, batch_id: str) -> int:
"""
Efficiently count the number of batch items for a given batch_id.
Args:
batch_id (str): The batch identifier to count items for.
Returns:
int: The total number of batch items associated with the given batch_id.
"""
with self.session_maker() as session:
count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.batch_id == batch_id).scalar()
return count or 0

View File

@@ -5058,3 +5058,30 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m
for item_id in created_ids:
fetched = server.batch_manager.get_batch_item_by_id(item_id, actor=default_user)
assert fetched.id in created_ids
def test_count_batch_items(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
# Create a batch job first.
batch = server.batch_manager.create_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
)
# Create a specific number of batch items for this batch.
num_items = 5
for _ in range(num_items):
server.batch_manager.create_batch_item(
batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
actor=default_user,
)
# Use the count_batch_items method to count the items.
count = server.batch_manager.count_batch_items(batch_id=batch.id)
# Assert that the count matches the expected number.
assert count == num_items, f"Expected {num_items} items, got {count}"