feat: add batch job tracking and generate batch APIs (#1727)

Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
Sarah Wooders
2025-04-17 17:02:07 -07:00
committed by GitHub
parent ec623325da
commit da62cc6898
22 changed files with 690 additions and 262 deletions

View File

@@ -33,6 +33,7 @@ from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.orm import Base
from letta.schemas.agent import AgentState, AgentStepState
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
from letta.schemas.job import BatchJob
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_request import LettaBatchRequest
from letta.schemas.message import MessageCreate
@@ -256,7 +257,7 @@ def clear_batch_tables():
"""Clear batch-related tables before each test."""
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables):
if table.name in {"llm_batch_job", "llm_batch_items"}:
if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}:
session.execute(table.delete()) # Truncate table
session.commit()
@@ -305,6 +306,22 @@ def client(server_url):
return Letta(base_url=server_url)
@pytest.fixture
def batch_job(default_user, server):
job = BatchJob(
user_id=default_user.id,
status=JobStatus.created,
metadata={
"job_type": "batch_messages",
},
)
job = server.job_manager.create_job(pydantic_job=job, actor=default_user)
yield job
# cleanup
server.job_manager.delete_job_by_id(job.id, actor=default_user)
class MockAsyncIterable:
def __init__(self, items):
self.items = items
@@ -324,8 +341,8 @@ class MockAsyncIterable:
@pytest.mark.asyncio
async def test_resume_step_after_request_happy_path(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map
async def test_resume_step_after_request_all_continue(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
anthropic_batch_id = "msgbatch_test_12345"
dummy_batch_response = create_batch_response(
@@ -342,6 +359,7 @@ async def test_resume_step_after_request_happy_path(
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=default_user,
)
@@ -349,15 +367,20 @@ async def test_resume_step_after_request_happy_path(
pre_resume_response = await batch_runner.step_until_request(
batch_requests=batch_requests,
agent_step_state_mapping=step_state_map,
letta_batch_job_id=batch_job.id,
)
# Basic sanity checks (This is tested more thoroughly in `test_step_until_request_prepares_and_submits_batch_correctly`
# Verify batch items
items = server.batch_manager.list_batch_items(batch_id=pre_resume_response.batch_id, actor=default_user)
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}"
llm_batch_job = llm_batch_jobs[0]
llm_batch_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
assert len(llm_batch_items) == 3, f"Expected 3 llm_batch_items, got {len(llm_batch_items)}"
# 2. Invoke the polling job and mock responses from Anthropic
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.batch_id, processing_status="ended"))
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended"))
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
mock_items = [
@@ -372,13 +395,13 @@ async def test_resume_step_after_request_happy_path(
await poll_running_llm_batches(server)
# Verify database records were updated correctly
job = server.batch_manager.get_batch_job_by_id(pre_resume_response.batch_id, actor=default_user)
llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user)
# Verify job properties
assert job.status == JobStatus.completed, "Job status should be 'completed'"
assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'"
# Verify batch items
items = server.batch_manager.list_batch_items(batch_id=job.id, actor=default_user)
items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
assert all([item.request_status == JobStatus.completed for item in items])
@@ -390,22 +413,27 @@ async def test_resume_step_after_request_happy_path(
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=default_user,
)
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
post_resume_response = await letta_batch_agent.resume_step_after_request(batch_id=pre_resume_response.batch_id)
post_resume_response = await letta_batch_agent.resume_step_after_request(
letta_batch_id=pre_resume_response.letta_batch_id, llm_batch_id=llm_batch_job.id
)
# A *new* batch job should have been spawned
assert (
post_resume_response.batch_id != pre_resume_response.batch_id
), "resume_step_after_request is expected to enqueue a followup batch job."
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
), "resume_step_after_request is expected to have the same letta_batch_id"
assert (
post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id
), "resume_step_after_request is expected to have different llm_batch_id."
assert post_resume_response.status == JobStatus.running
assert post_resume_response.agent_count == 3
# New batchitems should exist, initialised in (created, paused) state
new_items = server.batch_manager.list_batch_items(batch_id=post_resume_response.batch_id, actor=default_user)
new_items = server.batch_manager.list_llm_batch_items(llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user)
assert len(new_items) == 3, f"Expected 3 new batch items, got {len(new_items)}"
assert {i.request_status for i in new_items} == {JobStatus.created}
assert {i.step_status for i in new_items} == {AgentStepStatus.paused}
@@ -420,7 +448,7 @@ async def test_resume_step_after_request_happy_path(
# Old items must have been flipped to completed / finished earlier
# (sanity we already asserted this above, but we keep it close for clarity)
old_items = server.batch_manager.list_batch_items(batch_id=pre_resume_response.batch_id, actor=default_user)
old_items = server.batch_manager.list_llm_batch_items(llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user)
assert {i.request_status for i in old_items} == {JobStatus.completed}
assert {i.step_status for i in old_items} == {AgentStepStatus.completed}
@@ -440,7 +468,7 @@ async def test_resume_step_after_request_happy_path(
@pytest.mark.asyncio
async def test_step_until_request_prepares_and_submits_batch_correctly(
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job
):
"""
Test that step_until_request correctly:
@@ -512,6 +540,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=default_user,
)
@@ -519,23 +548,25 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
response = await batch_runner.step_until_request(
batch_requests=batch_requests,
agent_step_state_mapping=step_state_map,
letta_batch_job_id=batch_job.id,
)
# Verify the mock was called exactly once
mock_send.assert_called_once()
# Verify database records were created correctly
job = server.batch_manager.get_batch_job_by_id(response.batch_id, actor=default_user)
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=response.letta_batch_id, actor=default_user)
assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}"
llm_batch_job = llm_batch_jobs[0]
llm_batch_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
assert len(llm_batch_items) == 3, f"Expected 3 llm_batch_items, got {len(llm_batch_items)}"
# Verify job properties
assert job.llm_provider == ProviderType.anthropic, "Job provider should be Anthropic"
assert job.status == JobStatus.running, "Job status should be 'running'"
# Verify batch items
items = server.batch_manager.list_batch_items(batch_id=job.id, actor=default_user)
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
assert llm_batch_job.llm_provider == ProviderType.anthropic, "Job provider should be Anthropic"
assert llm_batch_job.status == JobStatus.running, "Job status should be 'running'"
# Verify all agents are represented in batch items
agent_ids_in_items = {item.agent_id for item in items}
agent_ids_in_items = {item.agent_id for item in llm_batch_items}
expected_agent_ids = {agent.id for agent in agents}
assert agent_ids_in_items == expected_agent_ids, f"Expected agent IDs {expected_agent_ids}, got {agent_ids_in_items}"