feat: Add cancel functionality to batch API (#1825)
This commit is contained in:
@@ -7,7 +7,7 @@ from starlette.requests import Request
|
||||
from letta.agents.letta_agent_batch import LettaAgentBatch
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.job import BatchJob, JobStatus, JobType
|
||||
from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate
|
||||
from letta.schemas.letta_request import CreateBatch
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
@@ -43,18 +43,17 @@ async def create_messages_batch(
|
||||
if length > max_bytes:
|
||||
raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.")
|
||||
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
batch_job = BatchJob(
|
||||
user_id=actor.id,
|
||||
status=JobStatus.running,
|
||||
metadata={
|
||||
"job_type": "batch_messages",
|
||||
},
|
||||
callback_url=str(payload.callback_url),
|
||||
)
|
||||
|
||||
# Create a new job
|
||||
batch_job = BatchJob(
|
||||
user_id=actor.id,
|
||||
status=JobStatus.created,
|
||||
metadata={
|
||||
"job_type": "batch_messages",
|
||||
},
|
||||
callback_url=str(payload.callback_url),
|
||||
)
|
||||
try:
|
||||
batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor)
|
||||
|
||||
# create the batch runner
|
||||
@@ -68,7 +67,7 @@ async def create_messages_batch(
|
||||
job_manager=server.job_manager,
|
||||
actor=actor,
|
||||
)
|
||||
llm_batch_job = await batch_runner.step_until_request(batch_requests=payload.requests, letta_batch_job_id=batch_job.id)
|
||||
await batch_runner.step_until_request(batch_requests=payload.requests, letta_batch_job_id=batch_job.id)
|
||||
|
||||
# TODO: update run metadata
|
||||
except Exception as e:
|
||||
@@ -78,7 +77,7 @@ async def create_messages_batch(
|
||||
traceback.print_exc()
|
||||
|
||||
# mark job as failed
|
||||
server.job_manager.update_job_by_id(job_id=batch_job.id, job=BatchJob(status=JobStatus.failed))
|
||||
server.job_manager.update_job_by_id(job_id=batch_job.id, job=BatchJob(status=JobStatus.failed), actor=actor)
|
||||
raise
|
||||
return batch_job
|
||||
|
||||
@@ -129,8 +128,19 @@ async def cancel_batch_run(
|
||||
|
||||
try:
|
||||
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
|
||||
job.status = JobStatus.cancelled
|
||||
server.job_manager.update_job_by_id(job_id=job, job=job)
|
||||
# TODO: actually cancel it
|
||||
job = server.job_manager.update_job_by_id(job_id=job.id, job_update=JobUpdate(status=JobStatus.cancelled), actor=actor)
|
||||
|
||||
# Get related llm batch jobs
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=job.id, actor=actor)
|
||||
for llm_batch_job in llm_batch_jobs:
|
||||
if llm_batch_job.status in {JobStatus.running, JobStatus.created}:
|
||||
# TODO: Extend to providers beyond anthropic
|
||||
# TODO: For now, we only support anthropic
|
||||
# Cancel the job
|
||||
anthropic_batch_id = llm_batch_job.create_batch_response.id
|
||||
await server.anthropic_async_client.messages.batches.cancel(anthropic_batch_id)
|
||||
|
||||
# Update all the batch_job statuses
|
||||
server.batch_manager.update_llm_batch_status(llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
@@ -7,9 +7,23 @@ from dotenv import load_dotenv
|
||||
from letta_client import Letta, LettaBatchRequest, MessageCreate, TextContent
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.orm import Base
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.server.db import db_context
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_batch_tables():
|
||||
"""Clear batch-related tables before each test."""
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables):
|
||||
if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}:
|
||||
session.execute(table.delete()) # Truncate table
|
||||
session.commit()
|
||||
|
||||
|
||||
def run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
@@ -54,7 +68,8 @@ def client(server_url):
|
||||
return Letta(base_url=server_url)
|
||||
|
||||
|
||||
def test_create_batch(client: Letta):
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch(client: Letta, server: SyncServer):
|
||||
|
||||
# create agents
|
||||
agent1 = client.agents.create(
|
||||
@@ -105,11 +120,21 @@ def test_create_batch(client: Letta):
|
||||
|
||||
# list batches
|
||||
batches = client.batches.list()
|
||||
assert len(batches) > 0, f"Expected 1 batch, got {len(batches)}"
|
||||
assert len(batches) == 1, f"Expected 1 batch, got {len(batches)}"
|
||||
assert batches[0].status == JobStatus.running
|
||||
|
||||
# Poll it once
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# get the batch results
|
||||
results = client.batches.retrieve(
|
||||
batch_id=run.id,
|
||||
)
|
||||
assert results is not None
|
||||
print(results)
|
||||
|
||||
# cancel
|
||||
client.batches.cancel(batch_id=run.id)
|
||||
batch_job = client.batches.retrieve(
|
||||
batch_id=run.id,
|
||||
)
|
||||
assert batch_job.status == JobStatus.cancelled
|
||||
|
||||
Reference in New Issue
Block a user