fix: change to pure rank-based RRF for relevance ordering (#4411)

* Fix RRF

* Fix turbopuffer tests
This commit is contained in:
Matthew Zhou
2025-09-03 17:33:19 -07:00
committed by GitHub
parent fc50a41680
commit d924cc005b
5 changed files with 174 additions and 120 deletions

View File

@@ -55,33 +55,31 @@ def conversation_search(
str: Query result string containing matching messages with timestamps and content.
"""
import math
from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.helpers.json_helpers import json_dumps
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
try:
page = int(page)
except:
raise ValueError("'page' argument must be an integer")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
# TODO: add paging by page number. currently cursor only works with strings.
# original: start=page * count
# Use provided limit or default
if limit is None:
limit = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
messages = self.message_manager.list_messages_for_agent(
agent_id=self.agent_state.id,
actor=self.user,
query_text=query,
limit=count,
roles=roles,
limit=limit,
)
total = len(messages)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(messages) == 0:
results_str = "No results found."
else:
results_pref = f"Showing {len(messages)} of {total} results (page {page}/{num_pages}):"
results_formatted = [message.content[0].text for message in messages]
results_pref = f"Found {len(messages)} results:"
results_formatted = []
for message in messages:
# Extract text content from message
text_content = message.content[0].text if message.content else ""
result_entry = {"role": message.role, "content": text_content}
results_formatted.append(result_entry)
results_str = f"{results_pref} {json_dumps(results_formatted)}"
return results_str

View File

@@ -474,8 +474,16 @@ class TurbopufferClient:
# for hybrid mode, we get a multi-query response
vector_results = self._process_single_query_results(result.results[0], archive_id, tags)
fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True)
# use backwards-compatible wrapper which calls generic RRF
return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k)
# use RRF and return only (passage, score) for backwards compatibility
results_with_metadata = self._reciprocal_rank_fusion(
vector_results=[passage for passage, _ in vector_results],
fts_results=[passage for passage, _ in fts_results],
get_id_func=lambda p: p.id,
vector_weight=vector_weight,
fts_weight=fts_weight,
top_k=top_k,
)
return [(passage, rrf_score) for passage, rrf_score, metadata in results_with_metadata]
else:
# for single queries (vector, fts, timestamp)
is_fts = search_mode == "fts"
@@ -499,7 +507,7 @@ class TurbopufferClient:
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> List[Tuple[dict, float]]:
) -> List[Tuple[dict, float, dict]]:
"""Query messages from Turbopuffer using vector search, full-text search, or hybrid search.
Args:
@@ -516,7 +524,10 @@ class TurbopufferClient:
end_date: Optional datetime to filter messages created before this date
Returns:
List of (message_dict, score) tuples where message_dict contains id, text, role, created_at
List of (message_dict, score, metadata) tuples where:
- message_dict contains id, text, role, created_at
- score is the final relevance score
- metadata contains individual scores and ranking information
"""
# Check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
@@ -576,9 +587,9 @@ class TurbopufferClient:
if search_mode == "hybrid":
# for hybrid mode, we get a multi-query response
vector_results = self._process_message_query_results(result.results[0])
fts_results = self._process_message_query_results(result.results[1], is_fts=True)
# use generic RRF with lambda to extract ID from dict
return self._generic_reciprocal_rank_fusion(
fts_results = self._process_message_query_results(result.results[1])
# use RRF with lambda to extract ID from dict - returns metadata
results_with_metadata = self._reciprocal_rank_fusion(
vector_results=vector_results,
fts_results=fts_results,
get_id_func=lambda msg_dict: msg_dict["id"],
@@ -586,18 +597,32 @@ class TurbopufferClient:
fts_weight=fts_weight,
top_k=top_k,
)
# return results with metadata
return results_with_metadata
else:
# for single queries (vector, fts, timestamp)
is_fts = search_mode == "fts"
return self._process_message_query_results(result, is_fts=is_fts)
results = self._process_message_query_results(result)
# add simple metadata for single search modes
results_with_metadata = []
for idx, msg_dict in enumerate(results):
metadata = {
"combined_score": 1.0 / (idx + 1), # Use rank-based score for single mode
"search_mode": search_mode,
f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
}
results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
return results_with_metadata
except Exception as e:
logger.error(f"Failed to query messages from Turbopuffer: {e}")
raise
def _process_message_query_results(self, result, is_fts: bool = False) -> List[Tuple[dict, float]]:
"""Process results from a message query into message dicts with scores."""
messages_with_scores = []
def _process_message_query_results(self, result) -> List[dict]:
"""Process results from a message query into message dicts.
For RRF, we only need the rank order - scores are not used.
"""
messages = []
for row in result.rows:
# Build message dict with key fields
@@ -609,19 +634,9 @@ class TurbopufferClient:
"role": getattr(row, "role", None),
"created_at": getattr(row, "created_at", None),
}
messages.append(message_dict)
# handle score based on search type
if is_fts:
# for FTS, use the BM25 score directly (higher is better)
score = getattr(row, "$score", 0.0)
else:
# for vector search, convert distance to similarity score
distance = getattr(row, "$dist", 0.0)
score = 1.0 - distance
messages_with_scores.append((message_dict, score))
return messages_with_scores
return messages
def _process_single_query_results(
self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False
@@ -663,74 +678,78 @@ class TurbopufferClient:
return passages_with_scores
def _generic_reciprocal_rank_fusion(
def _reciprocal_rank_fusion(
self,
vector_results: List[Tuple[Any, float]],
fts_results: List[Tuple[Any, float]],
vector_results: List[Any],
fts_results: List[Any],
get_id_func: Callable[[Any], str],
vector_weight: float,
fts_weight: float,
top_k: int,
) -> List[Tuple[Any, float]]:
"""Generic RRF implementation that works with any object type.
) -> List[Tuple[Any, float, dict]]:
"""RRF implementation that works with any object type.
RRF score = vector_weight * (1/(k + vector_rank)) + fts_weight * (1/(k + fts_rank))
RRF score = vector_weight * (1/(k + rank)) + fts_weight * (1/(k + rank))
where k is a constant (typically 60) to avoid division by zero
This is a pure rank-based fusion following the standard RRF algorithm.
Args:
vector_results: List of (item, score) tuples from vector search
fts_results: List of (item, score) tuples from FTS
vector_results: List of items from vector search (ordered by relevance)
fts_results: List of items from FTS (ordered by relevance)
get_id_func: Function to extract ID from an item
vector_weight: Weight for vector search results
fts_weight: Weight for FTS results
top_k: Number of results to return
Returns:
List of (item, score) tuples sorted by RRF score
List of (item, score, metadata) tuples sorted by RRF score
metadata contains ranks from each result list
"""
k = 60 # standard RRF constant
k = 60 # standard RRF constant from Cormack et al. (2009)
# create rank mappings using the get_id_func
vector_ranks = {get_id_func(item): rank + 1 for rank, (item, _) in enumerate(vector_results)}
fts_ranks = {get_id_func(item): rank + 1 for rank, (item, _) in enumerate(fts_results)}
# create rank mappings based on position in result lists
# rank starts at 1, not 0
vector_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(vector_results)}
fts_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(fts_results)}
# combine all unique items
# combine all unique items from both result sets
all_items = {}
for item, _ in vector_results:
for item in vector_results:
all_items[get_id_func(item)] = item
for item, _ in fts_results:
for item in fts_results:
all_items[get_id_func(item)] = item
# calculate RRF scores
# calculate RRF scores based purely on ranks
rrf_scores = {}
score_metadata = {}
for item_id in all_items:
vector_score = vector_weight / (k + vector_ranks.get(item_id, k + top_k))
fts_score = fts_weight / (k + fts_ranks.get(item_id, k + top_k))
rrf_scores[item_id] = vector_score + fts_score
# RRF formula: sum of 1/(k + rank) across result lists
# If item not in a list, we don't add anything (equivalent to rank = infinity)
vector_rrf_score = 0.0
fts_rrf_score = 0.0
# sort by RRF score and return top_k
sorted_results = sorted([(all_items[iid], score) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True)
if item_id in vector_ranks:
vector_rrf_score = vector_weight / (k + vector_ranks[item_id])
if item_id in fts_ranks:
fts_rrf_score = fts_weight / (k + fts_ranks[item_id])
combined_score = vector_rrf_score + fts_rrf_score
rrf_scores[item_id] = combined_score
score_metadata[item_id] = {
"combined_score": combined_score, # Final RRF score
"vector_rank": vector_ranks.get(item_id),
"fts_rank": fts_ranks.get(item_id),
}
# sort by RRF score and return with metadata
sorted_results = sorted(
[(all_items[iid], score, score_metadata[iid]) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True
)
return sorted_results[:top_k]
def _reciprocal_rank_fusion(
self,
vector_results: List[Tuple[PydanticPassage, float]],
fts_results: List[Tuple[PydanticPassage, float]],
vector_weight: float,
fts_weight: float,
top_k: int,
) -> List[Tuple[PydanticPassage, float]]:
"""Wrapper for backwards compatibility - uses generic RRF for passages."""
return self._generic_reciprocal_rank_fusion(
vector_results=vector_results,
fts_results=fts_results,
get_id_func=lambda p: p.id,
vector_weight=vector_weight,
fts_weight=fts_weight,
top_k=top_k,
)
@trace_method
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
"""Delete a passage from Turbopuffer."""

View File

@@ -1,7 +1,7 @@
import json
import uuid
from datetime import datetime
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Tuple
from sqlalchemy import delete, exists, func, select, text
@@ -1065,7 +1065,7 @@ class MessageManager:
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
embedding_config: Optional[EmbeddingConfig] = None,
) -> List[PydanticMessage]:
) -> List[Tuple[PydanticMessage, dict]]:
"""
Search messages using Turbopuffer if enabled, otherwise fall back to SQL search.
@@ -1082,7 +1082,7 @@ class MessageManager:
embedding_config: Optional embedding configuration for generating query embedding
Returns:
List of matching messages
List of tuples (message, metadata) where metadata contains relevance scores
"""
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
@@ -1133,8 +1133,8 @@ class MessageManager:
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message as PydanticMessage
turbopuffer_messages = []
for msg_dict, score in results:
message_tuples = []
for msg_dict, score, metadata in results:
# create a message object with the properly extracted text from turbopuffer
message = PydanticMessage(
id=msg_dict["id"],
@@ -1146,9 +1146,10 @@ class MessageManager:
created_by_id=actor.id,
last_updated_by_id=actor.id,
)
turbopuffer_messages.append(message)
# Return tuple of (message, metadata)
message_tuples.append((message, metadata))
return turbopuffer_messages
return message_tuples
else:
return []
@@ -1163,7 +1164,16 @@ class MessageManager:
limit=limit,
ascending=False,
)
return self._combine_assistant_tool_messages(messages)
combined_messages = self._combine_assistant_tool_messages(messages)
# Add basic metadata for SQL fallback
message_tuples = []
for message in combined_messages:
metadata = {
"search_mode": "sql_fallback",
"combined_score": None, # SQL doesn't provide scores
}
message_tuples.append((message, metadata))
return message_tuples
else:
# use sql-based search
messages = await self.list_messages_for_agent_async(
@@ -1174,4 +1184,13 @@ class MessageManager:
limit=limit,
ascending=False,
)
return self._combine_assistant_tool_messages(messages)
combined_messages = self._combine_assistant_tool_messages(messages)
# Add basic metadata for SQL search
message_tuples = []
for message in combined_messages:
metadata = {
"search_mode": "sql",
"combined_score": None, # SQL doesn't provide scores
}
message_tuples.append((message, metadata))
return message_tuples

View File

@@ -155,7 +155,7 @@ class LettaCoreToolExecutor(ToolExecutor):
search_limit = limit if limit is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
# Search using the message manager's search_messages_async method
messages = await self.message_manager.search_messages_async(
message_results = await self.message_manager.search_messages_async(
agent_id=agent_state.id,
actor=actor,
query_text=query,
@@ -166,10 +166,10 @@ class LettaCoreToolExecutor(ToolExecutor):
embedding_config=agent_state.embedding_config,
)
if len(messages) == 0:
if len(message_results) == 0:
results_str = "No results found."
else:
results_pref = f"Showing {len(messages)} results:"
results_pref = f"Showing {len(message_results)} results:"
results_formatted = []
# get current time in UTC, then convert to agent timezone for consistent comparison
from datetime import timezone
@@ -184,7 +184,7 @@ class LettaCoreToolExecutor(ToolExecutor):
else:
now = now_utc
for message in messages:
for message, metadata in message_results:
# Format timestamp in agent's timezone if available
timestamp = message.created_at
time_delta_str = ""
@@ -229,6 +229,23 @@ class LettaCoreToolExecutor(ToolExecutor):
"role": message.role,
}
# Add search relevance metadata if available
if metadata:
# Only include non-None values
relevance_info = {
k: v
for k, v in {
"rrf_score": metadata.get("combined_score"),
"vector_rank": metadata.get("vector_rank"),
"fts_rank": metadata.get("fts_rank"),
"search_mode": metadata.get("search_mode"),
}.items()
if v is not None
}
if relevance_info: # Only add if we have metadata
result_dict["relevance"] = relevance_info
# _extract_message_text returns already JSON-encoded strings
# We need to parse them to get the actual content structure
if content:

