diff --git a/letta/server/rest_api/routers/v1/messages.py b/letta/server/rest_api/routers/v1/messages.py index 047d8f5e..5424edda 100644 --- a/letta/server/rest_api/routers/v1/messages.py +++ b/letta/server/rest_api/routers/v1/messages.py @@ -7,7 +7,7 @@ from starlette.requests import Request from letta.agents.letta_agent_batch import LettaAgentBatch from letta.log import get_logger from letta.orm.errors import NoResultFound -from letta.schemas.job import BatchJob, JobStatus, JobType +from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate from letta.schemas.letta_request import CreateBatch from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer @@ -43,18 +43,17 @@ async def create_messages_batch( if length > max_bytes: raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.") - try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) + batch_job = BatchJob( + user_id=actor.id, + status=JobStatus.running, + metadata={ + "job_type": "batch_messages", + }, + callback_url=str(payload.callback_url), + ) - # Create a new job - batch_job = BatchJob( - user_id=actor.id, - status=JobStatus.created, - metadata={ - "job_type": "batch_messages", - }, - callback_url=str(payload.callback_url), - ) + try: batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor) # create the batch runner @@ -68,7 +67,7 @@ async def create_messages_batch( job_manager=server.job_manager, actor=actor, ) - llm_batch_job = await batch_runner.step_until_request(batch_requests=payload.requests, letta_batch_job_id=batch_job.id) + await batch_runner.step_until_request(batch_requests=payload.requests, letta_batch_job_id=batch_job.id) # TODO: update run metadata except Exception as e: @@ -78,7 +77,7 @@ async def create_messages_batch( traceback.print_exc() # mark job as failed - server.job_manager.update_job_by_id(job_id=batch_job.id, job=BatchJob(status=JobStatus.failed)) + server.job_manager.update_job_by_id(job_id=batch_job.id, job=BatchJob(status=JobStatus.failed), actor=actor) raise return batch_job @@ -129,8 +128,19 @@ async def cancel_batch_run( try: job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor) - job.status = JobStatus.cancelled - server.job_manager.update_job_by_id(job_id=job, job=job) - # TODO: actually cancel it + job = server.job_manager.update_job_by_id(job_id=job.id, job_update=JobUpdate(status=JobStatus.cancelled), actor=actor) + + # Get related llm batch jobs + llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=job.id, actor=actor) + for llm_batch_job in llm_batch_jobs: + if llm_batch_job.status in {JobStatus.running, JobStatus.created}: + # TODO: Extend to providers beyond anthropic + # TODO: For now, we only support anthropic + # Cancel the job + anthropic_batch_id = llm_batch_job.create_batch_response.id + await server.anthropic_async_client.messages.batches.cancel(anthropic_batch_id) + + # Update all the batch_job statuses + server.batch_manager.update_llm_batch_status(llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor) except NoResultFound: raise HTTPException(status_code=404, detail="Run not found") diff --git a/tests/integration_test_batch.py b/tests/integration_test_batch.py index d89a3a02..d36c5422 100644 --- a/tests/integration_test_batch.py +++ b/tests/integration_test_batch.py @@ -7,9 +7,23 @@ from dotenv import load_dotenv from letta_client import Letta, LettaBatchRequest, MessageCreate, TextContent from letta.config import LettaConfig +from letta.jobs.llm_batch_job_polling import poll_running_llm_batches +from letta.orm import Base +from letta.schemas.enums import JobStatus +from letta.server.db import db_context from letta.server.server import SyncServer +@pytest.fixture(autouse=True) +def clear_batch_tables(): + """Clear batch-related tables before each test.""" + with db_context() as session: + for table in reversed(Base.metadata.sorted_tables): + if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}: + session.execute(table.delete()) # Truncate table + session.commit() + + def run_server(): """Starts the Letta server in a background thread.""" load_dotenv() @@ -54,7 +68,8 @@ def client(server_url): return Letta(base_url=server_url) -def test_create_batch(client: Letta): +@pytest.mark.asyncio +async def test_create_batch(client: Letta, server: SyncServer): # create agents agent1 = client.agents.create( @@ -105,11 +120,21 @@ def test_create_batch(client: Letta): # list batches batches = client.batches.list() - assert len(batches) > 0, f"Expected 1 batch, got {len(batches)}" + assert len(batches) == 1, f"Expected 1 batch, got {len(batches)}" + assert batches[0].status == JobStatus.running + + # Poll it once + await poll_running_llm_batches(server) # get the batch results results = client.batches.retrieve( batch_id=run.id, ) assert results is not None - print(results) + + # cancel + client.batches.cancel(batch_id=run.id) + batch_job = client.batches.retrieve( + batch_id=run.id, + ) + assert batch_job.status == JobStatus.cancelled