feat: Create polling endpoint for batch (#1723)
This commit is contained in:
@@ -847,3 +847,29 @@ async def send_batch_messages(
|
||||
)
|
||||
|
||||
return await batch_runner.step_until_request(batch_requests=batch_requests)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/messages/batches/{batch_id}",
|
||||
response_model=LettaBatchResponse,
|
||||
operation_id="retrieve_batch_message_request",
|
||||
)
|
||||
async def retrieve_batch_message_request(
|
||||
batch_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Retrieve the result or current status of a previously submitted batch message request.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
batch_job = server.batch_manager.get_batch_job_by_id(batch_id=batch_id, actor=actor)
|
||||
agent_count = server.batch_manager.count_batch_items(batch_id=batch_id)
|
||||
|
||||
return LettaBatchResponse(
|
||||
batch_id=batch_id,
|
||||
status=batch_job.status,
|
||||
agent_count=agent_count,
|
||||
last_polled_at=batch_job.last_polled_at,
|
||||
created_at=batch_job.created_at,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
||||
from sqlalchemy import tuple_
|
||||
from sqlalchemy import func, tuple_
|
||||
|
||||
from letta.jobs.types import BatchPollingResult, ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo
|
||||
from letta.log import get_logger
|
||||
@@ -312,3 +312,18 @@ class LLMBatchManager:
|
||||
with self.session_maker() as session:
|
||||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||||
item.hard_delete(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def count_batch_items(self, batch_id: str) -> int:
|
||||
"""
|
||||
Efficiently count the number of batch items for a given batch_id.
|
||||
|
||||
Args:
|
||||
batch_id (str): The batch identifier to count items for.
|
||||
|
||||
Returns:
|
||||
int: The total number of batch items associated with the given batch_id.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.batch_id == batch_id).scalar()
|
||||
return count or 0
|
||||
|
||||
@@ -5058,3 +5058,30 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m
|
||||
for item_id in created_ids:
|
||||
fetched = server.batch_manager.get_batch_item_by_id(item_id, actor=default_user)
|
||||
assert fetched.id in created_ids
|
||||
|
||||
|
||||
def test_count_batch_items(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
|
||||
# Create a batch job first.
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create a specific number of batch items for this batch.
|
||||
num_items = 5
|
||||
for _ in range(num_items):
|
||||
server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Use the count_batch_items method to count the items.
|
||||
count = server.batch_manager.count_batch_items(batch_id=batch.id)
|
||||
|
||||
# Assert that the count matches the expected number.
|
||||
assert count == num_items, f"Expected {num_items} items, got {count}"
|
||||
|
||||
Reference in New Issue
Block a user