fix: change to pure rank-based RRF for relevance ordering (#4411)

* Fix RRF

* Fix turbopuffer tests
This commit is contained in:
Matthew Zhou
2025-09-03 17:33:19 -07:00
committed by GitHub
parent fc50a41680
commit d924cc005b
5 changed files with 174 additions and 120 deletions

View File

@@ -991,18 +991,19 @@ class TestTurbopufferMessagesIntegration:
vector_results = [(passage1, 0.9), (passage2, 0.7)]
fts_results = [(passage2, 0.8), (passage1, 0.6)]
# Test with passages using the wrapper function
# Test with passages using the RRF function
combined = client._reciprocal_rank_fusion(
vector_results=vector_results,
fts_results=fts_results,
vector_results=[passage for passage, _ in vector_results],
fts_results=[passage for passage, _ in fts_results],
get_id_func=lambda p: p.id,
vector_weight=0.5,
fts_weight=0.5,
top_k=2,
)
assert len(combined) == 2
# Both passages should be in results
result_ids = [p.id for p, _ in combined]
# Both passages should be in results - now returns (passage, score, metadata)
result_ids = [p.id for p, _, _ in combined]
assert p1_id in result_ids
assert p2_id in result_ids
@@ -1014,9 +1015,9 @@ class TestTurbopufferMessagesIntegration:
vector_msg_results = [(msg1, 0.95), (msg2, 0.85), (msg3, 0.75)]
fts_msg_results = [(msg2, 0.90), (msg3, 0.80), (msg1, 0.70)]
combined_msgs = client._generic_reciprocal_rank_fusion(
vector_results=vector_msg_results,
fts_results=fts_msg_results,
combined_msgs = client._reciprocal_rank_fusion(
vector_results=[msg for msg, _ in vector_msg_results],
fts_results=[msg for msg, _ in fts_msg_results],
get_id_func=lambda m: m["id"],
vector_weight=0.6,
fts_weight=0.4,
@@ -1024,14 +1025,14 @@ class TestTurbopufferMessagesIntegration:
)
assert len(combined_msgs) == 3
msg_ids = [m["id"] for m, _ in combined_msgs]
msg_ids = [m["id"] for m, _, _ in combined_msgs]
assert "m1" in msg_ids
assert "m2" in msg_ids
assert "m3" in msg_ids
# Test edge cases
# Empty results
empty_combined = client._generic_reciprocal_rank_fusion(
empty_combined = client._reciprocal_rank_fusion(
vector_results=[],
fts_results=[],
get_id_func=lambda x: x["id"],
@@ -1042,8 +1043,8 @@ class TestTurbopufferMessagesIntegration:
assert len(empty_combined) == 0
# Single result list
single_combined = client._generic_reciprocal_rank_fusion(
vector_results=[(msg1, 0.9)],
single_combined = client._reciprocal_rank_fusion(
vector_results=[msg1],
fts_results=[],
get_id_func=lambda m: m["id"],
vector_weight=0.5,
@@ -1104,7 +1105,7 @@ class TestTurbopufferMessagesIntegration:
assert len(results) == 3
# Results should be ordered by timestamp (most recent first)
for msg_dict, score in results:
for msg_dict, score, metadata in results:
assert msg_dict["agent_id"] == agent_id
assert msg_dict["organization_id"] == org_id
assert msg_dict["text"] in message_texts
@@ -1172,7 +1173,7 @@ class TestTurbopufferMessagesIntegration:
assert len(results) == 2
# Should return Python-related messages first
result_texts = [msg["text"] for msg, _ in results]
result_texts = [msg["text"] for msg, _, _ in results]
assert "Python is a great programming language" in result_texts
assert "Machine learning with Python is powerful" in result_texts
@@ -1242,7 +1243,7 @@ class TestTurbopufferMessagesIntegration:
assert len(results) > 0
# Should get a mix of results based on both vector and text similarity
result_texts = [msg["text"] for msg, _ in results]
result_texts = [msg["text"] for msg, _, _ in results]
# At least one result should contain "quick" due to FTS
assert any("quick" in text.lower() for text in result_texts)
@@ -1304,7 +1305,7 @@ class TestTurbopufferMessagesIntegration:
)
assert len(user_results) == 2
for msg, _ in user_results:
for msg, _, _ in user_results:
assert msg["role"] == "user"
assert msg["text"] in ["I need help with Python", "Can you explain this?"]
@@ -1318,7 +1319,7 @@ class TestTurbopufferMessagesIntegration:
)
assert len(non_user_results) == 3
for msg, _ in non_user_results:
for msg, _, _ in non_user_results:
assert msg["role"] in ["assistant", "system"]
finally:
@@ -1363,7 +1364,7 @@ class TestTurbopufferMessagesIntegration:
# Should return results from SQL search
assert len(results) > 0
# Extract text from messages and check for "fallback"
for msg in results:
for msg, metadata in results:
text = server.message_manager._extract_message_text(msg)
if "fallback" in text.lower():
break
@@ -1410,7 +1411,7 @@ class TestTurbopufferMessagesIntegration:
embedding_config=embedding_config,
)
assert len(python_results) > 0
assert any(msg.id == message_id for msg in python_results)
assert any(msg.id == message_id for msg, metadata in python_results)
# Update the message content
updated_message = await server.message_manager.update_message_by_id_async(
@@ -1433,7 +1434,7 @@ class TestTurbopufferMessagesIntegration:
embedding_config=embedding_config,
)
# Should either find no results or results that don't include our message
assert not any(msg.id == message_id for msg in python_results_after)
assert not any(msg.id == message_id for msg, metadata in python_results_after)
# Search for "JavaScript" - should find the updated message
js_results = await server.message_manager.search_messages_async(
@@ -1445,7 +1446,7 @@ class TestTurbopufferMessagesIntegration:
embedding_config=embedding_config,
)
assert len(js_results) > 0
assert any(msg.id == message_id for msg in js_results)
assert any(msg.id == message_id for msg, metadata in js_results)
# Clean up
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
@@ -1561,7 +1562,7 @@ class TestTurbopufferMessagesIntegration:
)
assert len(agent_a_final) == 2
# Verify the remaining messages are the correct ones
remaining_ids = {msg.id for msg in agent_a_final}
remaining_ids = {msg.id for msg, metadata in agent_a_final}
assert agent_a_messages[3].id in remaining_ids
assert agent_a_messages[4].id in remaining_ids
@@ -1616,7 +1617,7 @@ class TestTurbopufferMessagesIntegration:
embedding_config=embedding_config,
)
assert len(initial_search) > 0
assert any(msg.id == message_id for msg in initial_search)
assert any(msg.id == message_id for msg, metadata in initial_search)
# Update message WITHOUT embedding_config - should update postgres but not turbopuffer
updated_message = await server.message_manager.update_message_by_id_async(
@@ -1642,7 +1643,7 @@ class TestTurbopufferMessagesIntegration:
embedding_config=embedding_config,
)
assert len(still_searchable) > 0
assert any(msg.id == message_id for msg in still_searchable)
assert any(msg.id == message_id for msg, metadata in still_searchable)
# New content should NOT be searchable (wasn't re-indexed)
not_searchable = await server.message_manager.search_messages_async(
@@ -1654,7 +1655,7 @@ class TestTurbopufferMessagesIntegration:
embedding_config=embedding_config,
)
# Should either find no results or results that don't include our message
assert not any(msg.id == message_id for msg in not_searchable)
assert not any(msg.id == message_id for msg, metadata in not_searchable)
# Clean up
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
@@ -1789,7 +1790,7 @@ class TestTurbopufferMessagesIntegration:
# Should get today's and yesterday's messages
assert len(recent_results) == 2
result_texts = [msg["text"] for msg, _ in recent_results]
result_texts = [msg["text"] for msg, _, _ in recent_results]
assert "Today's message" in result_texts
assert "Yesterday's message" in result_texts
@@ -1820,7 +1821,7 @@ class TestTurbopufferMessagesIntegration:
# Should get only recent messages
assert len(filtered_vector_results) == 2
for msg, _ in filtered_vector_results:
for msg, _, _ in filtered_vector_results:
assert msg["text"] in ["Today's message", "Yesterday's message"]
finally: