feat: Add lookback weeks parameter for batch polling (#2407)

This commit is contained in:
Matthew Zhou
2025-05-25 20:02:06 -07:00
committed by GitHub
parent e813a65351
commit dad8766dfb
4 changed files with 22 additions and 3 deletions

View File

@@ -11,6 +11,7 @@ from letta.schemas.letta_response import LettaBatchResponse
from letta.schemas.llm_batch_job import LLMBatchJob
from letta.schemas.user import User
from letta.server.server import SyncServer
from letta.settings import settings
logger = get_logger(__name__)
@@ -180,7 +181,7 @@ async def poll_running_llm_batches(server: "SyncServer") -> List[LettaBatchRespo
try:
# 1. Retrieve running batch jobs
batches = await server.batch_manager.list_running_llm_batches_async()
batches = await server.batch_manager.list_running_llm_batches_async(weeks=max(settings.batch_job_polling_lookback_weeks, 1))
metrics.total_batches = len(batches)
# TODO: Expand to more providers

View File

@@ -205,14 +205,20 @@ class LLMBatchManager:
@enforce_types
@trace_method
async def list_running_llm_batches_async(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]:
"""Return all running LLM batch jobs, optionally filtered by actor's organization."""
async def list_running_llm_batches_async(
self, actor: Optional[PydanticUser] = None, weeks: Optional[int] = None
) -> List[PydanticLLMBatchJob]:
"""Return all running LLM batch jobs, optionally filtered by actor's organization and recent weeks."""
async with db_registry.async_session() as session:
query = select(LLMBatchJob).where(LLMBatchJob.status == JobStatus.running)
if actor is not None:
query = query.where(LLMBatchJob.organization_id == actor.organization_id)
if weeks is not None:
cutoff_datetime = datetime.datetime.utcnow() - datetime.timedelta(weeks=weeks)
query = query.where(LLMBatchJob.created_at >= cutoff_datetime)
results = await session.execute(query)
return [batch.to_pydantic() for batch in results.scalars().all()]

View File

@@ -228,6 +228,7 @@ class Settings(BaseSettings):
enable_batch_job_polling: bool = False
poll_running_llm_batches_interval_seconds: int = 5 * 60
poll_lock_retry_interval_seconds: int = 5 * 60
batch_job_polling_lookback_weeks: int = 2
@property
def letta_pg_uri(self) -> str:

View File

@@ -5245,6 +5245,7 @@ async def test_delete_batch_item(
@pytest.mark.asyncio
async def test_list_running_batches(server, default_user, dummy_beta_message_batch, letta_batch_job, event_loop):
# Create a recent running batch
await server.batch_manager.create_llm_batch_job_async(
llm_provider=ProviderType.anthropic,
status=JobStatus.running,
@@ -5253,10 +5254,20 @@ async def test_list_running_batches(server, default_user, dummy_beta_message_bat
letta_batch_job_id=letta_batch_job.id,
)
# Should return at least one running batch (no time filter)
running_batches = await server.batch_manager.list_running_llm_batches_async(actor=default_user)
assert len(running_batches) >= 1
assert all(batch.status == JobStatus.running for batch in running_batches)
# Should return the same when filtering by recent 1 week
recent_batches = await server.batch_manager.list_running_llm_batches_async(actor=default_user, weeks=1)
assert len(recent_batches) >= 1
assert all(batch.status == JobStatus.running for batch in recent_batches)
# Should return nothing if filtering by a very small timeframe (e.g., 0 weeks)
future_batches = await server.batch_manager.list_running_llm_batches_async(actor=default_user, weeks=0)
assert all(batch.created_at >= datetime.utcnow() - timedelta(weeks=0) for batch in future_batches)
@pytest.mark.asyncio
async def test_bulk_update_batch_statuses(server, default_user, dummy_beta_message_batch, letta_batch_job, event_loop):