From d924cc005b5cb88356e1ed0afe4ba3d96ebda364 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 3 Sep 2025 17:33:19 -0700 Subject: [PATCH] fix: change to pure rank-based RRF for relevance ordering (#4411) * Fix RRF * Fix turbopuffer tests --- letta/functions/function_sets/base.py | 30 ++-- letta/helpers/tpuf_client.py | 147 ++++++++++-------- letta/services/message_manager.py | 37 +++-- .../tool_executor/core_tool_executor.py | 25 ++- tests/integration_test_turbopuffer.py | 55 +++---- 5 files changed, 174 insertions(+), 120 deletions(-) diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index 0f802a72..cccd5ab5 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -55,33 +55,31 @@ def conversation_search( str: Query result string containing matching messages with timestamps and content. """ - import math - from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE from letta.helpers.json_helpers import json_dumps - if page is None or (isinstance(page, str) and page.lower().strip() == "none"): - page = 0 - try: - page = int(page) - except: - raise ValueError("'page' argument must be an integer") - count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE - # TODO: add paging by page number. currently cursor only works with strings. - # original: start=page * count + # Use provided limit or default + if limit is None: + limit = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE + messages = self.message_manager.list_messages_for_agent( agent_id=self.agent_state.id, actor=self.user, query_text=query, - limit=count, + roles=roles, + limit=limit, ) - total = len(messages) - num_pages = math.ceil(total / count) - 1 # 0 index + if len(messages) == 0: results_str = "No results found." else: - results_pref = f"Showing {len(messages)} of {total} results (page {page}/{num_pages}):" - results_formatted = [message.content[0].text for message in messages] + results_pref = f"Found {len(messages)} results:" + results_formatted = [] + for message in messages: + # Extract text content from message + text_content = message.content[0].text if message.content else "" + result_entry = {"role": message.role, "content": text_content} + results_formatted.append(result_entry) results_str = f"{results_pref} {json_dumps(results_formatted)}" return results_str diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index 78eb0572..57f81b2c 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -474,8 +474,16 @@ class TurbopufferClient: # for hybrid mode, we get a multi-query response vector_results = self._process_single_query_results(result.results[0], archive_id, tags) fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True) - # use backwards-compatible wrapper which calls generic RRF - return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k) + # use RRF and return only (passage, score) for backwards compatibility + results_with_metadata = self._reciprocal_rank_fusion( + vector_results=[passage for passage, _ in vector_results], + fts_results=[passage for passage, _ in fts_results], + get_id_func=lambda p: p.id, + vector_weight=vector_weight, + fts_weight=fts_weight, + top_k=top_k, + ) + return [(passage, rrf_score) for passage, rrf_score, metadata in results_with_metadata] else: # for single queries (vector, fts, timestamp) is_fts = search_mode == "fts" @@ -499,7 +507,7 @@ class TurbopufferClient: fts_weight: float = 0.5, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, - ) -> List[Tuple[dict, float]]: + ) -> List[Tuple[dict, float, dict]]: """Query messages from Turbopuffer using vector search, full-text search, or hybrid search. Args: @@ -516,7 +524,10 @@ class TurbopufferClient: end_date: Optional datetime to filter messages created before this date Returns: - List of (message_dict, score) tuples where message_dict contains id, text, role, created_at + List of (message_dict, score, metadata) tuples where: + - message_dict contains id, text, role, created_at + - score is the final relevance score + - metadata contains individual scores and ranking information """ # Check if we should fallback to timestamp-based retrieval if query_embedding is None and query_text is None and search_mode not in ["timestamp"]: @@ -576,9 +587,9 @@ class TurbopufferClient: if search_mode == "hybrid": # for hybrid mode, we get a multi-query response vector_results = self._process_message_query_results(result.results[0]) - fts_results = self._process_message_query_results(result.results[1], is_fts=True) - # use generic RRF with lambda to extract ID from dict - return self._generic_reciprocal_rank_fusion( + fts_results = self._process_message_query_results(result.results[1]) + # use RRF with lambda to extract ID from dict - returns metadata + results_with_metadata = self._reciprocal_rank_fusion( vector_results=vector_results, fts_results=fts_results, get_id_func=lambda msg_dict: msg_dict["id"], @@ -586,18 +597,32 @@ class TurbopufferClient: fts_weight=fts_weight, top_k=top_k, ) + # return results with metadata + return results_with_metadata else: # for single queries (vector, fts, timestamp) - is_fts = search_mode == "fts" - return self._process_message_query_results(result, is_fts=is_fts) + results = self._process_message_query_results(result) + # add simple metadata for single search modes + results_with_metadata = [] + for idx, msg_dict in enumerate(results): + metadata = { + "combined_score": 1.0 / (idx + 1), # Use rank-based score for single mode + "search_mode": search_mode, + f"{search_mode}_rank": idx + 1, # Add the rank for this search mode + } + results_with_metadata.append((msg_dict, metadata["combined_score"], metadata)) + return results_with_metadata except Exception as e: logger.error(f"Failed to query messages from Turbopuffer: {e}") raise - def _process_message_query_results(self, result, is_fts: bool = False) -> List[Tuple[dict, float]]: - """Process results from a message query into message dicts with scores.""" - messages_with_scores = [] + def _process_message_query_results(self, result) -> List[dict]: + """Process results from a message query into message dicts. + + For RRF, we only need the rank order - scores are not used. + """ + messages = [] for row in result.rows: # Build message dict with key fields @@ -609,19 +634,9 @@ class TurbopufferClient: "role": getattr(row, "role", None), "created_at": getattr(row, "created_at", None), } + messages.append(message_dict) - # handle score based on search type - if is_fts: - # for FTS, use the BM25 score directly (higher is better) - score = getattr(row, "$score", 0.0) - else: - # for vector search, convert distance to similarity score - distance = getattr(row, "$dist", 0.0) - score = 1.0 - distance - - messages_with_scores.append((message_dict, score)) - - return messages_with_scores + return messages def _process_single_query_results( self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False @@ -663,74 +678,78 @@ class TurbopufferClient: return passages_with_scores - def _generic_reciprocal_rank_fusion( + def _reciprocal_rank_fusion( self, - vector_results: List[Tuple[Any, float]], - fts_results: List[Tuple[Any, float]], + vector_results: List[Any], + fts_results: List[Any], get_id_func: Callable[[Any], str], vector_weight: float, fts_weight: float, top_k: int, - ) -> List[Tuple[Any, float]]: - """Generic RRF implementation that works with any object type. + ) -> List[Tuple[Any, float, dict]]: + """RRF implementation that works with any object type. - RRF score = vector_weight * (1/(k + vector_rank)) + fts_weight * (1/(k + fts_rank)) + RRF score = vector_weight * (1/(k + rank)) + fts_weight * (1/(k + rank)) where k is a constant (typically 60) to avoid division by zero + This is a pure rank-based fusion following the standard RRF algorithm. + Args: - vector_results: List of (item, score) tuples from vector search - fts_results: List of (item, score) tuples from FTS + vector_results: List of items from vector search (ordered by relevance) + fts_results: List of items from FTS (ordered by relevance) get_id_func: Function to extract ID from an item vector_weight: Weight for vector search results fts_weight: Weight for FTS results top_k: Number of results to return Returns: - List of (item, score) tuples sorted by RRF score + List of (item, score, metadata) tuples sorted by RRF score + metadata contains ranks from each result list """ - k = 60 # standard RRF constant + k = 60 # standard RRF constant from Cormack et al. (2009) - # create rank mappings using the get_id_func - vector_ranks = {get_id_func(item): rank + 1 for rank, (item, _) in enumerate(vector_results)} - fts_ranks = {get_id_func(item): rank + 1 for rank, (item, _) in enumerate(fts_results)} + # create rank mappings based on position in result lists + # rank starts at 1, not 0 + vector_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(vector_results)} + fts_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(fts_results)} - # combine all unique items + # combine all unique items from both result sets all_items = {} - for item, _ in vector_results: + for item in vector_results: all_items[get_id_func(item)] = item - for item, _ in fts_results: + for item in fts_results: all_items[get_id_func(item)] = item - # calculate RRF scores + # calculate RRF scores based purely on ranks rrf_scores = {} + score_metadata = {} for item_id in all_items: - vector_score = vector_weight / (k + vector_ranks.get(item_id, k + top_k)) - fts_score = fts_weight / (k + fts_ranks.get(item_id, k + top_k)) - rrf_scores[item_id] = vector_score + fts_score + # RRF formula: sum of 1/(k + rank) across result lists + # If item not in a list, we don't add anything (equivalent to rank = infinity) + vector_rrf_score = 0.0 + fts_rrf_score = 0.0 - # sort by RRF score and return top_k - sorted_results = sorted([(all_items[iid], score) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True) + if item_id in vector_ranks: + vector_rrf_score = vector_weight / (k + vector_ranks[item_id]) + if item_id in fts_ranks: + fts_rrf_score = fts_weight / (k + fts_ranks[item_id]) + + combined_score = vector_rrf_score + fts_rrf_score + + rrf_scores[item_id] = combined_score + score_metadata[item_id] = { + "combined_score": combined_score, # Final RRF score + "vector_rank": vector_ranks.get(item_id), + "fts_rank": fts_ranks.get(item_id), + } + + # sort by RRF score and return with metadata + sorted_results = sorted( + [(all_items[iid], score, score_metadata[iid]) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True + ) return sorted_results[:top_k] - def _reciprocal_rank_fusion( - self, - vector_results: List[Tuple[PydanticPassage, float]], - fts_results: List[Tuple[PydanticPassage, float]], - vector_weight: float, - fts_weight: float, - top_k: int, - ) -> List[Tuple[PydanticPassage, float]]: - """Wrapper for backwards compatibility - uses generic RRF for passages.""" - return self._generic_reciprocal_rank_fusion( - vector_results=vector_results, - fts_results=fts_results, - get_id_func=lambda p: p.id, - vector_weight=vector_weight, - fts_weight=fts_weight, - top_k=top_k, - ) - @trace_method async def delete_passage(self, archive_id: str, passage_id: str) -> bool: """Delete a passage from Turbopuffer.""" diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 51c96558..774eac69 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -1,7 +1,7 @@ import json import uuid from datetime import datetime -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, Tuple from sqlalchemy import delete, exists, func, select, text @@ -1065,7 +1065,7 @@ class MessageManager: start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, embedding_config: Optional[EmbeddingConfig] = None, - ) -> List[PydanticMessage]: + ) -> List[Tuple[PydanticMessage, dict]]: """ Search messages using Turbopuffer if enabled, otherwise fall back to SQL search. @@ -1082,7 +1082,7 @@ class MessageManager: embedding_config: Optional embedding configuration for generating query embedding Returns: - List of matching messages + List of tuples (message, metadata) where metadata contains relevance scores """ from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages @@ -1133,8 +1133,8 @@ class MessageManager: from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message as PydanticMessage - turbopuffer_messages = [] - for msg_dict, score in results: + message_tuples = [] + for msg_dict, score, metadata in results: # create a message object with the properly extracted text from turbopuffer message = PydanticMessage( id=msg_dict["id"], @@ -1146,9 +1146,10 @@ class MessageManager: created_by_id=actor.id, last_updated_by_id=actor.id, ) - turbopuffer_messages.append(message) + # Return tuple of (message, metadata) + message_tuples.append((message, metadata)) - return turbopuffer_messages + return message_tuples else: return [] @@ -1163,7 +1164,16 @@ class MessageManager: limit=limit, ascending=False, ) - return self._combine_assistant_tool_messages(messages) + combined_messages = self._combine_assistant_tool_messages(messages) + # Add basic metadata for SQL fallback + message_tuples = [] + for message in combined_messages: + metadata = { + "search_mode": "sql_fallback", + "combined_score": None, # SQL doesn't provide scores + } + message_tuples.append((message, metadata)) + return message_tuples else: # use sql-based search messages = await self.list_messages_for_agent_async( @@ -1174,4 +1184,13 @@ class MessageManager: limit=limit, ascending=False, ) - return self._combine_assistant_tool_messages(messages) + combined_messages = self._combine_assistant_tool_messages(messages) + # Add basic metadata for SQL search + message_tuples = [] + for message in combined_messages: + metadata = { + "search_mode": "sql", + "combined_score": None, # SQL doesn't provide scores + } + message_tuples.append((message, metadata)) + return message_tuples diff --git a/letta/services/tool_executor/core_tool_executor.py b/letta/services/tool_executor/core_tool_executor.py index 5561e00d..3338914c 100644 --- a/letta/services/tool_executor/core_tool_executor.py +++ b/letta/services/tool_executor/core_tool_executor.py @@ -155,7 +155,7 @@ class LettaCoreToolExecutor(ToolExecutor): search_limit = limit if limit is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE # Search using the message manager's search_messages_async method - messages = await self.message_manager.search_messages_async( + message_results = await self.message_manager.search_messages_async( agent_id=agent_state.id, actor=actor, query_text=query, @@ -166,10 +166,10 @@ class LettaCoreToolExecutor(ToolExecutor): embedding_config=agent_state.embedding_config, ) - if len(messages) == 0: + if len(message_results) == 0: results_str = "No results found." else: - results_pref = f"Showing {len(messages)} results:" + results_pref = f"Showing {len(message_results)} results:" results_formatted = [] # get current time in UTC, then convert to agent timezone for consistent comparison from datetime import timezone @@ -184,7 +184,7 @@ class LettaCoreToolExecutor(ToolExecutor): else: now = now_utc - for message in messages: + for message, metadata in message_results: # Format timestamp in agent's timezone if available timestamp = message.created_at time_delta_str = "" @@ -229,6 +229,23 @@ class LettaCoreToolExecutor(ToolExecutor): "role": message.role, } + # Add search relevance metadata if available + if metadata: + # Only include non-None values + relevance_info = { + k: v + for k, v in { + "rrf_score": metadata.get("combined_score"), + "vector_rank": metadata.get("vector_rank"), + "fts_rank": metadata.get("fts_rank"), + "search_mode": metadata.get("search_mode"), + }.items() + if v is not None + } + + if relevance_info: # Only add if we have metadata + result_dict["relevance"] = relevance_info + # _extract_message_text returns already JSON-encoded strings # We need to parse them to get the actual content structure if content: diff --git a/tests/integration_test_turbopuffer.py b/tests/integration_test_turbopuffer.py index d8f1aacc..1b19a96c 100644 --- a/tests/integration_test_turbopuffer.py +++ b/tests/integration_test_turbopuffer.py @@ -991,18 +991,19 @@ class TestTurbopufferMessagesIntegration: vector_results = [(passage1, 0.9), (passage2, 0.7)] fts_results = [(passage2, 0.8), (passage1, 0.6)] - # Test with passages using the wrapper function + # Test with passages using the RRF function combined = client._reciprocal_rank_fusion( - vector_results=vector_results, - fts_results=fts_results, + vector_results=[passage for passage, _ in vector_results], + fts_results=[passage for passage, _ in fts_results], + get_id_func=lambda p: p.id, vector_weight=0.5, fts_weight=0.5, top_k=2, ) assert len(combined) == 2 - # Both passages should be in results - result_ids = [p.id for p, _ in combined] + # Both passages should be in results - now returns (passage, score, metadata) + result_ids = [p.id for p, _, _ in combined] assert p1_id in result_ids assert p2_id in result_ids @@ -1014,9 +1015,9 @@ class TestTurbopufferMessagesIntegration: vector_msg_results = [(msg1, 0.95), (msg2, 0.85), (msg3, 0.75)] fts_msg_results = [(msg2, 0.90), (msg3, 0.80), (msg1, 0.70)] - combined_msgs = client._generic_reciprocal_rank_fusion( - vector_results=vector_msg_results, - fts_results=fts_msg_results, + combined_msgs = client._reciprocal_rank_fusion( + vector_results=[msg for msg, _ in vector_msg_results], + fts_results=[msg for msg, _ in fts_msg_results], get_id_func=lambda m: m["id"], vector_weight=0.6, fts_weight=0.4, @@ -1024,14 +1025,14 @@ class TestTurbopufferMessagesIntegration: ) assert len(combined_msgs) == 3 - msg_ids = [m["id"] for m, _ in combined_msgs] + msg_ids = [m["id"] for m, _, _ in combined_msgs] assert "m1" in msg_ids assert "m2" in msg_ids assert "m3" in msg_ids # Test edge cases # Empty results - empty_combined = client._generic_reciprocal_rank_fusion( + empty_combined = client._reciprocal_rank_fusion( vector_results=[], fts_results=[], get_id_func=lambda x: x["id"], @@ -1042,8 +1043,8 @@ class TestTurbopufferMessagesIntegration: assert len(empty_combined) == 0 # Single result list - single_combined = client._generic_reciprocal_rank_fusion( - vector_results=[(msg1, 0.9)], + single_combined = client._reciprocal_rank_fusion( + vector_results=[msg1], fts_results=[], get_id_func=lambda m: m["id"], vector_weight=0.5, @@ -1104,7 +1105,7 @@ class TestTurbopufferMessagesIntegration: assert len(results) == 3 # Results should be ordered by timestamp (most recent first) - for msg_dict, score in results: + for msg_dict, score, metadata in results: assert msg_dict["agent_id"] == agent_id assert msg_dict["organization_id"] == org_id assert msg_dict["text"] in message_texts @@ -1172,7 +1173,7 @@ class TestTurbopufferMessagesIntegration: assert len(results) == 2 # Should return Python-related messages first - result_texts = [msg["text"] for msg, _ in results] + result_texts = [msg["text"] for msg, _, _ in results] assert "Python is a great programming language" in result_texts assert "Machine learning with Python is powerful" in result_texts @@ -1242,7 +1243,7 @@ class TestTurbopufferMessagesIntegration: assert len(results) > 0 # Should get a mix of results based on both vector and text similarity - result_texts = [msg["text"] for msg, _ in results] + result_texts = [msg["text"] for msg, _, _ in results] # At least one result should contain "quick" due to FTS assert any("quick" in text.lower() for text in result_texts) @@ -1304,7 +1305,7 @@ class TestTurbopufferMessagesIntegration: ) assert len(user_results) == 2 - for msg, _ in user_results: + for msg, _, _ in user_results: assert msg["role"] == "user" assert msg["text"] in ["I need help with Python", "Can you explain this?"] @@ -1318,7 +1319,7 @@ class TestTurbopufferMessagesIntegration: ) assert len(non_user_results) == 3 - for msg, _ in non_user_results: + for msg, _, _ in non_user_results: assert msg["role"] in ["assistant", "system"] finally: @@ -1363,7 +1364,7 @@ class TestTurbopufferMessagesIntegration: # Should return results from SQL search assert len(results) > 0 # Extract text from messages and check for "fallback" - for msg in results: + for msg, metadata in results: text = server.message_manager._extract_message_text(msg) if "fallback" in text.lower(): break @@ -1410,7 +1411,7 @@ class TestTurbopufferMessagesIntegration: embedding_config=embedding_config, ) assert len(python_results) > 0 - assert any(msg.id == message_id for msg in python_results) + assert any(msg.id == message_id for msg, metadata in python_results) # Update the message content updated_message = await server.message_manager.update_message_by_id_async( @@ -1433,7 +1434,7 @@ class TestTurbopufferMessagesIntegration: embedding_config=embedding_config, ) # Should either find no results or results that don't include our message - assert not any(msg.id == message_id for msg in python_results_after) + assert not any(msg.id == message_id for msg, metadata in python_results_after) # Search for "JavaScript" - should find the updated message js_results = await server.message_manager.search_messages_async( @@ -1445,7 +1446,7 @@ class TestTurbopufferMessagesIntegration: embedding_config=embedding_config, ) assert len(js_results) > 0 - assert any(msg.id == message_id for msg in js_results) + assert any(msg.id == message_id for msg, metadata in js_results) # Clean up await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True) @@ -1561,7 +1562,7 @@ class TestTurbopufferMessagesIntegration: ) assert len(agent_a_final) == 2 # Verify the remaining messages are the correct ones - remaining_ids = {msg.id for msg in agent_a_final} + remaining_ids = {msg.id for msg, metadata in agent_a_final} assert agent_a_messages[3].id in remaining_ids assert agent_a_messages[4].id in remaining_ids @@ -1616,7 +1617,7 @@ class TestTurbopufferMessagesIntegration: embedding_config=embedding_config, ) assert len(initial_search) > 0 - assert any(msg.id == message_id for msg in initial_search) + assert any(msg.id == message_id for msg, metadata in initial_search) # Update message WITHOUT embedding_config - should update postgres but not turbopuffer updated_message = await server.message_manager.update_message_by_id_async( @@ -1642,7 +1643,7 @@ class TestTurbopufferMessagesIntegration: embedding_config=embedding_config, ) assert len(still_searchable) > 0 - assert any(msg.id == message_id for msg in still_searchable) + assert any(msg.id == message_id for msg, metadata in still_searchable) # New content should NOT be searchable (wasn't re-indexed) not_searchable = await server.message_manager.search_messages_async( @@ -1654,7 +1655,7 @@ class TestTurbopufferMessagesIntegration: embedding_config=embedding_config, ) # Should either find no results or results that don't include our message - assert not any(msg.id == message_id for msg in not_searchable) + assert not any(msg.id == message_id for msg, metadata in not_searchable) # Clean up await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True) @@ -1789,7 +1790,7 @@ class TestTurbopufferMessagesIntegration: # Should get today's and yesterday's messages assert len(recent_results) == 2 - result_texts = [msg["text"] for msg, _ in recent_results] + result_texts = [msg["text"] for msg, _, _ in recent_results] assert "Today's message" in result_texts assert "Yesterday's message" in result_texts @@ -1820,7 +1821,7 @@ class TestTurbopufferMessagesIntegration: # Should get only recent messages assert len(filtered_vector_results) == 2 - for msg, _ in filtered_vector_results: + for msg, _, _ in filtered_vector_results: assert msg["text"] in ["Today's message", "Yesterday's message"] finally: