feat: async list/prepare messages (#2181)

Co-authored-by: Caren Thomas <carenthomas@gmail.com>
This commit is contained in:
Andy Li
2025-05-15 00:34:04 -07:00
committed by GitHub
parent dde50d3c63
commit 955873ab4d
8 changed files with 429 additions and 178 deletions

View File

@@ -60,6 +60,46 @@ def _prepare_in_context_messages(
return current_in_context_messages, new_in_context_messages
async def _prepare_in_context_messages_async(
input_messages: List[MessageCreate],
agent_state: AgentState,
message_manager: MessageManager,
actor: User,
) -> Tuple[List[Message], List[Message]]:
"""
Prepares in-context messages for an agent, based on the current state and a new user input.
Async version of _prepare_in_context_messages.
Args:
input_messages (List[MessageCreate]): The new user input messages to process.
agent_state (AgentState): The current state of the agent, including message buffer config.
message_manager (MessageManager): The manager used to retrieve and create messages.
actor (User): The user performing the action, used for access control and attribution.
Returns:
Tuple[List[Message], List[Message]]: A tuple containing:
- The current in-context messages (existing context for the agent).
- The new in-context messages (messages created from the new input).
"""
if agent_state.message_buffer_autoclear:
# If autoclear is enabled, only include the most recent system message (usually at index 0)
current_in_context_messages = [
(await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor))[0]
]
else:
# Otherwise, include the full list of messages by ID for context
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
# Create a new user message from the input and store it
# TODO: make this async
new_in_context_messages = message_manager.create_many_messages(
create_input_messages(input_messages=input_messages, agent_id=agent_state.id, actor=actor), actor=actor
)
return current_in_context_messages, new_in_context_messages
def serialize_message_history(messages: List[str], context: str) -> str:
"""
Produce an XML document like:

View File

@@ -7,7 +7,7 @@ from aiomultiprocess import Pool
from anthropic.types.beta.messages import BetaMessageBatchCanceledResult, BetaMessageBatchErroredResult, BetaMessageBatchSucceededResult
from letta.agents.base_agent import BaseAgent
from letta.agents.helpers import _prepare_in_context_messages
from letta.agents.helpers import _prepare_in_context_messages_async
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.tool_execution_helper import enable_strict_mode
@@ -126,6 +126,7 @@ class LettaAgentBatch(BaseAgent):
letta_batch_job_id: str,
agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None,
) -> LettaBatchResponse:
"""Carry out agent steps until the LLM request is sent."""
log_event(name="validate_inputs")
if not batch_requests:
raise ValueError("Empty list of batch_requests passed in!")
@@ -133,15 +134,26 @@ class LettaAgentBatch(BaseAgent):
agent_step_state_mapping = {}
log_event(name="load_and_prepare_agents")
agent_messages_mapping: Dict[str, List[Message]] = {}
agent_tools_mapping: Dict[str, List[dict]] = {}
# prepares (1) agent states, (2) step states, (3) LLMBatchItems (4) message batch_item_ids (5) messages per agent (6) tools per agent
agent_messages_mapping: dict[str, list[Message]] = {}
agent_tools_mapping: dict[str, list[dict]] = {}
# TODO: This isn't optimal, moving fast - prone to bugs because we pass around this half formed pydantic object
agent_batch_item_mapping: Dict[str, LLMBatchItem] = {}
agent_batch_item_mapping: dict[str, LLMBatchItem] = {}
# fetch agent states in batch
agent_mapping = {
agent_state.id: agent_state
for agent_state in await self.agent_manager.get_agents_by_ids_async(
agent_ids=[request.agent_id for request in batch_requests], actor=self.actor
)
}
agent_states = []
for batch_request in batch_requests:
agent_id = batch_request.agent_id
agent_state = self.agent_manager.get_agent_by_id(agent_id, actor=self.actor)
agent_states.append(agent_state)
agent_state = agent_mapping[agent_id]
agent_states.append(agent_state) # keeping this to maintain ordering, but may not be necessary
if agent_id not in agent_step_state_mapping:
agent_step_state_mapping[agent_id] = AgentStepState(
@@ -162,7 +174,7 @@ class LettaAgentBatch(BaseAgent):
for msg in batch_request.messages:
msg.batch_item_id = llm_batch_item.id
agent_messages_mapping[agent_id] = self._prepare_in_context_messages_per_agent(
agent_messages_mapping[agent_id] = await self._prepare_in_context_messages_per_agent_async(
agent_state=agent_state, input_messages=batch_request.messages
)
@@ -528,12 +540,14 @@ class LettaAgentBatch(BaseAgent):
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
def _prepare_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]:
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
async def _prepare_in_context_messages_per_agent_async(
self, agent_state: AgentState, input_messages: List[MessageCreate]
) -> List[Message]:
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
input_messages, agent_state, self.message_manager, self.actor
)
in_context_messages = self._rebuild_memory(current_in_context_messages + new_in_context_messages, agent_state)
in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state)
return in_context_messages
# TODO: Make this a bullk function

View File

@@ -12,7 +12,7 @@ from composio.exceptions import (
)
class AsyncComposioToolSet(BaseComposioToolSet, runtime="letta"):
class AsyncComposioToolSet(BaseComposioToolSet, runtime="letta", description_char_limit=1024):
"""
Async version of ComposioToolSet client for interacting with Composio API
Used to asynchronously hit the execute action endpoint

View File

@@ -114,155 +114,324 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
if before_obj and after_obj and before_obj.created_at < after_obj.created_at:
raise ValueError("'before' reference must be later than 'after' reference")
query = select(cls)
query = cls._list_preprocess(
before_obj=before_obj,
after_obj=after_obj,
start_date=start_date,
end_date=end_date,
limit=limit,
query_text=query_text,
query_embedding=query_embedding,
ascending=ascending,
tags=tags,
match_all_tags=match_all_tags,
actor=actor,
access=access,
access_type=access_type,
join_model=join_model,
join_conditions=join_conditions,
identifier_keys=identifier_keys,
identity_id=identity_id,
**kwargs,
)
if join_model and join_conditions:
query = query.join(join_model, and_(*join_conditions))
# Execute the query
results = session.execute(query)
# Apply access predicate if actor is provided
if actor:
query = cls.apply_access_predicate(query, actor, access, access_type)
# Handle tag filtering if the model has tags
if tags and hasattr(cls, "tags"):
query = select(cls)
if match_all_tags:
# Match ALL tags - use subqueries
subquery = (
select(cls.tags.property.mapper.class_.agent_id)
.where(cls.tags.property.mapper.class_.tag.in_(tags))
.group_by(cls.tags.property.mapper.class_.agent_id)
.having(func.count() == len(tags))
)
query = query.filter(cls.id.in_(subquery))
else:
# Match ANY tag - use join and filter
query = (
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id)
) # Deduplicate results
# select distinct primary key
query = query.distinct(cls.id).order_by(cls.id)
if identifier_keys and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
# given the identity_id, we can find within the agents table any agents that have the identity_id in their identity_ids
if identity_id and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id)
# Apply filtering logic from kwargs
for key, value in kwargs.items():
if "." in key:
# Handle joined table columns
table_name, column_name = key.split(".")
joined_table = locals().get(table_name) or globals().get(table_name)
column = getattr(joined_table, column_name)
else:
# Handle columns from main table
column = getattr(cls, key)
if isinstance(value, (list, tuple, set)):
query = query.where(column.in_(value))
else:
query = query.where(column == value)
# Date range filtering
if start_date:
query = query.filter(cls.created_at > start_date)
if end_date:
query = query.filter(cls.created_at < end_date)
# Handle pagination based on before/after
if before or after:
conditions = []
if before and after:
# Window-based query - get records between before and after
conditions = [
or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id)),
or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id)),
]
else:
# Pure pagination query
if before:
conditions.append(
or_(
cls.created_at < before_obj.created_at,
and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id),
)
)
if after:
conditions.append(
or_(
cls.created_at > after_obj.created_at,
and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id),
)
)
if conditions:
query = query.where(and_(*conditions))
# Text search
if query_text:
if hasattr(cls, "text"):
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
elif hasattr(cls, "name"):
# Special case for Agent model - search across name
query = query.filter(func.lower(cls.name).contains(func.lower(query_text)))
# Embedding search (for Passages)
is_ordered = False
if query_embedding:
if not hasattr(cls, "embedding"):
raise ValueError(f"Class {cls.__name__} does not have an embedding column")
from letta.settings import settings
if settings.letta_pg_uri_no_default:
# PostgreSQL with pgvector
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
else:
# SQLite with custom vector type
query_embedding_binary = adapt_array(query_embedding)
query = query.order_by(
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
cls.created_at.asc() if ascending else cls.created_at.desc(),
cls.id.asc(),
)
is_ordered = True
# Handle soft deletes
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
# Apply ordering
if not is_ordered:
if ascending:
query = query.order_by(cls.created_at.asc(), cls.id.asc())
else:
query = query.order_by(cls.created_at.desc(), cls.id.desc())
# Apply limit, adjusting for both bounds if necessary
if before and after:
# When both bounds are provided, we need to fetch enough records to satisfy
# the limit while respecting both bounds. We'll fetch more and then trim.
query = query.limit(limit * 2)
else:
query = query.limit(limit)
results = list(session.execute(query).scalars())
# If we have both bounds, take the middle portion
if before and after and len(results) > limit:
middle = len(results) // 2
start = max(0, middle - limit // 2)
end = min(len(results), start + limit)
results = results[start:end]
results = list(results.scalars())
results = cls._list_postprocess(
before=before,
after=after,
limit=limit,
results=results,
)
return results
@classmethod
@handle_db_timeout
async def list_async(
cls,
*,
db_session: "AsyncSession",
before: Optional[str] = None,
after: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
ascending: bool = True,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
join_model: Optional[Base] = None,
join_conditions: Optional[Union[Tuple, List]] = None,
identifier_keys: Optional[List[str]] = None,
identity_id: Optional[str] = None,
**kwargs,
) -> List["SqlalchemyBase"]:
"""
Async version of list method above.
NOTE: Keep in sync.
List records with before/after pagination, ordering by created_at.
Can use both before and after to fetch a window of records.
Args:
db_session: SQLAlchemy session
before: ID of item to paginate before (upper bound)
after: ID of item to paginate after (lower bound)
start_date: Filter items after this date
end_date: Filter items before this date
limit: Maximum number of items to return
query_text: Text to search for
query_embedding: Vector to search for similar embeddings
ascending: Sort direction
tags: List of tags to filter by
match_all_tags: If True, return items matching all tags. If False, match any tag.
**kwargs: Additional filters to apply
"""
if start_date and end_date and start_date > end_date:
raise ValueError("start_date must be earlier than or equal to end_date")
logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
async with db_session as session:
# Get the reference objects for pagination
before_obj = None
after_obj = None
if before:
before_obj = await session.get(cls, before)
if not before_obj:
raise NoResultFound(f"No {cls.__name__} found with id {before}")
if after:
after_obj = await session.get(cls, after)
if not after_obj:
raise NoResultFound(f"No {cls.__name__} found with id {after}")
# Validate that before comes after the after object if both are provided
if before_obj and after_obj and before_obj.created_at < after_obj.created_at:
raise ValueError("'before' reference must be later than 'after' reference")
query = cls._list_preprocess(
before_obj=before_obj,
after_obj=after_obj,
start_date=start_date,
end_date=end_date,
limit=limit,
query_text=query_text,
query_embedding=query_embedding,
ascending=ascending,
tags=tags,
match_all_tags=match_all_tags,
actor=actor,
access=access,
access_type=access_type,
join_model=join_model,
join_conditions=join_conditions,
identifier_keys=identifier_keys,
identity_id=identity_id,
**kwargs,
)
# Execute the query
results = await session.execute(query)
results = list(results.scalars())
results = cls._list_postprocess(
before=before,
after=after,
limit=limit,
results=results,
)
return results
@classmethod
def _list_preprocess(
cls,
*,
before_obj,
after_obj,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
ascending: bool = True,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
join_model: Optional[Base] = None,
join_conditions: Optional[Union[Tuple, List]] = None,
identifier_keys: Optional[List[str]] = None,
identity_id: Optional[str] = None,
**kwargs,
):
"""
Constructs the query for listing records.
"""
query = select(cls)
if join_model and join_conditions:
query = query.join(join_model, and_(*join_conditions))
# Apply access predicate if actor is provided
if actor:
query = cls.apply_access_predicate(query, actor, access, access_type)
# Handle tag filtering if the model has tags
if tags and hasattr(cls, "tags"):
query = select(cls)
if match_all_tags:
# Match ALL tags - use subqueries
subquery = (
select(cls.tags.property.mapper.class_.agent_id)
.where(cls.tags.property.mapper.class_.tag.in_(tags))
.group_by(cls.tags.property.mapper.class_.agent_id)
.having(func.count() == len(tags))
)
query = query.filter(cls.id.in_(subquery))
else:
# Match ANY tag - use join and filter
query = (
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id)
) # Deduplicate results
# select distinct primary key
query = query.distinct(cls.id).order_by(cls.id)
if identifier_keys and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
# given the identity_id, we can find within the agents table any agents that have the identity_id in their identity_ids
if identity_id and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id)
# Apply filtering logic from kwargs
for key, value in kwargs.items():
if "." in key:
# Handle joined table columns
table_name, column_name = key.split(".")
joined_table = locals().get(table_name) or globals().get(table_name)
column = getattr(joined_table, column_name)
else:
# Handle columns from main table
column = getattr(cls, key)
if isinstance(value, (list, tuple, set)):
query = query.where(column.in_(value))
else:
query = query.where(column == value)
# Date range filtering
if start_date:
query = query.filter(cls.created_at > start_date)
if end_date:
query = query.filter(cls.created_at < end_date)
# Handle pagination based on before/after
if before_obj or after_obj:
conditions = []
if before_obj and after_obj:
# Window-based query - get records between before and after
conditions = [
or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id)),
or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id)),
]
else:
# Pure pagination query
if before_obj:
conditions.append(
or_(
cls.created_at < before_obj.created_at,
and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id),
)
)
if after_obj:
conditions.append(
or_(
cls.created_at > after_obj.created_at,
and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id),
)
)
if conditions:
query = query.where(and_(*conditions))
# Text search
if query_text:
if hasattr(cls, "text"):
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
elif hasattr(cls, "name"):
# Special case for Agent model - search across name
query = query.filter(func.lower(cls.name).contains(func.lower(query_text)))
# Embedding search (for Passages)
is_ordered = False
if query_embedding:
if not hasattr(cls, "embedding"):
raise ValueError(f"Class {cls.__name__} does not have an embedding column")
from letta.settings import settings
if settings.letta_pg_uri_no_default:
# PostgreSQL with pgvector
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
else:
# SQLite with custom vector type
query_embedding_binary = adapt_array(query_embedding)
query = query.order_by(
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
cls.created_at.asc() if ascending else cls.created_at.desc(),
cls.id.asc(),
)
is_ordered = True
# Handle soft deletes
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
# Apply ordering
if not is_ordered:
if ascending:
query = query.order_by(cls.created_at.asc(), cls.id.asc())
else:
query = query.order_by(cls.created_at.desc(), cls.id.desc())
# Apply limit, adjusting for both bounds if necessary
if before_obj and after_obj:
# When both bounds are provided, we need to fetch enough records to satisfy
# the limit while respecting both bounds. We'll fetch more and then trim.
query = query.limit(limit * 2)
else:
query = query.limit(limit)
return query
@classmethod
def _list_postprocess(
cls,
before: str | None,
after: str | None,
limit: int | None,
results: list,
):
# If we have both bounds, take the middle portion
if before and after and len(results) > limit:
middle = len(results) // 2
start = max(0, middle - limit // 2)
end = min(len(results), start + limit)
results = results[start:end]
return results
@classmethod
@handle_db_timeout
def read(
@@ -305,7 +474,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
@handle_db_timeout
async def read_async(
cls,
db_session: "Session",
db_session: "AsyncSession",
identifier: Optional[str] = None,
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],

View File

@@ -577,6 +577,13 @@ class AgentManager:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
return agent.to_pydantic()
@enforce_types
async def get_agents_by_ids_async(self, agent_ids: list[str], actor: PydanticUser) -> list[PydanticAgentState]:
"""Fetch a list of agents by their IDs."""
async with db_registry.async_session() as session:
agents = await AgentModel.read_multiple_async(db_session=session, identifiers=agent_ids, actor=actor)
return [agent.to_pydantic() for agent in agents]
@enforce_types
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""

View File

@@ -36,15 +36,29 @@ class MessageManager:
"""Fetch messages by ID and return them in the requested order."""
with db_registry.session() as session:
results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids))
return self._get_messages_by_id_postprocess(results, message_ids)
if len(results) != len(message_ids):
logger.warning(
f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}"
)
@enforce_types
async def get_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
"""Fetch messages by ID and return them in the requested order. Async version of above function."""
async with db_registry.async_session() as session:
results = await MessageModel.list_async(
db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids)
)
return self._get_messages_by_id_postprocess(results, message_ids)
# Sort results directly based on message_ids
result_dict = {msg.id: msg.to_pydantic() for msg in results}
return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids]))
def _get_messages_by_id_postprocess(
self,
results: List[MessageModel],
message_ids: List[str],
) -> List[PydanticMessage]:
if len(results) != len(message_ids):
logger.warning(
f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}"
)
# Sort results directly based on message_ids
result_dict = {msg.id: msg.to_pydantic() for msg in results}
return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids]))
@enforce_types
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:

