"""Turbopuffer utilities for archival memory storage.""" import asyncio import json import logging import random from datetime import datetime, timezone from functools import wraps from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, TypeVar if TYPE_CHECKING: from letta.schemas.tool import Tool as PydanticTool from letta.schemas.user import User as PydanticUser import httpx from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE from letta.errors import LettaInvalidArgumentError from letta.otel.tracing import log_event, trace_method from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole, TagMatchMode from letta.schemas.passage import Passage as PydanticPassage from letta.settings import model_settings, settings logger = logging.getLogger(__name__) # Type variable for generic async retry decorator T = TypeVar("T") # Default retry configuration for turbopuffer operations TPUF_MAX_RETRIES = 3 TPUF_INITIAL_DELAY = 1.0 # seconds TPUF_EXPONENTIAL_BASE = 2.0 TPUF_JITTER = True def is_transient_error(error: Exception) -> bool: """Check if an error is transient and should be retried. Args: error: The exception to check Returns: True if the error is transient and can be retried """ # httpx connection errors (network issues, DNS failures, etc.) if isinstance(error, httpx.ConnectError): return True # httpx timeout errors if isinstance(error, httpx.TimeoutException): return True # httpx network errors if isinstance(error, httpx.NetworkError): return True # Check for connection-related errors in the error message error_str = str(error).lower() transient_patterns = [ "connect call failed", "connection refused", "connection reset", "connection timed out", "temporary failure", "name resolution", "dns", "network unreachable", "no route to host", "ssl handshake", ] for pattern in transient_patterns: if pattern in error_str: return True return False def async_retry_with_backoff( max_retries: int = TPUF_MAX_RETRIES, initial_delay: float = TPUF_INITIAL_DELAY, exponential_base: float = TPUF_EXPONENTIAL_BASE, jitter: bool = TPUF_JITTER, ): """Decorator for async functions that retries on transient errors with exponential backoff. Args: max_retries: Maximum number of retry attempts initial_delay: Initial delay between retries in seconds exponential_base: Base for exponential backoff calculation jitter: Whether to add random jitter to delays Returns: Decorated async function with retry logic """ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) async def wrapper(*args, **kwargs) -> Any: num_retries = 0 delay = initial_delay while True: try: return await func(*args, **kwargs) except Exception as e: # Check if this is a retryable error if not is_transient_error(e): # Not a transient error, re-raise immediately raise num_retries += 1 # Log the retry attempt log_event( "turbopuffer_retry_attempt", { "attempt": num_retries, "delay": delay, "error_type": type(e).__name__, "error": str(e), "function": func.__name__, }, ) logger.warning( f"Turbopuffer operation '{func.__name__}' failed with transient error " f"(attempt {num_retries}/{max_retries}): {e}. Retrying in {delay:.1f}s..." ) # Check if max retries exceeded if num_retries > max_retries: log_event( "turbopuffer_max_retries_exceeded", { "max_retries": max_retries, "error_type": type(e).__name__, "error": str(e), "function": func.__name__, }, ) logger.error(f"Turbopuffer operation '{func.__name__}' failed after {max_retries} retries: {e}") raise # Wait with exponential backoff await asyncio.sleep(delay) # Calculate next delay with optional jitter delay *= exponential_base if jitter: delay *= 1 + random.random() * 0.1 # Add up to 10% jitter return wrapper return decorator # Global semaphore for Turbopuffer operations to prevent overwhelming the service # This is separate from embedding semaphore since Turbopuffer can handle more concurrency _GLOBAL_TURBOPUFFER_SEMAPHORE = asyncio.Semaphore(5) def _run_turbopuffer_write_in_thread( api_key: str, region: str, namespace_name: str, upsert_columns: dict | None = None, deletes: list | None = None, delete_by_filter: tuple | None = None, distance_metric: str = "cosine_distance", schema: dict | None = None, ): """ Sync wrapper to run turbopuffer write in isolated event loop. Turbopuffer's async write() does CPU-intensive base64 encoding of vectors synchronously in async functions, blocking the event loop. Running it in a thread pool with an isolated event loop prevents blocking. """ from turbopuffer import AsyncTurbopuffer # Create new event loop for this worker thread loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: async def do_write(): async with AsyncTurbopuffer(api_key=api_key, region=region) as client: namespace = client.namespace(namespace_name) # Build write kwargs kwargs = {"distance_metric": distance_metric} if upsert_columns: kwargs["upsert_columns"] = upsert_columns if deletes: kwargs["deletes"] = deletes if delete_by_filter: kwargs["delete_by_filter"] = delete_by_filter if schema: kwargs["schema"] = schema return await namespace.write(**kwargs) return loop.run_until_complete(do_write()) finally: loop.close() def should_use_tpuf() -> bool: # We need OpenAI since we default to their embedding model return bool(settings.use_tpuf) and bool(settings.tpuf_api_key) and bool(model_settings.openai_api_key) def should_use_tpuf_for_messages() -> bool: """Check if Turbopuffer should be used for messages.""" return should_use_tpuf() and bool(settings.embed_all_messages) def should_use_tpuf_for_tools() -> bool: """Check if Turbopuffer should be used for tools.""" return should_use_tpuf() and bool(settings.embed_tools) class TurbopufferClient: """Client for managing archival memory with Turbopuffer vector database.""" default_embedding_config = EmbeddingConfig( embedding_model="text-embedding-3-small", embedding_endpoint_type="openai", embedding_endpoint="https://api.openai.com/v1", embedding_dim=1536, embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, ) def __init__(self, api_key: str | None = None, region: str | None = None): """Initialize Turbopuffer client.""" self.api_key = api_key or settings.tpuf_api_key self.region = region or settings.tpuf_region from letta.services.agent_manager import AgentManager from letta.services.archive_manager import ArchiveManager self.archive_manager = ArchiveManager() self.agent_manager = AgentManager() if not self.api_key: raise ValueError("Turbopuffer API key not provided") @trace_method async def _generate_embeddings(self, texts: List[str], actor: "PydanticUser") -> List[List[float]]: """Generate embeddings using the default embedding configuration. Args: texts: List of texts to embed actor: User actor for embedding generation Returns: List of embedding vectors """ from letta.llm_api.llm_client import LLMClient # filter out empty strings after stripping filtered_texts = [text for text in texts if text.strip()] # skip embedding if no valid texts if not filtered_texts: return [] embedding_client = LLMClient.create( provider_type=self.default_embedding_config.embedding_endpoint_type, actor=actor, ) embeddings = await embedding_client.request_embeddings(filtered_texts, self.default_embedding_config) return embeddings @trace_method async def _get_archive_namespace_name(self, archive_id: str) -> str: """Get namespace name for a specific archive.""" return await self.archive_manager.get_or_set_vector_db_namespace_async(archive_id) @trace_method async def _get_message_namespace_name(self, organization_id: str) -> str: """Get namespace name for messages (org-scoped). Args: organization_id: Organization ID for namespace generation Returns: The org-scoped namespace name for messages """ environment = settings.environment if environment: namespace_name = f"messages_{organization_id}_{environment.lower()}" else: namespace_name = f"messages_{organization_id}" return namespace_name @trace_method async def _get_tool_namespace_name(self, organization_id: str) -> str: """Get namespace name for tools (org-scoped). Args: organization_id: Organization ID for namespace generation Returns: The org-scoped namespace name for tools """ environment = settings.environment if environment: namespace_name = f"tools_{organization_id}_{environment.lower()}" else: namespace_name = f"tools_{organization_id}" return namespace_name def _extract_tool_text(self, tool: "PydanticTool") -> str: """Extract searchable text from a tool for embedding. Combines name, description, and JSON schema into a structured format that provides rich context for semantic search. Args: tool: The tool to extract text from Returns: JSON-formatted string containing tool information """ parts = { "name": tool.name or "", "description": tool.description or "", } # Extract parameter information from JSON schema if tool.json_schema: # Include function description from schema if different from tool description schema_description = tool.json_schema.get("description", "") if schema_description and schema_description != tool.description: parts["schema_description"] = schema_description # Extract parameter information parameters = tool.json_schema.get("parameters", {}) if parameters: properties = parameters.get("properties", {}) param_descriptions = [] for param_name, param_info in properties.items(): param_desc = param_info.get("description", "") param_type = param_info.get("type", "any") if param_desc: param_descriptions.append(f"{param_name} ({param_type}): {param_desc}") else: param_descriptions.append(f"{param_name} ({param_type})") if param_descriptions: parts["parameters"] = param_descriptions # Include tags for additional context if tool.tags: parts["tags"] = tool.tags return json.dumps(parts) @trace_method @async_retry_with_backoff() async def insert_tools( self, tools: List["PydanticTool"], organization_id: str, actor: "PydanticUser", ) -> bool: """Insert tools into Turbopuffer. Args: tools: List of tools to store organization_id: Organization ID for the tools actor: User actor for embedding generation Returns: True if successful """ if not tools: return True # Extract text and filter out empty content tool_texts = [] valid_tools = [] for tool in tools: text = self._extract_tool_text(tool) if text.strip(): tool_texts.append(text) valid_tools.append(tool) if not valid_tools: logger.warning("All tools had empty text content, skipping insertion") return True # Generate embeddings embeddings = await self._generate_embeddings(tool_texts, actor) namespace_name = await self._get_tool_namespace_name(organization_id) # Prepare column-based data ids = [] vectors = [] texts = [] names = [] organization_ids = [] tool_types = [] tags_arrays = [] created_ats = [] for tool, text, embedding in zip(valid_tools, tool_texts, embeddings): ids.append(tool.id) vectors.append(embedding) texts.append(text) names.append(tool.name or "") organization_ids.append(organization_id) tool_types.append(tool.tool_type.value if tool.tool_type else "custom") tags_arrays.append(tool.tags or []) created_ats.append(getattr(tool, "created_at", None) or datetime.now(timezone.utc)) upsert_columns = { "id": ids, "vector": vectors, "text": texts, "name": names, "organization_id": organization_ids, "tool_type": tool_types, "tags": tags_arrays, "created_at": created_ats, } try: # Use global semaphore to limit concurrent Turbopuffer writes async with _GLOBAL_TURBOPUFFER_SEMAPHORE: # Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop await asyncio.to_thread( _run_turbopuffer_write_in_thread, api_key=self.api_key, region=self.region, namespace_name=namespace_name, upsert_columns=upsert_columns, distance_metric="cosine_distance", schema={"text": {"type": "string", "full_text_search": True}}, ) logger.info(f"Successfully inserted {len(ids)} tools to Turbopuffer") return True except Exception as e: logger.error(f"Failed to insert tools to Turbopuffer: {e}") raise @trace_method @async_retry_with_backoff() async def insert_archival_memories( self, archive_id: str, text_chunks: List[str], passage_ids: List[str], organization_id: str, actor: "PydanticUser", tags: Optional[List[str]] = None, created_at: Optional[datetime] = None, embeddings: Optional[List[List[float]]] = None, ) -> List[PydanticPassage]: """Insert passages into Turbopuffer. Args: archive_id: ID of the archive text_chunks: List of text chunks to store passage_ids: List of passage IDs (must match 1:1 with text_chunks) organization_id: Organization ID for the passages actor: User actor for embedding generation tags: Optional list of tags to attach to all passages created_at: Optional timestamp for retroactive entries (defaults to current UTC time) embeddings: Optional pre-computed embeddings (must match 1:1 with text_chunks). If provided, skips embedding generation. Returns: List of PydanticPassage objects that were inserted """ # filter out empty text chunks filtered_chunks = [(i, text) for i, text in enumerate(text_chunks) if text.strip()] if not filtered_chunks: logger.warning("All text chunks were empty, skipping insertion") return [] filtered_texts = [text for _, text in filtered_chunks] # use provided embeddings only if dimensions match TPUF's expected dimension use_provided_embeddings = False if embeddings is not None: if len(embeddings) != len(text_chunks): raise LettaInvalidArgumentError( f"embeddings length ({len(embeddings)}) must match text_chunks length ({len(text_chunks)})", argument_name="embeddings", ) # check if first non-empty embedding has correct dimensions filtered_indices = [i for i, _ in filtered_chunks] sample_embedding = embeddings[filtered_indices[0]] if filtered_indices else None if sample_embedding is not None and len(sample_embedding) == self.default_embedding_config.embedding_dim: use_provided_embeddings = True filtered_embeddings = [embeddings[i] for i, _ in filtered_chunks] else: logger.debug( f"Embedding dimension mismatch (got {len(sample_embedding) if sample_embedding else 'None'}, " f"expected {self.default_embedding_config.embedding_dim}), regenerating embeddings" ) if not use_provided_embeddings: filtered_embeddings = await self._generate_embeddings(filtered_texts, actor) namespace_name = await self._get_archive_namespace_name(archive_id) # handle timestamp - ensure UTC if created_at is None: timestamp = datetime.now(timezone.utc) else: # ensure the provided timestamp is timezone-aware and in UTC if created_at.tzinfo is None: # assume UTC if no timezone provided timestamp = created_at.replace(tzinfo=timezone.utc) else: # convert to UTC if in different timezone timestamp = created_at.astimezone(timezone.utc) # passage_ids must be provided for dual-write consistency if not passage_ids: raise ValueError("passage_ids must be provided for Turbopuffer insertion") if len(passage_ids) != len(text_chunks): raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})") # prepare column-based data for turbopuffer - optimized for batch insert ids = [] vectors = [] texts = [] organization_ids = [] archive_ids = [] created_ats = [] tags_arrays = [] # Store tags as arrays passages = [] for (original_idx, text), embedding in zip(filtered_chunks, filtered_embeddings): passage_id = passage_ids[original_idx] # append to columns ids.append(passage_id) vectors.append(embedding) texts.append(text) organization_ids.append(organization_id) archive_ids.append(archive_id) created_ats.append(timestamp) tags_arrays.append(tags or []) # Store tags as array # Create PydanticPassage object passage = PydanticPassage( id=passage_id, text=text, organization_id=organization_id, archive_id=archive_id, created_at=timestamp, metadata_={}, tags=tags or [], # Include tags in the passage embedding=embedding, embedding_config=self.default_embedding_config, # Will be set by caller if needed ) passages.append(passage) # build column-based upsert data upsert_columns = { "id": ids, "vector": vectors, "text": texts, "organization_id": organization_ids, "archive_id": archive_ids, "created_at": created_ats, "tags": tags_arrays, # Add tags as array column } try: # Use global semaphore to limit concurrent Turbopuffer writes async with _GLOBAL_TURBOPUFFER_SEMAPHORE: # Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop await asyncio.to_thread( _run_turbopuffer_write_in_thread, api_key=self.api_key, region=self.region, namespace_name=namespace_name, upsert_columns=upsert_columns, distance_metric="cosine_distance", schema={"text": {"type": "string", "full_text_search": True}}, ) logger.info(f"Successfully inserted {len(ids)} passages to Turbopuffer for archive {archive_id}") return passages except Exception as e: logger.error(f"Failed to insert passages to Turbopuffer: {e}") # check if it's a duplicate ID error if "duplicate" in str(e).lower(): logger.error("Duplicate passage IDs detected in batch") raise @trace_method @async_retry_with_backoff() async def insert_messages( self, agent_id: str, message_texts: List[str], message_ids: List[str], organization_id: str, actor: "PydanticUser", roles: List[MessageRole], created_ats: List[datetime], project_id: Optional[str] = None, template_id: Optional[str] = None, conversation_ids: Optional[List[Optional[str]]] = None, ) -> bool: """Insert messages into Turbopuffer. Args: agent_id: ID of the agent message_texts: List of message text content to store message_ids: List of message IDs (must match 1:1 with message_texts) organization_id: Organization ID for the messages actor: User actor for embedding generation roles: List of message roles corresponding to each message created_ats: List of creation timestamps for each message project_id: Optional project ID for all messages template_id: Optional template ID for all messages conversation_ids: Optional list of conversation IDs (one per message, must match 1:1 with message_texts) Returns: True if successful """ # filter out empty message texts filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()] if not filtered_messages: logger.warning("All message texts were empty, skipping insertion") return True # generate embeddings using the default config filtered_texts = [text for _, text in filtered_messages] embeddings = await self._generate_embeddings(filtered_texts, actor) namespace_name = await self._get_message_namespace_name(organization_id) # validation checks if not message_ids: raise ValueError("message_ids must be provided for Turbopuffer insertion") if len(message_ids) != len(message_texts): raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})") if len(message_ids) != len(roles): raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})") if len(message_ids) != len(created_ats): raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})") if conversation_ids is not None and len(conversation_ids) != len(message_ids): raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})") # prepare column-based data for turbopuffer - optimized for batch insert ids = [] vectors = [] texts = [] organization_ids_list = [] agent_ids_list = [] message_roles = [] created_at_timestamps = [] project_ids_list = [] template_ids_list = [] conversation_ids_list = [] for (original_idx, text), embedding in zip(filtered_messages, embeddings): message_id = message_ids[original_idx] role = roles[original_idx] created_at = created_ats[original_idx] conversation_id = conversation_ids[original_idx] if conversation_ids else None # ensure the provided timestamp is timezone-aware and in UTC if created_at.tzinfo is None: # assume UTC if no timezone provided timestamp = created_at.replace(tzinfo=timezone.utc) else: # convert to UTC if in different timezone timestamp = created_at.astimezone(timezone.utc) # append to columns ids.append(message_id) vectors.append(embedding) texts.append(text) organization_ids_list.append(organization_id) agent_ids_list.append(agent_id) message_roles.append(role.value) created_at_timestamps.append(timestamp) project_ids_list.append(project_id) template_ids_list.append(template_id) conversation_ids_list.append(conversation_id) # build column-based upsert data upsert_columns = { "id": ids, "vector": vectors, "text": texts, "organization_id": organization_ids_list, "agent_id": agent_ids_list, "role": message_roles, "created_at": created_at_timestamps, } # only include conversation_id if it's provided if conversation_ids is not None: upsert_columns["conversation_id"] = conversation_ids_list # only include project_id if it's provided if project_id is not None: upsert_columns["project_id"] = project_ids_list # only include template_id if it's provided if template_id is not None: upsert_columns["template_id"] = template_ids_list try: # Use global semaphore to limit concurrent Turbopuffer writes async with _GLOBAL_TURBOPUFFER_SEMAPHORE: # Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop await asyncio.to_thread( _run_turbopuffer_write_in_thread, api_key=self.api_key, region=self.region, namespace_name=namespace_name, upsert_columns=upsert_columns, distance_metric="cosine_distance", schema={ "text": {"type": "string", "full_text_search": True}, "conversation_id": {"type": "string"}, }, ) logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}") return True except Exception as e: logger.error(f"Failed to insert messages to Turbopuffer: {e}") # check if it's a duplicate ID error if "duplicate" in str(e).lower(): logger.error("Duplicate message IDs detected in batch") raise @trace_method @async_retry_with_backoff() async def _execute_query( self, namespace_name: str, search_mode: str, query_embedding: Optional[List[float]], query_text: Optional[str], top_k: int, include_attributes: List[str], filters: Optional[Any] = None, vector_weight: float = 0.5, fts_weight: float = 0.5, ) -> Any: """Generic query execution for Turbopuffer. Args: namespace_name: Turbopuffer namespace to query search_mode: "vector", "fts", "hybrid", or "timestamp" query_embedding: Embedding for vector search query_text: Text for full-text search top_k: Number of results to return include_attributes: Attributes to include in results filters: Turbopuffer filter expression vector_weight: Weight for vector search in hybrid mode fts_weight: Weight for FTS in hybrid mode Returns: Raw Turbopuffer query results or multi-query response """ from turbopuffer import AsyncTurbopuffer from turbopuffer.types import QueryParam # validate inputs based on search mode if search_mode == "vector" and query_embedding is None: raise ValueError("query_embedding is required for vector search mode") if search_mode == "fts" and query_text is None: raise ValueError("query_text is required for FTS search mode") if search_mode == "hybrid": if query_embedding is None or query_text is None: raise ValueError("Both query_embedding and query_text are required for hybrid search mode") if search_mode not in ["vector", "fts", "hybrid", "timestamp"]: raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'") try: async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: namespace = client.namespace(namespace_name) if search_mode == "timestamp": # retrieve most recent items by timestamp query_params = { "rank_by": ("created_at", "desc"), "top_k": top_k, "include_attributes": include_attributes, } if filters: query_params["filters"] = filters return await namespace.query(**query_params) elif search_mode == "vector": # vector search query query_params = { "rank_by": ("vector", "ANN", query_embedding), "top_k": top_k, "include_attributes": include_attributes, } if filters: query_params["filters"] = filters return await namespace.query(**query_params) elif search_mode == "fts": # full-text search query query_params = { "rank_by": ("text", "BM25", query_text), "top_k": top_k, "include_attributes": include_attributes, } if filters: query_params["filters"] = filters return await namespace.query(**query_params) else: # hybrid mode queries = [] # vector search query vector_query = { "rank_by": ("vector", "ANN", query_embedding), "top_k": top_k, "include_attributes": include_attributes, } if filters: vector_query["filters"] = filters queries.append(vector_query) # full-text search query fts_query = { "rank_by": ("text", "BM25", query_text), "top_k": top_k, "include_attributes": include_attributes, } if filters: fts_query["filters"] = filters queries.append(fts_query) # execute multi-query return await namespace.multi_query(queries=[QueryParam(**q) for q in queries]) except Exception as e: # Wrap turbopuffer errors with user-friendly messages from turbopuffer import NotFoundError if isinstance(e, NotFoundError): # Extract just the error message without implementation details error_msg = str(e) if "namespace" in error_msg.lower() and "not found" in error_msg.lower(): raise ValueError("No conversation history found. Please send a message first to enable search.") from e raise ValueError(f"Search data not found: {error_msg}") from e # Re-raise other errors as-is raise @trace_method async def query_passages( self, archive_id: str, actor: "PydanticUser", query_text: Optional[str] = None, search_mode: str = "vector", # "vector", "fts", "hybrid" top_k: int = 10, tags: Optional[List[str]] = None, tag_match_mode: TagMatchMode = TagMatchMode.ANY, vector_weight: float = 0.5, fts_weight: float = 0.5, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, ) -> List[Tuple[PydanticPassage, float, dict]]: """Query passages from Turbopuffer using vector search, full-text search, or hybrid search. Args: archive_id: ID of the archive actor: User actor for embedding generation query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes) search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector") top_k: Number of results to return tags: Optional list of tags to filter by tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY vector_weight: Weight for vector search results in hybrid mode (default: 0.5) fts_weight: Weight for FTS results in hybrid mode (default: 0.5) start_date: Optional datetime to filter passages created after this date end_date: Optional datetime to filter passages created on or before this date (inclusive) Returns: List of (passage, score, metadata) tuples with relevance rankings """ # generate embedding for vector/hybrid search if query_text is provided query_embedding = None if query_text and search_mode in ["vector", "hybrid"]: embeddings = await self._generate_embeddings([query_text], actor) query_embedding = embeddings[0] # Check if we should fallback to timestamp-based retrieval if query_embedding is None and query_text is None and search_mode not in ["timestamp"]: # Fallback to retrieving most recent passages when no search query is provided search_mode = "timestamp" namespace_name = await self._get_archive_namespace_name(archive_id) # build tag filter conditions tag_filter = None if tags: if tag_match_mode == TagMatchMode.ALL: # For ALL mode, need to check each tag individually with Contains tag_conditions = [] for tag in tags: tag_conditions.append(("tags", "Contains", tag)) if len(tag_conditions) == 1: tag_filter = tag_conditions[0] else: tag_filter = ("And", tag_conditions) else: # tag_match_mode == TagMatchMode.ANY # For ANY mode, use ContainsAny to match any of the tags tag_filter = ("tags", "ContainsAny", tags) # build date filter conditions date_filters = [] if start_date: # Convert to UTC to match stored timestamps if start_date.tzinfo is not None: start_date = start_date.astimezone(timezone.utc) date_filters.append(("created_at", "Gte", start_date)) if end_date: # if end_date has no time component (is at midnight), adjust to end of day # to make the filter inclusive of the entire day if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0: from datetime import timedelta # add 1 day and subtract 1 microsecond to get 23:59:59.999999 end_date = end_date + timedelta(days=1) - timedelta(microseconds=1) # Convert to UTC to match stored timestamps if end_date.tzinfo is not None: end_date = end_date.astimezone(timezone.utc) date_filters.append(("created_at", "Lte", end_date)) # combine all filters all_filters = [] if tag_filter: all_filters.append(tag_filter) if date_filters: all_filters.extend(date_filters) # create final filter expression final_filter = None if len(all_filters) == 1: final_filter = all_filters[0] elif len(all_filters) > 1: final_filter = ("And", all_filters) try: # use generic query executor result = await self._execute_query( namespace_name=namespace_name, search_mode=search_mode, query_embedding=query_embedding, query_text=query_text, top_k=top_k, include_attributes=["text", "organization_id", "archive_id", "created_at", "tags"], filters=final_filter, vector_weight=vector_weight, fts_weight=fts_weight, ) # process results based on search mode if search_mode == "hybrid": # for hybrid mode, we get a multi-query response vector_results = self._process_single_query_results(result.results[0], archive_id, tags) fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True) # use RRF and include metadata with ranks results_with_metadata = self._reciprocal_rank_fusion( vector_results=[passage for passage, _ in vector_results], fts_results=[passage for passage, _ in fts_results], get_id_func=lambda p: p.id, vector_weight=vector_weight, fts_weight=fts_weight, top_k=top_k, ) # Return (passage, score, metadata) with ranks return results_with_metadata else: # for single queries (vector, fts, timestamp) - add basic metadata is_fts = search_mode == "fts" results = self._process_single_query_results(result, archive_id, tags, is_fts=is_fts) # Add simple metadata for single search modes results_with_metadata = [] for idx, (passage, score) in enumerate(results): metadata = { "combined_score": score, f"{search_mode}_rank": idx + 1, # Add the rank for this search mode } results_with_metadata.append((passage, score, metadata)) return results_with_metadata except Exception as e: logger.error(f"Failed to query passages from Turbopuffer: {e}") raise @trace_method async def query_messages_by_agent_id( self, agent_id: str, organization_id: str, actor: "PydanticUser", query_text: Optional[str] = None, search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp" top_k: int = 10, roles: Optional[List[MessageRole]] = None, project_id: Optional[str] = None, template_id: Optional[str] = None, conversation_id: Optional[str] = None, vector_weight: float = 0.5, fts_weight: float = 0.5, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, ) -> List[Tuple[dict, float, dict]]: """Query messages from Turbopuffer using vector search, full-text search, or hybrid search. Args: agent_id: ID of the agent (used for filtering results) organization_id: Organization ID for namespace lookup actor: User actor for embedding generation query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes) search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector") top_k: Number of results to return roles: Optional list of message roles to filter by project_id: Optional project ID to filter messages by template_id: Optional template ID to filter messages by conversation_id: Optional conversation ID to filter messages by (use "default" for NULL) vector_weight: Weight for vector search results in hybrid mode (default: 0.5) fts_weight: Weight for FTS results in hybrid mode (default: 0.5) start_date: Optional datetime to filter messages created after this date end_date: Optional datetime to filter messages created on or before this date (inclusive) Returns: List of (message_dict, score, metadata) tuples where: - message_dict contains id, text, role, created_at - score is the final relevance score - metadata contains individual scores and ranking information """ # generate embedding for vector/hybrid search if query_text is provided query_embedding = None if query_text and search_mode in ["vector", "hybrid"]: embeddings = await self._generate_embeddings([query_text], actor) query_embedding = embeddings[0] # Check if we should fallback to timestamp-based retrieval if query_embedding is None and query_text is None and search_mode not in ["timestamp"]: # Fallback to retrieving most recent messages when no search query is provided search_mode = "timestamp" namespace_name = await self._get_message_namespace_name(organization_id) # build agent_id filter agent_filter = ("agent_id", "Eq", agent_id) # build role filter conditions role_filter = None if roles: role_values = [r.value for r in roles] if len(role_values) == 1: role_filter = ("role", "Eq", role_values[0]) else: role_filter = ("role", "In", role_values) # build date filter conditions date_filters = [] if start_date: # Convert to UTC to match stored timestamps if start_date.tzinfo is not None: start_date = start_date.astimezone(timezone.utc) date_filters.append(("created_at", "Gte", start_date)) if end_date: # if end_date has no time component (is at midnight), adjust to end of day # to make the filter inclusive of the entire day if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0: from datetime import timedelta # add 1 day and subtract 1 microsecond to get 23:59:59.999999 end_date = end_date + timedelta(days=1) - timedelta(microseconds=1) # Convert to UTC to match stored timestamps if end_date.tzinfo is not None: end_date = end_date.astimezone(timezone.utc) date_filters.append(("created_at", "Lte", end_date)) # build project_id filter if provided project_filter = None if project_id: project_filter = ("project_id", "Eq", project_id) # build template_id filter if provided template_filter = None if template_id: template_filter = ("template_id", "Eq", template_id) # build conversation_id filter if provided # three cases: # 1. conversation_id=None (omitted) -> return all messages (no filter) # 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility # 3. conversation_id="xyz" -> return only messages in that conversation conversation_filter = None if conversation_id == "default": # "default" is reserved for default messages only (conversation_id is none) conversation_filter = ("conversation_id", "Eq", None) elif conversation_id is not None: # Specific conversation conversation_filter = ("conversation_id", "Eq", conversation_id) # combine all filters all_filters = [agent_filter] # always include agent_id filter if role_filter: all_filters.append(role_filter) if project_filter: all_filters.append(project_filter) if template_filter: all_filters.append(template_filter) if conversation_filter: all_filters.append(conversation_filter) if date_filters: all_filters.extend(date_filters) # create final filter expression final_filter = None if len(all_filters) == 1: final_filter = all_filters[0] elif len(all_filters) > 1: final_filter = ("And", all_filters) try: # use generic query executor result = await self._execute_query( namespace_name=namespace_name, search_mode=search_mode, query_embedding=query_embedding, query_text=query_text, top_k=top_k, include_attributes=True, filters=final_filter, vector_weight=vector_weight, fts_weight=fts_weight, ) # process results based on search mode if search_mode == "hybrid": # for hybrid mode, we get a multi-query response vector_results = self._process_message_query_results(result.results[0]) fts_results = self._process_message_query_results(result.results[1]) # use RRF with lambda to extract ID from dict - returns metadata results_with_metadata = self._reciprocal_rank_fusion( vector_results=vector_results, fts_results=fts_results, get_id_func=lambda msg_dict: msg_dict["id"], vector_weight=vector_weight, fts_weight=fts_weight, top_k=top_k, ) # return results with metadata return results_with_metadata else: # for single queries (vector, fts, timestamp) results = self._process_message_query_results(result) # add simple metadata for single search modes results_with_metadata = [] for idx, msg_dict in enumerate(results): metadata = { "combined_score": 1.0 / (idx + 1), # Use rank-based score for single mode "search_mode": search_mode, f"{search_mode}_rank": idx + 1, # Add the rank for this search mode } results_with_metadata.append((msg_dict, metadata["combined_score"], metadata)) return results_with_metadata except Exception as e: logger.error(f"Failed to query messages from Turbopuffer: {e}") raise async def query_messages_by_org_id( self, organization_id: str, actor: "PydanticUser", query_text: Optional[str] = None, search_mode: str = "hybrid", # "vector", "fts", "hybrid" top_k: int = 10, roles: Optional[List[MessageRole]] = None, agent_id: Optional[str] = None, project_id: Optional[str] = None, template_id: Optional[str] = None, conversation_id: Optional[str] = None, vector_weight: float = 0.5, fts_weight: float = 0.5, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, ) -> List[Tuple[dict, float, dict]]: """Query messages from Turbopuffer across an entire organization. Args: organization_id: Organization ID for namespace lookup (required) actor: User actor for embedding generation query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes) search_mode: Search mode - "vector", "fts", or "hybrid" (default: "hybrid") top_k: Number of results to return roles: Optional list of message roles to filter by agent_id: Optional agent ID to filter messages by project_id: Optional project ID to filter messages by template_id: Optional template ID to filter messages by conversation_id: Optional conversation ID to filter messages by. Special values: - None (omitted): Return all messages - "default": Return only default messages (conversation_id IS NULL) - Any other value: Return messages in that specific conversation vector_weight: Weight for vector search results in hybrid mode (default: 0.5) fts_weight: Weight for FTS results in hybrid mode (default: 0.5) start_date: Optional datetime to filter messages created after this date end_date: Optional datetime to filter messages created on or before this date (inclusive) Returns: List of (message_dict, score, metadata) tuples where: - message_dict contains id, text, role, created_at, agent_id - score is the final relevance score (RRF score for hybrid, rank-based for single mode) - metadata contains individual scores and ranking information """ # generate embedding for vector/hybrid search if query_text is provided query_embedding = None if query_text and search_mode in ["vector", "hybrid"]: embeddings = await self._generate_embeddings([query_text], actor) query_embedding = embeddings[0] # Check if we should fallback to timestamp-based retrieval if query_embedding is None and query_text is None and search_mode not in ["timestamp"]: # Fallback to retrieving most recent messages when no search query is provided search_mode = "timestamp" # namespace is org-scoped namespace_name = await self._get_message_namespace_name(organization_id) # build filters all_filters = [] # role filter if roles: role_values = [r.value for r in roles] if len(role_values) == 1: all_filters.append(("role", "Eq", role_values[0])) else: all_filters.append(("role", "In", role_values)) # agent filter if agent_id: all_filters.append(("agent_id", "Eq", agent_id)) # project filter if project_id: all_filters.append(("project_id", "Eq", project_id)) # template filter if template_id: all_filters.append(("template_id", "Eq", template_id)) # conversation filter # three cases: # 1. conversation_id=None (omitted) -> return all messages (no filter) # 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility # 3. conversation_id="xyz" -> return only messages in that conversation if conversation_id == "default": # "default" is reserved for default messages only (conversation_id is none) all_filters.append(("conversation_id", "Eq", None)) elif conversation_id is not None: # Specific conversation all_filters.append(("conversation_id", "Eq", conversation_id)) # date filters if start_date: # Convert to UTC to match stored timestamps if start_date.tzinfo is not None: start_date = start_date.astimezone(timezone.utc) all_filters.append(("created_at", "Gte", start_date)) if end_date: # make end_date inclusive of the entire day if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0: from datetime import timedelta end_date = end_date + timedelta(days=1) - timedelta(microseconds=1) # Convert to UTC to match stored timestamps if end_date.tzinfo is not None: end_date = end_date.astimezone(timezone.utc) all_filters.append(("created_at", "Lte", end_date)) # combine filters final_filter = None if len(all_filters) == 1: final_filter = all_filters[0] elif len(all_filters) > 1: final_filter = ("And", all_filters) try: # execute query result = await self._execute_query( namespace_name=namespace_name, search_mode=search_mode, query_embedding=query_embedding, query_text=query_text, top_k=top_k, include_attributes=True, filters=final_filter, vector_weight=vector_weight, fts_weight=fts_weight, ) # process results based on search mode if search_mode == "hybrid": # for hybrid mode, we get a multi-query response vector_results = self._process_message_query_results(result.results[0]) fts_results = self._process_message_query_results(result.results[1]) # use existing RRF method - it already returns metadata with ranks results_with_metadata = self._reciprocal_rank_fusion( vector_results=vector_results, fts_results=fts_results, get_id_func=lambda msg_dict: msg_dict["id"], vector_weight=vector_weight, fts_weight=fts_weight, top_k=top_k, ) # add raw scores to metadata if available vector_scores = {} for row in result.results[0].rows: if hasattr(row, "dist"): vector_scores[row.id] = row.dist fts_scores = {} for row in result.results[1].rows: if hasattr(row, "score"): fts_scores[row.id] = row.score # enhance metadata with raw scores enhanced_results = [] for msg_dict, rrf_score, metadata in results_with_metadata: msg_id = msg_dict["id"] if msg_id in vector_scores: metadata["vector_score"] = vector_scores[msg_id] if msg_id in fts_scores: metadata["fts_score"] = fts_scores[msg_id] enhanced_results.append((msg_dict, rrf_score, metadata)) return enhanced_results else: # for single queries (vector or fts) results = self._process_message_query_results(result) results_with_metadata = [] for idx, msg_dict in enumerate(results): metadata = { "combined_score": 1.0 / (idx + 1), "search_mode": search_mode, f"{search_mode}_rank": idx + 1, } # add raw score if available if hasattr(result.rows[idx], "dist"): metadata["vector_score"] = result.rows[idx].dist elif hasattr(result.rows[idx], "score"): metadata["fts_score"] = result.rows[idx].score results_with_metadata.append((msg_dict, metadata["combined_score"], metadata)) return results_with_metadata except Exception as e: logger.error(f"Failed to query messages from Turbopuffer: {e}") raise def _process_message_query_results(self, result) -> List[dict]: """Process results from a message query into message dicts. For RRF, we only need the rank order - scores are not used. """ messages = [] for row in result.rows: # Build message dict with key fields message_dict = { "id": row.id, "text": getattr(row, "text", ""), "organization_id": getattr(row, "organization_id", None), "agent_id": getattr(row, "agent_id", None), "role": getattr(row, "role", None), "created_at": getattr(row, "created_at", None), "conversation_id": getattr(row, "conversation_id", None), } messages.append(message_dict) return messages def _process_single_query_results( self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False ) -> List[Tuple[PydanticPassage, float]]: """Process results from a single query into passage objects with scores.""" passages_with_scores = [] for row in result.rows: # Extract tags from the result row passage_tags = getattr(row, "tags", []) or [] # Build metadata metadata = {} # Create a passage with minimal fields - embeddings are not returned from Turbopuffer passage = PydanticPassage( id=row.id, text=getattr(row, "text", ""), organization_id=getattr(row, "organization_id", None), archive_id=archive_id, # use the archive_id from the query created_at=getattr(row, "created_at", None), metadata_=metadata, tags=passage_tags, # Set the actual tags from the passage # Set required fields to empty/default values since we don't store embeddings embedding=[], # Empty embedding since we don't return it from Turbopuffer embedding_config=self.default_embedding_config, # No embedding config needed for retrieved passages ) # handle score based on search type if is_fts: # for FTS, use the BM25 score directly (higher is better) score = getattr(row, "$score", 0.0) else: # for vector search, convert distance to similarity score distance = getattr(row, "$dist", 0.0) score = 1.0 - distance passages_with_scores.append((passage, score)) return passages_with_scores def _reciprocal_rank_fusion( self, vector_results: List[Any], fts_results: List[Any], get_id_func: Callable[[Any], str], vector_weight: float, fts_weight: float, top_k: int, ) -> List[Tuple[Any, float, dict]]: """RRF implementation that works with any object type. RRF score = vector_weight * (1/(k + rank)) + fts_weight * (1/(k + rank)) where k is a constant (typically 60) to avoid division by zero This is a pure rank-based fusion following the standard RRF algorithm. Args: vector_results: List of items from vector search (ordered by relevance) fts_results: List of items from FTS (ordered by relevance) get_id_func: Function to extract ID from an item vector_weight: Weight for vector search results fts_weight: Weight for FTS results top_k: Number of results to return Returns: List of (item, score, metadata) tuples sorted by RRF score metadata contains ranks from each result list """ k = 60 # standard RRF constant from Cormack et al. (2009) # create rank mappings based on position in result lists # rank starts at 1, not 0 vector_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(vector_results)} fts_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(fts_results)} # combine all unique items from both result sets all_items = {} for item in vector_results: all_items[get_id_func(item)] = item for item in fts_results: all_items[get_id_func(item)] = item # calculate RRF scores based purely on ranks rrf_scores = {} score_metadata = {} for item_id in all_items: # RRF formula: sum of 1/(k + rank) across result lists # If item not in a list, we don't add anything (equivalent to rank = infinity) vector_rrf_score = 0.0 fts_rrf_score = 0.0 if item_id in vector_ranks: vector_rrf_score = vector_weight / (k + vector_ranks[item_id]) if item_id in fts_ranks: fts_rrf_score = fts_weight / (k + fts_ranks[item_id]) combined_score = vector_rrf_score + fts_rrf_score rrf_scores[item_id] = combined_score score_metadata[item_id] = { "combined_score": combined_score, # Final RRF score "vector_rank": vector_ranks.get(item_id), "fts_rank": fts_ranks.get(item_id), } # sort by RRF score and return with metadata sorted_results = sorted( [(all_items[iid], score, score_metadata[iid]) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True ) return sorted_results[:top_k] @trace_method @async_retry_with_backoff() async def delete_passage(self, archive_id: str, passage_id: str) -> bool: """Delete a passage from Turbopuffer.""" namespace_name = await self._get_archive_namespace_name(archive_id) try: # Run in thread pool for consistency (deletes are lightweight but use same wrapper) await asyncio.to_thread( _run_turbopuffer_write_in_thread, api_key=self.api_key, region=self.region, namespace_name=namespace_name, deletes=[passage_id], ) logger.info(f"Successfully deleted passage {passage_id} from Turbopuffer archive {archive_id}") return True except Exception as e: logger.error(f"Failed to delete passage from Turbopuffer: {e}") raise @trace_method @async_retry_with_backoff() async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool: """Delete multiple passages from Turbopuffer.""" if not passage_ids: return True namespace_name = await self._get_archive_namespace_name(archive_id) try: # Run in thread pool for consistency await asyncio.to_thread( _run_turbopuffer_write_in_thread, api_key=self.api_key, region=self.region, namespace_name=namespace_name, deletes=passage_ids, ) logger.info(f"Successfully deleted {len(passage_ids)} passages from Turbopuffer archive {archive_id}") return True except Exception as e: logger.error(f"Failed to delete passages from Turbopuffer: {e}") raise @trace_method @async_retry_with_backoff() async def delete_all_passages(self, archive_id: str) -> bool: """Delete all passages for an archive from Turbopuffer.""" from turbopuffer import AsyncTurbopuffer namespace_name = await self._get_archive_namespace_name(archive_id) try: async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: namespace = client.namespace(namespace_name) # Turbopuffer has a delete_all() method on namespace await namespace.delete_all() logger.info(f"Successfully deleted all passages for archive {archive_id}") return True except Exception as e: logger.error(f"Failed to delete all passages from Turbopuffer: {e}") raise @trace_method @async_retry_with_backoff() async def delete_messages(self, agent_id: str, organization_id: str, message_ids: List[str]) -> bool: """Delete multiple messages from Turbopuffer.""" if not message_ids: return True namespace_name = await self._get_message_namespace_name(organization_id) try: # Run in thread pool for consistency await asyncio.to_thread( _run_turbopuffer_write_in_thread, api_key=self.api_key, region=self.region, namespace_name=namespace_name, deletes=message_ids, ) logger.info(f"Successfully deleted {len(message_ids)} messages from Turbopuffer for agent {agent_id}") return True except Exception as e: logger.error(f"Failed to delete messages from Turbopuffer: {e}") raise @trace_method @async_retry_with_backoff() async def delete_all_messages(self, agent_id: str, organization_id: str) -> bool: """Delete all messages for an agent from Turbopuffer.""" namespace_name = await self._get_message_namespace_name(organization_id) try: # Run in thread pool for consistency result = await asyncio.to_thread( _run_turbopuffer_write_in_thread, api_key=self.api_key, region=self.region, namespace_name=namespace_name, delete_by_filter=("agent_id", "Eq", agent_id), ) logger.info(f"Successfully deleted all messages for agent {agent_id} (deleted {result.rows_affected if result else 0} rows)") return True except Exception as e: logger.error(f"Failed to delete all messages from Turbopuffer: {e}") raise # file/source passage methods @trace_method async def _get_file_passages_namespace_name(self, organization_id: str) -> str: """Get namespace name for file passages (org-scoped). Args: organization_id: Organization ID for namespace generation Returns: The org-scoped namespace name for file passages """ environment = settings.environment if environment: namespace_name = f"file_passages_{organization_id}_{environment.lower()}" else: namespace_name = f"file_passages_{organization_id}" return namespace_name @trace_method @async_retry_with_backoff() async def insert_file_passages( self, source_id: str, file_id: str, text_chunks: List[str], organization_id: str, actor: "PydanticUser", created_at: Optional[datetime] = None, ) -> List[PydanticPassage]: """Insert file passages into Turbopuffer using org-scoped namespace. Args: source_id: ID of the source containing the file file_id: ID of the file text_chunks: List of text chunks to store organization_id: Organization ID for the passages actor: User actor for embedding generation created_at: Optional timestamp for retroactive entries (defaults to current UTC time) Returns: List of PydanticPassage objects that were inserted """ if not text_chunks: return [] # filter out empty text chunks filtered_chunks = [text for text in text_chunks if text.strip()] if not filtered_chunks: logger.warning("All text chunks were empty, skipping file passage insertion") return [] # generate embeddings using the default config embeddings = await self._generate_embeddings(filtered_chunks, actor) namespace_name = await self._get_file_passages_namespace_name(organization_id) # handle timestamp - ensure UTC if created_at is None: timestamp = datetime.now(timezone.utc) else: # ensure the provided timestamp is timezone-aware and in UTC if created_at.tzinfo is None: # assume UTC if no timezone provided timestamp = created_at.replace(tzinfo=timezone.utc) else: # convert to UTC if in different timezone timestamp = created_at.astimezone(timezone.utc) # prepare column-based data for turbopuffer - optimized for batch insert ids = [] vectors = [] texts = [] organization_ids = [] source_ids = [] file_ids = [] created_ats = [] passages = [] for text, embedding in zip(filtered_chunks, embeddings): passage = PydanticPassage( text=text, file_id=file_id, source_id=source_id, embedding=embedding, embedding_config=self.default_embedding_config, organization_id=actor.organization_id, ) passages.append(passage) # append to columns ids.append(passage.id) vectors.append(embedding) texts.append(text) organization_ids.append(organization_id) source_ids.append(source_id) file_ids.append(file_id) created_ats.append(timestamp) # build column-based upsert data upsert_columns = { "id": ids, "vector": vectors, "text": texts, "organization_id": organization_ids, "source_id": source_ids, "file_id": file_ids, "created_at": created_ats, } try: # Use global semaphore to limit concurrent Turbopuffer writes async with _GLOBAL_TURBOPUFFER_SEMAPHORE: # Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop await asyncio.to_thread( _run_turbopuffer_write_in_thread, api_key=self.api_key, region=self.region, namespace_name=namespace_name, upsert_columns=upsert_columns, distance_metric="cosine_distance", schema={"text": {"type": "string", "full_text_search": True}}, ) logger.info(f"Successfully inserted {len(ids)} file passages to Turbopuffer for source {source_id}, file {file_id}") return passages except Exception as e: logger.error(f"Failed to insert file passages to Turbopuffer: {e}") # check if it's a duplicate ID error if "duplicate" in str(e).lower(): logger.error("Duplicate passage IDs detected in batch") raise @trace_method async def query_file_passages( self, source_ids: List[str], organization_id: str, actor: "PydanticUser", query_text: Optional[str] = None, search_mode: str = "vector", # "vector", "fts", "hybrid" top_k: int = 10, file_id: Optional[str] = None, # optional filter by specific file vector_weight: float = 0.5, fts_weight: float = 0.5, ) -> List[Tuple[PydanticPassage, float, dict]]: """Query file passages from Turbopuffer using org-scoped namespace. Args: source_ids: List of source IDs to query organization_id: Organization ID for namespace lookup actor: User actor for embedding generation query_text: Text query for search search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector") top_k: Number of results to return file_id: Optional file ID to filter results to a specific file vector_weight: Weight for vector search results in hybrid mode (default: 0.5) fts_weight: Weight for FTS results in hybrid mode (default: 0.5) Returns: List of (passage, score, metadata) tuples with relevance rankings """ # generate embedding for vector/hybrid search if query_text is provided query_embedding = None if query_text and search_mode in ["vector", "hybrid"]: embeddings = await self._generate_embeddings([query_text], actor) query_embedding = embeddings[0] # check if we should fallback to timestamp-based retrieval if query_embedding is None and query_text is None and search_mode not in ["timestamp"]: # fallback to retrieving most recent passages when no search query is provided search_mode = "timestamp" namespace_name = await self._get_file_passages_namespace_name(organization_id) # build filters - always filter by source_ids if len(source_ids) == 1: # single source_id, use Eq for efficiency filters = [("source_id", "Eq", source_ids[0])] else: # multiple source_ids, use In operator filters = [("source_id", "In", source_ids)] # add file filter if specified if file_id: filters.append(("file_id", "Eq", file_id)) # combine filters final_filter = filters[0] if len(filters) == 1 else ("And", filters) try: # use generic query executor result = await self._execute_query( namespace_name=namespace_name, search_mode=search_mode, query_embedding=query_embedding, query_text=query_text, top_k=top_k, include_attributes=["text", "organization_id", "source_id", "file_id", "created_at"], filters=final_filter, vector_weight=vector_weight, fts_weight=fts_weight, ) # process results based on search mode if search_mode == "hybrid": # for hybrid mode, we get a multi-query response vector_results = self._process_file_query_results(result.results[0]) fts_results = self._process_file_query_results(result.results[1], is_fts=True) # use RRF and include metadata with ranks results_with_metadata = self._reciprocal_rank_fusion( vector_results=[passage for passage, _ in vector_results], fts_results=[passage for passage, _ in fts_results], get_id_func=lambda p: p.id, vector_weight=vector_weight, fts_weight=fts_weight, top_k=top_k, ) return results_with_metadata else: # for single queries (vector, fts, timestamp) - add basic metadata is_fts = search_mode == "fts" results = self._process_file_query_results(result, is_fts=is_fts) # add simple metadata for single search modes results_with_metadata = [] for idx, (passage, score) in enumerate(results): metadata = { "combined_score": score, f"{search_mode}_rank": idx + 1, # add the rank for this search mode } results_with_metadata.append((passage, score, metadata)) return results_with_metadata except Exception as e: logger.error(f"Failed to query file passages from Turbopuffer: {e}") raise def _process_file_query_results(self, result, is_fts: bool = False) -> List[Tuple[PydanticPassage, float]]: """Process results from a file query into passage objects with scores.""" passages_with_scores = [] for row in result.rows: # build metadata metadata = {} # create a passage with minimal fields - embeddings are not returned from Turbopuffer passage = PydanticPassage( id=row.id, text=getattr(row, "text", ""), organization_id=getattr(row, "organization_id", None), source_id=getattr(row, "source_id", None), # get source_id from the row file_id=getattr(row, "file_id", None), created_at=getattr(row, "created_at", None), metadata_=metadata, tags=[], # set required fields to empty/default values since we don't store embeddings embedding=[], # empty embedding since we don't return it from Turbopuffer embedding_config=self.default_embedding_config, ) # handle score based on search type if is_fts: # for FTS, use the BM25 score directly (higher is better) score = getattr(row, "$score", 0.0) else: # for vector search, convert distance to similarity score distance = getattr(row, "$dist", 0.0) score = 1.0 - distance passages_with_scores.append((passage, score)) return passages_with_scores @trace_method @async_retry_with_backoff() async def delete_file_passages(self, source_id: str, file_id: str, organization_id: str) -> bool: """Delete all passages for a specific file from Turbopuffer.""" namespace_name = await self._get_file_passages_namespace_name(organization_id) try: # use delete_by_filter to only delete passages for this file # need to filter by both source_id and file_id filter_expr = ("And", [("source_id", "Eq", source_id), ("file_id", "Eq", file_id)]) # Run in thread pool for consistency result = await asyncio.to_thread( _run_turbopuffer_write_in_thread, api_key=self.api_key, region=self.region, namespace_name=namespace_name, delete_by_filter=filter_expr, ) logger.info( f"Successfully deleted passages for file {file_id} from source {source_id} (deleted {result.rows_affected if result else 0} rows)" ) return True except Exception as e: logger.error(f"Failed to delete file passages from Turbopuffer: {e}") raise @trace_method @async_retry_with_backoff() async def delete_source_passages(self, source_id: str, organization_id: str) -> bool: """Delete all passages for a source from Turbopuffer.""" namespace_name = await self._get_file_passages_namespace_name(organization_id) try: # Run in thread pool for consistency result = await asyncio.to_thread( _run_turbopuffer_write_in_thread, api_key=self.api_key, region=self.region, namespace_name=namespace_name, delete_by_filter=("source_id", "Eq", source_id), ) logger.info(f"Successfully deleted all passages for source {source_id} (deleted {result.rows_affected if result else 0} rows)") return True except Exception as e: logger.error(f"Failed to delete source passages from Turbopuffer: {e}") raise # tool methods @trace_method @async_retry_with_backoff() async def delete_tools(self, organization_id: str, tool_ids: List[str]) -> bool: """Delete tools from Turbopuffer. Args: organization_id: Organization ID for namespace lookup tool_ids: List of tool IDs to delete Returns: True if successful """ if not tool_ids: return True namespace_name = await self._get_tool_namespace_name(organization_id) try: # Run in thread pool for consistency await asyncio.to_thread( _run_turbopuffer_write_in_thread, api_key=self.api_key, region=self.region, namespace_name=namespace_name, deletes=tool_ids, ) logger.info(f"Successfully deleted {len(tool_ids)} tools from Turbopuffer") return True except Exception as e: logger.error(f"Failed to delete tools from Turbopuffer: {e}") raise @trace_method async def query_tools( self, organization_id: str, actor: "PydanticUser", query_text: Optional[str] = None, search_mode: str = "hybrid", # "vector", "fts", "hybrid", "timestamp" top_k: int = 50, tool_types: Optional[List[str]] = None, tags: Optional[List[str]] = None, vector_weight: float = 0.5, fts_weight: float = 0.5, ) -> List[Tuple[dict, float, dict]]: """Query tools from Turbopuffer using semantic search. Args: organization_id: Organization ID for namespace lookup actor: User actor for embedding generation query_text: Text query for search search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" top_k: Number of results to return tool_types: Optional list of tool types to filter by tags: Optional list of tags to filter by (match any) vector_weight: Weight for vector search in hybrid mode fts_weight: Weight for FTS in hybrid mode Returns: List of (tool_dict, score, metadata) tuples """ # Generate embedding for vector/hybrid search query_embedding = None if query_text and search_mode in ["vector", "hybrid"]: embeddings = await self._generate_embeddings([query_text], actor) query_embedding = embeddings[0] if embeddings else None # Fallback to timestamp-based retrieval when no query if query_embedding is None and query_text is None and search_mode not in ["timestamp"]: search_mode = "timestamp" namespace_name = await self._get_tool_namespace_name(organization_id) # Build filters all_filters = [] if tool_types: if len(tool_types) == 1: all_filters.append(("tool_type", "Eq", tool_types[0])) else: all_filters.append(("tool_type", "In", tool_types)) if tags: all_filters.append(("tags", "ContainsAny", tags)) # Combine filters final_filter = None if len(all_filters) == 1: final_filter = all_filters[0] elif len(all_filters) > 1: final_filter = ("And", all_filters) try: result = await self._execute_query( namespace_name=namespace_name, search_mode=search_mode, query_embedding=query_embedding, query_text=query_text, top_k=top_k, include_attributes=["text", "name", "organization_id", "tool_type", "tags", "created_at"], filters=final_filter, vector_weight=vector_weight, fts_weight=fts_weight, ) if search_mode == "hybrid": vector_results = self._process_tool_query_results(result.results[0]) fts_results = self._process_tool_query_results(result.results[1]) results_with_metadata = self._reciprocal_rank_fusion( vector_results=vector_results, fts_results=fts_results, get_id_func=lambda d: d["id"], vector_weight=vector_weight, fts_weight=fts_weight, top_k=top_k, ) return results_with_metadata else: results = self._process_tool_query_results(result) results_with_metadata = [] for idx, tool_dict in enumerate(results): metadata = { "combined_score": 1.0 / (idx + 1), "search_mode": search_mode, f"{search_mode}_rank": idx + 1, } results_with_metadata.append((tool_dict, metadata["combined_score"], metadata)) return results_with_metadata except Exception as e: logger.error(f"Failed to query tools from Turbopuffer: {e}") raise def _process_tool_query_results(self, result) -> List[dict]: """Process results from a tool query into tool dicts.""" tools = [] for row in result.rows: tool_dict = { "id": row.id, "text": getattr(row, "text", ""), "name": getattr(row, "name", ""), "organization_id": getattr(row, "organization_id", None), "tool_type": getattr(row, "tool_type", None), "tags": getattr(row, "tags", []), "created_at": getattr(row, "created_at", None), } tools.append(tool_dict) return tools