Revert "fix: truncate oversized text in embedding requests (#9196)" This reverts commit a9c342087e022519c63d62fb76b72aed8859539b.
2118 lines
86 KiB
Python
2118 lines
86 KiB
Python
"""Turbopuffer utilities for archival memory storage."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import random
|
|
from datetime import datetime, timezone
|
|
from functools import wraps
|
|
from typing import Any, Callable, List, Optional, Tuple, TypeVar
|
|
|
|
import httpx
|
|
|
|
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
|
from letta.errors import LettaInvalidArgumentError
|
|
from letta.otel.tracing import trace_method, log_event
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import MessageRole, TagMatchMode
|
|
from letta.schemas.passage import Passage as PydanticPassage
|
|
from letta.settings import model_settings, settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Type variable for generic async retry decorator
|
|
T = TypeVar("T")
|
|
|
|
# Default retry configuration for turbopuffer operations
|
|
TPUF_MAX_RETRIES = 3
|
|
TPUF_INITIAL_DELAY = 1.0 # seconds
|
|
TPUF_EXPONENTIAL_BASE = 2.0
|
|
TPUF_JITTER = True
|
|
|
|
|
|
def is_transient_error(error: Exception) -> bool:
|
|
"""Check if an error is transient and should be retried.
|
|
|
|
Args:
|
|
error: The exception to check
|
|
|
|
Returns:
|
|
True if the error is transient and can be retried
|
|
"""
|
|
# httpx connection errors (network issues, DNS failures, etc.)
|
|
if isinstance(error, httpx.ConnectError):
|
|
return True
|
|
|
|
# httpx timeout errors
|
|
if isinstance(error, httpx.TimeoutException):
|
|
return True
|
|
|
|
# httpx network errors
|
|
if isinstance(error, httpx.NetworkError):
|
|
return True
|
|
|
|
# Check for connection-related errors in the error message
|
|
error_str = str(error).lower()
|
|
transient_patterns = [
|
|
"connect call failed",
|
|
"connection refused",
|
|
"connection reset",
|
|
"connection timed out",
|
|
"temporary failure",
|
|
"name resolution",
|
|
"dns",
|
|
"network unreachable",
|
|
"no route to host",
|
|
"ssl handshake",
|
|
]
|
|
for pattern in transient_patterns:
|
|
if pattern in error_str:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def async_retry_with_backoff(
|
|
max_retries: int = TPUF_MAX_RETRIES,
|
|
initial_delay: float = TPUF_INITIAL_DELAY,
|
|
exponential_base: float = TPUF_EXPONENTIAL_BASE,
|
|
jitter: bool = TPUF_JITTER,
|
|
):
|
|
"""Decorator for async functions that retries on transient errors with exponential backoff.
|
|
|
|
Args:
|
|
max_retries: Maximum number of retry attempts
|
|
initial_delay: Initial delay between retries in seconds
|
|
exponential_base: Base for exponential backoff calculation
|
|
jitter: Whether to add random jitter to delays
|
|
|
|
Returns:
|
|
Decorated async function with retry logic
|
|
"""
|
|
|
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs) -> Any:
|
|
num_retries = 0
|
|
delay = initial_delay
|
|
last_error: Optional[Exception] = None
|
|
|
|
while True:
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except Exception as e:
|
|
# Check if this is a retryable error
|
|
if not is_transient_error(e):
|
|
# Not a transient error, re-raise immediately
|
|
raise
|
|
|
|
last_error = e
|
|
num_retries += 1
|
|
|
|
# Log the retry attempt
|
|
log_event(
|
|
"turbopuffer_retry_attempt",
|
|
{
|
|
"attempt": num_retries,
|
|
"delay": delay,
|
|
"error_type": type(e).__name__,
|
|
"error": str(e),
|
|
"function": func.__name__,
|
|
},
|
|
)
|
|
logger.warning(
|
|
f"Turbopuffer operation '{func.__name__}' failed with transient error "
|
|
f"(attempt {num_retries}/{max_retries}): {e}. Retrying in {delay:.1f}s..."
|
|
)
|
|
|
|
# Check if max retries exceeded
|
|
if num_retries > max_retries:
|
|
log_event(
|
|
"turbopuffer_max_retries_exceeded",
|
|
{
|
|
"max_retries": max_retries,
|
|
"error_type": type(e).__name__,
|
|
"error": str(e),
|
|
"function": func.__name__,
|
|
},
|
|
)
|
|
logger.error(
|
|
f"Turbopuffer operation '{func.__name__}' failed after {max_retries} retries: {e}"
|
|
)
|
|
raise
|
|
|
|
# Wait with exponential backoff
|
|
await asyncio.sleep(delay)
|
|
|
|
# Calculate next delay with optional jitter
|
|
delay *= exponential_base
|
|
if jitter:
|
|
delay *= 1 + random.random() * 0.1 # Add up to 10% jitter
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
# Global semaphore for Turbopuffer operations to prevent overwhelming the service
|
|
# This is separate from embedding semaphore since Turbopuffer can handle more concurrency
|
|
_GLOBAL_TURBOPUFFER_SEMAPHORE = asyncio.Semaphore(5)
|
|
|
|
|
|
def _run_turbopuffer_write_in_thread(
|
|
api_key: str,
|
|
region: str,
|
|
namespace_name: str,
|
|
upsert_columns: dict = None,
|
|
deletes: list = None,
|
|
delete_by_filter: tuple = None,
|
|
distance_metric: str = "cosine_distance",
|
|
schema: dict = None,
|
|
):
|
|
"""
|
|
Sync wrapper to run turbopuffer write in isolated event loop.
|
|
|
|
Turbopuffer's async write() does CPU-intensive base64 encoding of vectors
|
|
synchronously in async functions, blocking the event loop. Running it in
|
|
a thread pool with an isolated event loop prevents blocking.
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
# Create new event loop for this worker thread
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
|
|
async def do_write():
|
|
async with AsyncTurbopuffer(api_key=api_key, region=region) as client:
|
|
namespace = client.namespace(namespace_name)
|
|
|
|
# Build write kwargs
|
|
kwargs = {"distance_metric": distance_metric}
|
|
if upsert_columns:
|
|
kwargs["upsert_columns"] = upsert_columns
|
|
if deletes:
|
|
kwargs["deletes"] = deletes
|
|
if delete_by_filter:
|
|
kwargs["delete_by_filter"] = delete_by_filter
|
|
if schema:
|
|
kwargs["schema"] = schema
|
|
|
|
return await namespace.write(**kwargs)
|
|
|
|
return loop.run_until_complete(do_write())
|
|
finally:
|
|
loop.close()
|
|
|
|
|
|
def should_use_tpuf() -> bool:
|
|
# We need OpenAI since we default to their embedding model
|
|
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key) and bool(model_settings.openai_api_key)
|
|
|
|
|
|
def should_use_tpuf_for_messages() -> bool:
|
|
"""Check if Turbopuffer should be used for messages."""
|
|
return should_use_tpuf() and bool(settings.embed_all_messages)
|
|
|
|
|
|
def should_use_tpuf_for_tools() -> bool:
|
|
"""Check if Turbopuffer should be used for tools."""
|
|
return should_use_tpuf() and bool(settings.embed_tools)
|
|
|
|
|
|
class TurbopufferClient:
|
|
"""Client for managing archival memory with Turbopuffer vector database."""
|
|
|
|
default_embedding_config = EmbeddingConfig(
|
|
embedding_model="text-embedding-3-small",
|
|
embedding_endpoint_type="openai",
|
|
embedding_endpoint="https://api.openai.com/v1",
|
|
embedding_dim=1536,
|
|
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
|
)
|
|
|
|
def __init__(self, api_key: str = None, region: str = None):
|
|
"""Initialize Turbopuffer client."""
|
|
self.api_key = api_key or settings.tpuf_api_key
|
|
self.region = region or settings.tpuf_region
|
|
|
|
from letta.services.agent_manager import AgentManager
|
|
from letta.services.archive_manager import ArchiveManager
|
|
|
|
self.archive_manager = ArchiveManager()
|
|
self.agent_manager = AgentManager()
|
|
|
|
if not self.api_key:
|
|
raise ValueError("Turbopuffer API key not provided")
|
|
|
|
@trace_method
|
|
async def _generate_embeddings(self, texts: List[str], actor: "PydanticUser") -> List[List[float]]:
|
|
"""Generate embeddings using the default embedding configuration.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
actor: User actor for embedding generation
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
from letta.llm_api.llm_client import LLMClient
|
|
|
|
# filter out empty strings after stripping
|
|
filtered_texts = [text for text in texts if text.strip()]
|
|
|
|
# skip embedding if no valid texts
|
|
if not filtered_texts:
|
|
return []
|
|
|
|
embedding_client = LLMClient.create(
|
|
provider_type=self.default_embedding_config.embedding_endpoint_type,
|
|
actor=actor,
|
|
)
|
|
embeddings = await embedding_client.request_embeddings(filtered_texts, self.default_embedding_config)
|
|
return embeddings
|
|
|
|
@trace_method
|
|
async def _get_archive_namespace_name(self, archive_id: str) -> str:
|
|
"""Get namespace name for a specific archive."""
|
|
return await self.archive_manager.get_or_set_vector_db_namespace_async(archive_id)
|
|
|
|
@trace_method
|
|
async def _get_message_namespace_name(self, organization_id: str) -> str:
|
|
"""Get namespace name for messages (org-scoped).
|
|
|
|
Args:
|
|
organization_id: Organization ID for namespace generation
|
|
|
|
Returns:
|
|
The org-scoped namespace name for messages
|
|
"""
|
|
environment = settings.environment
|
|
if environment:
|
|
namespace_name = f"messages_{organization_id}_{environment.lower()}"
|
|
else:
|
|
namespace_name = f"messages_{organization_id}"
|
|
|
|
return namespace_name
|
|
|
|
@trace_method
|
|
async def _get_tool_namespace_name(self, organization_id: str) -> str:
|
|
"""Get namespace name for tools (org-scoped).
|
|
|
|
Args:
|
|
organization_id: Organization ID for namespace generation
|
|
|
|
Returns:
|
|
The org-scoped namespace name for tools
|
|
"""
|
|
environment = settings.environment
|
|
if environment:
|
|
namespace_name = f"tools_{organization_id}_{environment.lower()}"
|
|
else:
|
|
namespace_name = f"tools_{organization_id}"
|
|
|
|
return namespace_name
|
|
|
|
def _extract_tool_text(self, tool: "PydanticTool") -> str:
|
|
"""Extract searchable text from a tool for embedding.
|
|
|
|
Combines name, description, and JSON schema into a structured format
|
|
that provides rich context for semantic search.
|
|
|
|
Args:
|
|
tool: The tool to extract text from
|
|
|
|
Returns:
|
|
JSON-formatted string containing tool information
|
|
"""
|
|
|
|
parts = {
|
|
"name": tool.name or "",
|
|
"description": tool.description or "",
|
|
}
|
|
|
|
# Extract parameter information from JSON schema
|
|
if tool.json_schema:
|
|
# Include function description from schema if different from tool description
|
|
schema_description = tool.json_schema.get("description", "")
|
|
if schema_description and schema_description != tool.description:
|
|
parts["schema_description"] = schema_description
|
|
|
|
# Extract parameter information
|
|
parameters = tool.json_schema.get("parameters", {})
|
|
if parameters:
|
|
properties = parameters.get("properties", {})
|
|
param_descriptions = []
|
|
for param_name, param_info in properties.items():
|
|
param_desc = param_info.get("description", "")
|
|
param_type = param_info.get("type", "any")
|
|
if param_desc:
|
|
param_descriptions.append(f"{param_name} ({param_type}): {param_desc}")
|
|
else:
|
|
param_descriptions.append(f"{param_name} ({param_type})")
|
|
if param_descriptions:
|
|
parts["parameters"] = param_descriptions
|
|
|
|
# Include tags for additional context
|
|
if tool.tags:
|
|
parts["tags"] = tool.tags
|
|
|
|
return json.dumps(parts)
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def insert_tools(
|
|
self,
|
|
tools: List["PydanticTool"],
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
) -> bool:
|
|
"""Insert tools into Turbopuffer.
|
|
|
|
Args:
|
|
tools: List of tools to store
|
|
organization_id: Organization ID for the tools
|
|
actor: User actor for embedding generation
|
|
|
|
Returns:
|
|
True if successful
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
if not tools:
|
|
return True
|
|
|
|
# Extract text and filter out empty content
|
|
tool_texts = []
|
|
valid_tools = []
|
|
for tool in tools:
|
|
text = self._extract_tool_text(tool)
|
|
if text.strip():
|
|
tool_texts.append(text)
|
|
valid_tools.append(tool)
|
|
|
|
if not valid_tools:
|
|
logger.warning("All tools had empty text content, skipping insertion")
|
|
return True
|
|
|
|
# Generate embeddings
|
|
embeddings = await self._generate_embeddings(tool_texts, actor)
|
|
|
|
namespace_name = await self._get_tool_namespace_name(organization_id)
|
|
|
|
# Prepare column-based data
|
|
ids = []
|
|
vectors = []
|
|
texts = []
|
|
names = []
|
|
organization_ids = []
|
|
tool_types = []
|
|
tags_arrays = []
|
|
created_ats = []
|
|
|
|
for tool, text, embedding in zip(valid_tools, tool_texts, embeddings):
|
|
ids.append(tool.id)
|
|
vectors.append(embedding)
|
|
texts.append(text)
|
|
names.append(tool.name or "")
|
|
organization_ids.append(organization_id)
|
|
tool_types.append(tool.tool_type.value if tool.tool_type else "custom")
|
|
tags_arrays.append(tool.tags or [])
|
|
created_ats.append(getattr(tool, "created_at", None) or datetime.now(timezone.utc))
|
|
|
|
upsert_columns = {
|
|
"id": ids,
|
|
"vector": vectors,
|
|
"text": texts,
|
|
"name": names,
|
|
"organization_id": organization_ids,
|
|
"tool_type": tool_types,
|
|
"tags": tags_arrays,
|
|
"created_at": created_ats,
|
|
}
|
|
|
|
try:
|
|
# Use global semaphore to limit concurrent Turbopuffer writes
|
|
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
|
|
# Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
upsert_columns=upsert_columns,
|
|
distance_metric="cosine_distance",
|
|
schema={"text": {"type": "string", "full_text_search": True}},
|
|
)
|
|
logger.info(f"Successfully inserted {len(ids)} tools to Turbopuffer")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to insert tools to Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def insert_archival_memories(
|
|
self,
|
|
archive_id: str,
|
|
text_chunks: List[str],
|
|
passage_ids: List[str],
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
tags: Optional[List[str]] = None,
|
|
created_at: Optional[datetime] = None,
|
|
embeddings: Optional[List[List[float]]] = None,
|
|
) -> List[PydanticPassage]:
|
|
"""Insert passages into Turbopuffer.
|
|
|
|
Args:
|
|
archive_id: ID of the archive
|
|
text_chunks: List of text chunks to store
|
|
passage_ids: List of passage IDs (must match 1:1 with text_chunks)
|
|
organization_id: Organization ID for the passages
|
|
actor: User actor for embedding generation
|
|
tags: Optional list of tags to attach to all passages
|
|
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
|
|
embeddings: Optional pre-computed embeddings (must match 1:1 with text_chunks). If provided, skips embedding generation.
|
|
|
|
Returns:
|
|
List of PydanticPassage objects that were inserted
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
# filter out empty text chunks
|
|
filtered_chunks = [(i, text) for i, text in enumerate(text_chunks) if text.strip()]
|
|
|
|
if not filtered_chunks:
|
|
logger.warning("All text chunks were empty, skipping insertion")
|
|
return []
|
|
|
|
filtered_texts = [text for _, text in filtered_chunks]
|
|
|
|
# use provided embeddings only if dimensions match TPUF's expected dimension
|
|
use_provided_embeddings = False
|
|
if embeddings is not None:
|
|
if len(embeddings) != len(text_chunks):
|
|
raise LettaInvalidArgumentError(
|
|
f"embeddings length ({len(embeddings)}) must match text_chunks length ({len(text_chunks)})",
|
|
argument_name="embeddings",
|
|
)
|
|
# check if first non-empty embedding has correct dimensions
|
|
filtered_indices = [i for i, _ in filtered_chunks]
|
|
sample_embedding = embeddings[filtered_indices[0]] if filtered_indices else None
|
|
if sample_embedding is not None and len(sample_embedding) == self.default_embedding_config.embedding_dim:
|
|
use_provided_embeddings = True
|
|
filtered_embeddings = [embeddings[i] for i, _ in filtered_chunks]
|
|
else:
|
|
logger.debug(
|
|
f"Embedding dimension mismatch (got {len(sample_embedding) if sample_embedding else 'None'}, "
|
|
f"expected {self.default_embedding_config.embedding_dim}), regenerating embeddings"
|
|
)
|
|
|
|
if not use_provided_embeddings:
|
|
filtered_embeddings = await self._generate_embeddings(filtered_texts, actor)
|
|
|
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
|
|
|
# handle timestamp - ensure UTC
|
|
if created_at is None:
|
|
timestamp = datetime.now(timezone.utc)
|
|
else:
|
|
# ensure the provided timestamp is timezone-aware and in UTC
|
|
if created_at.tzinfo is None:
|
|
# assume UTC if no timezone provided
|
|
timestamp = created_at.replace(tzinfo=timezone.utc)
|
|
else:
|
|
# convert to UTC if in different timezone
|
|
timestamp = created_at.astimezone(timezone.utc)
|
|
|
|
# passage_ids must be provided for dual-write consistency
|
|
if not passage_ids:
|
|
raise ValueError("passage_ids must be provided for Turbopuffer insertion")
|
|
if len(passage_ids) != len(text_chunks):
|
|
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})")
|
|
|
|
# prepare column-based data for turbopuffer - optimized for batch insert
|
|
ids = []
|
|
vectors = []
|
|
texts = []
|
|
organization_ids = []
|
|
archive_ids = []
|
|
created_ats = []
|
|
tags_arrays = [] # Store tags as arrays
|
|
passages = []
|
|
|
|
for (original_idx, text), embedding in zip(filtered_chunks, filtered_embeddings):
|
|
passage_id = passage_ids[original_idx]
|
|
|
|
# append to columns
|
|
ids.append(passage_id)
|
|
vectors.append(embedding)
|
|
texts.append(text)
|
|
organization_ids.append(organization_id)
|
|
archive_ids.append(archive_id)
|
|
created_ats.append(timestamp)
|
|
tags_arrays.append(tags or []) # Store tags as array
|
|
|
|
# Create PydanticPassage object
|
|
passage = PydanticPassage(
|
|
id=passage_id,
|
|
text=text,
|
|
organization_id=organization_id,
|
|
archive_id=archive_id,
|
|
created_at=timestamp,
|
|
metadata_={},
|
|
tags=tags or [], # Include tags in the passage
|
|
embedding=embedding,
|
|
embedding_config=self.default_embedding_config, # Will be set by caller if needed
|
|
)
|
|
passages.append(passage)
|
|
|
|
# build column-based upsert data
|
|
upsert_columns = {
|
|
"id": ids,
|
|
"vector": vectors,
|
|
"text": texts,
|
|
"organization_id": organization_ids,
|
|
"archive_id": archive_ids,
|
|
"created_at": created_ats,
|
|
"tags": tags_arrays, # Add tags as array column
|
|
}
|
|
|
|
try:
|
|
# Use global semaphore to limit concurrent Turbopuffer writes
|
|
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
|
|
# Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
upsert_columns=upsert_columns,
|
|
distance_metric="cosine_distance",
|
|
schema={"text": {"type": "string", "full_text_search": True}},
|
|
)
|
|
logger.info(f"Successfully inserted {len(ids)} passages to Turbopuffer for archive {archive_id}")
|
|
return passages
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to insert passages to Turbopuffer: {e}")
|
|
# check if it's a duplicate ID error
|
|
if "duplicate" in str(e).lower():
|
|
logger.error("Duplicate passage IDs detected in batch")
|
|
raise
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def insert_messages(
|
|
self,
|
|
agent_id: str,
|
|
message_texts: List[str],
|
|
message_ids: List[str],
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
roles: List[MessageRole],
|
|
created_ats: List[datetime],
|
|
project_id: Optional[str] = None,
|
|
template_id: Optional[str] = None,
|
|
conversation_ids: Optional[List[Optional[str]]] = None,
|
|
) -> bool:
|
|
"""Insert messages into Turbopuffer.
|
|
|
|
Args:
|
|
agent_id: ID of the agent
|
|
message_texts: List of message text content to store
|
|
message_ids: List of message IDs (must match 1:1 with message_texts)
|
|
organization_id: Organization ID for the messages
|
|
actor: User actor for embedding generation
|
|
roles: List of message roles corresponding to each message
|
|
created_ats: List of creation timestamps for each message
|
|
project_id: Optional project ID for all messages
|
|
template_id: Optional template ID for all messages
|
|
conversation_ids: Optional list of conversation IDs (one per message, must match 1:1 with message_texts)
|
|
|
|
Returns:
|
|
True if successful
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
# filter out empty message texts
|
|
filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()]
|
|
|
|
if not filtered_messages:
|
|
logger.warning("All message texts were empty, skipping insertion")
|
|
return True
|
|
|
|
# generate embeddings using the default config
|
|
filtered_texts = [text for _, text in filtered_messages]
|
|
embeddings = await self._generate_embeddings(filtered_texts, actor)
|
|
|
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
|
|
|
# validation checks
|
|
if not message_ids:
|
|
raise ValueError("message_ids must be provided for Turbopuffer insertion")
|
|
if len(message_ids) != len(message_texts):
|
|
raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
|
|
if len(message_ids) != len(roles):
|
|
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
|
|
if len(message_ids) != len(created_ats):
|
|
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
|
|
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
|
|
raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})")
|
|
|
|
# prepare column-based data for turbopuffer - optimized for batch insert
|
|
ids = []
|
|
vectors = []
|
|
texts = []
|
|
organization_ids_list = []
|
|
agent_ids_list = []
|
|
message_roles = []
|
|
created_at_timestamps = []
|
|
project_ids_list = []
|
|
template_ids_list = []
|
|
conversation_ids_list = []
|
|
|
|
for (original_idx, text), embedding in zip(filtered_messages, embeddings):
|
|
message_id = message_ids[original_idx]
|
|
role = roles[original_idx]
|
|
created_at = created_ats[original_idx]
|
|
conversation_id = conversation_ids[original_idx] if conversation_ids else None
|
|
|
|
# ensure the provided timestamp is timezone-aware and in UTC
|
|
if created_at.tzinfo is None:
|
|
# assume UTC if no timezone provided
|
|
timestamp = created_at.replace(tzinfo=timezone.utc)
|
|
else:
|
|
# convert to UTC if in different timezone
|
|
timestamp = created_at.astimezone(timezone.utc)
|
|
|
|
# append to columns
|
|
ids.append(message_id)
|
|
vectors.append(embedding)
|
|
texts.append(text)
|
|
organization_ids_list.append(organization_id)
|
|
agent_ids_list.append(agent_id)
|
|
message_roles.append(role.value)
|
|
created_at_timestamps.append(timestamp)
|
|
project_ids_list.append(project_id)
|
|
template_ids_list.append(template_id)
|
|
conversation_ids_list.append(conversation_id)
|
|
|
|
# build column-based upsert data
|
|
upsert_columns = {
|
|
"id": ids,
|
|
"vector": vectors,
|
|
"text": texts,
|
|
"organization_id": organization_ids_list,
|
|
"agent_id": agent_ids_list,
|
|
"role": message_roles,
|
|
"created_at": created_at_timestamps,
|
|
}
|
|
|
|
# only include conversation_id if it's provided
|
|
if conversation_ids is not None:
|
|
upsert_columns["conversation_id"] = conversation_ids_list
|
|
|
|
# only include project_id if it's provided
|
|
if project_id is not None:
|
|
upsert_columns["project_id"] = project_ids_list
|
|
|
|
# only include template_id if it's provided
|
|
if template_id is not None:
|
|
upsert_columns["template_id"] = template_ids_list
|
|
|
|
try:
|
|
# Use global semaphore to limit concurrent Turbopuffer writes
|
|
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
|
|
# Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
upsert_columns=upsert_columns,
|
|
distance_metric="cosine_distance",
|
|
schema={
|
|
"text": {"type": "string", "full_text_search": True},
|
|
"conversation_id": {"type": "string"},
|
|
},
|
|
)
|
|
logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to insert messages to Turbopuffer: {e}")
|
|
# check if it's a duplicate ID error
|
|
if "duplicate" in str(e).lower():
|
|
logger.error("Duplicate message IDs detected in batch")
|
|
raise
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def _execute_query(
|
|
self,
|
|
namespace_name: str,
|
|
search_mode: str,
|
|
query_embedding: Optional[List[float]],
|
|
query_text: Optional[str],
|
|
top_k: int,
|
|
include_attributes: List[str],
|
|
filters: Optional[Any] = None,
|
|
vector_weight: float = 0.5,
|
|
fts_weight: float = 0.5,
|
|
) -> Any:
|
|
"""Generic query execution for Turbopuffer.
|
|
|
|
Args:
|
|
namespace_name: Turbopuffer namespace to query
|
|
search_mode: "vector", "fts", "hybrid", or "timestamp"
|
|
query_embedding: Embedding for vector search
|
|
query_text: Text for full-text search
|
|
top_k: Number of results to return
|
|
include_attributes: Attributes to include in results
|
|
filters: Turbopuffer filter expression
|
|
vector_weight: Weight for vector search in hybrid mode
|
|
fts_weight: Weight for FTS in hybrid mode
|
|
|
|
Returns:
|
|
Raw Turbopuffer query results or multi-query response
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
from turbopuffer.types import QueryParam
|
|
|
|
# validate inputs based on search mode
|
|
if search_mode == "vector" and query_embedding is None:
|
|
raise ValueError("query_embedding is required for vector search mode")
|
|
if search_mode == "fts" and query_text is None:
|
|
raise ValueError("query_text is required for FTS search mode")
|
|
if search_mode == "hybrid":
|
|
if query_embedding is None or query_text is None:
|
|
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
|
|
if search_mode not in ["vector", "fts", "hybrid", "timestamp"]:
|
|
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'")
|
|
|
|
try:
|
|
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
|
namespace = client.namespace(namespace_name)
|
|
|
|
if search_mode == "timestamp":
|
|
# retrieve most recent items by timestamp
|
|
query_params = {
|
|
"rank_by": ("created_at", "desc"),
|
|
"top_k": top_k,
|
|
"include_attributes": include_attributes,
|
|
}
|
|
if filters:
|
|
query_params["filters"] = filters
|
|
return await namespace.query(**query_params)
|
|
|
|
elif search_mode == "vector":
|
|
# vector search query
|
|
query_params = {
|
|
"rank_by": ("vector", "ANN", query_embedding),
|
|
"top_k": top_k,
|
|
"include_attributes": include_attributes,
|
|
}
|
|
if filters:
|
|
query_params["filters"] = filters
|
|
return await namespace.query(**query_params)
|
|
|
|
elif search_mode == "fts":
|
|
# full-text search query
|
|
query_params = {
|
|
"rank_by": ("text", "BM25", query_text),
|
|
"top_k": top_k,
|
|
"include_attributes": include_attributes,
|
|
}
|
|
if filters:
|
|
query_params["filters"] = filters
|
|
return await namespace.query(**query_params)
|
|
|
|
else: # hybrid mode
|
|
queries = []
|
|
|
|
# vector search query
|
|
vector_query = {
|
|
"rank_by": ("vector", "ANN", query_embedding),
|
|
"top_k": top_k,
|
|
"include_attributes": include_attributes,
|
|
}
|
|
if filters:
|
|
vector_query["filters"] = filters
|
|
queries.append(vector_query)
|
|
|
|
# full-text search query
|
|
fts_query = {
|
|
"rank_by": ("text", "BM25", query_text),
|
|
"top_k": top_k,
|
|
"include_attributes": include_attributes,
|
|
}
|
|
if filters:
|
|
fts_query["filters"] = filters
|
|
queries.append(fts_query)
|
|
|
|
# execute multi-query
|
|
return await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
|
|
except Exception as e:
|
|
# Wrap turbopuffer errors with user-friendly messages
|
|
from turbopuffer import NotFoundError
|
|
|
|
if isinstance(e, NotFoundError):
|
|
# Extract just the error message without implementation details
|
|
error_msg = str(e)
|
|
if "namespace" in error_msg.lower() and "not found" in error_msg.lower():
|
|
raise ValueError("No conversation history found. Please send a message first to enable search.") from e
|
|
raise ValueError(f"Search data not found: {error_msg}") from e
|
|
# Re-raise other errors as-is
|
|
raise
|
|
|
|
@trace_method
|
|
async def query_passages(
|
|
self,
|
|
archive_id: str,
|
|
actor: "PydanticUser",
|
|
query_text: Optional[str] = None,
|
|
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
|
top_k: int = 10,
|
|
tags: Optional[List[str]] = None,
|
|
tag_match_mode: TagMatchMode = TagMatchMode.ANY,
|
|
vector_weight: float = 0.5,
|
|
fts_weight: float = 0.5,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
) -> List[Tuple[PydanticPassage, float, dict]]:
|
|
"""Query passages from Turbopuffer using vector search, full-text search, or hybrid search.
|
|
|
|
Args:
|
|
archive_id: ID of the archive
|
|
actor: User actor for embedding generation
|
|
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
|
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
|
|
top_k: Number of results to return
|
|
tags: Optional list of tags to filter by
|
|
tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY
|
|
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
|
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
|
start_date: Optional datetime to filter passages created after this date
|
|
end_date: Optional datetime to filter passages created on or before this date (inclusive)
|
|
|
|
Returns:
|
|
List of (passage, score, metadata) tuples with relevance rankings
|
|
"""
|
|
# generate embedding for vector/hybrid search if query_text is provided
|
|
query_embedding = None
|
|
if query_text and search_mode in ["vector", "hybrid"]:
|
|
embeddings = await self._generate_embeddings([query_text], actor)
|
|
query_embedding = embeddings[0]
|
|
|
|
# Check if we should fallback to timestamp-based retrieval
|
|
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
|
# Fallback to retrieving most recent passages when no search query is provided
|
|
search_mode = "timestamp"
|
|
|
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
|
|
|
# build tag filter conditions
|
|
tag_filter = None
|
|
if tags:
|
|
if tag_match_mode == TagMatchMode.ALL:
|
|
# For ALL mode, need to check each tag individually with Contains
|
|
tag_conditions = []
|
|
for tag in tags:
|
|
tag_conditions.append(("tags", "Contains", tag))
|
|
if len(tag_conditions) == 1:
|
|
tag_filter = tag_conditions[0]
|
|
else:
|
|
tag_filter = ("And", tag_conditions)
|
|
else: # tag_match_mode == TagMatchMode.ANY
|
|
# For ANY mode, use ContainsAny to match any of the tags
|
|
tag_filter = ("tags", "ContainsAny", tags)
|
|
|
|
# build date filter conditions
|
|
date_filters = []
|
|
if start_date:
|
|
# Convert to UTC to match stored timestamps
|
|
if start_date.tzinfo is not None:
|
|
start_date = start_date.astimezone(timezone.utc)
|
|
date_filters.append(("created_at", "Gte", start_date))
|
|
if end_date:
|
|
# if end_date has no time component (is at midnight), adjust to end of day
|
|
# to make the filter inclusive of the entire day
|
|
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
|
|
from datetime import timedelta
|
|
|
|
# add 1 day and subtract 1 microsecond to get 23:59:59.999999
|
|
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
|
|
# Convert to UTC to match stored timestamps
|
|
if end_date.tzinfo is not None:
|
|
end_date = end_date.astimezone(timezone.utc)
|
|
date_filters.append(("created_at", "Lte", end_date))
|
|
|
|
# combine all filters
|
|
all_filters = []
|
|
if tag_filter:
|
|
all_filters.append(tag_filter)
|
|
if date_filters:
|
|
all_filters.extend(date_filters)
|
|
|
|
# create final filter expression
|
|
final_filter = None
|
|
if len(all_filters) == 1:
|
|
final_filter = all_filters[0]
|
|
elif len(all_filters) > 1:
|
|
final_filter = ("And", all_filters)
|
|
|
|
try:
|
|
# use generic query executor
|
|
result = await self._execute_query(
|
|
namespace_name=namespace_name,
|
|
search_mode=search_mode,
|
|
query_embedding=query_embedding,
|
|
query_text=query_text,
|
|
top_k=top_k,
|
|
include_attributes=["text", "organization_id", "archive_id", "created_at", "tags"],
|
|
filters=final_filter,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
)
|
|
|
|
# process results based on search mode
|
|
if search_mode == "hybrid":
|
|
# for hybrid mode, we get a multi-query response
|
|
vector_results = self._process_single_query_results(result.results[0], archive_id, tags)
|
|
fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True)
|
|
# use RRF and include metadata with ranks
|
|
results_with_metadata = self._reciprocal_rank_fusion(
|
|
vector_results=[passage for passage, _ in vector_results],
|
|
fts_results=[passage for passage, _ in fts_results],
|
|
get_id_func=lambda p: p.id,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
top_k=top_k,
|
|
)
|
|
# Return (passage, score, metadata) with ranks
|
|
return results_with_metadata
|
|
else:
|
|
# for single queries (vector, fts, timestamp) - add basic metadata
|
|
is_fts = search_mode == "fts"
|
|
results = self._process_single_query_results(result, archive_id, tags, is_fts=is_fts)
|
|
# Add simple metadata for single search modes
|
|
results_with_metadata = []
|
|
for idx, (passage, score) in enumerate(results):
|
|
metadata = {
|
|
"combined_score": score,
|
|
f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
|
|
}
|
|
results_with_metadata.append((passage, score, metadata))
|
|
return results_with_metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to query passages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
async def query_messages_by_agent_id(
|
|
self,
|
|
agent_id: str,
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
query_text: Optional[str] = None,
|
|
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
|
|
top_k: int = 10,
|
|
roles: Optional[List[MessageRole]] = None,
|
|
project_id: Optional[str] = None,
|
|
template_id: Optional[str] = None,
|
|
conversation_id: Optional[str] = None,
|
|
vector_weight: float = 0.5,
|
|
fts_weight: float = 0.5,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
) -> List[Tuple[dict, float, dict]]:
|
|
"""Query messages from Turbopuffer using vector search, full-text search, or hybrid search.
|
|
|
|
Args:
|
|
agent_id: ID of the agent (used for filtering results)
|
|
organization_id: Organization ID for namespace lookup
|
|
actor: User actor for embedding generation
|
|
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
|
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector")
|
|
top_k: Number of results to return
|
|
roles: Optional list of message roles to filter by
|
|
project_id: Optional project ID to filter messages by
|
|
template_id: Optional template ID to filter messages by
|
|
conversation_id: Optional conversation ID to filter messages by (use "default" for NULL)
|
|
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
|
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
|
start_date: Optional datetime to filter messages created after this date
|
|
end_date: Optional datetime to filter messages created on or before this date (inclusive)
|
|
|
|
Returns:
|
|
List of (message_dict, score, metadata) tuples where:
|
|
- message_dict contains id, text, role, created_at
|
|
- score is the final relevance score
|
|
- metadata contains individual scores and ranking information
|
|
"""
|
|
# generate embedding for vector/hybrid search if query_text is provided
|
|
query_embedding = None
|
|
if query_text and search_mode in ["vector", "hybrid"]:
|
|
embeddings = await self._generate_embeddings([query_text], actor)
|
|
query_embedding = embeddings[0]
|
|
|
|
# Check if we should fallback to timestamp-based retrieval
|
|
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
|
# Fallback to retrieving most recent messages when no search query is provided
|
|
search_mode = "timestamp"
|
|
|
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
|
|
|
# build agent_id filter
|
|
agent_filter = ("agent_id", "Eq", agent_id)
|
|
|
|
# build role filter conditions
|
|
role_filter = None
|
|
if roles:
|
|
role_values = [r.value for r in roles]
|
|
if len(role_values) == 1:
|
|
role_filter = ("role", "Eq", role_values[0])
|
|
else:
|
|
role_filter = ("role", "In", role_values)
|
|
|
|
# build date filter conditions
|
|
date_filters = []
|
|
if start_date:
|
|
# Convert to UTC to match stored timestamps
|
|
if start_date.tzinfo is not None:
|
|
start_date = start_date.astimezone(timezone.utc)
|
|
date_filters.append(("created_at", "Gte", start_date))
|
|
if end_date:
|
|
# if end_date has no time component (is at midnight), adjust to end of day
|
|
# to make the filter inclusive of the entire day
|
|
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
|
|
from datetime import timedelta
|
|
|
|
# add 1 day and subtract 1 microsecond to get 23:59:59.999999
|
|
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
|
|
# Convert to UTC to match stored timestamps
|
|
if end_date.tzinfo is not None:
|
|
end_date = end_date.astimezone(timezone.utc)
|
|
date_filters.append(("created_at", "Lte", end_date))
|
|
|
|
# build project_id filter if provided
|
|
project_filter = None
|
|
if project_id:
|
|
project_filter = ("project_id", "Eq", project_id)
|
|
|
|
# build template_id filter if provided
|
|
template_filter = None
|
|
if template_id:
|
|
template_filter = ("template_id", "Eq", template_id)
|
|
|
|
# build conversation_id filter if provided
|
|
# three cases:
|
|
# 1. conversation_id=None (omitted) -> return all messages (no filter)
|
|
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
|
|
# 3. conversation_id="xyz" -> return only messages in that conversation
|
|
conversation_filter = None
|
|
if conversation_id == "default":
|
|
# "default" is reserved for default messages only (conversation_id is none)
|
|
conversation_filter = ("conversation_id", "Eq", None)
|
|
elif conversation_id is not None:
|
|
# Specific conversation
|
|
conversation_filter = ("conversation_id", "Eq", conversation_id)
|
|
|
|
# combine all filters
|
|
all_filters = [agent_filter] # always include agent_id filter
|
|
if role_filter:
|
|
all_filters.append(role_filter)
|
|
if project_filter:
|
|
all_filters.append(project_filter)
|
|
if template_filter:
|
|
all_filters.append(template_filter)
|
|
if conversation_filter:
|
|
all_filters.append(conversation_filter)
|
|
if date_filters:
|
|
all_filters.extend(date_filters)
|
|
|
|
# create final filter expression
|
|
final_filter = None
|
|
if len(all_filters) == 1:
|
|
final_filter = all_filters[0]
|
|
elif len(all_filters) > 1:
|
|
final_filter = ("And", all_filters)
|
|
|
|
try:
|
|
# use generic query executor
|
|
result = await self._execute_query(
|
|
namespace_name=namespace_name,
|
|
search_mode=search_mode,
|
|
query_embedding=query_embedding,
|
|
query_text=query_text,
|
|
top_k=top_k,
|
|
include_attributes=True,
|
|
filters=final_filter,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
)
|
|
|
|
# process results based on search mode
|
|
if search_mode == "hybrid":
|
|
# for hybrid mode, we get a multi-query response
|
|
vector_results = self._process_message_query_results(result.results[0])
|
|
fts_results = self._process_message_query_results(result.results[1])
|
|
# use RRF with lambda to extract ID from dict - returns metadata
|
|
results_with_metadata = self._reciprocal_rank_fusion(
|
|
vector_results=vector_results,
|
|
fts_results=fts_results,
|
|
get_id_func=lambda msg_dict: msg_dict["id"],
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
top_k=top_k,
|
|
)
|
|
# return results with metadata
|
|
return results_with_metadata
|
|
else:
|
|
# for single queries (vector, fts, timestamp)
|
|
results = self._process_message_query_results(result)
|
|
# add simple metadata for single search modes
|
|
results_with_metadata = []
|
|
for idx, msg_dict in enumerate(results):
|
|
metadata = {
|
|
"combined_score": 1.0 / (idx + 1), # Use rank-based score for single mode
|
|
"search_mode": search_mode,
|
|
f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
|
|
}
|
|
results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
|
|
return results_with_metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
async def query_messages_by_org_id(
|
|
self,
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
query_text: Optional[str] = None,
|
|
search_mode: str = "hybrid", # "vector", "fts", "hybrid"
|
|
top_k: int = 10,
|
|
roles: Optional[List[MessageRole]] = None,
|
|
agent_id: Optional[str] = None,
|
|
project_id: Optional[str] = None,
|
|
template_id: Optional[str] = None,
|
|
conversation_id: Optional[str] = None,
|
|
vector_weight: float = 0.5,
|
|
fts_weight: float = 0.5,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
) -> List[Tuple[dict, float, dict]]:
|
|
"""Query messages from Turbopuffer across an entire organization.
|
|
|
|
Args:
|
|
organization_id: Organization ID for namespace lookup (required)
|
|
actor: User actor for embedding generation
|
|
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
|
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "hybrid")
|
|
top_k: Number of results to return
|
|
roles: Optional list of message roles to filter by
|
|
agent_id: Optional agent ID to filter messages by
|
|
project_id: Optional project ID to filter messages by
|
|
template_id: Optional template ID to filter messages by
|
|
conversation_id: Optional conversation ID to filter messages by. Special values:
|
|
- None (omitted): Return all messages
|
|
- "default": Return only default messages (conversation_id IS NULL)
|
|
- Any other value: Return messages in that specific conversation
|
|
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
|
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
|
start_date: Optional datetime to filter messages created after this date
|
|
end_date: Optional datetime to filter messages created on or before this date (inclusive)
|
|
|
|
Returns:
|
|
List of (message_dict, score, metadata) tuples where:
|
|
- message_dict contains id, text, role, created_at, agent_id
|
|
- score is the final relevance score (RRF score for hybrid, rank-based for single mode)
|
|
- metadata contains individual scores and ranking information
|
|
"""
|
|
# generate embedding for vector/hybrid search if query_text is provided
|
|
query_embedding = None
|
|
if query_text and search_mode in ["vector", "hybrid"]:
|
|
embeddings = await self._generate_embeddings([query_text], actor)
|
|
query_embedding = embeddings[0]
|
|
|
|
# Check if we should fallback to timestamp-based retrieval
|
|
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
|
# Fallback to retrieving most recent messages when no search query is provided
|
|
search_mode = "timestamp"
|
|
|
|
# namespace is org-scoped
|
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
|
|
|
# build filters
|
|
all_filters = []
|
|
|
|
# role filter
|
|
if roles:
|
|
role_values = [r.value for r in roles]
|
|
if len(role_values) == 1:
|
|
all_filters.append(("role", "Eq", role_values[0]))
|
|
else:
|
|
all_filters.append(("role", "In", role_values))
|
|
|
|
# agent filter
|
|
if agent_id:
|
|
all_filters.append(("agent_id", "Eq", agent_id))
|
|
|
|
# project filter
|
|
if project_id:
|
|
all_filters.append(("project_id", "Eq", project_id))
|
|
|
|
# template filter
|
|
if template_id:
|
|
all_filters.append(("template_id", "Eq", template_id))
|
|
|
|
# conversation filter
|
|
# three cases:
|
|
# 1. conversation_id=None (omitted) -> return all messages (no filter)
|
|
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
|
|
# 3. conversation_id="xyz" -> return only messages in that conversation
|
|
if conversation_id == "default":
|
|
# "default" is reserved for default messages only (conversation_id is none)
|
|
all_filters.append(("conversation_id", "Eq", None))
|
|
elif conversation_id is not None:
|
|
# Specific conversation
|
|
all_filters.append(("conversation_id", "Eq", conversation_id))
|
|
|
|
# date filters
|
|
if start_date:
|
|
# Convert to UTC to match stored timestamps
|
|
if start_date.tzinfo is not None:
|
|
start_date = start_date.astimezone(timezone.utc)
|
|
all_filters.append(("created_at", "Gte", start_date))
|
|
if end_date:
|
|
# make end_date inclusive of the entire day
|
|
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
|
|
from datetime import timedelta
|
|
|
|
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
|
|
# Convert to UTC to match stored timestamps
|
|
if end_date.tzinfo is not None:
|
|
end_date = end_date.astimezone(timezone.utc)
|
|
all_filters.append(("created_at", "Lte", end_date))
|
|
|
|
# combine filters
|
|
final_filter = None
|
|
if len(all_filters) == 1:
|
|
final_filter = all_filters[0]
|
|
elif len(all_filters) > 1:
|
|
final_filter = ("And", all_filters)
|
|
|
|
try:
|
|
# execute query
|
|
result = await self._execute_query(
|
|
namespace_name=namespace_name,
|
|
search_mode=search_mode,
|
|
query_embedding=query_embedding,
|
|
query_text=query_text,
|
|
top_k=top_k,
|
|
include_attributes=True,
|
|
filters=final_filter,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
)
|
|
|
|
# process results based on search mode
|
|
if search_mode == "hybrid":
|
|
# for hybrid mode, we get a multi-query response
|
|
vector_results = self._process_message_query_results(result.results[0])
|
|
fts_results = self._process_message_query_results(result.results[1])
|
|
|
|
# use existing RRF method - it already returns metadata with ranks
|
|
results_with_metadata = self._reciprocal_rank_fusion(
|
|
vector_results=vector_results,
|
|
fts_results=fts_results,
|
|
get_id_func=lambda msg_dict: msg_dict["id"],
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
top_k=top_k,
|
|
)
|
|
|
|
# add raw scores to metadata if available
|
|
vector_scores = {}
|
|
for row in result.results[0].rows:
|
|
if hasattr(row, "dist"):
|
|
vector_scores[row.id] = row.dist
|
|
|
|
fts_scores = {}
|
|
for row in result.results[1].rows:
|
|
if hasattr(row, "score"):
|
|
fts_scores[row.id] = row.score
|
|
|
|
# enhance metadata with raw scores
|
|
enhanced_results = []
|
|
for msg_dict, rrf_score, metadata in results_with_metadata:
|
|
msg_id = msg_dict["id"]
|
|
if msg_id in vector_scores:
|
|
metadata["vector_score"] = vector_scores[msg_id]
|
|
if msg_id in fts_scores:
|
|
metadata["fts_score"] = fts_scores[msg_id]
|
|
enhanced_results.append((msg_dict, rrf_score, metadata))
|
|
|
|
return enhanced_results
|
|
else:
|
|
# for single queries (vector or fts)
|
|
results = self._process_message_query_results(result)
|
|
results_with_metadata = []
|
|
for idx, msg_dict in enumerate(results):
|
|
metadata = {
|
|
"combined_score": 1.0 / (idx + 1),
|
|
"search_mode": search_mode,
|
|
f"{search_mode}_rank": idx + 1,
|
|
}
|
|
|
|
# add raw score if available
|
|
if hasattr(result.rows[idx], "dist"):
|
|
metadata["vector_score"] = result.rows[idx].dist
|
|
elif hasattr(result.rows[idx], "score"):
|
|
metadata["fts_score"] = result.rows[idx].score
|
|
|
|
results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
|
|
|
|
return results_with_metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
def _process_message_query_results(self, result) -> List[dict]:
|
|
"""Process results from a message query into message dicts.
|
|
|
|
For RRF, we only need the rank order - scores are not used.
|
|
"""
|
|
messages = []
|
|
|
|
for row in result.rows:
|
|
# Build message dict with key fields
|
|
message_dict = {
|
|
"id": row.id,
|
|
"text": getattr(row, "text", ""),
|
|
"organization_id": getattr(row, "organization_id", None),
|
|
"agent_id": getattr(row, "agent_id", None),
|
|
"role": getattr(row, "role", None),
|
|
"created_at": getattr(row, "created_at", None),
|
|
"conversation_id": getattr(row, "conversation_id", None),
|
|
}
|
|
messages.append(message_dict)
|
|
|
|
return messages
|
|
|
|
def _process_single_query_results(
|
|
self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False
|
|
) -> List[Tuple[PydanticPassage, float]]:
|
|
"""Process results from a single query into passage objects with scores."""
|
|
passages_with_scores = []
|
|
|
|
for row in result.rows:
|
|
# Extract tags from the result row
|
|
passage_tags = getattr(row, "tags", []) or []
|
|
|
|
# Build metadata
|
|
metadata = {}
|
|
|
|
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
|
|
passage = PydanticPassage(
|
|
id=row.id,
|
|
text=getattr(row, "text", ""),
|
|
organization_id=getattr(row, "organization_id", None),
|
|
archive_id=archive_id, # use the archive_id from the query
|
|
created_at=getattr(row, "created_at", None),
|
|
metadata_=metadata,
|
|
tags=passage_tags, # Set the actual tags from the passage
|
|
# Set required fields to empty/default values since we don't store embeddings
|
|
embedding=[], # Empty embedding since we don't return it from Turbopuffer
|
|
embedding_config=self.default_embedding_config, # No embedding config needed for retrieved passages
|
|
)
|
|
|
|
# handle score based on search type
|
|
if is_fts:
|
|
# for FTS, use the BM25 score directly (higher is better)
|
|
score = getattr(row, "$score", 0.0)
|
|
else:
|
|
# for vector search, convert distance to similarity score
|
|
distance = getattr(row, "$dist", 0.0)
|
|
score = 1.0 - distance
|
|
|
|
passages_with_scores.append((passage, score))
|
|
|
|
return passages_with_scores
|
|
|
|
def _reciprocal_rank_fusion(
|
|
self,
|
|
vector_results: List[Any],
|
|
fts_results: List[Any],
|
|
get_id_func: Callable[[Any], str],
|
|
vector_weight: float,
|
|
fts_weight: float,
|
|
top_k: int,
|
|
) -> List[Tuple[Any, float, dict]]:
|
|
"""RRF implementation that works with any object type.
|
|
|
|
RRF score = vector_weight * (1/(k + rank)) + fts_weight * (1/(k + rank))
|
|
where k is a constant (typically 60) to avoid division by zero
|
|
|
|
This is a pure rank-based fusion following the standard RRF algorithm.
|
|
|
|
Args:
|
|
vector_results: List of items from vector search (ordered by relevance)
|
|
fts_results: List of items from FTS (ordered by relevance)
|
|
get_id_func: Function to extract ID from an item
|
|
vector_weight: Weight for vector search results
|
|
fts_weight: Weight for FTS results
|
|
top_k: Number of results to return
|
|
|
|
Returns:
|
|
List of (item, score, metadata) tuples sorted by RRF score
|
|
metadata contains ranks from each result list
|
|
"""
|
|
k = 60 # standard RRF constant from Cormack et al. (2009)
|
|
|
|
# create rank mappings based on position in result lists
|
|
# rank starts at 1, not 0
|
|
vector_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(vector_results)}
|
|
fts_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(fts_results)}
|
|
|
|
# combine all unique items from both result sets
|
|
all_items = {}
|
|
for item in vector_results:
|
|
all_items[get_id_func(item)] = item
|
|
for item in fts_results:
|
|
all_items[get_id_func(item)] = item
|
|
|
|
# calculate RRF scores based purely on ranks
|
|
rrf_scores = {}
|
|
score_metadata = {}
|
|
for item_id in all_items:
|
|
# RRF formula: sum of 1/(k + rank) across result lists
|
|
# If item not in a list, we don't add anything (equivalent to rank = infinity)
|
|
vector_rrf_score = 0.0
|
|
fts_rrf_score = 0.0
|
|
|
|
if item_id in vector_ranks:
|
|
vector_rrf_score = vector_weight / (k + vector_ranks[item_id])
|
|
if item_id in fts_ranks:
|
|
fts_rrf_score = fts_weight / (k + fts_ranks[item_id])
|
|
|
|
combined_score = vector_rrf_score + fts_rrf_score
|
|
|
|
rrf_scores[item_id] = combined_score
|
|
score_metadata[item_id] = {
|
|
"combined_score": combined_score, # Final RRF score
|
|
"vector_rank": vector_ranks.get(item_id),
|
|
"fts_rank": fts_ranks.get(item_id),
|
|
}
|
|
|
|
# sort by RRF score and return with metadata
|
|
sorted_results = sorted(
|
|
[(all_items[iid], score, score_metadata[iid]) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True
|
|
)
|
|
|
|
return sorted_results[:top_k]
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
|
|
"""Delete a passage from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
|
|
|
try:
|
|
# Run in thread pool for consistency (deletes are lightweight but use same wrapper)
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
deletes=[passage_id],
|
|
)
|
|
logger.info(f"Successfully deleted passage {passage_id} from Turbopuffer archive {archive_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete passage from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool:
|
|
"""Delete multiple passages from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
if not passage_ids:
|
|
return True
|
|
|
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
|
|
|
try:
|
|
# Run in thread pool for consistency
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
deletes=passage_ids,
|
|
)
|
|
logger.info(f"Successfully deleted {len(passage_ids)} passages from Turbopuffer archive {archive_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete passages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def delete_all_passages(self, archive_id: str) -> bool:
|
|
"""Delete all passages for an archive from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
|
|
|
try:
|
|
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
|
namespace = client.namespace(namespace_name)
|
|
# Turbopuffer has a delete_all() method on namespace
|
|
await namespace.delete_all()
|
|
logger.info(f"Successfully deleted all passages for archive {archive_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete all passages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def delete_messages(self, agent_id: str, organization_id: str, message_ids: List[str]) -> bool:
|
|
"""Delete multiple messages from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
if not message_ids:
|
|
return True
|
|
|
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
|
|
|
try:
|
|
# Run in thread pool for consistency
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
deletes=message_ids,
|
|
)
|
|
logger.info(f"Successfully deleted {len(message_ids)} messages from Turbopuffer for agent {agent_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete messages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def delete_all_messages(self, agent_id: str, organization_id: str) -> bool:
|
|
"""Delete all messages for an agent from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
|
|
|
try:
|
|
# Run in thread pool for consistency
|
|
result = await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
delete_by_filter=("agent_id", "Eq", agent_id),
|
|
)
|
|
logger.info(f"Successfully deleted all messages for agent {agent_id} (deleted {result.rows_affected if result else 0} rows)")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete all messages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
# file/source passage methods
|
|
|
|
@trace_method
|
|
async def _get_file_passages_namespace_name(self, organization_id: str) -> str:
|
|
"""Get namespace name for file passages (org-scoped).
|
|
|
|
Args:
|
|
organization_id: Organization ID for namespace generation
|
|
|
|
Returns:
|
|
The org-scoped namespace name for file passages
|
|
"""
|
|
environment = settings.environment
|
|
if environment:
|
|
namespace_name = f"file_passages_{organization_id}_{environment.lower()}"
|
|
else:
|
|
namespace_name = f"file_passages_{organization_id}"
|
|
|
|
return namespace_name
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def insert_file_passages(
|
|
self,
|
|
source_id: str,
|
|
file_id: str,
|
|
text_chunks: List[str],
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
created_at: Optional[datetime] = None,
|
|
) -> List[PydanticPassage]:
|
|
"""Insert file passages into Turbopuffer using org-scoped namespace.
|
|
|
|
Args:
|
|
source_id: ID of the source containing the file
|
|
file_id: ID of the file
|
|
text_chunks: List of text chunks to store
|
|
organization_id: Organization ID for the passages
|
|
actor: User actor for embedding generation
|
|
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
|
|
|
|
Returns:
|
|
List of PydanticPassage objects that were inserted
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
if not text_chunks:
|
|
return []
|
|
|
|
# filter out empty text chunks
|
|
filtered_chunks = [text for text in text_chunks if text.strip()]
|
|
|
|
if not filtered_chunks:
|
|
logger.warning("All text chunks were empty, skipping file passage insertion")
|
|
return []
|
|
|
|
# generate embeddings using the default config
|
|
embeddings = await self._generate_embeddings(filtered_chunks, actor)
|
|
|
|
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
|
|
|
# handle timestamp - ensure UTC
|
|
if created_at is None:
|
|
timestamp = datetime.now(timezone.utc)
|
|
else:
|
|
# ensure the provided timestamp is timezone-aware and in UTC
|
|
if created_at.tzinfo is None:
|
|
# assume UTC if no timezone provided
|
|
timestamp = created_at.replace(tzinfo=timezone.utc)
|
|
else:
|
|
# convert to UTC if in different timezone
|
|
timestamp = created_at.astimezone(timezone.utc)
|
|
|
|
# prepare column-based data for turbopuffer - optimized for batch insert
|
|
ids = []
|
|
vectors = []
|
|
texts = []
|
|
organization_ids = []
|
|
source_ids = []
|
|
file_ids = []
|
|
created_ats = []
|
|
passages = []
|
|
|
|
for text, embedding in zip(filtered_chunks, embeddings):
|
|
passage = PydanticPassage(
|
|
text=text,
|
|
file_id=file_id,
|
|
source_id=source_id,
|
|
embedding=embedding,
|
|
embedding_config=self.default_embedding_config,
|
|
organization_id=actor.organization_id,
|
|
)
|
|
passages.append(passage)
|
|
|
|
# append to columns
|
|
ids.append(passage.id)
|
|
vectors.append(embedding)
|
|
texts.append(text)
|
|
organization_ids.append(organization_id)
|
|
source_ids.append(source_id)
|
|
file_ids.append(file_id)
|
|
created_ats.append(timestamp)
|
|
|
|
# build column-based upsert data
|
|
upsert_columns = {
|
|
"id": ids,
|
|
"vector": vectors,
|
|
"text": texts,
|
|
"organization_id": organization_ids,
|
|
"source_id": source_ids,
|
|
"file_id": file_ids,
|
|
"created_at": created_ats,
|
|
}
|
|
|
|
try:
|
|
# Use global semaphore to limit concurrent Turbopuffer writes
|
|
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
|
|
# Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
upsert_columns=upsert_columns,
|
|
distance_metric="cosine_distance",
|
|
schema={"text": {"type": "string", "full_text_search": True}},
|
|
)
|
|
logger.info(f"Successfully inserted {len(ids)} file passages to Turbopuffer for source {source_id}, file {file_id}")
|
|
return passages
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to insert file passages to Turbopuffer: {e}")
|
|
# check if it's a duplicate ID error
|
|
if "duplicate" in str(e).lower():
|
|
logger.error("Duplicate passage IDs detected in batch")
|
|
raise
|
|
|
|
@trace_method
|
|
async def query_file_passages(
|
|
self,
|
|
source_ids: List[str],
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
query_text: Optional[str] = None,
|
|
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
|
top_k: int = 10,
|
|
file_id: Optional[str] = None, # optional filter by specific file
|
|
vector_weight: float = 0.5,
|
|
fts_weight: float = 0.5,
|
|
) -> List[Tuple[PydanticPassage, float, dict]]:
|
|
"""Query file passages from Turbopuffer using org-scoped namespace.
|
|
|
|
Args:
|
|
source_ids: List of source IDs to query
|
|
organization_id: Organization ID for namespace lookup
|
|
actor: User actor for embedding generation
|
|
query_text: Text query for search
|
|
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
|
|
top_k: Number of results to return
|
|
file_id: Optional file ID to filter results to a specific file
|
|
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
|
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
|
|
|
Returns:
|
|
List of (passage, score, metadata) tuples with relevance rankings
|
|
"""
|
|
# generate embedding for vector/hybrid search if query_text is provided
|
|
query_embedding = None
|
|
if query_text and search_mode in ["vector", "hybrid"]:
|
|
embeddings = await self._generate_embeddings([query_text], actor)
|
|
query_embedding = embeddings[0]
|
|
|
|
# check if we should fallback to timestamp-based retrieval
|
|
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
|
# fallback to retrieving most recent passages when no search query is provided
|
|
search_mode = "timestamp"
|
|
|
|
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
|
|
|
# build filters - always filter by source_ids
|
|
if len(source_ids) == 1:
|
|
# single source_id, use Eq for efficiency
|
|
filters = [("source_id", "Eq", source_ids[0])]
|
|
else:
|
|
# multiple source_ids, use In operator
|
|
filters = [("source_id", "In", source_ids)]
|
|
|
|
# add file filter if specified
|
|
if file_id:
|
|
filters.append(("file_id", "Eq", file_id))
|
|
|
|
# combine filters
|
|
final_filter = filters[0] if len(filters) == 1 else ("And", filters)
|
|
|
|
try:
|
|
# use generic query executor
|
|
result = await self._execute_query(
|
|
namespace_name=namespace_name,
|
|
search_mode=search_mode,
|
|
query_embedding=query_embedding,
|
|
query_text=query_text,
|
|
top_k=top_k,
|
|
include_attributes=["text", "organization_id", "source_id", "file_id", "created_at"],
|
|
filters=final_filter,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
)
|
|
|
|
# process results based on search mode
|
|
if search_mode == "hybrid":
|
|
# for hybrid mode, we get a multi-query response
|
|
vector_results = self._process_file_query_results(result.results[0])
|
|
fts_results = self._process_file_query_results(result.results[1], is_fts=True)
|
|
# use RRF and include metadata with ranks
|
|
results_with_metadata = self._reciprocal_rank_fusion(
|
|
vector_results=[passage for passage, _ in vector_results],
|
|
fts_results=[passage for passage, _ in fts_results],
|
|
get_id_func=lambda p: p.id,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
top_k=top_k,
|
|
)
|
|
return results_with_metadata
|
|
else:
|
|
# for single queries (vector, fts, timestamp) - add basic metadata
|
|
is_fts = search_mode == "fts"
|
|
results = self._process_file_query_results(result, is_fts=is_fts)
|
|
# add simple metadata for single search modes
|
|
results_with_metadata = []
|
|
for idx, (passage, score) in enumerate(results):
|
|
metadata = {
|
|
"combined_score": score,
|
|
f"{search_mode}_rank": idx + 1, # add the rank for this search mode
|
|
}
|
|
results_with_metadata.append((passage, score, metadata))
|
|
return results_with_metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to query file passages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
def _process_file_query_results(self, result, is_fts: bool = False) -> List[Tuple[PydanticPassage, float]]:
|
|
"""Process results from a file query into passage objects with scores."""
|
|
passages_with_scores = []
|
|
|
|
for row in result.rows:
|
|
# build metadata
|
|
metadata = {}
|
|
|
|
# create a passage with minimal fields - embeddings are not returned from Turbopuffer
|
|
passage = PydanticPassage(
|
|
id=row.id,
|
|
text=getattr(row, "text", ""),
|
|
organization_id=getattr(row, "organization_id", None),
|
|
source_id=getattr(row, "source_id", None), # get source_id from the row
|
|
file_id=getattr(row, "file_id", None),
|
|
created_at=getattr(row, "created_at", None),
|
|
metadata_=metadata,
|
|
tags=[],
|
|
# set required fields to empty/default values since we don't store embeddings
|
|
embedding=[], # empty embedding since we don't return it from Turbopuffer
|
|
embedding_config=self.default_embedding_config,
|
|
)
|
|
|
|
# handle score based on search type
|
|
if is_fts:
|
|
# for FTS, use the BM25 score directly (higher is better)
|
|
score = getattr(row, "$score", 0.0)
|
|
else:
|
|
# for vector search, convert distance to similarity score
|
|
distance = getattr(row, "$dist", 0.0)
|
|
score = 1.0 - distance
|
|
|
|
passages_with_scores.append((passage, score))
|
|
|
|
return passages_with_scores
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def delete_file_passages(self, source_id: str, file_id: str, organization_id: str) -> bool:
|
|
"""Delete all passages for a specific file from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
|
|
|
try:
|
|
# use delete_by_filter to only delete passages for this file
|
|
# need to filter by both source_id and file_id
|
|
filter_expr = ("And", [("source_id", "Eq", source_id), ("file_id", "Eq", file_id)])
|
|
|
|
# Run in thread pool for consistency
|
|
result = await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
delete_by_filter=filter_expr,
|
|
)
|
|
logger.info(
|
|
f"Successfully deleted passages for file {file_id} from source {source_id} (deleted {result.rows_affected if result else 0} rows)"
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete file passages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def delete_source_passages(self, source_id: str, organization_id: str) -> bool:
|
|
"""Delete all passages for a source from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
|
|
|
try:
|
|
# Run in thread pool for consistency
|
|
result = await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
delete_by_filter=("source_id", "Eq", source_id),
|
|
)
|
|
logger.info(f"Successfully deleted all passages for source {source_id} (deleted {result.rows_affected if result else 0} rows)")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete source passages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
# tool methods
|
|
|
|
@trace_method
|
|
@async_retry_with_backoff()
|
|
async def delete_tools(self, organization_id: str, tool_ids: List[str]) -> bool:
|
|
"""Delete tools from Turbopuffer.
|
|
|
|
Args:
|
|
organization_id: Organization ID for namespace lookup
|
|
tool_ids: List of tool IDs to delete
|
|
|
|
Returns:
|
|
True if successful
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
if not tool_ids:
|
|
return True
|
|
|
|
namespace_name = await self._get_tool_namespace_name(organization_id)
|
|
|
|
try:
|
|
# Run in thread pool for consistency
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
deletes=tool_ids,
|
|
)
|
|
logger.info(f"Successfully deleted {len(tool_ids)} tools from Turbopuffer")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete tools from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
async def query_tools(
|
|
self,
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
query_text: Optional[str] = None,
|
|
search_mode: str = "hybrid", # "vector", "fts", "hybrid", "timestamp"
|
|
top_k: int = 50,
|
|
tool_types: Optional[List[str]] = None,
|
|
tags: Optional[List[str]] = None,
|
|
vector_weight: float = 0.5,
|
|
fts_weight: float = 0.5,
|
|
) -> List[Tuple[dict, float, dict]]:
|
|
"""Query tools from Turbopuffer using semantic search.
|
|
|
|
Args:
|
|
organization_id: Organization ID for namespace lookup
|
|
actor: User actor for embedding generation
|
|
query_text: Text query for search
|
|
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp"
|
|
top_k: Number of results to return
|
|
tool_types: Optional list of tool types to filter by
|
|
tags: Optional list of tags to filter by (match any)
|
|
vector_weight: Weight for vector search in hybrid mode
|
|
fts_weight: Weight for FTS in hybrid mode
|
|
|
|
Returns:
|
|
List of (tool_dict, score, metadata) tuples
|
|
"""
|
|
# Generate embedding for vector/hybrid search
|
|
query_embedding = None
|
|
if query_text and search_mode in ["vector", "hybrid"]:
|
|
embeddings = await self._generate_embeddings([query_text], actor)
|
|
query_embedding = embeddings[0] if embeddings else None
|
|
|
|
# Fallback to timestamp-based retrieval when no query
|
|
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
|
search_mode = "timestamp"
|
|
|
|
namespace_name = await self._get_tool_namespace_name(organization_id)
|
|
|
|
# Build filters
|
|
all_filters = []
|
|
|
|
if tool_types:
|
|
if len(tool_types) == 1:
|
|
all_filters.append(("tool_type", "Eq", tool_types[0]))
|
|
else:
|
|
all_filters.append(("tool_type", "In", tool_types))
|
|
|
|
if tags:
|
|
all_filters.append(("tags", "ContainsAny", tags))
|
|
|
|
# Combine filters
|
|
final_filter = None
|
|
if len(all_filters) == 1:
|
|
final_filter = all_filters[0]
|
|
elif len(all_filters) > 1:
|
|
final_filter = ("And", all_filters)
|
|
|
|
try:
|
|
result = await self._execute_query(
|
|
namespace_name=namespace_name,
|
|
search_mode=search_mode,
|
|
query_embedding=query_embedding,
|
|
query_text=query_text,
|
|
top_k=top_k,
|
|
include_attributes=["text", "name", "organization_id", "tool_type", "tags", "created_at"],
|
|
filters=final_filter,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
)
|
|
|
|
if search_mode == "hybrid":
|
|
vector_results = self._process_tool_query_results(result.results[0])
|
|
fts_results = self._process_tool_query_results(result.results[1])
|
|
results_with_metadata = self._reciprocal_rank_fusion(
|
|
vector_results=vector_results,
|
|
fts_results=fts_results,
|
|
get_id_func=lambda d: d["id"],
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
top_k=top_k,
|
|
)
|
|
return results_with_metadata
|
|
else:
|
|
results = self._process_tool_query_results(result)
|
|
results_with_metadata = []
|
|
for idx, tool_dict in enumerate(results):
|
|
metadata = {
|
|
"combined_score": 1.0 / (idx + 1),
|
|
"search_mode": search_mode,
|
|
f"{search_mode}_rank": idx + 1,
|
|
}
|
|
results_with_metadata.append((tool_dict, metadata["combined_score"], metadata))
|
|
return results_with_metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to query tools from Turbopuffer: {e}")
|
|
raise
|
|
|
|
def _process_tool_query_results(self, result) -> List[dict]:
|
|
"""Process results from a tool query into tool dicts."""
|
|
tools = []
|
|
for row in result.rows:
|
|
tool_dict = {
|
|
"id": row.id,
|
|
"text": getattr(row, "text", ""),
|
|
"name": getattr(row, "name", ""),
|
|
"organization_id": getattr(row, "organization_id", None),
|
|
"tool_type": getattr(row, "tool_type", None),
|
|
"tags": getattr(row, "tags", []),
|
|
"created_at": getattr(row, "created_at", None),
|
|
}
|
|
tools.append(tool_dict)
|
|
return tools
|