feat(asyncify): migrate jobs (#2420)

This commit is contained in:
cthomas
2025-05-25 18:58:06 -07:00
committed by GitHub
parent 19efa1a89a
commit c071e079dc
4 changed files with 109 additions and 77 deletions

View File

@@ -12,7 +12,7 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
@router.get("/", response_model=List[Job], operation_id="list_jobs")
def list_jobs(
async def list_jobs(
server: "SyncServer" = Depends(get_letta_server),
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@@ -20,10 +20,10 @@ def list_jobs(
"""
List all jobs.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# TODO: add filtering by status
jobs = server.job_manager.list_jobs(actor=actor)
jobs = await server.job_manager.list_jobs_async(actor=actor)
if source_id:
# can't be in the ORM since we have source_id stored in the metadata
@@ -46,7 +46,7 @@ async def list_active_jobs(
@router.get("/{job_id}", response_model=Job, operation_id="retrieve_job")
def retrieve_job(
async def retrieve_job(
job_id: str,
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
@@ -54,16 +54,16 @@ def retrieve_job(
"""
Get the status of a job.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
return server.job_manager.get_job_by_id(job_id=job_id, actor=actor)
return await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
except NoResultFound:
raise HTTPException(status_code=404, detail="Job not found")
@router.delete("/{job_id}", response_model=Job, operation_id="delete_job")
def delete_job(
async def delete_job(
job_id: str,
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
@@ -71,10 +71,10 @@ def delete_job(
"""
Delete a job by its job_id.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
job = server.job_manager.delete_job_by_id(job_id=job_id, actor=actor)
job = await server.job_manager.delete_job_by_id_async(job_id=job_id, actor=actor)
return job
except NoResultFound:
raise HTTPException(status_code=404, detail="Job not found")

View File

@@ -199,7 +199,7 @@ async def list_run_steps(
@router.delete("/{run_id}", response_model=Run, operation_id="delete_run")
def delete_run(
async def delete_run(
run_id: str,
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
@@ -207,10 +207,10 @@ def delete_run(
"""
Delete a run by its run_id.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
job = server.job_manager.delete_job_by_id(job_id=run_id, actor=actor)
job = await server.job_manager.delete_job_by_id_async(job_id=run_id, actor=actor)
return Run.from_job(job)
except NoResultFound:
raise HTTPException(status_code=404, detail="Run not found")

View File

@@ -197,6 +197,15 @@ class JobManager:
job.hard_delete(db_session=session, actor=actor)
return job.to_pydantic()
@enforce_types
@trace_method
async def delete_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob:
"""Delete a job by its ID."""
async with db_registry.async_session() as session:
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor)
await job.hard_delete_async(db_session=session, actor=actor)
return job.to_pydantic()
@enforce_types
@trace_method
def get_job_messages(
@@ -599,25 +608,23 @@ class JobManager:
async def _dispatch_callback_async(self, session, job: JobModel) -> None:
"""
POST a standard JSON payload to job.callback_url
and record timestamp + HTTP status asynchronously.
POST a standard JSON payload to job.callback_url and record timestamp + HTTP status asynchronously.
"""
payload = {
"job_id": job.id,
"status": job.status,
"completed_at": job.completed_at.isoformat(),
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
}
try:
import httpx
async with httpx.AsyncClient() as client:
resp = await client.post(job.callback_url, json=payload, timeout=5.0)
job.callback_sent_at = get_utc_time()
# Ensure timestamp is timezone-naive for DB compatibility
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
job.callback_status_code = resp.status_code
except Exception:
return
session.add(job)
await session.commit()
# Silently fail on callback errors - job updates should still succeed
# In production, this would include proper error logging
pass

View File

@@ -4245,14 +4245,15 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
# ======================================================================================================================
def test_create_job(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_create_job(server: SyncServer, default_user, event_loop):
"""Test creating a job."""
job_data = PydanticJob(
status=JobStatus.created,
metadata={"type": "test"},
)
created_job = server.job_manager.create_job(job_data, actor=default_user)
created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
# Assertions to ensure the created job matches the expected values
assert created_job.user_id == default_user.id
@@ -4260,17 +4261,18 @@ def test_create_job(server: SyncServer, default_user):
assert created_job.metadata == {"type": "test"}
def test_get_job_by_id(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_get_job_by_id(server: SyncServer, default_user, event_loop):
"""Test fetching a job by ID."""
# Create a job
job_data = PydanticJob(
status=JobStatus.created,
metadata={"type": "test"},
)
created_job = server.job_manager.create_job(job_data, actor=default_user)
created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
# Fetch the job by ID
fetched_job = server.job_manager.get_job_by_id(created_job.id, actor=default_user)
fetched_job = await server.job_manager.get_job_by_id_async(created_job.id, actor=default_user)
# Assertions to ensure the fetched job matches the created job
assert fetched_job.id == created_job.id
@@ -4278,7 +4280,8 @@ def test_get_job_by_id(server: SyncServer, default_user):
assert fetched_job.metadata == {"type": "test"}
def test_list_jobs(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_list_jobs(server: SyncServer, default_user, event_loop):
"""Test listing jobs."""
# Create multiple jobs
for i in range(3):
@@ -4286,10 +4289,10 @@ def test_list_jobs(server: SyncServer, default_user):
status=JobStatus.created,
metadata={"type": f"test-{i}"},
)
server.job_manager.create_job(job_data, actor=default_user)
await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
# List jobs
jobs = server.job_manager.list_jobs(actor=default_user)
jobs = await server.job_manager.list_jobs_async(actor=default_user)
# Assertions to check that the created jobs are listed
assert len(jobs) == 3
@@ -4297,19 +4300,20 @@ def test_list_jobs(server: SyncServer, default_user):
assert all(job.metadata["type"].startswith("test") for job in jobs)
def test_update_job_by_id(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_update_job_by_id(server: SyncServer, default_user, event_loop):
"""Test updating a job by its ID."""
# Create a job
job_data = PydanticJob(
status=JobStatus.created,
metadata={"type": "test"},
)
created_job = server.job_manager.create_job(job_data, actor=default_user)
created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
assert created_job.metadata == {"type": "test"}
# Update the job
update_data = JobUpdate(status=JobStatus.completed, metadata={"type": "updated"})
updated_job = server.job_manager.update_job_by_id(created_job.id, update_data, actor=default_user)
updated_job = await server.job_manager.update_job_by_id_async(created_job.id, update_data, actor=default_user)
# Assertions to ensure the job was updated
assert updated_job.status == JobStatus.completed
@@ -4317,56 +4321,61 @@ def test_update_job_by_id(server: SyncServer, default_user):
assert updated_job.completed_at is not None
def test_delete_job_by_id(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_delete_job_by_id(server: SyncServer, default_user, event_loop):
"""Test deleting a job by its ID."""
# Create a job
job_data = PydanticJob(
status=JobStatus.created,
metadata={"type": "test"},
)
created_job = server.job_manager.create_job(job_data, actor=default_user)
created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
# Delete the job
server.job_manager.delete_job_by_id(created_job.id, actor=default_user)
await server.job_manager.delete_job_by_id_async(created_job.id, actor=default_user)
# List jobs to ensure the job was deleted
jobs = server.job_manager.list_jobs(actor=default_user)
jobs = await server.job_manager.list_jobs_async(actor=default_user)
assert len(jobs) == 0
def test_update_job_auto_complete(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_update_job_auto_complete(server: SyncServer, default_user, event_loop):
"""Test that updating a job's status to 'completed' automatically sets completed_at."""
# Create a job
job_data = PydanticJob(
status=JobStatus.created,
metadata={"type": "test"},
)
created_job = server.job_manager.create_job(job_data, actor=default_user)
created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
# Update the job's status to 'completed'
update_data = JobUpdate(status=JobStatus.completed)
updated_job = server.job_manager.update_job_by_id(created_job.id, update_data, actor=default_user)
updated_job = await server.job_manager.update_job_by_id_async(created_job.id, update_data, actor=default_user)
# Assertions to check that completed_at was set
assert updated_job.status == JobStatus.completed
assert updated_job.completed_at is not None
def test_get_job_not_found(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_get_job_not_found(server: SyncServer, default_user, event_loop):
"""Test fetching a non-existent job."""
non_existent_job_id = "nonexistent-id"
with pytest.raises(NoResultFound):
server.job_manager.get_job_by_id(non_existent_job_id, actor=default_user)
await server.job_manager.get_job_by_id_async(non_existent_job_id, actor=default_user)
def test_delete_job_not_found(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_delete_job_not_found(server: SyncServer, default_user, event_loop):
"""Test deleting a non-existent job."""
non_existent_job_id = "nonexistent-id"
with pytest.raises(NoResultFound):
server.job_manager.delete_job_by_id(non_existent_job_id, actor=default_user)
await server.job_manager.delete_job_by_id_async(non_existent_job_id, actor=default_user)
def test_list_jobs_pagination(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_list_jobs_pagination(server: SyncServer, default_user, event_loop):
"""Test listing jobs with pagination."""
# Create multiple jobs
for i in range(10):
@@ -4374,19 +4383,19 @@ def test_list_jobs_pagination(server: SyncServer, default_user):
status=JobStatus.created,
metadata={"type": f"test-{i}"},
)
server.job_manager.create_job(job_data, actor=default_user)
await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
# List jobs with a limit
jobs = server.job_manager.list_jobs(actor=default_user, limit=5)
jobs = await server.job_manager.list_jobs_async(actor=default_user, limit=5)
assert len(jobs) == 5
assert all(job.user_id == default_user.id for job in jobs)
# Test cursor-based pagination
first_page = server.job_manager.list_jobs(actor=default_user, limit=3, ascending=True) # [J0, J1, J2]
first_page = await server.job_manager.list_jobs_async(actor=default_user, limit=3, ascending=True) # [J0, J1, J2]
assert len(first_page) == 3
assert first_page[0].created_at <= first_page[1].created_at <= first_page[2].created_at
last_page = server.job_manager.list_jobs(actor=default_user, limit=3, ascending=False) # [J9, J8, J7]
last_page = await server.job_manager.list_jobs_async(actor=default_user, limit=3, ascending=False) # [J9, J8, J7]
assert len(last_page) == 3
assert last_page[0].created_at >= last_page[1].created_at >= last_page[2].created_at
first_page_ids = set(job.id for job in first_page)
@@ -4394,7 +4403,7 @@ def test_list_jobs_pagination(server: SyncServer, default_user):
assert first_page_ids.isdisjoint(last_page_ids)
# Test middle page using both before and after
middle_page = server.job_manager.list_jobs(
middle_page = await server.job_manager.list_jobs_async(
actor=default_user, before=last_page[-1].id, after=first_page[-1].id, ascending=True
) # [J3, J4, J5, J6]
assert len(middle_page) == 4 # Should include jobs between first and second page
@@ -4402,7 +4411,7 @@ def test_list_jobs_pagination(server: SyncServer, default_user):
assert all(job.id not in head_tail_jobs for job in middle_page)
# Test descending order
middle_page_desc = server.job_manager.list_jobs(
middle_page_desc = await server.job_manager.list_jobs_async(
actor=default_user, before=last_page[-1].id, after=first_page[-1].id, ascending=False
) # [J6, J5, J4, J3]
assert len(middle_page_desc) == 4
@@ -4413,13 +4422,14 @@ def test_list_jobs_pagination(server: SyncServer, default_user):
# BONUS
job_7 = last_page[-1].id
earliest_jobs = server.job_manager.list_jobs(actor=default_user, ascending=False, before=job_7)
earliest_jobs = await server.job_manager.list_jobs_async(actor=default_user, ascending=False, before=job_7)
assert len(earliest_jobs) == 7
assert all(j.id not in last_page_ids for j in earliest_jobs)
assert all(earliest_jobs[i].created_at >= earliest_jobs[i + 1].created_at for i in range(len(earliest_jobs) - 1))
def test_list_jobs_by_status(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_list_jobs_by_status(server: SyncServer, default_user, event_loop):
"""Test listing jobs filtered by status."""
# Create multiple jobs with different statuses
job_data_created = PydanticJob(
@@ -4435,14 +4445,14 @@ def test_list_jobs_by_status(server: SyncServer, default_user):
metadata={"type": "test-completed"},
)
server.job_manager.create_job(job_data_created, actor=default_user)
server.job_manager.create_job(job_data_in_progress, actor=default_user)
server.job_manager.create_job(job_data_completed, actor=default_user)
await server.job_manager.create_job_async(pydantic_job=job_data_created, actor=default_user)
await server.job_manager.create_job_async(pydantic_job=job_data_in_progress, actor=default_user)
await server.job_manager.create_job_async(pydantic_job=job_data_completed, actor=default_user)
# List jobs filtered by status
created_jobs = server.job_manager.list_jobs(actor=default_user, statuses=[JobStatus.created])
in_progress_jobs = server.job_manager.list_jobs(actor=default_user, statuses=[JobStatus.running])
completed_jobs = server.job_manager.list_jobs(actor=default_user, statuses=[JobStatus.completed])
created_jobs = await server.job_manager.list_jobs_async(actor=default_user, statuses=[JobStatus.created])
in_progress_jobs = await server.job_manager.list_jobs_async(actor=default_user, statuses=[JobStatus.running])
completed_jobs = await server.job_manager.list_jobs_async(actor=default_user, statuses=[JobStatus.completed])
# Assertions
assert len(created_jobs) == 1
@@ -4455,7 +4465,8 @@ def test_list_jobs_by_status(server: SyncServer, default_user):
assert completed_jobs[0].metadata["type"] == job_data_completed.metadata["type"]
def test_list_jobs_filter_by_type(server: SyncServer, default_user, default_job):
@pytest.mark.asyncio
async def test_list_jobs_filter_by_type(server: SyncServer, default_user, default_job, event_loop):
"""Test that list_jobs correctly filters by job_type."""
# Create a run job
run_pydantic = PydanticJob(
@@ -4463,48 +4474,62 @@ def test_list_jobs_filter_by_type(server: SyncServer, default_user, default_job)
status=JobStatus.pending,
job_type=JobType.RUN,
)
run = server.job_manager.create_job(pydantic_job=run_pydantic, actor=default_user)
run = await server.job_manager.create_job_async(pydantic_job=run_pydantic, actor=default_user)
# List only regular jobs
jobs = server.job_manager.list_jobs(actor=default_user)
jobs = await server.job_manager.list_jobs_async(actor=default_user)
assert len(jobs) == 1
assert jobs[0].id == default_job.id
# List only run jobs
jobs = server.job_manager.list_jobs(actor=default_user, job_type=JobType.RUN)
jobs = await server.job_manager.list_jobs_async(actor=default_user, job_type=JobType.RUN)
assert len(jobs) == 1
assert jobs[0].id == run.id
def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user):
"""Test that job callbacks are properly dispatched when a job is completed."""
captured = {}
def fake_post(url, json, timeout):
# Create a simple mock for the async HTTP client
class MockAsyncResponse:
status_code = 202
async def mock_post(url, json, timeout):
captured["url"] = url
captured["json"] = json
return MockAsyncResponse()
class FakeResponse:
status_code = 202
class MockAsyncClient:
async def __aenter__(self):
return self
return FakeResponse()
async def __aexit__(self, *args):
pass
monkeypatch.setattr(httpx, "post", fake_post)
async def post(self, url, json, timeout):
return await mock_post(url, json, timeout)
# Patch the AsyncClient
monkeypatch.setattr(httpx, "AsyncClient", MockAsyncClient)
job_in = PydanticJob(status=JobStatus.created, metadata={"foo": "bar"}, callback_url="http://example.test/webhook/jobs")
created = server.job_manager.create_job(job_in, actor=default_user)
created = await server.job_manager.create_job_async(pydantic_job=job_in, actor=default_user)
assert created.callback_url == "http://example.test/webhook/jobs"
# Update the job status to completed, which should trigger the callback
update = JobUpdate(status=JobStatus.completed)
updated = server.job_manager.update_job_by_id(created.id, update, actor=default_user)
updated = await server.job_manager.update_job_by_id_async(created.id, update, actor=default_user)
assert captured["url"] == created.callback_url
assert captured["json"]["job_id"] == created.id
assert captured["json"]["status"] == JobStatus.completed.value
# Verify the callback was triggered with the correct parameters
assert captured["url"] == created.callback_url, "Callback URL doesn't match"
assert captured["json"]["job_id"] == created.id, "Job ID in callback doesn't match"
assert captured["json"]["status"] == JobStatus.completed.value, "Job status in callback doesn't match"
# Normalize the received completed_at to compare properly
# Verify the completed_at timestamp is reasonable
actual_dt = datetime.fromisoformat(captured["json"]["completed_at"]).replace(tzinfo=None)
expected_dt = updated.completed_at.replace(tzinfo=None)
assert actual_dt == expected_dt
assert abs((actual_dt - updated.completed_at).total_seconds()) < 1, "Timestamp difference is too large"
assert isinstance(updated.callback_sent_at, datetime)
assert updated.callback_status_code == 202