Files
letta-server/letta/helpers/tpuf_client.py
Kian Jones 6f746c5225 fix(core): handle Anthropic overloaded errors and Unicode encoding issues (#9305)
* fix: handle Anthropic overloaded_error in streaming interfaces

* fix: handle Unicode surrogates in OpenAI requests

Sanitize Unicode surrogate pairs before sending requests to OpenAI API.
Surrogate pairs (U+D800-U+DFFF) are UTF-16 encoding artifacts that cause
UnicodeEncodeError when encoding to UTF-8.

Fixes Datadog error: 'utf-8' codec can't encode character '\ud83c' in
position 326605: surrogates not allowed

* fix: handle UnicodeEncodeError from lone Unicode surrogates in OpenAI requests

Improved sanitize_unicode_surrogates() to explicitly filter out lone
surrogate characters (U+D800 to U+DFFF) which are invalid in UTF-8.

Previous implementation used errors='ignore' which could still fail in
edge cases. New approach directly checks Unicode code points and removes
any surrogates before data reaches httpx encoding.

Also added sanitization to stream_async_responses() method which was
missing it.

Fixes: 'utf-8' codec can't encode character '\ud83c' in position X:
surrogates not allowed
2026-02-24 10:52:06 -08:00

2117 lines
86 KiB
Python

"""Turbopuffer utilities for archival memory storage."""
import asyncio
import json
import logging
import random
from datetime import datetime, timezone
from functools import wraps
from typing import Any, Callable, List, Optional, Tuple, TypeVar
import httpx
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
from letta.errors import LettaInvalidArgumentError
from letta.otel.tracing import log_event, trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, TagMatchMode
from letta.schemas.passage import Passage as PydanticPassage
from letta.settings import model_settings, settings
logger = logging.getLogger(__name__)
# Type variable for generic async retry decorator
T = TypeVar("T")
# Default retry configuration for turbopuffer operations
TPUF_MAX_RETRIES = 3
TPUF_INITIAL_DELAY = 1.0 # seconds
TPUF_EXPONENTIAL_BASE = 2.0
TPUF_JITTER = True
def is_transient_error(error: Exception) -> bool:
"""Check if an error is transient and should be retried.
Args:
error: The exception to check
Returns:
True if the error is transient and can be retried
"""
# httpx connection errors (network issues, DNS failures, etc.)
if isinstance(error, httpx.ConnectError):
return True
# httpx timeout errors
if isinstance(error, httpx.TimeoutException):
return True
# httpx network errors
if isinstance(error, httpx.NetworkError):
return True
# Check for connection-related errors in the error message
error_str = str(error).lower()
transient_patterns = [
"connect call failed",
"connection refused",
"connection reset",
"connection timed out",
"temporary failure",
"name resolution",
"dns",
"network unreachable",
"no route to host",
"ssl handshake",
]
for pattern in transient_patterns:
if pattern in error_str:
return True
return False
def async_retry_with_backoff(
max_retries: int = TPUF_MAX_RETRIES,
initial_delay: float = TPUF_INITIAL_DELAY,
exponential_base: float = TPUF_EXPONENTIAL_BASE,
jitter: bool = TPUF_JITTER,
):
"""Decorator for async functions that retries on transient errors with exponential backoff.
Args:
max_retries: Maximum number of retry attempts
initial_delay: Initial delay between retries in seconds
exponential_base: Base for exponential backoff calculation
jitter: Whether to add random jitter to delays
Returns:
Decorated async function with retry logic
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
async def wrapper(*args, **kwargs) -> Any:
num_retries = 0
delay = initial_delay
last_error: Optional[Exception] = None
while True:
try:
return await func(*args, **kwargs)
except Exception as e:
# Check if this is a retryable error
if not is_transient_error(e):
# Not a transient error, re-raise immediately
raise
last_error = e
num_retries += 1
# Log the retry attempt
log_event(
"turbopuffer_retry_attempt",
{
"attempt": num_retries,
"delay": delay,
"error_type": type(e).__name__,
"error": str(e),
"function": func.__name__,
},
)
logger.warning(
f"Turbopuffer operation '{func.__name__}' failed with transient error "
f"(attempt {num_retries}/{max_retries}): {e}. Retrying in {delay:.1f}s..."
)
# Check if max retries exceeded
if num_retries > max_retries:
log_event(
"turbopuffer_max_retries_exceeded",
{
"max_retries": max_retries,
"error_type": type(e).__name__,
"error": str(e),
"function": func.__name__,
},
)
logger.error(f"Turbopuffer operation '{func.__name__}' failed after {max_retries} retries: {e}")
raise
# Wait with exponential backoff
await asyncio.sleep(delay)
# Calculate next delay with optional jitter
delay *= exponential_base
if jitter:
delay *= 1 + random.random() * 0.1 # Add up to 10% jitter
return wrapper
return decorator
# Global semaphore for Turbopuffer operations to prevent overwhelming the service
# This is separate from embedding semaphore since Turbopuffer can handle more concurrency
_GLOBAL_TURBOPUFFER_SEMAPHORE = asyncio.Semaphore(5)
def _run_turbopuffer_write_in_thread(
api_key: str,
region: str,
namespace_name: str,
upsert_columns: dict = None,
deletes: list = None,
delete_by_filter: tuple = None,
distance_metric: str = "cosine_distance",
schema: dict = None,
):
"""
Sync wrapper to run turbopuffer write in isolated event loop.
Turbopuffer's async write() does CPU-intensive base64 encoding of vectors
synchronously in async functions, blocking the event loop. Running it in
a thread pool with an isolated event loop prevents blocking.
"""
from turbopuffer import AsyncTurbopuffer
# Create new event loop for this worker thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def do_write():
async with AsyncTurbopuffer(api_key=api_key, region=region) as client:
namespace = client.namespace(namespace_name)
# Build write kwargs
kwargs = {"distance_metric": distance_metric}
if upsert_columns:
kwargs["upsert_columns"] = upsert_columns
if deletes:
kwargs["deletes"] = deletes
if delete_by_filter:
kwargs["delete_by_filter"] = delete_by_filter
if schema:
kwargs["schema"] = schema
return await namespace.write(**kwargs)
return loop.run_until_complete(do_write())
finally:
loop.close()
def should_use_tpuf() -> bool:
# We need OpenAI since we default to their embedding model
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key) and bool(model_settings.openai_api_key)
def should_use_tpuf_for_messages() -> bool:
"""Check if Turbopuffer should be used for messages."""
return should_use_tpuf() and bool(settings.embed_all_messages)
def should_use_tpuf_for_tools() -> bool:
"""Check if Turbopuffer should be used for tools."""
return should_use_tpuf() and bool(settings.embed_tools)
class TurbopufferClient:
"""Client for managing archival memory with Turbopuffer vector database."""
default_embedding_config = EmbeddingConfig(
embedding_model="text-embedding-3-small",
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
)
def __init__(self, api_key: str = None, region: str = None):
"""Initialize Turbopuffer client."""
self.api_key = api_key or settings.tpuf_api_key
self.region = region or settings.tpuf_region
from letta.services.agent_manager import AgentManager
from letta.services.archive_manager import ArchiveManager
self.archive_manager = ArchiveManager()
self.agent_manager = AgentManager()
if not self.api_key:
raise ValueError("Turbopuffer API key not provided")
@trace_method
async def _generate_embeddings(self, texts: List[str], actor: "PydanticUser") -> List[List[float]]:
"""Generate embeddings using the default embedding configuration.
Args:
texts: List of texts to embed
actor: User actor for embedding generation
Returns:
List of embedding vectors
"""
from letta.llm_api.llm_client import LLMClient
# filter out empty strings after stripping
filtered_texts = [text for text in texts if text.strip()]
# skip embedding if no valid texts
if not filtered_texts:
return []
embedding_client = LLMClient.create(
provider_type=self.default_embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings(filtered_texts, self.default_embedding_config)
return embeddings
@trace_method
async def _get_archive_namespace_name(self, archive_id: str) -> str:
"""Get namespace name for a specific archive."""
return await self.archive_manager.get_or_set_vector_db_namespace_async(archive_id)
@trace_method
async def _get_message_namespace_name(self, organization_id: str) -> str:
"""Get namespace name for messages (org-scoped).
Args:
organization_id: Organization ID for namespace generation
Returns:
The org-scoped namespace name for messages
"""
environment = settings.environment
if environment:
namespace_name = f"messages_{organization_id}_{environment.lower()}"
else:
namespace_name = f"messages_{organization_id}"
return namespace_name
@trace_method
async def _get_tool_namespace_name(self, organization_id: str) -> str:
"""Get namespace name for tools (org-scoped).
Args:
organization_id: Organization ID for namespace generation
Returns:
The org-scoped namespace name for tools
"""
environment = settings.environment
if environment:
namespace_name = f"tools_{organization_id}_{environment.lower()}"
else:
namespace_name = f"tools_{organization_id}"
return namespace_name
def _extract_tool_text(self, tool: "PydanticTool") -> str:
"""Extract searchable text from a tool for embedding.
Combines name, description, and JSON schema into a structured format
that provides rich context for semantic search.
Args:
tool: The tool to extract text from
Returns:
JSON-formatted string containing tool information
"""
parts = {
"name": tool.name or "",
"description": tool.description or "",
}
# Extract parameter information from JSON schema
if tool.json_schema:
# Include function description from schema if different from tool description
schema_description = tool.json_schema.get("description", "")
if schema_description and schema_description != tool.description:
parts["schema_description"] = schema_description
# Extract parameter information
parameters = tool.json_schema.get("parameters", {})
if parameters:
properties = parameters.get("properties", {})
param_descriptions = []
for param_name, param_info in properties.items():
param_desc = param_info.get("description", "")
param_type = param_info.get("type", "any")
if param_desc:
param_descriptions.append(f"{param_name} ({param_type}): {param_desc}")
else:
param_descriptions.append(f"{param_name} ({param_type})")
if param_descriptions:
parts["parameters"] = param_descriptions
# Include tags for additional context
if tool.tags:
parts["tags"] = tool.tags
return json.dumps(parts)
@trace_method
@async_retry_with_backoff()
async def insert_tools(
self,
tools: List["PydanticTool"],
organization_id: str,
actor: "PydanticUser",
) -> bool:
"""Insert tools into Turbopuffer.
Args:
tools: List of tools to store
organization_id: Organization ID for the tools
actor: User actor for embedding generation
Returns:
True if successful
"""
from turbopuffer import AsyncTurbopuffer
if not tools:
return True
# Extract text and filter out empty content
tool_texts = []
valid_tools = []
for tool in tools:
text = self._extract_tool_text(tool)
if text.strip():
tool_texts.append(text)
valid_tools.append(tool)
if not valid_tools:
logger.warning("All tools had empty text content, skipping insertion")
return True
# Generate embeddings
embeddings = await self._generate_embeddings(tool_texts, actor)
namespace_name = await self._get_tool_namespace_name(organization_id)
# Prepare column-based data
ids = []
vectors = []
texts = []
names = []
organization_ids = []
tool_types = []
tags_arrays = []
created_ats = []
for tool, text, embedding in zip(valid_tools, tool_texts, embeddings):
ids.append(tool.id)
vectors.append(embedding)
texts.append(text)
names.append(tool.name or "")
organization_ids.append(organization_id)
tool_types.append(tool.tool_type.value if tool.tool_type else "custom")
tags_arrays.append(tool.tags or [])
created_ats.append(getattr(tool, "created_at", None) or datetime.now(timezone.utc))
upsert_columns = {
"id": ids,
"vector": vectors,
"text": texts,
"name": names,
"organization_id": organization_ids,
"tool_type": tool_types,
"tags": tags_arrays,
"created_at": created_ats,
}
try:
# Use global semaphore to limit concurrent Turbopuffer writes
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
# Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop
await asyncio.to_thread(
_run_turbopuffer_write_in_thread,
api_key=self.api_key,
region=self.region,
namespace_name=namespace_name,
upsert_columns=upsert_columns,
distance_metric="cosine_distance",
schema={"text": {"type": "string", "full_text_search": True}},
)
logger.info(f"Successfully inserted {len(ids)} tools to Turbopuffer")
return True
except Exception as e:
logger.error(f"Failed to insert tools to Turbopuffer: {e}")
raise
@trace_method
@async_retry_with_backoff()
async def insert_archival_memories(
self,
archive_id: str,
text_chunks: List[str],
passage_ids: List[str],
organization_id: str,
actor: "PydanticUser",
tags: Optional[List[str]] = None,
created_at: Optional[datetime] = None,
embeddings: Optional[List[List[float]]] = None,
) -> List[PydanticPassage]:
"""Insert passages into Turbopuffer.
Args:
archive_id: ID of the archive
text_chunks: List of text chunks to store
passage_ids: List of passage IDs (must match 1:1 with text_chunks)
organization_id: Organization ID for the passages
actor: User actor for embedding generation
tags: Optional list of tags to attach to all passages
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
embeddings: Optional pre-computed embeddings (must match 1:1 with text_chunks). If provided, skips embedding generation.
Returns:
List of PydanticPassage objects that were inserted
"""
from turbopuffer import AsyncTurbopuffer
# filter out empty text chunks
filtered_chunks = [(i, text) for i, text in enumerate(text_chunks) if text.strip()]
if not filtered_chunks:
logger.warning("All text chunks were empty, skipping insertion")
return []
filtered_texts = [text for _, text in filtered_chunks]
# use provided embeddings only if dimensions match TPUF's expected dimension
use_provided_embeddings = False
if embeddings is not None:
if len(embeddings) != len(text_chunks):
raise LettaInvalidArgumentError(
f"embeddings length ({len(embeddings)}) must match text_chunks length ({len(text_chunks)})",
argument_name="embeddings",
)
# check if first non-empty embedding has correct dimensions
filtered_indices = [i for i, _ in filtered_chunks]
sample_embedding = embeddings[filtered_indices[0]] if filtered_indices else None
if sample_embedding is not None and len(sample_embedding) == self.default_embedding_config.embedding_dim:
use_provided_embeddings = True
filtered_embeddings = [embeddings[i] for i, _ in filtered_chunks]
else:
logger.debug(
f"Embedding dimension mismatch (got {len(sample_embedding) if sample_embedding else 'None'}, "
f"expected {self.default_embedding_config.embedding_dim}), regenerating embeddings"
)
if not use_provided_embeddings:
filtered_embeddings = await self._generate_embeddings(filtered_texts, actor)
namespace_name = await self._get_archive_namespace_name(archive_id)
# handle timestamp - ensure UTC
if created_at is None:
timestamp = datetime.now(timezone.utc)
else:
# ensure the provided timestamp is timezone-aware and in UTC
if created_at.tzinfo is None:
# assume UTC if no timezone provided
timestamp = created_at.replace(tzinfo=timezone.utc)
else:
# convert to UTC if in different timezone
timestamp = created_at.astimezone(timezone.utc)
# passage_ids must be provided for dual-write consistency
if not passage_ids:
raise ValueError("passage_ids must be provided for Turbopuffer insertion")
if len(passage_ids) != len(text_chunks):
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})")
# prepare column-based data for turbopuffer - optimized for batch insert
ids = []
vectors = []
texts = []
organization_ids = []
archive_ids = []
created_ats = []
tags_arrays = [] # Store tags as arrays
passages = []
for (original_idx, text), embedding in zip(filtered_chunks, filtered_embeddings):
passage_id = passage_ids[original_idx]
# append to columns
ids.append(passage_id)
vectors.append(embedding)
texts.append(text)
organization_ids.append(organization_id)
archive_ids.append(archive_id)
created_ats.append(timestamp)
tags_arrays.append(tags or []) # Store tags as array
# Create PydanticPassage object
passage = PydanticPassage(
id=passage_id,
text=text,
organization_id=organization_id,
archive_id=archive_id,
created_at=timestamp,
metadata_={},
tags=tags or [], # Include tags in the passage
embedding=embedding,
embedding_config=self.default_embedding_config, # Will be set by caller if needed
)
passages.append(passage)
# build column-based upsert data
upsert_columns = {
"id": ids,
"vector": vectors,
"text": texts,
"organization_id": organization_ids,
"archive_id": archive_ids,
"created_at": created_ats,
"tags": tags_arrays, # Add tags as array column
}
try:
# Use global semaphore to limit concurrent Turbopuffer writes
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
# Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop
await asyncio.to_thread(
_run_turbopuffer_write_in_thread,
api_key=self.api_key,
region=self.region,
namespace_name=namespace_name,
upsert_columns=upsert_columns,
distance_metric="cosine_distance",
schema={"text": {"type": "string", "full_text_search": True}},
)
logger.info(f"Successfully inserted {len(ids)} passages to Turbopuffer for archive {archive_id}")
return passages
except Exception as e:
logger.error(f"Failed to insert passages to Turbopuffer: {e}")
# check if it's a duplicate ID error
if "duplicate" in str(e).lower():
logger.error("Duplicate passage IDs detected in batch")
raise
@trace_method
@async_retry_with_backoff()
async def insert_messages(
self,
agent_id: str,
message_texts: List[str],
message_ids: List[str],
organization_id: str,
actor: "PydanticUser",
roles: List[MessageRole],
created_ats: List[datetime],
project_id: Optional[str] = None,
template_id: Optional[str] = None,
conversation_ids: Optional[List[Optional[str]]] = None,
) -> bool:
"""Insert messages into Turbopuffer.
Args:
agent_id: ID of the agent
message_texts: List of message text content to store
message_ids: List of message IDs (must match 1:1 with message_texts)
organization_id: Organization ID for the messages
actor: User actor for embedding generation
roles: List of message roles corresponding to each message
created_ats: List of creation timestamps for each message
project_id: Optional project ID for all messages
template_id: Optional template ID for all messages
conversation_ids: Optional list of conversation IDs (one per message, must match 1:1 with message_texts)
Returns:
True if successful
"""
from turbopuffer import AsyncTurbopuffer
# filter out empty message texts
filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()]
if not filtered_messages:
logger.warning("All message texts were empty, skipping insertion")
return True
# generate embeddings using the default config
filtered_texts = [text for _, text in filtered_messages]
embeddings = await self._generate_embeddings(filtered_texts, actor)
namespace_name = await self._get_message_namespace_name(organization_id)
# validation checks
if not message_ids:
raise ValueError("message_ids must be provided for Turbopuffer insertion")
if len(message_ids) != len(message_texts):
raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
if len(message_ids) != len(roles):
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
if len(message_ids) != len(created_ats):
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})")
# prepare column-based data for turbopuffer - optimized for batch insert
ids = []
vectors = []
texts = []
organization_ids_list = []
agent_ids_list = []
message_roles = []
created_at_timestamps = []
project_ids_list = []
template_ids_list = []
conversation_ids_list = []
for (original_idx, text), embedding in zip(filtered_messages, embeddings):
message_id = message_ids[original_idx]
role = roles[original_idx]
created_at = created_ats[original_idx]
conversation_id = conversation_ids[original_idx] if conversation_ids else None
# ensure the provided timestamp is timezone-aware and in UTC
if created_at.tzinfo is None:
# assume UTC if no timezone provided
timestamp = created_at.replace(tzinfo=timezone.utc)
else:
# convert to UTC if in different timezone
timestamp = created_at.astimezone(timezone.utc)
# append to columns
ids.append(message_id)
vectors.append(embedding)
texts.append(text)
organization_ids_list.append(organization_id)
agent_ids_list.append(agent_id)
message_roles.append(role.value)
created_at_timestamps.append(timestamp)
project_ids_list.append(project_id)
template_ids_list.append(template_id)
conversation_ids_list.append(conversation_id)
# build column-based upsert data
upsert_columns = {
"id": ids,
"vector": vectors,
"text": texts,
"organization_id": organization_ids_list,
"agent_id": agent_ids_list,
"role": message_roles,
"created_at": created_at_timestamps,
}
# only include conversation_id if it's provided
if conversation_ids is not None:
upsert_columns["conversation_id"] = conversation_ids_list
# only include project_id if it's provided
if project_id is not None:
upsert_columns["project_id"] = project_ids_list
# only include template_id if it's provided
if template_id is not None:
upsert_columns["template_id"] = template_ids_list
try:
# Use global semaphore to limit concurrent Turbopuffer writes
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
# Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop
await asyncio.to_thread(
_run_turbopuffer_write_in_thread,
api_key=self.api_key,
region=self.region,
namespace_name=namespace_name,
upsert_columns=upsert_columns,
distance_metric="cosine_distance",
schema={
"text": {"type": "string", "full_text_search": True},
"conversation_id": {"type": "string"},
},
)
logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}")
return True
except Exception as e:
logger.error(f"Failed to insert messages to Turbopuffer: {e}")
# check if it's a duplicate ID error
if "duplicate" in str(e).lower():
logger.error("Duplicate message IDs detected in batch")
raise
@trace_method
@async_retry_with_backoff()
async def _execute_query(
self,
namespace_name: str,
search_mode: str,
query_embedding: Optional[List[float]],
query_text: Optional[str],
top_k: int,
include_attributes: List[str],
filters: Optional[Any] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
) -> Any:
"""Generic query execution for Turbopuffer.
Args:
namespace_name: Turbopuffer namespace to query
search_mode: "vector", "fts", "hybrid", or "timestamp"
query_embedding: Embedding for vector search
query_text: Text for full-text search
top_k: Number of results to return
include_attributes: Attributes to include in results
filters: Turbopuffer filter expression
vector_weight: Weight for vector search in hybrid mode
fts_weight: Weight for FTS in hybrid mode
Returns:
Raw Turbopuffer query results or multi-query response
"""
from turbopuffer import AsyncTurbopuffer
from turbopuffer.types import QueryParam
# validate inputs based on search mode
if search_mode == "vector" and query_embedding is None:
raise ValueError("query_embedding is required for vector search mode")
if search_mode == "fts" and query_text is None:
raise ValueError("query_text is required for FTS search mode")
if search_mode == "hybrid":
if query_embedding is None or query_text is None:
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
if search_mode not in ["vector", "fts", "hybrid", "timestamp"]:
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'")
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
if search_mode == "timestamp":
# retrieve most recent items by timestamp
query_params = {
"rank_by": ("created_at", "desc"),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
query_params["filters"] = filters
return await namespace.query(**query_params)
elif search_mode == "vector":
# vector search query
query_params = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
query_params["filters"] = filters
return await namespace.query(**query_params)
elif search_mode == "fts":
# full-text search query
query_params = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
query_params["filters"] = filters
return await namespace.query(**query_params)
else: # hybrid mode
queries = []
# vector search query
vector_query = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
vector_query["filters"] = filters
queries.append(vector_query)
# full-text search query
fts_query = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
fts_query["filters"] = filters
queries.append(fts_query)
# execute multi-query
return await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
except Exception as e:
# Wrap turbopuffer errors with user-friendly messages
from turbopuffer import NotFoundError
if isinstance(e, NotFoundError):
# Extract just the error message without implementation details
error_msg = str(e)
if "namespace" in error_msg.lower() and "not found" in error_msg.lower():
raise ValueError("No conversation history found. Please send a message first to enable search.") from e
raise ValueError(f"Search data not found: {error_msg}") from e
# Re-raise other errors as-is
raise
@trace_method
async def query_passages(
self,
archive_id: str,
actor: "PydanticUser",
query_text: Optional[str] = None,
search_mode: str = "vector", # "vector", "fts", "hybrid"
top_k: int = 10,
tags: Optional[List[str]] = None,
tag_match_mode: TagMatchMode = TagMatchMode.ANY,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> List[Tuple[PydanticPassage, float, dict]]:
"""Query passages from Turbopuffer using vector search, full-text search, or hybrid search.
Args:
archive_id: ID of the archive
actor: User actor for embedding generation
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
top_k: Number of results to return
tags: Optional list of tags to filter by
tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter passages created after this date
end_date: Optional datetime to filter passages created on or before this date (inclusive)
Returns:
List of (passage, score, metadata) tuples with relevance rankings
"""
# generate embedding for vector/hybrid search if query_text is provided
query_embedding = None
if query_text and search_mode in ["vector", "hybrid"]:
embeddings = await self._generate_embeddings([query_text], actor)
query_embedding = embeddings[0]
# Check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
# Fallback to retrieving most recent passages when no search query is provided
search_mode = "timestamp"
namespace_name = await self._get_archive_namespace_name(archive_id)
# build tag filter conditions
tag_filter = None
if tags:
if tag_match_mode == TagMatchMode.ALL:
# For ALL mode, need to check each tag individually with Contains
tag_conditions = []
for tag in tags:
tag_conditions.append(("tags", "Contains", tag))
if len(tag_conditions) == 1:
tag_filter = tag_conditions[0]
else:
tag_filter = ("And", tag_conditions)
else: # tag_match_mode == TagMatchMode.ANY
# For ANY mode, use ContainsAny to match any of the tags
tag_filter = ("tags", "ContainsAny", tags)
# build date filter conditions
date_filters = []
if start_date:
# Convert to UTC to match stored timestamps
if start_date.tzinfo is not None:
start_date = start_date.astimezone(timezone.utc)
date_filters.append(("created_at", "Gte", start_date))
if end_date:
# if end_date has no time component (is at midnight), adjust to end of day
# to make the filter inclusive of the entire day
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
from datetime import timedelta
# add 1 day and subtract 1 microsecond to get 23:59:59.999999
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
# Convert to UTC to match stored timestamps
if end_date.tzinfo is not None:
end_date = end_date.astimezone(timezone.utc)
date_filters.append(("created_at", "Lte", end_date))
# combine all filters
all_filters = []
if tag_filter:
all_filters.append(tag_filter)
if date_filters:
all_filters.extend(date_filters)
# create final filter expression
final_filter = None
if len(all_filters) == 1:
final_filter = all_filters[0]
elif len(all_filters) > 1:
final_filter = ("And", all_filters)
try:
# use generic query executor
result = await self._execute_query(
namespace_name=namespace_name,
search_mode=search_mode,
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=["text", "organization_id", "archive_id", "created_at", "tags"],
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
)
# process results based on search mode
if search_mode == "hybrid":
# for hybrid mode, we get a multi-query response
vector_results = self._process_single_query_results(result.results[0], archive_id, tags)
fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True)
# use RRF and include metadata with ranks
results_with_metadata = self._reciprocal_rank_fusion(
vector_results=[passage for passage, _ in vector_results],
fts_results=[passage for passage, _ in fts_results],
get_id_func=lambda p: p.id,
vector_weight=vector_weight,
fts_weight=fts_weight,
top_k=top_k,
)
# Return (passage, score, metadata) with ranks
return results_with_metadata
else:
# for single queries (vector, fts, timestamp) - add basic metadata
is_fts = search_mode == "fts"
results = self._process_single_query_results(result, archive_id, tags, is_fts=is_fts)
# Add simple metadata for single search modes
results_with_metadata = []
for idx, (passage, score) in enumerate(results):
metadata = {
"combined_score": score,
f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
}
results_with_metadata.append((passage, score, metadata))
return results_with_metadata
except Exception as e:
logger.error(f"Failed to query passages from Turbopuffer: {e}")
raise
@trace_method
async def query_messages_by_agent_id(
self,
agent_id: str,
organization_id: str,
actor: "PydanticUser",
query_text: Optional[str] = None,
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
top_k: int = 10,
roles: Optional[List[MessageRole]] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
conversation_id: Optional[str] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> List[Tuple[dict, float, dict]]:
"""Query messages from Turbopuffer using vector search, full-text search, or hybrid search.
Args:
agent_id: ID of the agent (used for filtering results)
organization_id: Organization ID for namespace lookup
actor: User actor for embedding generation
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector")
top_k: Number of results to return
roles: Optional list of message roles to filter by
project_id: Optional project ID to filter messages by
template_id: Optional template ID to filter messages by
conversation_id: Optional conversation ID to filter messages by (use "default" for NULL)
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter messages created after this date
end_date: Optional datetime to filter messages created on or before this date (inclusive)
Returns:
List of (message_dict, score, metadata) tuples where:
- message_dict contains id, text, role, created_at
- score is the final relevance score
- metadata contains individual scores and ranking information
"""
# generate embedding for vector/hybrid search if query_text is provided
query_embedding = None
if query_text and search_mode in ["vector", "hybrid"]:
embeddings = await self._generate_embeddings([query_text], actor)
query_embedding = embeddings[0]
# Check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
# Fallback to retrieving most recent messages when no search query is provided
search_mode = "timestamp"
namespace_name = await self._get_message_namespace_name(organization_id)
# build agent_id filter
agent_filter = ("agent_id", "Eq", agent_id)
# build role filter conditions
role_filter = None
if roles:
role_values = [r.value for r in roles]
if len(role_values) == 1:
role_filter = ("role", "Eq", role_values[0])
else:
role_filter = ("role", "In", role_values)
# build date filter conditions
date_filters = []
if start_date:
# Convert to UTC to match stored timestamps
if start_date.tzinfo is not None:
start_date = start_date.astimezone(timezone.utc)
date_filters.append(("created_at", "Gte", start_date))
if end_date:
# if end_date has no time component (is at midnight), adjust to end of day
# to make the filter inclusive of the entire day
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
from datetime import timedelta
# add 1 day and subtract 1 microsecond to get 23:59:59.999999
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
# Convert to UTC to match stored timestamps
if end_date.tzinfo is not None:
end_date = end_date.astimezone(timezone.utc)
date_filters.append(("created_at", "Lte", end_date))
# build project_id filter if provided
project_filter = None
if project_id:
project_filter = ("project_id", "Eq", project_id)
# build template_id filter if provided
template_filter = None
if template_id:
template_filter = ("template_id", "Eq", template_id)
# build conversation_id filter if provided
# three cases:
# 1. conversation_id=None (omitted) -> return all messages (no filter)
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
# 3. conversation_id="xyz" -> return only messages in that conversation
conversation_filter = None
if conversation_id == "default":
# "default" is reserved for default messages only (conversation_id is none)
conversation_filter = ("conversation_id", "Eq", None)
elif conversation_id is not None:
# Specific conversation
conversation_filter = ("conversation_id", "Eq", conversation_id)
# combine all filters
all_filters = [agent_filter] # always include agent_id filter
if role_filter:
all_filters.append(role_filter)
if project_filter:
all_filters.append(project_filter)
if template_filter:
all_filters.append(template_filter)
if conversation_filter:
all_filters.append(conversation_filter)
if date_filters:
all_filters.extend(date_filters)
# create final filter expression
final_filter = None
if len(all_filters) == 1:
final_filter = all_filters[0]
elif len(all_filters) > 1:
final_filter = ("And", all_filters)
try:
# use generic query executor
result = await self._execute_query(
namespace_name=namespace_name,
search_mode=search_mode,
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=True,
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
)
# process results based on search mode
if search_mode == "hybrid":
# for hybrid mode, we get a multi-query response
vector_results = self._process_message_query_results(result.results[0])
fts_results = self._process_message_query_results(result.results[1])
# use RRF with lambda to extract ID from dict - returns metadata
results_with_metadata = self._reciprocal_rank_fusion(
vector_results=vector_results,
fts_results=fts_results,
get_id_func=lambda msg_dict: msg_dict["id"],
vector_weight=vector_weight,
fts_weight=fts_weight,
top_k=top_k,
)
# return results with metadata
return results_with_metadata
else:
# for single queries (vector, fts, timestamp)
results = self._process_message_query_results(result)
# add simple metadata for single search modes
results_with_metadata = []
for idx, msg_dict in enumerate(results):
metadata = {
"combined_score": 1.0 / (idx + 1), # Use rank-based score for single mode
"search_mode": search_mode,
f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
}
results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
return results_with_metadata
except Exception as e:
logger.error(f"Failed to query messages from Turbopuffer: {e}")
raise
async def query_messages_by_org_id(
self,
organization_id: str,
actor: "PydanticUser",
query_text: Optional[str] = None,
search_mode: str = "hybrid", # "vector", "fts", "hybrid"
top_k: int = 10,
roles: Optional[List[MessageRole]] = None,
agent_id: Optional[str] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
conversation_id: Optional[str] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> List[Tuple[dict, float, dict]]:
"""Query messages from Turbopuffer across an entire organization.
Args:
organization_id: Organization ID for namespace lookup (required)
actor: User actor for embedding generation
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "hybrid")
top_k: Number of results to return
roles: Optional list of message roles to filter by
agent_id: Optional agent ID to filter messages by
project_id: Optional project ID to filter messages by
template_id: Optional template ID to filter messages by
conversation_id: Optional conversation ID to filter messages by. Special values:
- None (omitted): Return all messages
- "default": Return only default messages (conversation_id IS NULL)
- Any other value: Return messages in that specific conversation
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter messages created after this date
end_date: Optional datetime to filter messages created on or before this date (inclusive)
Returns:
List of (message_dict, score, metadata) tuples where:
- message_dict contains id, text, role, created_at, agent_id
- score is the final relevance score (RRF score for hybrid, rank-based for single mode)
- metadata contains individual scores and ranking information
"""
# generate embedding for vector/hybrid search if query_text is provided
query_embedding = None
if query_text and search_mode in ["vector", "hybrid"]:
embeddings = await self._generate_embeddings([query_text], actor)
query_embedding = embeddings[0]
# Check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
# Fallback to retrieving most recent messages when no search query is provided
search_mode = "timestamp"
# namespace is org-scoped
namespace_name = await self._get_message_namespace_name(organization_id)
# build filters
all_filters = []
# role filter
if roles:
role_values = [r.value for r in roles]
if len(role_values) == 1:
all_filters.append(("role", "Eq", role_values[0]))
else:
all_filters.append(("role", "In", role_values))
# agent filter
if agent_id:
all_filters.append(("agent_id", "Eq", agent_id))
# project filter
if project_id:
all_filters.append(("project_id", "Eq", project_id))
# template filter
if template_id:
all_filters.append(("template_id", "Eq", template_id))
# conversation filter
# three cases:
# 1. conversation_id=None (omitted) -> return all messages (no filter)
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
# 3. conversation_id="xyz" -> return only messages in that conversation
if conversation_id == "default":
# "default" is reserved for default messages only (conversation_id is none)
all_filters.append(("conversation_id", "Eq", None))
elif conversation_id is not None:
# Specific conversation
all_filters.append(("conversation_id", "Eq", conversation_id))
# date filters
if start_date:
# Convert to UTC to match stored timestamps
if start_date.tzinfo is not None:
start_date = start_date.astimezone(timezone.utc)
all_filters.append(("created_at", "Gte", start_date))
if end_date:
# make end_date inclusive of the entire day
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
from datetime import timedelta
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
# Convert to UTC to match stored timestamps
if end_date.tzinfo is not None:
end_date = end_date.astimezone(timezone.utc)
all_filters.append(("created_at", "Lte", end_date))
# combine filters
final_filter = None
if len(all_filters) == 1:
final_filter = all_filters[0]
elif len(all_filters) > 1:
final_filter = ("And", all_filters)
try:
# execute query
result = await self._execute_query(
namespace_name=namespace_name,
search_mode=search_mode,
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=True,
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
)
# process results based on search mode
if search_mode == "hybrid":
# for hybrid mode, we get a multi-query response
vector_results = self._process_message_query_results(result.results[0])
fts_results = self._process_message_query_results(result.results[1])
# use existing RRF method - it already returns metadata with ranks
results_with_metadata = self._reciprocal_rank_fusion(
vector_results=vector_results,
fts_results=fts_results,
get_id_func=lambda msg_dict: msg_dict["id"],
vector_weight=vector_weight,
fts_weight=fts_weight,
top_k=top_k,
)
# add raw scores to metadata if available
vector_scores = {}
for row in result.results[0].rows:
if hasattr(row, "dist"):
vector_scores[row.id] = row.dist
fts_scores = {}
for row in result.results[1].rows:
if hasattr(row, "score"):
fts_scores[row.id] = row.score
# enhance metadata with raw scores
enhanced_results = []
for msg_dict, rrf_score, metadata in results_with_metadata:
msg_id = msg_dict["id"]
if msg_id in vector_scores:
metadata["vector_score"] = vector_scores[msg_id]
if msg_id in fts_scores:
metadata["fts_score"] = fts_scores[msg_id]
enhanced_results.append((msg_dict, rrf_score, metadata))
return enhanced_results
else:
# for single queries (vector or fts)
results = self._process_message_query_results(result)
results_with_metadata = []
for idx, msg_dict in enumerate(results):
metadata = {
"combined_score": 1.0 / (idx + 1),
"search_mode": search_mode,
f"{search_mode}_rank": idx + 1,
}
# add raw score if available
if hasattr(result.rows[idx], "dist"):
metadata["vector_score"] = result.rows[idx].dist
elif hasattr(result.rows[idx], "score"):
metadata["fts_score"] = result.rows[idx].score
results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
return results_with_metadata
except Exception as e:
logger.error(f"Failed to query messages from Turbopuffer: {e}")
raise
def _process_message_query_results(self, result) -> List[dict]:
"""Process results from a message query into message dicts.
For RRF, we only need the rank order - scores are not used.
"""
messages = []
for row in result.rows:
# Build message dict with key fields
message_dict = {
"id": row.id,
"text": getattr(row, "text", ""),
"organization_id": getattr(row, "organization_id", None),
"agent_id": getattr(row, "agent_id", None),
"role": getattr(row, "role", None),
"created_at": getattr(row, "created_at", None),
"conversation_id": getattr(row, "conversation_id", None),
}
messages.append(message_dict)
return messages
def _process_single_query_results(
self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False
) -> List[Tuple[PydanticPassage, float]]:
"""Process results from a single query into passage objects with scores."""
passages_with_scores = []
for row in result.rows:
# Extract tags from the result row
passage_tags = getattr(row, "tags", []) or []
# Build metadata
metadata = {}
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
passage = PydanticPassage(
id=row.id,
text=getattr(row, "text", ""),
organization_id=getattr(row, "organization_id", None),
archive_id=archive_id, # use the archive_id from the query
created_at=getattr(row, "created_at", None),
metadata_=metadata,
tags=passage_tags, # Set the actual tags from the passage
# Set required fields to empty/default values since we don't store embeddings
embedding=[], # Empty embedding since we don't return it from Turbopuffer
embedding_config=self.default_embedding_config, # No embedding config needed for retrieved passages
)
# handle score based on search type
if is_fts:
# for FTS, use the BM25 score directly (higher is better)
score = getattr(row, "$score", 0.0)
else:
# for vector search, convert distance to similarity score
distance = getattr(row, "$dist", 0.0)
score = 1.0 - distance
passages_with_scores.append((passage, score))
return passages_with_scores
def _reciprocal_rank_fusion(
self,
vector_results: List[Any],
fts_results: List[Any],
get_id_func: Callable[[Any], str],
vector_weight: float,
fts_weight: float,
top_k: int,
) -> List[Tuple[Any, float, dict]]:
"""RRF implementation that works with any object type.
RRF score = vector_weight * (1/(k + rank)) + fts_weight * (1/(k + rank))
where k is a constant (typically 60) to avoid division by zero
This is a pure rank-based fusion following the standard RRF algorithm.
Args:
vector_results: List of items from vector search (ordered by relevance)
fts_results: List of items from FTS (ordered by relevance)
get_id_func: Function to extract ID from an item
vector_weight: Weight for vector search results
fts_weight: Weight for FTS results
top_k: Number of results to return
Returns:
List of (item, score, metadata) tuples sorted by RRF score
metadata contains ranks from each result list
"""
k = 60 # standard RRF constant from Cormack et al. (2009)
# create rank mappings based on position in result lists
# rank starts at 1, not 0
vector_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(vector_results)}
fts_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(fts_results)}
# combine all unique items from both result sets
all_items = {}
for item in vector_results:
all_items[get_id_func(item)] = item
for item in fts_results:
all_items[get_id_func(item)] = item
# calculate RRF scores based purely on ranks
rrf_scores = {}
score_metadata = {}
for item_id in all_items:
# RRF formula: sum of 1/(k + rank) across result lists
# If item not in a list, we don't add anything (equivalent to rank = infinity)
vector_rrf_score = 0.0
fts_rrf_score = 0.0
if item_id in vector_ranks:
vector_rrf_score = vector_weight / (k + vector_ranks[item_id])
if item_id in fts_ranks:
fts_rrf_score = fts_weight / (k + fts_ranks[item_id])
combined_score = vector_rrf_score + fts_rrf_score
rrf_scores[item_id] = combined_score
score_metadata[item_id] = {
"combined_score": combined_score, # Final RRF score
"vector_rank": vector_ranks.get(item_id),
"fts_rank": fts_ranks.get(item_id),
}
# sort by RRF score and return with metadata
sorted_results = sorted(
[(all_items[iid], score, score_metadata[iid]) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True
)
return sorted_results[:top_k]
@trace_method
@async_retry_with_backoff()
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
"""Delete a passage from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = await self._get_archive_namespace_name(archive_id)
try:
# Run in thread pool for consistency (deletes are lightweight but use same wrapper)
await asyncio.to_thread(
_run_turbopuffer_write_in_thread,
api_key=self.api_key,
region=self.region,
namespace_name=namespace_name,
deletes=[passage_id],
)
logger.info(f"Successfully deleted passage {passage_id} from Turbopuffer archive {archive_id}")
return True
except Exception as e:
logger.error(f"Failed to delete passage from Turbopuffer: {e}")
raise
@trace_method
@async_retry_with_backoff()
async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool:
"""Delete multiple passages from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
if not passage_ids:
return True
namespace_name = await self._get_archive_namespace_name(archive_id)
try:
# Run in thread pool for consistency
await asyncio.to_thread(
_run_turbopuffer_write_in_thread,
api_key=self.api_key,
region=self.region,
namespace_name=namespace_name,
deletes=passage_ids,
)
logger.info(f"Successfully deleted {len(passage_ids)} passages from Turbopuffer archive {archive_id}")
return True
except Exception as e:
logger.error(f"Failed to delete passages from Turbopuffer: {e}")
raise
@trace_method
@async_retry_with_backoff()
async def delete_all_passages(self, archive_id: str) -> bool:
"""Delete all passages for an archive from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = await self._get_archive_namespace_name(archive_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# Turbopuffer has a delete_all() method on namespace
await namespace.delete_all()
logger.info(f"Successfully deleted all passages for archive {archive_id}")
return True
except Exception as e:
logger.error(f"Failed to delete all passages from Turbopuffer: {e}")
raise
@trace_method
@async_retry_with_backoff()
async def delete_messages(self, agent_id: str, organization_id: str, message_ids: List[str]) -> bool:
"""Delete multiple messages from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
if not message_ids:
return True
namespace_name = await self._get_message_namespace_name(organization_id)
try:
# Run in thread pool for consistency
await asyncio.to_thread(
_run_turbopuffer_write_in_thread,
api_key=self.api_key,
region=self.region,
namespace_name=namespace_name,
deletes=message_ids,
)
logger.info(f"Successfully deleted {len(message_ids)} messages from Turbopuffer for agent {agent_id}")
return True
except Exception as e:
logger.error(f"Failed to delete messages from Turbopuffer: {e}")
raise
@trace_method
@async_retry_with_backoff()
async def delete_all_messages(self, agent_id: str, organization_id: str) -> bool:
"""Delete all messages for an agent from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = await self._get_message_namespace_name(organization_id)
try:
# Run in thread pool for consistency
result = await asyncio.to_thread(
_run_turbopuffer_write_in_thread,
api_key=self.api_key,
region=self.region,
namespace_name=namespace_name,
delete_by_filter=("agent_id", "Eq", agent_id),
)
logger.info(f"Successfully deleted all messages for agent {agent_id} (deleted {result.rows_affected if result else 0} rows)")
return True
except Exception as e:
logger.error(f"Failed to delete all messages from Turbopuffer: {e}")
raise
# file/source passage methods
@trace_method
async def _get_file_passages_namespace_name(self, organization_id: str) -> str:
"""Get namespace name for file passages (org-scoped).
Args:
organization_id: Organization ID for namespace generation
Returns:
The org-scoped namespace name for file passages
"""
environment = settings.environment
if environment:
namespace_name = f"file_passages_{organization_id}_{environment.lower()}"
else:
namespace_name = f"file_passages_{organization_id}"
return namespace_name
@trace_method
@async_retry_with_backoff()
async def insert_file_passages(
self,
source_id: str,
file_id: str,
text_chunks: List[str],
organization_id: str,
actor: "PydanticUser",
created_at: Optional[datetime] = None,
) -> List[PydanticPassage]:
"""Insert file passages into Turbopuffer using org-scoped namespace.
Args:
source_id: ID of the source containing the file
file_id: ID of the file
text_chunks: List of text chunks to store
organization_id: Organization ID for the passages
actor: User actor for embedding generation
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
Returns:
List of PydanticPassage objects that were inserted
"""
from turbopuffer import AsyncTurbopuffer
if not text_chunks:
return []
# filter out empty text chunks
filtered_chunks = [text for text in text_chunks if text.strip()]
if not filtered_chunks:
logger.warning("All text chunks were empty, skipping file passage insertion")
return []
# generate embeddings using the default config
embeddings = await self._generate_embeddings(filtered_chunks, actor)
namespace_name = await self._get_file_passages_namespace_name(organization_id)
# handle timestamp - ensure UTC
if created_at is None:
timestamp = datetime.now(timezone.utc)
else:
# ensure the provided timestamp is timezone-aware and in UTC
if created_at.tzinfo is None:
# assume UTC if no timezone provided
timestamp = created_at.replace(tzinfo=timezone.utc)
else:
# convert to UTC if in different timezone
timestamp = created_at.astimezone(timezone.utc)
# prepare column-based data for turbopuffer - optimized for batch insert
ids = []
vectors = []
texts = []
organization_ids = []
source_ids = []
file_ids = []
created_ats = []
passages = []
for text, embedding in zip(filtered_chunks, embeddings):
passage = PydanticPassage(
text=text,
file_id=file_id,
source_id=source_id,
embedding=embedding,
embedding_config=self.default_embedding_config,
organization_id=actor.organization_id,
)
passages.append(passage)
# append to columns
ids.append(passage.id)
vectors.append(embedding)
texts.append(text)
organization_ids.append(organization_id)
source_ids.append(source_id)
file_ids.append(file_id)
created_ats.append(timestamp)
# build column-based upsert data
upsert_columns = {
"id": ids,
"vector": vectors,
"text": texts,
"organization_id": organization_ids,
"source_id": source_ids,
"file_id": file_ids,
"created_at": created_ats,
}
try:
# Use global semaphore to limit concurrent Turbopuffer writes
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
# Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop
await asyncio.to_thread(
_run_turbopuffer_write_in_thread,
api_key=self.api_key,
region=self.region,
namespace_name=namespace_name,
upsert_columns=upsert_columns,
distance_metric="cosine_distance",
schema={"text": {"type": "string", "full_text_search": True}},
)
logger.info(f"Successfully inserted {len(ids)} file passages to Turbopuffer for source {source_id}, file {file_id}")
return passages
except Exception as e:
logger.error(f"Failed to insert file passages to Turbopuffer: {e}")
# check if it's a duplicate ID error
if "duplicate" in str(e).lower():
logger.error("Duplicate passage IDs detected in batch")
raise
@trace_method
async def query_file_passages(
self,
source_ids: List[str],
organization_id: str,
actor: "PydanticUser",
query_text: Optional[str] = None,
search_mode: str = "vector", # "vector", "fts", "hybrid"
top_k: int = 10,
file_id: Optional[str] = None, # optional filter by specific file
vector_weight: float = 0.5,
fts_weight: float = 0.5,
) -> List[Tuple[PydanticPassage, float, dict]]:
"""Query file passages from Turbopuffer using org-scoped namespace.
Args:
source_ids: List of source IDs to query
organization_id: Organization ID for namespace lookup
actor: User actor for embedding generation
query_text: Text query for search
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
top_k: Number of results to return
file_id: Optional file ID to filter results to a specific file
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
Returns:
List of (passage, score, metadata) tuples with relevance rankings
"""
# generate embedding for vector/hybrid search if query_text is provided
query_embedding = None
if query_text and search_mode in ["vector", "hybrid"]:
embeddings = await self._generate_embeddings([query_text], actor)
query_embedding = embeddings[0]
# check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
# fallback to retrieving most recent passages when no search query is provided
search_mode = "timestamp"
namespace_name = await self._get_file_passages_namespace_name(organization_id)
# build filters - always filter by source_ids
if len(source_ids) == 1:
# single source_id, use Eq for efficiency
filters = [("source_id", "Eq", source_ids[0])]
else:
# multiple source_ids, use In operator
filters = [("source_id", "In", source_ids)]
# add file filter if specified
if file_id:
filters.append(("file_id", "Eq", file_id))
# combine filters
final_filter = filters[0] if len(filters) == 1 else ("And", filters)
try:
# use generic query executor
result = await self._execute_query(
namespace_name=namespace_name,
search_mode=search_mode,
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=["text", "organization_id", "source_id", "file_id", "created_at"],
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
)
# process results based on search mode
if search_mode == "hybrid":
# for hybrid mode, we get a multi-query response
vector_results = self._process_file_query_results(result.results[0])
fts_results = self._process_file_query_results(result.results[1], is_fts=True)
# use RRF and include metadata with ranks
results_with_metadata = self._reciprocal_rank_fusion(
vector_results=[passage for passage, _ in vector_results],
fts_results=[passage for passage, _ in fts_results],
get_id_func=lambda p: p.id,
vector_weight=vector_weight,
fts_weight=fts_weight,
top_k=top_k,
)
return results_with_metadata
else:
# for single queries (vector, fts, timestamp) - add basic metadata
is_fts = search_mode == "fts"
results = self._process_file_query_results(result, is_fts=is_fts)
# add simple metadata for single search modes
results_with_metadata = []
for idx, (passage, score) in enumerate(results):
metadata = {
"combined_score": score,
f"{search_mode}_rank": idx + 1, # add the rank for this search mode
}
results_with_metadata.append((passage, score, metadata))
return results_with_metadata
except Exception as e:
logger.error(f"Failed to query file passages from Turbopuffer: {e}")
raise
def _process_file_query_results(self, result, is_fts: bool = False) -> List[Tuple[PydanticPassage, float]]:
"""Process results from a file query into passage objects with scores."""
passages_with_scores = []
for row in result.rows:
# build metadata
metadata = {}
# create a passage with minimal fields - embeddings are not returned from Turbopuffer
passage = PydanticPassage(
id=row.id,
text=getattr(row, "text", ""),
organization_id=getattr(row, "organization_id", None),
source_id=getattr(row, "source_id", None), # get source_id from the row
file_id=getattr(row, "file_id", None),
created_at=getattr(row, "created_at", None),
metadata_=metadata,
tags=[],
# set required fields to empty/default values since we don't store embeddings
embedding=[], # empty embedding since we don't return it from Turbopuffer
embedding_config=self.default_embedding_config,
)
# handle score based on search type
if is_fts:
# for FTS, use the BM25 score directly (higher is better)
score = getattr(row, "$score", 0.0)
else:
# for vector search, convert distance to similarity score
distance = getattr(row, "$dist", 0.0)
score = 1.0 - distance
passages_with_scores.append((passage, score))
return passages_with_scores
@trace_method
@async_retry_with_backoff()
async def delete_file_passages(self, source_id: str, file_id: str, organization_id: str) -> bool:
"""Delete all passages for a specific file from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = await self._get_file_passages_namespace_name(organization_id)
try:
# use delete_by_filter to only delete passages for this file
# need to filter by both source_id and file_id
filter_expr = ("And", [("source_id", "Eq", source_id), ("file_id", "Eq", file_id)])
# Run in thread pool for consistency
result = await asyncio.to_thread(
_run_turbopuffer_write_in_thread,
api_key=self.api_key,
region=self.region,
namespace_name=namespace_name,
delete_by_filter=filter_expr,
)
logger.info(
f"Successfully deleted passages for file {file_id} from source {source_id} (deleted {result.rows_affected if result else 0} rows)"
)
return True
except Exception as e:
logger.error(f"Failed to delete file passages from Turbopuffer: {e}")
raise
@trace_method
@async_retry_with_backoff()
async def delete_source_passages(self, source_id: str, organization_id: str) -> bool:
"""Delete all passages for a source from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = await self._get_file_passages_namespace_name(organization_id)
try:
# Run in thread pool for consistency
result = await asyncio.to_thread(
_run_turbopuffer_write_in_thread,
api_key=self.api_key,
region=self.region,
namespace_name=namespace_name,
delete_by_filter=("source_id", "Eq", source_id),
)
logger.info(f"Successfully deleted all passages for source {source_id} (deleted {result.rows_affected if result else 0} rows)")
return True
except Exception as e:
logger.error(f"Failed to delete source passages from Turbopuffer: {e}")
raise
# tool methods
@trace_method
@async_retry_with_backoff()
async def delete_tools(self, organization_id: str, tool_ids: List[str]) -> bool:
"""Delete tools from Turbopuffer.
Args:
organization_id: Organization ID for namespace lookup
tool_ids: List of tool IDs to delete
Returns:
True if successful
"""
from turbopuffer import AsyncTurbopuffer
if not tool_ids:
return True
namespace_name = await self._get_tool_namespace_name(organization_id)
try:
# Run in thread pool for consistency
await asyncio.to_thread(
_run_turbopuffer_write_in_thread,
api_key=self.api_key,
region=self.region,
namespace_name=namespace_name,
deletes=tool_ids,
)
logger.info(f"Successfully deleted {len(tool_ids)} tools from Turbopuffer")
return True
except Exception as e:
logger.error(f"Failed to delete tools from Turbopuffer: {e}")
raise
@trace_method
async def query_tools(
self,
organization_id: str,
actor: "PydanticUser",
query_text: Optional[str] = None,
search_mode: str = "hybrid", # "vector", "fts", "hybrid", "timestamp"
top_k: int = 50,
tool_types: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
) -> List[Tuple[dict, float, dict]]:
"""Query tools from Turbopuffer using semantic search.
Args:
organization_id: Organization ID for namespace lookup
actor: User actor for embedding generation
query_text: Text query for search
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp"
top_k: Number of results to return
tool_types: Optional list of tool types to filter by
tags: Optional list of tags to filter by (match any)
vector_weight: Weight for vector search in hybrid mode
fts_weight: Weight for FTS in hybrid mode
Returns:
List of (tool_dict, score, metadata) tuples
"""
# Generate embedding for vector/hybrid search
query_embedding = None
if query_text and search_mode in ["vector", "hybrid"]:
embeddings = await self._generate_embeddings([query_text], actor)
query_embedding = embeddings[0] if embeddings else None
# Fallback to timestamp-based retrieval when no query
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
search_mode = "timestamp"
namespace_name = await self._get_tool_namespace_name(organization_id)
# Build filters
all_filters = []
if tool_types:
if len(tool_types) == 1:
all_filters.append(("tool_type", "Eq", tool_types[0]))
else:
all_filters.append(("tool_type", "In", tool_types))
if tags:
all_filters.append(("tags", "ContainsAny", tags))
# Combine filters
final_filter = None
if len(all_filters) == 1:
final_filter = all_filters[0]
elif len(all_filters) > 1:
final_filter = ("And", all_filters)
try:
result = await self._execute_query(
namespace_name=namespace_name,
search_mode=search_mode,
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=["text", "name", "organization_id", "tool_type", "tags", "created_at"],
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
)
if search_mode == "hybrid":
vector_results = self._process_tool_query_results(result.results[0])
fts_results = self._process_tool_query_results(result.results[1])
results_with_metadata = self._reciprocal_rank_fusion(
vector_results=vector_results,
fts_results=fts_results,
get_id_func=lambda d: d["id"],
vector_weight=vector_weight,
fts_weight=fts_weight,
top_k=top_k,
)
return results_with_metadata
else:
results = self._process_tool_query_results(result)
results_with_metadata = []
for idx, tool_dict in enumerate(results):
metadata = {
"combined_score": 1.0 / (idx + 1),
"search_mode": search_mode,
f"{search_mode}_rank": idx + 1,
}
results_with_metadata.append((tool_dict, metadata["combined_score"], metadata))
return results_with_metadata
except Exception as e:
logger.error(f"Failed to query tools from Turbopuffer: {e}")
raise
def _process_tool_query_results(self, result) -> List[dict]:
"""Process results from a tool query into tool dicts."""
tools = []
for row in result.rows:
tool_dict = {
"id": row.id,
"text": getattr(row, "text", ""),
"name": getattr(row, "name", ""),
"organization_id": getattr(row, "organization_id", None),
"tool_type": getattr(row, "tool_type", None),
"tags": getattr(row, "tags", []),
"created_at": getattr(row, "created_at", None),
}
tools.append(tool_dict)
return tools