feat: add batch job tracking and generate batch APIs (#1727)

Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
Sarah Wooders
2025-04-17 17:02:07 -07:00
committed by GitHub
parent ec623325da
commit da62cc6898
22 changed files with 690 additions and 262 deletions

View File

@@ -0,0 +1,41 @@
"""Rename batch_id to llm_batch_id on llm_batch_item
Revision ID: 7b189006c97d
Revises: f2f78d62005c
Create Date: 2025-04-17 16:04:52.045672
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "7b189006c97d"
down_revision: Union[str, None] = "f2f78d62005c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("llm_batch_items", sa.Column("llm_batch_id", sa.String(), nullable=False))
op.drop_index("ix_llm_batch_items_batch_id", table_name="llm_batch_items")
op.create_index("ix_llm_batch_items_llm_batch_id", "llm_batch_items", ["llm_batch_id"], unique=False)
op.drop_constraint("llm_batch_items_batch_id_fkey", "llm_batch_items", type_="foreignkey")
op.create_foreign_key(None, "llm_batch_items", "llm_batch_job", ["llm_batch_id"], ["id"], ondelete="CASCADE")
op.drop_column("llm_batch_items", "batch_id")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("llm_batch_items", sa.Column("batch_id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.drop_constraint(None, "llm_batch_items", type_="foreignkey")
op.create_foreign_key("llm_batch_items_batch_id_fkey", "llm_batch_items", "llm_batch_job", ["batch_id"], ["id"], ondelete="CASCADE")
op.drop_index("ix_llm_batch_items_llm_batch_id", table_name="llm_batch_items")
op.create_index("ix_llm_batch_items_batch_id", "llm_batch_items", ["batch_id"], unique=False)
op.drop_column("llm_batch_items", "llm_batch_id")
# ### end Alembic commands ###

View File

@@ -0,0 +1,33 @@
"""Add letta batch job id to llm_batch_job
Revision ID: f2f78d62005c
Revises: c3b1da3d1157
Create Date: 2025-04-17 15:58:43.705483
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "f2f78d62005c"
down_revision: Union[str, None] = "c3b1da3d1157"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("llm_batch_job", sa.Column("letta_batch_job_id", sa.String(), nullable=False))
op.create_foreign_key(None, "llm_batch_job", "jobs", ["letta_batch_job_id"], ["id"], ondelete="CASCADE")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "llm_batch_job", type_="foreignkey")
op.drop_column("llm_batch_job", "letta_batch_job_id")
# ### end Alembic commands ###

View File

