diff --git a/letta/jobs/llm_batch_job_polling.py b/letta/jobs/llm_batch_job_polling.py index 401860e8..7651c283 100644 --- a/letta/jobs/llm_batch_job_polling.py +++ b/letta/jobs/llm_batch_job_polling.py @@ -11,6 +11,7 @@ from letta.schemas.letta_response import LettaBatchResponse from letta.schemas.llm_batch_job import LLMBatchJob from letta.schemas.user import User from letta.server.server import SyncServer +from letta.settings import settings logger = get_logger(__name__) @@ -180,7 +181,7 @@ async def poll_running_llm_batches(server: "SyncServer") -> List[LettaBatchRespo try: # 1. Retrieve running batch jobs - batches = await server.batch_manager.list_running_llm_batches_async() + batches = await server.batch_manager.list_running_llm_batches_async(weeks=max(settings.batch_job_polling_lookback_weeks, 1)) metrics.total_batches = len(batches) # TODO: Expand to more providers diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index c296a64e..c9071eba 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -205,14 +205,20 @@ class LLMBatchManager: @enforce_types @trace_method - async def list_running_llm_batches_async(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]: - """Return all running LLM batch jobs, optionally filtered by actor's organization.""" + async def list_running_llm_batches_async( + self, actor: Optional[PydanticUser] = None, weeks: Optional[int] = None + ) -> List[PydanticLLMBatchJob]: + """Return all running LLM batch jobs, optionally filtered by actor's organization and recent weeks.""" async with db_registry.async_session() as session: query = select(LLMBatchJob).where(LLMBatchJob.status == JobStatus.running) if actor is not None: query = query.where(LLMBatchJob.organization_id == actor.organization_id) + if weeks is not None: + cutoff_datetime = datetime.datetime.utcnow() - datetime.timedelta(weeks=weeks) + query = query.where(LLMBatchJob.created_at >= cutoff_datetime) + results = await session.execute(query) return [batch.to_pydantic() for batch in results.scalars().all()] diff --git a/letta/settings.py b/letta/settings.py index 06311432..b30c3f35 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -228,6 +228,7 @@ class Settings(BaseSettings): enable_batch_job_polling: bool = False poll_running_llm_batches_interval_seconds: int = 5 * 60 poll_lock_retry_interval_seconds: int = 5 * 60 + batch_job_polling_lookback_weeks: int = 2 @property def letta_pg_uri(self) -> str: diff --git a/tests/test_managers.py b/tests/test_managers.py index e033b994..c767355c 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -5245,6 +5245,7 @@ async def test_delete_batch_item( @pytest.mark.asyncio async def test_list_running_batches(server, default_user, dummy_beta_message_batch, letta_batch_job, event_loop): + # Create a recent running batch await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, status=JobStatus.running, @@ -5253,10 +5254,20 @@ async def test_list_running_batches(server, default_user, dummy_beta_message_bat letta_batch_job_id=letta_batch_job.id, ) + # Should return at least one running batch (no time filter) running_batches = await server.batch_manager.list_running_llm_batches_async(actor=default_user) assert len(running_batches) >= 1 assert all(batch.status == JobStatus.running for batch in running_batches) + # Should return the same when filtering by recent 1 week + recent_batches = await server.batch_manager.list_running_llm_batches_async(actor=default_user, weeks=1) + assert len(recent_batches) >= 1 + assert all(batch.status == JobStatus.running for batch in recent_batches) + + # Should return nothing if filtering by a very small timeframe (e.g., 0 weeks) + future_batches = await server.batch_manager.list_running_llm_batches_async(actor=default_user, weeks=0) + assert all(batch.created_at >= datetime.utcnow() - timedelta(weeks=0) for batch in future_batches) + @pytest.mark.asyncio async def test_bulk_update_batch_statuses(server, default_user, dummy_beta_message_batch, letta_batch_job, event_loop):