From d9d586e43100729581c6d84d6b0f0562a9c7e3b8 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 2 May 2025 11:14:03 -0700 Subject: [PATCH] feat: Add message listing for a letta batch (#1982) --- ...5b1eb9c40_add_batch_item_id_to_messages.py | 31 ++++++++ letta/agents/helpers.py | 5 +- letta/agents/letta_agent_batch.py | 50 ++++++++----- letta/helpers/message_helper.py | 37 +++++----- letta/orm/message.py | 4 + letta/schemas/letta_response.py | 5 ++ letta/schemas/llm_batch_job.py | 10 ++- letta/schemas/message.py | 2 + letta/server/rest_api/routers/v1/messages.py | 47 +++++++++++- letta/server/rest_api/utils.py | 11 +-- letta/services/llm_batch_manager.py | 61 +++++++++++++++- letta/services/message_manager.py | 1 + tests/test_letta_agent_batch.py | 73 ++++++++++++++++++- 13 files changed, 289 insertions(+), 48 deletions(-) create mode 100644 alembic/versions/0335b1eb9c40_add_batch_item_id_to_messages.py diff --git a/alembic/versions/0335b1eb9c40_add_batch_item_id_to_messages.py b/alembic/versions/0335b1eb9c40_add_batch_item_id_to_messages.py new file mode 100644 index 00000000..01c87429 --- /dev/null +++ b/alembic/versions/0335b1eb9c40_add_batch_item_id_to_messages.py @@ -0,0 +1,31 @@ +"""Add batch_item_id to messages + +Revision ID: 0335b1eb9c40 +Revises: 373dabcba6cf +Create Date: 2025-05-02 10:30:08.156190 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0335b1eb9c40" +down_revision: Union[str, None] = "373dabcba6cf" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("messages", sa.Column("batch_item_id", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("messages", "batch_item_id") + # ### end Alembic commands ### diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 38228119..ce07bafc 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -21,7 +21,10 @@ def _create_letta_response(new_in_context_messages: list[Message], use_assistant def _prepare_in_context_messages( - input_messages: List[MessageCreate], agent_state: AgentState, message_manager: MessageManager, actor: User + input_messages: List[MessageCreate], + agent_state: AgentState, + message_manager: MessageManager, + actor: User, ) -> Tuple[List[Message], List[Message]]: """ Prepares in-context messages for an agent, based on the current state and a new user input. diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index 3610bf2e..b9e30ac0 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -137,21 +137,37 @@ class LettaAgentBatch: log_event(name="load_and_prepare_agents") agent_messages_mapping: Dict[str, List[Message]] = {} agent_tools_mapping: Dict[str, List[dict]] = {} + # TODO: This isn't optimal, moving fast - prone to bugs because we pass around this half formed pydantic object + agent_batch_item_mapping: Dict[str, LLMBatchItem] = {} agent_states = [] for batch_request in batch_requests: agent_id = batch_request.agent_id agent_state = self.agent_manager.get_agent_by_id(agent_id, actor=self.actor) agent_states.append(agent_state) - agent_messages_mapping[agent_id] = self._get_in_context_messages_per_agent( - agent_state=agent_state, input_messages=batch_request.messages - ) - if agent_id not in agent_step_state_mapping: agent_step_state_mapping[agent_id] = AgentStepState( step_number=0, tool_rules_solver=ToolRulesSolver(tool_rules=agent_state.tool_rules) ) + llm_batch_item = LLMBatchItem( + llm_batch_id="", # TODO: This is hacky, it gets filled in later + agent_id=agent_state.id, + llm_config=agent_state.llm_config, + request_status=JobStatus.created, + step_status=AgentStepStatus.paused, + step_state=agent_step_state_mapping[agent_id], + ) + agent_batch_item_mapping[agent_id] = llm_batch_item + + # Fill in the batch_item_id for the message + for msg in batch_request.messages: + msg.batch_item_id = llm_batch_item.id + + agent_messages_mapping[agent_id] = self._prepare_in_context_messages_per_agent( + agent_state=agent_state, input_messages=batch_request.messages + ) + agent_tools_mapping[agent_id] = self._prepare_tools_per_agent(agent_state, agent_step_state_mapping[agent_id].tool_rules_solver) log_event(name="init_llm_client") @@ -182,21 +198,14 @@ class LettaAgentBatch: log_event(name="prepare_batch_items") batch_items = [] for state in agent_states: - step_state = agent_step_state_mapping[state.id] - batch_items.append( - LLMBatchItem( - llm_batch_id=llm_batch_job.id, - agent_id=state.id, - llm_config=state.llm_config, - request_status=JobStatus.created, - step_status=AgentStepStatus.paused, - step_state=step_state, - ) - ) + llm_batch_item = agent_batch_item_mapping[state.id] + # TODO This is hacky + llm_batch_item.llm_batch_id = llm_batch_job.id + batch_items.append(llm_batch_item) if batch_items: log_event(name="bulk_create_batch_items") - self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor) + batch_items_persisted = self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor) log_event(name="return_batch_response") return LettaBatchResponse( @@ -335,9 +344,14 @@ class LettaAgentBatch: exec_results: Sequence[Tuple[str, Tuple[str, bool]]], ctx: _ResumeContext, ) -> Dict[str, List[Message]]: + # TODO: This is redundant, we should have this ready on the ctx + # TODO: I am doing it quick and dirty for now + agent_item_map: Dict[str, LLMBatchItem] = {item.agent_id: item for item in ctx.batch_items} + msg_map: Dict[str, List[Message]] = {} for aid, (tool_res, success) in exec_results: msgs = self._create_tool_call_messages( + llm_batch_item_id=agent_item_map[aid].id, agent_state=ctx.agent_state_map[aid], tool_call_name=ctx.tool_call_name_map[aid], tool_call_args=ctx.tool_call_args_map[aid], @@ -399,6 +413,7 @@ class LettaAgentBatch: def _create_tool_call_messages( self, + llm_batch_item_id: str, agent_state: AgentState, tool_call_name: str, tool_call_args: Dict[str, Any], @@ -421,6 +436,7 @@ class LettaAgentBatch: reasoning_content=reasoning_content, pre_computed_assistant_message_id=None, pre_computed_tool_message_id=None, + llm_batch_item_id=llm_batch_item_id, ) return tool_call_messages @@ -477,7 +493,7 @@ class LettaAgentBatch: valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools])) return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)] - def _get_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]: + def _prepare_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]: current_in_context_messages, new_in_context_messages = _prepare_in_context_messages( input_messages, agent_state, self.message_manager, self.actor ) diff --git a/letta/helpers/message_helper.py b/letta/helpers/message_helper.py index be05b85a..90d7b680 100644 --- a/letta/helpers/message_helper.py +++ b/letta/helpers/message_helper.py @@ -5,57 +5,58 @@ from letta.schemas.message import Message, MessageCreate def convert_message_creates_to_messages( - messages: list[MessageCreate], + message_creates: list[MessageCreate], agent_id: str, wrap_user_message: bool = True, wrap_system_message: bool = True, ) -> list[Message]: return [ _convert_message_create_to_message( - message=message, + message_create=create, agent_id=agent_id, wrap_user_message=wrap_user_message, wrap_system_message=wrap_system_message, ) - for message in messages + for create in message_creates ] def _convert_message_create_to_message( - message: MessageCreate, + message_create: MessageCreate, agent_id: str, wrap_user_message: bool = True, wrap_system_message: bool = True, ) -> Message: """Converts a MessageCreate object into a Message object, applying wrapping if needed.""" # TODO: This seems like extra boilerplate with little benefit - assert isinstance(message, MessageCreate) + assert isinstance(message_create, MessageCreate) # Extract message content - if isinstance(message.content, str): - message_content = message.content - elif message.content and len(message.content) > 0 and isinstance(message.content[0], TextContent): - message_content = message.content[0].text + if isinstance(message_create.content, str): + message_content = message_create.content + elif message_create.content and len(message_create.content) > 0 and isinstance(message_create.content[0], TextContent): + message_content = message_create.content[0].text else: raise ValueError("Message content is empty or invalid") # Apply wrapping if needed - if message.role not in {MessageRole.user, MessageRole.system}: - raise ValueError(f"Invalid message role: {message.role}") - elif message.role == MessageRole.user and wrap_user_message: + if message_create.role not in {MessageRole.user, MessageRole.system}: + raise ValueError(f"Invalid message role: {message_create.role}") + elif message_create.role == MessageRole.user and wrap_user_message: message_content = system.package_user_message(user_message=message_content) - elif message.role == MessageRole.system and wrap_system_message: + elif message_create.role == MessageRole.system and wrap_system_message: message_content = system.package_system_message(system_message=message_content) return Message( agent_id=agent_id, - role=message.role, + role=message_create.role, content=[TextContent(text=message_content)] if message_content else [], - name=message.name, + name=message_create.name, model=None, # assigned later? tool_calls=None, # irrelevant tool_call_id=None, - otid=message.otid, - sender_id=message.sender_id, - group_id=message.group_id, + otid=message_create.otid, + sender_id=message_create.sender_id, + group_id=message_create.group_id, + batch_item_id=message_create.batch_item_id, ) diff --git a/letta/orm/message.py b/letta/orm/message.py index b5c65ec3..0495da20 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -44,6 +44,10 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): sender_id: Mapped[Optional[str]] = mapped_column( nullable=True, doc="The id of the sender of the message, can be an identity id or agent id" ) + batch_item_id: Mapped[Optional[str]] = mapped_column( + nullable=True, + doc="The id of the LLMBatchItem that this message is associated with", + ) # Monotonically increasing sequence for efficient/correct listing sequence_id: Mapped[int] = mapped_column( diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index 453fa30a..a4057298 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Field from letta.helpers.json_helpers import json_dumps from letta.schemas.enums import JobStatus, MessageStreamStatus from letta.schemas.letta_message import LettaMessage, LettaMessageUnion +from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics # TODO: consider moving into own file @@ -175,3 +176,7 @@ class LettaBatchResponse(BaseModel): agent_count: int = Field(..., description="The number of agents in the batch request.") last_polled_at: datetime = Field(..., description="The timestamp when the batch was last polled for updates.") created_at: datetime = Field(..., description="The timestamp when the batch request was created.") + + +class LettaBatchMessages(BaseModel): + messages: List[Message] diff --git a/letta/schemas/llm_batch_job.py b/letta/schemas/llm_batch_job.py index cde072f1..a6e537f0 100644 --- a/letta/schemas/llm_batch_job.py +++ b/letta/schemas/llm_batch_job.py @@ -10,16 +10,18 @@ from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.llm_config import LLMConfig -class LLMBatchItem(OrmMetadataBase, validate_assignment=True): +class LLMBatchItemBase(OrmMetadataBase, validate_assignment=True): + __id_prefix__ = "batch_item" + + +class LLMBatchItem(LLMBatchItemBase, validate_assignment=True): """ Represents a single agent's LLM request within a batch. This object captures the configuration, execution status, and eventual result of one agent's request within a larger LLM batch job. """ - __id_prefix__ = "batch_item" - - id: Optional[str] = Field(None, description="The id of the batch item. Assigned by the database.") + id: str = LLMBatchItemBase.generate_id_field() llm_batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.") agent_id: str = Field(..., description="The id of the agent associated with this LLM request.") diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 6c310801..7fbe1fd4 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -85,6 +85,7 @@ class MessageCreate(BaseModel): name: Optional[str] = Field(None, description="The name of the participant.") otid: Optional[str] = Field(None, description="The offline threading id associated with this message") sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id") + batch_item_id: Optional[str] = Field(None, description="The id of the LLMBatchItem that this message is associated with") group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in") def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]: @@ -168,6 +169,7 @@ class Message(BaseMessage): tool_returns: Optional[List[ToolReturn]] = Field(None, description="Tool execution return information for prior tool calls") group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in") sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id") + batch_item_id: Optional[str] = Field(None, description="The id of the LLMBatchItem that this message is associated with") # This overrides the optional base orm schema, created_at MUST exist on all messages objects created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") diff --git a/letta/server/rest_api/routers/v1/messages.py b/letta/server/rest_api/routers/v1/messages.py index 252e6fe8..95b3748f 100644 --- a/letta/server/rest_api/routers/v1/messages.py +++ b/letta/server/rest_api/routers/v1/messages.py @@ -1,6 +1,6 @@ from typing import List, Optional -from fastapi import APIRouter, Body, Depends, Header, status +from fastapi import APIRouter, Body, Depends, Header, Query, status from fastapi.exceptions import HTTPException from starlette.requests import Request @@ -9,6 +9,7 @@ from letta.log import get_logger from letta.orm.errors import NoResultFound from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate from letta.schemas.letta_request import CreateBatch +from letta.schemas.letta_response import LettaBatchMessages from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer from letta.settings import settings @@ -123,6 +124,50 @@ async def list_batch_runs( return [BatchJob.from_job(job) for job in jobs] +@router.get( + "/batches/{batch_id}/messages", + response_model=LettaBatchMessages, + operation_id="list_batch_messages", +) +async def list_batch_messages( + batch_id: str, + limit: int = Query(100, description="Maximum number of messages to return"), + cursor: Optional[str] = Query( + None, description="Message ID to use as pagination cursor (get messages before/after this ID) depending on sort_descending." + ), + agent_id: Optional[str] = Query(None, description="Filter messages by agent ID"), + sort_descending: bool = Query(True, description="Sort messages by creation time (true=newest first)"), + actor_id: Optional[str] = Header(None, alias="user_id"), + server: SyncServer = Depends(get_letta_server), +): + """ + Get messages for a specific batch job. + + Returns messages associated with the batch in chronological order. + + Pagination: + - For the first page, omit the cursor parameter + - For subsequent pages, use the ID of the last message from the previous response as the cursor + - Results will include messages before/after the cursor based on sort_descending + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + + # First, verify the batch job exists and the user has access to it + try: + job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor) + BatchJob.from_job(job) + except NoResultFound: + raise HTTPException(status_code=404, detail="Batch not found") + + # Get messages directly using our efficient method + # We'll need to update the underlying implementation to use message_id as cursor + messages = server.batch_manager.get_messages_for_letta_batch( + letta_batch_job_id=batch_id, limit=limit, actor=actor, agent_id=agent_id, sort_descending=sort_descending, cursor=cursor + ) + + return LettaBatchMessages(messages=messages) + + @router.patch("/batches/{batch_id}/cancel", operation_id="cancel_batch_run") async def cancel_batch_run( batch_id: str, diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 2e9b3e9a..eff457eb 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -168,6 +168,7 @@ def create_letta_messages_from_llm_response( reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, pre_computed_assistant_message_id: Optional[str] = None, pre_computed_tool_message_id: Optional[str] = None, + llm_batch_item_id: Optional[str] = None, ) -> List[Message]: messages = [] @@ -192,6 +193,7 @@ def create_letta_messages_from_llm_response( tool_calls=[tool_call], tool_call_id=tool_call_id, created_at=get_utc_time(), + batch_item_id=llm_batch_item_id, ) if pre_computed_assistant_message_id: assistant_message.id = pre_computed_assistant_message_id @@ -209,6 +211,7 @@ def create_letta_messages_from_llm_response( tool_call_id=tool_call_id, created_at=get_utc_time(), name=function_name, + batch_item_id=llm_batch_item_id, ) if pre_computed_tool_message_id: tool_message.id = pre_computed_tool_message_id @@ -216,7 +219,7 @@ def create_letta_messages_from_llm_response( if add_heartbeat_request_system_message: heartbeat_system_message = create_heartbeat_system_message( - agent_id=agent_id, model=model, function_call_success=function_call_success, actor=actor + agent_id=agent_id, model=model, function_call_success=function_call_success, actor=actor, llm_batch_item_id=llm_batch_item_id ) messages.append(heartbeat_system_message) @@ -224,10 +227,7 @@ def create_letta_messages_from_llm_response( def create_heartbeat_system_message( - agent_id: str, - model: str, - function_call_success: bool, - actor: User, + agent_id: str, model: str, function_call_success: bool, actor: User, llm_batch_item_id: Optional[str] = None ) -> Message: text_content = REQ_HEARTBEAT_MESSAGE if function_call_success else FUNC_FAILED_HEARTBEAT_MESSAGE heartbeat_system_message = Message( @@ -239,6 +239,7 @@ def create_heartbeat_system_message( tool_calls=[], tool_call_id=None, created_at=get_utc_time(), + batch_item_id=llm_batch_item_id, ) return heartbeat_system_message diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index ec3a947b..caebaaf0 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -2,10 +2,11 @@ import datetime from typing import Any, Dict, List, Optional, Tuple from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse -from sqlalchemy import func, tuple_ +from sqlalchemy import desc, func, tuple_ from letta.jobs.types import BatchPollingResult, ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo from letta.log import get_logger +from letta.orm import Message as MessageModel from letta.orm.llm_batch_items import LLMBatchItem from letta.orm.llm_batch_job import LLMBatchJob from letta.schemas.agent import AgentStepState @@ -13,6 +14,7 @@ from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message as PydanticMessage from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types @@ -142,6 +144,62 @@ class LLMBatchManager: batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor) batch.hard_delete(db_session=session, actor=actor) + @enforce_types + def get_messages_for_letta_batch( + self, + letta_batch_job_id: str, + limit: int = 100, + actor: Optional[PydanticUser] = None, + agent_id: Optional[str] = None, + sort_descending: bool = True, + cursor: Optional[str] = None, # Message ID as cursor + ) -> List[PydanticMessage]: + """ + Retrieve messages across all LLM batch jobs associated with a Letta batch job. + Optimized for PostgreSQL performance using ID-based keyset pagination. + """ + with self.session_maker() as session: + # If cursor is provided, get sequence_id for that message + cursor_sequence_id = None + if cursor: + cursor_query = session.query(MessageModel.sequence_id).filter(MessageModel.id == cursor).limit(1) + cursor_result = cursor_query.first() + if cursor_result: + cursor_sequence_id = cursor_result[0] + else: + # If cursor message doesn't exist, ignore it + pass + + query = ( + session.query(MessageModel) + .join(LLMBatchItem, MessageModel.batch_item_id == LLMBatchItem.id) + .join(LLMBatchJob, LLMBatchItem.llm_batch_id == LLMBatchJob.id) + .filter(LLMBatchJob.letta_batch_job_id == letta_batch_job_id) + ) + + if actor is not None: + query = query.filter(MessageModel.organization_id == actor.organization_id) + + if agent_id is not None: + query = query.filter(MessageModel.agent_id == agent_id) + + # Apply cursor-based pagination if cursor exists + if cursor_sequence_id is not None: + if sort_descending: + query = query.filter(MessageModel.sequence_id < cursor_sequence_id) + else: + query = query.filter(MessageModel.sequence_id > cursor_sequence_id) + + if sort_descending: + query = query.order_by(desc(MessageModel.sequence_id)) + else: + query = query.order_by(MessageModel.sequence_id) + + query = query.limit(limit) + + results = query.all() + return [message.to_pydantic() for message in results] + @enforce_types def list_running_llm_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]: """Return all running LLM batch jobs, optionally filtered by actor's organization.""" @@ -196,6 +254,7 @@ class LLMBatchManager: orm_items = [] for item in llm_batch_items: orm_item = LLMBatchItem( + id=item.id, llm_batch_id=item.llm_batch_id, agent_id=item.agent_id, llm_config=item.llm_config, diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index bca13353..e87f9917 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -73,6 +73,7 @@ class MessageManager: Returns: List of created Pydantic message models """ + if not pydantic_msgs: return [] diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index 20f16611..ee668fd0 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -23,7 +23,7 @@ from letta.helpers import ToolRulesSolver from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.orm import Base from letta.schemas.agent import AgentState, AgentStepState -from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType +from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType from letta.schemas.job import BatchJob from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_request import LettaBatchRequest @@ -589,6 +589,26 @@ async def test_partial_error_from_anthropic_batch( len(refreshed_agent.message_ids) == 6 ), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}" + # Check the total list of messages + messages = server.batch_manager.get_messages_for_letta_batch( + letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user + ) + assert len(messages) == (len(agents) - 1) * 4 + 1 + assert_descending_order(messages) + # Check that each agent is represented + for agent in agents_continue: + agent_messages = [m for m in messages if m.agent_id == agent.id] + assert len(agent_messages) == 4 + assert agent_messages[-1].role == MessageRole.user, "Expected initial user message" + assert agent_messages[-2].role == MessageRole.assistant, "Expected assistant tool call after user message" + assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call" + assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message" + + for agent in agents_failed: + agent_messages = [m for m in messages if m.agent_id == agent.id] + assert len(agent_messages) == 1 + assert agent_messages[0].role == MessageRole.user, "Expected initial user message" + @pytest.mark.asyncio async def test_resume_step_some_stop( @@ -718,6 +738,42 @@ async def test_resume_step_some_stop( len(refreshed_agent.message_ids) == 6 ), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}" + # Check the total list of messages + messages = server.batch_manager.get_messages_for_letta_batch( + letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user + ) + assert len(messages) == len(agents) * 3 + 1 + assert_descending_order(messages) + # Check that each agent is represented + for agent in agents_continue: + agent_messages = [m for m in messages if m.agent_id == agent.id] + assert len(agent_messages) == 4 + assert agent_messages[-1].role == MessageRole.user, "Expected initial user message" + assert agent_messages[-2].role == MessageRole.assistant, "Expected assistant tool call after user message" + assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call" + assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message" + + for agent in agents_finish: + agent_messages = [m for m in messages if m.agent_id == agent.id] + assert len(agent_messages) == 3 + assert agent_messages[-1].role == MessageRole.user, "Expected initial user message" + assert agent_messages[-2].role == MessageRole.assistant, "Expected assistant tool call after user message" + assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call" + + +def assert_descending_order(messages): + """Assert messages are in descending order by created_at timestamps.""" + if len(messages) <= 1: + return True + + for i in range(1, len(messages)): + assert messages[i].created_at <= messages[i - 1].created_at, ( + f"Order violation: {messages[i - 1].id} ({messages[i - 1].created_at}) " + f"followed by {messages[i].id} ({messages[i].created_at})" + ) + + return True + @pytest.mark.asyncio async def test_resume_step_after_request_all_continue( @@ -841,6 +897,21 @@ async def test_resume_step_after_request_all_continue( len(refreshed_agent.message_ids) == 6 ), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}" + # Check the total list of messages + messages = server.batch_manager.get_messages_for_letta_batch( + letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user + ) + assert len(messages) == len(agents) * 4 + assert_descending_order(messages) + # Check that each agent is represented + for agent in agents: + agent_messages = [m for m in messages if m.agent_id == agent.id] + assert len(agent_messages) == 4 + assert agent_messages[-1].role == MessageRole.user, "Expected initial user message" + assert agent_messages[-2].role == MessageRole.assistant, "Expected assistant tool call after user message" + assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call" + assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message" + @pytest.mark.asyncio async def test_step_until_request_prepares_and_submits_batch_correctly(