Files
letta-server/letta/helpers/tpuf_client.py
2025-08-30 19:31:07 -07:00

466 lines
20 KiB
Python

"""Turbopuffer utilities for archival memory storage."""
import logging
from datetime import datetime, timezone
from typing import List, Optional, Tuple
from letta.otel.tracing import trace_method
from letta.schemas.enums import TagMatchMode
from letta.schemas.passage import Passage as PydanticPassage
from letta.settings import settings
logger = logging.getLogger(__name__)
def should_use_tpuf() -> bool:
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key)
class TurbopufferClient:
"""Client for managing archival memory with Turbopuffer vector database."""
def __init__(self, api_key: str = None, region: str = None):
"""Initialize Turbopuffer client."""
self.api_key = api_key or settings.tpuf_api_key
self.region = region or settings.tpuf_region
if not self.api_key:
raise ValueError("Turbopuffer API key not provided")
@trace_method
def _get_namespace_name(self, archive_id: str) -> str:
"""Get namespace name for a specific archive."""
# use archive_id as namespace to isolate different archives' memories
# append environment suffix to namespace for isolation if environment is set
environment = settings.environment
if environment:
namespace_name = f"{archive_id}_{environment.lower()}"
else:
namespace_name = archive_id
return namespace_name
@trace_method
async def insert_archival_memories(
self,
archive_id: str,
text_chunks: List[str],
embeddings: List[List[float]],
passage_ids: List[str],
organization_id: str,
tags: Optional[List[str]] = None,
created_at: Optional[datetime] = None,
) -> List[PydanticPassage]:
"""Insert passages into Turbopuffer.
Args:
archive_id: ID of the archive
text_chunks: List of text chunks to store
embeddings: List of embedding vectors corresponding to text chunks
passage_ids: List of passage IDs (must match 1:1 with text_chunks)
organization_id: Organization ID for the passages
tags: Optional list of tags to attach to all passages
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
Returns:
List of PydanticPassage objects that were inserted
"""
from turbopuffer import AsyncTurbopuffer
namespace_name = self._get_namespace_name(archive_id)
# handle timestamp - ensure UTC
if created_at is None:
timestamp = datetime.now(timezone.utc)
else:
# ensure the provided timestamp is timezone-aware and in UTC
if created_at.tzinfo is None:
# assume UTC if no timezone provided
timestamp = created_at.replace(tzinfo=timezone.utc)
else:
# convert to UTC if in different timezone
timestamp = created_at.astimezone(timezone.utc)
# passage_ids must be provided for dual-write consistency
if not passage_ids:
raise ValueError("passage_ids must be provided for Turbopuffer insertion")
if len(passage_ids) != len(text_chunks):
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})")
if len(passage_ids) != len(embeddings):
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match embeddings length ({len(embeddings)})")
# prepare column-based data for turbopuffer - optimized for batch insert
ids = []
vectors = []
texts = []
organization_ids = []
archive_ids = []
created_ats = []
tags_arrays = [] # Store tags as arrays
passages = []
for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)):
passage_id = passage_ids[idx]
# append to columns
ids.append(passage_id)
vectors.append(embedding)
texts.append(text)
organization_ids.append(organization_id)
archive_ids.append(archive_id)
created_ats.append(timestamp)
tags_arrays.append(tags or []) # Store tags as array
# Create PydanticPassage object
passage = PydanticPassage(
id=passage_id,
text=text,
organization_id=organization_id,
archive_id=archive_id,
created_at=timestamp,
metadata_={},
tags=tags or [], # Include tags in the passage
embedding=embedding,
embedding_config=None, # Will be set by caller if needed
)
passages.append(passage)
# build column-based upsert data
upsert_columns = {
"id": ids,
"vector": vectors,
"text": texts,
"organization_id": organization_ids,
"archive_id": archive_ids,
"created_at": created_ats,
"tags": tags_arrays, # Add tags as array column
}
try:
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# turbopuffer recommends column-based writes for performance
await namespace.write(
upsert_columns=upsert_columns,
distance_metric="cosine_distance",
schema={"text": {"type": "string", "full_text_search": True}},
)
logger.info(f"Successfully inserted {len(ids)} passages to Turbopuffer for archive {archive_id}")
return passages
except Exception as e:
logger.error(f"Failed to insert passages to Turbopuffer: {e}")
# check if it's a duplicate ID error
if "duplicate" in str(e).lower():
logger.error("Duplicate passage IDs detected in batch")
raise
@trace_method
async def query_passages(
self,
archive_id: str,
query_embedding: Optional[List[float]] = None,
query_text: Optional[str] = None,
search_mode: str = "vector", # "vector", "fts", "hybrid"
top_k: int = 10,
tags: Optional[List[str]] = None,
tag_match_mode: TagMatchMode = TagMatchMode.ANY,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> List[Tuple[PydanticPassage, float]]:
"""Query passages from Turbopuffer using vector search, full-text search, or hybrid search.
Args:
archive_id: ID of the archive
query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes)
query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
top_k: Number of results to return
tags: Optional list of tags to filter by
tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter passages created after this date
end_date: Optional datetime to filter passages created before this date
Returns:
List of (passage, score) tuples
"""
from turbopuffer import AsyncTurbopuffer
from turbopuffer.types import QueryParam
# validate inputs based on search mode first
if search_mode == "vector" and query_embedding is None:
raise ValueError("query_embedding is required for vector search mode")
if search_mode == "fts" and query_text is None:
raise ValueError("query_text is required for FTS search mode")
if search_mode == "hybrid":
if query_embedding is None or query_text is None:
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
if search_mode not in ["vector", "fts", "hybrid", "timestamp"]:
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'")
# Check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
# Fallback to retrieving most recent passages when no search query is provided
search_mode = "timestamp"
namespace_name = self._get_namespace_name(archive_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# build tag filter conditions
tag_filter = None
if tags:
if tag_match_mode == TagMatchMode.ALL:
# For ALL mode, need to check each tag individually with Contains
tag_conditions = []
for tag in tags:
tag_conditions.append(("tags", "Contains", tag))
if len(tag_conditions) == 1:
tag_filter = tag_conditions[0]
else:
tag_filter = ("And", tag_conditions)
else: # tag_match_mode == TagMatchMode.ANY
# For ANY mode, use ContainsAny to match any of the tags
tag_filter = ("tags", "ContainsAny", tags)
# build date filter conditions
date_filters = []
if start_date:
# Turbopuffer expects datetime objects directly for comparison
date_filters.append(("created_at", "Gte", start_date))
if end_date:
# Turbopuffer expects datetime objects directly for comparison
date_filters.append(("created_at", "Lte", end_date))
# combine all filters
all_filters = []
if tag_filter:
all_filters.append(tag_filter)
if date_filters:
all_filters.extend(date_filters)
# create final filter expression
final_filter = None
if len(all_filters) == 1:
final_filter = all_filters[0]
elif len(all_filters) > 1:
final_filter = ("And", all_filters)
if search_mode == "timestamp":
# Fallback: retrieve most recent passages by timestamp
query_params = {
"rank_by": ("created_at", "desc"), # Order by created_at in descending order
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if final_filter:
query_params["filters"] = final_filter
result = await namespace.query(**query_params)
return self._process_single_query_results(result, archive_id, tags)
elif search_mode == "vector":
# single vector search query
query_params = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if final_filter:
query_params["filters"] = final_filter
result = await namespace.query(**query_params)
return self._process_single_query_results(result, archive_id, tags)
elif search_mode == "fts":
# single full-text search query
query_params = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if final_filter:
query_params["filters"] = final_filter
result = await namespace.query(**query_params)
return self._process_single_query_results(result, archive_id, tags, is_fts=True)
else: # hybrid mode
# multi-query for both vector and FTS
queries = []
# vector search query
vector_query = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if final_filter:
vector_query["filters"] = final_filter
queries.append(vector_query)
# full-text search query
fts_query = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if final_filter:
fts_query["filters"] = final_filter
queries.append(fts_query)
# execute multi-query
response = await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
# process and combine results using reciprocal rank fusion
vector_results = self._process_single_query_results(response.results[0], archive_id, tags)
fts_results = self._process_single_query_results(response.results[1], archive_id, tags, is_fts=True)
# combine results using reciprocal rank fusion
return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k)
except Exception as e:
logger.error(f"Failed to query passages from Turbopuffer: {e}")
raise
def _process_single_query_results(
self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False
) -> List[Tuple[PydanticPassage, float]]:
"""Process results from a single query into passage objects with scores."""
passages_with_scores = []
for row in result.rows:
# Extract tags from the result row
passage_tags = getattr(row, "tags", []) or []
# Build metadata
metadata = {}
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
passage = PydanticPassage(
id=row.id,
text=getattr(row, "text", ""),
organization_id=getattr(row, "organization_id", None),
archive_id=archive_id, # use the archive_id from the query
created_at=getattr(row, "created_at", None),
metadata_=metadata,
tags=passage_tags, # Set the actual tags from the passage
# Set required fields to empty/default values since we don't store embeddings
embedding=[], # Empty embedding since we don't return it from Turbopuffer
embedding_config=None, # No embedding config needed for retrieved passages
)
# handle score based on search type
if is_fts:
# for FTS, use the BM25 score directly (higher is better)
score = getattr(row, "$score", 0.0)
else:
# for vector search, convert distance to similarity score
distance = getattr(row, "$dist", 0.0)
score = 1.0 - distance
passages_with_scores.append((passage, score))
return passages_with_scores
def _reciprocal_rank_fusion(
self,
vector_results: List[Tuple[PydanticPassage, float]],
fts_results: List[Tuple[PydanticPassage, float]],
vector_weight: float,
fts_weight: float,
top_k: int,
) -> List[Tuple[PydanticPassage, float]]:
"""Combine vector and FTS results using Reciprocal Rank Fusion (RRF).
RRF score = vector_weight * (1/(k + vector_rank)) + fts_weight * (1/(k + fts_rank))
where k is a constant (typically 60) to avoid division by zero
"""
k = 60 # standard RRF constant
# create rank mappings
vector_ranks = {passage.id: rank + 1 for rank, (passage, _) in enumerate(vector_results)}
fts_ranks = {passage.id: rank + 1 for rank, (passage, _) in enumerate(fts_results)}
# combine all unique passage IDs
all_passages = {}
for passage, _ in vector_results:
all_passages[passage.id] = passage
for passage, _ in fts_results:
all_passages[passage.id] = passage
# calculate RRF scores
rrf_scores = {}
for passage_id in all_passages:
vector_score = vector_weight / (k + vector_ranks.get(passage_id, k + top_k))
fts_score = fts_weight / (k + fts_ranks.get(passage_id, k + top_k))
rrf_scores[passage_id] = vector_score + fts_score
# sort by RRF score and return top_k
sorted_results = sorted([(all_passages[pid], score) for pid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True)
return sorted_results[:top_k]
@trace_method
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
"""Delete a passage from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = self._get_namespace_name(archive_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# Use write API with deletes parameter as per Turbopuffer docs
await namespace.write(deletes=[passage_id])
logger.info(f"Successfully deleted passage {passage_id} from Turbopuffer archive {archive_id}")
return True
except Exception as e:
logger.error(f"Failed to delete passage from Turbopuffer: {e}")
raise
@trace_method
async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool:
"""Delete multiple passages from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
if not passage_ids:
return True
namespace_name = self._get_namespace_name(archive_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# Use write API with deletes parameter as per Turbopuffer docs
await namespace.write(deletes=passage_ids)
logger.info(f"Successfully deleted {len(passage_ids)} passages from Turbopuffer archive {archive_id}")
return True
except Exception as e:
logger.error(f"Failed to delete passages from Turbopuffer: {e}")
raise
@trace_method
async def delete_all_passages(self, archive_id: str) -> bool:
"""Delete all passages for an archive from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = self._get_namespace_name(archive_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# Turbopuffer has a delete_all() method on namespace
await namespace.delete_all()
logger.info(f"Successfully deleted all passages for archive {archive_id}")
return True
except Exception as e:
logger.error(f"Failed to delete all passages from Turbopuffer: {e}")
raise