fix: change to pure rank-based RRF for relevance ordering (#4411)
* Fix RRF * Fix turbopuffer tests
This commit is contained in:
@@ -991,18 +991,19 @@ class TestTurbopufferMessagesIntegration:
|
||||
vector_results = [(passage1, 0.9), (passage2, 0.7)]
|
||||
fts_results = [(passage2, 0.8), (passage1, 0.6)]
|
||||
|
||||
# Test with passages using the wrapper function
|
||||
# Test with passages using the RRF function
|
||||
combined = client._reciprocal_rank_fusion(
|
||||
vector_results=vector_results,
|
||||
fts_results=fts_results,
|
||||
vector_results=[passage for passage, _ in vector_results],
|
||||
fts_results=[passage for passage, _ in fts_results],
|
||||
get_id_func=lambda p: p.id,
|
||||
vector_weight=0.5,
|
||||
fts_weight=0.5,
|
||||
top_k=2,
|
||||
)
|
||||
|
||||
assert len(combined) == 2
|
||||
# Both passages should be in results
|
||||
result_ids = [p.id for p, _ in combined]
|
||||
# Both passages should be in results - now returns (passage, score, metadata)
|
||||
result_ids = [p.id for p, _, _ in combined]
|
||||
assert p1_id in result_ids
|
||||
assert p2_id in result_ids
|
||||
|
||||
@@ -1014,9 +1015,9 @@ class TestTurbopufferMessagesIntegration:
|
||||
vector_msg_results = [(msg1, 0.95), (msg2, 0.85), (msg3, 0.75)]
|
||||
fts_msg_results = [(msg2, 0.90), (msg3, 0.80), (msg1, 0.70)]
|
||||
|
||||
combined_msgs = client._generic_reciprocal_rank_fusion(
|
||||
vector_results=vector_msg_results,
|
||||
fts_results=fts_msg_results,
|
||||
combined_msgs = client._reciprocal_rank_fusion(
|
||||
vector_results=[msg for msg, _ in vector_msg_results],
|
||||
fts_results=[msg for msg, _ in fts_msg_results],
|
||||
get_id_func=lambda m: m["id"],
|
||||
vector_weight=0.6,
|
||||
fts_weight=0.4,
|
||||
@@ -1024,14 +1025,14 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
|
||||
assert len(combined_msgs) == 3
|
||||
msg_ids = [m["id"] for m, _ in combined_msgs]
|
||||
msg_ids = [m["id"] for m, _, _ in combined_msgs]
|
||||
assert "m1" in msg_ids
|
||||
assert "m2" in msg_ids
|
||||
assert "m3" in msg_ids
|
||||
|
||||
# Test edge cases
|
||||
# Empty results
|
||||
empty_combined = client._generic_reciprocal_rank_fusion(
|
||||
empty_combined = client._reciprocal_rank_fusion(
|
||||
vector_results=[],
|
||||
fts_results=[],
|
||||
get_id_func=lambda x: x["id"],
|
||||
@@ -1042,8 +1043,8 @@ class TestTurbopufferMessagesIntegration:
|
||||
assert len(empty_combined) == 0
|
||||
|
||||
# Single result list
|
||||
single_combined = client._generic_reciprocal_rank_fusion(
|
||||
vector_results=[(msg1, 0.9)],
|
||||
single_combined = client._reciprocal_rank_fusion(
|
||||
vector_results=[msg1],
|
||||
fts_results=[],
|
||||
get_id_func=lambda m: m["id"],
|
||||
vector_weight=0.5,
|
||||
@@ -1104,7 +1105,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
assert len(results) == 3
|
||||
# Results should be ordered by timestamp (most recent first)
|
||||
for msg_dict, score in results:
|
||||
for msg_dict, score, metadata in results:
|
||||
assert msg_dict["agent_id"] == agent_id
|
||||
assert msg_dict["organization_id"] == org_id
|
||||
assert msg_dict["text"] in message_texts
|
||||
@@ -1172,7 +1173,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
assert len(results) == 2
|
||||
# Should return Python-related messages first
|
||||
result_texts = [msg["text"] for msg, _ in results]
|
||||
result_texts = [msg["text"] for msg, _, _ in results]
|
||||
assert "Python is a great programming language" in result_texts
|
||||
assert "Machine learning with Python is powerful" in result_texts
|
||||
|
||||
@@ -1242,7 +1243,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
assert len(results) > 0
|
||||
# Should get a mix of results based on both vector and text similarity
|
||||
result_texts = [msg["text"] for msg, _ in results]
|
||||
result_texts = [msg["text"] for msg, _, _ in results]
|
||||
# At least one result should contain "quick" due to FTS
|
||||
assert any("quick" in text.lower() for text in result_texts)
|
||||
|
||||
@@ -1304,7 +1305,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
|
||||
assert len(user_results) == 2
|
||||
for msg, _ in user_results:
|
||||
for msg, _, _ in user_results:
|
||||
assert msg["role"] == "user"
|
||||
assert msg["text"] in ["I need help with Python", "Can you explain this?"]
|
||||
|
||||
@@ -1318,7 +1319,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
|
||||
assert len(non_user_results) == 3
|
||||
for msg, _ in non_user_results:
|
||||
for msg, _, _ in non_user_results:
|
||||
assert msg["role"] in ["assistant", "system"]
|
||||
|
||||
finally:
|
||||
@@ -1363,7 +1364,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
# Should return results from SQL search
|
||||
assert len(results) > 0
|
||||
# Extract text from messages and check for "fallback"
|
||||
for msg in results:
|
||||
for msg, metadata in results:
|
||||
text = server.message_manager._extract_message_text(msg)
|
||||
if "fallback" in text.lower():
|
||||
break
|
||||
@@ -1410,7 +1411,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(python_results) > 0
|
||||
assert any(msg.id == message_id for msg in python_results)
|
||||
assert any(msg.id == message_id for msg, metadata in python_results)
|
||||
|
||||
# Update the message content
|
||||
updated_message = await server.message_manager.update_message_by_id_async(
|
||||
@@ -1433,7 +1434,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
# Should either find no results or results that don't include our message
|
||||
assert not any(msg.id == message_id for msg in python_results_after)
|
||||
assert not any(msg.id == message_id for msg, metadata in python_results_after)
|
||||
|
||||
# Search for "JavaScript" - should find the updated message
|
||||
js_results = await server.message_manager.search_messages_async(
|
||||
@@ -1445,7 +1446,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(js_results) > 0
|
||||
assert any(msg.id == message_id for msg in js_results)
|
||||
assert any(msg.id == message_id for msg, metadata in js_results)
|
||||
|
||||
# Clean up
|
||||
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
|
||||
@@ -1561,7 +1562,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
assert len(agent_a_final) == 2
|
||||
# Verify the remaining messages are the correct ones
|
||||
remaining_ids = {msg.id for msg in agent_a_final}
|
||||
remaining_ids = {msg.id for msg, metadata in agent_a_final}
|
||||
assert agent_a_messages[3].id in remaining_ids
|
||||
assert agent_a_messages[4].id in remaining_ids
|
||||
|
||||
@@ -1616,7 +1617,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(initial_search) > 0
|
||||
assert any(msg.id == message_id for msg in initial_search)
|
||||
assert any(msg.id == message_id for msg, metadata in initial_search)
|
||||
|
||||
# Update message WITHOUT embedding_config - should update postgres but not turbopuffer
|
||||
updated_message = await server.message_manager.update_message_by_id_async(
|
||||
@@ -1642,7 +1643,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(still_searchable) > 0
|
||||
assert any(msg.id == message_id for msg in still_searchable)
|
||||
assert any(msg.id == message_id for msg, metadata in still_searchable)
|
||||
|
||||
# New content should NOT be searchable (wasn't re-indexed)
|
||||
not_searchable = await server.message_manager.search_messages_async(
|
||||
@@ -1654,7 +1655,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
# Should either find no results or results that don't include our message
|
||||
assert not any(msg.id == message_id for msg in not_searchable)
|
||||
assert not any(msg.id == message_id for msg, metadata in not_searchable)
|
||||
|
||||
# Clean up
|
||||
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
|
||||
@@ -1789,7 +1790,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
# Should get today's and yesterday's messages
|
||||
assert len(recent_results) == 2
|
||||
result_texts = [msg["text"] for msg, _ in recent_results]
|
||||
result_texts = [msg["text"] for msg, _, _ in recent_results]
|
||||
assert "Today's message" in result_texts
|
||||
assert "Yesterday's message" in result_texts
|
||||
|
||||
@@ -1820,7 +1821,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
# Should get only recent messages
|
||||
assert len(filtered_vector_results) == 2
|
||||
for msg, _ in filtered_vector_results:
|
||||
for msg, _, _ in filtered_vector_results:
|
||||
assert msg["text"] in ["Today's message", "Yesterday's message"]
|
||||
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user