fix: Refactor listing messages to be much more performant (#963)

This commit is contained in:
Matthew Zhou
2025-02-12 10:32:38 -08:00
committed by GitHub
parent b8bd29e5f0
commit 9fd8d2f56b
5 changed files with 98 additions and 86 deletions

View File

@@ -875,14 +875,12 @@ class SyncServer(Server):
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
start_date = self.message_manager.get_message_by_id(after, actor=actor).created_at if after else None
end_date = self.message_manager.get_message_by_id(before, actor=actor).created_at if before else None
records = self.message_manager.list_messages_for_agent(
agent_id=agent_id,
actor=actor,
start_date=start_date,
end_date=end_date,
after=after,
before=before,
limit=limit,
ascending=not reverse,
)

View File

@@ -1,6 +1,8 @@
from datetime import datetime
from typing import Dict, List, Optional
from typing import List, Optional
from sqlalchemy import and_, or_
from letta.orm.agent import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
from letta.schemas.enums import MessageRole
@@ -127,44 +129,21 @@ class MessageManager:
def list_user_messages_for_agent(
self,
agent_id: str,
actor: Optional[PydanticUser] = None,
before: Optional[str] = None,
actor: PydanticUser,
after: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
filters: Optional[Dict] = None,
before: Optional[str] = None,
query_text: Optional[str] = None,
limit: Optional[int] = 50,
ascending: bool = True,
) -> List[PydanticMessage]:
"""List user messages with flexible filtering and pagination options.
Args:
before: Cursor-based pagination - return records before this ID (exclusive)
after: Cursor-based pagination - return records after this ID (exclusive)
start_date: Filter records created after this date
end_date: Filter records created before this date
limit: Maximum number of records to return
filters: Additional filters to apply
query_text: Optional text to search for in message content
Returns:
List[PydanticMessage] - List of messages matching the criteria
"""
message_filters = {"role": "user"}
if filters:
message_filters.update(filters)
return self.list_messages_for_agent(
agent_id=agent_id,
actor=actor,
before=before,
after=after,
start_date=start_date,
end_date=end_date,
limit=limit,
filters=message_filters,
before=before,
query_text=query_text,
role=MessageRole.user,
limit=limit,
ascending=ascending,
)
@@ -172,48 +151,94 @@ class MessageManager:
def list_messages_for_agent(
self,
agent_id: str,
actor: Optional[PydanticUser] = None,
before: Optional[str] = None,
actor: PydanticUser,
after: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
filters: Optional[Dict] = None,
before: Optional[str] = None,
query_text: Optional[str] = None,
role: Optional[MessageRole] = None, # New parameter for filtering by role
limit: Optional[int] = 50,
ascending: bool = True,
) -> List[PydanticMessage]:
"""List messages with flexible filtering and pagination options.
"""
Most performant query to list messages for an agent by directly querying the Message table.
This function filters by the agent_id (leveraging the index on messages.agent_id)
and applies efficient pagination using (created_at, id) as the cursor.
If query_text is provided, it will filter messages whose text content partially matches the query.
If role is provided, it will filter messages by the specified role.
Args:
before: Cursor-based pagination - return records before this ID (exclusive)
after: Cursor-based pagination - return records after this ID (exclusive)
start_date: Filter records created after this date
end_date: Filter records created before this date
limit: Maximum number of records to return
filters: Additional filters to apply
query_text: Optional text to search for in message content
agent_id: The ID of the agent whose messages are queried.
actor: The user performing the action (used for permission checks).
after: A message ID; if provided, only messages *after* this message (per sort order) are returned.
before: A message ID; if provided, only messages *before* this message are returned.
query_text: Optional string to partially match the message text content.
role: Optional MessageRole to filter messages by role.
limit: Maximum number of messages to return.
ascending: If True, sort by (created_at, id) ascending; if False, sort descending.
Returns:
List[PydanticMessage] - List of messages matching the criteria
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
Raises:
NoResultFound: If the provided after/before message IDs do not exist.
"""
with self.session_maker() as session:
# Start with base filters
message_filters = {"agent_id": agent_id}
if actor:
message_filters.update({"organization_id": actor.organization_id})
if filters:
message_filters.update(filters)
# Permission check: raise if the agent doesn't exist or actor is not allowed.
AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
results = MessageModel.list(
db_session=session,
before=before,
after=after,
start_date=start_date,
end_date=end_date,
limit=limit,
query_text=query_text,
ascending=ascending,
**message_filters,
)
# Build a query that directly filters the Message table by agent_id.
query = session.query(MessageModel).filter(MessageModel.agent_id == agent_id)
# If query_text is provided, filter messages by partial match on text.
if query_text:
query = query.filter(MessageModel.text.ilike(f"%{query_text}%"))
# If role is provided, filter messages by role.
if role:
query = query.filter(MessageModel.role == role.value) # Enum.value ensures comparison is against the string value
# Apply 'after' pagination if specified.
if after:
after_ref = session.query(MessageModel.created_at, MessageModel.id).filter(MessageModel.id == after).limit(1).one_or_none()
if not after_ref:
raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.")
query = query.filter(
or_(
MessageModel.created_at > after_ref.created_at,
and_(
MessageModel.created_at == after_ref.created_at,
MessageModel.id > after_ref.id,
),
)
)
# Apply 'before' pagination if specified.
if before:
before_ref = (
session.query(MessageModel.created_at, MessageModel.id).filter(MessageModel.id == before).limit(1).one_or_none()
)
if not before_ref:
raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.")
query = query.filter(
or_(
MessageModel.created_at < before_ref.created_at,
and_(
MessageModel.created_at == before_ref.created_at,
MessageModel.id < before_ref.id,
),
)
)
# Apply ordering based on the ascending flag.
if ascending:
query = query.order_by(MessageModel.created_at.asc(), MessageModel.id.asc())
else:
query = query.order_by(MessageModel.created_at.desc(), MessageModel.id.desc())
# Limit the number of results.
query = query.limit(limit)
# Execute and convert each Message to its Pydantic representation.
results = query.all()
return [msg.to_pydantic() for msg in results]

View File

@@ -186,7 +186,7 @@ def test_check_tool_rules_with_different_models(mock_e2b_api_key_none):
client = create_client()
config_files = [
"tests/configs/llm_model_configs/claude-3-sonnet-20240229.json",
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
"tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json",
"tests/configs/llm_model_configs/openai-gpt-4o.json",
]
@@ -247,7 +247,7 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none):
tools = [t1, t2]
# Make agent state
anthropic_config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
anthropic_config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
for i in range(3):
agent_uuid = str(uuid.uuid4())
agent_state = setup_agent(
@@ -299,7 +299,7 @@ def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none):
tools = [send_message, archival_memory_search, archival_memory_insert]
config_files = [
"tests/configs/llm_model_configs/claude-3-sonnet-20240229.json",
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
"tests/configs/llm_model_configs/openai-gpt-4o.json",
]
@@ -383,7 +383,7 @@ def test_agent_conditional_tool_easy(mock_e2b_api_key_none):
]
tools = [flip_coin_tool, reveal_secret]
config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word")
@@ -455,7 +455,7 @@ def test_agent_conditional_tool_hard(mock_e2b_api_key_none):
# Setup agent with all tools
tools = [play_game_tool, flip_coin_tool, reveal_secret]
config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# Ask agent to try to get all secret words
@@ -681,7 +681,7 @@ def test_init_tool_rule_always_fails_one_tool():
)
# Set up agent with the tool rule
claude_config = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=False)
# Start conversation
@@ -710,7 +710,7 @@ def test_init_tool_rule_always_fails_multiple_tools():
)
# Set up agent with the tool rule
claude_config = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=True)
# Start conversation

View File

@@ -1971,17 +1971,6 @@ def test_message_listing_text_search(server: SyncServer, hello_world_message_fix
assert len(search_results) == 0
def test_message_listing_date_range_filtering(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
"""Test filtering messages by date range"""
create_test_messages(server, hello_world_message_fixture, default_user)
now = datetime.utcnow()
date_results = server.message_manager.list_user_messages_for_agent(
agent_id=sarah_agent.id, actor=default_user, start_date=now - timedelta(minutes=1), end_date=now + timedelta(minutes=1), limit=10
)
assert len(date_results) > 0
# ======================================================================================================================
# Block Manager Tests
# ======================================================================================================================

View File

@@ -164,7 +164,7 @@ def wait_for_incoming_message(
deadline = time.time() + max_wait_seconds
while time.time() < deadline:
messages = client.server.message_manager.list_messages_for_agent(agent_id=agent_id)
messages = client.server.message_manager.list_messages_for_agent(agent_id=agent_id, actor=client.user)
# Check for the system message containing `substring`
if any(message.role == MessageRole.system and substring in (message.text or "") for message in messages):
return True