From 955873ab4dbeffc9ebca5553de20de714c09d02e Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Thu, 15 May 2025 00:34:04 -0700 Subject: [PATCH] feat: async list/prepare messages (#2181) Co-authored-by: Caren Thomas --- letta/agents/helpers.py | 40 ++ letta/agents/letta_agent_batch.py | 34 +- letta/functions/async_composio_toolset.py | 2 +- letta/orm/sqlalchemy_base.py | 459 +++++++++++++++------- letta/services/agent_manager.py | 7 + letta/services/message_manager.py | 28 +- letta/types/__init__.py | 0 tests/test_letta_agent_batch.py | 37 +- 8 files changed, 429 insertions(+), 178 deletions(-) create mode 100644 letta/types/__init__.py diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 5f60dcbb..17d9e676 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -60,6 +60,46 @@ def _prepare_in_context_messages( return current_in_context_messages, new_in_context_messages +async def _prepare_in_context_messages_async( + input_messages: List[MessageCreate], + agent_state: AgentState, + message_manager: MessageManager, + actor: User, +) -> Tuple[List[Message], List[Message]]: + """ + Prepares in-context messages for an agent, based on the current state and a new user input. + Async version of _prepare_in_context_messages. + + Args: + input_messages (List[MessageCreate]): The new user input messages to process. + agent_state (AgentState): The current state of the agent, including message buffer config. + message_manager (MessageManager): The manager used to retrieve and create messages. + actor (User): The user performing the action, used for access control and attribution. + + Returns: + Tuple[List[Message], List[Message]]: A tuple containing: + - The current in-context messages (existing context for the agent). + - The new in-context messages (messages created from the new input). + """ + + if agent_state.message_buffer_autoclear: + # If autoclear is enabled, only include the most recent system message (usually at index 0) + current_in_context_messages = [ + (await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor))[0] + ] + else: + # Otherwise, include the full list of messages by ID for context + current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor) + + # Create a new user message from the input and store it + # TODO: make this async + new_in_context_messages = message_manager.create_many_messages( + create_input_messages(input_messages=input_messages, agent_id=agent_state.id, actor=actor), actor=actor + ) + + return current_in_context_messages, new_in_context_messages + + def serialize_message_history(messages: List[str], context: str) -> str: """ Produce an XML document like: diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index ba426688..f844876e 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -7,7 +7,7 @@ from aiomultiprocess import Pool from anthropic.types.beta.messages import BetaMessageBatchCanceledResult, BetaMessageBatchErroredResult, BetaMessageBatchSucceededResult from letta.agents.base_agent import BaseAgent -from letta.agents.helpers import _prepare_in_context_messages +from letta.agents.helpers import _prepare_in_context_messages_async from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.tool_execution_helper import enable_strict_mode @@ -126,6 +126,7 @@ class LettaAgentBatch(BaseAgent): letta_batch_job_id: str, agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None, ) -> LettaBatchResponse: + """Carry out agent steps until the LLM request is sent.""" log_event(name="validate_inputs") if not batch_requests: raise ValueError("Empty list of batch_requests passed in!") @@ -133,15 +134,26 @@ class LettaAgentBatch(BaseAgent): agent_step_state_mapping = {} log_event(name="load_and_prepare_agents") - agent_messages_mapping: Dict[str, List[Message]] = {} - agent_tools_mapping: Dict[str, List[dict]] = {} + # prepares (1) agent states, (2) step states, (3) LLMBatchItems (4) message batch_item_ids (5) messages per agent (6) tools per agent + + agent_messages_mapping: dict[str, list[Message]] = {} + agent_tools_mapping: dict[str, list[dict]] = {} # TODO: This isn't optimal, moving fast - prone to bugs because we pass around this half formed pydantic object - agent_batch_item_mapping: Dict[str, LLMBatchItem] = {} + agent_batch_item_mapping: dict[str, LLMBatchItem] = {} + + # fetch agent states in batch + agent_mapping = { + agent_state.id: agent_state + for agent_state in await self.agent_manager.get_agents_by_ids_async( + agent_ids=[request.agent_id for request in batch_requests], actor=self.actor + ) + } + agent_states = [] for batch_request in batch_requests: agent_id = batch_request.agent_id - agent_state = self.agent_manager.get_agent_by_id(agent_id, actor=self.actor) - agent_states.append(agent_state) + agent_state = agent_mapping[agent_id] + agent_states.append(agent_state) # keeping this to maintain ordering, but may not be necessary if agent_id not in agent_step_state_mapping: agent_step_state_mapping[agent_id] = AgentStepState( @@ -162,7 +174,7 @@ class LettaAgentBatch(BaseAgent): for msg in batch_request.messages: msg.batch_item_id = llm_batch_item.id - agent_messages_mapping[agent_id] = self._prepare_in_context_messages_per_agent( + agent_messages_mapping[agent_id] = await self._prepare_in_context_messages_per_agent_async( agent_state=agent_state, input_messages=batch_request.messages ) @@ -528,12 +540,14 @@ class LettaAgentBatch(BaseAgent): valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools])) return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)] - def _prepare_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]: - current_in_context_messages, new_in_context_messages = _prepare_in_context_messages( + async def _prepare_in_context_messages_per_agent_async( + self, agent_state: AgentState, input_messages: List[MessageCreate] + ) -> List[Message]: + current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( input_messages, agent_state, self.message_manager, self.actor ) - in_context_messages = self._rebuild_memory(current_in_context_messages + new_in_context_messages, agent_state) + in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state) return in_context_messages # TODO: Make this a bullk function diff --git a/letta/functions/async_composio_toolset.py b/letta/functions/async_composio_toolset.py index f240721e..bcea60d6 100644 --- a/letta/functions/async_composio_toolset.py +++ b/letta/functions/async_composio_toolset.py @@ -12,7 +12,7 @@ from composio.exceptions import ( ) -class AsyncComposioToolSet(BaseComposioToolSet, runtime="letta"): +class AsyncComposioToolSet(BaseComposioToolSet, runtime="letta", description_char_limit=1024): """ Async version of ComposioToolSet client for interacting with Composio API Used to asynchronously hit the execute action endpoint diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index dcb4cebf..2e5abe99 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -114,155 +114,324 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): if before_obj and after_obj and before_obj.created_at < after_obj.created_at: raise ValueError("'before' reference must be later than 'after' reference") - query = select(cls) + query = cls._list_preprocess( + before_obj=before_obj, + after_obj=after_obj, + start_date=start_date, + end_date=end_date, + limit=limit, + query_text=query_text, + query_embedding=query_embedding, + ascending=ascending, + tags=tags, + match_all_tags=match_all_tags, + actor=actor, + access=access, + access_type=access_type, + join_model=join_model, + join_conditions=join_conditions, + identifier_keys=identifier_keys, + identity_id=identity_id, + **kwargs, + ) - if join_model and join_conditions: - query = query.join(join_model, and_(*join_conditions)) + # Execute the query + results = session.execute(query) - # Apply access predicate if actor is provided - if actor: - query = cls.apply_access_predicate(query, actor, access, access_type) - - # Handle tag filtering if the model has tags - if tags and hasattr(cls, "tags"): - query = select(cls) - - if match_all_tags: - # Match ALL tags - use subqueries - subquery = ( - select(cls.tags.property.mapper.class_.agent_id) - .where(cls.tags.property.mapper.class_.tag.in_(tags)) - .group_by(cls.tags.property.mapper.class_.agent_id) - .having(func.count() == len(tags)) - ) - query = query.filter(cls.id.in_(subquery)) - else: - # Match ANY tag - use join and filter - query = ( - query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id) - ) # Deduplicate results - - # select distinct primary key - query = query.distinct(cls.id).order_by(cls.id) - - if identifier_keys and hasattr(cls, "identities"): - query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys)) - - # given the identity_id, we can find within the agents table any agents that have the identity_id in their identity_ids - if identity_id and hasattr(cls, "identities"): - query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id) - - # Apply filtering logic from kwargs - for key, value in kwargs.items(): - if "." in key: - # Handle joined table columns - table_name, column_name = key.split(".") - joined_table = locals().get(table_name) or globals().get(table_name) - column = getattr(joined_table, column_name) - else: - # Handle columns from main table - column = getattr(cls, key) - - if isinstance(value, (list, tuple, set)): - query = query.where(column.in_(value)) - else: - query = query.where(column == value) - - # Date range filtering - if start_date: - query = query.filter(cls.created_at > start_date) - if end_date: - query = query.filter(cls.created_at < end_date) - - # Handle pagination based on before/after - if before or after: - conditions = [] - - if before and after: - # Window-based query - get records between before and after - conditions = [ - or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id)), - or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id)), - ] - else: - # Pure pagination query - if before: - conditions.append( - or_( - cls.created_at < before_obj.created_at, - and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id), - ) - ) - if after: - conditions.append( - or_( - cls.created_at > after_obj.created_at, - and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id), - ) - ) - - if conditions: - query = query.where(and_(*conditions)) - - # Text search - if query_text: - if hasattr(cls, "text"): - query = query.filter(func.lower(cls.text).contains(func.lower(query_text))) - elif hasattr(cls, "name"): - # Special case for Agent model - search across name - query = query.filter(func.lower(cls.name).contains(func.lower(query_text))) - - # Embedding search (for Passages) - is_ordered = False - if query_embedding: - if not hasattr(cls, "embedding"): - raise ValueError(f"Class {cls.__name__} does not have an embedding column") - - from letta.settings import settings - - if settings.letta_pg_uri_no_default: - # PostgreSQL with pgvector - query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc()) - else: - # SQLite with custom vector type - query_embedding_binary = adapt_array(query_embedding) - query = query.order_by( - func.cosine_distance(cls.embedding, query_embedding_binary).asc(), - cls.created_at.asc() if ascending else cls.created_at.desc(), - cls.id.asc(), - ) - is_ordered = True - - # Handle soft deletes - if hasattr(cls, "is_deleted"): - query = query.where(cls.is_deleted == False) - - # Apply ordering - if not is_ordered: - if ascending: - query = query.order_by(cls.created_at.asc(), cls.id.asc()) - else: - query = query.order_by(cls.created_at.desc(), cls.id.desc()) - - # Apply limit, adjusting for both bounds if necessary - if before and after: - # When both bounds are provided, we need to fetch enough records to satisfy - # the limit while respecting both bounds. We'll fetch more and then trim. - query = query.limit(limit * 2) - else: - query = query.limit(limit) - - results = list(session.execute(query).scalars()) - - # If we have both bounds, take the middle portion - if before and after and len(results) > limit: - middle = len(results) // 2 - start = max(0, middle - limit // 2) - end = min(len(results), start + limit) - results = results[start:end] + results = list(results.scalars()) + results = cls._list_postprocess( + before=before, + after=after, + limit=limit, + results=results, + ) return results + @classmethod + @handle_db_timeout + async def list_async( + cls, + *, + db_session: "AsyncSession", + before: Optional[str] = None, + after: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + limit: Optional[int] = 50, + query_text: Optional[str] = None, + query_embedding: Optional[List[float]] = None, + ascending: bool = True, + tags: Optional[List[str]] = None, + match_all_tags: bool = False, + actor: Optional["User"] = None, + access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], + access_type: AccessType = AccessType.ORGANIZATION, + join_model: Optional[Base] = None, + join_conditions: Optional[Union[Tuple, List]] = None, + identifier_keys: Optional[List[str]] = None, + identity_id: Optional[str] = None, + **kwargs, + ) -> List["SqlalchemyBase"]: + """ + Async version of list method above. + NOTE: Keep in sync. + List records with before/after pagination, ordering by created_at. + Can use both before and after to fetch a window of records. + + Args: + db_session: SQLAlchemy session + before: ID of item to paginate before (upper bound) + after: ID of item to paginate after (lower bound) + start_date: Filter items after this date + end_date: Filter items before this date + limit: Maximum number of items to return + query_text: Text to search for + query_embedding: Vector to search for similar embeddings + ascending: Sort direction + tags: List of tags to filter by + match_all_tags: If True, return items matching all tags. If False, match any tag. + **kwargs: Additional filters to apply + """ + if start_date and end_date and start_date > end_date: + raise ValueError("start_date must be earlier than or equal to end_date") + + logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}") + + async with db_session as session: + # Get the reference objects for pagination + before_obj = None + after_obj = None + + if before: + before_obj = await session.get(cls, before) + if not before_obj: + raise NoResultFound(f"No {cls.__name__} found with id {before}") + + if after: + after_obj = await session.get(cls, after) + if not after_obj: + raise NoResultFound(f"No {cls.__name__} found with id {after}") + + # Validate that before comes after the after object if both are provided + if before_obj and after_obj and before_obj.created_at < after_obj.created_at: + raise ValueError("'before' reference must be later than 'after' reference") + + query = cls._list_preprocess( + before_obj=before_obj, + after_obj=after_obj, + start_date=start_date, + end_date=end_date, + limit=limit, + query_text=query_text, + query_embedding=query_embedding, + ascending=ascending, + tags=tags, + match_all_tags=match_all_tags, + actor=actor, + access=access, + access_type=access_type, + join_model=join_model, + join_conditions=join_conditions, + identifier_keys=identifier_keys, + identity_id=identity_id, + **kwargs, + ) + + # Execute the query + results = await session.execute(query) + + results = list(results.scalars()) + results = cls._list_postprocess( + before=before, + after=after, + limit=limit, + results=results, + ) + + return results + + @classmethod + def _list_preprocess( + cls, + *, + before_obj, + after_obj, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + limit: Optional[int] = 50, + query_text: Optional[str] = None, + query_embedding: Optional[List[float]] = None, + ascending: bool = True, + tags: Optional[List[str]] = None, + match_all_tags: bool = False, + actor: Optional["User"] = None, + access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], + access_type: AccessType = AccessType.ORGANIZATION, + join_model: Optional[Base] = None, + join_conditions: Optional[Union[Tuple, List]] = None, + identifier_keys: Optional[List[str]] = None, + identity_id: Optional[str] = None, + **kwargs, + ): + """ + Constructs the query for listing records. + """ + query = select(cls) + + if join_model and join_conditions: + query = query.join(join_model, and_(*join_conditions)) + + # Apply access predicate if actor is provided + if actor: + query = cls.apply_access_predicate(query, actor, access, access_type) + + # Handle tag filtering if the model has tags + if tags and hasattr(cls, "tags"): + query = select(cls) + + if match_all_tags: + # Match ALL tags - use subqueries + subquery = ( + select(cls.tags.property.mapper.class_.agent_id) + .where(cls.tags.property.mapper.class_.tag.in_(tags)) + .group_by(cls.tags.property.mapper.class_.agent_id) + .having(func.count() == len(tags)) + ) + query = query.filter(cls.id.in_(subquery)) + else: + # Match ANY tag - use join and filter + query = ( + query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id) + ) # Deduplicate results + + # select distinct primary key + query = query.distinct(cls.id).order_by(cls.id) + + if identifier_keys and hasattr(cls, "identities"): + query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys)) + + # given the identity_id, we can find within the agents table any agents that have the identity_id in their identity_ids + if identity_id and hasattr(cls, "identities"): + query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id) + + # Apply filtering logic from kwargs + for key, value in kwargs.items(): + if "." in key: + # Handle joined table columns + table_name, column_name = key.split(".") + joined_table = locals().get(table_name) or globals().get(table_name) + column = getattr(joined_table, column_name) + else: + # Handle columns from main table + column = getattr(cls, key) + + if isinstance(value, (list, tuple, set)): + query = query.where(column.in_(value)) + else: + query = query.where(column == value) + + # Date range filtering + if start_date: + query = query.filter(cls.created_at > start_date) + if end_date: + query = query.filter(cls.created_at < end_date) + + # Handle pagination based on before/after + if before_obj or after_obj: + conditions = [] + + if before_obj and after_obj: + # Window-based query - get records between before and after + conditions = [ + or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id)), + or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id)), + ] + else: + # Pure pagination query + if before_obj: + conditions.append( + or_( + cls.created_at < before_obj.created_at, + and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id), + ) + ) + if after_obj: + conditions.append( + or_( + cls.created_at > after_obj.created_at, + and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id), + ) + ) + + if conditions: + query = query.where(and_(*conditions)) + + # Text search + if query_text: + if hasattr(cls, "text"): + query = query.filter(func.lower(cls.text).contains(func.lower(query_text))) + elif hasattr(cls, "name"): + # Special case for Agent model - search across name + query = query.filter(func.lower(cls.name).contains(func.lower(query_text))) + + # Embedding search (for Passages) + is_ordered = False + if query_embedding: + if not hasattr(cls, "embedding"): + raise ValueError(f"Class {cls.__name__} does not have an embedding column") + + from letta.settings import settings + + if settings.letta_pg_uri_no_default: + # PostgreSQL with pgvector + query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc()) + else: + # SQLite with custom vector type + query_embedding_binary = adapt_array(query_embedding) + query = query.order_by( + func.cosine_distance(cls.embedding, query_embedding_binary).asc(), + cls.created_at.asc() if ascending else cls.created_at.desc(), + cls.id.asc(), + ) + is_ordered = True + + # Handle soft deletes + if hasattr(cls, "is_deleted"): + query = query.where(cls.is_deleted == False) + + # Apply ordering + if not is_ordered: + if ascending: + query = query.order_by(cls.created_at.asc(), cls.id.asc()) + else: + query = query.order_by(cls.created_at.desc(), cls.id.desc()) + + # Apply limit, adjusting for both bounds if necessary + if before_obj and after_obj: + # When both bounds are provided, we need to fetch enough records to satisfy + # the limit while respecting both bounds. We'll fetch more and then trim. + query = query.limit(limit * 2) + else: + query = query.limit(limit) + return query + + @classmethod + def _list_postprocess( + cls, + before: str | None, + after: str | None, + limit: int | None, + results: list, + ): + # If we have both bounds, take the middle portion + if before and after and len(results) > limit: + middle = len(results) // 2 + start = max(0, middle - limit // 2) + end = min(len(results), start + limit) + results = results[start:end] + return results + @classmethod @handle_db_timeout def read( @@ -305,7 +474,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): @handle_db_timeout async def read_async( cls, - db_session: "Session", + db_session: "AsyncSession", identifier: Optional[str] = None, actor: Optional["User"] = None, access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 0ff701f4..9747ea60 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -577,6 +577,13 @@ class AgentManager: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) return agent.to_pydantic() + @enforce_types + async def get_agents_by_ids_async(self, agent_ids: list[str], actor: PydanticUser) -> list[PydanticAgentState]: + """Fetch a list of agents by their IDs.""" + async with db_registry.async_session() as session: + agents = await AgentModel.read_multiple_async(db_session=session, identifiers=agent_ids, actor=actor) + return [agent.to_pydantic() for agent in agents] + @enforce_types def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState: """Fetch an agent by its ID.""" diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index c6ca4579..0dda1cfe 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -36,15 +36,29 @@ class MessageManager: """Fetch messages by ID and return them in the requested order.""" with db_registry.session() as session: results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids)) + return self._get_messages_by_id_postprocess(results, message_ids) - if len(results) != len(message_ids): - logger.warning( - f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}" - ) + @enforce_types + async def get_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]: + """Fetch messages by ID and return them in the requested order. Async version of above function.""" + async with db_registry.async_session() as session: + results = await MessageModel.list_async( + db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids) + ) + return self._get_messages_by_id_postprocess(results, message_ids) - # Sort results directly based on message_ids - result_dict = {msg.id: msg.to_pydantic() for msg in results} - return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids])) + def _get_messages_by_id_postprocess( + self, + results: List[MessageModel], + message_ids: List[str], + ) -> List[PydanticMessage]: + if len(results) != len(message_ids): + logger.warning( + f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}" + ) + # Sort results directly based on message_ids + result_dict = {msg.id: msg.to_pydantic() for msg in results} + return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids])) @enforce_types def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage: diff --git a/letta/types/__init__.py b/letta/types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index ee668fd0..b043b216 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -1,3 +1,4 @@ +import asyncio import os import threading from datetime import datetime, timezone @@ -164,6 +165,14 @@ def step_state_map(agents): return {agent.id: AgentStepState(step_number=0, tool_rules_solver=solver) for agent in agents} +@pytest.fixture(scope="session") +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + def create_batch_response(batch_id: str, processing_status: str = "in_progress") -> BetaMessageBatch: """Create a dummy BetaMessageBatch with the specified ID and status.""" now = datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc) @@ -452,7 +461,7 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv @pytest.mark.asyncio async def test_partial_error_from_anthropic_batch( - disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job + disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job, event_loop ): anthropic_batch_id = "msgbatch_test_12345" dummy_batch_response = create_batch_response( @@ -594,7 +603,7 @@ async def test_partial_error_from_anthropic_batch( letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user ) assert len(messages) == (len(agents) - 1) * 4 + 1 - assert_descending_order(messages) + _assert_descending_order(messages) # Check that each agent is represented for agent in agents_continue: agent_messages = [m for m in messages if m.agent_id == agent.id] @@ -612,7 +621,7 @@ async def test_partial_error_from_anthropic_batch( @pytest.mark.asyncio async def test_resume_step_some_stop( - disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job + disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job, event_loop ): anthropic_batch_id = "msgbatch_test_12345" dummy_batch_response = create_batch_response( @@ -743,7 +752,7 @@ async def test_resume_step_some_stop( letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user ) assert len(messages) == len(agents) * 3 + 1 - assert_descending_order(messages) + _assert_descending_order(messages) # Check that each agent is represented for agent in agents_continue: agent_messages = [m for m in messages if m.agent_id == agent.id] @@ -761,23 +770,21 @@ async def test_resume_step_some_stop( assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call" -def assert_descending_order(messages): - """Assert messages are in descending order by created_at timestamps.""" +def _assert_descending_order(messages): + """Assert messages are in monotonically decreasing by created_at timestamps.""" if len(messages) <= 1: return True - for i in range(1, len(messages)): - assert messages[i].created_at <= messages[i - 1].created_at, ( - f"Order violation: {messages[i - 1].id} ({messages[i - 1].created_at}) " - f"followed by {messages[i].id} ({messages[i].created_at})" - ) - + for prev, next in zip(messages[:-1], messages[1:]): + assert ( + prev.created_at >= next.created_at + ), f"Order violation: {prev.id} ({prev.created_at}) followed by {next.id} ({next.created_at})" return True @pytest.mark.asyncio async def test_resume_step_after_request_all_continue( - disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job + disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job, event_loop ): anthropic_batch_id = "msgbatch_test_12345" dummy_batch_response = create_batch_response( @@ -902,7 +909,7 @@ async def test_resume_step_after_request_all_continue( letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user ) assert len(messages) == len(agents) * 4 - assert_descending_order(messages) + _assert_descending_order(messages) # Check that each agent is represented for agent in agents: agent_messages = [m for m in messages if m.agent_id == agent.id] @@ -915,7 +922,7 @@ async def test_resume_step_after_request_all_continue( @pytest.mark.asyncio async def test_step_until_request_prepares_and_submits_batch_correctly( - disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job + disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job, event_loop ): """ Test that step_until_request correctly: