feat: add batch job tracking and generate batch APIs (#1727)
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
"""Rename batch_id to llm_batch_id on llm_batch_item
|
||||
|
||||
Revision ID: 7b189006c97d
|
||||
Revises: f2f78d62005c
|
||||
Create Date: 2025-04-17 16:04:52.045672
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "7b189006c97d"
|
||||
down_revision: Union[str, None] = "f2f78d62005c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("llm_batch_items", sa.Column("llm_batch_id", sa.String(), nullable=False))
|
||||
op.drop_index("ix_llm_batch_items_batch_id", table_name="llm_batch_items")
|
||||
op.create_index("ix_llm_batch_items_llm_batch_id", "llm_batch_items", ["llm_batch_id"], unique=False)
|
||||
op.drop_constraint("llm_batch_items_batch_id_fkey", "llm_batch_items", type_="foreignkey")
|
||||
op.create_foreign_key(None, "llm_batch_items", "llm_batch_job", ["llm_batch_id"], ["id"], ondelete="CASCADE")
|
||||
op.drop_column("llm_batch_items", "batch_id")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("llm_batch_items", sa.Column("batch_id", sa.VARCHAR(), autoincrement=False, nullable=False))
|
||||
op.drop_constraint(None, "llm_batch_items", type_="foreignkey")
|
||||
op.create_foreign_key("llm_batch_items_batch_id_fkey", "llm_batch_items", "llm_batch_job", ["batch_id"], ["id"], ondelete="CASCADE")
|
||||
op.drop_index("ix_llm_batch_items_llm_batch_id", table_name="llm_batch_items")
|
||||
op.create_index("ix_llm_batch_items_batch_id", "llm_batch_items", ["batch_id"], unique=False)
|
||||
op.drop_column("llm_batch_items", "llm_batch_id")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Add letta batch job id to llm_batch_job
|
||||
|
||||
Revision ID: f2f78d62005c
|
||||
Revises: c3b1da3d1157
|
||||
Create Date: 2025-04-17 15:58:43.705483
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f2f78d62005c"
|
||||
down_revision: Union[str, None] = "c3b1da3d1157"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("llm_batch_job", sa.Column("letta_batch_job_id", sa.String(), nullable=False))
|
||||
op.create_foreign_key(None, "llm_batch_job", "jobs", ["letta_batch_job_id"], ["id"], ondelete="CASCADE")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(None, "llm_batch_job", type_="foreignkey")
|
||||
op.drop_column("llm_batch_job", "letta_batch_job_id")
|
||||
# ### end Alembic commands ###
|
||||
@@ -17,6 +17,7 @@ from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState, AgentStepState
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
||||
from letta.schemas.job import JobUpdate
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_request import LettaBatchRequest
|
||||
from letta.schemas.letta_response import LettaBatchResponse
|
||||
@@ -29,6 +30,7 @@ from letta.server.rest_api.utils import create_heartbeat_system_message, create_
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.llm_batch_manager import LLMBatchManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
@@ -102,6 +104,7 @@ class LettaAgentBatch:
|
||||
passage_manager: PassageManager,
|
||||
batch_manager: LLMBatchManager,
|
||||
sandbox_config_manager: SandboxConfigManager,
|
||||
job_manager: JobManager,
|
||||
actor: User,
|
||||
use_assistant_message: bool = True,
|
||||
max_steps: int = 10,
|
||||
@@ -112,12 +115,16 @@ class LettaAgentBatch:
|
||||
self.passage_manager = passage_manager
|
||||
self.batch_manager = batch_manager
|
||||
self.sandbox_config_manager = sandbox_config_manager
|
||||
self.job_manager = job_manager
|
||||
self.use_assistant_message = use_assistant_message
|
||||
self.actor = actor
|
||||
self.max_steps = max_steps
|
||||
|
||||
async def step_until_request(
|
||||
self, batch_requests: List[LettaBatchRequest], agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None
|
||||
self,
|
||||
batch_requests: List[LettaBatchRequest],
|
||||
letta_batch_job_id: str,
|
||||
agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None,
|
||||
) -> LettaBatchResponse:
|
||||
# Basic checks
|
||||
if not batch_requests:
|
||||
@@ -162,11 +169,12 @@ class LettaAgentBatch:
|
||||
)
|
||||
|
||||
# Write the response into the jobs table, where it will get picked up by the next cron run
|
||||
batch_job = self.batch_manager.create_batch_job(
|
||||
llm_batch_job = self.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic, # TODO: Expand to more providers
|
||||
create_batch_response=batch_response,
|
||||
actor=self.actor,
|
||||
status=JobStatus.running,
|
||||
letta_batch_job_id=letta_batch_job_id,
|
||||
)
|
||||
|
||||
# Create batch items in bulk for all agents
|
||||
@@ -174,7 +182,7 @@ class LettaAgentBatch:
|
||||
for agent_state in agent_states:
|
||||
agent_step_state = agent_step_state_mapping.get(agent_state.id)
|
||||
batch_item = LLMBatchItem(
|
||||
batch_id=batch_job.id,
|
||||
llm_batch_id=llm_batch_job.id,
|
||||
agent_id=agent_state.id,
|
||||
llm_config=agent_state.llm_config,
|
||||
request_status=JobStatus.created,
|
||||
@@ -185,19 +193,21 @@ class LettaAgentBatch:
|
||||
|
||||
# Create all batch items at once using the bulk operation
|
||||
if batch_items:
|
||||
self.batch_manager.create_batch_items_bulk(batch_items, actor=self.actor)
|
||||
self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor)
|
||||
|
||||
return LettaBatchResponse(
|
||||
batch_id=batch_job.id,
|
||||
status=batch_job.status,
|
||||
letta_batch_id=llm_batch_job.letta_batch_job_id,
|
||||
last_llm_batch_id=llm_batch_job.id,
|
||||
status=llm_batch_job.status,
|
||||
agent_count=len(agent_states),
|
||||
last_polled_at=get_utc_time(),
|
||||
created_at=batch_job.created_at,
|
||||
created_at=llm_batch_job.created_at,
|
||||
)
|
||||
|
||||
async def resume_step_after_request(self, batch_id: str) -> LettaBatchResponse:
|
||||
async def resume_step_after_request(self, letta_batch_id: str, llm_batch_id: str) -> LettaBatchResponse:
|
||||
# 1. gather everything we need
|
||||
ctx = await self._collect_resume_context(batch_id)
|
||||
llm_batch_job = self.batch_manager.get_llm_batch_job_by_id(llm_batch_id=llm_batch_id, actor=self.actor)
|
||||
ctx = await self._collect_resume_context(llm_batch_id)
|
||||
|
||||
# 2. persist request‑level status updates
|
||||
self._update_request_statuses(ctx.request_status_updates)
|
||||
@@ -209,19 +219,31 @@ class LettaAgentBatch:
|
||||
msg_map = self._persist_tool_messages(exec_results, ctx)
|
||||
|
||||
# 5. mark steps complete
|
||||
self._mark_steps_complete(batch_id, ctx.agent_ids)
|
||||
self._mark_steps_complete(llm_batch_id, ctx.agent_ids)
|
||||
|
||||
# 6. build next‑round requests / step‑state map
|
||||
next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map)
|
||||
if len(next_reqs) == 0:
|
||||
# mark batch job as completed
|
||||
self.job_manager.update_job_by_id(job_id=letta_batch_id, job_update=JobUpdate(status=JobStatus.completed), actor=self.actor)
|
||||
return LettaBatchResponse(
|
||||
letta_batch_id=llm_batch_job.letta_batch_job_id,
|
||||
last_llm_batch_id=llm_batch_job.id,
|
||||
status=JobStatus.completed,
|
||||
agent_count=len(ctx.agent_ids),
|
||||
last_polled_at=get_utc_time(),
|
||||
created_at=llm_batch_job.created_at,
|
||||
)
|
||||
|
||||
# 7. recurse into the normal stepping pipeline
|
||||
return await self.step_until_request(
|
||||
batch_requests=next_reqs,
|
||||
letta_batch_job_id=letta_batch_id,
|
||||
agent_step_state_mapping=next_step_state,
|
||||
)
|
||||
|
||||
async def _collect_resume_context(self, batch_id: str) -> _ResumeContext:
|
||||
batch_items = self.batch_manager.list_batch_items(batch_id=batch_id)
|
||||
async def _collect_resume_context(self, llm_batch_id: str) -> _ResumeContext:
|
||||
batch_items = self.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_id)
|
||||
|
||||
agent_ids, agent_state_map = [], {}
|
||||
provider_results, name_map, args_map, cont_map = {}, {}, {}, {}
|
||||
@@ -244,7 +266,7 @@ class LettaAgentBatch:
|
||||
else JobStatus.cancelled if isinstance(pr, BetaMessageBatchCanceledResult) else JobStatus.expired
|
||||
)
|
||||
)
|
||||
request_status_updates.append(RequestStatusUpdateInfo(batch_id=batch_id, agent_id=aid, request_status=status))
|
||||
request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status))
|
||||
|
||||
# translate provider‑specific response → OpenAI‑style tool call (unchanged)
|
||||
llm_client = LLMClient.create(llm_config=item.llm_config, put_inner_thoughts_first=True)
|
||||
@@ -270,7 +292,7 @@ class LettaAgentBatch:
|
||||
|
||||
def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None:
|
||||
if updates:
|
||||
self.batch_manager.bulk_update_batch_items_request_status_by_agent(updates=updates)
|
||||
self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(updates=updates)
|
||||
|
||||
def _build_sandbox(self) -> Tuple[SandboxConfig, Dict[str, Any]]:
|
||||
sbx_type = SandboxType.E2B if tool_settings.e2b_api_key else SandboxType.LOCAL
|
||||
@@ -315,9 +337,11 @@ class LettaAgentBatch:
|
||||
self.message_manager.create_many_messages([m for msgs in msg_map.values() for m in msgs], actor=self.actor)
|
||||
return msg_map
|
||||
|
||||
def _mark_steps_complete(self, batch_id: str, agent_ids: List[str]) -> None:
|
||||
updates = [StepStatusUpdateInfo(batch_id=batch_id, agent_id=aid, step_status=AgentStepStatus.completed) for aid in agent_ids]
|
||||
self.batch_manager.bulk_update_batch_items_step_status_by_agent(updates)
|
||||
def _mark_steps_complete(self, llm_batch_id: str, agent_ids: List[str]) -> None:
|
||||
updates = [
|
||||
StepStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, step_status=AgentStepStatus.completed) for aid in agent_ids
|
||||
]
|
||||
self.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(updates)
|
||||
|
||||
def _prepare_next_iteration(
|
||||
self,
|
||||
|
||||
@@ -102,7 +102,7 @@ async def poll_batch_updates(server: SyncServer, batch_jobs: List[LLMBatchJob],
|
||||
results: List[BatchPollingResult] = await asyncio.gather(*coros)
|
||||
|
||||
# Update the server with batch status changes
|
||||
server.batch_manager.bulk_update_batch_statuses(updates=results)
|
||||
server.batch_manager.bulk_update_llm_batch_statuses(updates=results)
|
||||
logger.info(f"[Poll BatchJob] Bulk-updated {len(results)} LLM batch(es) in the DB at job level.")
|
||||
|
||||
return results
|
||||
@@ -176,7 +176,7 @@ async def poll_running_llm_batches(server: "SyncServer") -> None:
|
||||
|
||||
try:
|
||||
# 1. Retrieve running batch jobs
|
||||
batches = server.batch_manager.list_running_batches()
|
||||
batches = server.batch_manager.list_running_llm_batches()
|
||||
metrics.total_batches = len(batches)
|
||||
|
||||
# TODO: Expand to more providers
|
||||
@@ -193,7 +193,7 @@ async def poll_running_llm_batches(server: "SyncServer") -> None:
|
||||
# 6. Bulk update all items for newly completed batch(es)
|
||||
if item_updates:
|
||||
metrics.updated_items_count = len(item_updates)
|
||||
server.batch_manager.bulk_update_batch_items_results_by_agent(item_updates)
|
||||
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(item_updates)
|
||||
else:
|
||||
logger.info("[Poll BatchJob] No item-level updates needed.")
|
||||
|
||||
|
||||
@@ -6,25 +6,25 @@ from letta.schemas.enums import AgentStepStatus, JobStatus
|
||||
|
||||
|
||||
class BatchPollingResult(NamedTuple):
|
||||
batch_id: str
|
||||
llm_batch_id: str
|
||||
request_status: JobStatus
|
||||
batch_response: Optional[BetaMessageBatch]
|
||||
|
||||
|
||||
class ItemUpdateInfo(NamedTuple):
|
||||
batch_id: str
|
||||
llm_batch_id: str
|
||||
agent_id: str
|
||||
request_status: JobStatus
|
||||
batch_request_result: Optional[BetaMessageBatchIndividualResponse]
|
||||
|
||||
|
||||
class StepStatusUpdateInfo(NamedTuple):
|
||||
batch_id: str
|
||||
llm_batch_id: str
|
||||
agent_id: str
|
||||
step_status: AgentStepStatus
|
||||
|
||||
|
||||
class RequestStatusUpdateInfo(NamedTuple):
|
||||
batch_id: str
|
||||
llm_batch_id: str
|
||||
agent_id: str
|
||||
request_status: JobStatus
|
||||
|
||||
@@ -16,6 +16,7 @@ class ToolType(str, Enum):
|
||||
class JobType(str, Enum):
|
||||
JOB = "job"
|
||||
RUN = "run"
|
||||
BATCH = "batch"
|
||||
|
||||
|
||||
class ToolSourceType(str, Enum):
|
||||
|
||||
@@ -20,7 +20,7 @@ class LLMBatchItem(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
__tablename__ = "llm_batch_items"
|
||||
__pydantic_model__ = PydanticLLMBatchItem
|
||||
__table_args__ = (
|
||||
Index("ix_llm_batch_items_batch_id", "batch_id"),
|
||||
Index("ix_llm_batch_items_llm_batch_id", "llm_batch_id"),
|
||||
Index("ix_llm_batch_items_agent_id", "agent_id"),
|
||||
Index("ix_llm_batch_items_status", "request_status"),
|
||||
)
|
||||
@@ -29,7 +29,7 @@ class LLMBatchItem(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
# TODO: Some still rely on the Pydantic object to do this
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"batch_item-{uuid.uuid4()}")
|
||||
|
||||
batch_id: Mapped[str] = mapped_column(
|
||||
llm_batch_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("llm_batch_job.id", ondelete="CASCADE"), doc="Foreign key to the LLM provider batch this item belongs to"
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from anthropic.types.beta.messages import BetaMessageBatch
|
||||
from sqlalchemy import DateTime, Index, String
|
||||
from sqlalchemy import DateTime, ForeignKey, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.custom_columns import CreateBatchResponseColumn, PollBatchResponseColumn
|
||||
@@ -43,6 +43,9 @@ class LLMBatchJob(SqlalchemyBase, OrganizationMixin):
|
||||
DateTime(timezone=True), nullable=True, doc="Last time we polled the provider for status"
|
||||
)
|
||||
|
||||
# relationships
|
||||
letta_batch_job_id: Mapped[str] = mapped_column(
|
||||
String, ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, doc="ID of the Letta batch job"
|
||||
)
|
||||
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_jobs")
|
||||
items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="batch", lazy="selectin")
|
||||
|
||||
@@ -34,6 +34,41 @@ class Job(JobBase):
|
||||
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the job.")
|
||||
|
||||
|
||||
class BatchJob(JobBase):
|
||||
id: str = JobBase.generate_id_field()
|
||||
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the job.")
|
||||
job_type: JobType = JobType.BATCH
|
||||
|
||||
@classmethod
|
||||
def from_job(cls, job: Job) -> "BatchJob":
|
||||
"""
|
||||
Convert a Job instance to a BatchJob instance by replacing the ID prefix.
|
||||
All other fields are copied as-is.
|
||||
|
||||
Args:
|
||||
job: The Job instance to convert
|
||||
|
||||
Returns:
|
||||
A new Run instance with the same data but 'run-' prefix in ID
|
||||
"""
|
||||
# Convert job dict to exclude None values
|
||||
job_data = job.model_dump(exclude_none=True)
|
||||
|
||||
# Create new Run instance with converted data
|
||||
return cls(**job_data)
|
||||
|
||||
def to_job(self) -> Job:
|
||||
"""
|
||||
Convert this BatchJob instance to a Job instance by replacing the ID prefix.
|
||||
All other fields are copied as-is.
|
||||
|
||||
Returns:
|
||||
A new Job instance with the same data but 'job-' prefix in ID
|
||||
"""
|
||||
run_data = self.model_dump(exclude_none=True)
|
||||
return Job(**run_data)
|
||||
|
||||
|
||||
class JobUpdate(JobBase):
|
||||
status: Optional[JobStatus] = Field(None, description="The status of the job.")
|
||||
|
||||
|
||||
@@ -31,3 +31,7 @@ class LettaStreamingRequest(LettaRequest):
|
||||
|
||||
class LettaBatchRequest(LettaRequest):
|
||||
agent_id: str = Field(..., description="The ID of the agent to send this batch request for")
|
||||
|
||||
|
||||
class CreateBatch(BaseModel):
|
||||
requests: List[LettaBatchRequest] = Field(..., description="List of requests to be processed in batch.")
|
||||
|
||||
@@ -169,7 +169,8 @@ LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStat
|
||||
|
||||
|
||||
class LettaBatchResponse(BaseModel):
|
||||
batch_id: str = Field(..., description="A unique identifier for this batch request.")
|
||||
letta_batch_id: str = Field(..., description="A unique identifier for the Letta batch request.")
|
||||
last_llm_batch_id: str = Field(..., description="A unique identifier for the most recent model provider batch request.")
|
||||
status: JobStatus = Field(..., description="The current status of the batch request.")
|
||||
agent_count: int = Field(..., description="The number of agents in the batch request.")
|
||||
last_polled_at: datetime = Field(..., description="The timestamp when the batch was last polled for updates.")
|
||||
|
||||
@@ -20,7 +20,7 @@ class LLMBatchItem(OrmMetadataBase, validate_assignment=True):
|
||||
__id_prefix__ = "batch_item"
|
||||
|
||||
id: Optional[str] = Field(None, description="The id of the batch item. Assigned by the database.")
|
||||
batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.")
|
||||
llm_batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.")
|
||||
agent_id: str = Field(..., description="The id of the agent associated with this LLM request.")
|
||||
|
||||
llm_config: LLMConfig = Field(..., description="The LLM configuration used for this request.")
|
||||
@@ -45,6 +45,7 @@ class LLMBatchJob(OrmMetadataBase, validate_assignment=True):
|
||||
id: Optional[str] = Field(None, description="The id of the batch job. Assigned by the database.")
|
||||
status: JobStatus = Field(..., description="The current status of the batch (e.g., created, in_progress, done).")
|
||||
llm_provider: ProviderType = Field(..., description="The LLM provider used for the batch (e.g., anthropic, openai).")
|
||||
letta_batch_job_id: str = Field(..., description="ID of the Letta batch job")
|
||||
|
||||
create_batch_response: Union[BetaMessageBatch] = Field(..., description="The full JSON response from the initial batch creation.")
|
||||
latest_polling_response: Optional[Union[BetaMessageBatch]] = Field(
|
||||
|
||||
@@ -174,6 +174,7 @@ def create_application() -> "FastAPI":
|
||||
async def generic_error_handler(request: Request, exc: Exception):
|
||||
# Log the actual error for debugging
|
||||
log.error(f"Unhandled error: {exc}", exc_info=True)
|
||||
print(f"Unhandled error: {exc}")
|
||||
|
||||
# Print the stack trace
|
||||
print(f"Stack trace: {exc}")
|
||||
|
||||
@@ -5,6 +5,7 @@ from letta.server.rest_api.routers.v1.health import router as health_router
|
||||
from letta.server.rest_api.routers.v1.identities import router as identities_router
|
||||
from letta.server.rest_api.routers.v1.jobs import router as jobs_router
|
||||
from letta.server.rest_api.routers.v1.llms import router as llm_router
|
||||
from letta.server.rest_api.routers.v1.messages import router as messages_router
|
||||
from letta.server.rest_api.routers.v1.providers import router as providers_router
|
||||
from letta.server.rest_api.routers.v1.runs import router as runs_router
|
||||
from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_configs_router
|
||||
@@ -29,5 +30,6 @@ ROUTERS = [
|
||||
runs_router,
|
||||
steps_router,
|
||||
tags_router,
|
||||
messages_router,
|
||||
voice_router,
|
||||
]
|
||||
|
||||
@@ -12,7 +12,6 @@ from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.agents.letta_agent_batch import LettaAgentBatch
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
@@ -21,8 +20,8 @@ from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.group import Group
|
||||
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion
|
||||
from letta.schemas.letta_request import LettaBatchRequest, LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaBatchResponse, LettaResponse
|
||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.passage import Passage, PassageUpdate
|
||||
@@ -832,57 +831,3 @@ async def list_agent_groups(
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
print("in list agents with manager_type", manager_type)
|
||||
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)
|
||||
|
||||
|
||||
# Batch APIs
|
||||
|
||||
|
||||
@router.post("/messages/batches", response_model=LettaBatchResponse, operation_id="create_batch_message_request")
|
||||
async def send_batch_messages(
|
||||
batch_requests: List[LettaBatchRequest] = Body(..., description="Messages and config for all agents"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Submit a batch of agent messages for asynchronous processing.
|
||||
Creates a job that will fan out messages to all listed agents and process them in parallel.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
batch_runner = LettaAgentBatch(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
return await batch_runner.step_until_request(batch_requests=batch_requests)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/messages/batches/{batch_id}",
|
||||
response_model=LettaBatchResponse,
|
||||
operation_id="retrieve_batch_message_request",
|
||||
)
|
||||
async def retrieve_batch_message_request(
|
||||
batch_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Retrieve the result or current status of a previously submitted batch message request.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
batch_job = server.batch_manager.get_batch_job_by_id(batch_id=batch_id, actor=actor)
|
||||
agent_count = server.batch_manager.count_batch_items(batch_id=batch_id)
|
||||
|
||||
return LettaBatchResponse(
|
||||
batch_id=batch_id,
|
||||
status=batch_job.status,
|
||||
agent_count=agent_count,
|
||||
last_polled_at=batch_job.last_polled_at,
|
||||
created_at=batch_job.created_at,
|
||||
)
|
||||
|
||||
127
letta/server/rest_api/routers/v1/messages.py
Normal file
127
letta/server/rest_api/routers/v1/messages.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from letta.agents.letta_agent_batch import LettaAgentBatch
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.job import BatchJob, JobStatus, JobType
|
||||
from letta.schemas.letta_request import CreateBatch
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/messages", tags=["messages"])
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Batch APIs
|
||||
|
||||
|
||||
@router.post(
|
||||
"/batches",
|
||||
response_model=BatchJob,
|
||||
operation_id="create_messages_batch",
|
||||
)
|
||||
async def create_messages_batch(
|
||||
request: CreateBatch = Body(..., description="Messages and config for all agents"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Submit a batch of agent messages for asynchronous processing.
|
||||
Creates a job that will fan out messages to all listed agents and process them in parallel.
|
||||
"""
|
||||
print("GOT REQQUEST", request)
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
print("ACTOR", actor)
|
||||
|
||||
# Create a new job
|
||||
batch_job = BatchJob(
|
||||
user_id=actor.id,
|
||||
status=JobStatus.created,
|
||||
metadata={
|
||||
"job_type": "batch_messages",
|
||||
},
|
||||
)
|
||||
print("BATCH JOB", batch_job)
|
||||
|
||||
# create the batch runner
|
||||
batch_runner = LettaAgentBatch(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=actor,
|
||||
)
|
||||
print("call step_until_request", batch_job)
|
||||
llm_batch_job = await batch_runner.step_until_request(batch_requests=request.requests, letta_batch_job_id=batch_job.id)
|
||||
|
||||
# TODO: update run metadata
|
||||
batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor)
|
||||
except Exception as e:
|
||||
print("Error creating batch job", e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
return batch_job
|
||||
|
||||
|
||||
@router.get("/batches/{batch_id}", response_model=BatchJob, operation_id="retrieve_batch_run")
|
||||
async def retrieve_batch_run(
|
||||
batch_id: str,
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get the status of a batch run.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
|
||||
return BatchJob.from_job(job)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Batch not found")
|
||||
|
||||
|
||||
@router.get("/batches", response_model=List[BatchJob], operation_id="list_batch_runs")
|
||||
async def list_batch_runs(
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
List all batch runs.
|
||||
"""
|
||||
# TODO: filter
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
jobs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.BATCH)
|
||||
print("ACTIVE", jobs)
|
||||
return [BatchJob.from_job(job) for job in jobs]
|
||||
|
||||
|
||||
@router.patch("/batches/{batch_id}/cancel", operation_id="cancel_batch_run")
|
||||
async def cancel_batch_run(
|
||||
batch_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Cancel a batch run.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
|
||||
job.status = JobStatus.cancelled
|
||||
server.job_manager.update_job_by_id(job_id=job, job=job)
|
||||
# TODO: actually cancel it
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
@@ -15,6 +15,7 @@ from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.orm.step import Step
|
||||
from letta.orm.step import Step as StepModel
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.job import BatchJob as PydanticBatchJob
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
@@ -36,7 +37,9 @@ class JobManager:
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_job(self, pydantic_job: Union[PydanticJob, PydanticRun], actor: PydanticUser) -> Union[PydanticJob, PydanticRun]:
|
||||
def create_job(
|
||||
self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser
|
||||
) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]:
|
||||
"""Create a new job based on the JobCreate schema."""
|
||||
with self.session_maker() as session:
|
||||
# Associate the job with the user
|
||||
|
||||
@@ -28,11 +28,12 @@ class LLMBatchManager:
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_batch_job(
|
||||
def create_llm_batch_job(
|
||||
self,
|
||||
llm_provider: ProviderType,
|
||||
create_batch_response: BetaMessageBatch,
|
||||
actor: PydanticUser,
|
||||
letta_batch_job_id: str,
|
||||
status: JobStatus = JobStatus.created,
|
||||
) -> PydanticLLMBatchJob:
|
||||
"""Create a new LLM batch job."""
|
||||
@@ -42,51 +43,52 @@ class LLMBatchManager:
|
||||
llm_provider=llm_provider,
|
||||
create_batch_response=create_batch_response,
|
||||
organization_id=actor.organization_id,
|
||||
letta_batch_job_id=letta_batch_job_id,
|
||||
)
|
||||
batch.create(session, actor=actor)
|
||||
return batch.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_batch_job_by_id(self, batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob:
|
||||
def get_llm_batch_job_by_id(self, llm_batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob:
|
||||
"""Retrieve a single batch job by ID."""
|
||||
with self.session_maker() as session:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
|
||||
return batch.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_batch_status(
|
||||
def update_llm_batch_status(
|
||||
self,
|
||||
batch_id: str,
|
||||
llm_batch_id: str,
|
||||
status: JobStatus,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
latest_polling_response: Optional[BetaMessageBatch] = None,
|
||||
) -> PydanticLLMBatchJob:
|
||||
"""Update a batch job’s status and optionally its polling response."""
|
||||
with self.session_maker() as session:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
|
||||
batch.status = status
|
||||
batch.latest_polling_response = latest_polling_response
|
||||
batch.last_polled_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
batch = batch.update(db_session=session, actor=actor)
|
||||
return batch.to_pydantic()
|
||||
|
||||
def bulk_update_batch_statuses(
|
||||
def bulk_update_llm_batch_statuses(
|
||||
self,
|
||||
updates: List[BatchPollingResult],
|
||||
) -> None:
|
||||
"""
|
||||
Efficiently update many LLMBatchJob rows. This is used by the cron jobs.
|
||||
|
||||
`updates` = [(batch_id, new_status, polling_response_or_None), …]
|
||||
`updates` = [(llm_batch_id, new_status, polling_response_or_None), …]
|
||||
"""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
with self.session_maker() as session:
|
||||
mappings = []
|
||||
for batch_id, status, response in updates:
|
||||
for llm_batch_id, status, response in updates:
|
||||
mappings.append(
|
||||
{
|
||||
"id": batch_id,
|
||||
"id": llm_batch_id,
|
||||
"status": status,
|
||||
"latest_polling_response": response,
|
||||
"last_polled_at": now,
|
||||
@@ -97,14 +99,51 @@ class LLMBatchManager:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_batch_request(self, batch_id: str, actor: PydanticUser) -> None:
|
||||
def list_llm_batch_jobs(
|
||||
self,
|
||||
letta_batch_id: str,
|
||||
limit: Optional[int] = None,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
after: Optional[str] = None,
|
||||
) -> List[PydanticLLMBatchItem]:
|
||||
"""
|
||||
List all batch items for a given llm_batch_id, optionally filtered by additional criteria and limited in count.
|
||||
|
||||
Optional filters:
|
||||
- after: A cursor string. Only items with an `id` greater than this value are returned.
|
||||
- agent_id: Restrict the result set to a specific agent.
|
||||
- request_status: Filter items based on their request status (e.g., created, completed, expired).
|
||||
- step_status: Filter items based on their step execution status.
|
||||
|
||||
The results are ordered by their id in ascending order.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
query = session.query(LLMBatchJob).filter(LLMBatchJob.letta_batch_job_id == letta_batch_id)
|
||||
|
||||
if actor is not None:
|
||||
query = query.filter(LLMBatchJob.organization_id == actor.organization_id)
|
||||
|
||||
# Additional optional filters
|
||||
if after is not None:
|
||||
query = query.filter(LLMBatchJob.id > after)
|
||||
|
||||
query = query.order_by(LLMBatchJob.id.asc())
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
return [item.to_pydantic() for item in results]
|
||||
|
||||
@enforce_types
|
||||
def delete_llm_batch_request(self, llm_batch_id: str, actor: PydanticUser) -> None:
|
||||
"""Hard delete a batch job by ID."""
|
||||
with self.session_maker() as session:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
|
||||
batch.hard_delete(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def list_running_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]:
|
||||
def list_running_llm_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]:
|
||||
"""Return all running LLM batch jobs, optionally filtered by actor's organization."""
|
||||
with self.session_maker() as session:
|
||||
query = session.query(LLMBatchJob).filter(LLMBatchJob.status == JobStatus.running)
|
||||
@@ -116,9 +155,9 @@ class LLMBatchManager:
|
||||
return [batch.to_pydantic() for batch in results]
|
||||
|
||||
@enforce_types
|
||||
def create_batch_item(
|
||||
def create_llm_batch_item(
|
||||
self,
|
||||
batch_id: str,
|
||||
llm_batch_id: str,
|
||||
agent_id: str,
|
||||
llm_config: LLMConfig,
|
||||
actor: PydanticUser,
|
||||
@@ -129,7 +168,7 @@ class LLMBatchManager:
|
||||
"""Create a new batch item."""
|
||||
with self.session_maker() as session:
|
||||
item = LLMBatchItem(
|
||||
batch_id=batch_id,
|
||||
llm_batch_id=llm_batch_id,
|
||||
agent_id=agent_id,
|
||||
llm_config=llm_config,
|
||||
request_status=request_status,
|
||||
@@ -141,7 +180,7 @@ class LLMBatchManager:
|
||||
return item.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def create_batch_items_bulk(self, llm_batch_items: List[PydanticLLMBatchItem], actor: PydanticUser) -> List[PydanticLLMBatchItem]:
|
||||
def create_llm_batch_items_bulk(self, llm_batch_items: List[PydanticLLMBatchItem], actor: PydanticUser) -> List[PydanticLLMBatchItem]:
|
||||
"""
|
||||
Create multiple batch items in bulk for better performance.
|
||||
|
||||
@@ -157,7 +196,7 @@ class LLMBatchManager:
|
||||
orm_items = []
|
||||
for item in llm_batch_items:
|
||||
orm_item = LLMBatchItem(
|
||||
batch_id=item.batch_id,
|
||||
llm_batch_id=item.llm_batch_id,
|
||||
agent_id=item.agent_id,
|
||||
llm_config=item.llm_config,
|
||||
request_status=item.request_status,
|
||||
@@ -174,14 +213,14 @@ class LLMBatchManager:
|
||||
return [item.to_pydantic() for item in created_items]
|
||||
|
||||
@enforce_types
|
||||
def get_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
|
||||
def get_llm_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
|
||||
"""Retrieve a single batch item by ID."""
|
||||
with self.session_maker() as session:
|
||||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||||
return item.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_batch_item(
|
||||
def update_llm_batch_item(
|
||||
self,
|
||||
item_id: str,
|
||||
actor: PydanticUser,
|
||||
@@ -206,9 +245,9 @@ class LLMBatchManager:
|
||||
return item.update(db_session=session, actor=actor).to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def list_batch_items(
|
||||
def list_llm_batch_items(
|
||||
self,
|
||||
batch_id: str,
|
||||
llm_batch_id: str,
|
||||
limit: Optional[int] = None,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
after: Optional[str] = None,
|
||||
@@ -217,7 +256,7 @@ class LLMBatchManager:
|
||||
step_status: Optional[AgentStepStatus] = None,
|
||||
) -> List[PydanticLLMBatchItem]:
|
||||
"""
|
||||
List all batch items for a given batch_id, optionally filtered by additional criteria and limited in count.
|
||||
List all batch items for a given llm_batch_id, optionally filtered by additional criteria and limited in count.
|
||||
|
||||
Optional filters:
|
||||
- after: A cursor string. Only items with an `id` greater than this value are returned.
|
||||
@@ -228,7 +267,7 @@ class LLMBatchManager:
|
||||
The results are ordered by their id in ascending order.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
query = session.query(LLMBatchItem).filter(LLMBatchItem.batch_id == batch_id)
|
||||
query = session.query(LLMBatchItem).filter(LLMBatchItem.llm_batch_id == llm_batch_id)
|
||||
|
||||
if actor is not None:
|
||||
query = query.filter(LLMBatchItem.organization_id == actor.organization_id)
|
||||
@@ -251,36 +290,36 @@ class LLMBatchManager:
|
||||
results = query.all()
|
||||
return [item.to_pydantic() for item in results]
|
||||
|
||||
def bulk_update_batch_items(
|
||||
def bulk_update_llm_batch_items(
|
||||
self,
|
||||
batch_id_agent_id_pairs: List[Tuple[str, str]],
|
||||
llm_batch_id_agent_id_pairs: List[Tuple[str, str]],
|
||||
field_updates: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Efficiently update multiple LLMBatchItem rows by (batch_id, agent_id) pairs.
|
||||
Efficiently update multiple LLMBatchItem rows by (llm_batch_id, agent_id) pairs.
|
||||
|
||||
Args:
|
||||
batch_id_agent_id_pairs: List of (batch_id, agent_id) tuples identifying items to update
|
||||
llm_batch_id_agent_id_pairs: List of (llm_batch_id, agent_id) tuples identifying items to update
|
||||
field_updates: List of dictionaries containing the fields to update for each item
|
||||
"""
|
||||
if not batch_id_agent_id_pairs or not field_updates:
|
||||
if not llm_batch_id_agent_id_pairs or not field_updates:
|
||||
return
|
||||
|
||||
if len(batch_id_agent_id_pairs) != len(field_updates):
|
||||
if len(llm_batch_id_agent_id_pairs) != len(field_updates):
|
||||
raise ValueError("batch_id_agent_id_pairs and field_updates must have the same length")
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Lookup primary keys
|
||||
items = (
|
||||
session.query(LLMBatchItem.id, LLMBatchItem.batch_id, LLMBatchItem.agent_id)
|
||||
.filter(tuple_(LLMBatchItem.batch_id, LLMBatchItem.agent_id).in_(batch_id_agent_id_pairs))
|
||||
session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id)
|
||||
.filter(tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs))
|
||||
.all()
|
||||
)
|
||||
pair_to_pk = {(b, a): id for id, b, a in items}
|
||||
|
||||
mappings = []
|
||||
for (batch_id, agent_id), fields in zip(batch_id_agent_id_pairs, field_updates):
|
||||
pk_id = pair_to_pk.get((batch_id, agent_id))
|
||||
for (llm_batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates):
|
||||
pk_id = pair_to_pk.get((llm_batch_id, agent_id))
|
||||
if not pk_id:
|
||||
continue
|
||||
|
||||
@@ -293,12 +332,12 @@ class LLMBatchManager:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def bulk_update_batch_items_results_by_agent(
|
||||
def bulk_update_batch_llm_items_results_by_agent(
|
||||
self,
|
||||
updates: List[ItemUpdateInfo],
|
||||
) -> None:
|
||||
"""Update request status and batch results for multiple batch items."""
|
||||
batch_id_agent_id_pairs = [(update.batch_id, update.agent_id) for update in updates]
|
||||
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
|
||||
field_updates = [
|
||||
{
|
||||
"request_status": update.request_status,
|
||||
@@ -307,48 +346,48 @@ class LLMBatchManager:
|
||||
for update in updates
|
||||
]
|
||||
|
||||
self.bulk_update_batch_items(batch_id_agent_id_pairs, field_updates)
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
|
||||
|
||||
@enforce_types
|
||||
def bulk_update_batch_items_step_status_by_agent(
|
||||
def bulk_update_llm_batch_items_step_status_by_agent(
|
||||
self,
|
||||
updates: List[StepStatusUpdateInfo],
|
||||
) -> None:
|
||||
"""Update step status for multiple batch items."""
|
||||
batch_id_agent_id_pairs = [(update.batch_id, update.agent_id) for update in updates]
|
||||
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
|
||||
field_updates = [{"step_status": update.step_status} for update in updates]
|
||||
|
||||
self.bulk_update_batch_items(batch_id_agent_id_pairs, field_updates)
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
|
||||
|
||||
@enforce_types
|
||||
def bulk_update_batch_items_request_status_by_agent(
|
||||
def bulk_update_llm_batch_items_request_status_by_agent(
|
||||
self,
|
||||
updates: List[RequestStatusUpdateInfo],
|
||||
) -> None:
|
||||
"""Update request status for multiple batch items."""
|
||||
batch_id_agent_id_pairs = [(update.batch_id, update.agent_id) for update in updates]
|
||||
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
|
||||
field_updates = [{"request_status": update.request_status} for update in updates]
|
||||
|
||||
self.bulk_update_batch_items(batch_id_agent_id_pairs, field_updates)
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
|
||||
|
||||
@enforce_types
|
||||
def delete_batch_item(self, item_id: str, actor: PydanticUser) -> None:
|
||||
def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None:
|
||||
"""Hard delete a batch item by ID."""
|
||||
with self.session_maker() as session:
|
||||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||||
item.hard_delete(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def count_batch_items(self, batch_id: str) -> int:
|
||||
def count_llm_batch_items(self, llm_batch_id: str) -> int:
|
||||
"""
|
||||
Efficiently count the number of batch items for a given batch_id.
|
||||
Efficiently count the number of batch items for a given llm_batch_id.
|
||||
|
||||
Args:
|
||||
batch_id (str): The batch identifier to count items for.
|
||||
llm_batch_id (str): The batch identifier to count items for.
|
||||
|
||||
Returns:
|
||||
int: The total number of batch items associated with the given batch_id.
|
||||
int: The total number of batch items associated with the given llm_batch_id.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.batch_id == batch_id).scalar()
|
||||
count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.llm_batch_id == llm_batch_id).scalar()
|
||||
return count or 0
|
||||
|
||||
80
tests/integration_test_batch.py
Normal file
80
tests/integration_test_batch.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# import time
|
||||
#
|
||||
# import pytest
|
||||
# from letta_client import Letta, LettaBatchRequest, MessageCreate, TextContent
|
||||
#
|
||||
#
|
||||
# @pytest.fixture(scope="module")
|
||||
# def client():
|
||||
# return Letta(base_url="http://localhost:8283")
|
||||
#
|
||||
#
|
||||
# def test_create_batch(client: Letta):
|
||||
#
|
||||
# # create agents
|
||||
# agent1 = client.agents.create(
|
||||
# name="agent1",
|
||||
# memory_blocks=[{"label": "persona", "value": "you are agent 1"}],
|
||||
# model="anthropic/claude-3-7-sonnet-20250219",
|
||||
# embedding="letta/letta-free",
|
||||
# )
|
||||
# agent2 = client.agents.create(
|
||||
# name="agent2",
|
||||
# memory_blocks=[{"label": "persona", "value": "you are agent 2"}],
|
||||
# model="anthropic/claude-3-7-sonnet-20250219",
|
||||
# embedding="letta/letta-free",
|
||||
# )
|
||||
#
|
||||
# # create a run
|
||||
# run = client.messages.batches.create(
|
||||
# requests=[
|
||||
# LettaBatchRequest(
|
||||
# messages=[
|
||||
# MessageCreate(
|
||||
# role="user",
|
||||
# content=[
|
||||
# TextContent(
|
||||
# text="text",
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
# ],
|
||||
# agent_id=agent1.id,
|
||||
# ),
|
||||
# LettaBatchRequest(
|
||||
# messages=[
|
||||
# MessageCreate(
|
||||
# role="user",
|
||||
# content=[
|
||||
# TextContent(
|
||||
# text="text",
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
# ],
|
||||
# agent_id=agent2.id,
|
||||
# ),
|
||||
# ]
|
||||
# )
|
||||
# assert run is not None
|
||||
#
|
||||
# # list batches
|
||||
# batches = client.messages.batches.list()
|
||||
# assert len(batches) > 0, f"Expected 1 batch, got {len(batches)}"
|
||||
#
|
||||
# # check run status
|
||||
# while True:
|
||||
# run = client.messages.batches.retrieve(batch_id=run.id)
|
||||
# if run.status == "completed":
|
||||
# break
|
||||
# print("Waiting for run to complete...", run.status)
|
||||
# time.sleep(1)
|
||||
#
|
||||
# # get the batch results
|
||||
# results = client.messages.batches.retrieve(
|
||||
# run_id=run.id,
|
||||
# )
|
||||
# assert results is not None
|
||||
# print(results)
|
||||
#
|
||||
# # cancel a run
|
||||
@@ -23,6 +23,7 @@ from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.orm import Base
|
||||
from letta.schemas.agent import AgentStepState
|
||||
from letta.schemas.enums import JobStatus, ProviderType
|
||||
from letta.schemas.job import BatchJob
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.tool_rule import InitToolRule
|
||||
from letta.server.db import db_context
|
||||
@@ -145,13 +146,21 @@ def create_test_agent(client, name, model="anthropic/claude-3-5-sonnet-20241022"
|
||||
)
|
||||
|
||||
|
||||
def create_test_batch_job(server, batch_response, default_user):
|
||||
def create_test_letta_batch_job(server, default_user):
|
||||
"""Create a test batch job with the given batch response."""
|
||||
return server.batch_manager.create_batch_job(
|
||||
return server.job_manager.create_job(BatchJob(user_id=default_user.id), actor=default_user)
|
||||
|
||||
|
||||
def create_test_llm_batch_job(server, batch_response, default_user):
|
||||
"""Create a test batch job with the given batch response."""
|
||||
letta_batch_job = create_test_letta_batch_job(server, default_user)
|
||||
|
||||
return server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
create_batch_response=batch_response,
|
||||
actor=default_user,
|
||||
status=JobStatus.running,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
|
||||
@@ -171,8 +180,8 @@ def create_test_batch_item(server, batch_id, agent_id, default_user):
|
||||
step_number=1, tool_rules_solver=ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")])
|
||||
)
|
||||
|
||||
return server.batch_manager.create_batch_item(
|
||||
batch_id=batch_id,
|
||||
return server.batch_manager.create_llm_batch_item(
|
||||
llm_batch_id=batch_id,
|
||||
agent_id=agent_id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=common_step_state,
|
||||
@@ -242,8 +251,8 @@ async def test_polling_mixed_batch_jobs(client, default_user, server):
|
||||
agent_c = create_test_agent(client, "agent_c")
|
||||
|
||||
# --- Step 2: Create batch jobs ---
|
||||
job_a = create_test_batch_job(server, batch_a_resp, default_user)
|
||||
job_b = create_test_batch_job(server, batch_b_resp, default_user)
|
||||
job_a = create_test_llm_batch_job(server, batch_a_resp, default_user)
|
||||
job_b = create_test_llm_batch_job(server, batch_b_resp, default_user)
|
||||
|
||||
# --- Step 3: Create batch items ---
|
||||
item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user)
|
||||
@@ -258,8 +267,8 @@ async def test_polling_mixed_batch_jobs(client, default_user, server):
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# --- Step 6: Verify batch job status updates ---
|
||||
updated_job_a = server.batch_manager.get_batch_job_by_id(batch_id=job_a.id, actor=default_user)
|
||||
updated_job_b = server.batch_manager.get_batch_job_by_id(batch_id=job_b.id, actor=default_user)
|
||||
updated_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user)
|
||||
updated_job_b = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_b.id, actor=default_user)
|
||||
|
||||
# Job A should remain running since its processing_status is "in_progress"
|
||||
assert updated_job_a.status == JobStatus.running
|
||||
@@ -273,17 +282,17 @@ async def test_polling_mixed_batch_jobs(client, default_user, server):
|
||||
|
||||
# --- Step 7: Verify batch item status updates ---
|
||||
# Item A should remain unchanged
|
||||
updated_item_a = server.batch_manager.get_batch_item_by_id(item_a.id, actor=default_user)
|
||||
updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
|
||||
assert updated_item_a.request_status == JobStatus.created
|
||||
assert updated_item_a.batch_request_result is None
|
||||
|
||||
# Item B should be marked as completed with a successful result
|
||||
updated_item_b = server.batch_manager.get_batch_item_by_id(item_b.id, actor=default_user)
|
||||
updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
|
||||
assert updated_item_b.request_status == JobStatus.completed
|
||||
assert updated_item_b.batch_request_result is not None
|
||||
|
||||
# Item C should be marked as failed with an error result
|
||||
updated_item_c = server.batch_manager.get_batch_item_by_id(item_c.id, actor=default_user)
|
||||
updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
|
||||
assert updated_item_c.request_status == JobStatus.failed
|
||||
assert updated_item_c.batch_request_result is not None
|
||||
|
||||
@@ -307,11 +316,11 @@ async def test_polling_mixed_batch_jobs(client, default_user, server):
|
||||
|
||||
# --- Step 9: Verify that nothing changed for completed jobs ---
|
||||
# Refresh all objects
|
||||
final_job_a = server.batch_manager.get_batch_job_by_id(batch_id=job_a.id, actor=default_user)
|
||||
final_job_b = server.batch_manager.get_batch_job_by_id(batch_id=job_b.id, actor=default_user)
|
||||
final_item_a = server.batch_manager.get_batch_item_by_id(item_a.id, actor=default_user)
|
||||
final_item_b = server.batch_manager.get_batch_item_by_id(item_b.id, actor=default_user)
|
||||
final_item_c = server.batch_manager.get_batch_item_by_id(item_c.id, actor=default_user)
|
||||
final_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user)
|
||||
final_job_b = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_b.id, actor=default_user)
|
||||
final_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
|
||||
final_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
|
||||
final_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
|
||||
|
||||
# Job A should still be polling (last_polled_at should update)
|
||||
assert final_job_a.status == JobStatus.running
|
||||
|
||||
@@ -33,6 +33,7 @@ from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.orm import Base
|
||||
from letta.schemas.agent import AgentState, AgentStepState
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
||||
from letta.schemas.job import BatchJob
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_request import LettaBatchRequest
|
||||
from letta.schemas.message import MessageCreate
|
||||
@@ -256,7 +257,7 @@ def clear_batch_tables():
|
||||
"""Clear batch-related tables before each test."""
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables):
|
||||
if table.name in {"llm_batch_job", "llm_batch_items"}:
|
||||
if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}:
|
||||
session.execute(table.delete()) # Truncate table
|
||||
session.commit()
|
||||
|
||||
@@ -305,6 +306,22 @@ def client(server_url):
|
||||
return Letta(base_url=server_url)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_job(default_user, server):
|
||||
job = BatchJob(
|
||||
user_id=default_user.id,
|
||||
status=JobStatus.created,
|
||||
metadata={
|
||||
"job_type": "batch_messages",
|
||||
},
|
||||
)
|
||||
job = server.job_manager.create_job(pydantic_job=job, actor=default_user)
|
||||
yield job
|
||||
|
||||
# cleanup
|
||||
server.job_manager.delete_job_by_id(job.id, actor=default_user)
|
||||
|
||||
|
||||
class MockAsyncIterable:
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
@@ -324,8 +341,8 @@ class MockAsyncIterable:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_step_after_request_happy_path(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map
|
||||
async def test_resume_step_after_request_all_continue(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
anthropic_batch_id = "msgbatch_test_12345"
|
||||
dummy_batch_response = create_batch_response(
|
||||
@@ -342,6 +359,7 @@ async def test_resume_step_after_request_happy_path(
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
@@ -349,15 +367,20 @@ async def test_resume_step_after_request_happy_path(
|
||||
pre_resume_response = await batch_runner.step_until_request(
|
||||
batch_requests=batch_requests,
|
||||
agent_step_state_mapping=step_state_map,
|
||||
letta_batch_job_id=batch_job.id,
|
||||
)
|
||||
|
||||
# Basic sanity checks (This is tested more thoroughly in `test_step_until_request_prepares_and_submits_batch_correctly`
|
||||
# Verify batch items
|
||||
items = server.batch_manager.list_batch_items(batch_id=pre_resume_response.batch_id, actor=default_user)
|
||||
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
|
||||
assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}"
|
||||
|
||||
llm_batch_job = llm_batch_jobs[0]
|
||||
llm_batch_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
|
||||
assert len(llm_batch_items) == 3, f"Expected 3 llm_batch_items, got {len(llm_batch_items)}"
|
||||
|
||||
# 2. Invoke the polling job and mock responses from Anthropic
|
||||
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.batch_id, processing_status="ended"))
|
||||
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended"))
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
|
||||
mock_items = [
|
||||
@@ -372,13 +395,13 @@ async def test_resume_step_after_request_happy_path(
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# Verify database records were updated correctly
|
||||
job = server.batch_manager.get_batch_job_by_id(pre_resume_response.batch_id, actor=default_user)
|
||||
llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user)
|
||||
|
||||
# Verify job properties
|
||||
assert job.status == JobStatus.completed, "Job status should be 'completed'"
|
||||
assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'"
|
||||
|
||||
# Verify batch items
|
||||
items = server.batch_manager.list_batch_items(batch_id=job.id, actor=default_user)
|
||||
items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
|
||||
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
|
||||
assert all([item.request_status == JobStatus.completed for item in items])
|
||||
|
||||
@@ -390,22 +413,27 @@ async def test_resume_step_after_request_happy_path(
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=default_user,
|
||||
)
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
|
||||
|
||||
post_resume_response = await letta_batch_agent.resume_step_after_request(batch_id=pre_resume_response.batch_id)
|
||||
post_resume_response = await letta_batch_agent.resume_step_after_request(
|
||||
letta_batch_id=pre_resume_response.letta_batch_id, llm_batch_id=llm_batch_job.id
|
||||
)
|
||||
|
||||
# A *new* batch job should have been spawned
|
||||
assert (
|
||||
post_resume_response.batch_id != pre_resume_response.batch_id
|
||||
), "resume_step_after_request is expected to enqueue a follow‑up batch job."
|
||||
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
|
||||
), "resume_step_after_request is expected to have the same letta_batch_id"
|
||||
assert (
|
||||
post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id
|
||||
), "resume_step_after_request is expected to have different llm_batch_id."
|
||||
assert post_resume_response.status == JobStatus.running
|
||||
assert post_resume_response.agent_count == 3
|
||||
|
||||
# New batch‑items should exist, initialised in (created, paused) state
|
||||
new_items = server.batch_manager.list_batch_items(batch_id=post_resume_response.batch_id, actor=default_user)
|
||||
new_items = server.batch_manager.list_llm_batch_items(llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user)
|
||||
assert len(new_items) == 3, f"Expected 3 new batch items, got {len(new_items)}"
|
||||
assert {i.request_status for i in new_items} == {JobStatus.created}
|
||||
assert {i.step_status for i in new_items} == {AgentStepStatus.paused}
|
||||
@@ -420,7 +448,7 @@ async def test_resume_step_after_request_happy_path(
|
||||
|
||||
# Old items must have been flipped to completed / finished earlier
|
||||
# (sanity – we already asserted this above, but we keep it close for clarity)
|
||||
old_items = server.batch_manager.list_batch_items(batch_id=pre_resume_response.batch_id, actor=default_user)
|
||||
old_items = server.batch_manager.list_llm_batch_items(llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user)
|
||||
assert {i.request_status for i in old_items} == {JobStatus.completed}
|
||||
assert {i.step_status for i in old_items} == {AgentStepStatus.completed}
|
||||
|
||||
@@ -440,7 +468,7 @@ async def test_resume_step_after_request_happy_path(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response
|
||||
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job
|
||||
):
|
||||
"""
|
||||
Test that step_until_request correctly:
|
||||
@@ -512,6 +540,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
@@ -519,23 +548,25 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
response = await batch_runner.step_until_request(
|
||||
batch_requests=batch_requests,
|
||||
agent_step_state_mapping=step_state_map,
|
||||
letta_batch_job_id=batch_job.id,
|
||||
)
|
||||
|
||||
# Verify the mock was called exactly once
|
||||
mock_send.assert_called_once()
|
||||
|
||||
# Verify database records were created correctly
|
||||
job = server.batch_manager.get_batch_job_by_id(response.batch_id, actor=default_user)
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=response.letta_batch_id, actor=default_user)
|
||||
assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}"
|
||||
|
||||
llm_batch_job = llm_batch_jobs[0]
|
||||
llm_batch_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
|
||||
assert len(llm_batch_items) == 3, f"Expected 3 llm_batch_items, got {len(llm_batch_items)}"
|
||||
|
||||
# Verify job properties
|
||||
assert job.llm_provider == ProviderType.anthropic, "Job provider should be Anthropic"
|
||||
assert job.status == JobStatus.running, "Job status should be 'running'"
|
||||
|
||||
# Verify batch items
|
||||
items = server.batch_manager.list_batch_items(batch_id=job.id, actor=default_user)
|
||||
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
|
||||
assert llm_batch_job.llm_provider == ProviderType.anthropic, "Job provider should be Anthropic"
|
||||
assert llm_batch_job.status == JobStatus.running, "Job status should be 'running'"
|
||||
|
||||
# Verify all agents are represented in batch items
|
||||
agent_ids_in_items = {item.agent_id for item in items}
|
||||
agent_ids_in_items = {item.agent_id for item in llm_batch_items}
|
||||
expected_agent_ids = {agent.id for agent in agents}
|
||||
assert agent_ids_in_items == expected_agent_ids, f"Expected agent IDs {expected_agent_ids}, got {agent_ids_in_items}"
|
||||
|
||||
@@ -39,6 +39,8 @@ from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, Provide
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate, IdentityUpsert
|
||||
from letta.schemas.job import BatchJob
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
|
||||
@@ -615,6 +617,11 @@ def dummy_successful_response() -> BetaMessageBatchIndividualResponse:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def letta_batch_job(server: SyncServer, default_user) -> Job:
|
||||
return server.job_manager.create_job(BatchJob(user_id=default_user.id), actor=default_user)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Basic
|
||||
# ======================================================================================================================
|
||||
@@ -4761,77 +4768,90 @@ def test_list_tags(server: SyncServer, default_user, default_organization):
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch):
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch, letta_batch_job):
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
assert batch.id.startswith("batch_req-")
|
||||
assert batch.create_batch_response == dummy_beta_message_batch
|
||||
fetched = server.batch_manager.get_batch_job_by_id(batch.id, actor=default_user)
|
||||
fetched = server.batch_manager.get_llm_batch_job_by_id(batch.id, actor=default_user)
|
||||
assert fetched.id == batch.id
|
||||
|
||||
|
||||
def test_update_batch_status(server, default_user, dummy_beta_message_batch):
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
def test_update_batch_status(server, default_user, dummy_beta_message_batch, letta_batch_job):
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
before = datetime.now(timezone.utc)
|
||||
|
||||
server.batch_manager.update_batch_status(
|
||||
batch_id=batch.id,
|
||||
server.batch_manager.update_llm_batch_status(
|
||||
llm_batch_id=batch.id,
|
||||
status=JobStatus.completed,
|
||||
latest_polling_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
updated = server.batch_manager.get_batch_job_by_id(batch.id, actor=default_user)
|
||||
updated = server.batch_manager.get_llm_batch_job_by_id(batch.id, actor=default_user)
|
||||
assert updated.status == JobStatus.completed
|
||||
assert updated.latest_polling_response == dummy_beta_message_batch
|
||||
assert updated.last_polled_at >= before
|
||||
|
||||
|
||||
def test_create_and_get_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
def test_create_and_get_batch_item(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
|
||||
):
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
item = server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert item.batch_id == batch.id
|
||||
assert item.llm_batch_id == batch.id
|
||||
assert item.agent_id == sarah_agent.id
|
||||
assert item.step_state == dummy_step_state
|
||||
|
||||
fetched = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
|
||||
fetched = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
|
||||
assert fetched.id == item.id
|
||||
|
||||
|
||||
def test_update_batch_item(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response
|
||||
server,
|
||||
default_user,
|
||||
sarah_agent,
|
||||
dummy_beta_message_batch,
|
||||
dummy_llm_config,
|
||||
dummy_step_state,
|
||||
dummy_successful_response,
|
||||
letta_batch_job,
|
||||
):
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
item = server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
@@ -4840,7 +4860,7 @@ def test_update_batch_item(
|
||||
|
||||
updated_step_state = AgentStepState(step_number=2, tool_rules_solver=dummy_step_state.tool_rules_solver)
|
||||
|
||||
server.batch_manager.update_batch_item(
|
||||
server.batch_manager.update_llm_batch_item(
|
||||
item_id=item.id,
|
||||
request_status=JobStatus.completed,
|
||||
step_status=AgentStepStatus.resumed,
|
||||
@@ -4849,146 +4869,166 @@ def test_update_batch_item(
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
|
||||
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
|
||||
assert updated.request_status == JobStatus.completed
|
||||
assert updated.batch_request_result == dummy_successful_response
|
||||
|
||||
|
||||
def test_delete_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
def test_delete_batch_item(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
|
||||
):
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
item = server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
server.batch_manager.delete_batch_item(item_id=item.id, actor=default_user)
|
||||
server.batch_manager.delete_llm_batch_item(item_id=item.id, actor=default_user)
|
||||
|
||||
with pytest.raises(NoResultFound):
|
||||
server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
|
||||
server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
|
||||
|
||||
|
||||
def test_list_running_batches(server, default_user, dummy_beta_message_batch):
|
||||
server.batch_manager.create_batch_job(
|
||||
def test_list_running_batches(server, default_user, dummy_beta_message_batch, letta_batch_job):
|
||||
server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.running,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
running_batches = server.batch_manager.list_running_batches(actor=default_user)
|
||||
running_batches = server.batch_manager.list_running_llm_batches(actor=default_user)
|
||||
assert len(running_batches) >= 1
|
||||
assert all(batch.status == JobStatus.running for batch in running_batches)
|
||||
|
||||
|
||||
def test_bulk_update_batch_statuses(server, default_user, dummy_beta_message_batch):
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
def test_bulk_update_batch_statuses(server, default_user, dummy_beta_message_batch, letta_batch_job):
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
server.batch_manager.bulk_update_batch_statuses([(batch.id, JobStatus.completed, dummy_beta_message_batch)])
|
||||
server.batch_manager.bulk_update_llm_batch_statuses([(batch.id, JobStatus.completed, dummy_beta_message_batch)])
|
||||
|
||||
updated = server.batch_manager.get_batch_job_by_id(batch.id, actor=default_user)
|
||||
updated = server.batch_manager.get_llm_batch_job_by_id(batch.id, actor=default_user)
|
||||
assert updated.status == JobStatus.completed
|
||||
assert updated.latest_polling_response == dummy_beta_message_batch
|
||||
|
||||
|
||||
def test_bulk_update_batch_items_results_by_agent(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response
|
||||
server,
|
||||
default_user,
|
||||
sarah_agent,
|
||||
dummy_beta_message_batch,
|
||||
dummy_llm_config,
|
||||
dummy_step_state,
|
||||
dummy_successful_response,
|
||||
letta_batch_job,
|
||||
):
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
item = server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
server.batch_manager.bulk_update_batch_items_results_by_agent(
|
||||
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
|
||||
[ItemUpdateInfo(batch.id, sarah_agent.id, JobStatus.completed, dummy_successful_response)]
|
||||
)
|
||||
|
||||
updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
|
||||
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
|
||||
assert updated.request_status == JobStatus.completed
|
||||
assert updated.batch_request_result == dummy_successful_response
|
||||
|
||||
|
||||
def test_bulk_update_batch_items_step_status_by_agent(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
|
||||
):
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
item = server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
server.batch_manager.bulk_update_batch_items_step_status_by_agent(
|
||||
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
|
||||
[StepStatusUpdateInfo(batch.id, sarah_agent.id, AgentStepStatus.resumed)]
|
||||
)
|
||||
|
||||
updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
|
||||
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
|
||||
assert updated.step_status == AgentStepStatus.resumed
|
||||
|
||||
|
||||
def test_list_batch_items_limit_and_filter(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
def test_list_batch_items_limit_and_filter(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
|
||||
):
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
server.batch_manager.create_llm_batch_item(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user)
|
||||
limited_items = server.batch_manager.list_batch_items(batch_id=batch.id, limit=2, actor=default_user)
|
||||
all_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user)
|
||||
limited_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, limit=2, actor=default_user)
|
||||
|
||||
assert len(all_items) >= 3
|
||||
assert len(limited_items) == 2
|
||||
|
||||
|
||||
def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
|
||||
def test_list_batch_items_pagination(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
|
||||
):
|
||||
# Create a batch job.
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
# Create 10 batch items.
|
||||
created_items = []
|
||||
for i in range(10):
|
||||
item = server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
@@ -4997,7 +5037,7 @@ def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_be
|
||||
created_items.append(item)
|
||||
|
||||
# Retrieve all items (without pagination).
|
||||
all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user)
|
||||
all_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user)
|
||||
assert len(all_items) >= 10, f"Expected at least 10 items, got {len(all_items)}"
|
||||
|
||||
# Verify the items are ordered ascending by id (based on our implementation).
|
||||
@@ -5009,7 +5049,7 @@ def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_be
|
||||
cursor = all_items[4].id
|
||||
|
||||
# Retrieve items after the cursor.
|
||||
paged_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=cursor)
|
||||
paged_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user, after=cursor)
|
||||
|
||||
# All returned items should have an id greater than the cursor.
|
||||
for item in paged_items:
|
||||
@@ -5023,7 +5063,7 @@ def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_be
|
||||
|
||||
# Test pagination with a limit.
|
||||
limit = 3
|
||||
limited_page = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=cursor, limit=limit)
|
||||
limited_page = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user, after=cursor, limit=limit)
|
||||
# If more than 'limit' items remain, we should only get exactly 'limit' items.
|
||||
assert len(limited_page) == min(
|
||||
limit, expected_remaining
|
||||
@@ -5031,23 +5071,24 @@ def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_be
|
||||
|
||||
# Optional: Test with a cursor beyond the last item returns an empty list.
|
||||
last_cursor = sorted_ids[-1]
|
||||
empty_page = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=last_cursor)
|
||||
empty_page = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user, after=last_cursor)
|
||||
assert empty_page == [], "Expected an empty list when cursor is after the last item"
|
||||
|
||||
|
||||
def test_bulk_update_batch_items_request_status_by_agent(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
|
||||
):
|
||||
# Create a batch job
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
# Create a batch item
|
||||
item = server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
item = server.batch_manager.create_llm_batch_item(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
@@ -5055,55 +5096,59 @@ def test_bulk_update_batch_items_request_status_by_agent(
|
||||
)
|
||||
|
||||
# Update the request status using the bulk update method
|
||||
server.batch_manager.bulk_update_batch_items_request_status_by_agent(
|
||||
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
|
||||
[RequestStatusUpdateInfo(batch.id, sarah_agent.id, JobStatus.expired)]
|
||||
)
|
||||
|
||||
# Verify the update was applied
|
||||
updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
|
||||
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
|
||||
assert updated.request_status == JobStatus.expired
|
||||
|
||||
|
||||
def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response):
|
||||
def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response, letta_batch_job):
|
||||
# Create a batch job
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
# Attempt to update non-existent items should not raise errors
|
||||
|
||||
# Test with the direct bulk_update_batch_items method
|
||||
# Test with the direct bulk_update_llm_batch_items method
|
||||
nonexistent_pairs = [(batch.id, "nonexistent-agent-id")]
|
||||
nonexistent_updates = [{"request_status": JobStatus.expired}]
|
||||
|
||||
# This should not raise an error, just silently skip non-existent items
|
||||
server.batch_manager.bulk_update_batch_items(nonexistent_pairs, nonexistent_updates)
|
||||
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates)
|
||||
|
||||
# Test with higher-level methods
|
||||
# Results by agent
|
||||
server.batch_manager.bulk_update_batch_items_results_by_agent(
|
||||
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
|
||||
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)]
|
||||
)
|
||||
|
||||
# Step status by agent
|
||||
server.batch_manager.bulk_update_batch_items_step_status_by_agent(
|
||||
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
|
||||
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)]
|
||||
)
|
||||
|
||||
# Request status by agent
|
||||
server.batch_manager.bulk_update_batch_items_request_status_by_agent(
|
||||
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
|
||||
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)]
|
||||
)
|
||||
|
||||
|
||||
def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
|
||||
def test_create_batch_items_bulk(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
|
||||
):
|
||||
# Create a batch job
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
llm_batch_job = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
# Prepare data for multiple batch items
|
||||
@@ -5112,7 +5157,7 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m
|
||||
|
||||
for agent_id in agent_ids:
|
||||
batch_item = LLMBatchItem(
|
||||
batch_id=batch.id,
|
||||
llm_batch_id=llm_batch_job.id,
|
||||
agent_id=agent_id,
|
||||
llm_config=dummy_llm_config,
|
||||
request_status=JobStatus.created,
|
||||
@@ -5122,7 +5167,7 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m
|
||||
batch_items.append(batch_item)
|
||||
|
||||
# Call the bulk create function
|
||||
created_items = server.batch_manager.create_batch_items_bulk(batch_items, actor=default_user)
|
||||
created_items = server.batch_manager.create_llm_batch_items_bulk(batch_items, actor=default_user)
|
||||
|
||||
# Verify the correct number of items were created
|
||||
assert len(created_items) == len(agent_ids)
|
||||
@@ -5130,7 +5175,7 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m
|
||||
# Verify each item has expected properties
|
||||
for item in created_items:
|
||||
assert item.id.startswith("batch_item-")
|
||||
assert item.batch_id == batch.id
|
||||
assert item.llm_batch_id == llm_batch_job.id
|
||||
assert item.agent_id in agent_ids
|
||||
assert item.llm_config == dummy_llm_config
|
||||
assert item.request_status == JobStatus.created
|
||||
@@ -5138,38 +5183,41 @@ def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_m
|
||||
assert item.step_state == dummy_step_state
|
||||
|
||||
# Verify items can be retrieved from the database
|
||||
all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user)
|
||||
all_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
|
||||
assert len(all_items) >= len(agent_ids)
|
||||
|
||||
# Verify the IDs of created items match what's in the database
|
||||
created_ids = [item.id for item in created_items]
|
||||
for item_id in created_ids:
|
||||
fetched = server.batch_manager.get_batch_item_by_id(item_id, actor=default_user)
|
||||
fetched = server.batch_manager.get_llm_batch_item_by_id(item_id, actor=default_user)
|
||||
assert fetched.id in created_ids
|
||||
|
||||
|
||||
def test_count_batch_items(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
|
||||
def test_count_batch_items(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job
|
||||
):
|
||||
# Create a batch job first.
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
# Create a specific number of batch items for this batch.
|
||||
num_items = 5
|
||||
for _ in range(num_items):
|
||||
server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
server.batch_manager.create_llm_batch_item(
|
||||
llm_batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Use the count_batch_items method to count the items.
|
||||
count = server.batch_manager.count_batch_items(batch_id=batch.id)
|
||||
# Use the count_llm_batch_items method to count the items.
|
||||
count = server.batch_manager.count_llm_batch_items(llm_batch_id=batch.id)
|
||||
|
||||
# Assert that the count matches the expected number.
|
||||
assert count == num_items, f"Expected {num_items} items, got {count}"
|
||||
|
||||
Reference in New Issue
Block a user