View File

@@ -991,18 +991,19 @@ class TestTurbopufferMessagesIntegration:
vector_results = [(passage1, 0.9), (passage2, 0.7)]
fts_results = [(passage2, 0.8), (passage1, 0.6)]
# Test with passages using the wrapper function
# Test with passages using the RRF function
combined = client._reciprocal_rank_fusion(
vector_results=vector_results,
fts_results=fts_results,
vector_results=[passage for passage, _ in vector_results],
fts_results=[passage for passage, _ in fts_results],
get_id_func=lambda p: p.id,
vector_weight=0.5,
fts_weight=0.5,
top_k=2,
)
assert len(combined) == 2
# Both passages should be in results
result_ids = [p.id for p, _ in combined]
# Both passages should be in results - now returns (passage, score, metadata)
result_ids = [p.id for p, _, _ in combined]
assert p1_id in result_ids
assert p2_id in result_ids
@@ -1014,9 +1015,9 @@ class TestTurbopufferMessagesIntegration:
vector_msg_results = [(msg1, 0.95), (msg2, 0.85), (msg3, 0.75)]
fts_msg_results = [(msg2, 0.90), (msg3, 0.80), (msg1, 0.70)]
combined_msgs = client._generic_reciprocal_rank_fusion(
vector_results=vector_msg_results,
fts_results=fts_msg_results,
combined_msgs = client._reciprocal_rank_fusion(
vector_results=[msg for msg, _ in vector_msg_results],
fts_results=[msg for msg, _ in fts_msg_results],
get_id_func=lambda m: m["id"],
vector_weight=0.6,
fts_weight=0.4,
@@ -1024,14 +1025,14 @@ class TestTurbopufferMessagesIntegration:
)
assert len(combined_msgs) == 3
msg_ids = [m["id"] for m, _ in combined_msgs]
msg_ids = [m["id"] for m, _, _ in combined_msgs]
assert "m1" in msg_ids
assert "m2" in msg_ids
assert "m3" in msg_ids
# Test edge cases
# Empty results
empty_combined = client._generic_reciprocal_rank_fusion(
empty_combined = client._reciprocal_rank_fusion(
vector_results=[],
fts_results=[],
get_id_func=lambda x: x["id"],
@@ -1042,8 +1043,8 @@ class TestTurbopufferMessagesIntegration:
assert len(empty_combined) == 0
# Single result list
single_combined = client._generic_reciprocal_rank_fusion(
vector_results=[(msg1, 0.9)],
single_combined = client._reciprocal_rank_fusion(
vector_results=[msg1],
fts_results=[],
get_id_func=lambda m: m["id"],
vector_weight=0.5,
@@ -1104,7 +1105,7 @@ class TestTurbopufferMessagesIntegration:
assert len(results) == 3
# Results should be ordered by timestamp (most recent first)
for msg_dict, score in results:
for msg_dict, score, metadata in results:
assert msg_dict["agent_id"] == agent_id
assert msg_dict["organization_id"] == org_id
assert msg_dict["text"] in message_texts
@@ -1172,7 +1173,7 @@ class TestTurbopufferMessagesIntegration:
assert len(results) == 2
# Should return Python-related messages first
result_texts = [msg["text"] for msg, _ in results]
result_texts = [msg["text"] for msg, _, _ in results]
assert "Python is a great programming language" in result_texts
assert "Machine learning with Python is powerful" in result_texts
@@ -1242,7 +1243,7 @@ class TestTurbopufferMessagesIntegration:
assert len(results) > 0
# Should get a mix of results based on both vector and text similarity
result_texts = [msg["text"] for msg, _ in results]
result_texts = [msg["text"] for msg, _, _ in results]
# At least one result should contain "quick" due to FTS
assert any("quick" in text.lower() for text in result_texts)
@@ -1304,7 +1305,7 @@ class TestTurbopufferMessagesIntegration:
)
assert len(user_results) == 2
for msg, _ in user_results:
for msg, _, _ in user_results:
assert msg["role"] == "user"
assert msg["text"] in ["I need help with Python", "Can you explain this?"]
@@ -1318,7 +1319,7 @@ class TestTurbopufferMessagesIntegration:
)
assert len(non_user_results) == 3
for msg, _ in non_user_results:
for msg, _, _ in non_user_results:
assert msg["role"] in ["assistant", "system"]
finally:
@@ -1363,7 +1364,7 @@ class TestTurbopufferMessagesIntegration:
# Should return results from SQL search
assert len(results) > 0
# Extract text from messages and check for "fallback"
for msg in results:
for msg, metadata in results:
text = server.message_manager._extract_message_text(msg)
if "fallback" in text.lower():
break
@@ -1410,7 +1411,7 @@ class TestTurbopufferMessagesIntegration:
embedding_config=embedding_config,
)
assert len(python_results) > 0
assert any(msg.id == message_id for msg in python_results)
assert any(msg.id == message_id for msg, metadata in python_results)
# Update the message content
updated_message = await server.message_manager.update_message_by_id_async(
@@ -1433,7 +1434,7 @@ class TestTurbopufferMessagesIntegration:
embedding_config=embedding_config,
)
# Should either find no results or results that don't include our message
assert not any(msg.id == message_id for msg in python_results_after)
assert not any(msg.id == message_id for msg, metadata in python_results_after)
# Search for "JavaScript" - should find the updated message
js_results = await server.message_manager.search_messages_async(
@@ -1445,7 +1446,7 @@ class TestTurbopufferMessagesIntegration:
embedding_config=embedding_config,
)
assert len(js_results) > 0
assert any(msg.id == message_id for msg in js_results)
assert any(msg.id == message_id for msg, metadata in js_results)
# Clean up
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
@@ -1561,7 +1562,7 @@ class TestTurbopufferMessagesIntegration:
)
assert len(agent_a_final) == 2
# Verify the remaining messages are the correct ones
remaining_ids = {msg.id for msg in agent_a_final}
remaining_ids = {msg.id for msg, metadata in agent_a_final}
assert agent_a_messages[3].id in remaining_ids
assert agent_a_messages[4].id in remaining_ids
@@ -1616,7 +1617,7 @@ class TestTurbopufferMessagesIntegration:
embedding_config=embedding_config,
)
assert len(initial_search) > 0
assert any(msg.id == message_id for msg in initial_search)
assert any(msg.id == message_id for msg, metadata in initial_search)
# Update message WITHOUT embedding_config - should update postgres but not turbopuffer
updated_message = await server.message_manager.update_message_by_id_async(
@@ -1642,7 +1643,7 @@ class TestTurbopufferMessagesIntegration:
embedding_config=embedding_config,
)
assert len(still_searchable) > 0
assert any(msg.id == message_id for msg in still_searchable)
assert any(msg.id == message_id for msg, metadata in still_searchable)
# New content should NOT be searchable (wasn't re-indexed)
not_searchable = await server.message_manager.search_messages_async(
@@ -1654,7 +1655,7 @@ class TestTurbopufferMessagesIntegration:
embedding_config=embedding_config,
)
# Should either find no results or results that don't include our message
assert not any(msg.id == message_id for msg in not_searchable)
assert not any(msg.id == message_id for msg, metadata in not_searchable)
# Clean up
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
@@ -1789,7 +1790,7 @@ class TestTurbopufferMessagesIntegration:
# Should get today's and yesterday's messages
assert len(recent_results) == 2
result_texts = [msg["text"] for msg, _ in recent_results]
result_texts = [msg["text"] for msg, _, _ in recent_results]
assert "Today's message" in result_texts
assert "Yesterday's message" in result_texts
@@ -1820,7 +1821,7 @@ class TestTurbopufferMessagesIntegration:
# Should get only recent messages
assert len(filtered_vector_results) == 2
for msg, _ in filtered_vector_results:
for msg, _, _ in filtered_vector_results:
assert msg["text"] in ["Today's message", "Yesterday's message"]
finally: