diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index fb39b955..630324c0 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -303,17 +303,24 @@ class MessageManager: if group_id: query = query.filter(MessageModel.group_id == group_id) - # If query_text is provided, filter messages using subquery + json_array_elements. + # If query_text is provided, filter messages by matching any "text" type content block + # whose text includes the query string (case-insensitive). if query_text: - content_element = func.json_array_elements(MessageModel.content).alias("content_element") - query = query.filter( - exists( - select(1) - .select_from(content_element) - .where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text")) - .params(query_text=f"%{query_text}%") + dialect_name = session.bind.dialect.name + + if dialect_name == "postgresql": # using subquery + json_array_elements. + content_element = func.json_array_elements(MessageModel.content).alias("content_element") + subquery_sql = text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text") + subquery = select(1).select_from(content_element).where(subquery_sql) + + elif dialect_name == "sqlite": # using `json_each` and JSON path expressions + json_item = func.json_each(MessageModel.content).alias("json_item") + subquery_sql = text( + "json_extract(value, '$.type') = 'text' AND lower(json_extract(value, '$.text')) LIKE lower(:query_text)" ) - ) + subquery = select(1).select_from(json_item).where(subquery_sql) + + query = query.filter(exists(subquery.params(query_text=f"%{query_text}%"))) # If role(s) are provided, filter messages by those roles. if roles: diff --git a/tests/test_managers.py b/tests/test_managers.py index 3a9de069..a8a29309 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1588,6 +1588,46 @@ def test_modify_letta_message(server: SyncServer, sarah_agent, default_user): # TODO: tool calls/responses +def test_list_messages_with_query_text_filter(server: SyncServer, sarah_agent, default_user): + """ + Ensure that list_messages_for_agent correctly filters messages by query_text. + """ + test_contents = [ + "This is a message about unicorns and rainbows.", + "Another message discussing dragons in the sky.", + "Plain message with no magical beasts.", + "Mentioning unicorns again for good measure.", + "Something unrelated entirely.", + ] + + created_messages = [] + for content in test_contents: + message = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[{"type": "text", "text": content}], + ) + created = server.message_manager.create_message(pydantic_msg=message, actor=default_user) + created_messages.append(created) + + # Query messages that include "unicorns" + unicorn_messages = server.message_manager.list_messages_for_agent(agent_id=sarah_agent.id, actor=default_user, query_text="unicorns") + assert len(unicorn_messages) == 2 + for msg in unicorn_messages: + assert any(chunk.type == "text" and "unicorns" in chunk.text.lower() for chunk in msg.content or []) + + # Query messages that include "dragons" + dragon_messages = server.message_manager.list_messages_for_agent(agent_id=sarah_agent.id, actor=default_user, query_text="dragons") + assert len(dragon_messages) == 1 + assert any(chunk.type == "text" and "dragons" in chunk.text.lower() for chunk in dragon_messages[0].content or []) + + # Query with a word that shouldn't match any message + no_match_messages = server.message_manager.list_messages_for_agent( + agent_id=sarah_agent.id, actor=default_user, query_text="nonexistentcreature" + ) + assert len(no_match_messages) == 0 + + # ====================================================================================================================== # AgentManager Tests - Blocks Relationship # ======================================================================================================================