From dbbcbf1e2d55e580c31fed22e7f9dc14568820a8 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 12 Feb 2025 10:32:38 -0800 Subject: [PATCH] fix: Refactor listing messages to be much more performant (#963) --- letta/server/server.py | 6 +- letta/services/message_manager.py | 151 ++++++++++++--------- tests/integration_test_agent_tool_graph.py | 14 +- tests/test_managers.py | 11 -- tests/utils.py | 2 +- 5 files changed, 98 insertions(+), 86 deletions(-) diff --git a/letta/server/server.py b/letta/server/server.py index 4a02b74e..75ac63bb 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -875,14 +875,12 @@ class SyncServer(Server): # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user actor = self.user_manager.get_user_or_default(user_id=user_id) - start_date = self.message_manager.get_message_by_id(after, actor=actor).created_at if after else None - end_date = self.message_manager.get_message_by_id(before, actor=actor).created_at if before else None records = self.message_manager.list_messages_for_agent( agent_id=agent_id, actor=actor, - start_date=start_date, - end_date=end_date, + after=after, + before=before, limit=limit, ascending=not reverse, ) diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index ac00ca15..01eccb53 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -1,6 +1,8 @@ -from datetime import datetime -from typing import Dict, List, Optional +from typing import List, Optional +from sqlalchemy import and_, or_ + +from letta.orm.agent import Agent as AgentModel from letta.orm.errors import NoResultFound from letta.orm.message import Message as MessageModel from letta.schemas.enums import MessageRole @@ -127,44 +129,21 @@ class MessageManager: def list_user_messages_for_agent( self, agent_id: str, - actor: Optional[PydanticUser] = None, - before: Optional[str] = None, + actor: PydanticUser, after: Optional[str] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - limit: Optional[int] = 50, - filters: Optional[Dict] = None, + before: Optional[str] = None, query_text: Optional[str] = None, + limit: Optional[int] = 50, ascending: bool = True, ) -> List[PydanticMessage]: - """List user messages with flexible filtering and pagination options. - - Args: - before: Cursor-based pagination - return records before this ID (exclusive) - after: Cursor-based pagination - return records after this ID (exclusive) - start_date: Filter records created after this date - end_date: Filter records created before this date - limit: Maximum number of records to return - filters: Additional filters to apply - query_text: Optional text to search for in message content - - Returns: - List[PydanticMessage] - List of messages matching the criteria - """ - message_filters = {"role": "user"} - if filters: - message_filters.update(filters) - return self.list_messages_for_agent( agent_id=agent_id, actor=actor, - before=before, after=after, - start_date=start_date, - end_date=end_date, - limit=limit, - filters=message_filters, + before=before, query_text=query_text, + role=MessageRole.user, + limit=limit, ascending=ascending, ) @@ -172,48 +151,94 @@ class MessageManager: def list_messages_for_agent( self, agent_id: str, - actor: Optional[PydanticUser] = None, - before: Optional[str] = None, + actor: PydanticUser, after: Optional[str] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - limit: Optional[int] = 50, - filters: Optional[Dict] = None, + before: Optional[str] = None, query_text: Optional[str] = None, + role: Optional[MessageRole] = None, # New parameter for filtering by role + limit: Optional[int] = 50, ascending: bool = True, ) -> List[PydanticMessage]: - """List messages with flexible filtering and pagination options. + """ + Most performant query to list messages for an agent by directly querying the Message table. + + This function filters by the agent_id (leveraging the index on messages.agent_id) + and applies efficient pagination using (created_at, id) as the cursor. + If query_text is provided, it will filter messages whose text content partially matches the query. + If role is provided, it will filter messages by the specified role. Args: - before: Cursor-based pagination - return records before this ID (exclusive) - after: Cursor-based pagination - return records after this ID (exclusive) - start_date: Filter records created after this date - end_date: Filter records created before this date - limit: Maximum number of records to return - filters: Additional filters to apply - query_text: Optional text to search for in message content + agent_id: The ID of the agent whose messages are queried. + actor: The user performing the action (used for permission checks). + after: A message ID; if provided, only messages *after* this message (per sort order) are returned. + before: A message ID; if provided, only messages *before* this message are returned. + query_text: Optional string to partially match the message text content. + role: Optional MessageRole to filter messages by role. + limit: Maximum number of messages to return. + ascending: If True, sort by (created_at, id) ascending; if False, sort descending. Returns: - List[PydanticMessage] - List of messages matching the criteria + List[PydanticMessage]: A list of messages (converted via .to_pydantic()). + + Raises: + NoResultFound: If the provided after/before message IDs do not exist. """ with self.session_maker() as session: - # Start with base filters - message_filters = {"agent_id": agent_id} - if actor: - message_filters.update({"organization_id": actor.organization_id}) - if filters: - message_filters.update(filters) + # Permission check: raise if the agent doesn't exist or actor is not allowed. + AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - results = MessageModel.list( - db_session=session, - before=before, - after=after, - start_date=start_date, - end_date=end_date, - limit=limit, - query_text=query_text, - ascending=ascending, - **message_filters, - ) + # Build a query that directly filters the Message table by agent_id. + query = session.query(MessageModel).filter(MessageModel.agent_id == agent_id) + # If query_text is provided, filter messages by partial match on text. + if query_text: + query = query.filter(MessageModel.text.ilike(f"%{query_text}%")) + + # If role is provided, filter messages by role. + if role: + query = query.filter(MessageModel.role == role.value) # Enum.value ensures comparison is against the string value + + # Apply 'after' pagination if specified. + if after: + after_ref = session.query(MessageModel.created_at, MessageModel.id).filter(MessageModel.id == after).limit(1).one_or_none() + if not after_ref: + raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.") + query = query.filter( + or_( + MessageModel.created_at > after_ref.created_at, + and_( + MessageModel.created_at == after_ref.created_at, + MessageModel.id > after_ref.id, + ), + ) + ) + + # Apply 'before' pagination if specified. + if before: + before_ref = ( + session.query(MessageModel.created_at, MessageModel.id).filter(MessageModel.id == before).limit(1).one_or_none() + ) + if not before_ref: + raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.") + query = query.filter( + or_( + MessageModel.created_at < before_ref.created_at, + and_( + MessageModel.created_at == before_ref.created_at, + MessageModel.id < before_ref.id, + ), + ) + ) + + # Apply ordering based on the ascending flag. + if ascending: + query = query.order_by(MessageModel.created_at.asc(), MessageModel.id.asc()) + else: + query = query.order_by(MessageModel.created_at.desc(), MessageModel.id.desc()) + + # Limit the number of results. + query = query.limit(limit) + + # Execute and convert each Message to its Pydantic representation. + results = query.all() return [msg.to_pydantic() for msg in results] diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 025f751b..97b8709f 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -186,7 +186,7 @@ def test_check_tool_rules_with_different_models(mock_e2b_api_key_none): client = create_client() config_files = [ - "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json", + "tests/configs/llm_model_configs/claude-3-5-sonnet.json", "tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json", "tests/configs/llm_model_configs/openai-gpt-4o.json", ] @@ -247,7 +247,7 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none): tools = [t1, t2] # Make agent state - anthropic_config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" + anthropic_config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" for i in range(3): agent_uuid = str(uuid.uuid4()) agent_state = setup_agent( @@ -299,7 +299,7 @@ def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none): tools = [send_message, archival_memory_search, archival_memory_insert] config_files = [ - "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json", + "tests/configs/llm_model_configs/claude-3-5-sonnet.json", "tests/configs/llm_model_configs/openai-gpt-4o.json", ] @@ -383,7 +383,7 @@ def test_agent_conditional_tool_easy(mock_e2b_api_key_none): ] tools = [flip_coin_tool, reveal_secret] - config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" + config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word") @@ -455,7 +455,7 @@ def test_agent_conditional_tool_hard(mock_e2b_api_key_none): # Setup agent with all tools tools = [play_game_tool, flip_coin_tool, reveal_secret] - config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" + config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) # Ask agent to try to get all secret words @@ -681,7 +681,7 @@ def test_init_tool_rule_always_fails_one_tool(): ) # Set up agent with the tool rule - claude_config = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" + claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=False) # Start conversation @@ -710,7 +710,7 @@ def test_init_tool_rule_always_fails_multiple_tools(): ) # Set up agent with the tool rule - claude_config = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" + claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=True) # Start conversation diff --git a/tests/test_managers.py b/tests/test_managers.py index 43ffbaa7..540717dc 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1971,17 +1971,6 @@ def test_message_listing_text_search(server: SyncServer, hello_world_message_fix assert len(search_results) == 0 -def test_message_listing_date_range_filtering(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent): - """Test filtering messages by date range""" - create_test_messages(server, hello_world_message_fixture, default_user) - now = datetime.utcnow() - - date_results = server.message_manager.list_user_messages_for_agent( - agent_id=sarah_agent.id, actor=default_user, start_date=now - timedelta(minutes=1), end_date=now + timedelta(minutes=1), limit=10 - ) - assert len(date_results) > 0 - - # ====================================================================================================================== # Block Manager Tests # ====================================================================================================================== diff --git a/tests/utils.py b/tests/utils.py index 46d83ed7..e16cd15a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -164,7 +164,7 @@ def wait_for_incoming_message( deadline = time.time() + max_wait_seconds while time.time() < deadline: - messages = client.server.message_manager.list_messages_for_agent(agent_id=agent_id) + messages = client.server.message_manager.list_messages_for_agent(agent_id=agent_id, actor=client.user) # Check for the system message containing `substring` if any(message.role == MessageRole.system and substring in (message.text or "") for message in messages): return True