@@ -17,6 +17,7 @@ from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState, AgentStepState
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
from letta.schemas.job import JobUpdate
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_request import LettaBatchRequest
from letta.schemas.letta_response import LettaBatchResponse
@@ -29,6 +30,7 @@ from letta.server.rest_api.utils import create_heartbeat_system_message, create_
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import compile_system_message
from letta.services.job_manager import JobManager
from letta.services.llm_batch_manager import LLMBatchManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
@@ -102,6 +104,7 @@ class LettaAgentBatch:
passage_manager: PassageManager,
batch_manager: LLMBatchManager,
sandbox_config_manager: SandboxConfigManager,
job_manager: JobManager,
actor: User,
use_assistant_message: bool = True,
max_steps: int = 10,
@@ -112,12 +115,16 @@ class LettaAgentBatch:
self.passage_manager = passage_manager
self.batch_manager = batch_manager
self.sandbox_config_manager = sandbox_config_manager
self.job_manager = job_manager
self.use_assistant_message = use_assistant_message
self.actor = actor
self.max_steps = max_steps
async def step_until_request(
self, batch_requests: List[LettaBatchRequest], agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None
self,
batch_requests: List[LettaBatchRequest],
letta_batch_job_id: str,
agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None,
) -> LettaBatchResponse:
# Basic checks
if not batch_requests:
@@ -162,11 +169,12 @@ class LettaAgentBatch:
)
# Write the response into the jobs table, where it will get picked up by the next cron run
batch_job = self.batch_manager.create_batch_job(
llm_batch_job = self.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic, # TODO: Expand to more providers
create_batch_response=batch_response,
actor=self.actor,
status=JobStatus.running,
letta_batch_job_id=letta_batch_job_id,
)
# Create batch items in bulk for all agents
@@ -174,7 +182,7 @@ class LettaAgentBatch:
for agent_state in agent_states:
agent_step_state = agent_step_state_mapping.get(agent_state.id)
batch_item = LLMBatchItem(
batch_id=batch_job.id,
llm_batch_id=llm_batch_job.id,
agent_id=agent_state.id,
llm_config=agent_state.llm_config,
request_status=JobStatus.created,
@@ -185,19 +193,21 @@ class LettaAgentBatch:
# Create all batch items at once using the bulk operation
if batch_items:
self.batch_manager.create_batch_items_bulk(batch_items, actor=self.actor)
self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor)
return LettaBatchResponse(
batch_id=batch_job.id,
status=batch_job.status,
letta_batch_id=llm_batch_job.letta_batch_job_id,
last_llm_batch_id=llm_batch_job.id,
status=llm_batch_job.status,
agent_count=len(agent_states),
last_polled_at=get_utc_time(),
created_at=batch_job.created_at,
created_at=llm_batch_job.created_at,
)
async def resume_step_after_request(self, batch_id: str) -> LettaBatchResponse:
async def resume_step_after_request(self, letta_batch_id: str, llm_batch_id: str) -> LettaBatchResponse:
# 1. gather everything we need
ctx = await self._collect_resume_context(batch_id)
llm_batch_job = self.batch_manager.get_llm_batch_job_by_id(llm_batch_id=llm_batch_id, actor=self.actor)
ctx = await self._collect_resume_context(llm_batch_id)
# 2. persist requestlevel status updates
self._update_request_statuses(ctx.request_status_updates)
@@ -209,19 +219,31 @@ class LettaAgentBatch:
msg_map = self._persist_tool_messages(exec_results, ctx)
# 5. mark steps complete
self._mark_steps_complete(batch_id, ctx.agent_ids)
self._mark_steps_complete(llm_batch_id, ctx.agent_ids)
# 6. build nextround requests / stepstate map
next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map)
if len(next_reqs) == 0:
# mark batch job as completed
self.job_manager.update_job_by_id(job_id=letta_batch_id, job_update=JobUpdate(status=JobStatus.completed), actor=self.actor)
return LettaBatchResponse(
letta_batch_id=llm_batch_job.letta_batch_job_id,
last_llm_batch_id=llm_batch_job.id,
status=JobStatus.completed,
agent_count=len(ctx.agent_ids),
last_polled_at=get_utc_time(),
created_at=llm_batch_job.created_at,
)
# 7. recurse into the normal stepping pipeline
return await self.step_until_request(
batch_requests=next_reqs,
letta_batch_job_id=letta_batch_id,
agent_step_state_mapping=next_step_state,
)
async def _collect_resume_context(self, batch_id: str) -> _ResumeContext:
batch_items = self.batch_manager.list_batch_items(batch_id=batch_id)
async def _collect_resume_context(self, llm_batch_id: str) -> _ResumeContext:
batch_items = self.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_id)
agent_ids, agent_state_map = [], {}
provider_results, name_map, args_map, cont_map = {}, {}, {}, {}
@@ -244,7 +266,7 @@ class LettaAgentBatch:
else JobStatus.cancelled if isinstance(pr, BetaMessageBatchCanceledResult) else JobStatus.expired
)
)
request_status_updates.append(RequestStatusUpdateInfo(batch_id=batch_id, agent_id=aid, request_status=status))
request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status))
# translate providerspecific response → OpenAIstyle tool call (unchanged)
llm_client = LLMClient.create(llm_config=item.llm_config, put_inner_thoughts_first=True)
@@ -270,7 +292,7 @@ class LettaAgentBatch:
def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None:
if updates:
self.batch_manager.bulk_update_batch_items_request_status_by_agent(updates=updates)
self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(updates=updates)
def _build_sandbox(self) -> Tuple[SandboxConfig, Dict[str, Any]]:
sbx_type = SandboxType.E2B if tool_settings.e2b_api_key else SandboxType.LOCAL
@@ -315,9 +337,11 @@ class LettaAgentBatch:
self.message_manager.create_many_messages([m for msgs in msg_map.values() for m in msgs], actor=self.actor)
return msg_map
def _mark_steps_complete(self, batch_id: str, agent_ids: List[str]) -> None:
updates = [StepStatusUpdateInfo(batch_id=batch_id, agent_id=aid, step_status=AgentStepStatus.completed) for aid in agent_ids]
self.batch_manager.bulk_update_batch_items_step_status_by_agent(updates)
def _mark_steps_complete(self, llm_batch_id: str, agent_ids: List[str]) -> None:
updates = [
StepStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, step_status=AgentStepStatus.completed) for aid in agent_ids
]
self.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(updates)
def _prepare_next_iteration(
self,

View File

@@ -102,7 +102,7 @@ async def poll_batch_updates(server: SyncServer, batch_jobs: List[LLMBatchJob],
results: List[BatchPollingResult] = await asyncio.gather(*coros)
# Update the server with batch status changes
server.batch_manager.bulk_update_batch_statuses(updates=results)
server.batch_manager.bulk_update_llm_batch_statuses(updates=results)
logger.info(f"[Poll BatchJob] Bulk-updated {len(results)} LLM batch(es) in the DB at job level.")
return results
@@ -176,7 +176,7 @@ async def poll_running_llm_batches(server: "SyncServer") -> None:
try:
# 1. Retrieve running batch jobs
batches = server.batch_manager.list_running_batches()
batches = server.batch_manager.list_running_llm_batches()
metrics.total_batches = len(batches)
# TODO: Expand to more providers
@@ -193,7 +193,7 @@ async def poll_running_llm_batches(server: "SyncServer") -> None:
# 6. Bulk update all items for newly completed batch(es)
if item_updates:
metrics.updated_items_count = len(item_updates)
server.batch_manager.bulk_update_batch_items_results_by_agent(item_updates)
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(item_updates)
else:
logger.info("[Poll BatchJob] No item-level updates needed.")

View File

@@ -6,25 +6,25 @@ from letta.schemas.enums import AgentStepStatus, JobStatus
class BatchPollingResult(NamedTuple):
batch_id: str
llm_batch_id: str
request_status: JobStatus
batch_response: Optional[BetaMessageBatch]
class ItemUpdateInfo(NamedTuple):
batch_id: str
llm_batch_id: str
agent_id: str
request_status: JobStatus
batch_request_result: Optional[BetaMessageBatchIndividualResponse]
class StepStatusUpdateInfo(NamedTuple):
batch_id: str
llm_batch_id: str
agent_id: str
step_status: AgentStepStatus
class RequestStatusUpdateInfo(NamedTuple):
batch_id: str
llm_batch_id: str
agent_id: str
request_status: JobStatus

View File

@@ -16,6 +16,7 @@ class ToolType(str, Enum):
class JobType(str, Enum):
JOB = "job"
RUN = "run"
BATCH = "batch"
class ToolSourceType(str, Enum):

View File

@@ -20,7 +20,7 @@ class LLMBatchItem(SqlalchemyBase, OrganizationMixin, AgentMixin):
__tablename__ = "llm_batch_items"
__pydantic_model__ = PydanticLLMBatchItem
__table_args__ = (
Index("ix_llm_batch_items_batch_id", "batch_id"),
Index("ix_llm_batch_items_llm_batch_id", "llm_batch_id"),
Index("ix_llm_batch_items_agent_id", "agent_id"),
Index("ix_llm_batch_items_status", "request_status"),
)
@@ -29,7 +29,7 @@ class LLMBatchItem(SqlalchemyBase, OrganizationMixin, AgentMixin):
# TODO: Some still rely on the Pydantic object to do this
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"batch_item-{uuid.uuid4()}")
batch_id: Mapped[str] = mapped_column(
llm_batch_id: Mapped[str] = mapped_column(
ForeignKey("llm_batch_job.id", ondelete="CASCADE"), doc="Foreign key to the LLM provider batch this item belongs to"
)

View File

@@ -3,7 +3,7 @@ from datetime import datetime
from typing import List, Optional, Union
from anthropic.types.beta.messages import BetaMessageBatch
from sqlalchemy import DateTime, Index, String
from sqlalchemy import DateTime, ForeignKey, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.custom_columns import CreateBatchResponseColumn, PollBatchResponseColumn
@@ -43,6 +43,9 @@ class LLMBatchJob(SqlalchemyBase, OrganizationMixin):
DateTime(timezone=True), nullable=True, doc="Last time we polled the provider for status"
)
# relationships
letta_batch_job_id: Mapped[str] = mapped_column(
String, ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, doc="ID of the Letta batch job"
)
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_jobs")
items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="batch", lazy="selectin")

View File

@@ -34,6 +34,41 @@ class Job(JobBase):
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the job.")
class BatchJob(JobBase):
id: str = JobBase.generate_id_field()
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the job.")
job_type: JobType = JobType.BATCH
@classmethod
def from_job(cls, job: Job) -> "BatchJob":
"""
Convert a Job instance to a BatchJob instance by replacing the ID prefix.
All other fields are copied as-is.
Args:
job: The Job instance to convert
Returns:
A new Run instance with the same data but 'run-' prefix in ID
"""
# Convert job dict to exclude None values
job_data = job.model_dump(exclude_none=True)
# Create new Run instance with converted data
return cls(**job_data)
def to_job(self) -> Job:
"""
Convert this BatchJob instance to a Job instance by replacing the ID prefix.
All other fields are copied as-is.
Returns:
A new Job instance with the same data but 'job-' prefix in ID
"""
run_data = self.model_dump(exclude_none=True)
return Job(**run_data)
class JobUpdate(JobBase):
status: Optional[JobStatus] = Field(None, description="The status of the job.")

View File

@@ -31,3 +31,7 @@ class LettaStreamingRequest(LettaRequest):
class LettaBatchRequest(LettaRequest):
agent_id: str = Field(..., description="The ID of the agent to send this batch request for")
class CreateBatch(BaseModel):
requests: List[LettaBatchRequest] = Field(..., description="List of requests to be processed in batch.")

View File

@@ -169,7 +169,8 @@ LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStat
class LettaBatchResponse(BaseModel):
batch_id: str = Field(..., description="A unique identifier for this batch request.")
letta_batch_id: str = Field(..., description="A unique identifier for the Letta batch request.")
last_llm_batch_id: str = Field(..., description="A unique identifier for the most recent model provider batch request.")
status: JobStatus = Field(..., description="The current status of the batch request.")
agent_count: int = Field(..., description="The number of agents in the batch request.")
last_polled_at: datetime = Field(..., description="The timestamp when the batch was last polled for updates.")

View File

@@ -20,7 +20,7 @@ class LLMBatchItem(OrmMetadataBase, validate_assignment=True):
__id_prefix__ = "batch_item"
id: Optional[str] = Field(None, description="The id of the batch item. Assigned by the database.")
batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.")
llm_batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.")
agent_id: str = Field(..., description="The id of the agent associated with this LLM request.")
llm_config: LLMConfig = Field(..., description="The LLM configuration used for this request.")
@@ -45,6 +45,7 @@ class LLMBatchJob(OrmMetadataBase, validate_assignment=True):
id: Optional[str] = Field(None, description="The id of the batch job. Assigned by the database.")
status: JobStatus = Field(..., description="The current status of the batch (e.g., created, in_progress, done).")
llm_provider: ProviderType = Field(..., description="The LLM provider used for the batch (e.g., anthropic, openai).")
letta_batch_job_id: str = Field(..., description="ID of the Letta batch job")
create_batch_response: Union[BetaMessageBatch] = Field(..., description="The full JSON response from the initial batch creation.")
latest_polling_response: Optional[Union[BetaMessageBatch]] = Field(

View File

@@ -174,6 +174,7 @@ def create_application() -> "FastAPI":
async def generic_error_handler(request: Request, exc: Exception):
# Log the actual error for debugging
log.error(f"Unhandled error: {exc}", exc_info=True)
print(f"Unhandled error: {exc}")
# Print the stack trace
print(f"Stack trace: {exc}")

View File

@@ -5,6 +5,7 @@ from letta.server.rest_api.routers.v1.health import router as health_router
from letta.server.rest_api.routers.v1.identities import router as identities_router
from letta.server.rest_api.routers.v1.jobs import router as jobs_router
from letta.server.rest_api.routers.v1.llms import router as llm_router
from letta.server.rest_api.routers.v1.messages import router as messages_router
from letta.server.rest_api.routers.v1.providers import router as providers_router
from letta.server.rest_api.routers.v1.runs import router as runs_router
from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_configs_router
@@ -29,5 +30,6 @@ ROUTERS = [
runs_router,
steps_router,
tags_router,
messages_router,
voice_router,
]

View File

@@ -12,7 +12,6 @@ from sqlalchemy.exc import IntegrityError, OperationalError
from starlette.responses import Response, StreamingResponse
from letta.agents.letta_agent import LettaAgent
from letta.agents.letta_agent_batch import LettaAgentBatch
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger
from letta.orm.errors import NoResultFound
@@ -21,8 +20,8 @@ from letta.schemas.block import Block, BlockUpdate
from letta.schemas.group import Group
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion
from letta.schemas.letta_request import LettaBatchRequest, LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaBatchResponse, LettaResponse
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
from letta.schemas.message import MessageCreate
from letta.schemas.passage import Passage, PassageUpdate
@@ -832,57 +831,3 @@ async def list_agent_groups(
actor = server.user_manager.get_user_or_default(user_id=actor_id)
print("in list agents with manager_type", manager_type)
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)
# Batch APIs
@router.post("/messages/batches", response_model=LettaBatchResponse, operation_id="create_batch_message_request")
async def send_batch_messages(
batch_requests: List[LettaBatchRequest] = Body(..., description="Messages and config for all agents"),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Submit a batch of agent messages for asynchronous processing.
Creates a job that will fan out messages to all listed agents and process them in parallel.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
batch_runner = LettaAgentBatch(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
actor=actor,
)
return await batch_runner.step_until_request(batch_requests=batch_requests)
@router.get(
"/messages/batches/{batch_id}",
response_model=LettaBatchResponse,
operation_id="retrieve_batch_message_request",
)
async def retrieve_batch_message_request(
batch_id: str,
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Retrieve the result or current status of a previously submitted batch message request.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
batch_job = server.batch_manager.get_batch_job_by_id(batch_id=batch_id, actor=actor)
agent_count = server.batch_manager.count_batch_items(batch_id=batch_id)
return LettaBatchResponse(
batch_id=batch_id,
status=batch_job.status,
agent_count=agent_count,
last_polled_at=batch_job.last_polled_at,
created_at=batch_job.created_at,
)

View File

@@ -0,0 +1,127 @@
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, Header
from fastapi.exceptions import HTTPException
from letta.agents.letta_agent_batch import LettaAgentBatch
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.schemas.job import BatchJob, JobStatus, JobType
from letta.schemas.letta_request import CreateBatch
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
router = APIRouter(prefix="/messages", tags=["messages"])
logger = get_logger(__name__)
# Batch APIs
@router.post(
"/batches",
response_model=BatchJob,
operation_id="create_messages_batch",
)
async def create_messages_batch(
request: CreateBatch = Body(..., description="Messages and config for all agents"),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Submit a batch of agent messages for asynchronous processing.
Creates a job that will fan out messages to all listed agents and process them in parallel.
"""
print("GOT REQQUEST", request)
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
print("ACTOR", actor)
# Create a new job
batch_job = BatchJob(
user_id=actor.id,
status=JobStatus.created,
metadata={
"job_type": "batch_messages",
},
)
print("BATCH JOB", batch_job)
# create the batch runner
batch_runner = LettaAgentBatch(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=actor,
)
print("call step_until_request", batch_job)
llm_batch_job = await batch_runner.step_until_request(batch_requests=request.requests, letta_batch_job_id=batch_job.id)
# TODO: update run metadata
batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor)
except Exception as e:
print("Error creating batch job", e)
import traceback
traceback.print_exc()
raise
return batch_job
@router.get("/batches/{batch_id}", response_model=BatchJob, operation_id="retrieve_batch_run")
async def retrieve_batch_run(
batch_id: str,
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Get the status of a batch run.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
return BatchJob.from_job(job)
except NoResultFound:
raise HTTPException(status_code=404, detail="Batch not found")
@router.get("/batches", response_model=List[BatchJob], operation_id="list_batch_runs")
async def list_batch_runs(
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
):
"""
List all batch runs.
"""
# TODO: filter
actor = server.user_manager.get_user_or_default(user_id=actor_id)
jobs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.BATCH)
print("ACTIVE", jobs)
return [BatchJob.from_job(job) for job in jobs]
@router.patch("/batches/{batch_id}/cancel", operation_id="cancel_batch_run")
async def cancel_batch_run(
batch_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Cancel a batch run.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
job.status = JobStatus.cancelled
server.job_manager.update_job_by_id(job_id=job, job=job)
# TODO: actually cancel it
except NoResultFound:
raise HTTPException(status_code=404, detail="Run not found")

View File

@@ -15,6 +15,7 @@ from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step
from letta.orm.step import Step as StepModel
from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.job import BatchJob as PydanticBatchJob
from letta.schemas.job import Job as PydanticJob
from letta.schemas.job import JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessage
@@ -36,7 +37,9 @@ class JobManager:
self.session_maker = db_context
@enforce_types
def create_job(self, pydantic_job: Union[PydanticJob, PydanticRun], actor: PydanticUser) -> Union[PydanticJob, PydanticRun]:
def create_job(
self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser
) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]:
"""Create a new job based on the JobCreate schema."""
with self.session_maker() as session:
# Associate the job with the user

View File

@@ -28,11 +28,12 @@ class LLMBatchManager:
self.session_maker = db_context
@enforce_types
def create_batch_job(
def create_llm_batch_job(
self,
llm_provider: ProviderType,
create_batch_response: BetaMessageBatch,
actor: PydanticUser,
letta_batch_job_id: str,
status: JobStatus = JobStatus.created,
) -> PydanticLLMBatchJob:
"""Create a new LLM batch job."""
@@ -42,51 +43,52 @@ class LLMBatchManager:
llm_provider=llm_provider,
create_batch_response=create_batch_response,
organization_id=actor.organization_id,
letta_batch_job_id=letta_batch_job_id,
)
batch.create(session, actor=actor)
return batch.to_pydantic()
@enforce_types
def get_batch_job_by_id(self, batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob:
def get_llm_batch_job_by_id(self, llm_batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob:
"""Retrieve a single batch job by ID."""
with self.session_maker() as session:
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
return batch.to_pydantic()
@enforce_types
def update_batch_status(
def update_llm_batch_status(
self,
batch_id: str,
llm_batch_id: str,
status: JobStatus,
actor: Optional[PydanticUser] = None,
latest_polling_response: Optional[BetaMessageBatch] = None,
) -> PydanticLLMBatchJob:
"""Update a batch jobs status and optionally its polling response."""
with self.session_maker() as session:
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
batch.status = status
batch.latest_polling_response = latest_polling_response
batch.last_polled_at = datetime.datetime.now(datetime.timezone.utc)
batch = batch.update(db_session=session, actor=actor)
return batch.to_pydantic()
def bulk_update_batch_statuses(
def bulk_update_llm_batch_statuses(
self,
updates: List[BatchPollingResult],
) -> None:
"""
Efficiently update many LLMBatchJob rows. This is used by the cron jobs.
`updates` = [(batch_id, new_status, polling_response_or_None), …]
`updates` = [(llm_batch_id, new_status, polling_response_or_None), …]
"""
now = datetime.datetime.now(datetime.timezone.utc)
with self.session_maker() as session:
mappings = []
for batch_id, status, response in updates:
for llm_batch_id, status, response in updates:
mappings.append(
{
"id": batch_id,
"id": llm_batch_id,
"status": status,
"latest_polling_response": response,
"last_polled_at": now,
@@ -97,14 +99,51 @@ class LLMBatchManager:
session.commit()
@enforce_types
def delete_batch_request(self, batch_id: str, actor: PydanticUser) -> None:
def list_llm_batch_jobs(
self,
letta_batch_id: str,
limit: Optional[int] = None,
actor: Optional[PydanticUser] = None,
after: Optional[str] = None,
) -> List[PydanticLLMBatchItem]:
"""
List all batch items for a given llm_batch_id, optionally filtered by additional criteria and limited in count.
Optional filters:
- after: A cursor string. Only items with an `id` greater than this value are returned.
- agent_id: Restrict the result set to a specific agent.
- request_status: Filter items based on their request status (e.g., created, completed, expired).
- step_status: Filter items based on their step execution status.
The results are ordered by their id in ascending order.
"""
with self.session_maker() as session:
query = session.query(LLMBatchJob).filter(LLMBatchJob.letta_batch_job_id == letta_batch_id)
if actor is not None:
query = query.filter(LLMBatchJob.organization_id == actor.organization_id)
# Additional optional filters
if after is not None:
query = query.filter(LLMBatchJob.id > after)
query = query.order_by(LLMBatchJob.id.asc())
if limit is not None:
query = query.limit(limit)
results = query.all()
return [item.to_pydantic() for item in results]
@enforce_types
def delete_llm_batch_request(self, llm_batch_id: str, actor: PydanticUser) -> None:
"""Hard delete a batch job by ID."""
with self.session_maker() as session:
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
batch.hard_delete(db_session=session, actor=actor)
@enforce_types
def list_running_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]:
def list_running_llm_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]:
"""Return all running LLM batch jobs, optionally filtered by actor's organization."""
with self.session_maker() as session:
query = session.query(LLMBatchJob).filter(LLMBatchJob.status == JobStatus.running)
@@ -116,9 +155,9 @@ class LLMBatchManager:
return [batch.to_pydantic() for batch in results]
@enforce_types
def create_batch_item(
def create_llm_batch_item(
self,
batch_id: str,
llm_batch_id: str,
agent_id: str,
llm_config: LLMConfig,
actor: PydanticUser,
@@ -129,7 +168,7 @@ class LLMBatchManager:
"""Create a new batch item."""
with self.session_maker() as session:
item = LLMBatchItem(
batch_id=batch_id,
llm_batch_id=llm_batch_id,
agent_id=agent_id,
llm_config=llm_config,
request_status=request_status,
@@ -141,7 +180,7 @@ class LLMBatchManager:
return item.to_pydantic()
@enforce_types
def create_batch_items_bulk(self, llm_batch_items: List[PydanticLLMBatchItem], actor: PydanticUser) -> List[PydanticLLMBatchItem]:
def create_llm_batch_items_bulk(self, llm_batch_items: List[PydanticLLMBatchItem], actor: PydanticUser) -> List[PydanticLLMBatchItem]:
"""
Create multiple batch items in bulk for better performance.
@@ -157,7 +196,7 @@ class LLMBatchManager:
orm_items = []
for item in llm_batch_items:
orm_item = LLMBatchItem(
batch_id=item.batch_id,
llm_batch_id=item.llm_batch_id,
agent_id=item.agent_id,
llm_config=item.llm_config,
request_status=item.request_status,
@@ -174,14 +213,14 @@ class LLMBatchManager:
return [item.to_pydantic() for item in created_items]
@enforce_types
def get_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
def get_llm_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
"""Retrieve a single batch item by ID."""
with self.session_maker() as session:
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
return item.to_pydantic()
@enforce_types
def update_batch_item(
def update_llm_batch_item(
self,
item_id: str,
actor: PydanticUser,
@@ -206,9 +245,9 @@ class LLMBatchManager:
return item.update(db_session=session, actor=actor).to_pydantic()
@enforce_types
def list_batch_items(
def list_llm_batch_items(
self,
batch_id: str,
llm_batch_id: str,
limit: Optional[int] = None,
actor: Optional[PydanticUser] = None,
after: Optional[str] = None,
@@ -217,7 +256,7 @@ class LLMBatchManager:
step_status: Optional[AgentStepStatus] = None,
) -> List[PydanticLLMBatchItem]:
"""
List all batch items for a given batch_id, optionally filtered by additional criteria and limited in count.
List all batch items for a given llm_batch_id, optionally filtered by additional criteria and limited in count.
Optional filters:
- after: A cursor string. Only items with an `id` greater than this value are returned.
@@ -228,7 +267,7 @@ class LLMBatchManager:
The results are ordered by their id in ascending order.
"""
with self.session_maker() as session:
query = session.query(LLMBatchItem).filter(LLMBatchItem.batch_id == batch_id)
query = session.query(LLMBatchItem).filter(LLMBatchItem.llm_batch_id == llm_batch_id)
if actor is not None:
query = query.filter(LLMBatchItem.organization_id == actor.organization_id)
@@ -251,36 +290,36 @@ class LLMBatchManager:
results = query.all()
return [item.to_pydantic() for item in results]
def bulk_update_batch_items(
def bulk_update_llm_batch_items(
self,
batch_id_agent_id_pairs: List[Tuple[str, str]],
llm_batch_id_agent_id_pairs: List[Tuple[str, str]],
field_updates: List[Dict[str, Any]],
) -> None:
"""
Efficiently update multiple LLMBatchItem rows by (batch_id, agent_id) pairs.
Efficiently update multiple LLMBatchItem rows by (llm_batch_id, agent_id) pairs.
Args:
batch_id_agent_id_pairs: List of (batch_id, agent_id) tuples identifying items to update
llm_batch_id_agent_id_pairs: List of (llm_batch_id, agent_id) tuples identifying items to update
field_updates: List of dictionaries containing the fields to update for each item
"""
if not batch_id_agent_id_pairs or not field_updates:
if not llm_batch_id_agent_id_pairs or not field_updates:
return
if len(batch_id_agent_id_pairs) != len(field_updates):
if len(llm_batch_id_agent_id_pairs) != len(field_updates):
raise ValueError("batch_id_agent_id_pairs and field_updates must have the same length")
with self.session_maker() as session:
# Lookup primary keys
items = (
session.query(LLMBatchItem.id, LLMBatchItem.batch_id, LLMBatchItem.agent_id)
.filter(tuple_(LLMBatchItem.batch_id, LLMBatchItem.agent_id).in_(batch_id_agent_id_pairs))
session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id)
.filter(tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs))
.all()
)
pair_to_pk = {(b, a): id for id, b, a in items}
mappings = []
for (batch_id, agent_id), fields in zip(batch_id_agent_id_pairs, field_updates):
pk_id = pair_to_pk.get((batch_id, agent_id))
for (llm_batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates):
pk_id = pair_to_pk.get((llm_batch_id, agent_id))
if not pk_id:
continue
@@ -293,12 +332,12 @@ class LLMBatchManager:
session.commit()
@enforce_types
def bulk_update_batch_items_results_by_agent(
def bulk_update_batch_llm_items_results_by_agent(
self,
updates: List[ItemUpdateInfo],
) -> None:
"""Update request status and batch results for multiple batch items."""
batch_id_agent_id_pairs = [(update.batch_id, update.agent_id) for update in updates]
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
field_updates = [
{
"request_status": update.request_status,
@@ -307,48 +346,48 @@ class LLMBatchManager:
for update in updates
]
self.bulk_update_batch_items(batch_id_agent_id_pairs, field_updates)
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
@enforce_types
def bulk_update_batch_items_step_status_by_agent(
def bulk_update_llm_batch_items_step_status_by_agent(
self,
updates: List[StepStatusUpdateInfo],
) -> None:
"""Update step status for multiple batch items."""
batch_id_agent_id_pairs = [(update.batch_id, update.agent_id) for update in updates]
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
field_updates = [{"step_status": update.step_status} for update in updates]
self.bulk_update_batch_items(batch_id_agent_id_pairs, field_updates)
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
@enforce_types
def bulk_update_batch_items_request_status_by_agent(
def bulk_update_llm_batch_items_request_status_by_agent(
self,
updates: List[RequestStatusUpdateInfo],
) -> None:
"""Update request status for multiple batch items."""
batch_id_agent_id_pairs = [(update.batch_id, update.agent_id) for update in updates]
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
field_updates = [{"request_status": update.request_status} for update in updates]
self.bulk_update_batch_items(batch_id_agent_id_pairs, field_updates)
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
@enforce_types
def delete_batch_item(self, item_id: str, actor: PydanticUser) -> None:
def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None:
"""Hard delete a batch item by ID."""
with self.session_maker() as session:
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
item.hard_delete(db_session=session, actor=actor)
@enforce_types
def count_batch_items(self, batch_id: str) -> int:
def count_llm_batch_items(self, llm_batch_id: str) -> int:
"""
Efficiently count the number of batch items for a given batch_id.
Efficiently count the number of batch items for a given llm_batch_id.
Args:
batch_id (str): The batch identifier to count items for.
llm_batch_id (str): The batch identifier to count items for.
Returns:
int: The total number of batch items associated with the given batch_id.
int: The total number of batch items associated with the given llm_batch_id.
"""
with self.session_maker() as session:
count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.batch_id == batch_id).scalar()
count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.llm_batch_id == llm_batch_id).scalar()
return count or 0

View File

@@ -0,0 +1,80 @@
# import time
#
# import pytest
# from letta_client import Letta, LettaBatchRequest, MessageCreate, TextContent
#
#
# @pytest.fixture(scope="module")
# def client():
# return Letta(base_url="http://localhost:8283")
#
#
# def test_create_batch(client: Letta):
#
# # create agents
# agent1 = client.agents.create(
# name="agent1",
# memory_blocks=[{"label": "persona", "value": "you are agent 1"}],
# model="anthropic/claude-3-7-sonnet-20250219",
# embedding="letta/letta-free",
# )
# agent2 = client.agents.create(
# name="agent2",
# memory_blocks=[{"label": "persona", "value": "you are agent 2"}],
# model="anthropic/claude-3-7-sonnet-20250219",
# embedding="letta/letta-free",
# )
#
# # create a run
# run = client.messages.batches.create(
# requests=[
# LettaBatchRequest(
# messages=[
# MessageCreate(
# role="user",
# content=[
# TextContent(
# text="text",
# )
# ],
# )
# ],
# agent_id=agent1.id,
# ),
# LettaBatchRequest(
# messages=[
# MessageCreate(
# role="user",
# content=[
# TextContent(
# text="text",
# )
# ],
# )
# ],
# agent_id=agent2.id,
# ),
# ]
# )
# assert run is not None
#
# # list batches
# batches = client.messages.batches.list()
# assert len(batches) > 0, f"Expected 1 batch, got {len(batches)}"
#
# # check run status
# while True:
# run = client.messages.batches.retrieve(batch_id=run.id)
# if run.status == "completed":
# break
# print("Waiting for run to complete...", run.status)
# time.sleep(1)
#
# # get the batch results
# results = client.messages.batches.retrieve(
# run_id=run.id,
# )
# assert results is not None
# print(results)
#
# # cancel a run

View File

@@ -23,6 +23,7 @@ from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.orm import Base
from letta.schemas.agent import AgentStepState
from letta.schemas.enums import JobStatus, ProviderType
from letta.schemas.job import BatchJob
from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool_rule import InitToolRule
from letta.server.db import db_context
@@ -145,13 +146,21 @@ def create_test_agent(client, name, model="anthropic/claude-3-5-sonnet-20241022"
)
def create_test_batch_job(server, batch_response, default_user):
def create_test_letta_batch_job(server, default_user):
"""Create a test batch job with the given batch response."""
return server.batch_manager.create_batch_job(
return server.job_manager.create_job(BatchJob(user_id=default_user.id), actor=default_user)
def create_test_llm_batch_job(server, batch_response, default_user):
"""Create a test batch job with the given batch response."""
letta_batch_job = create_test_letta_batch_job(server, default_user)
return server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
create_batch_response=batch_response,
actor=default_user,
status=JobStatus.running,
letta_batch_job_id=letta_batch_job.id,
)
@@ -171,8 +180,8 @@ def create_test_batch_item(server, batch_id, agent_id, default_user):
step_number=1, tool_rules_solver=ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")])
)
return server.batch_manager.create_batch_item(
batch_id=batch_id,
return server.batch_manager.create_llm_batch_item(
llm_batch_id=batch_id,
agent_id=agent_id,
llm_config=dummy_llm_config,
step_state=common_step_state,
@@ -242,8 +251,8 @@ async def test_polling_mixed_batch_jobs(client, default_user, server):
agent_c = create_test_agent(client, "agent_c")
# --- Step 2: Create batch jobs ---
job_a = create_test_batch_job(server, batch_a_resp, default_user)
job_b = create_test_batch_job(server, batch_b_resp, default_user)
job_a = create_test_llm_batch_job(server, batch_a_resp, default_user)
job_b = create_test_llm_batch_job(server, batch_b_resp, default_user)
# --- Step 3: Create batch items ---
item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user)
@@ -258,8 +267,8 @@ async def test_polling_mixed_batch_jobs(client, default_user, server):
await poll_running_llm_batches(server)
# --- Step 6: Verify batch job status updates ---
updated_job_a = server.batch_manager.get_batch_job_by_id(batch_id=job_a.id, actor=default_user)
updated_job_b = server.batch_manager.get_batch_job_by_id(batch_id=job_b.id, actor=default_user)
updated_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user)
updated_job_b = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_b.id, actor=default_user)
# Job A should remain running since its processing_status is "in_progress"
assert updated_job_a.status == JobStatus.running
@@ -273,17 +282,17 @@ async def test_polling_mixed_batch_jobs(client, default_user, server):
# --- Step 7: Verify batch item status updates ---
# Item A should remain unchanged
updated_item_a = server.batch_manager.get_batch_item_by_id(item_a.id, actor=default_user)
updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
assert updated_item_a.request_status == JobStatus.created
assert updated_item_a.batch_request_result is None
# Item B should be marked as completed with a successful result
updated_item_b = server.batch_manager.get_batch_item_by_id(item_b.id, actor=default_user)
updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
assert updated_item_b.request_status == JobStatus.completed
assert updated_item_b.batch_request_result is not None
# Item C should be marked as failed with an error result
updated_item_c = server.batch_manager.get_batch_item_by_id(item_c.id, actor=default_user)
updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
assert updated_item_c.request_status == JobStatus.failed
assert updated_item_c.batch_request_result is not None
@@ -307,11 +316,11 @@ async def test_polling_mixed_batch_jobs(client, default_user, server):
# --- Step 9: Verify that nothing changed for completed jobs ---
# Refresh all objects
final_job_a = server.batch_manager.get_batch_job_by_id(batch_id=job_a.id, actor=default_user)
final_job_b = server.batch_manager.get_batch_job_by_id(batch_id=job_b.id, actor=default_user)
final_item_a = server.batch_manager.get_batch_item_by_id(item_a.id, actor=default_user)
final_item_b = server.batch_manager.get_batch_item_by_id(item_b.id, actor=default_user)
final_item_c = server.batch_manager.get_batch_item_by_id(item_c.id, actor=default_user)
final_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user)
final_job_b = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_b.id, actor=default_user)
final_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
final_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
final_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
# Job A should still be polling (last_polled_at should update)
assert final_job_a.status == JobStatus.running

View File

@@ -33,6 +33,7 @@ from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.orm import Base
from letta.schemas.agent import AgentState, AgentStepState
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
from letta.schemas.job import BatchJob
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_request import LettaBatchRequest
from letta.schemas.message import MessageCreate
@@ -256,7 +257,7 @@ def clear_batch_tables():
"""Clear batch-related tables before each test."""
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables):
if table.name in {"llm_batch_job", "llm_batch_items"}:
if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}:
session.execute(table.delete()) # Truncate table
session.commit()
@@ -305,6 +306,22 @@ def client(server_url):
return Letta(base_url=server_url)
@pytest.fixture
def batch_job(default_user, server):
job = BatchJob(
user_id=default_user.id,
status=JobStatus.created,
metadata={
"job_type": "batch_messages",
},
)
job = server.job_manager.create_job(pydantic_job=job, actor=default_user)
yield job
# cleanup
server.job_manager.delete_job_by_id(job.id, actor=default_user)
class MockAsyncIterable:
def __init__(self, items):
self.items = items
@@ -324,8 +341,8 @@ class MockAsyncIterable:
@pytest.mark.asyncio
async def test_resume_step_after_request_happy_path(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map
async def test_resume_step_after_request_all_continue(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
anthropic_batch_id = "msgbatch_test_12345"
dummy_batch_response = create_batch_response(
@@ -342,6 +359,7 @@ async def test_resume_step_after_request_happy_path(
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=default_user,
)
@@ -349,15 +367,20 @@ async def test_resume_step_after_request_happy_path(
pre_resume_response = await batch_runner.step_until_request(
batch_requests=batch_requests,
agent_step_state_mapping=step_state_map,
letta_batch_job_id=batch_job.id,
)
# Basic sanity checks (This is tested more thoroughly in `test_step_until_request_prepares_and_submits_batch_correctly`
# Verify batch items
items = server.batch_manager.list_batch_items(batch_id=pre_resume_response.batch_id, actor=default_user)
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}"
llm_batch_job = llm_batch_jobs[0]
llm_batch_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
assert len(llm_batch_items) == 3, f"Expected 3 llm_batch_items, got {len(llm_batch_items)}"
# 2. Invoke the polling job and mock responses from Anthropic
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.batch_id, processing_status="ended"))
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended"))
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
mock_items = [
@@ -372,13 +395,13 @@ async def test_resume_step_after_request_happy_path(
await poll_running_llm_batches(server)
# Verify database records were updated correctly
job = server.batch_manager.get_batch_job_by_id(pre_resume_response.batch_id, actor=default_user)
llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user)
# Verify job properties
assert job.status == JobStatus.completed, "Job status should be 'completed'"
assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'"
# Verify batch items
items = server.batch_manager.list_batch_items(batch_id=job.id, actor=default_user)
items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
assert all([item.request_status == JobStatus.completed for item in items])
@@ -390,22 +413,27 @@ async def test_resume_step_after_request_happy_path(
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=default_user,
)
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
post_resume_response = await letta_batch_agent.resume_step_after_request(batch_id=pre_resume_response.batch_id)
post_resume_response = await letta_batch_agent.resume_step_after_request(
letta_batch_id=pre_resume_response.letta_batch_id, llm_batch_id=llm_batch_job.id
)
# A *new* batch job should have been spawned
assert (
post_resume_response.batch_id != pre_resume_response.batch_id
), "resume_step_after_request is expected to enqueue a followup batch job."
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
), "resume_step_after_request is expected to have the same letta_batch_id"
assert (
post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id
), "resume_step_after_request is expected to have different llm_batch_id."
assert post_resume_response.status == JobStatus.running
assert post_resume_response.agent_count == 3
# New batchitems should exist, initialised in (created, paused) state
new_items = server.batch_manager.list_batch_items(batch_id=post_resume_response.batch_id, actor=default_user)
new_items = server.batch_manager.list_llm_batch_items(llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user)
assert len(new_items) == 3, f"Expected 3 new batch items, got {len(new_items)}"
assert {i.request_status for i in new_items} == {JobStatus.created}
assert {i.step_status for i in new_items} == {AgentStepStatus.paused}
@@ -420,7 +448,7 @@ async def test_resume_step_after_request_happy_path(
# Old items must have been flipped to completed / finished earlier
# (sanity we already asserted this above, but we keep it close for clarity)
old_items = server.batch_manager.list_batch_items(batch_id=pre_resume_response.batch_id, actor=default_user)
old_items = server.batch_manager.list_llm_batch_items(llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user)
assert {i.request_status for i in old_items} == {JobStatus.completed}
assert {i.step_status for i in old_items} == {AgentStepStatus.completed}
@@ -440,7 +468,7 @@ async def test_resume_step_after_request_happy_path(
@pytest.mark.asyncio
async def test_step_until_request_prepares_and_submits_batch_correctly(
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job
):
"""
Test that step_until_request correctly:
@@ -512,6 +540,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=default_user,
)
@@ -519,23 +548,25 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
response = await batch_runner.step_until_request(
batch_requests=batch_requests,
agent_step_state_mapping=step_state_map,
letta_batch_job_id=batch_job.id,
)
# Verify the mock was called exactly once
mock_send.assert_called_once()
# Verify database records were created correctly
job = server.batch_manager.get_batch_job_by_id(response.batch_id, actor=default_user)
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=response.letta_batch_id, actor=default_user)
assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}"
llm_batch_job = llm_batch_jobs[0]
llm_batch_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
assert len(llm_batch_items) == 3, f"Expected 3 llm_batch_items, got {len(llm_batch_items)}"
# Verify job properties
assert job.llm_provider == ProviderType.anthropic, "Job provider should be Anthropic"
assert job.status == JobStatus.running, "Job status should be 'running'"
# Verify batch items
items = server.batch_manager.list_batch_items(batch_id=job.id, actor=default_user)
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
assert llm_batch_job.llm_provider == ProviderType.anthropic, "Job provider should be Anthropic"
assert llm_batch_job.status == JobStatus.running, "Job status should be 'running'"
# Verify all agents are represented in batch items
agent_ids_in_items = {item.agent_id for item in items}
agent_ids_in_items = {item.agent_id for item in llm_batch_items}
expected_agent_ids = {agent.id for agent in agents}
assert agent_ids_in_items == expected_agent_ids, f"Expected agent IDs {expected_agent_ids}, got {agent_ids_in_items}"

View File

@@ -39,6 +39,8 @@ from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, Provide
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate, IdentityUpsert
from letta.schemas.job import BatchJob
from letta.schemas.job import Job
from letta.schemas.job import Job as PydanticJob
from letta.schemas.job import JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
@@ -615,6 +617,11 @@ def dummy_successful_response() -> BetaMessageBatchIndividualResponse:
)
@pytest.fixture
def letta_batch_job(server: SyncServer, default_user) -> Job:
return server.job_manager.create_job(BatchJob(user_id=default_user.id), actor=default_user)
# ======================================================================================================================
# AgentManager Tests - Basic
# ======================================================================================================================
@@ -4761,77 +4768,90 @@ def test_list_tags(server: SyncServer, default_user, default_organization):
# ======================================================================================================================
def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch):
batch = server.batch_manager.create_batch_job(
def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch, letta_batch_job):
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
assert batch.id.startswith("batch_req-")
assert batch.create_batch_response == dummy_beta_message_batch
fetched = server.batch_manager.get_batch_job_by_id(batch.id, actor=default_user)
fetched = server.batch_manager.get_llm_batch_job_by_id(batch.id, actor=default_user)
assert fetched.id == batch.id
def test_update_batch_status(server, default_user, dummy_beta_message_batch):
batch = server.batch_manager.create_batch_job(
def test_update_batch_status(server, default_user, dummy_beta_message_batch, letta_batch_job):
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
before = datetime.now(timezone.utc)
server.batch_manager.update_batch_status(
batch_id=batch.id,
server.batch_manager.update_llm_batch_status(
llm_batch_id=batch.id,
status=JobStatus.completed,
latest_polling_response=dummy_beta_message_batch,
actor=default_user,
)
updated = server.batch_manager.get_batch_job_by_id(batch.id, actor=default_user)
updated = server.batch_manager.get_llm_batch_job_by_id(batch.id, actor=default_user)
assert updated.status == JobStatus.completed
assert updated.latest_polling_response == dummy_beta_message_batch
assert updated.last_polled_at >= before
def test_create_and_get_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
batch = server.batch_manager.create_batch_job(
def test_create_and_get_batch_item(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
):
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
item = server.batch_manager.create_batch_item(
batch_id=batch.id,
item = server.batch_manager.create_llm_batch_item(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
actor=default_user,
)
assert item.batch_id == batch.id
assert item.llm_batch_id == batch.id
assert item.agent_id == sarah_agent.id
assert item.step_state == dummy_step_state
fetched = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
fetched = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
assert fetched.id == item.id
def test_update_batch_item(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response
server,
default_user,
sarah_agent,
dummy_beta_message_batch,
dummy_llm_config,
dummy_step_state,
dummy_successful_response,
letta_batch_job,
):
batch = server.batch_manager.create_batch_job(
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
item = server.batch_manager.create_batch_item(
batch_id=batch.id,
item = server.batch_manager.create_llm_batch_item(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
@@ -4840,7 +4860,7 @@ def test_update_batch_item(
updated_step_state = AgentStepState(step_number=2, tool_rules_solver=dummy_step_state.tool_rules_solver)
server.batch_manager.update_batch_item(
server.batch_manager.update_llm_batch_item(
item_id=item.id,
request_status=JobStatus.completed,
step_status=AgentStepStatus.resumed,
@@ -4849,146 +4869,166 @@ def test_update_batch_item(
actor=default_user,
)
updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
assert updated.request_status == JobStatus.completed
assert updated.batch_request_result == dummy_successful_response
def test_delete_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
batch = server.batch_manager.create_batch_job(
def test_delete_batch_item(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
):
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
item = server.batch_manager.create_batch_item(
batch_id=batch.id,
item = server.batch_manager.create_llm_batch_item(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
actor=default_user,
)
server.batch_manager.delete_batch_item(item_id=item.id, actor=default_user)
server.batch_manager.delete_llm_batch_item(item_id=item.id, actor=default_user)
with pytest.raises(NoResultFound):
server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
def test_list_running_batches(server, default_user, dummy_beta_message_batch):
server.batch_manager.create_batch_job(
def test_list_running_batches(server, default_user, dummy_beta_message_batch, letta_batch_job):
server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.running,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
running_batches = server.batch_manager.list_running_batches(actor=default_user)
running_batches = server.batch_manager.list_running_llm_batches(actor=default_user)
assert len(running_batches) >= 1
assert all(batch.status == JobStatus.running for batch in running_batches)
def test_bulk_update_batch_statuses(server, default_user, dummy_beta_message_batch):
batch = server.batch_manager.create_batch_job(
def test_bulk_update_batch_statuses(server, default_user, dummy_beta_message_batch, letta_batch_job):
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
server.batch_manager.bulk_update_batch_statuses([(batch.id, JobStatus.completed, dummy_beta_message_batch)])
server.batch_manager.bulk_update_llm_batch_statuses([(batch.id, JobStatus.completed, dummy_beta_message_batch)])
updated = server.batch_manager.get_batch_job_by_id(batch.id, actor=default_user)
updated = server.batch_manager.get_llm_batch_job_by_id(batch.id, actor=default_user)
assert updated.status == JobStatus.completed
assert updated.latest_polling_response == dummy_beta_message_batch
def test_bulk_update_batch_items_results_by_agent(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response
server,
default_user,
sarah_agent,
dummy_beta_message_batch,
dummy_llm_config,
dummy_step_state,
dummy_successful_response,
letta_batch_job,
):
batch = server.batch_manager.create_batch_job(
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
item = server.batch_manager.create_batch_item(
batch_id=batch.id,
item = server.batch_manager.create_llm_batch_item(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
actor=default_user,
)
server.batch_manager.bulk_update_batch_items_results_by_agent(
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
[ItemUpdateInfo(batch.id, sarah_agent.id, JobStatus.completed, dummy_successful_response)]
)
updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
assert updated.request_status == JobStatus.completed
assert updated.batch_request_result == dummy_successful_response
def test_bulk_update_batch_items_step_status_by_agent(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
):
batch = server.batch_manager.create_batch_job(
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
item = server.batch_manager.create_batch_item(
batch_id=batch.id,
item = server.batch_manager.create_llm_batch_item(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
actor=default_user,
)
server.batch_manager.bulk_update_batch_items_step_status_by_agent(
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
[StepStatusUpdateInfo(batch.id, sarah_agent.id, AgentStepStatus.resumed)]
)
updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
assert updated.step_status == AgentStepStatus.resumed
def test_list_batch_items_limit_and_filter(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
batch = server.batch_manager.create_batch_job(
def test_list_batch_items_limit_and_filter(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
):
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
for _ in range(3):
server.batch_manager.create_batch_item(
batch_id=batch.id,
server.batch_manager.create_llm_batch_item(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
actor=default_user,
)
all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user)
limited_items = server.batch_manager.list_batch_items(batch_id=batch.id, limit=2, actor=default_user)
all_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user)
limited_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, limit=2, actor=default_user)
assert len(all_items) >= 3
assert len(limited_items) == 2
def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
def test_list_batch_items_pagination(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
):
# Create a batch job.
batch = server.batch_manager.create_batch_job(
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
# Create 10 batch items.
created_items = []
for i in range(10):
item = server.batch_manager.create_batch_item(
batch_id=batch.id,
item = server.batch_manager.create_llm_batch_item(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
@@ -4997,7 +5037,7 @@ def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_be
created_items.append(item)
# Retrieve all items (without pagination).
all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user)
all_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user)
assert len(all_items) >= 10, f"Expected at least 10 items, got {len(all_items)}"
# Verify the items are ordered ascending by id (based on our implementation).
@@ -5009,7 +5049,7 @@ def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_be
cursor = all_items[4].id
# Retrieve items after the cursor.
paged_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=cursor)
paged_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user, after=cursor)
# All returned items should have an id greater than the cursor.
for item in paged_items:
@@ -5023,7 +5063,7 @@ def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_be
# Test pagination with a limit.
limit = 3
limited_page = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=cursor, limit=limit)
limited_page = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user, after=cursor, limit=limit)
# If more than 'limit' items remain, we should only get exactly 'limit' items.
assert len(limited_page) == min(
limit, expected_remaining
@@ -5031,23 +5071,24 @@ def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_be
# Optional: Test with a cursor beyond the last item returns an empty list.
last_cursor = sorted_ids[-1]
empty_page = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=last_cursor)
empty_page = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user, after=last_cursor)
assert empty_page == [], "Expected an empty list when cursor is after the last item"
def test_bulk_update_batch_items_request_status_by_agent(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
):
# Create a batch job
batch = server.batch_manager.create_batch_job(
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
# Create a batch item
item = server.batch_manager.create_batch_item(
batch_id=batch.id,
item = server.batch_manager.create_llm_batch_item(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
@@ -5055,55 +5096,59 @@ def test_bulk_update_batch_items_request_status_by_agent(
)
# Update the request status using the bulk update method
server.batch_manager.bulk_update_batch_items_request_status_by_agent(
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
[RequestStatusUpdateInfo(batch.id, sarah_agent.id, JobStatus.expired)]
)
# Verify the update was applied
updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
assert updated.request_status == JobStatus.expired
def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response):
def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response, letta_batch_job):
# Create a batch job
batch = server.batch_manager.create_batch_job(
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
# Attempt to update non-existent items should not raise errors
# Test with the direct bulk_update_batch_items method
# Test with the direct bulk_update_llm_batch_items method
nonexistent_pairs = [(batch.id, "nonexistent-agent-id")]
nonexistent_updates = [{"request_status": JobStatus.expired}]
# This should not raise an error, just silently skip non-existent items
server.batch_manager.bulk_update_batch_items(nonexistent_pairs, nonexistent_updates)
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates)
# Test with higher-level methods
# Results by agent
server.batch_manager.bulk_update_batch_items_results_by_agent(
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)]
)
# Step status by agent
server.batch_manager.bulk_update_batch_items_step_status_by_agent(
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)]
)
# Request status by agent
server.batch_manager.bulk_update_batch_items_request_status_by_agent(
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)]
)
def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
def test_create_batch_items_bulk(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
):
# Create a batch job
batch = server.batch_manager.create_batch_job(
llm_batch_job = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
# Prepare data for multiple batch items
@@ -5112,7 +5157,7 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m
for agent_id in agent_ids:
batch_item = LLMBatchItem(
batch_id=batch.id,
llm_batch_id=llm_batch_job.id,
agent_id=agent_id,
llm_config=dummy_llm_config,
request_status=JobStatus.created,
@@ -5122,7 +5167,7 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m
batch_items.append(batch_item)
# Call the bulk create function
created_items = server.batch_manager.create_batch_items_bulk(batch_items, actor=default_user)
created_items = server.batch_manager.create_llm_batch_items_bulk(batch_items, actor=default_user)
# Verify the correct number of items were created
assert len(created_items) == len(agent_ids)
@@ -5130,7 +5175,7 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m
# Verify each item has expected properties
for item in created_items:
assert item.id.startswith("batch_item-")
assert item.batch_id == batch.id
assert item.llm_batch_id == llm_batch_job.id
assert item.agent_id in agent_ids
assert item.llm_config == dummy_llm_config
assert item.request_status == JobStatus.created
@@ -5138,38 +5183,41 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m
assert item.step_state == dummy_step_state
# Verify items can be retrieved from the database
all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user)
all_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
assert len(all_items) >= len(agent_ids)
# Verify the IDs of created items match what's in the database
created_ids = [item.id for item in created_items]
for item_id in created_ids:
fetched = server.batch_manager.get_batch_item_by_id(item_id, actor=default_user)
fetched = server.batch_manager.get_llm_batch_item_by_id(item_id, actor=default_user)
assert fetched.id in created_ids
def test_count_batch_items(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
def test_count_batch_items(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
):
# Create a batch job first.
batch = server.batch_manager.create_batch_job(
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
# Create a specific number of batch items for this batch.
num_items = 5
for _ in range(num_items):
server.batch_manager.create_batch_item(
batch_id=batch.id,
server.batch_manager.create_llm_batch_item(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
actor=default_user,
)
# Use the count_batch_items method to count the items.
count = server.batch_manager.count_batch_items(batch_id=batch.id)
# Use the count_llm_batch_items method to count the items.
count = server.batch_manager.count_llm_batch_items(llm_batch_id=batch.id)
# Assert that the count matches the expected number.
assert count == num_items, f"Expected {num_items} items, got {count}"