feat(asyncify): migrate list messages (#2272)
This commit is contained in:
@@ -457,7 +457,7 @@ class VoiceAgent(BaseAgent):
|
||||
keyword_results = {}
|
||||
if convo_keyword_queries:
|
||||
for keyword in convo_keyword_queries:
|
||||
messages = self.message_manager.list_messages_for_agent(
|
||||
messages = await self.message_manager.list_messages_for_agent_async(
|
||||
agent_id=self.agent_id,
|
||||
actor=self.actor,
|
||||
query_text=keyword,
|
||||
|
||||
@@ -190,7 +190,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
prior_messages = []
|
||||
if self.group.sleeptime_agent_frequency:
|
||||
try:
|
||||
prior_messages = self.message_manager.list_messages_for_agent(
|
||||
prior_messages = await self.message_manager.list_messages_for_agent_async(
|
||||
agent_id=foreground_agent_id,
|
||||
actor=self.actor,
|
||||
after=last_processed_message_id,
|
||||
|
||||
@@ -567,7 +567,7 @@ AgentMessagesResponse = Annotated[
|
||||
|
||||
|
||||
@router.get("/{agent_id}/messages", response_model=AgentMessagesResponse, operation_id="list_messages")
|
||||
def list_messages(
|
||||
async def list_messages(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
|
||||
@@ -582,10 +582,9 @@ def list_messages(
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
return server.get_agent_recall(
|
||||
user_id=actor.id,
|
||||
return await server.get_agent_recall_async(
|
||||
agent_id=agent_id,
|
||||
after=after,
|
||||
before=before,
|
||||
@@ -596,6 +595,7 @@ def list_messages(
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1096,6 +1096,44 @@ class SyncServer(Server):
|
||||
|
||||
return records
|
||||
|
||||
async def get_agent_recall_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: User,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
group_id: Optional[str] = None,
|
||||
reverse: Optional[bool] = False,
|
||||
return_message_object: bool = True,
|
||||
use_assistant_message: bool = True,
|
||||
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
) -> Union[List[Message], List[LettaMessage]]:
|
||||
records = await self.message_manager.list_messages_for_agent_async(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
ascending=not reverse,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
if not return_message_object:
|
||||
records = Message.to_letta_messages_from_list(
|
||||
messages=records,
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
reverse=reverse,
|
||||
)
|
||||
|
||||
if reverse:
|
||||
records = records[::-1]
|
||||
|
||||
return records
|
||||
|
||||
def get_server_config(self, include_defaults: bool = False) -> dict:
|
||||
"""Return the base config"""
|
||||
|
||||
|
||||
@@ -429,6 +429,107 @@ class MessageManager:
|
||||
results = query.all()
|
||||
return [msg.to_pydantic() for msg in results]
|
||||
|
||||
@enforce_types
|
||||
async def list_messages_for_agent_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: PydanticUser,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
roles: Optional[Sequence[MessageRole]] = None,
|
||||
limit: Optional[int] = 50,
|
||||
ascending: bool = True,
|
||||
group_id: Optional[str] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Most performant query to list messages for an agent by directly querying the Message table.
|
||||
|
||||
This function filters by the agent_id (leveraging the index on messages.agent_id)
|
||||
and applies pagination using sequence_id as the cursor.
|
||||
If query_text is provided, it will filter messages whose text content partially matches the query.
|
||||
If role is provided, it will filter messages by the specified role.
|
||||
|
||||
Args:
|
||||
agent_id: The ID of the agent whose messages are queried.
|
||||
actor: The user performing the action (used for permission checks).
|
||||
after: A message ID; if provided, only messages *after* this message (by sequence_id) are returned.
|
||||
before: A message ID; if provided, only messages *before* this message (by sequence_id) are returned.
|
||||
query_text: Optional string to partially match the message text content.
|
||||
roles: Optional MessageRole to filter messages by role.
|
||||
limit: Maximum number of messages to return.
|
||||
ascending: If True, sort by sequence_id ascending; if False, sort descending.
|
||||
group_id: Optional group ID to filter messages by group_id.
|
||||
|
||||
Returns:
|
||||
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the provided after/before message IDs do not exist.
|
||||
"""
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Permission check: raise if the agent doesn't exist or actor is not allowed.
|
||||
await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Build a query that directly filters the Message table by agent_id.
|
||||
query = select(MessageModel).where(MessageModel.agent_id == agent_id)
|
||||
|
||||
# If group_id is provided, filter messages by group_id.
|
||||
if group_id:
|
||||
query = query.where(MessageModel.group_id == group_id)
|
||||
|
||||
# If query_text is provided, filter messages using subquery + json_array_elements.
|
||||
if query_text:
|
||||
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
|
||||
query = query.where(
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(content_element)
|
||||
.where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text"))
|
||||
.params(query_text=f"%{query_text}%")
|
||||
)
|
||||
)
|
||||
|
||||
# If role(s) are provided, filter messages by those roles.
|
||||
if roles:
|
||||
role_values = [r.value for r in roles]
|
||||
query = query.where(MessageModel.role.in_(role_values))
|
||||
|
||||
# Apply 'after' pagination if specified.
|
||||
if after:
|
||||
after_query = select(MessageModel.sequence_id).where(MessageModel.id == after)
|
||||
after_result = await session.execute(after_query)
|
||||
after_ref = after_result.one_or_none()
|
||||
if not after_ref:
|
||||
raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.")
|
||||
# Filter out any messages with a sequence_id <= after_ref.sequence_id
|
||||
query = query.where(MessageModel.sequence_id > after_ref.sequence_id)
|
||||
|
||||
# Apply 'before' pagination if specified.
|
||||
if before:
|
||||
before_query = select(MessageModel.sequence_id).where(MessageModel.id == before)
|
||||
before_result = await session.execute(before_query)
|
||||
before_ref = before_result.one_or_none()
|
||||
if not before_ref:
|
||||
raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.")
|
||||
# Filter out any messages with a sequence_id >= before_ref.sequence_id
|
||||
query = query.where(MessageModel.sequence_id < before_ref.sequence_id)
|
||||
|
||||
# Apply ordering based on the ascending flag.
|
||||
if ascending:
|
||||
query = query.order_by(MessageModel.sequence_id.asc())
|
||||
else:
|
||||
query = query.order_by(MessageModel.sequence_id.desc())
|
||||
|
||||
# Limit the number of results.
|
||||
query = query.limit(limit)
|
||||
|
||||
# Execute and convert each Message to its Pydantic representation.
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
return [msg.to_pydantic() for msg in results]
|
||||
|
||||
@enforce_types
|
||||
def delete_all_messages_for_agent(self, agent_id: str, actor: PydanticUser) -> int:
|
||||
"""
|
||||
|
||||
@@ -936,7 +936,7 @@ def test_composio_client_simple(server):
|
||||
assert len(actions) > 0
|
||||
|
||||
|
||||
def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools):
|
||||
async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools):
|
||||
"""Test that the memory rebuild is generating the correct number of role=system messages"""
|
||||
actor = user
|
||||
# create agent
|
||||
|
||||
Reference in New Issue
Block a user