0
letta/types/__init__.py Normal file
View File

View File

@@ -1,3 +1,4 @@
import asyncio
import os
import threading
from datetime import datetime, timezone
@@ -164,6 +165,14 @@ def step_state_map(agents):
return {agent.id: AgentStepState(step_number=0, tool_rules_solver=solver) for agent in agents}
@pytest.fixture(scope="session")
def event_loop(request):
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
def create_batch_response(batch_id: str, processing_status: str = "in_progress") -> BetaMessageBatch:
"""Create a dummy BetaMessageBatch with the specified ID and status."""
now = datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc)
@@ -452,7 +461,7 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv
@pytest.mark.asyncio
async def test_partial_error_from_anthropic_batch(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job, event_loop
):
anthropic_batch_id = "msgbatch_test_12345"
dummy_batch_response = create_batch_response(
@@ -594,7 +603,7 @@ async def test_partial_error_from_anthropic_batch(
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
)
assert len(messages) == (len(agents) - 1) * 4 + 1
assert_descending_order(messages)
_assert_descending_order(messages)
# Check that each agent is represented
for agent in agents_continue:
agent_messages = [m for m in messages if m.agent_id == agent.id]
@@ -612,7 +621,7 @@ async def test_partial_error_from_anthropic_batch(
@pytest.mark.asyncio
async def test_resume_step_some_stop(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job, event_loop
):
anthropic_batch_id = "msgbatch_test_12345"
dummy_batch_response = create_batch_response(
@@ -743,7 +752,7 @@ async def test_resume_step_some_stop(
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
)
assert len(messages) == len(agents) * 3 + 1
assert_descending_order(messages)
_assert_descending_order(messages)
# Check that each agent is represented
for agent in agents_continue:
agent_messages = [m for m in messages if m.agent_id == agent.id]
@@ -761,23 +770,21 @@ async def test_resume_step_some_stop(
assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call"
def assert_descending_order(messages):
"""Assert messages are in descending order by created_at timestamps."""
def _assert_descending_order(messages):
"""Assert messages are in monotonically decreasing by created_at timestamps."""
if len(messages) <= 1:
return True
for i in range(1, len(messages)):
assert messages[i].created_at <= messages[i - 1].created_at, (
f"Order violation: {messages[i - 1].id} ({messages[i - 1].created_at}) "
f"followed by {messages[i].id} ({messages[i].created_at})"
)
for prev, next in zip(messages[:-1], messages[1:]):
assert (
prev.created_at >= next.created_at
), f"Order violation: {prev.id} ({prev.created_at}) followed by {next.id} ({next.created_at})"
return True
@pytest.mark.asyncio
async def test_resume_step_after_request_all_continue(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job, event_loop
):
anthropic_batch_id = "msgbatch_test_12345"
dummy_batch_response = create_batch_response(
@@ -902,7 +909,7 @@ async def test_resume_step_after_request_all_continue(
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
)
assert len(messages) == len(agents) * 4
assert_descending_order(messages)
_assert_descending_order(messages)
# Check that each agent is represented
for agent in agents:
agent_messages = [m for m in messages if m.agent_id == agent.id]
@@ -915,7 +922,7 @@ async def test_resume_step_after_request_all_continue(
@pytest.mark.asyncio
async def test_step_until_request_prepares_and_submits_batch_correctly(
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job, event_loop
):
"""
Test that step_until_request correctly: