fix: Refactor listing messages to be much more performant (#963)
This commit is contained in:
@@ -875,14 +875,12 @@ class SyncServer(Server):
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
start_date = self.message_manager.get_message_by_id(after, actor=actor).created_at if after else None
|
||||
end_date = self.message_manager.get_message_by_id(before, actor=actor).created_at if before else None
|
||||
|
||||
records = self.message_manager.list_messages_for_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
ascending=not reverse,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.schemas.enums import MessageRole
|
||||
@@ -127,44 +129,21 @@ class MessageManager:
|
||||
def list_user_messages_for_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
before: Optional[str] = None,
|
||||
actor: PydanticUser,
|
||||
after: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
filters: Optional[Dict] = None,
|
||||
before: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
ascending: bool = True,
|
||||
) -> List[PydanticMessage]:
|
||||
"""List user messages with flexible filtering and pagination options.
|
||||
|
||||
Args:
|
||||
before: Cursor-based pagination - return records before this ID (exclusive)
|
||||
after: Cursor-based pagination - return records after this ID (exclusive)
|
||||
start_date: Filter records created after this date
|
||||
end_date: Filter records created before this date
|
||||
limit: Maximum number of records to return
|
||||
filters: Additional filters to apply
|
||||
query_text: Optional text to search for in message content
|
||||
|
||||
Returns:
|
||||
List[PydanticMessage] - List of messages matching the criteria
|
||||
"""
|
||||
message_filters = {"role": "user"}
|
||||
if filters:
|
||||
message_filters.update(filters)
|
||||
|
||||
return self.list_messages_for_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
before=before,
|
||||
after=after,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
filters=message_filters,
|
||||
before=before,
|
||||
query_text=query_text,
|
||||
role=MessageRole.user,
|
||||
limit=limit,
|
||||
ascending=ascending,
|
||||
)
|
||||
|
||||
@@ -172,48 +151,94 @@ class MessageManager:
|
||||
def list_messages_for_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
before: Optional[str] = None,
|
||||
actor: PydanticUser,
|
||||
after: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
filters: Optional[Dict] = None,
|
||||
before: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
role: Optional[MessageRole] = None, # New parameter for filtering by role
|
||||
limit: Optional[int] = 50,
|
||||
ascending: bool = True,
|
||||
) -> List[PydanticMessage]:
|
||||
"""List messages with flexible filtering and pagination options.
|
||||
"""
|
||||
Most performant query to list messages for an agent by directly querying the Message table.
|
||||
|
||||
This function filters by the agent_id (leveraging the index on messages.agent_id)
|
||||
and applies efficient pagination using (created_at, id) as the cursor.
|
||||
If query_text is provided, it will filter messages whose text content partially matches the query.
|
||||
If role is provided, it will filter messages by the specified role.
|
||||
|
||||
Args:
|
||||
before: Cursor-based pagination - return records before this ID (exclusive)
|
||||
after: Cursor-based pagination - return records after this ID (exclusive)
|
||||
start_date: Filter records created after this date
|
||||
end_date: Filter records created before this date
|
||||
limit: Maximum number of records to return
|
||||
filters: Additional filters to apply
|
||||
query_text: Optional text to search for in message content
|
||||
agent_id: The ID of the agent whose messages are queried.
|
||||
actor: The user performing the action (used for permission checks).
|
||||
after: A message ID; if provided, only messages *after* this message (per sort order) are returned.
|
||||
before: A message ID; if provided, only messages *before* this message are returned.
|
||||
query_text: Optional string to partially match the message text content.
|
||||
role: Optional MessageRole to filter messages by role.
|
||||
limit: Maximum number of messages to return.
|
||||
ascending: If True, sort by (created_at, id) ascending; if False, sort descending.
|
||||
|
||||
Returns:
|
||||
List[PydanticMessage] - List of messages matching the criteria
|
||||
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the provided after/before message IDs do not exist.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Start with base filters
|
||||
message_filters = {"agent_id": agent_id}
|
||||
if actor:
|
||||
message_filters.update({"organization_id": actor.organization_id})
|
||||
if filters:
|
||||
message_filters.update(filters)
|
||||
# Permission check: raise if the agent doesn't exist or actor is not allowed.
|
||||
AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
results = MessageModel.list(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
query_text=query_text,
|
||||
ascending=ascending,
|
||||
**message_filters,
|
||||
)
|
||||
# Build a query that directly filters the Message table by agent_id.
|
||||
query = session.query(MessageModel).filter(MessageModel.agent_id == agent_id)
|
||||
|
||||
# If query_text is provided, filter messages by partial match on text.
|
||||
if query_text:
|
||||
query = query.filter(MessageModel.text.ilike(f"%{query_text}%"))
|
||||
|
||||
# If role is provided, filter messages by role.
|
||||
if role:
|
||||
query = query.filter(MessageModel.role == role.value) # Enum.value ensures comparison is against the string value
|
||||
|
||||
# Apply 'after' pagination if specified.
|
||||
if after:
|
||||
after_ref = session.query(MessageModel.created_at, MessageModel.id).filter(MessageModel.id == after).limit(1).one_or_none()
|
||||
if not after_ref:
|
||||
raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.")
|
||||
query = query.filter(
|
||||
or_(
|
||||
MessageModel.created_at > after_ref.created_at,
|
||||
and_(
|
||||
MessageModel.created_at == after_ref.created_at,
|
||||
MessageModel.id > after_ref.id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Apply 'before' pagination if specified.
|
||||
if before:
|
||||
before_ref = (
|
||||
session.query(MessageModel.created_at, MessageModel.id).filter(MessageModel.id == before).limit(1).one_or_none()
|
||||
)
|
||||
if not before_ref:
|
||||
raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.")
|
||||
query = query.filter(
|
||||
or_(
|
||||
MessageModel.created_at < before_ref.created_at,
|
||||
and_(
|
||||
MessageModel.created_at == before_ref.created_at,
|
||||
MessageModel.id < before_ref.id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Apply ordering based on the ascending flag.
|
||||
if ascending:
|
||||
query = query.order_by(MessageModel.created_at.asc(), MessageModel.id.asc())
|
||||
else:
|
||||
query = query.order_by(MessageModel.created_at.desc(), MessageModel.id.desc())
|
||||
|
||||
# Limit the number of results.
|
||||
query = query.limit(limit)
|
||||
|
||||
# Execute and convert each Message to its Pydantic representation.
|
||||
results = query.all()
|
||||
return [msg.to_pydantic() for msg in results]
|
||||
|
||||
@@ -186,7 +186,7 @@ def test_check_tool_rules_with_different_models(mock_e2b_api_key_none):
|
||||
client = create_client()
|
||||
|
||||
config_files = [
|
||||
"tests/configs/llm_model_configs/claude-3-sonnet-20240229.json",
|
||||
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
|
||||
"tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json",
|
||||
"tests/configs/llm_model_configs/openai-gpt-4o.json",
|
||||
]
|
||||
@@ -247,7 +247,7 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none):
|
||||
tools = [t1, t2]
|
||||
|
||||
# Make agent state
|
||||
anthropic_config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
|
||||
anthropic_config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
for i in range(3):
|
||||
agent_uuid = str(uuid.uuid4())
|
||||
agent_state = setup_agent(
|
||||
@@ -299,7 +299,7 @@ def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none):
|
||||
tools = [send_message, archival_memory_search, archival_memory_insert]
|
||||
|
||||
config_files = [
|
||||
"tests/configs/llm_model_configs/claude-3-sonnet-20240229.json",
|
||||
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
|
||||
"tests/configs/llm_model_configs/openai-gpt-4o.json",
|
||||
]
|
||||
|
||||
@@ -383,7 +383,7 @@ def test_agent_conditional_tool_easy(mock_e2b_api_key_none):
|
||||
]
|
||||
tools = [flip_coin_tool, reveal_secret]
|
||||
|
||||
config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
|
||||
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word")
|
||||
|
||||
@@ -455,7 +455,7 @@ def test_agent_conditional_tool_hard(mock_e2b_api_key_none):
|
||||
|
||||
# Setup agent with all tools
|
||||
tools = [play_game_tool, flip_coin_tool, reveal_secret]
|
||||
config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
|
||||
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Ask agent to try to get all secret words
|
||||
@@ -681,7 +681,7 @@ def test_init_tool_rule_always_fails_one_tool():
|
||||
)
|
||||
|
||||
# Set up agent with the tool rule
|
||||
claude_config = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
|
||||
claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=False)
|
||||
|
||||
# Start conversation
|
||||
@@ -710,7 +710,7 @@ def test_init_tool_rule_always_fails_multiple_tools():
|
||||
)
|
||||
|
||||
# Set up agent with the tool rule
|
||||
claude_config = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
|
||||
claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=True)
|
||||
|
||||
# Start conversation
|
||||
|
||||
@@ -1971,17 +1971,6 @@ def test_message_listing_text_search(server: SyncServer, hello_world_message_fix
|
||||
assert len(search_results) == 0
|
||||
|
||||
|
||||
def test_message_listing_date_range_filtering(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
|
||||
"""Test filtering messages by date range"""
|
||||
create_test_messages(server, hello_world_message_fixture, default_user)
|
||||
now = datetime.utcnow()
|
||||
|
||||
date_results = server.message_manager.list_user_messages_for_agent(
|
||||
agent_id=sarah_agent.id, actor=default_user, start_date=now - timedelta(minutes=1), end_date=now + timedelta(minutes=1), limit=10
|
||||
)
|
||||
assert len(date_results) > 0
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -164,7 +164,7 @@ def wait_for_incoming_message(
|
||||
deadline = time.time() + max_wait_seconds
|
||||
|
||||
while time.time() < deadline:
|
||||
messages = client.server.message_manager.list_messages_for_agent(agent_id=agent_id)
|
||||
messages = client.server.message_manager.list_messages_for_agent(agent_id=agent_id, actor=client.user)
|
||||
# Check for the system message containing `substring`
|
||||
if any(message.role == MessageRole.system and substring in (message.text or "") for message in messages):
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user