feat: add batch job tracking and generate batch APIs (#1727)
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.orm import Base
|
||||
from letta.schemas.agent import AgentState, AgentStepState
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
||||
from letta.schemas.job import BatchJob
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_request import LettaBatchRequest
|
||||
from letta.schemas.message import MessageCreate
|
||||
@@ -256,7 +257,7 @@ def clear_batch_tables():
|
||||
"""Clear batch-related tables before each test."""
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables):
|
||||
if table.name in {"llm_batch_job", "llm_batch_items"}:
|
||||
if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}:
|
||||
session.execute(table.delete()) # Truncate table
|
||||
session.commit()
|
||||
|
||||
@@ -305,6 +306,22 @@ def client(server_url):
|
||||
return Letta(base_url=server_url)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_job(default_user, server):
|
||||
job = BatchJob(
|
||||
user_id=default_user.id,
|
||||
status=JobStatus.created,
|
||||
metadata={
|
||||
"job_type": "batch_messages",
|
||||
},
|
||||
)
|
||||
job = server.job_manager.create_job(pydantic_job=job, actor=default_user)
|
||||
yield job
|
||||
|
||||
# cleanup
|
||||
server.job_manager.delete_job_by_id(job.id, actor=default_user)
|
||||
|
||||
|
||||
class MockAsyncIterable:
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
@@ -324,8 +341,8 @@ class MockAsyncIterable:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_step_after_request_happy_path(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map
|
||||
async def test_resume_step_after_request_all_continue(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
anthropic_batch_id = "msgbatch_test_12345"
|
||||
dummy_batch_response = create_batch_response(
|
||||
@@ -342,6 +359,7 @@ async def test_resume_step_after_request_happy_path(
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
@@ -349,15 +367,20 @@ async def test_resume_step_after_request_happy_path(
|
||||
pre_resume_response = await batch_runner.step_until_request(
|
||||
batch_requests=batch_requests,
|
||||
agent_step_state_mapping=step_state_map,
|
||||
letta_batch_job_id=batch_job.id,
|
||||
)
|
||||
|
||||
# Basic sanity checks (This is tested more thoroughly in `test_step_until_request_prepares_and_submits_batch_correctly`
|
||||
# Verify batch items
|
||||
items = server.batch_manager.list_batch_items(batch_id=pre_resume_response.batch_id, actor=default_user)
|
||||
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
|
||||
assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}"
|
||||
|
||||
llm_batch_job = llm_batch_jobs[0]
|
||||
llm_batch_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
|
||||
assert len(llm_batch_items) == 3, f"Expected 3 llm_batch_items, got {len(llm_batch_items)}"
|
||||
|
||||
# 2. Invoke the polling job and mock responses from Anthropic
|
||||
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.batch_id, processing_status="ended"))
|
||||
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended"))
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
|
||||
mock_items = [
|
||||
@@ -372,13 +395,13 @@ async def test_resume_step_after_request_happy_path(
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# Verify database records were updated correctly
|
||||
job = server.batch_manager.get_batch_job_by_id(pre_resume_response.batch_id, actor=default_user)
|
||||
llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user)
|
||||
|
||||
# Verify job properties
|
||||
assert job.status == JobStatus.completed, "Job status should be 'completed'"
|
||||
assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'"
|
||||
|
||||
# Verify batch items
|
||||
items = server.batch_manager.list_batch_items(batch_id=job.id, actor=default_user)
|
||||
items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
|
||||
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
|
||||
assert all([item.request_status == JobStatus.completed for item in items])
|
||||
|
||||
@@ -390,22 +413,27 @@ async def test_resume_step_after_request_happy_path(
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=default_user,
|
||||
)
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
|
||||
|
||||
post_resume_response = await letta_batch_agent.resume_step_after_request(batch_id=pre_resume_response.batch_id)
|
||||
post_resume_response = await letta_batch_agent.resume_step_after_request(
|
||||
letta_batch_id=pre_resume_response.letta_batch_id, llm_batch_id=llm_batch_job.id
|
||||
)
|
||||
|
||||
# A *new* batch job should have been spawned
|
||||
assert (
|
||||
post_resume_response.batch_id != pre_resume_response.batch_id
|
||||
), "resume_step_after_request is expected to enqueue a follow‑up batch job."
|
||||
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
|
||||
), "resume_step_after_request is expected to have the same letta_batch_id"
|
||||
assert (
|
||||
post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id
|
||||
), "resume_step_after_request is expected to have different llm_batch_id."
|
||||
assert post_resume_response.status == JobStatus.running
|
||||
assert post_resume_response.agent_count == 3
|
||||
|
||||
# New batch‑items should exist, initialised in (created, paused) state
|
||||
new_items = server.batch_manager.list_batch_items(batch_id=post_resume_response.batch_id, actor=default_user)
|
||||
new_items = server.batch_manager.list_llm_batch_items(llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user)
|
||||
assert len(new_items) == 3, f"Expected 3 new batch items, got {len(new_items)}"
|
||||
assert {i.request_status for i in new_items} == {JobStatus.created}
|
||||
assert {i.step_status for i in new_items} == {AgentStepStatus.paused}
|
||||
@@ -420,7 +448,7 @@ async def test_resume_step_after_request_happy_path(
|
||||
|
||||
# Old items must have been flipped to completed / finished earlier
|
||||
# (sanity – we already asserted this above, but we keep it close for clarity)
|
||||
old_items = server.batch_manager.list_batch_items(batch_id=pre_resume_response.batch_id, actor=default_user)
|
||||
old_items = server.batch_manager.list_llm_batch_items(llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user)
|
||||
assert {i.request_status for i in old_items} == {JobStatus.completed}
|
||||
assert {i.step_status for i in old_items} == {AgentStepStatus.completed}
|
||||
|
||||
@@ -440,7 +468,7 @@ async def test_resume_step_after_request_happy_path(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response
|
||||
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job
|
||||
):
|
||||
"""
|
||||
Test that step_until_request correctly:
|
||||
@@ -512,6 +540,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
@@ -519,23 +548,25 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
response = await batch_runner.step_until_request(
|
||||
batch_requests=batch_requests,
|
||||
agent_step_state_mapping=step_state_map,
|
||||
letta_batch_job_id=batch_job.id,
|
||||
)
|
||||
|
||||
# Verify the mock was called exactly once
|
||||
mock_send.assert_called_once()
|
||||
|
||||
# Verify database records were created correctly
|
||||
job = server.batch_manager.get_batch_job_by_id(response.batch_id, actor=default_user)
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=response.letta_batch_id, actor=default_user)
|
||||
assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}"
|
||||
|
||||
llm_batch_job = llm_batch_jobs[0]
|
||||
llm_batch_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
|
||||
assert len(llm_batch_items) == 3, f"Expected 3 llm_batch_items, got {len(llm_batch_items)}"
|
||||
|
||||
# Verify job properties
|
||||
assert job.llm_provider == ProviderType.anthropic, "Job provider should be Anthropic"
|
||||
assert job.status == JobStatus.running, "Job status should be 'running'"
|
||||
|
||||
# Verify batch items
|
||||
items = server.batch_manager.list_batch_items(batch_id=job.id, actor=default_user)
|
||||
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
|
||||
assert llm_batch_job.llm_provider == ProviderType.anthropic, "Job provider should be Anthropic"
|
||||
assert llm_batch_job.status == JobStatus.running, "Job status should be 'running'"
|
||||
|
||||
# Verify all agents are represented in batch items
|
||||
agent_ids_in_items = {item.agent_id for item in items}
|
||||
agent_ids_in_items = {item.agent_id for item in llm_batch_items}
|
||||
expected_agent_ids = {agent.id for agent in agents}
|
||||
assert agent_ids_in_items == expected_agent_ids, f"Expected agent IDs {expected_agent_ids}, got {agent_ids_in_items}"
|
||||
|
||||
Reference in New Issue
Block a user