feat(asyncify): migrate list messages (#2272)

This commit is contained in:
cthomas
2025-05-20 16:52:11 -07:00
committed by GitHub
parent 6432dea8e1
commit 44afd54c5c
6 changed files with 146 additions and 7 deletions

View File

@@ -457,7 +457,7 @@ class VoiceAgent(BaseAgent):
keyword_results = {}
if convo_keyword_queries:
for keyword in convo_keyword_queries:
messages = self.message_manager.list_messages_for_agent(
messages = await self.message_manager.list_messages_for_agent_async(
agent_id=self.agent_id,
actor=self.actor,
query_text=keyword,

View File

@@ -190,7 +190,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
prior_messages = []
if self.group.sleeptime_agent_frequency:
try:
prior_messages = self.message_manager.list_messages_for_agent(
prior_messages = await self.message_manager.list_messages_for_agent_async(
agent_id=foreground_agent_id,
actor=self.actor,
after=last_processed_message_id,

View File

@@ -567,7 +567,7 @@ AgentMessagesResponse = Annotated[
@router.get("/{agent_id}/messages", response_model=AgentMessagesResponse, operation_id="list_messages")
def list_messages(
async def list_messages(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
@@ -582,10 +582,9 @@ def list_messages(
"""
Retrieve message history for an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return server.get_agent_recall(
user_id=actor.id,
return await server.get_agent_recall_async(
agent_id=agent_id,
after=after,
before=before,
@@ -596,6 +595,7 @@ def list_messages(
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
actor=actor,
)

View File

@@ -1096,6 +1096,44 @@ class SyncServer(Server):
return records
async def get_agent_recall_async(
self,
agent_id: str,
actor: User,
after: Optional[str] = None,
before: Optional[str] = None,
limit: Optional[int] = 100,
group_id: Optional[str] = None,
reverse: Optional[bool] = False,
return_message_object: bool = True,
use_assistant_message: bool = True,
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
) -> Union[List[Message], List[LettaMessage]]:
records = await self.message_manager.list_messages_for_agent_async(
agent_id=agent_id,
actor=actor,
after=after,
before=before,
limit=limit,
ascending=not reverse,
group_id=group_id,
)
if not return_message_object:
records = Message.to_letta_messages_from_list(
messages=records,
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
reverse=reverse,
)
if reverse:
records = records[::-1]
return records
def get_server_config(self, include_defaults: bool = False) -> dict:
"""Return the base config"""

View File

@@ -429,6 +429,107 @@ class MessageManager:
results = query.all()
return [msg.to_pydantic() for msg in results]
@enforce_types
async def list_messages_for_agent_async(
self,
agent_id: str,
actor: PydanticUser,
after: Optional[str] = None,
before: Optional[str] = None,
query_text: Optional[str] = None,
roles: Optional[Sequence[MessageRole]] = None,
limit: Optional[int] = 50,
ascending: bool = True,
group_id: Optional[str] = None,
) -> List[PydanticMessage]:
"""
Most performant query to list messages for an agent by directly querying the Message table.
This function filters by the agent_id (leveraging the index on messages.agent_id)
and applies pagination using sequence_id as the cursor.
If query_text is provided, it will filter messages whose text content partially matches the query.
If role is provided, it will filter messages by the specified role.
Args:
agent_id: The ID of the agent whose messages are queried.
actor: The user performing the action (used for permission checks).
after: A message ID; if provided, only messages *after* this message (by sequence_id) are returned.
before: A message ID; if provided, only messages *before* this message (by sequence_id) are returned.
query_text: Optional string to partially match the message text content.
roles: Optional MessageRole to filter messages by role.
limit: Maximum number of messages to return.
ascending: If True, sort by sequence_id ascending; if False, sort descending.
group_id: Optional group ID to filter messages by group_id.
Returns:
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
Raises:
NoResultFound: If the provided after/before message IDs do not exist.
"""
async with db_registry.async_session() as session:
# Permission check: raise if the agent doesn't exist or actor is not allowed.
await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
# Build a query that directly filters the Message table by agent_id.
query = select(MessageModel).where(MessageModel.agent_id == agent_id)
# If group_id is provided, filter messages by group_id.
if group_id:
query = query.where(MessageModel.group_id == group_id)
# If query_text is provided, filter messages using subquery + json_array_elements.
if query_text:
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
query = query.where(
exists(
select(1)
.select_from(content_element)
.where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text"))
.params(query_text=f"%{query_text}%")
)
)
# If role(s) are provided, filter messages by those roles.
if roles:
role_values = [r.value for r in roles]
query = query.where(MessageModel.role.in_(role_values))
# Apply 'after' pagination if specified.
if after:
after_query = select(MessageModel.sequence_id).where(MessageModel.id == after)
after_result = await session.execute(after_query)
after_ref = after_result.one_or_none()
if not after_ref:
raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.")
# Filter out any messages with a sequence_id <= after_ref.sequence_id
query = query.where(MessageModel.sequence_id > after_ref.sequence_id)
# Apply 'before' pagination if specified.
if before:
before_query = select(MessageModel.sequence_id).where(MessageModel.id == before)
before_result = await session.execute(before_query)
before_ref = before_result.one_or_none()
if not before_ref:
raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.")
# Filter out any messages with a sequence_id >= before_ref.sequence_id
query = query.where(MessageModel.sequence_id < before_ref.sequence_id)
# Apply ordering based on the ascending flag.
if ascending:
query = query.order_by(MessageModel.sequence_id.asc())
else:
query = query.order_by(MessageModel.sequence_id.desc())
# Limit the number of results.
query = query.limit(limit)
# Execute and convert each Message to its Pydantic representation.
result = await session.execute(query)
results = result.scalars().all()
return [msg.to_pydantic() for msg in results]
@enforce_types
def delete_all_messages_for_agent(self, agent_id: str, actor: PydanticUser) -> int:
"""

View File

@@ -936,7 +936,7 @@ def test_composio_client_simple(server):
assert len(actions) > 0
def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools):
async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = user
# create agent