feat: Add batch size parameter (#2574)

This commit is contained in:
Matthew Zhou
2025-06-01 12:23:46 -07:00
committed by GitHub
parent 60cf3341b1
commit 0dd8994294

View File

@@ -5297,28 +5297,38 @@ async def test_delete_batch_item(
@pytest.mark.asyncio
async def test_list_running_batches(server, default_user, dummy_beta_message_batch, letta_batch_job, event_loop):
# Create a recent running batch
await server.batch_manager.create_llm_batch_job_async(
llm_provider=ProviderType.anthropic,
status=JobStatus.running,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
# Create recent running batches
num_running = 3
for _ in range(num_running):
await server.batch_manager.create_llm_batch_job_async(
llm_provider=ProviderType.anthropic,
status=JobStatus.running,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
# Should return at least one running batch (no time filter)
running_batches = await server.batch_manager.list_running_llm_batches_async(actor=default_user)
assert len(running_batches) >= 1
assert len(running_batches) == num_running
assert all(batch.status == JobStatus.running for batch in running_batches)
# Should return the same when filtering by recent 1 week
recent_batches = await server.batch_manager.list_running_llm_batches_async(actor=default_user, weeks=1)
assert len(recent_batches) >= 1
assert len(recent_batches) == num_running
assert all(batch.status == JobStatus.running for batch in recent_batches)
assert all(batch.created_at >= datetime.now(timezone.utc) - timedelta(weeks=1) for batch in recent_batches)
# Filter by size
recent_batches = await server.batch_manager.list_running_llm_batches_async(actor=default_user, weeks=1, batch_size=2)
assert len(recent_batches) == 2
assert all(batch.status == JobStatus.running for batch in recent_batches)
assert all(batch.created_at >= datetime.now(timezone.utc) - timedelta(weeks=1) for batch in recent_batches)
# Should return nothing if filtering by a very small timeframe (e.g., 0 weeks)
future_batches = await server.batch_manager.list_running_llm_batches_async(actor=default_user, weeks=0)
assert all(batch.created_at >= datetime.utcnow() - timedelta(weeks=0) for batch in future_batches)
assert len(future_batches) == 0
@pytest.mark.asyncio