feat(asyncify): more batch db calls (#2350)
This commit is contained in:
@@ -233,7 +233,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
ctx = await self._collect_resume_context(llm_batch_id)
|
||||
|
||||
log_event(name="update_statuses")
|
||||
self._update_request_statuses(ctx.request_status_updates)
|
||||
await self._update_request_statuses_async(ctx.request_status_updates)
|
||||
|
||||
log_event(name="exec_tools")
|
||||
exec_results = await self._execute_tools(ctx)
|
||||
@@ -242,7 +242,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
msg_map = await self._persist_tool_messages(exec_results, ctx)
|
||||
|
||||
log_event(name="mark_steps_done")
|
||||
self._mark_steps_complete(llm_batch_id, ctx.agent_ids)
|
||||
await self._mark_steps_complete_async(llm_batch_id, ctx.agent_ids)
|
||||
|
||||
log_event(name="prepare_next")
|
||||
next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map)
|
||||
@@ -382,9 +382,9 @@ class LettaAgentBatch(BaseAgent):
|
||||
|
||||
return self._extract_tool_call_and_decide_continue(tool_call, item.step_state)
|
||||
|
||||
def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None:
|
||||
async def _update_request_statuses_async(self, updates: List[RequestStatusUpdateInfo]) -> None:
|
||||
if updates:
|
||||
self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(updates=updates)
|
||||
await self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async(updates=updates)
|
||||
|
||||
def _build_sandbox(self) -> Tuple[SandboxConfig, Dict[str, Any]]:
|
||||
sbx_type = SandboxType.E2B if tool_settings.e2b_api_key else SandboxType.LOCAL
|
||||
@@ -474,11 +474,11 @@ class LettaAgentBatch(BaseAgent):
|
||||
await self.message_manager.create_many_messages_async([m for msgs in msg_map.values() for m in msgs], actor=self.actor)
|
||||
return msg_map
|
||||
|
||||
def _mark_steps_complete(self, llm_batch_id: str, agent_ids: List[str]) -> None:
|
||||
async def _mark_steps_complete_async(self, llm_batch_id: str, agent_ids: List[str]) -> None:
|
||||
updates = [
|
||||
StepStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, step_status=AgentStepStatus.completed) for aid in agent_ids
|
||||
]
|
||||
self.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(updates)
|
||||
await self.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async(updates)
|
||||
|
||||
def _prepare_next_iteration(
|
||||
self,
|
||||
|
||||
@@ -106,7 +106,7 @@ async def poll_batch_updates(server: SyncServer, batch_jobs: List[LLMBatchJob],
|
||||
results: List[BatchPollingResult] = await asyncio.gather(*coros)
|
||||
|
||||
# Update the server with batch status changes
|
||||
server.batch_manager.bulk_update_llm_batch_statuses(updates=results)
|
||||
await server.batch_manager.bulk_update_llm_batch_statuses_async(updates=results)
|
||||
logger.info(f"[Poll BatchJob] Bulk-updated {len(results)} LLM batch(es) in the DB at job level.")
|
||||
|
||||
return results
|
||||
@@ -197,13 +197,13 @@ async def poll_running_llm_batches(server: "SyncServer") -> List[LettaBatchRespo
|
||||
# 6. Bulk update all items for newly completed batch(es)
|
||||
if item_updates:
|
||||
metrics.updated_items_count = len(item_updates)
|
||||
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(item_updates)
|
||||
await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async(item_updates)
|
||||
|
||||
# ─── Kick off post‑processing for each batch that just completed ───
|
||||
completed = [r for r in batch_results if r.request_status == JobStatus.completed]
|
||||
|
||||
async def _resume(batch_row: LLMBatchJob) -> LettaBatchResponse:
|
||||
actor: User = server.user_manager.get_user_by_id(batch_row.created_by_id)
|
||||
actor: User = await server.user_manager.get_actor_by_id_async(batch_row.created_by_id)
|
||||
runner = LettaAgentBatch(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
|
||||
@@ -7,7 +7,7 @@ from apscheduler.triggers.interval import IntervalTrigger
|
||||
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.log import get_logger
|
||||
from letta.server.db import db_context
|
||||
from letta.server.db import db_registry
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
@@ -34,18 +34,15 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool:
|
||||
acquired_lock = False
|
||||
try:
|
||||
# Use a temporary connection context for the attempt initially
|
||||
with db_context() as session:
|
||||
engine = session.get_bind()
|
||||
# Get raw connection - MUST be kept open if lock is acquired
|
||||
raw_conn = engine.raw_connection()
|
||||
cur = raw_conn.cursor()
|
||||
async with db_registry.async_session() as session:
|
||||
raw_conn = await session.connection()
|
||||
|
||||
cur.execute("SELECT pg_try_advisory_lock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,))
|
||||
acquired_lock = cur.fetchone()[0]
|
||||
# Try to acquire the advisory lock
|
||||
result = await session.execute(f"SELECT pg_try_advisory_lock(CAST({ADVISORY_LOCK_KEY} AS bigint))")
|
||||
acquired_lock = result.scalar_one()
|
||||
|
||||
if not acquired_lock:
|
||||
cur.close()
|
||||
raw_conn.close()
|
||||
await raw_conn.close()
|
||||
logger.info("Scheduler lock held by another instance.")
|
||||
return False
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ async def list_batch_messages(
|
||||
|
||||
# Get messages directly using our efficient method
|
||||
# We'll need to update the underlying implementation to use message_id as cursor
|
||||
messages = server.batch_manager.get_messages_for_letta_batch(
|
||||
messages = await server.batch_manager.get_messages_for_letta_batch_async(
|
||||
letta_batch_job_id=batch_id, limit=limit, actor=actor, agent_id=agent_id, sort_descending=sort_descending, cursor=cursor
|
||||
)
|
||||
|
||||
@@ -184,7 +184,7 @@ async def cancel_batch_run(
|
||||
job = await server.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(status=JobStatus.cancelled), actor=actor)
|
||||
|
||||
# Get related llm batch jobs
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=job.id, actor=actor)
|
||||
llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(letta_batch_id=job.id, actor=actor)
|
||||
for llm_batch_job in llm_batch_jobs:
|
||||
if llm_batch_job.status in {JobStatus.running, JobStatus.created}:
|
||||
# TODO: Extend to providers beyond anthropic
|
||||
@@ -194,6 +194,8 @@ async def cancel_batch_run(
|
||||
await server.anthropic_async_client.messages.batches.cancel(anthropic_batch_id)
|
||||
|
||||
# Update all the batch_job statuses
|
||||
server.batch_manager.update_llm_batch_status(llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor)
|
||||
await server.batch_manager.update_llm_batch_status_async(
|
||||
llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor
|
||||
)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
@@ -58,7 +58,7 @@ class LLMBatchManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_llm_batch_status(
|
||||
async def update_llm_batch_status_async(
|
||||
self,
|
||||
llm_batch_id: str,
|
||||
status: JobStatus,
|
||||
@@ -66,15 +66,15 @@ class LLMBatchManager:
|
||||
latest_polling_response: Optional[BetaMessageBatch] = None,
|
||||
) -> PydanticLLMBatchJob:
|
||||
"""Update a batch job’s status and optionally its polling response."""
|
||||
with db_registry.session() as session:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
|
||||
async with db_registry.async_session() as session:
|
||||
batch = await LLMBatchJob.read_async(db_session=session, identifier=llm_batch_id, actor=actor)
|
||||
batch.status = status
|
||||
batch.latest_polling_response = latest_polling_response
|
||||
batch.last_polled_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
batch = batch.update(db_session=session, actor=actor)
|
||||
batch = await batch.update_async(db_session=session, actor=actor)
|
||||
return batch.to_pydantic()
|
||||
|
||||
def bulk_update_llm_batch_statuses(
|
||||
async def bulk_update_llm_batch_statuses_async(
|
||||
self,
|
||||
updates: List[BatchPollingResult],
|
||||
) -> None:
|
||||
@@ -85,7 +85,7 @@ class LLMBatchManager:
|
||||
"""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
with db_registry.session() as session:
|
||||
async with db_registry.async_session() as session:
|
||||
mappings = []
|
||||
for llm_batch_id, status, response in updates:
|
||||
mappings.append(
|
||||
@@ -97,18 +97,18 @@ class LLMBatchManager:
|
||||
}
|
||||
)
|
||||
|
||||
session.bulk_update_mappings(LLMBatchJob, mappings)
|
||||
session.commit()
|
||||
await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchJob, mappings))
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def list_llm_batch_jobs(
|
||||
async def list_llm_batch_jobs_async(
|
||||
self,
|
||||
letta_batch_id: str,
|
||||
limit: Optional[int] = None,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
after: Optional[str] = None,
|
||||
) -> List[PydanticLLMBatchItem]:
|
||||
) -> List[PydanticLLMBatchJob]:
|
||||
"""
|
||||
List all batch items for a given llm_batch_id, optionally filtered by additional criteria and limited in count.
|
||||
|
||||
@@ -120,35 +120,35 @@ class LLMBatchManager:
|
||||
|
||||
The results are ordered by their id in ascending order.
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
query = session.query(LLMBatchJob).filter(LLMBatchJob.letta_batch_job_id == letta_batch_id)
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(LLMBatchJob).where(LLMBatchJob.letta_batch_job_id == letta_batch_id)
|
||||
|
||||
if actor is not None:
|
||||
query = query.filter(LLMBatchJob.organization_id == actor.organization_id)
|
||||
query = query.where(LLMBatchJob.organization_id == actor.organization_id)
|
||||
|
||||
# Additional optional filters
|
||||
if after is not None:
|
||||
query = query.filter(LLMBatchJob.id > after)
|
||||
query = query.where(LLMBatchJob.id > after)
|
||||
|
||||
query = query.order_by(LLMBatchJob.id.asc())
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
return [item.to_pydantic() for item in results]
|
||||
results = await session.execute(query)
|
||||
return [item.to_pydantic() for item in results.scalars().all()]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def delete_llm_batch_request(self, llm_batch_id: str, actor: PydanticUser) -> None:
|
||||
async def delete_llm_batch_request_async(self, llm_batch_id: str, actor: PydanticUser) -> None:
|
||||
"""Hard delete a batch job by ID."""
|
||||
with db_registry.session() as session:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
|
||||
batch.hard_delete(db_session=session, actor=actor)
|
||||
async with db_registry.async_session() as session:
|
||||
batch = await LLMBatchJob.read_async(db_session=session, identifier=llm_batch_id, actor=actor)
|
||||
await batch.hard_delete_async(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_messages_for_letta_batch(
|
||||
async def get_messages_for_letta_batch_async(
|
||||
self,
|
||||
letta_batch_job_id: str,
|
||||
limit: int = 100,
|
||||
@@ -161,12 +161,12 @@ class LLMBatchManager:
|
||||
Retrieve messages across all LLM batch jobs associated with a Letta batch job.
|
||||
Optimized for PostgreSQL performance using ID-based keyset pagination.
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
async with db_registry.async_session() as session:
|
||||
# If cursor is provided, get sequence_id for that message
|
||||
cursor_sequence_id = None
|
||||
if cursor:
|
||||
cursor_query = session.query(MessageModel.sequence_id).filter(MessageModel.id == cursor).limit(1)
|
||||
cursor_result = cursor_query.first()
|
||||
cursor_query = select(MessageModel.sequence_id).where(MessageModel.id == cursor).limit(1)
|
||||
cursor_result = await session.execute(cursor_query)
|
||||
if cursor_result:
|
||||
cursor_sequence_id = cursor_result[0]
|
||||
else:
|
||||
@@ -174,24 +174,24 @@ class LLMBatchManager:
|
||||
pass
|
||||
|
||||
query = (
|
||||
session.query(MessageModel)
|
||||
select(MessageModel)
|
||||
.join(LLMBatchItem, MessageModel.batch_item_id == LLMBatchItem.id)
|
||||
.join(LLMBatchJob, LLMBatchItem.llm_batch_id == LLMBatchJob.id)
|
||||
.filter(LLMBatchJob.letta_batch_job_id == letta_batch_job_id)
|
||||
.where(LLMBatchJob.letta_batch_job_id == letta_batch_job_id)
|
||||
)
|
||||
|
||||
if actor is not None:
|
||||
query = query.filter(MessageModel.organization_id == actor.organization_id)
|
||||
query = query.where(MessageModel.organization_id == actor.organization_id)
|
||||
|
||||
if agent_id is not None:
|
||||
query = query.filter(MessageModel.agent_id == agent_id)
|
||||
query = query.where(MessageModel.agent_id == agent_id)
|
||||
|
||||
# Apply cursor-based pagination if cursor exists
|
||||
if cursor_sequence_id is not None:
|
||||
if sort_descending:
|
||||
query = query.filter(MessageModel.sequence_id < cursor_sequence_id)
|
||||
query = query.where(MessageModel.sequence_id < cursor_sequence_id)
|
||||
else:
|
||||
query = query.filter(MessageModel.sequence_id > cursor_sequence_id)
|
||||
query = query.where(MessageModel.sequence_id > cursor_sequence_id)
|
||||
|
||||
if sort_descending:
|
||||
query = query.order_by(desc(MessageModel.sequence_id))
|
||||
@@ -200,8 +200,8 @@ class LLMBatchManager:
|
||||
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
return [message.to_pydantic() for message in results]
|
||||
results = await session.execute(query)
|
||||
return [message.to_pydantic() for message in results.scalars().all()]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -218,7 +218,7 @@ class LLMBatchManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_llm_batch_item(
|
||||
async def create_llm_batch_item_async(
|
||||
self,
|
||||
llm_batch_id: str,
|
||||
agent_id: str,
|
||||
@@ -229,7 +229,7 @@ class LLMBatchManager:
|
||||
step_state: Optional[AgentStepState] = None,
|
||||
) -> PydanticLLMBatchItem:
|
||||
"""Create a new batch item."""
|
||||
with db_registry.session() as session:
|
||||
async with db_registry.async_session() as session:
|
||||
item = LLMBatchItem(
|
||||
llm_batch_id=llm_batch_id,
|
||||
agent_id=agent_id,
|
||||
@@ -239,7 +239,7 @@ class LLMBatchManager:
|
||||
step_state=step_state,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
item.create(session, actor=actor)
|
||||
await item.create_async(session, actor=actor)
|
||||
return item.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -280,15 +280,15 @@ class LLMBatchManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_llm_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
|
||||
async def get_llm_batch_item_by_id_async(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
|
||||
"""Retrieve a single batch item by ID."""
|
||||
with db_registry.session() as session:
|
||||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||||
async with db_registry.async_session() as session:
|
||||
item = await LLMBatchItem.read_async(db_session=session, identifier=item_id, actor=actor)
|
||||
return item.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_llm_batch_item(
|
||||
async def update_llm_batch_item_async(
|
||||
self,
|
||||
item_id: str,
|
||||
actor: PydanticUser,
|
||||
@@ -298,8 +298,8 @@ class LLMBatchManager:
|
||||
step_state: Optional[AgentStepState] = None,
|
||||
) -> PydanticLLMBatchItem:
|
||||
"""Update fields on a batch item."""
|
||||
with db_registry.session() as session:
|
||||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||||
async with db_registry.async_session() as session:
|
||||
item = await LLMBatchItem.read_async(db_session=session, identifier=item_id, actor=actor)
|
||||
|
||||
if request_status:
|
||||
item.request_status = request_status
|
||||
@@ -310,7 +310,8 @@ class LLMBatchManager:
|
||||
if step_state:
|
||||
item.step_state = step_state
|
||||
|
||||
return item.update(db_session=session, actor=actor).to_pydantic()
|
||||
result = await item.update_async(db_session=session, actor=actor)
|
||||
return result.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -360,7 +361,7 @@ class LLMBatchManager:
|
||||
return [item.to_pydantic() for item in results.scalars()]
|
||||
|
||||
@trace_method
|
||||
def bulk_update_llm_batch_items(
|
||||
async def bulk_update_llm_batch_items_async(
|
||||
self, llm_batch_id_agent_id_pairs: List[Tuple[str, str]], field_updates: List[Dict[str, Any]], strict: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
@@ -378,13 +379,13 @@ class LLMBatchManager:
|
||||
if len(llm_batch_id_agent_id_pairs) != len(field_updates):
|
||||
raise ValueError("llm_batch_id_agent_id_pairs and field_updates must have the same length")
|
||||
|
||||
with db_registry.session() as session:
|
||||
async with db_registry.async_session() as session:
|
||||
# Lookup primary keys for all requested (batch_id, agent_id) pairs
|
||||
items = (
|
||||
session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id)
|
||||
.filter(tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs))
|
||||
.all()
|
||||
query = select(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).filter(
|
||||
tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
items = result.all()
|
||||
pair_to_pk = {(batch_id, agent_id): pk for pk, batch_id, agent_id in items}
|
||||
|
||||
if strict:
|
||||
@@ -409,12 +410,12 @@ class LLMBatchManager:
|
||||
mappings.append(update_fields)
|
||||
|
||||
if mappings:
|
||||
session.bulk_update_mappings(LLMBatchItem, mappings)
|
||||
session.commit()
|
||||
await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchItem, mappings))
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def bulk_update_batch_llm_items_results_by_agent(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None:
|
||||
async def bulk_update_batch_llm_items_results_by_agent_async(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None:
|
||||
"""Update request status and batch results for multiple batch items."""
|
||||
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
|
||||
field_updates = [
|
||||
@@ -425,37 +426,41 @@ class LLMBatchManager:
|
||||
for update in updates
|
||||
]
|
||||
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
|
||||
await self.bulk_update_llm_batch_items_async(batch_id_agent_id_pairs, field_updates, strict=strict)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def bulk_update_llm_batch_items_step_status_by_agent(self, updates: List[StepStatusUpdateInfo], strict: bool = True) -> None:
|
||||
async def bulk_update_llm_batch_items_step_status_by_agent_async(
|
||||
self, updates: List[StepStatusUpdateInfo], strict: bool = True
|
||||
) -> None:
|
||||
"""Update step status for multiple batch items."""
|
||||
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
|
||||
field_updates = [{"step_status": update.step_status} for update in updates]
|
||||
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
|
||||
await self.bulk_update_llm_batch_items_async(batch_id_agent_id_pairs, field_updates, strict=strict)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def bulk_update_llm_batch_items_request_status_by_agent(self, updates: List[RequestStatusUpdateInfo], strict: bool = True) -> None:
|
||||
async def bulk_update_llm_batch_items_request_status_by_agent_async(
|
||||
self, updates: List[RequestStatusUpdateInfo], strict: bool = True
|
||||
) -> None:
|
||||
"""Update request status for multiple batch items."""
|
||||
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
|
||||
field_updates = [{"request_status": update.request_status} for update in updates]
|
||||
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
|
||||
await self.bulk_update_llm_batch_items_async(batch_id_agent_id_pairs, field_updates, strict=strict)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None:
|
||||
async def delete_llm_batch_item_async(self, item_id: str, actor: PydanticUser) -> None:
|
||||
"""Hard delete a batch item by ID."""
|
||||
with db_registry.session() as session:
|
||||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||||
item.hard_delete(db_session=session, actor=actor)
|
||||
async with db_registry.async_session() as session:
|
||||
item = await LLMBatchItem.read_async(db_session=session, identifier=item_id, actor=actor)
|
||||
await item.hard_delete_async(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def count_llm_batch_items(self, llm_batch_id: str) -> int:
|
||||
async def count_llm_batch_items_async(self, llm_batch_id: str) -> int:
|
||||
"""
|
||||
Efficiently count the number of batch items for a given llm_batch_id.
|
||||
|
||||
@@ -465,6 +470,6 @@ class LLMBatchManager:
|
||||
Returns:
|
||||
int: The total number of batch items associated with the given llm_batch_id.
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.llm_batch_id == llm_batch_id).scalar()
|
||||
return count or 0
|
||||
async with db_registry.async_session() as session:
|
||||
count = await session.execute(select(func.count(LLMBatchItem.id)).where(LLMBatchItem.llm_batch_id == llm_batch_id))
|
||||
return count.scalar() or 0
|
||||
|
||||
@@ -185,7 +185,7 @@ async def create_test_llm_batch_job_async(server, batch_response, default_user):
|
||||
)
|
||||
|
||||
|
||||
def create_test_batch_item(server, batch_id, agent_id, default_user):
|
||||
async def create_test_batch_item(server, batch_id, agent_id, default_user):
|
||||
"""Create a test batch item for the given batch and agent."""
|
||||
dummy_llm_config = LLMConfig(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
@@ -201,7 +201,7 @@ def create_test_batch_item(server, batch_id, agent_id, default_user):
|
||||
step_number=1, tool_rules_solver=ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")])
|
||||
)
|
||||
|
||||
return server.batch_manager.create_llm_batch_item(
|
||||
return await server.batch_manager.create_llm_batch_item_async(
|
||||
llm_batch_id=batch_id,
|
||||
agent_id=agent_id,
|
||||
llm_config=dummy_llm_config,
|
||||
@@ -289,9 +289,9 @@ async def test_polling_mixed_batch_jobs(default_user, server):
|
||||
job_b = await create_test_llm_batch_job_async(server, batch_b_resp, default_user)
|
||||
|
||||
# --- Step 3: Create batch items ---
|
||||
item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user)
|
||||
item_b = create_test_batch_item(server, job_b.id, agent_b.id, default_user)
|
||||
item_c = create_test_batch_item(server, job_b.id, agent_c.id, default_user)
|
||||
item_a = await create_test_batch_item(server, job_a.id, agent_a.id, default_user)
|
||||
item_b = await create_test_batch_item(server, job_b.id, agent_b.id, default_user)
|
||||
item_c = await create_test_batch_item(server, job_b.id, agent_c.id, default_user)
|
||||
|
||||
# --- Step 4: Mock the Anthropic client ---
|
||||
mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b.id, agent_c.id)
|
||||
@@ -316,17 +316,17 @@ async def test_polling_mixed_batch_jobs(default_user, server):
|
||||
|
||||
# --- Step 7: Verify batch item status updates ---
|
||||
# Item A should remain unchanged
|
||||
updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
|
||||
updated_item_a = await server.batch_manager.get_llm_batch_item_by_id_async(item_a.id, actor=default_user)
|
||||
assert updated_item_a.request_status == JobStatus.created
|
||||
assert updated_item_a.batch_request_result is None
|
||||
|
||||
# Item B should be marked as completed with a successful result
|
||||
updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
|
||||
updated_item_b = await server.batch_manager.get_llm_batch_item_by_id_async(item_b.id, actor=default_user)
|
||||
assert updated_item_b.request_status == JobStatus.completed
|
||||
assert updated_item_b.batch_request_result is not None
|
||||
|
||||
# Item C should be marked as failed with an error result
|
||||
updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
|
||||
updated_item_c = await server.batch_manager.get_llm_batch_item_by_id_async(item_c.id, actor=default_user)
|
||||
assert updated_item_c.request_status == JobStatus.failed
|
||||
assert updated_item_c.batch_request_result is not None
|
||||
|
||||
@@ -352,9 +352,9 @@ async def test_polling_mixed_batch_jobs(default_user, server):
|
||||
# Refresh all objects
|
||||
final_job_a = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_a.id, actor=default_user)
|
||||
final_job_b = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_b.id, actor=default_user)
|
||||
final_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
|
||||
final_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
|
||||
final_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
|
||||
final_item_a = await server.batch_manager.get_llm_batch_item_by_id_async(item_a.id, actor=default_user)
|
||||
final_item_b = await server.batch_manager.get_llm_batch_item_by_id_async(item_b.id, actor=default_user)
|
||||
final_item_c = await server.batch_manager.get_llm_batch_item_by_id_async(item_c.id, actor=default_user)
|
||||
|
||||
# Job A should still be polling (last_polled_at should update)
|
||||
assert final_job_a.status == JobStatus.running
|
||||
|
||||
@@ -458,7 +458,9 @@ async def test_partial_error_from_anthropic_batch(
|
||||
letta_batch_job_id=batch_job.id,
|
||||
)
|
||||
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
|
||||
llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(
|
||||
letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user
|
||||
)
|
||||
llm_batch_job = llm_batch_jobs[0]
|
||||
|
||||
# 2. Invoke the polling job and mock responses from Anthropic
|
||||
@@ -571,7 +573,7 @@ async def test_partial_error_from_anthropic_batch(
|
||||
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
|
||||
# Check the total list of messages
|
||||
messages = server.batch_manager.get_messages_for_letta_batch(
|
||||
messages = await server.batch_manager.get_messages_for_letta_batch_async(
|
||||
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
|
||||
)
|
||||
assert len(messages) == (len(agents) - 1) * 4 + 1
|
||||
@@ -621,7 +623,9 @@ async def test_resume_step_some_stop(
|
||||
letta_batch_job_id=batch_job.id,
|
||||
)
|
||||
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
|
||||
llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(
|
||||
letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user
|
||||
)
|
||||
llm_batch_job = llm_batch_jobs[0]
|
||||
|
||||
# 2. Invoke the polling job and mock responses from Anthropic
|
||||
@@ -723,7 +727,7 @@ async def test_resume_step_some_stop(
|
||||
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
|
||||
# Check the total list of messages
|
||||
messages = server.batch_manager.get_messages_for_letta_batch(
|
||||
messages = await server.batch_manager.get_messages_for_letta_batch_async(
|
||||
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
|
||||
)
|
||||
assert len(messages) == len(agents) * 3 + 1
|
||||
@@ -789,7 +793,9 @@ async def test_resume_step_after_request_all_continue(
|
||||
|
||||
# Basic sanity checks (This is tested more thoroughly in `test_step_until_request_prepares_and_submits_batch_correctly`
|
||||
# Verify batch items
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
|
||||
llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(
|
||||
letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user
|
||||
)
|
||||
assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}"
|
||||
|
||||
llm_batch_job = llm_batch_jobs[0]
|
||||
@@ -883,7 +889,7 @@ async def test_resume_step_after_request_all_continue(
|
||||
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
|
||||
# Check the total list of messages
|
||||
messages = server.batch_manager.get_messages_for_letta_batch(
|
||||
messages = await server.batch_manager.get_messages_for_letta_batch_async(
|
||||
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
|
||||
)
|
||||
assert len(messages) == len(agents) * 4
|
||||
@@ -987,7 +993,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
mock_send.assert_called_once()
|
||||
|
||||
# Verify database records were created correctly
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=response.letta_batch_id, actor=default_user)
|
||||
llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(letta_batch_id=response.letta_batch_id, actor=default_user)
|
||||
assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}"
|
||||
|
||||
llm_batch_job = llm_batch_jobs[0]
|
||||
|
||||
@@ -5107,7 +5107,7 @@ async def test_update_batch_status(server, default_user, dummy_beta_message_batc
|
||||
)
|
||||
before = datetime.now(timezone.utc)
|
||||
|
||||
server.batch_manager.update_llm_batch_status(
|
||||
await server.batch_manager.update_llm_batch_status_async(
|
||||
llm_batch_id=batch.id,
|
||||
status=JobStatus.completed,
|
||||
latest_polling_response=dummy_beta_message_batch,
|
||||
@@ -5132,7 +5132,7 @@ async def test_create_and_get_batch_item(
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
item = await server.batch_manager.create_llm_batch_item_async(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
@@ -5144,7 +5144,7 @@ async def test_create_and_get_batch_item(
|
||||
assert item.agent_id == sarah_agent.id
|
||||
assert item.step_state == dummy_step_state
|
||||
|
||||
fetched = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
|
||||
fetched = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user)
|
||||
assert fetched.id == item.id
|
||||
|
||||
|
||||
@@ -5168,7 +5168,7 @@ async def test_update_batch_item(
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
item = await server.batch_manager.create_llm_batch_item_async(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
@@ -5178,7 +5178,7 @@ async def test_update_batch_item(
|
||||
|
||||
updated_step_state = AgentStepState(step_number=2, tool_rules_solver=dummy_step_state.tool_rules_solver)
|
||||
|
||||
server.batch_manager.update_llm_batch_item(
|
||||
await server.batch_manager.update_llm_batch_item_async(
|
||||
item_id=item.id,
|
||||
request_status=JobStatus.completed,
|
||||
step_status=AgentStepStatus.resumed,
|
||||
@@ -5187,7 +5187,7 @@ async def test_update_batch_item(
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
|
||||
updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user)
|
||||
assert updated.request_status == JobStatus.completed
|
||||
assert updated.batch_request_result == dummy_successful_response
|
||||
|
||||
@@ -5204,7 +5204,7 @@ async def test_delete_batch_item(
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
item = await server.batch_manager.create_llm_batch_item_async(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
@@ -5212,10 +5212,10 @@ async def test_delete_batch_item(
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
server.batch_manager.delete_llm_batch_item(item_id=item.id, actor=default_user)
|
||||
await server.batch_manager.delete_llm_batch_item_async(item_id=item.id, actor=default_user)
|
||||
|
||||
with pytest.raises(NoResultFound):
|
||||
server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
|
||||
await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -5243,7 +5243,7 @@ async def test_bulk_update_batch_statuses(server, default_user, dummy_beta_messa
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
server.batch_manager.bulk_update_llm_batch_statuses([(batch.id, JobStatus.completed, dummy_beta_message_batch)])
|
||||
await server.batch_manager.bulk_update_llm_batch_statuses_async([(batch.id, JobStatus.completed, dummy_beta_message_batch)])
|
||||
|
||||
updated = await server.batch_manager.get_llm_batch_job_by_id_async(batch.id, actor=default_user)
|
||||
assert updated.status == JobStatus.completed
|
||||
@@ -5268,7 +5268,7 @@ async def test_bulk_update_batch_items_results_by_agent(
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
item = await server.batch_manager.create_llm_batch_item_async(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
@@ -5276,11 +5276,11 @@ async def test_bulk_update_batch_items_results_by_agent(
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
|
||||
await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async(
|
||||
[ItemUpdateInfo(batch.id, sarah_agent.id, JobStatus.completed, dummy_successful_response)]
|
||||
)
|
||||
|
||||
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
|
||||
updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user)
|
||||
assert updated.request_status == JobStatus.completed
|
||||
assert updated.batch_request_result == dummy_successful_response
|
||||
|
||||
@@ -5295,7 +5295,7 @@ async def test_bulk_update_batch_items_step_status_by_agent(
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
item = await server.batch_manager.create_llm_batch_item_async(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
@@ -5303,11 +5303,11 @@ async def test_bulk_update_batch_items_step_status_by_agent(
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
|
||||
await server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async(
|
||||
[StepStatusUpdateInfo(batch.id, sarah_agent.id, AgentStepStatus.resumed)]
|
||||
)
|
||||
|
||||
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
|
||||
updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user)
|
||||
assert updated.step_status == AgentStepStatus.resumed
|
||||
|
||||
|
||||
@@ -5323,7 +5323,7 @@ async def test_list_batch_items_limit_and_filter(
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
server.batch_manager.create_llm_batch_item(
|
||||
await server.batch_manager.create_llm_batch_item_async(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
@@ -5353,7 +5353,7 @@ async def test_list_batch_items_pagination(
|
||||
# Create 10 batch items.
|
||||
created_items = []
|
||||
for i in range(10):
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
item = await server.batch_manager.create_llm_batch_item_async(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
@@ -5416,7 +5416,7 @@ async def test_bulk_update_batch_items_request_status_by_agent(
|
||||
)
|
||||
|
||||
# Create a batch item
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
item = await server.batch_manager.create_llm_batch_item_async(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
@@ -5425,12 +5425,12 @@ async def test_bulk_update_batch_items_request_status_by_agent(
|
||||
)
|
||||
|
||||
# Update the request status using the bulk update method
|
||||
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
|
||||
await server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async(
|
||||
[RequestStatusUpdateInfo(batch.id, sarah_agent.id, JobStatus.expired)]
|
||||
)
|
||||
|
||||
# Verify the update was applied
|
||||
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
|
||||
updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user)
|
||||
assert updated.request_status == JobStatus.expired
|
||||
|
||||
|
||||
@@ -5459,20 +5459,20 @@ async def test_bulk_update_nonexistent_items_should_error(
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
|
||||
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates)
|
||||
await server.batch_manager.bulk_update_llm_batch_items_async(nonexistent_pairs, nonexistent_updates)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
|
||||
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
|
||||
await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async(
|
||||
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
|
||||
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
|
||||
await server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async(
|
||||
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
|
||||
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
|
||||
await server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async(
|
||||
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)]
|
||||
)
|
||||
|
||||
@@ -5496,21 +5496,21 @@ async def test_bulk_update_nonexistent_items(
|
||||
nonexistent_updates = [{"request_status": JobStatus.expired}]
|
||||
|
||||
# This should not raise an error, just silently skip non-existent items
|
||||
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates, strict=False)
|
||||
await server.batch_manager.bulk_update_llm_batch_items_async(nonexistent_pairs, nonexistent_updates, strict=False)
|
||||
|
||||
# Test with higher-level methods
|
||||
# Results by agent
|
||||
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
|
||||
await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async(
|
||||
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)], strict=False
|
||||
)
|
||||
|
||||
# Step status by agent
|
||||
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
|
||||
await server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async(
|
||||
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)], strict=False
|
||||
)
|
||||
|
||||
# Request status by agent
|
||||
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
|
||||
await server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async(
|
||||
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)], strict=False
|
||||
)
|
||||
|
||||
@@ -5565,7 +5565,7 @@ async def test_create_batch_items_bulk(
|
||||
# Verify the IDs of created items match what's in the database
|
||||
created_ids = [item.id for item in created_items]
|
||||
for item_id in created_ids:
|
||||
fetched = server.batch_manager.get_llm_batch_item_by_id(item_id, actor=default_user)
|
||||
fetched = await server.batch_manager.get_llm_batch_item_by_id_async(item_id, actor=default_user)
|
||||
assert fetched.id in created_ids
|
||||
|
||||
|
||||
@@ -5585,7 +5585,7 @@ async def test_count_batch_items(
|
||||
# Create a specific number of batch items for this batch.
|
||||
num_items = 5
|
||||
for _ in range(num_items):
|
||||
server.batch_manager.create_llm_batch_item(
|
||||
await server.batch_manager.create_llm_batch_item_async(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
@@ -5594,7 +5594,7 @@ async def test_count_batch_items(
|
||||
)
|
||||
|
||||
# Use the count_llm_batch_items method to count the items.
|
||||
count = server.batch_manager.count_llm_batch_items(llm_batch_id=batch.id)
|
||||
count = await server.batch_manager.count_llm_batch_items_async(llm_batch_id=batch.id)
|
||||
|
||||
# Assert that the count matches the expected number.
|
||||
assert count == num_items, f"Expected {num_items} items, got {count}"
|
||||
|
||||
Reference in New Issue
Block a user