feat: Add hybrid search on turbopuffer [LET-4096] (#4284)
Add hybrid search
This commit is contained in:
@@ -145,7 +145,11 @@ class TurbopufferClient:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
# turbopuffer recommends column-based writes for performance
|
||||
await namespace.write(upsert_columns=upsert_columns, distance_metric="cosine_distance")
|
||||
await namespace.write(
|
||||
upsert_columns=upsert_columns,
|
||||
distance_metric="cosine_distance",
|
||||
schema={"text": {"type": "string", "full_text_search": True}},
|
||||
)
|
||||
logger.info(f"Successfully inserted {len(ids)} passages to Turbopuffer for archive {archive_id}")
|
||||
return passages
|
||||
|
||||
@@ -158,10 +162,44 @@ class TurbopufferClient:
|
||||
|
||||
@trace_method
|
||||
async def query_passages(
|
||||
self, archive_id: str, query_embedding: List[float], top_k: int = 10, filters: Optional[Dict[str, Any]] = None
|
||||
self,
|
||||
archive_id: str,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
||||
top_k: int = 10,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
) -> List[Tuple[PydanticPassage, float]]:
|
||||
"""Query passages from Turbopuffer."""
|
||||
"""Query passages from Turbopuffer using vector search, full-text search, or hybrid search.
|
||||
|
||||
Args:
|
||||
archive_id: ID of the archive
|
||||
query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes)
|
||||
query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
|
||||
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
|
||||
top_k: Number of results to return
|
||||
filters: Optional filter conditions
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
|
||||
Returns:
|
||||
List of (passage, score) tuples
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
from turbopuffer.types import QueryParam
|
||||
|
||||
# validate inputs based on search mode
|
||||
if search_mode == "vector" and query_embedding is None:
|
||||
raise ValueError("query_embedding is required for vector search mode")
|
||||
if search_mode == "fts" and query_text is None:
|
||||
raise ValueError("query_text is required for FTS search mode")
|
||||
if search_mode == "hybrid":
|
||||
if query_embedding is None or query_text is None:
|
||||
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
|
||||
if search_mode not in ["vector", "fts", "hybrid"]:
|
||||
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', or 'hybrid'")
|
||||
|
||||
namespace_name = self._get_namespace_name(archive_id)
|
||||
|
||||
@@ -175,49 +213,150 @@ class TurbopufferClient:
|
||||
for key, value in filters.items():
|
||||
filter_conditions.append((key, "Eq", value))
|
||||
|
||||
query_params = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
}
|
||||
base_filter = (
|
||||
("And", filter_conditions) if len(filter_conditions) > 1 else (filter_conditions[0] if filter_conditions else None)
|
||||
)
|
||||
|
||||
if filter_conditions:
|
||||
query_params["filters"] = ("And", filter_conditions) if len(filter_conditions) > 1 else filter_conditions[0]
|
||||
if search_mode == "vector":
|
||||
# single vector search query
|
||||
query_params = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
}
|
||||
if base_filter:
|
||||
query_params["filters"] = base_filter
|
||||
|
||||
result = await namespace.query(**query_params)
|
||||
result = await namespace.query(**query_params)
|
||||
return self._process_single_query_results(result, archive_id, filters)
|
||||
|
||||
# convert results back to passages
|
||||
passages_with_scores = []
|
||||
# Turbopuffer returns a NamespaceQueryResponse with a rows attribute
|
||||
for row in result.rows:
|
||||
# Build metadata including any filter conditions that were applied
|
||||
metadata = {}
|
||||
if filters:
|
||||
metadata["applied_filters"] = filters
|
||||
elif search_mode == "fts":
|
||||
# single full-text search query
|
||||
query_params = {
|
||||
"rank_by": ("text", "BM25", query_text),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
}
|
||||
if base_filter:
|
||||
query_params["filters"] = base_filter
|
||||
|
||||
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
|
||||
passage = PydanticPassage(
|
||||
id=row.id,
|
||||
text=getattr(row, "text", ""),
|
||||
organization_id=getattr(row, "organization_id", None),
|
||||
archive_id=archive_id, # use the archive_id from the query
|
||||
created_at=getattr(row, "created_at", None),
|
||||
metadata_=metadata, # Include filter conditions in metadata
|
||||
# Set required fields to empty/default values since we don't store embeddings
|
||||
embedding=[], # Empty embedding since we don't return it from Turbopuffer
|
||||
embedding_config=None, # No embedding config needed for retrieved passages
|
||||
)
|
||||
# turbopuffer returns distance in $dist attribute, convert to similarity score
|
||||
distance = getattr(row, "$dist", 0.0)
|
||||
score = 1.0 - distance
|
||||
passages_with_scores.append((passage, score))
|
||||
result = await namespace.query(**query_params)
|
||||
return self._process_single_query_results(result, archive_id, filters, is_fts=True)
|
||||
|
||||
return passages_with_scores
|
||||
else: # hybrid mode
|
||||
# multi-query for both vector and FTS
|
||||
queries = []
|
||||
|
||||
# vector search query
|
||||
vector_query = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
}
|
||||
if base_filter:
|
||||
vector_query["filters"] = base_filter
|
||||
queries.append(vector_query)
|
||||
|
||||
# full-text search query
|
||||
fts_query = {
|
||||
"rank_by": ("text", "BM25", query_text),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
}
|
||||
if base_filter:
|
||||
fts_query["filters"] = base_filter
|
||||
queries.append(fts_query)
|
||||
|
||||
# execute multi-query
|
||||
response = await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
|
||||
|
||||
# process and combine results using reciprocal rank fusion
|
||||
vector_results = self._process_single_query_results(response.results[0], archive_id, filters)
|
||||
fts_results = self._process_single_query_results(response.results[1], archive_id, filters, is_fts=True)
|
||||
|
||||
# combine results using reciprocal rank fusion
|
||||
return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query passages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
def _process_single_query_results(
|
||||
self, result, archive_id: str, filters: Optional[Dict[str, Any]], is_fts: bool = False
|
||||
) -> List[Tuple[PydanticPassage, float]]:
|
||||
"""Process results from a single query into passage objects with scores."""
|
||||
passages_with_scores = []
|
||||
|
||||
for row in result.rows:
|
||||
# Build metadata including any filter conditions that were applied
|
||||
metadata = {}
|
||||
if filters:
|
||||
metadata["applied_filters"] = filters
|
||||
|
||||
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
|
||||
passage = PydanticPassage(
|
||||
id=row.id,
|
||||
text=getattr(row, "text", ""),
|
||||
organization_id=getattr(row, "organization_id", None),
|
||||
archive_id=archive_id, # use the archive_id from the query
|
||||
created_at=getattr(row, "created_at", None),
|
||||
metadata_=metadata, # Include filter conditions in metadata
|
||||
# Set required fields to empty/default values since we don't store embeddings
|
||||
embedding=[], # Empty embedding since we don't return it from Turbopuffer
|
||||
embedding_config=None, # No embedding config needed for retrieved passages
|
||||
)
|
||||
|
||||
# handle score based on search type
|
||||
if is_fts:
|
||||
# for FTS, use the BM25 score directly (higher is better)
|
||||
score = getattr(row, "$score", 0.0)
|
||||
else:
|
||||
# for vector search, convert distance to similarity score
|
||||
distance = getattr(row, "$dist", 0.0)
|
||||
score = 1.0 - distance
|
||||
|
||||
passages_with_scores.append((passage, score))
|
||||
|
||||
return passages_with_scores
|
||||
|
||||
def _reciprocal_rank_fusion(
|
||||
self,
|
||||
vector_results: List[Tuple[PydanticPassage, float]],
|
||||
fts_results: List[Tuple[PydanticPassage, float]],
|
||||
vector_weight: float,
|
||||
fts_weight: float,
|
||||
top_k: int,
|
||||
) -> List[Tuple[PydanticPassage, float]]:
|
||||
"""Combine vector and FTS results using Reciprocal Rank Fusion (RRF).
|
||||
|
||||
RRF score = vector_weight * (1/(k + vector_rank)) + fts_weight * (1/(k + fts_rank))
|
||||
where k is a constant (typically 60) to avoid division by zero
|
||||
"""
|
||||
k = 60 # standard RRF constant
|
||||
|
||||
# create rank mappings
|
||||
vector_ranks = {passage.id: rank + 1 for rank, (passage, _) in enumerate(vector_results)}
|
||||
fts_ranks = {passage.id: rank + 1 for rank, (passage, _) in enumerate(fts_results)}
|
||||
|
||||
# combine all unique passage IDs
|
||||
all_passages = {}
|
||||
for passage, _ in vector_results:
|
||||
all_passages[passage.id] = passage
|
||||
for passage, _ in fts_results:
|
||||
all_passages[passage.id] = passage
|
||||
|
||||
# calculate RRF scores
|
||||
rrf_scores = {}
|
||||
for passage_id in all_passages:
|
||||
vector_score = vector_weight / (k + vector_ranks.get(passage_id, k + top_k))
|
||||
fts_score = fts_weight / (k + fts_ranks.get(passage_id, k + top_k))
|
||||
rrf_scores[passage_id] = vector_score + fts_score
|
||||
|
||||
# sort by RRF score and return top_k
|
||||
sorted_results = sorted([(all_passages[pid], score) for pid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True)
|
||||
|
||||
return sorted_results[:top_k]
|
||||
|
||||
@trace_method
|
||||
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
|
||||
"""Delete a passage from Turbopuffer."""
|
||||
|
||||
@@ -2671,11 +2671,14 @@ class AgentManager:
|
||||
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
|
||||
query_embedding = embeddings[0]
|
||||
|
||||
# Query Turbopuffer
|
||||
# Query Turbopuffer - use hybrid search when text is available
|
||||
tpuf_client = TurbopufferClient()
|
||||
# use hybrid search to combine vector and full-text search
|
||||
passages_with_scores = await tpuf_client.query_passages(
|
||||
archive_id=archive_ids[0],
|
||||
query_embedding=query_embedding,
|
||||
query_text=query_text, # pass text for potential hybrid search
|
||||
search_mode="hybrid", # use hybrid mode for better results
|
||||
top_k=limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -289,6 +289,111 @@ class TestTurbopufferIntegration:
|
||||
# Should still work with native PostgreSQL
|
||||
assert isinstance(vector_results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
|
||||
async def test_hybrid_search_with_real_tpuf(self, enable_turbopuffer):
|
||||
"""Test hybrid search functionality combining vector and full-text search"""
|
||||
|
||||
import uuid
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
client = TurbopufferClient()
|
||||
archive_id = f"test-hybrid-{datetime.now().timestamp()}"
|
||||
org_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Insert test passages with different characteristics
|
||||
texts = [
|
||||
"Turbopuffer is a vector database optimized for high-performance similarity search",
|
||||
"The quick brown fox jumps over the lazy dog",
|
||||
"Machine learning models require vector embeddings for semantic search",
|
||||
"Database optimization techniques improve query performance",
|
||||
"Turbopuffer supports both vector and full-text search capabilities",
|
||||
]
|
||||
|
||||
# Create simple embeddings for testing (normally you'd use a real embedding model)
|
||||
embeddings = [[float(i), float(i + 5), float(i + 10)] for i in range(len(texts))]
|
||||
passage_ids = [f"passage-{str(uuid.uuid4())}" for _ in texts]
|
||||
|
||||
# Insert passages
|
||||
await client.insert_archival_memories(
|
||||
archive_id=archive_id, text_chunks=texts, embeddings=embeddings, passage_ids=passage_ids, organization_id=org_id
|
||||
)
|
||||
|
||||
# Test vector-only search
|
||||
vector_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 6.0, 11.0], # similar to second passage embedding
|
||||
search_mode="vector",
|
||||
top_k=3,
|
||||
)
|
||||
assert 0 < len(vector_results) <= 3
|
||||
# all results should have scores
|
||||
assert all(isinstance(score, float) for _, score in vector_results)
|
||||
|
||||
# Test FTS-only search
|
||||
fts_results = await client.query_passages(
|
||||
archive_id=archive_id, query_text="Turbopuffer vector database", search_mode="fts", top_k=3
|
||||
)
|
||||
assert 0 < len(fts_results) <= 3
|
||||
# should find passages mentioning Turbopuffer
|
||||
assert any("Turbopuffer" in passage.text for passage, _ in fts_results)
|
||||
# all results should have scores
|
||||
assert all(isinstance(score, float) for _, score in fts_results)
|
||||
|
||||
# Test hybrid search
|
||||
hybrid_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[2.0, 7.0, 12.0],
|
||||
query_text="vector search Turbopuffer",
|
||||
search_mode="hybrid",
|
||||
top_k=3,
|
||||
vector_weight=0.5,
|
||||
fts_weight=0.5,
|
||||
)
|
||||
assert 0 < len(hybrid_results) <= 3
|
||||
# hybrid should combine both vector and text relevance
|
||||
assert any("Turbopuffer" in passage.text or "vector" in passage.text for passage, _ in hybrid_results)
|
||||
# all results should have scores
|
||||
assert all(isinstance(score, float) for _, score in hybrid_results)
|
||||
# results should be sorted by score (highest first)
|
||||
scores = [score for _, score in hybrid_results]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
# Test with different weights
|
||||
vector_heavy_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[0.0, 5.0, 10.0], # very similar to first passage
|
||||
query_text="quick brown fox", # matches second passage
|
||||
search_mode="hybrid",
|
||||
top_k=3,
|
||||
vector_weight=0.8, # emphasize vector search
|
||||
fts_weight=0.2,
|
||||
)
|
||||
assert 0 < len(vector_heavy_results) <= 3
|
||||
# all results should have scores
|
||||
assert all(isinstance(score, float) for _, score in vector_heavy_results)
|
||||
|
||||
# Test error handling - missing embedding for vector mode
|
||||
with pytest.raises(ValueError, match="query_embedding is required"):
|
||||
await client.query_passages(archive_id=archive_id, search_mode="vector", top_k=3)
|
||||
|
||||
# Test error handling - missing text for FTS mode
|
||||
with pytest.raises(ValueError, match="query_text is required"):
|
||||
await client.query_passages(archive_id=archive_id, search_mode="fts", top_k=3)
|
||||
|
||||
# Test error handling - missing both for hybrid mode
|
||||
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
|
||||
await client.query_passages(archive_id=archive_id, query_embedding=[1.0, 2.0, 3.0], search_mode="hybrid", top_k=3)
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
try:
|
||||
await client.delete_all_passages(archive_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True)
|
||||
class TestTurbopufferParametrized:
|
||||
|
||||
Reference in New Issue
Block a user