diff --git a/alembic/versions/7b189006c97d_rename_batch_id_to_llm_batch_id_on_llm_.py b/alembic/versions/7b189006c97d_rename_batch_id_to_llm_batch_id_on_llm_.py new file mode 100644 index 00000000..e75e0638 --- /dev/null +++ b/alembic/versions/7b189006c97d_rename_batch_id_to_llm_batch_id_on_llm_.py @@ -0,0 +1,41 @@ +"""Rename batch_id to llm_batch_id on llm_batch_item + +Revision ID: 7b189006c97d +Revises: f2f78d62005c +Create Date: 2025-04-17 16:04:52.045672 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "7b189006c97d" +down_revision: Union[str, None] = "f2f78d62005c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("llm_batch_items", sa.Column("llm_batch_id", sa.String(), nullable=False)) + op.drop_index("ix_llm_batch_items_batch_id", table_name="llm_batch_items") + op.create_index("ix_llm_batch_items_llm_batch_id", "llm_batch_items", ["llm_batch_id"], unique=False) + op.drop_constraint("llm_batch_items_batch_id_fkey", "llm_batch_items", type_="foreignkey") + op.create_foreign_key(None, "llm_batch_items", "llm_batch_job", ["llm_batch_id"], ["id"], ondelete="CASCADE") + op.drop_column("llm_batch_items", "batch_id") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("llm_batch_items", sa.Column("batch_id", sa.VARCHAR(), autoincrement=False, nullable=False)) + op.drop_constraint(None, "llm_batch_items", type_="foreignkey") + op.create_foreign_key("llm_batch_items_batch_id_fkey", "llm_batch_items", "llm_batch_job", ["batch_id"], ["id"], ondelete="CASCADE") + op.drop_index("ix_llm_batch_items_llm_batch_id", table_name="llm_batch_items") + op.create_index("ix_llm_batch_items_batch_id", "llm_batch_items", ["batch_id"], unique=False) + op.drop_column("llm_batch_items", "llm_batch_id") + # ### end Alembic commands ### diff --git a/alembic/versions/f2f78d62005c_add_letta_batch_job_id_to_llm_batch_job.py b/alembic/versions/f2f78d62005c_add_letta_batch_job_id_to_llm_batch_job.py new file mode 100644 index 00000000..f0eb8454 --- /dev/null +++ b/alembic/versions/f2f78d62005c_add_letta_batch_job_id_to_llm_batch_job.py @@ -0,0 +1,33 @@ +"""Add letta batch job id to llm_batch_job + +Revision ID: f2f78d62005c +Revises: c3b1da3d1157 +Create Date: 2025-04-17 15:58:43.705483 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f2f78d62005c" +down_revision: Union[str, None] = "c3b1da3d1157" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("llm_batch_job", sa.Column("letta_batch_job_id", sa.String(), nullable=False)) + op.create_foreign_key(None, "llm_batch_job", "jobs", ["letta_batch_job_id"], ["id"], ondelete="CASCADE") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "llm_batch_job", type_="foreignkey") + op.drop_column("llm_batch_job", "letta_batch_job_id") + # ### end Alembic commands ### diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index b792ae3f..d2648944 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -17,6 +17,7 @@ from letta.log import get_logger from letta.orm.enums import ToolType from letta.schemas.agent import AgentState, AgentStepState from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType +from letta.schemas.job import JobUpdate from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_request import LettaBatchRequest from letta.schemas.letta_response import LettaBatchResponse @@ -29,6 +30,7 @@ from letta.server.rest_api.utils import create_heartbeat_system_message, create_ from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import compile_system_message +from letta.services.job_manager import JobManager from letta.services.llm_batch_manager import LLMBatchManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager @@ -102,6 +104,7 @@ class LettaAgentBatch: passage_manager: PassageManager, batch_manager: LLMBatchManager, sandbox_config_manager: SandboxConfigManager, + job_manager: JobManager, actor: User, use_assistant_message: bool = True, max_steps: int = 10, @@ -112,12 +115,16 @@ class LettaAgentBatch: self.passage_manager = passage_manager self.batch_manager = batch_manager self.sandbox_config_manager = sandbox_config_manager + self.job_manager = job_manager self.use_assistant_message = use_assistant_message self.actor = actor self.max_steps = max_steps async def step_until_request( - self, batch_requests: List[LettaBatchRequest], agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None + self, + batch_requests: List[LettaBatchRequest], + letta_batch_job_id: str, + agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None, ) -> LettaBatchResponse: # Basic checks if not batch_requests: @@ -162,11 +169,12 @@ class LettaAgentBatch: ) # Write the response into the jobs table, where it will get picked up by the next cron run - batch_job = self.batch_manager.create_batch_job( + llm_batch_job = self.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, # TODO: Expand to more providers create_batch_response=batch_response, actor=self.actor, status=JobStatus.running, + letta_batch_job_id=letta_batch_job_id, ) # Create batch items in bulk for all agents @@ -174,7 +182,7 @@ class LettaAgentBatch: for agent_state in agent_states: agent_step_state = agent_step_state_mapping.get(agent_state.id) batch_item = LLMBatchItem( - batch_id=batch_job.id, + llm_batch_id=llm_batch_job.id, agent_id=agent_state.id, llm_config=agent_state.llm_config, request_status=JobStatus.created, @@ -185,19 +193,21 @@ class LettaAgentBatch: # Create all batch items at once using the bulk operation if batch_items: - self.batch_manager.create_batch_items_bulk(batch_items, actor=self.actor) + self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor) return LettaBatchResponse( - batch_id=batch_job.id, - status=batch_job.status, + letta_batch_id=llm_batch_job.letta_batch_job_id, + last_llm_batch_id=llm_batch_job.id, + status=llm_batch_job.status, agent_count=len(agent_states), last_polled_at=get_utc_time(), - created_at=batch_job.created_at, + created_at=llm_batch_job.created_at, ) - async def resume_step_after_request(self, batch_id: str) -> LettaBatchResponse: + async def resume_step_after_request(self, letta_batch_id: str, llm_batch_id: str) -> LettaBatchResponse: # 1. gather everything we need - ctx = await self._collect_resume_context(batch_id) + llm_batch_job = self.batch_manager.get_llm_batch_job_by_id(llm_batch_id=llm_batch_id, actor=self.actor) + ctx = await self._collect_resume_context(llm_batch_id) # 2. persist request‑level status updates self._update_request_statuses(ctx.request_status_updates) @@ -209,19 +219,31 @@ class LettaAgentBatch: msg_map = self._persist_tool_messages(exec_results, ctx) # 5. mark steps complete - self._mark_steps_complete(batch_id, ctx.agent_ids) + self._mark_steps_complete(llm_batch_id, ctx.agent_ids) # 6. build next‑round requests / step‑state map next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map) + if len(next_reqs) == 0: + # mark batch job as completed + self.job_manager.update_job_by_id(job_id=letta_batch_id, job_update=JobUpdate(status=JobStatus.completed), actor=self.actor) + return LettaBatchResponse( + letta_batch_id=llm_batch_job.letta_batch_job_id, + last_llm_batch_id=llm_batch_job.id, + status=JobStatus.completed, + agent_count=len(ctx.agent_ids), + last_polled_at=get_utc_time(), + created_at=llm_batch_job.created_at, + ) # 7. recurse into the normal stepping pipeline return await self.step_until_request( batch_requests=next_reqs, + letta_batch_job_id=letta_batch_id, agent_step_state_mapping=next_step_state, ) - async def _collect_resume_context(self, batch_id: str) -> _ResumeContext: - batch_items = self.batch_manager.list_batch_items(batch_id=batch_id) + async def _collect_resume_context(self, llm_batch_id: str) -> _ResumeContext: + batch_items = self.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_id) agent_ids, agent_state_map = [], {} provider_results, name_map, args_map, cont_map = {}, {}, {}, {} @@ -244,7 +266,7 @@ class LettaAgentBatch: else JobStatus.cancelled if isinstance(pr, BetaMessageBatchCanceledResult) else JobStatus.expired ) ) - request_status_updates.append(RequestStatusUpdateInfo(batch_id=batch_id, agent_id=aid, request_status=status)) + request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status)) # translate provider‑specific response → OpenAI‑style tool call (unchanged) llm_client = LLMClient.create(llm_config=item.llm_config, put_inner_thoughts_first=True) @@ -270,7 +292,7 @@ class LettaAgentBatch: def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None: if updates: - self.batch_manager.bulk_update_batch_items_request_status_by_agent(updates=updates) + self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(updates=updates) def _build_sandbox(self) -> Tuple[SandboxConfig, Dict[str, Any]]: sbx_type = SandboxType.E2B if tool_settings.e2b_api_key else SandboxType.LOCAL @@ -315,9 +337,11 @@ class LettaAgentBatch: self.message_manager.create_many_messages([m for msgs in msg_map.values() for m in msgs], actor=self.actor) return msg_map - def _mark_steps_complete(self, batch_id: str, agent_ids: List[str]) -> None: - updates = [StepStatusUpdateInfo(batch_id=batch_id, agent_id=aid, step_status=AgentStepStatus.completed) for aid in agent_ids] - self.batch_manager.bulk_update_batch_items_step_status_by_agent(updates) + def _mark_steps_complete(self, llm_batch_id: str, agent_ids: List[str]) -> None: + updates = [ + StepStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, step_status=AgentStepStatus.completed) for aid in agent_ids + ] + self.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(updates) def _prepare_next_iteration( self, diff --git a/letta/jobs/llm_batch_job_polling.py b/letta/jobs/llm_batch_job_polling.py index 479f33b2..cfc0a0dd 100644 --- a/letta/jobs/llm_batch_job_polling.py +++ b/letta/jobs/llm_batch_job_polling.py @@ -102,7 +102,7 @@ async def poll_batch_updates(server: SyncServer, batch_jobs: List[LLMBatchJob], results: List[BatchPollingResult] = await asyncio.gather(*coros) # Update the server with batch status changes - server.batch_manager.bulk_update_batch_statuses(updates=results) + server.batch_manager.bulk_update_llm_batch_statuses(updates=results) logger.info(f"[Poll BatchJob] Bulk-updated {len(results)} LLM batch(es) in the DB at job level.") return results @@ -176,7 +176,7 @@ async def poll_running_llm_batches(server: "SyncServer") -> None: try: # 1. Retrieve running batch jobs - batches = server.batch_manager.list_running_batches() + batches = server.batch_manager.list_running_llm_batches() metrics.total_batches = len(batches) # TODO: Expand to more providers @@ -193,7 +193,7 @@ async def poll_running_llm_batches(server: "SyncServer") -> None: # 6. Bulk update all items for newly completed batch(es) if item_updates: metrics.updated_items_count = len(item_updates) - server.batch_manager.bulk_update_batch_items_results_by_agent(item_updates) + server.batch_manager.bulk_update_batch_llm_items_results_by_agent(item_updates) else: logger.info("[Poll BatchJob] No item-level updates needed.") diff --git a/letta/jobs/types.py b/letta/jobs/types.py index 854e0fef..f7143541 100644 --- a/letta/jobs/types.py +++ b/letta/jobs/types.py @@ -6,25 +6,25 @@ from letta.schemas.enums import AgentStepStatus, JobStatus class BatchPollingResult(NamedTuple): - batch_id: str + llm_batch_id: str request_status: JobStatus batch_response: Optional[BetaMessageBatch] class ItemUpdateInfo(NamedTuple): - batch_id: str + llm_batch_id: str agent_id: str request_status: JobStatus batch_request_result: Optional[BetaMessageBatchIndividualResponse] class StepStatusUpdateInfo(NamedTuple): - batch_id: str + llm_batch_id: str agent_id: str step_status: AgentStepStatus class RequestStatusUpdateInfo(NamedTuple): - batch_id: str + llm_batch_id: str agent_id: str request_status: JobStatus diff --git a/letta/orm/enums.py b/letta/orm/enums.py index 62aa5a42..03905eb7 100644 --- a/letta/orm/enums.py +++ b/letta/orm/enums.py @@ -16,6 +16,7 @@ class ToolType(str, Enum): class JobType(str, Enum): JOB = "job" RUN = "run" + BATCH = "batch" class ToolSourceType(str, Enum): diff --git a/letta/orm/llm_batch_items.py b/letta/orm/llm_batch_items.py index e11de396..d9d97d1a 100644 --- a/letta/orm/llm_batch_items.py +++ b/letta/orm/llm_batch_items.py @@ -20,7 +20,7 @@ class LLMBatchItem(SqlalchemyBase, OrganizationMixin, AgentMixin): __tablename__ = "llm_batch_items" __pydantic_model__ = PydanticLLMBatchItem __table_args__ = ( - Index("ix_llm_batch_items_batch_id", "batch_id"), + Index("ix_llm_batch_items_llm_batch_id", "llm_batch_id"), Index("ix_llm_batch_items_agent_id", "agent_id"), Index("ix_llm_batch_items_status", "request_status"), ) @@ -29,7 +29,7 @@ class LLMBatchItem(SqlalchemyBase, OrganizationMixin, AgentMixin): # TODO: Some still rely on the Pydantic object to do this id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"batch_item-{uuid.uuid4()}") - batch_id: Mapped[str] = mapped_column( + llm_batch_id: Mapped[str] = mapped_column( ForeignKey("llm_batch_job.id", ondelete="CASCADE"), doc="Foreign key to the LLM provider batch this item belongs to" ) diff --git a/letta/orm/llm_batch_job.py b/letta/orm/llm_batch_job.py index d86a7564..db085dc7 100644 --- a/letta/orm/llm_batch_job.py +++ b/letta/orm/llm_batch_job.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import List, Optional, Union from anthropic.types.beta.messages import BetaMessageBatch -from sqlalchemy import DateTime, Index, String +from sqlalchemy import DateTime, ForeignKey, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.custom_columns import CreateBatchResponseColumn, PollBatchResponseColumn @@ -43,6 +43,9 @@ class LLMBatchJob(SqlalchemyBase, OrganizationMixin): DateTime(timezone=True), nullable=True, doc="Last time we polled the provider for status" ) - # relationships + letta_batch_job_id: Mapped[str] = mapped_column( + String, ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, doc="ID of the Letta batch job" + ) + organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_jobs") items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="batch", lazy="selectin") diff --git a/letta/schemas/job.py b/letta/schemas/job.py index 3d5c3b2c..2e4005c9 100644 --- a/letta/schemas/job.py +++ b/letta/schemas/job.py @@ -34,6 +34,41 @@ class Job(JobBase): user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the job.") +class BatchJob(JobBase): + id: str = JobBase.generate_id_field() + user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the job.") + job_type: JobType = JobType.BATCH + + @classmethod + def from_job(cls, job: Job) -> "BatchJob": + """ + Convert a Job instance to a BatchJob instance by replacing the ID prefix. + All other fields are copied as-is. + + Args: + job: The Job instance to convert + + Returns: + A new Run instance with the same data but 'run-' prefix in ID + """ + # Convert job dict to exclude None values + job_data = job.model_dump(exclude_none=True) + + # Create new Run instance with converted data + return cls(**job_data) + + def to_job(self) -> Job: + """ + Convert this BatchJob instance to a Job instance by replacing the ID prefix. + All other fields are copied as-is. + + Returns: + A new Job instance with the same data but 'job-' prefix in ID + """ + run_data = self.model_dump(exclude_none=True) + return Job(**run_data) + + class JobUpdate(JobBase): status: Optional[JobStatus] = Field(None, description="The status of the job.") diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index 4aa66e62..64936891 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -31,3 +31,7 @@ class LettaStreamingRequest(LettaRequest): class LettaBatchRequest(LettaRequest): agent_id: str = Field(..., description="The ID of the agent to send this batch request for") + + +class CreateBatch(BaseModel): + requests: List[LettaBatchRequest] = Field(..., description="List of requests to be processed in batch.") diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index 662f0f8f..453fa30a 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -169,7 +169,8 @@ LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStat class LettaBatchResponse(BaseModel): - batch_id: str = Field(..., description="A unique identifier for this batch request.") + letta_batch_id: str = Field(..., description="A unique identifier for the Letta batch request.") + last_llm_batch_id: str = Field(..., description="A unique identifier for the most recent model provider batch request.") status: JobStatus = Field(..., description="The current status of the batch request.") agent_count: int = Field(..., description="The number of agents in the batch request.") last_polled_at: datetime = Field(..., description="The timestamp when the batch was last polled for updates.") diff --git a/letta/schemas/llm_batch_job.py b/letta/schemas/llm_batch_job.py index 5e178d32..cde072f1 100644 --- a/letta/schemas/llm_batch_job.py +++ b/letta/schemas/llm_batch_job.py @@ -20,7 +20,7 @@ class LLMBatchItem(OrmMetadataBase, validate_assignment=True): __id_prefix__ = "batch_item" id: Optional[str] = Field(None, description="The id of the batch item. Assigned by the database.") - batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.") + llm_batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.") agent_id: str = Field(..., description="The id of the agent associated with this LLM request.") llm_config: LLMConfig = Field(..., description="The LLM configuration used for this request.") @@ -45,6 +45,7 @@ class LLMBatchJob(OrmMetadataBase, validate_assignment=True): id: Optional[str] = Field(None, description="The id of the batch job. Assigned by the database.") status: JobStatus = Field(..., description="The current status of the batch (e.g., created, in_progress, done).") llm_provider: ProviderType = Field(..., description="The LLM provider used for the batch (e.g., anthropic, openai).") + letta_batch_job_id: str = Field(..., description="ID of the Letta batch job") create_batch_response: Union[BetaMessageBatch] = Field(..., description="The full JSON response from the initial batch creation.") latest_polling_response: Optional[Union[BetaMessageBatch]] = Field( diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index f085f499..c5dc3be7 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -174,6 +174,7 @@ def create_application() -> "FastAPI": async def generic_error_handler(request: Request, exc: Exception): # Log the actual error for debugging log.error(f"Unhandled error: {exc}", exc_info=True) + print(f"Unhandled error: {exc}") # Print the stack trace print(f"Stack trace: {exc}") diff --git a/letta/server/rest_api/routers/v1/__init__.py b/letta/server/rest_api/routers/v1/__init__.py index 8b983fb7..d08a1d82 100644 --- a/letta/server/rest_api/routers/v1/__init__.py +++ b/letta/server/rest_api/routers/v1/__init__.py @@ -5,6 +5,7 @@ from letta.server.rest_api.routers.v1.health import router as health_router from letta.server.rest_api.routers.v1.identities import router as identities_router from letta.server.rest_api.routers.v1.jobs import router as jobs_router from letta.server.rest_api.routers.v1.llms import router as llm_router +from letta.server.rest_api.routers.v1.messages import router as messages_router from letta.server.rest_api.routers.v1.providers import router as providers_router from letta.server.rest_api.routers.v1.runs import router as runs_router from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_configs_router @@ -29,5 +30,6 @@ ROUTERS = [ runs_router, steps_router, tags_router, + messages_router, voice_router, ] diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 7295841e..72d8dc68 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -12,7 +12,6 @@ from sqlalchemy.exc import IntegrityError, OperationalError from starlette.responses import Response, StreamingResponse from letta.agents.letta_agent import LettaAgent -from letta.agents.letta_agent_batch import LettaAgentBatch from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.log import get_logger from letta.orm.errors import NoResultFound @@ -21,8 +20,8 @@ from letta.schemas.block import Block, BlockUpdate from letta.schemas.group import Group from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion -from letta.schemas.letta_request import LettaBatchRequest, LettaRequest, LettaStreamingRequest -from letta.schemas.letta_response import LettaBatchResponse, LettaResponse +from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest +from letta.schemas.letta_response import LettaResponse from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory from letta.schemas.message import MessageCreate from letta.schemas.passage import Passage, PassageUpdate @@ -832,57 +831,3 @@ async def list_agent_groups( actor = server.user_manager.get_user_or_default(user_id=actor_id) print("in list agents with manager_type", manager_type) return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor) - - -# Batch APIs - - -@router.post("/messages/batches", response_model=LettaBatchResponse, operation_id="create_batch_message_request") -async def send_batch_messages( - batch_requests: List[LettaBatchRequest] = Body(..., description="Messages and config for all agents"), - server: SyncServer = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), -): - """ - Submit a batch of agent messages for asynchronous processing. - Creates a job that will fan out messages to all listed agents and process them in parallel. - """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - - batch_runner = LettaAgentBatch( - message_manager=server.message_manager, - agent_manager=server.agent_manager, - block_manager=server.block_manager, - passage_manager=server.passage_manager, - batch_manager=server.batch_manager, - sandbox_config_manager=server.sandbox_config_manager, - actor=actor, - ) - - return await batch_runner.step_until_request(batch_requests=batch_requests) - - -@router.get( - "/messages/batches/{batch_id}", - response_model=LettaBatchResponse, - operation_id="retrieve_batch_message_request", -) -async def retrieve_batch_message_request( - batch_id: str, - server: SyncServer = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), -): - """ - Retrieve the result or current status of a previously submitted batch message request. - """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - batch_job = server.batch_manager.get_batch_job_by_id(batch_id=batch_id, actor=actor) - agent_count = server.batch_manager.count_batch_items(batch_id=batch_id) - - return LettaBatchResponse( - batch_id=batch_id, - status=batch_job.status, - agent_count=agent_count, - last_polled_at=batch_job.last_polled_at, - created_at=batch_job.created_at, - ) diff --git a/letta/server/rest_api/routers/v1/messages.py b/letta/server/rest_api/routers/v1/messages.py new file mode 100644 index 00000000..02f7f201 --- /dev/null +++ b/letta/server/rest_api/routers/v1/messages.py @@ -0,0 +1,127 @@ +from typing import List, Optional + +from fastapi import APIRouter, Body, Depends, Header +from fastapi.exceptions import HTTPException + +from letta.agents.letta_agent_batch import LettaAgentBatch +from letta.log import get_logger +from letta.orm.errors import NoResultFound +from letta.schemas.job import BatchJob, JobStatus, JobType +from letta.schemas.letta_request import CreateBatch +from letta.server.rest_api.utils import get_letta_server +from letta.server.server import SyncServer + +router = APIRouter(prefix="/messages", tags=["messages"]) + +logger = get_logger(__name__) + + +# Batch APIs + + +@router.post( + "/batches", + response_model=BatchJob, + operation_id="create_messages_batch", +) +async def create_messages_batch( + request: CreateBatch = Body(..., description="Messages and config for all agents"), + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Submit a batch of agent messages for asynchronous processing. + Creates a job that will fan out messages to all listed agents and process them in parallel. + """ + print("GOT REQQUEST", request) + try: + actor = server.user_manager.get_user_or_default(user_id=actor_id) + print("ACTOR", actor) + + # Create a new job + batch_job = BatchJob( + user_id=actor.id, + status=JobStatus.created, + metadata={ + "job_type": "batch_messages", + }, + ) + print("BATCH JOB", batch_job) + + # create the batch runner + batch_runner = LettaAgentBatch( + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + batch_manager=server.batch_manager, + sandbox_config_manager=server.sandbox_config_manager, + job_manager=server.job_manager, + actor=actor, + ) + print("call step_until_request", batch_job) + llm_batch_job = await batch_runner.step_until_request(batch_requests=request.requests, letta_batch_job_id=batch_job.id) + + # TODO: update run metadata + batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor) + except Exception as e: + print("Error creating batch job", e) + import traceback + + traceback.print_exc() + raise + return batch_job + + +@router.get("/batches/{batch_id}", response_model=BatchJob, operation_id="retrieve_batch_run") +async def retrieve_batch_run( + batch_id: str, + actor_id: Optional[str] = Header(None, alias="user_id"), + server: "SyncServer" = Depends(get_letta_server), +): + """ + Get the status of a batch run. + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + + try: + job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor) + return BatchJob.from_job(job) + except NoResultFound: + raise HTTPException(status_code=404, detail="Batch not found") + + +@router.get("/batches", response_model=List[BatchJob], operation_id="list_batch_runs") +async def list_batch_runs( + actor_id: Optional[str] = Header(None, alias="user_id"), + server: "SyncServer" = Depends(get_letta_server), +): + """ + List all batch runs. + """ + # TODO: filter + actor = server.user_manager.get_user_or_default(user_id=actor_id) + + jobs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.BATCH) + print("ACTIVE", jobs) + return [BatchJob.from_job(job) for job in jobs] + + +@router.patch("/batches/{batch_id}/cancel", operation_id="cancel_batch_run") +async def cancel_batch_run( + batch_id: str, + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Cancel a batch run. + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + + try: + job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor) + job.status = JobStatus.cancelled + server.job_manager.update_job_by_id(job_id=job, job=job) + # TODO: actually cancel it + except NoResultFound: + raise HTTPException(status_code=404, detail="Run not found") diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 7a4c7f3c..d5c4f577 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -15,6 +15,7 @@ from letta.orm.sqlalchemy_base import AccessType from letta.orm.step import Step from letta.orm.step import Step as StepModel from letta.schemas.enums import JobStatus, MessageRole +from letta.schemas.job import BatchJob as PydanticBatchJob from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate, LettaRequestConfig from letta.schemas.letta_message import LettaMessage @@ -36,7 +37,9 @@ class JobManager: self.session_maker = db_context @enforce_types - def create_job(self, pydantic_job: Union[PydanticJob, PydanticRun], actor: PydanticUser) -> Union[PydanticJob, PydanticRun]: + def create_job( + self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser + ) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]: """Create a new job based on the JobCreate schema.""" with self.session_maker() as session: # Associate the job with the user diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index 72ca9847..0e944003 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -28,11 +28,12 @@ class LLMBatchManager: self.session_maker = db_context @enforce_types - def create_batch_job( + def create_llm_batch_job( self, llm_provider: ProviderType, create_batch_response: BetaMessageBatch, actor: PydanticUser, + letta_batch_job_id: str, status: JobStatus = JobStatus.created, ) -> PydanticLLMBatchJob: """Create a new LLM batch job.""" @@ -42,51 +43,52 @@ class LLMBatchManager: llm_provider=llm_provider, create_batch_response=create_batch_response, organization_id=actor.organization_id, + letta_batch_job_id=letta_batch_job_id, ) batch.create(session, actor=actor) return batch.to_pydantic() @enforce_types - def get_batch_job_by_id(self, batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob: + def get_llm_batch_job_by_id(self, llm_batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob: """Retrieve a single batch job by ID.""" with self.session_maker() as session: - batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor) + batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor) return batch.to_pydantic() @enforce_types - def update_batch_status( + def update_llm_batch_status( self, - batch_id: str, + llm_batch_id: str, status: JobStatus, actor: Optional[PydanticUser] = None, latest_polling_response: Optional[BetaMessageBatch] = None, ) -> PydanticLLMBatchJob: """Update a batch job’s status and optionally its polling response.""" with self.session_maker() as session: - batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor) + batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor) batch.status = status batch.latest_polling_response = latest_polling_response batch.last_polled_at = datetime.datetime.now(datetime.timezone.utc) batch = batch.update(db_session=session, actor=actor) return batch.to_pydantic() - def bulk_update_batch_statuses( + def bulk_update_llm_batch_statuses( self, updates: List[BatchPollingResult], ) -> None: """ Efficiently update many LLMBatchJob rows. This is used by the cron jobs. - `updates` = [(batch_id, new_status, polling_response_or_None), …] + `updates` = [(llm_batch_id, new_status, polling_response_or_None), …] """ now = datetime.datetime.now(datetime.timezone.utc) with self.session_maker() as session: mappings = [] - for batch_id, status, response in updates: + for llm_batch_id, status, response in updates: mappings.append( { - "id": batch_id, + "id": llm_batch_id, "status": status, "latest_polling_response": response, "last_polled_at": now, @@ -97,14 +99,51 @@ class LLMBatchManager: session.commit() @enforce_types - def delete_batch_request(self, batch_id: str, actor: PydanticUser) -> None: + def list_llm_batch_jobs( + self, + letta_batch_id: str, + limit: Optional[int] = None, + actor: Optional[PydanticUser] = None, + after: Optional[str] = None, + ) -> List[PydanticLLMBatchItem]: + """ + List all batch items for a given llm_batch_id, optionally filtered by additional criteria and limited in count. + + Optional filters: + - after: A cursor string. Only items with an `id` greater than this value are returned. + - agent_id: Restrict the result set to a specific agent. + - request_status: Filter items based on their request status (e.g., created, completed, expired). + - step_status: Filter items based on their step execution status. + + The results are ordered by their id in ascending order. + """ + with self.session_maker() as session: + query = session.query(LLMBatchJob).filter(LLMBatchJob.letta_batch_job_id == letta_batch_id) + + if actor is not None: + query = query.filter(LLMBatchJob.organization_id == actor.organization_id) + + # Additional optional filters + if after is not None: + query = query.filter(LLMBatchJob.id > after) + + query = query.order_by(LLMBatchJob.id.asc()) + + if limit is not None: + query = query.limit(limit) + + results = query.all() + return [item.to_pydantic() for item in results] + + @enforce_types + def delete_llm_batch_request(self, llm_batch_id: str, actor: PydanticUser) -> None: """Hard delete a batch job by ID.""" with self.session_maker() as session: - batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor) + batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor) batch.hard_delete(db_session=session, actor=actor) @enforce_types - def list_running_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]: + def list_running_llm_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]: """Return all running LLM batch jobs, optionally filtered by actor's organization.""" with self.session_maker() as session: query = session.query(LLMBatchJob).filter(LLMBatchJob.status == JobStatus.running) @@ -116,9 +155,9 @@ class LLMBatchManager: return [batch.to_pydantic() for batch in results] @enforce_types - def create_batch_item( + def create_llm_batch_item( self, - batch_id: str, + llm_batch_id: str, agent_id: str, llm_config: LLMConfig, actor: PydanticUser, @@ -129,7 +168,7 @@ class LLMBatchManager: """Create a new batch item.""" with self.session_maker() as session: item = LLMBatchItem( - batch_id=batch_id, + llm_batch_id=llm_batch_id, agent_id=agent_id, llm_config=llm_config, request_status=request_status, @@ -141,7 +180,7 @@ class LLMBatchManager: return item.to_pydantic() @enforce_types - def create_batch_items_bulk(self, llm_batch_items: List[PydanticLLMBatchItem], actor: PydanticUser) -> List[PydanticLLMBatchItem]: + def create_llm_batch_items_bulk(self, llm_batch_items: List[PydanticLLMBatchItem], actor: PydanticUser) -> List[PydanticLLMBatchItem]: """ Create multiple batch items in bulk for better performance. @@ -157,7 +196,7 @@ class LLMBatchManager: orm_items = [] for item in llm_batch_items: orm_item = LLMBatchItem( - batch_id=item.batch_id, + llm_batch_id=item.llm_batch_id, agent_id=item.agent_id, llm_config=item.llm_config, request_status=item.request_status, @@ -174,14 +213,14 @@ class LLMBatchManager: return [item.to_pydantic() for item in created_items] @enforce_types - def get_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem: + def get_llm_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem: """Retrieve a single batch item by ID.""" with self.session_maker() as session: item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) return item.to_pydantic() @enforce_types - def update_batch_item( + def update_llm_batch_item( self, item_id: str, actor: PydanticUser, @@ -206,9 +245,9 @@ class LLMBatchManager: return item.update(db_session=session, actor=actor).to_pydantic() @enforce_types - def list_batch_items( + def list_llm_batch_items( self, - batch_id: str, + llm_batch_id: str, limit: Optional[int] = None, actor: Optional[PydanticUser] = None, after: Optional[str] = None, @@ -217,7 +256,7 @@ class LLMBatchManager: step_status: Optional[AgentStepStatus] = None, ) -> List[PydanticLLMBatchItem]: """ - List all batch items for a given batch_id, optionally filtered by additional criteria and limited in count. + List all batch items for a given llm_batch_id, optionally filtered by additional criteria and limited in count. Optional filters: - after: A cursor string. Only items with an `id` greater than this value are returned. @@ -228,7 +267,7 @@ class LLMBatchManager: The results are ordered by their id in ascending order. """ with self.session_maker() as session: - query = session.query(LLMBatchItem).filter(LLMBatchItem.batch_id == batch_id) + query = session.query(LLMBatchItem).filter(LLMBatchItem.llm_batch_id == llm_batch_id) if actor is not None: query = query.filter(LLMBatchItem.organization_id == actor.organization_id) @@ -251,36 +290,36 @@ class LLMBatchManager: results = query.all() return [item.to_pydantic() for item in results] - def bulk_update_batch_items( + def bulk_update_llm_batch_items( self, - batch_id_agent_id_pairs: List[Tuple[str, str]], + llm_batch_id_agent_id_pairs: List[Tuple[str, str]], field_updates: List[Dict[str, Any]], ) -> None: """ - Efficiently update multiple LLMBatchItem rows by (batch_id, agent_id) pairs. + Efficiently update multiple LLMBatchItem rows by (llm_batch_id, agent_id) pairs. Args: - batch_id_agent_id_pairs: List of (batch_id, agent_id) tuples identifying items to update + llm_batch_id_agent_id_pairs: List of (llm_batch_id, agent_id) tuples identifying items to update field_updates: List of dictionaries containing the fields to update for each item """ - if not batch_id_agent_id_pairs or not field_updates: + if not llm_batch_id_agent_id_pairs or not field_updates: return - if len(batch_id_agent_id_pairs) != len(field_updates): + if len(llm_batch_id_agent_id_pairs) != len(field_updates): raise ValueError("batch_id_agent_id_pairs and field_updates must have the same length") with self.session_maker() as session: # Lookup primary keys items = ( - session.query(LLMBatchItem.id, LLMBatchItem.batch_id, LLMBatchItem.agent_id) - .filter(tuple_(LLMBatchItem.batch_id, LLMBatchItem.agent_id).in_(batch_id_agent_id_pairs)) + session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id) + .filter(tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs)) .all() ) pair_to_pk = {(b, a): id for id, b, a in items} mappings = [] - for (batch_id, agent_id), fields in zip(batch_id_agent_id_pairs, field_updates): - pk_id = pair_to_pk.get((batch_id, agent_id)) + for (llm_batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates): + pk_id = pair_to_pk.get((llm_batch_id, agent_id)) if not pk_id: continue @@ -293,12 +332,12 @@ class LLMBatchManager: session.commit() @enforce_types - def bulk_update_batch_items_results_by_agent( + def bulk_update_batch_llm_items_results_by_agent( self, updates: List[ItemUpdateInfo], ) -> None: """Update request status and batch results for multiple batch items.""" - batch_id_agent_id_pairs = [(update.batch_id, update.agent_id) for update in updates] + batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates] field_updates = [ { "request_status": update.request_status, @@ -307,48 +346,48 @@ class LLMBatchManager: for update in updates ] - self.bulk_update_batch_items(batch_id_agent_id_pairs, field_updates) + self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates) @enforce_types - def bulk_update_batch_items_step_status_by_agent( + def bulk_update_llm_batch_items_step_status_by_agent( self, updates: List[StepStatusUpdateInfo], ) -> None: """Update step status for multiple batch items.""" - batch_id_agent_id_pairs = [(update.batch_id, update.agent_id) for update in updates] + batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates] field_updates = [{"step_status": update.step_status} for update in updates] - self.bulk_update_batch_items(batch_id_agent_id_pairs, field_updates) + self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates) @enforce_types - def bulk_update_batch_items_request_status_by_agent( + def bulk_update_llm_batch_items_request_status_by_agent( self, updates: List[RequestStatusUpdateInfo], ) -> None: """Update request status for multiple batch items.""" - batch_id_agent_id_pairs = [(update.batch_id, update.agent_id) for update in updates] + batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates] field_updates = [{"request_status": update.request_status} for update in updates] - self.bulk_update_batch_items(batch_id_agent_id_pairs, field_updates) + self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates) @enforce_types - def delete_batch_item(self, item_id: str, actor: PydanticUser) -> None: + def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None: """Hard delete a batch item by ID.""" with self.session_maker() as session: item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) item.hard_delete(db_session=session, actor=actor) @enforce_types - def count_batch_items(self, batch_id: str) -> int: + def count_llm_batch_items(self, llm_batch_id: str) -> int: """ - Efficiently count the number of batch items for a given batch_id. + Efficiently count the number of batch items for a given llm_batch_id. Args: - batch_id (str): The batch identifier to count items for. + llm_batch_id (str): The batch identifier to count items for. Returns: - int: The total number of batch items associated with the given batch_id. + int: The total number of batch items associated with the given llm_batch_id. """ with self.session_maker() as session: - count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.batch_id == batch_id).scalar() + count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.llm_batch_id == llm_batch_id).scalar() return count or 0 diff --git a/tests/integration_test_batch.py b/tests/integration_test_batch.py new file mode 100644 index 00000000..adb16664 --- /dev/null +++ b/tests/integration_test_batch.py @@ -0,0 +1,80 @@ +# import time +# +# import pytest +# from letta_client import Letta, LettaBatchRequest, MessageCreate, TextContent +# +# +# @pytest.fixture(scope="module") +# def client(): +# return Letta(base_url="http://localhost:8283") +# +# +# def test_create_batch(client: Letta): +# +# # create agents +# agent1 = client.agents.create( +# name="agent1", +# memory_blocks=[{"label": "persona", "value": "you are agent 1"}], +# model="anthropic/claude-3-7-sonnet-20250219", +# embedding="letta/letta-free", +# ) +# agent2 = client.agents.create( +# name="agent2", +# memory_blocks=[{"label": "persona", "value": "you are agent 2"}], +# model="anthropic/claude-3-7-sonnet-20250219", +# embedding="letta/letta-free", +# ) +# +# # create a run +# run = client.messages.batches.create( +# requests=[ +# LettaBatchRequest( +# messages=[ +# MessageCreate( +# role="user", +# content=[ +# TextContent( +# text="text", +# ) +# ], +# ) +# ], +# agent_id=agent1.id, +# ), +# LettaBatchRequest( +# messages=[ +# MessageCreate( +# role="user", +# content=[ +# TextContent( +# text="text", +# ) +# ], +# ) +# ], +# agent_id=agent2.id, +# ), +# ] +# ) +# assert run is not None +# +# # list batches +# batches = client.messages.batches.list() +# assert len(batches) > 0, f"Expected 1 batch, got {len(batches)}" +# +# # check run status +# while True: +# run = client.messages.batches.retrieve(batch_id=run.id) +# if run.status == "completed": +# break +# print("Waiting for run to complete...", run.status) +# time.sleep(1) +# +# # get the batch results +# results = client.messages.batches.retrieve( +# run_id=run.id, +# ) +# assert results is not None +# print(results) +# +# # cancel a run diff --git a/tests/integration_test_batch_api_cron_jobs.py b/tests/integration_test_batch_api_cron_jobs.py index 7c0322e4..044192e1 100644 --- a/tests/integration_test_batch_api_cron_jobs.py +++ b/tests/integration_test_batch_api_cron_jobs.py @@ -23,6 +23,7 @@ from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.orm import Base from letta.schemas.agent import AgentStepState from letta.schemas.enums import JobStatus, ProviderType +from letta.schemas.job import BatchJob from letta.schemas.llm_config import LLMConfig from letta.schemas.tool_rule import InitToolRule from letta.server.db import db_context @@ -145,13 +146,21 @@ def create_test_agent(client, name, model="anthropic/claude-3-5-sonnet-20241022" ) -def create_test_batch_job(server, batch_response, default_user): +def create_test_letta_batch_job(server, default_user): """Create a test batch job with the given batch response.""" - return server.batch_manager.create_batch_job( + return server.job_manager.create_job(BatchJob(user_id=default_user.id), actor=default_user) + + +def create_test_llm_batch_job(server, batch_response, default_user): + """Create a test batch job with the given batch response.""" + letta_batch_job = create_test_letta_batch_job(server, default_user) + + return server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, create_batch_response=batch_response, actor=default_user, status=JobStatus.running, + letta_batch_job_id=letta_batch_job.id, ) @@ -171,8 +180,8 @@ def create_test_batch_item(server, batch_id, agent_id, default_user): step_number=1, tool_rules_solver=ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")]) ) - return server.batch_manager.create_batch_item( - batch_id=batch_id, + return server.batch_manager.create_llm_batch_item( + llm_batch_id=batch_id, agent_id=agent_id, llm_config=dummy_llm_config, step_state=common_step_state, @@ -242,8 +251,8 @@ async def test_polling_mixed_batch_jobs(client, default_user, server): agent_c = create_test_agent(client, "agent_c") # --- Step 2: Create batch jobs --- - job_a = create_test_batch_job(server, batch_a_resp, default_user) - job_b = create_test_batch_job(server, batch_b_resp, default_user) + job_a = create_test_llm_batch_job(server, batch_a_resp, default_user) + job_b = create_test_llm_batch_job(server, batch_b_resp, default_user) # --- Step 3: Create batch items --- item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user) @@ -258,8 +267,8 @@ async def test_polling_mixed_batch_jobs(client, default_user, server): await poll_running_llm_batches(server) # --- Step 6: Verify batch job status updates --- - updated_job_a = server.batch_manager.get_batch_job_by_id(batch_id=job_a.id, actor=default_user) - updated_job_b = server.batch_manager.get_batch_job_by_id(batch_id=job_b.id, actor=default_user) + updated_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user) + updated_job_b = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_b.id, actor=default_user) # Job A should remain running since its processing_status is "in_progress" assert updated_job_a.status == JobStatus.running @@ -273,17 +282,17 @@ async def test_polling_mixed_batch_jobs(client, default_user, server): # --- Step 7: Verify batch item status updates --- # Item A should remain unchanged - updated_item_a = server.batch_manager.get_batch_item_by_id(item_a.id, actor=default_user) + updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user) assert updated_item_a.request_status == JobStatus.created assert updated_item_a.batch_request_result is None # Item B should be marked as completed with a successful result - updated_item_b = server.batch_manager.get_batch_item_by_id(item_b.id, actor=default_user) + updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user) assert updated_item_b.request_status == JobStatus.completed assert updated_item_b.batch_request_result is not None # Item C should be marked as failed with an error result - updated_item_c = server.batch_manager.get_batch_item_by_id(item_c.id, actor=default_user) + updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user) assert updated_item_c.request_status == JobStatus.failed assert updated_item_c.batch_request_result is not None @@ -307,11 +316,11 @@ async def test_polling_mixed_batch_jobs(client, default_user, server): # --- Step 9: Verify that nothing changed for completed jobs --- # Refresh all objects - final_job_a = server.batch_manager.get_batch_job_by_id(batch_id=job_a.id, actor=default_user) - final_job_b = server.batch_manager.get_batch_job_by_id(batch_id=job_b.id, actor=default_user) - final_item_a = server.batch_manager.get_batch_item_by_id(item_a.id, actor=default_user) - final_item_b = server.batch_manager.get_batch_item_by_id(item_b.id, actor=default_user) - final_item_c = server.batch_manager.get_batch_item_by_id(item_c.id, actor=default_user) + final_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user) + final_job_b = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_b.id, actor=default_user) + final_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user) + final_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user) + final_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user) # Job A should still be polling (last_polled_at should update) assert final_job_a.status == JobStatus.running diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index 1e13ec3b..9f5d5f36 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -33,6 +33,7 @@ from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.orm import Base from letta.schemas.agent import AgentState, AgentStepState from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType +from letta.schemas.job import BatchJob from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_request import LettaBatchRequest from letta.schemas.message import MessageCreate @@ -256,7 +257,7 @@ def clear_batch_tables(): """Clear batch-related tables before each test.""" with db_context() as session: for table in reversed(Base.metadata.sorted_tables): - if table.name in {"llm_batch_job", "llm_batch_items"}: + if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}: session.execute(table.delete()) # Truncate table session.commit() @@ -305,6 +306,22 @@ def client(server_url): return Letta(base_url=server_url) +@pytest.fixture +def batch_job(default_user, server): + job = BatchJob( + user_id=default_user.id, + status=JobStatus.created, + metadata={ + "job_type": "batch_messages", + }, + ) + job = server.job_manager.create_job(pydantic_job=job, actor=default_user) + yield job + + # cleanup + server.job_manager.delete_job_by_id(job.id, actor=default_user) + + class MockAsyncIterable: def __init__(self, items): self.items = items @@ -324,8 +341,8 @@ class MockAsyncIterable: @pytest.mark.asyncio -async def test_resume_step_after_request_happy_path( - disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map +async def test_resume_step_after_request_all_continue( + disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): anthropic_batch_id = "msgbatch_test_12345" dummy_batch_response = create_batch_response( @@ -342,6 +359,7 @@ async def test_resume_step_after_request_happy_path( passage_manager=server.passage_manager, batch_manager=server.batch_manager, sandbox_config_manager=server.sandbox_config_manager, + job_manager=server.job_manager, actor=default_user, ) @@ -349,15 +367,20 @@ async def test_resume_step_after_request_happy_path( pre_resume_response = await batch_runner.step_until_request( batch_requests=batch_requests, agent_step_state_mapping=step_state_map, + letta_batch_job_id=batch_job.id, ) # Basic sanity checks (This is tested more thoroughly in `test_step_until_request_prepares_and_submits_batch_correctly` # Verify batch items - items = server.batch_manager.list_batch_items(batch_id=pre_resume_response.batch_id, actor=default_user) - assert len(items) == 3, f"Expected 3 batch items, got {len(items)}" + llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user) + assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}" + + llm_batch_job = llm_batch_jobs[0] + llm_batch_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user) + assert len(llm_batch_items) == 3, f"Expected 3 llm_batch_items, got {len(llm_batch_items)}" # 2. Invoke the polling job and mock responses from Anthropic - mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.batch_id, processing_status="ended")) + mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended")) with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve): mock_items = [ @@ -372,13 +395,13 @@ async def test_resume_step_after_request_happy_path( await poll_running_llm_batches(server) # Verify database records were updated correctly - job = server.batch_manager.get_batch_job_by_id(pre_resume_response.batch_id, actor=default_user) + llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user) # Verify job properties - assert job.status == JobStatus.completed, "Job status should be 'completed'" + assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'" # Verify batch items - items = server.batch_manager.list_batch_items(batch_id=job.id, actor=default_user) + items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user) assert len(items) == 3, f"Expected 3 batch items, got {len(items)}" assert all([item.request_status == JobStatus.completed for item in items]) @@ -390,22 +413,27 @@ async def test_resume_step_after_request_happy_path( passage_manager=server.passage_manager, batch_manager=server.batch_manager, sandbox_config_manager=server.sandbox_config_manager, + job_manager=server.job_manager, actor=default_user, ) with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents} - post_resume_response = await letta_batch_agent.resume_step_after_request(batch_id=pre_resume_response.batch_id) + post_resume_response = await letta_batch_agent.resume_step_after_request( + letta_batch_id=pre_resume_response.letta_batch_id, llm_batch_id=llm_batch_job.id + ) - # A *new* batch job should have been spawned assert ( - post_resume_response.batch_id != pre_resume_response.batch_id - ), "resume_step_after_request is expected to enqueue a follow‑up batch job." + post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id + ), "resume_step_after_request is expected to have the same letta_batch_id" + assert ( + post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id + ), "resume_step_after_request is expected to have different llm_batch_id." assert post_resume_response.status == JobStatus.running assert post_resume_response.agent_count == 3 # New batch‑items should exist, initialised in (created, paused) state - new_items = server.batch_manager.list_batch_items(batch_id=post_resume_response.batch_id, actor=default_user) + new_items = server.batch_manager.list_llm_batch_items(llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user) assert len(new_items) == 3, f"Expected 3 new batch items, got {len(new_items)}" assert {i.request_status for i in new_items} == {JobStatus.created} assert {i.step_status for i in new_items} == {AgentStepStatus.paused} @@ -420,7 +448,7 @@ async def test_resume_step_after_request_happy_path( # Old items must have been flipped to completed / finished earlier # (sanity – we already asserted this above, but we keep it close for clarity) - old_items = server.batch_manager.list_batch_items(batch_id=pre_resume_response.batch_id, actor=default_user) + old_items = server.batch_manager.list_llm_batch_items(llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user) assert {i.request_status for i in old_items} == {JobStatus.completed} assert {i.step_status for i in old_items} == {AgentStepStatus.completed} @@ -440,7 +468,7 @@ async def test_resume_step_after_request_happy_path( @pytest.mark.asyncio async def test_step_until_request_prepares_and_submits_batch_correctly( - disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response + disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job ): """ Test that step_until_request correctly: @@ -512,6 +540,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly( passage_manager=server.passage_manager, batch_manager=server.batch_manager, sandbox_config_manager=server.sandbox_config_manager, + job_manager=server.job_manager, actor=default_user, ) @@ -519,23 +548,25 @@ async def test_step_until_request_prepares_and_submits_batch_correctly( response = await batch_runner.step_until_request( batch_requests=batch_requests, agent_step_state_mapping=step_state_map, + letta_batch_job_id=batch_job.id, ) # Verify the mock was called exactly once mock_send.assert_called_once() # Verify database records were created correctly - job = server.batch_manager.get_batch_job_by_id(response.batch_id, actor=default_user) + llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=response.letta_batch_id, actor=default_user) + assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}" + + llm_batch_job = llm_batch_jobs[0] + llm_batch_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user) + assert len(llm_batch_items) == 3, f"Expected 3 llm_batch_items, got {len(llm_batch_items)}" # Verify job properties - assert job.llm_provider == ProviderType.anthropic, "Job provider should be Anthropic" - assert job.status == JobStatus.running, "Job status should be 'running'" - - # Verify batch items - items = server.batch_manager.list_batch_items(batch_id=job.id, actor=default_user) - assert len(items) == 3, f"Expected 3 batch items, got {len(items)}" + assert llm_batch_job.llm_provider == ProviderType.anthropic, "Job provider should be Anthropic" + assert llm_batch_job.status == JobStatus.running, "Job status should be 'running'" # Verify all agents are represented in batch items - agent_ids_in_items = {item.agent_id for item in items} + agent_ids_in_items = {item.agent_id for item in llm_batch_items} expected_agent_ids = {agent.id for agent in agents} assert agent_ids_in_items == expected_agent_ids, f"Expected agent IDs {expected_agent_ids}, got {agent_ids_in_items}" diff --git a/tests/test_managers.py b/tests/test_managers.py index 45fafe5d..2170b091 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -39,6 +39,8 @@ from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, Provide from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate, IdentityUpsert +from letta.schemas.job import BatchJob +from letta.schemas.job import Job from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate, LettaRequestConfig from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage @@ -615,6 +617,11 @@ def dummy_successful_response() -> BetaMessageBatchIndividualResponse: ) +@pytest.fixture +def letta_batch_job(server: SyncServer, default_user) -> Job: + return server.job_manager.create_job(BatchJob(user_id=default_user.id), actor=default_user) + + # ====================================================================================================================== # AgentManager Tests - Basic # ====================================================================================================================== @@ -4761,77 +4768,90 @@ def test_list_tags(server: SyncServer, default_user, default_organization): # ====================================================================================================================== -def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch): - batch = server.batch_manager.create_batch_job( +def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch, letta_batch_job): + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) assert batch.id.startswith("batch_req-") assert batch.create_batch_response == dummy_beta_message_batch - fetched = server.batch_manager.get_batch_job_by_id(batch.id, actor=default_user) + fetched = server.batch_manager.get_llm_batch_job_by_id(batch.id, actor=default_user) assert fetched.id == batch.id -def test_update_batch_status(server, default_user, dummy_beta_message_batch): - batch = server.batch_manager.create_batch_job( +def test_update_batch_status(server, default_user, dummy_beta_message_batch, letta_batch_job): + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) before = datetime.now(timezone.utc) - server.batch_manager.update_batch_status( - batch_id=batch.id, + server.batch_manager.update_llm_batch_status( + llm_batch_id=batch.id, status=JobStatus.completed, latest_polling_response=dummy_beta_message_batch, actor=default_user, ) - updated = server.batch_manager.get_batch_job_by_id(batch.id, actor=default_user) + updated = server.batch_manager.get_llm_batch_job_by_id(batch.id, actor=default_user) assert updated.status == JobStatus.completed assert updated.latest_polling_response == dummy_beta_message_batch assert updated.last_polled_at >= before -def test_create_and_get_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): - batch = server.batch_manager.create_batch_job( +def test_create_and_get_batch_item( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +): + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_batch_item( - batch_id=batch.id, + item = server.batch_manager.create_llm_batch_item( + llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, step_state=dummy_step_state, actor=default_user, ) - assert item.batch_id == batch.id + assert item.llm_batch_id == batch.id assert item.agent_id == sarah_agent.id assert item.step_state == dummy_step_state - fetched = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user) + fetched = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) assert fetched.id == item.id def test_update_batch_item( - server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response + server, + default_user, + sarah_agent, + dummy_beta_message_batch, + dummy_llm_config, + dummy_step_state, + dummy_successful_response, + letta_batch_job, ): - batch = server.batch_manager.create_batch_job( + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_batch_item( - batch_id=batch.id, + item = server.batch_manager.create_llm_batch_item( + llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, step_state=dummy_step_state, @@ -4840,7 +4860,7 @@ def test_update_batch_item( updated_step_state = AgentStepState(step_number=2, tool_rules_solver=dummy_step_state.tool_rules_solver) - server.batch_manager.update_batch_item( + server.batch_manager.update_llm_batch_item( item_id=item.id, request_status=JobStatus.completed, step_status=AgentStepStatus.resumed, @@ -4849,146 +4869,166 @@ def test_update_batch_item( actor=default_user, ) - updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user) + updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) assert updated.request_status == JobStatus.completed assert updated.batch_request_result == dummy_successful_response -def test_delete_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): - batch = server.batch_manager.create_batch_job( +def test_delete_batch_item( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +): + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_batch_item( - batch_id=batch.id, + item = server.batch_manager.create_llm_batch_item( + llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, step_state=dummy_step_state, actor=default_user, ) - server.batch_manager.delete_batch_item(item_id=item.id, actor=default_user) + server.batch_manager.delete_llm_batch_item(item_id=item.id, actor=default_user) with pytest.raises(NoResultFound): - server.batch_manager.get_batch_item_by_id(item.id, actor=default_user) + server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) -def test_list_running_batches(server, default_user, dummy_beta_message_batch): - server.batch_manager.create_batch_job( +def test_list_running_batches(server, default_user, dummy_beta_message_batch, letta_batch_job): + server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.running, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) - running_batches = server.batch_manager.list_running_batches(actor=default_user) + running_batches = server.batch_manager.list_running_llm_batches(actor=default_user) assert len(running_batches) >= 1 assert all(batch.status == JobStatus.running for batch in running_batches) -def test_bulk_update_batch_statuses(server, default_user, dummy_beta_message_batch): - batch = server.batch_manager.create_batch_job( +def test_bulk_update_batch_statuses(server, default_user, dummy_beta_message_batch, letta_batch_job): + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) - server.batch_manager.bulk_update_batch_statuses([(batch.id, JobStatus.completed, dummy_beta_message_batch)]) + server.batch_manager.bulk_update_llm_batch_statuses([(batch.id, JobStatus.completed, dummy_beta_message_batch)]) - updated = server.batch_manager.get_batch_job_by_id(batch.id, actor=default_user) + updated = server.batch_manager.get_llm_batch_job_by_id(batch.id, actor=default_user) assert updated.status == JobStatus.completed assert updated.latest_polling_response == dummy_beta_message_batch def test_bulk_update_batch_items_results_by_agent( - server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response + server, + default_user, + sarah_agent, + dummy_beta_message_batch, + dummy_llm_config, + dummy_step_state, + dummy_successful_response, + letta_batch_job, ): - batch = server.batch_manager.create_batch_job( + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_batch_item( - batch_id=batch.id, + item = server.batch_manager.create_llm_batch_item( + llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, step_state=dummy_step_state, actor=default_user, ) - server.batch_manager.bulk_update_batch_items_results_by_agent( + server.batch_manager.bulk_update_batch_llm_items_results_by_agent( [ItemUpdateInfo(batch.id, sarah_agent.id, JobStatus.completed, dummy_successful_response)] ) - updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user) + updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) assert updated.request_status == JobStatus.completed assert updated.batch_request_result == dummy_successful_response def test_bulk_update_batch_items_step_status_by_agent( - server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job ): - batch = server.batch_manager.create_batch_job( + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_batch_item( - batch_id=batch.id, + item = server.batch_manager.create_llm_batch_item( + llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, step_state=dummy_step_state, actor=default_user, ) - server.batch_manager.bulk_update_batch_items_step_status_by_agent( + server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent( [StepStatusUpdateInfo(batch.id, sarah_agent.id, AgentStepStatus.resumed)] ) - updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user) + updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) assert updated.step_status == AgentStepStatus.resumed -def test_list_batch_items_limit_and_filter(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): - batch = server.batch_manager.create_batch_job( +def test_list_batch_items_limit_and_filter( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +): + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) for _ in range(3): - server.batch_manager.create_batch_item( - batch_id=batch.id, + server.batch_manager.create_llm_batch_item( + llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, step_state=dummy_step_state, actor=default_user, ) - all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user) - limited_items = server.batch_manager.list_batch_items(batch_id=batch.id, limit=2, actor=default_user) + all_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user) + limited_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, limit=2, actor=default_user) assert len(all_items) >= 3 assert len(limited_items) == 2 -def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): +def test_list_batch_items_pagination( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +): # Create a batch job. - batch = server.batch_manager.create_batch_job( + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) # Create 10 batch items. created_items = [] for i in range(10): - item = server.batch_manager.create_batch_item( - batch_id=batch.id, + item = server.batch_manager.create_llm_batch_item( + llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, step_state=dummy_step_state, @@ -4997,7 +5037,7 @@ def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_be created_items.append(item) # Retrieve all items (without pagination). - all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user) + all_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user) assert len(all_items) >= 10, f"Expected at least 10 items, got {len(all_items)}" # Verify the items are ordered ascending by id (based on our implementation). @@ -5009,7 +5049,7 @@ def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_be cursor = all_items[4].id # Retrieve items after the cursor. - paged_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=cursor) + paged_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user, after=cursor) # All returned items should have an id greater than the cursor. for item in paged_items: @@ -5023,7 +5063,7 @@ def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_be # Test pagination with a limit. limit = 3 - limited_page = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=cursor, limit=limit) + limited_page = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user, after=cursor, limit=limit) # If more than 'limit' items remain, we should only get exactly 'limit' items. assert len(limited_page) == min( limit, expected_remaining @@ -5031,23 +5071,24 @@ def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_be # Optional: Test with a cursor beyond the last item returns an empty list. last_cursor = sorted_ids[-1] - empty_page = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=last_cursor) + empty_page = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user, after=last_cursor) assert empty_page == [], "Expected an empty list when cursor is after the last item" def test_bulk_update_batch_items_request_status_by_agent( - server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job ): # Create a batch job - batch = server.batch_manager.create_batch_job( + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) # Create a batch item - item = server.batch_manager.create_batch_item( - batch_id=batch.id, + item = server.batch_manager.create_llm_batch_item( + llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, step_state=dummy_step_state, @@ -5055,55 +5096,59 @@ def test_bulk_update_batch_items_request_status_by_agent( ) # Update the request status using the bulk update method - server.batch_manager.bulk_update_batch_items_request_status_by_agent( + server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent( [RequestStatusUpdateInfo(batch.id, sarah_agent.id, JobStatus.expired)] ) # Verify the update was applied - updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user) + updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) assert updated.request_status == JobStatus.expired -def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response): +def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response, letta_batch_job): # Create a batch job - batch = server.batch_manager.create_batch_job( + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) # Attempt to update non-existent items should not raise errors - # Test with the direct bulk_update_batch_items method + # Test with the direct bulk_update_llm_batch_items method nonexistent_pairs = [(batch.id, "nonexistent-agent-id")] nonexistent_updates = [{"request_status": JobStatus.expired}] # This should not raise an error, just silently skip non-existent items - server.batch_manager.bulk_update_batch_items(nonexistent_pairs, nonexistent_updates) + server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates) # Test with higher-level methods # Results by agent - server.batch_manager.bulk_update_batch_items_results_by_agent( + server.batch_manager.bulk_update_batch_llm_items_results_by_agent( [ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)] ) # Step status by agent - server.batch_manager.bulk_update_batch_items_step_status_by_agent( + server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent( [StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)] ) # Request status by agent - server.batch_manager.bulk_update_batch_items_request_status_by_agent( + server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent( [RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)] ) -def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): +def test_create_batch_items_bulk( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +): # Create a batch job - batch = server.batch_manager.create_batch_job( + llm_batch_job = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) # Prepare data for multiple batch items @@ -5112,7 +5157,7 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m for agent_id in agent_ids: batch_item = LLMBatchItem( - batch_id=batch.id, + llm_batch_id=llm_batch_job.id, agent_id=agent_id, llm_config=dummy_llm_config, request_status=JobStatus.created, @@ -5122,7 +5167,7 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m batch_items.append(batch_item) # Call the bulk create function - created_items = server.batch_manager.create_batch_items_bulk(batch_items, actor=default_user) + created_items = server.batch_manager.create_llm_batch_items_bulk(batch_items, actor=default_user) # Verify the correct number of items were created assert len(created_items) == len(agent_ids) @@ -5130,7 +5175,7 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m # Verify each item has expected properties for item in created_items: assert item.id.startswith("batch_item-") - assert item.batch_id == batch.id + assert item.llm_batch_id == llm_batch_job.id assert item.agent_id in agent_ids assert item.llm_config == dummy_llm_config assert item.request_status == JobStatus.created @@ -5138,38 +5183,41 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m assert item.step_state == dummy_step_state # Verify items can be retrieved from the database - all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user) + all_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user) assert len(all_items) >= len(agent_ids) # Verify the IDs of created items match what's in the database created_ids = [item.id for item in created_items] for item_id in created_ids: - fetched = server.batch_manager.get_batch_item_by_id(item_id, actor=default_user) + fetched = server.batch_manager.get_llm_batch_item_by_id(item_id, actor=default_user) assert fetched.id in created_ids -def test_count_batch_items(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): +def test_count_batch_items( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +): # Create a batch job first. - batch = server.batch_manager.create_batch_job( + batch = server.batch_manager.create_llm_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, actor=default_user, + letta_batch_job_id=letta_batch_job.id, ) # Create a specific number of batch items for this batch. num_items = 5 for _ in range(num_items): - server.batch_manager.create_batch_item( - batch_id=batch.id, + server.batch_manager.create_llm_batch_item( + llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, step_state=dummy_step_state, actor=default_user, ) - # Use the count_batch_items method to count the items. - count = server.batch_manager.count_batch_items(batch_id=batch.id) + # Use the count_llm_batch_items method to count the items. + count = server.batch_manager.count_llm_batch_items(llm_batch_id=batch.id) # Assert that the count matches the expected number. assert count == num_items, f"Expected {num_items} items, got {count}"