* fix: wrap turbopuffer vector writes in thread pool Turbopuffer library does CPU-intensive base64 encoding of vectors synchronously in async functions (_async_transform_recursive → b64encode_vector), blocking the event loop during file uploads. Solution: Created _run_turbopuffer_write_in_thread() helper that runs turbopuffer writes in an isolated event loop within a worker thread. Applied to all vector write operations: - insert_tools() - insert_archival_memories() - insert_messages() - insert_file_passages() This prevents pybase64.b64encode_as_string() from blocking the main event loop during vector encoding. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: wrap all turbopuffer operations in thread pool Extended the thread pool wrapping to ALL turbopuffer write operations, including delete operations, for complete isolation from the main event loop. All turbopuffer namespace.write() calls now run in isolated event loops within worker threads, preventing any potential CPU work from blocking. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> --------- Co-authored-by: Letta <noreply@letta.com>
1944 lines
79 KiB
Python
1944 lines
79 KiB
Python
"""Turbopuffer utilities for archival memory storage."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Callable, List, Optional, Tuple
|
|
|
|
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import MessageRole, TagMatchMode
|
|
from letta.schemas.passage import Passage as PydanticPassage
|
|
from letta.settings import model_settings, settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global semaphore for Turbopuffer operations to prevent overwhelming the service
|
|
# This is separate from embedding semaphore since Turbopuffer can handle more concurrency
|
|
_GLOBAL_TURBOPUFFER_SEMAPHORE = asyncio.Semaphore(5)
|
|
|
|
|
|
def _run_turbopuffer_write_in_thread(
|
|
api_key: str,
|
|
region: str,
|
|
namespace_name: str,
|
|
upsert_columns: dict = None,
|
|
deletes: list = None,
|
|
delete_by_filter: tuple = None,
|
|
distance_metric: str = "cosine_distance",
|
|
schema: dict = None,
|
|
):
|
|
"""
|
|
Sync wrapper to run turbopuffer write in isolated event loop.
|
|
|
|
Turbopuffer's async write() does CPU-intensive base64 encoding of vectors
|
|
synchronously in async functions, blocking the event loop. Running it in
|
|
a thread pool with an isolated event loop prevents blocking.
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
# Create new event loop for this worker thread
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
|
|
async def do_write():
|
|
async with AsyncTurbopuffer(api_key=api_key, region=region) as client:
|
|
namespace = client.namespace(namespace_name)
|
|
|
|
# Build write kwargs
|
|
kwargs = {"distance_metric": distance_metric}
|
|
if upsert_columns:
|
|
kwargs["upsert_columns"] = upsert_columns
|
|
if deletes:
|
|
kwargs["deletes"] = deletes
|
|
if delete_by_filter:
|
|
kwargs["delete_by_filter"] = delete_by_filter
|
|
if schema:
|
|
kwargs["schema"] = schema
|
|
|
|
return await namespace.write(**kwargs)
|
|
|
|
return loop.run_until_complete(do_write())
|
|
finally:
|
|
loop.close()
|
|
|
|
|
|
def should_use_tpuf() -> bool:
|
|
# We need OpenAI since we default to their embedding model
|
|
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key) and bool(model_settings.openai_api_key)
|
|
|
|
|
|
def should_use_tpuf_for_messages() -> bool:
|
|
"""Check if Turbopuffer should be used for messages."""
|
|
return should_use_tpuf() and bool(settings.embed_all_messages)
|
|
|
|
|
|
def should_use_tpuf_for_tools() -> bool:
|
|
"""Check if Turbopuffer should be used for tools."""
|
|
return should_use_tpuf() and bool(settings.embed_tools)
|
|
|
|
|
|
class TurbopufferClient:
|
|
"""Client for managing archival memory with Turbopuffer vector database."""
|
|
|
|
default_embedding_config = EmbeddingConfig(
|
|
embedding_model="text-embedding-3-small",
|
|
embedding_endpoint_type="openai",
|
|
embedding_endpoint="https://api.openai.com/v1",
|
|
embedding_dim=1536,
|
|
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
|
)
|
|
|
|
def __init__(self, api_key: str = None, region: str = None):
|
|
"""Initialize Turbopuffer client."""
|
|
self.api_key = api_key or settings.tpuf_api_key
|
|
self.region = region or settings.tpuf_region
|
|
|
|
from letta.services.agent_manager import AgentManager
|
|
from letta.services.archive_manager import ArchiveManager
|
|
|
|
self.archive_manager = ArchiveManager()
|
|
self.agent_manager = AgentManager()
|
|
|
|
if not self.api_key:
|
|
raise ValueError("Turbopuffer API key not provided")
|
|
|
|
@trace_method
|
|
async def _generate_embeddings(self, texts: List[str], actor: "PydanticUser") -> List[List[float]]:
|
|
"""Generate embeddings using the default embedding configuration.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
actor: User actor for embedding generation
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
from letta.llm_api.llm_client import LLMClient
|
|
|
|
# filter out empty strings after stripping
|
|
filtered_texts = [text for text in texts if text.strip()]
|
|
|
|
# skip embedding if no valid texts
|
|
if not filtered_texts:
|
|
return []
|
|
|
|
embedding_client = LLMClient.create(
|
|
provider_type=self.default_embedding_config.embedding_endpoint_type,
|
|
actor=actor,
|
|
)
|
|
embeddings = await embedding_client.request_embeddings(filtered_texts, self.default_embedding_config)
|
|
return embeddings
|
|
|
|
@trace_method
|
|
async def _get_archive_namespace_name(self, archive_id: str) -> str:
|
|
"""Get namespace name for a specific archive."""
|
|
return await self.archive_manager.get_or_set_vector_db_namespace_async(archive_id)
|
|
|
|
@trace_method
|
|
async def _get_message_namespace_name(self, organization_id: str) -> str:
|
|
"""Get namespace name for messages (org-scoped).
|
|
|
|
Args:
|
|
organization_id: Organization ID for namespace generation
|
|
|
|
Returns:
|
|
The org-scoped namespace name for messages
|
|
"""
|
|
environment = settings.environment
|
|
if environment:
|
|
namespace_name = f"messages_{organization_id}_{environment.lower()}"
|
|
else:
|
|
namespace_name = f"messages_{organization_id}"
|
|
|
|
return namespace_name
|
|
|
|
@trace_method
|
|
async def _get_tool_namespace_name(self, organization_id: str) -> str:
|
|
"""Get namespace name for tools (org-scoped).
|
|
|
|
Args:
|
|
organization_id: Organization ID for namespace generation
|
|
|
|
Returns:
|
|
The org-scoped namespace name for tools
|
|
"""
|
|
environment = settings.environment
|
|
if environment:
|
|
namespace_name = f"tools_{organization_id}_{environment.lower()}"
|
|
else:
|
|
namespace_name = f"tools_{organization_id}"
|
|
|
|
return namespace_name
|
|
|
|
def _extract_tool_text(self, tool: "PydanticTool") -> str:
|
|
"""Extract searchable text from a tool for embedding.
|
|
|
|
Combines name, description, and JSON schema into a structured format
|
|
that provides rich context for semantic search.
|
|
|
|
Args:
|
|
tool: The tool to extract text from
|
|
|
|
Returns:
|
|
JSON-formatted string containing tool information
|
|
"""
|
|
|
|
parts = {
|
|
"name": tool.name or "",
|
|
"description": tool.description or "",
|
|
}
|
|
|
|
# Extract parameter information from JSON schema
|
|
if tool.json_schema:
|
|
# Include function description from schema if different from tool description
|
|
schema_description = tool.json_schema.get("description", "")
|
|
if schema_description and schema_description != tool.description:
|
|
parts["schema_description"] = schema_description
|
|
|
|
# Extract parameter information
|
|
parameters = tool.json_schema.get("parameters", {})
|
|
if parameters:
|
|
properties = parameters.get("properties", {})
|
|
param_descriptions = []
|
|
for param_name, param_info in properties.items():
|
|
param_desc = param_info.get("description", "")
|
|
param_type = param_info.get("type", "any")
|
|
if param_desc:
|
|
param_descriptions.append(f"{param_name} ({param_type}): {param_desc}")
|
|
else:
|
|
param_descriptions.append(f"{param_name} ({param_type})")
|
|
if param_descriptions:
|
|
parts["parameters"] = param_descriptions
|
|
|
|
# Include tags for additional context
|
|
if tool.tags:
|
|
parts["tags"] = tool.tags
|
|
|
|
return json.dumps(parts)
|
|
|
|
@trace_method
|
|
async def insert_tools(
|
|
self,
|
|
tools: List["PydanticTool"],
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
) -> bool:
|
|
"""Insert tools into Turbopuffer.
|
|
|
|
Args:
|
|
tools: List of tools to store
|
|
organization_id: Organization ID for the tools
|
|
actor: User actor for embedding generation
|
|
|
|
Returns:
|
|
True if successful
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
if not tools:
|
|
return True
|
|
|
|
# Extract text and filter out empty content
|
|
tool_texts = []
|
|
valid_tools = []
|
|
for tool in tools:
|
|
text = self._extract_tool_text(tool)
|
|
if text.strip():
|
|
tool_texts.append(text)
|
|
valid_tools.append(tool)
|
|
|
|
if not valid_tools:
|
|
logger.warning("All tools had empty text content, skipping insertion")
|
|
return True
|
|
|
|
# Generate embeddings
|
|
embeddings = await self._generate_embeddings(tool_texts, actor)
|
|
|
|
namespace_name = await self._get_tool_namespace_name(organization_id)
|
|
|
|
# Prepare column-based data
|
|
ids = []
|
|
vectors = []
|
|
texts = []
|
|
names = []
|
|
organization_ids = []
|
|
tool_types = []
|
|
tags_arrays = []
|
|
created_ats = []
|
|
|
|
for tool, text, embedding in zip(valid_tools, tool_texts, embeddings):
|
|
ids.append(tool.id)
|
|
vectors.append(embedding)
|
|
texts.append(text)
|
|
names.append(tool.name or "")
|
|
organization_ids.append(organization_id)
|
|
tool_types.append(tool.tool_type.value if tool.tool_type else "custom")
|
|
tags_arrays.append(tool.tags or [])
|
|
created_ats.append(getattr(tool, "created_at", None) or datetime.now(timezone.utc))
|
|
|
|
upsert_columns = {
|
|
"id": ids,
|
|
"vector": vectors,
|
|
"text": texts,
|
|
"name": names,
|
|
"organization_id": organization_ids,
|
|
"tool_type": tool_types,
|
|
"tags": tags_arrays,
|
|
"created_at": created_ats,
|
|
}
|
|
|
|
try:
|
|
# Use global semaphore to limit concurrent Turbopuffer writes
|
|
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
|
|
# Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
upsert_columns=upsert_columns,
|
|
distance_metric="cosine_distance",
|
|
schema={"text": {"type": "string", "full_text_search": True}},
|
|
)
|
|
logger.info(f"Successfully inserted {len(ids)} tools to Turbopuffer")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to insert tools to Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
async def insert_archival_memories(
|
|
self,
|
|
archive_id: str,
|
|
text_chunks: List[str],
|
|
passage_ids: List[str],
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
tags: Optional[List[str]] = None,
|
|
created_at: Optional[datetime] = None,
|
|
) -> List[PydanticPassage]:
|
|
"""Insert passages into Turbopuffer.
|
|
|
|
Args:
|
|
archive_id: ID of the archive
|
|
text_chunks: List of text chunks to store
|
|
passage_ids: List of passage IDs (must match 1:1 with text_chunks)
|
|
organization_id: Organization ID for the passages
|
|
actor: User actor for embedding generation
|
|
tags: Optional list of tags to attach to all passages
|
|
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
|
|
|
|
Returns:
|
|
List of PydanticPassage objects that were inserted
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
# filter out empty text chunks
|
|
filtered_chunks = [(i, text) for i, text in enumerate(text_chunks) if text.strip()]
|
|
|
|
if not filtered_chunks:
|
|
logger.warning("All text chunks were empty, skipping insertion")
|
|
return []
|
|
|
|
# generate embeddings using the default config
|
|
filtered_texts = [text for _, text in filtered_chunks]
|
|
embeddings = await self._generate_embeddings(filtered_texts, actor)
|
|
|
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
|
|
|
# handle timestamp - ensure UTC
|
|
if created_at is None:
|
|
timestamp = datetime.now(timezone.utc)
|
|
else:
|
|
# ensure the provided timestamp is timezone-aware and in UTC
|
|
if created_at.tzinfo is None:
|
|
# assume UTC if no timezone provided
|
|
timestamp = created_at.replace(tzinfo=timezone.utc)
|
|
else:
|
|
# convert to UTC if in different timezone
|
|
timestamp = created_at.astimezone(timezone.utc)
|
|
|
|
# passage_ids must be provided for dual-write consistency
|
|
if not passage_ids:
|
|
raise ValueError("passage_ids must be provided for Turbopuffer insertion")
|
|
if len(passage_ids) != len(text_chunks):
|
|
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})")
|
|
|
|
# prepare column-based data for turbopuffer - optimized for batch insert
|
|
ids = []
|
|
vectors = []
|
|
texts = []
|
|
organization_ids = []
|
|
archive_ids = []
|
|
created_ats = []
|
|
tags_arrays = [] # Store tags as arrays
|
|
passages = []
|
|
|
|
for (original_idx, text), embedding in zip(filtered_chunks, embeddings):
|
|
passage_id = passage_ids[original_idx]
|
|
|
|
# append to columns
|
|
ids.append(passage_id)
|
|
vectors.append(embedding)
|
|
texts.append(text)
|
|
organization_ids.append(organization_id)
|
|
archive_ids.append(archive_id)
|
|
created_ats.append(timestamp)
|
|
tags_arrays.append(tags or []) # Store tags as array
|
|
|
|
# Create PydanticPassage object
|
|
passage = PydanticPassage(
|
|
id=passage_id,
|
|
text=text,
|
|
organization_id=organization_id,
|
|
archive_id=archive_id,
|
|
created_at=timestamp,
|
|
metadata_={},
|
|
tags=tags or [], # Include tags in the passage
|
|
embedding=embedding,
|
|
embedding_config=self.default_embedding_config, # Will be set by caller if needed
|
|
)
|
|
passages.append(passage)
|
|
|
|
# build column-based upsert data
|
|
upsert_columns = {
|
|
"id": ids,
|
|
"vector": vectors,
|
|
"text": texts,
|
|
"organization_id": organization_ids,
|
|
"archive_id": archive_ids,
|
|
"created_at": created_ats,
|
|
"tags": tags_arrays, # Add tags as array column
|
|
}
|
|
|
|
try:
|
|
# Use global semaphore to limit concurrent Turbopuffer writes
|
|
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
|
|
# Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
upsert_columns=upsert_columns,
|
|
distance_metric="cosine_distance",
|
|
schema={"text": {"type": "string", "full_text_search": True}},
|
|
)
|
|
logger.info(f"Successfully inserted {len(ids)} passages to Turbopuffer for archive {archive_id}")
|
|
return passages
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to insert passages to Turbopuffer: {e}")
|
|
# check if it's a duplicate ID error
|
|
if "duplicate" in str(e).lower():
|
|
logger.error("Duplicate passage IDs detected in batch")
|
|
raise
|
|
|
|
@trace_method
|
|
async def insert_messages(
|
|
self,
|
|
agent_id: str,
|
|
message_texts: List[str],
|
|
message_ids: List[str],
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
roles: List[MessageRole],
|
|
created_ats: List[datetime],
|
|
project_id: Optional[str] = None,
|
|
template_id: Optional[str] = None,
|
|
conversation_ids: Optional[List[Optional[str]]] = None,
|
|
) -> bool:
|
|
"""Insert messages into Turbopuffer.
|
|
|
|
Args:
|
|
agent_id: ID of the agent
|
|
message_texts: List of message text content to store
|
|
message_ids: List of message IDs (must match 1:1 with message_texts)
|
|
organization_id: Organization ID for the messages
|
|
actor: User actor for embedding generation
|
|
roles: List of message roles corresponding to each message
|
|
created_ats: List of creation timestamps for each message
|
|
project_id: Optional project ID for all messages
|
|
template_id: Optional template ID for all messages
|
|
conversation_ids: Optional list of conversation IDs (one per message, must match 1:1 with message_texts)
|
|
|
|
Returns:
|
|
True if successful
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
# filter out empty message texts
|
|
filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()]
|
|
|
|
if not filtered_messages:
|
|
logger.warning("All message texts were empty, skipping insertion")
|
|
return True
|
|
|
|
# generate embeddings using the default config
|
|
filtered_texts = [text for _, text in filtered_messages]
|
|
embeddings = await self._generate_embeddings(filtered_texts, actor)
|
|
|
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
|
|
|
# validation checks
|
|
if not message_ids:
|
|
raise ValueError("message_ids must be provided for Turbopuffer insertion")
|
|
if len(message_ids) != len(message_texts):
|
|
raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
|
|
if len(message_ids) != len(roles):
|
|
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
|
|
if len(message_ids) != len(created_ats):
|
|
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
|
|
if conversation_ids is not None and len(conversation_ids) != len(message_ids):
|
|
raise ValueError(f"conversation_ids length ({len(conversation_ids)}) must match message_ids length ({len(message_ids)})")
|
|
|
|
# prepare column-based data for turbopuffer - optimized for batch insert
|
|
ids = []
|
|
vectors = []
|
|
texts = []
|
|
organization_ids_list = []
|
|
agent_ids_list = []
|
|
message_roles = []
|
|
created_at_timestamps = []
|
|
project_ids_list = []
|
|
template_ids_list = []
|
|
conversation_ids_list = []
|
|
|
|
for (original_idx, text), embedding in zip(filtered_messages, embeddings):
|
|
message_id = message_ids[original_idx]
|
|
role = roles[original_idx]
|
|
created_at = created_ats[original_idx]
|
|
conversation_id = conversation_ids[original_idx] if conversation_ids else None
|
|
|
|
# ensure the provided timestamp is timezone-aware and in UTC
|
|
if created_at.tzinfo is None:
|
|
# assume UTC if no timezone provided
|
|
timestamp = created_at.replace(tzinfo=timezone.utc)
|
|
else:
|
|
# convert to UTC if in different timezone
|
|
timestamp = created_at.astimezone(timezone.utc)
|
|
|
|
# append to columns
|
|
ids.append(message_id)
|
|
vectors.append(embedding)
|
|
texts.append(text)
|
|
organization_ids_list.append(organization_id)
|
|
agent_ids_list.append(agent_id)
|
|
message_roles.append(role.value)
|
|
created_at_timestamps.append(timestamp)
|
|
project_ids_list.append(project_id)
|
|
template_ids_list.append(template_id)
|
|
conversation_ids_list.append(conversation_id)
|
|
|
|
# build column-based upsert data
|
|
upsert_columns = {
|
|
"id": ids,
|
|
"vector": vectors,
|
|
"text": texts,
|
|
"organization_id": organization_ids_list,
|
|
"agent_id": agent_ids_list,
|
|
"role": message_roles,
|
|
"created_at": created_at_timestamps,
|
|
}
|
|
|
|
# only include conversation_id if it's provided
|
|
if conversation_ids is not None:
|
|
upsert_columns["conversation_id"] = conversation_ids_list
|
|
|
|
# only include project_id if it's provided
|
|
if project_id is not None:
|
|
upsert_columns["project_id"] = project_ids_list
|
|
|
|
# only include template_id if it's provided
|
|
if template_id is not None:
|
|
upsert_columns["template_id"] = template_ids_list
|
|
|
|
try:
|
|
# Use global semaphore to limit concurrent Turbopuffer writes
|
|
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
|
|
# Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
upsert_columns=upsert_columns,
|
|
distance_metric="cosine_distance",
|
|
schema={
|
|
"text": {"type": "string", "full_text_search": True},
|
|
"conversation_id": {"type": "string"},
|
|
},
|
|
)
|
|
logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to insert messages to Turbopuffer: {e}")
|
|
# check if it's a duplicate ID error
|
|
if "duplicate" in str(e).lower():
|
|
logger.error("Duplicate message IDs detected in batch")
|
|
raise
|
|
|
|
@trace_method
|
|
async def _execute_query(
|
|
self,
|
|
namespace_name: str,
|
|
search_mode: str,
|
|
query_embedding: Optional[List[float]],
|
|
query_text: Optional[str],
|
|
top_k: int,
|
|
include_attributes: List[str],
|
|
filters: Optional[Any] = None,
|
|
vector_weight: float = 0.5,
|
|
fts_weight: float = 0.5,
|
|
) -> Any:
|
|
"""Generic query execution for Turbopuffer.
|
|
|
|
Args:
|
|
namespace_name: Turbopuffer namespace to query
|
|
search_mode: "vector", "fts", "hybrid", or "timestamp"
|
|
query_embedding: Embedding for vector search
|
|
query_text: Text for full-text search
|
|
top_k: Number of results to return
|
|
include_attributes: Attributes to include in results
|
|
filters: Turbopuffer filter expression
|
|
vector_weight: Weight for vector search in hybrid mode
|
|
fts_weight: Weight for FTS in hybrid mode
|
|
|
|
Returns:
|
|
Raw Turbopuffer query results or multi-query response
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
from turbopuffer.types import QueryParam
|
|
|
|
# validate inputs based on search mode
|
|
if search_mode == "vector" and query_embedding is None:
|
|
raise ValueError("query_embedding is required for vector search mode")
|
|
if search_mode == "fts" and query_text is None:
|
|
raise ValueError("query_text is required for FTS search mode")
|
|
if search_mode == "hybrid":
|
|
if query_embedding is None or query_text is None:
|
|
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
|
|
if search_mode not in ["vector", "fts", "hybrid", "timestamp"]:
|
|
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'")
|
|
|
|
try:
|
|
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
|
namespace = client.namespace(namespace_name)
|
|
|
|
if search_mode == "timestamp":
|
|
# retrieve most recent items by timestamp
|
|
query_params = {
|
|
"rank_by": ("created_at", "desc"),
|
|
"top_k": top_k,
|
|
"include_attributes": include_attributes,
|
|
}
|
|
if filters:
|
|
query_params["filters"] = filters
|
|
return await namespace.query(**query_params)
|
|
|
|
elif search_mode == "vector":
|
|
# vector search query
|
|
query_params = {
|
|
"rank_by": ("vector", "ANN", query_embedding),
|
|
"top_k": top_k,
|
|
"include_attributes": include_attributes,
|
|
}
|
|
if filters:
|
|
query_params["filters"] = filters
|
|
return await namespace.query(**query_params)
|
|
|
|
elif search_mode == "fts":
|
|
# full-text search query
|
|
query_params = {
|
|
"rank_by": ("text", "BM25", query_text),
|
|
"top_k": top_k,
|
|
"include_attributes": include_attributes,
|
|
}
|
|
if filters:
|
|
query_params["filters"] = filters
|
|
return await namespace.query(**query_params)
|
|
|
|
else: # hybrid mode
|
|
queries = []
|
|
|
|
# vector search query
|
|
vector_query = {
|
|
"rank_by": ("vector", "ANN", query_embedding),
|
|
"top_k": top_k,
|
|
"include_attributes": include_attributes,
|
|
}
|
|
if filters:
|
|
vector_query["filters"] = filters
|
|
queries.append(vector_query)
|
|
|
|
# full-text search query
|
|
fts_query = {
|
|
"rank_by": ("text", "BM25", query_text),
|
|
"top_k": top_k,
|
|
"include_attributes": include_attributes,
|
|
}
|
|
if filters:
|
|
fts_query["filters"] = filters
|
|
queries.append(fts_query)
|
|
|
|
# execute multi-query
|
|
return await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
|
|
except Exception as e:
|
|
# Wrap turbopuffer errors with user-friendly messages
|
|
from turbopuffer import NotFoundError
|
|
|
|
if isinstance(e, NotFoundError):
|
|
# Extract just the error message without implementation details
|
|
error_msg = str(e)
|
|
if "namespace" in error_msg.lower() and "not found" in error_msg.lower():
|
|
raise ValueError("No conversation history found. Please send a message first to enable search.") from e
|
|
raise ValueError(f"Search data not found: {error_msg}") from e
|
|
# Re-raise other errors as-is
|
|
raise
|
|
|
|
@trace_method
|
|
async def query_passages(
|
|
self,
|
|
archive_id: str,
|
|
actor: "PydanticUser",
|
|
query_text: Optional[str] = None,
|
|
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
|
top_k: int = 10,
|
|
tags: Optional[List[str]] = None,
|
|
tag_match_mode: TagMatchMode = TagMatchMode.ANY,
|
|
vector_weight: float = 0.5,
|
|
fts_weight: float = 0.5,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
) -> List[Tuple[PydanticPassage, float, dict]]:
|
|
"""Query passages from Turbopuffer using vector search, full-text search, or hybrid search.
|
|
|
|
Args:
|
|
archive_id: ID of the archive
|
|
actor: User actor for embedding generation
|
|
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
|
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
|
|
top_k: Number of results to return
|
|
tags: Optional list of tags to filter by
|
|
tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY
|
|
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
|
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
|
start_date: Optional datetime to filter passages created after this date
|
|
end_date: Optional datetime to filter passages created on or before this date (inclusive)
|
|
|
|
Returns:
|
|
List of (passage, score, metadata) tuples with relevance rankings
|
|
"""
|
|
# generate embedding for vector/hybrid search if query_text is provided
|
|
query_embedding = None
|
|
if query_text and search_mode in ["vector", "hybrid"]:
|
|
embeddings = await self._generate_embeddings([query_text], actor)
|
|
query_embedding = embeddings[0]
|
|
|
|
# Check if we should fallback to timestamp-based retrieval
|
|
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
|
# Fallback to retrieving most recent passages when no search query is provided
|
|
search_mode = "timestamp"
|
|
|
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
|
|
|
# build tag filter conditions
|
|
tag_filter = None
|
|
if tags:
|
|
if tag_match_mode == TagMatchMode.ALL:
|
|
# For ALL mode, need to check each tag individually with Contains
|
|
tag_conditions = []
|
|
for tag in tags:
|
|
tag_conditions.append(("tags", "Contains", tag))
|
|
if len(tag_conditions) == 1:
|
|
tag_filter = tag_conditions[0]
|
|
else:
|
|
tag_filter = ("And", tag_conditions)
|
|
else: # tag_match_mode == TagMatchMode.ANY
|
|
# For ANY mode, use ContainsAny to match any of the tags
|
|
tag_filter = ("tags", "ContainsAny", tags)
|
|
|
|
# build date filter conditions
|
|
date_filters = []
|
|
if start_date:
|
|
# Convert to UTC to match stored timestamps
|
|
if start_date.tzinfo is not None:
|
|
start_date = start_date.astimezone(timezone.utc)
|
|
date_filters.append(("created_at", "Gte", start_date))
|
|
if end_date:
|
|
# if end_date has no time component (is at midnight), adjust to end of day
|
|
# to make the filter inclusive of the entire day
|
|
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
|
|
from datetime import timedelta
|
|
|
|
# add 1 day and subtract 1 microsecond to get 23:59:59.999999
|
|
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
|
|
# Convert to UTC to match stored timestamps
|
|
if end_date.tzinfo is not None:
|
|
end_date = end_date.astimezone(timezone.utc)
|
|
date_filters.append(("created_at", "Lte", end_date))
|
|
|
|
# combine all filters
|
|
all_filters = []
|
|
if tag_filter:
|
|
all_filters.append(tag_filter)
|
|
if date_filters:
|
|
all_filters.extend(date_filters)
|
|
|
|
# create final filter expression
|
|
final_filter = None
|
|
if len(all_filters) == 1:
|
|
final_filter = all_filters[0]
|
|
elif len(all_filters) > 1:
|
|
final_filter = ("And", all_filters)
|
|
|
|
try:
|
|
# use generic query executor
|
|
result = await self._execute_query(
|
|
namespace_name=namespace_name,
|
|
search_mode=search_mode,
|
|
query_embedding=query_embedding,
|
|
query_text=query_text,
|
|
top_k=top_k,
|
|
include_attributes=["text", "organization_id", "archive_id", "created_at", "tags"],
|
|
filters=final_filter,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
)
|
|
|
|
# process results based on search mode
|
|
if search_mode == "hybrid":
|
|
# for hybrid mode, we get a multi-query response
|
|
vector_results = self._process_single_query_results(result.results[0], archive_id, tags)
|
|
fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True)
|
|
# use RRF and include metadata with ranks
|
|
results_with_metadata = self._reciprocal_rank_fusion(
|
|
vector_results=[passage for passage, _ in vector_results],
|
|
fts_results=[passage for passage, _ in fts_results],
|
|
get_id_func=lambda p: p.id,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
top_k=top_k,
|
|
)
|
|
# Return (passage, score, metadata) with ranks
|
|
return results_with_metadata
|
|
else:
|
|
# for single queries (vector, fts, timestamp) - add basic metadata
|
|
is_fts = search_mode == "fts"
|
|
results = self._process_single_query_results(result, archive_id, tags, is_fts=is_fts)
|
|
# Add simple metadata for single search modes
|
|
results_with_metadata = []
|
|
for idx, (passage, score) in enumerate(results):
|
|
metadata = {
|
|
"combined_score": score,
|
|
f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
|
|
}
|
|
results_with_metadata.append((passage, score, metadata))
|
|
return results_with_metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to query passages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
async def query_messages_by_agent_id(
|
|
self,
|
|
agent_id: str,
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
query_text: Optional[str] = None,
|
|
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
|
|
top_k: int = 10,
|
|
roles: Optional[List[MessageRole]] = None,
|
|
project_id: Optional[str] = None,
|
|
template_id: Optional[str] = None,
|
|
conversation_id: Optional[str] = None,
|
|
vector_weight: float = 0.5,
|
|
fts_weight: float = 0.5,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
) -> List[Tuple[dict, float, dict]]:
|
|
"""Query messages from Turbopuffer using vector search, full-text search, or hybrid search.
|
|
|
|
Args:
|
|
agent_id: ID of the agent (used for filtering results)
|
|
organization_id: Organization ID for namespace lookup
|
|
actor: User actor for embedding generation
|
|
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
|
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector")
|
|
top_k: Number of results to return
|
|
roles: Optional list of message roles to filter by
|
|
project_id: Optional project ID to filter messages by
|
|
template_id: Optional template ID to filter messages by
|
|
conversation_id: Optional conversation ID to filter messages by (use "default" for NULL)
|
|
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
|
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
|
start_date: Optional datetime to filter messages created after this date
|
|
end_date: Optional datetime to filter messages created on or before this date (inclusive)
|
|
|
|
Returns:
|
|
List of (message_dict, score, metadata) tuples where:
|
|
- message_dict contains id, text, role, created_at
|
|
- score is the final relevance score
|
|
- metadata contains individual scores and ranking information
|
|
"""
|
|
# generate embedding for vector/hybrid search if query_text is provided
|
|
query_embedding = None
|
|
if query_text and search_mode in ["vector", "hybrid"]:
|
|
embeddings = await self._generate_embeddings([query_text], actor)
|
|
query_embedding = embeddings[0]
|
|
|
|
# Check if we should fallback to timestamp-based retrieval
|
|
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
|
# Fallback to retrieving most recent messages when no search query is provided
|
|
search_mode = "timestamp"
|
|
|
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
|
|
|
# build agent_id filter
|
|
agent_filter = ("agent_id", "Eq", agent_id)
|
|
|
|
# build role filter conditions
|
|
role_filter = None
|
|
if roles:
|
|
role_values = [r.value for r in roles]
|
|
if len(role_values) == 1:
|
|
role_filter = ("role", "Eq", role_values[0])
|
|
else:
|
|
role_filter = ("role", "In", role_values)
|
|
|
|
# build date filter conditions
|
|
date_filters = []
|
|
if start_date:
|
|
# Convert to UTC to match stored timestamps
|
|
if start_date.tzinfo is not None:
|
|
start_date = start_date.astimezone(timezone.utc)
|
|
date_filters.append(("created_at", "Gte", start_date))
|
|
if end_date:
|
|
# if end_date has no time component (is at midnight), adjust to end of day
|
|
# to make the filter inclusive of the entire day
|
|
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
|
|
from datetime import timedelta
|
|
|
|
# add 1 day and subtract 1 microsecond to get 23:59:59.999999
|
|
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
|
|
# Convert to UTC to match stored timestamps
|
|
if end_date.tzinfo is not None:
|
|
end_date = end_date.astimezone(timezone.utc)
|
|
date_filters.append(("created_at", "Lte", end_date))
|
|
|
|
# build project_id filter if provided
|
|
project_filter = None
|
|
if project_id:
|
|
project_filter = ("project_id", "Eq", project_id)
|
|
|
|
# build template_id filter if provided
|
|
template_filter = None
|
|
if template_id:
|
|
template_filter = ("template_id", "Eq", template_id)
|
|
|
|
# build conversation_id filter if provided
|
|
# three cases:
|
|
# 1. conversation_id=None (omitted) -> return all messages (no filter)
|
|
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
|
|
# 3. conversation_id="xyz" -> return only messages in that conversation
|
|
conversation_filter = None
|
|
if conversation_id == "default":
|
|
# "default" is reserved for default messages only (conversation_id is none)
|
|
conversation_filter = ("conversation_id", "Eq", None)
|
|
elif conversation_id is not None:
|
|
# Specific conversation
|
|
conversation_filter = ("conversation_id", "Eq", conversation_id)
|
|
|
|
# combine all filters
|
|
all_filters = [agent_filter] # always include agent_id filter
|
|
if role_filter:
|
|
all_filters.append(role_filter)
|
|
if project_filter:
|
|
all_filters.append(project_filter)
|
|
if template_filter:
|
|
all_filters.append(template_filter)
|
|
if conversation_filter:
|
|
all_filters.append(conversation_filter)
|
|
if date_filters:
|
|
all_filters.extend(date_filters)
|
|
|
|
# create final filter expression
|
|
final_filter = None
|
|
if len(all_filters) == 1:
|
|
final_filter = all_filters[0]
|
|
elif len(all_filters) > 1:
|
|
final_filter = ("And", all_filters)
|
|
|
|
try:
|
|
# use generic query executor
|
|
result = await self._execute_query(
|
|
namespace_name=namespace_name,
|
|
search_mode=search_mode,
|
|
query_embedding=query_embedding,
|
|
query_text=query_text,
|
|
top_k=top_k,
|
|
include_attributes=True,
|
|
filters=final_filter,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
)
|
|
|
|
# process results based on search mode
|
|
if search_mode == "hybrid":
|
|
# for hybrid mode, we get a multi-query response
|
|
vector_results = self._process_message_query_results(result.results[0])
|
|
fts_results = self._process_message_query_results(result.results[1])
|
|
# use RRF with lambda to extract ID from dict - returns metadata
|
|
results_with_metadata = self._reciprocal_rank_fusion(
|
|
vector_results=vector_results,
|
|
fts_results=fts_results,
|
|
get_id_func=lambda msg_dict: msg_dict["id"],
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
top_k=top_k,
|
|
)
|
|
# return results with metadata
|
|
return results_with_metadata
|
|
else:
|
|
# for single queries (vector, fts, timestamp)
|
|
results = self._process_message_query_results(result)
|
|
# add simple metadata for single search modes
|
|
results_with_metadata = []
|
|
for idx, msg_dict in enumerate(results):
|
|
metadata = {
|
|
"combined_score": 1.0 / (idx + 1), # Use rank-based score for single mode
|
|
"search_mode": search_mode,
|
|
f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
|
|
}
|
|
results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
|
|
return results_with_metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
async def query_messages_by_org_id(
|
|
self,
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
query_text: Optional[str] = None,
|
|
search_mode: str = "hybrid", # "vector", "fts", "hybrid"
|
|
top_k: int = 10,
|
|
roles: Optional[List[MessageRole]] = None,
|
|
agent_id: Optional[str] = None,
|
|
project_id: Optional[str] = None,
|
|
template_id: Optional[str] = None,
|
|
conversation_id: Optional[str] = None,
|
|
vector_weight: float = 0.5,
|
|
fts_weight: float = 0.5,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
) -> List[Tuple[dict, float, dict]]:
|
|
"""Query messages from Turbopuffer across an entire organization.
|
|
|
|
Args:
|
|
organization_id: Organization ID for namespace lookup (required)
|
|
actor: User actor for embedding generation
|
|
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
|
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "hybrid")
|
|
top_k: Number of results to return
|
|
roles: Optional list of message roles to filter by
|
|
agent_id: Optional agent ID to filter messages by
|
|
project_id: Optional project ID to filter messages by
|
|
template_id: Optional template ID to filter messages by
|
|
conversation_id: Optional conversation ID to filter messages by. Special values:
|
|
- None (omitted): Return all messages
|
|
- "default": Return only default messages (conversation_id IS NULL)
|
|
- Any other value: Return messages in that specific conversation
|
|
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
|
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
|
start_date: Optional datetime to filter messages created after this date
|
|
end_date: Optional datetime to filter messages created on or before this date (inclusive)
|
|
|
|
Returns:
|
|
List of (message_dict, score, metadata) tuples where:
|
|
- message_dict contains id, text, role, created_at, agent_id
|
|
- score is the final relevance score (RRF score for hybrid, rank-based for single mode)
|
|
- metadata contains individual scores and ranking information
|
|
"""
|
|
# generate embedding for vector/hybrid search if query_text is provided
|
|
query_embedding = None
|
|
if query_text and search_mode in ["vector", "hybrid"]:
|
|
embeddings = await self._generate_embeddings([query_text], actor)
|
|
query_embedding = embeddings[0]
|
|
|
|
# Check if we should fallback to timestamp-based retrieval
|
|
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
|
# Fallback to retrieving most recent messages when no search query is provided
|
|
search_mode = "timestamp"
|
|
|
|
# namespace is org-scoped
|
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
|
|
|
# build filters
|
|
all_filters = []
|
|
|
|
# role filter
|
|
if roles:
|
|
role_values = [r.value for r in roles]
|
|
if len(role_values) == 1:
|
|
all_filters.append(("role", "Eq", role_values[0]))
|
|
else:
|
|
all_filters.append(("role", "In", role_values))
|
|
|
|
# agent filter
|
|
if agent_id:
|
|
all_filters.append(("agent_id", "Eq", agent_id))
|
|
|
|
# project filter
|
|
if project_id:
|
|
all_filters.append(("project_id", "Eq", project_id))
|
|
|
|
# template filter
|
|
if template_id:
|
|
all_filters.append(("template_id", "Eq", template_id))
|
|
|
|
# conversation filter
|
|
# three cases:
|
|
# 1. conversation_id=None (omitted) -> return all messages (no filter)
|
|
# 2. conversation_id="default" -> return only default messages (conversation_id is none), for backward compatibility
|
|
# 3. conversation_id="xyz" -> return only messages in that conversation
|
|
if conversation_id == "default":
|
|
# "default" is reserved for default messages only (conversation_id is none)
|
|
all_filters.append(("conversation_id", "Eq", None))
|
|
elif conversation_id is not None:
|
|
# Specific conversation
|
|
all_filters.append(("conversation_id", "Eq", conversation_id))
|
|
|
|
# date filters
|
|
if start_date:
|
|
# Convert to UTC to match stored timestamps
|
|
if start_date.tzinfo is not None:
|
|
start_date = start_date.astimezone(timezone.utc)
|
|
all_filters.append(("created_at", "Gte", start_date))
|
|
if end_date:
|
|
# make end_date inclusive of the entire day
|
|
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
|
|
from datetime import timedelta
|
|
|
|
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
|
|
# Convert to UTC to match stored timestamps
|
|
if end_date.tzinfo is not None:
|
|
end_date = end_date.astimezone(timezone.utc)
|
|
all_filters.append(("created_at", "Lte", end_date))
|
|
|
|
# combine filters
|
|
final_filter = None
|
|
if len(all_filters) == 1:
|
|
final_filter = all_filters[0]
|
|
elif len(all_filters) > 1:
|
|
final_filter = ("And", all_filters)
|
|
|
|
try:
|
|
# execute query
|
|
result = await self._execute_query(
|
|
namespace_name=namespace_name,
|
|
search_mode=search_mode,
|
|
query_embedding=query_embedding,
|
|
query_text=query_text,
|
|
top_k=top_k,
|
|
include_attributes=True,
|
|
filters=final_filter,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
)
|
|
|
|
# process results based on search mode
|
|
if search_mode == "hybrid":
|
|
# for hybrid mode, we get a multi-query response
|
|
vector_results = self._process_message_query_results(result.results[0])
|
|
fts_results = self._process_message_query_results(result.results[1])
|
|
|
|
# use existing RRF method - it already returns metadata with ranks
|
|
results_with_metadata = self._reciprocal_rank_fusion(
|
|
vector_results=vector_results,
|
|
fts_results=fts_results,
|
|
get_id_func=lambda msg_dict: msg_dict["id"],
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
top_k=top_k,
|
|
)
|
|
|
|
# add raw scores to metadata if available
|
|
vector_scores = {}
|
|
for row in result.results[0].rows:
|
|
if hasattr(row, "dist"):
|
|
vector_scores[row.id] = row.dist
|
|
|
|
fts_scores = {}
|
|
for row in result.results[1].rows:
|
|
if hasattr(row, "score"):
|
|
fts_scores[row.id] = row.score
|
|
|
|
# enhance metadata with raw scores
|
|
enhanced_results = []
|
|
for msg_dict, rrf_score, metadata in results_with_metadata:
|
|
msg_id = msg_dict["id"]
|
|
if msg_id in vector_scores:
|
|
metadata["vector_score"] = vector_scores[msg_id]
|
|
if msg_id in fts_scores:
|
|
metadata["fts_score"] = fts_scores[msg_id]
|
|
enhanced_results.append((msg_dict, rrf_score, metadata))
|
|
|
|
return enhanced_results
|
|
else:
|
|
# for single queries (vector or fts)
|
|
results = self._process_message_query_results(result)
|
|
results_with_metadata = []
|
|
for idx, msg_dict in enumerate(results):
|
|
metadata = {
|
|
"combined_score": 1.0 / (idx + 1),
|
|
"search_mode": search_mode,
|
|
f"{search_mode}_rank": idx + 1,
|
|
}
|
|
|
|
# add raw score if available
|
|
if hasattr(result.rows[idx], "dist"):
|
|
metadata["vector_score"] = result.rows[idx].dist
|
|
elif hasattr(result.rows[idx], "score"):
|
|
metadata["fts_score"] = result.rows[idx].score
|
|
|
|
results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
|
|
|
|
return results_with_metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
def _process_message_query_results(self, result) -> List[dict]:
|
|
"""Process results from a message query into message dicts.
|
|
|
|
For RRF, we only need the rank order - scores are not used.
|
|
"""
|
|
messages = []
|
|
|
|
for row in result.rows:
|
|
# Build message dict with key fields
|
|
message_dict = {
|
|
"id": row.id,
|
|
"text": getattr(row, "text", ""),
|
|
"organization_id": getattr(row, "organization_id", None),
|
|
"agent_id": getattr(row, "agent_id", None),
|
|
"role": getattr(row, "role", None),
|
|
"created_at": getattr(row, "created_at", None),
|
|
"conversation_id": getattr(row, "conversation_id", None),
|
|
}
|
|
messages.append(message_dict)
|
|
|
|
return messages
|
|
|
|
def _process_single_query_results(
|
|
self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False
|
|
) -> List[Tuple[PydanticPassage, float]]:
|
|
"""Process results from a single query into passage objects with scores."""
|
|
passages_with_scores = []
|
|
|
|
for row in result.rows:
|
|
# Extract tags from the result row
|
|
passage_tags = getattr(row, "tags", []) or []
|
|
|
|
# Build metadata
|
|
metadata = {}
|
|
|
|
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
|
|
passage = PydanticPassage(
|
|
id=row.id,
|
|
text=getattr(row, "text", ""),
|
|
organization_id=getattr(row, "organization_id", None),
|
|
archive_id=archive_id, # use the archive_id from the query
|
|
created_at=getattr(row, "created_at", None),
|
|
metadata_=metadata,
|
|
tags=passage_tags, # Set the actual tags from the passage
|
|
# Set required fields to empty/default values since we don't store embeddings
|
|
embedding=[], # Empty embedding since we don't return it from Turbopuffer
|
|
embedding_config=self.default_embedding_config, # No embedding config needed for retrieved passages
|
|
)
|
|
|
|
# handle score based on search type
|
|
if is_fts:
|
|
# for FTS, use the BM25 score directly (higher is better)
|
|
score = getattr(row, "$score", 0.0)
|
|
else:
|
|
# for vector search, convert distance to similarity score
|
|
distance = getattr(row, "$dist", 0.0)
|
|
score = 1.0 - distance
|
|
|
|
passages_with_scores.append((passage, score))
|
|
|
|
return passages_with_scores
|
|
|
|
def _reciprocal_rank_fusion(
|
|
self,
|
|
vector_results: List[Any],
|
|
fts_results: List[Any],
|
|
get_id_func: Callable[[Any], str],
|
|
vector_weight: float,
|
|
fts_weight: float,
|
|
top_k: int,
|
|
) -> List[Tuple[Any, float, dict]]:
|
|
"""RRF implementation that works with any object type.
|
|
|
|
RRF score = vector_weight * (1/(k + rank)) + fts_weight * (1/(k + rank))
|
|
where k is a constant (typically 60) to avoid division by zero
|
|
|
|
This is a pure rank-based fusion following the standard RRF algorithm.
|
|
|
|
Args:
|
|
vector_results: List of items from vector search (ordered by relevance)
|
|
fts_results: List of items from FTS (ordered by relevance)
|
|
get_id_func: Function to extract ID from an item
|
|
vector_weight: Weight for vector search results
|
|
fts_weight: Weight for FTS results
|
|
top_k: Number of results to return
|
|
|
|
Returns:
|
|
List of (item, score, metadata) tuples sorted by RRF score
|
|
metadata contains ranks from each result list
|
|
"""
|
|
k = 60 # standard RRF constant from Cormack et al. (2009)
|
|
|
|
# create rank mappings based on position in result lists
|
|
# rank starts at 1, not 0
|
|
vector_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(vector_results)}
|
|
fts_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(fts_results)}
|
|
|
|
# combine all unique items from both result sets
|
|
all_items = {}
|
|
for item in vector_results:
|
|
all_items[get_id_func(item)] = item
|
|
for item in fts_results:
|
|
all_items[get_id_func(item)] = item
|
|
|
|
# calculate RRF scores based purely on ranks
|
|
rrf_scores = {}
|
|
score_metadata = {}
|
|
for item_id in all_items:
|
|
# RRF formula: sum of 1/(k + rank) across result lists
|
|
# If item not in a list, we don't add anything (equivalent to rank = infinity)
|
|
vector_rrf_score = 0.0
|
|
fts_rrf_score = 0.0
|
|
|
|
if item_id in vector_ranks:
|
|
vector_rrf_score = vector_weight / (k + vector_ranks[item_id])
|
|
if item_id in fts_ranks:
|
|
fts_rrf_score = fts_weight / (k + fts_ranks[item_id])
|
|
|
|
combined_score = vector_rrf_score + fts_rrf_score
|
|
|
|
rrf_scores[item_id] = combined_score
|
|
score_metadata[item_id] = {
|
|
"combined_score": combined_score, # Final RRF score
|
|
"vector_rank": vector_ranks.get(item_id),
|
|
"fts_rank": fts_ranks.get(item_id),
|
|
}
|
|
|
|
# sort by RRF score and return with metadata
|
|
sorted_results = sorted(
|
|
[(all_items[iid], score, score_metadata[iid]) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True
|
|
)
|
|
|
|
return sorted_results[:top_k]
|
|
|
|
@trace_method
|
|
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
|
|
"""Delete a passage from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
|
|
|
try:
|
|
# Run in thread pool for consistency (deletes are lightweight but use same wrapper)
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
deletes=[passage_id],
|
|
)
|
|
logger.info(f"Successfully deleted passage {passage_id} from Turbopuffer archive {archive_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete passage from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool:
|
|
"""Delete multiple passages from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
if not passage_ids:
|
|
return True
|
|
|
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
|
|
|
try:
|
|
# Run in thread pool for consistency
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
deletes=passage_ids,
|
|
)
|
|
logger.info(f"Successfully deleted {len(passage_ids)} passages from Turbopuffer archive {archive_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete passages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
async def delete_all_passages(self, archive_id: str) -> bool:
|
|
"""Delete all passages for an archive from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
|
|
|
try:
|
|
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
|
namespace = client.namespace(namespace_name)
|
|
# Turbopuffer has a delete_all() method on namespace
|
|
await namespace.delete_all()
|
|
logger.info(f"Successfully deleted all passages for archive {archive_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete all passages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
async def delete_messages(self, agent_id: str, organization_id: str, message_ids: List[str]) -> bool:
|
|
"""Delete multiple messages from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
if not message_ids:
|
|
return True
|
|
|
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
|
|
|
try:
|
|
# Run in thread pool for consistency
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
deletes=message_ids,
|
|
)
|
|
logger.info(f"Successfully deleted {len(message_ids)} messages from Turbopuffer for agent {agent_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete messages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
async def delete_all_messages(self, agent_id: str, organization_id: str) -> bool:
|
|
"""Delete all messages for an agent from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
|
|
|
try:
|
|
# Run in thread pool for consistency
|
|
result = await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
delete_by_filter=("agent_id", "Eq", agent_id),
|
|
)
|
|
logger.info(f"Successfully deleted all messages for agent {agent_id} (deleted {result.rows_affected if result else 0} rows)")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete all messages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
# file/source passage methods
|
|
|
|
@trace_method
|
|
async def _get_file_passages_namespace_name(self, organization_id: str) -> str:
|
|
"""Get namespace name for file passages (org-scoped).
|
|
|
|
Args:
|
|
organization_id: Organization ID for namespace generation
|
|
|
|
Returns:
|
|
The org-scoped namespace name for file passages
|
|
"""
|
|
environment = settings.environment
|
|
if environment:
|
|
namespace_name = f"file_passages_{organization_id}_{environment.lower()}"
|
|
else:
|
|
namespace_name = f"file_passages_{organization_id}"
|
|
|
|
return namespace_name
|
|
|
|
@trace_method
|
|
async def insert_file_passages(
|
|
self,
|
|
source_id: str,
|
|
file_id: str,
|
|
text_chunks: List[str],
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
created_at: Optional[datetime] = None,
|
|
) -> List[PydanticPassage]:
|
|
"""Insert file passages into Turbopuffer using org-scoped namespace.
|
|
|
|
Args:
|
|
source_id: ID of the source containing the file
|
|
file_id: ID of the file
|
|
text_chunks: List of text chunks to store
|
|
organization_id: Organization ID for the passages
|
|
actor: User actor for embedding generation
|
|
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
|
|
|
|
Returns:
|
|
List of PydanticPassage objects that were inserted
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
if not text_chunks:
|
|
return []
|
|
|
|
# filter out empty text chunks
|
|
filtered_chunks = [text for text in text_chunks if text.strip()]
|
|
|
|
if not filtered_chunks:
|
|
logger.warning("All text chunks were empty, skipping file passage insertion")
|
|
return []
|
|
|
|
# generate embeddings using the default config
|
|
embeddings = await self._generate_embeddings(filtered_chunks, actor)
|
|
|
|
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
|
|
|
# handle timestamp - ensure UTC
|
|
if created_at is None:
|
|
timestamp = datetime.now(timezone.utc)
|
|
else:
|
|
# ensure the provided timestamp is timezone-aware and in UTC
|
|
if created_at.tzinfo is None:
|
|
# assume UTC if no timezone provided
|
|
timestamp = created_at.replace(tzinfo=timezone.utc)
|
|
else:
|
|
# convert to UTC if in different timezone
|
|
timestamp = created_at.astimezone(timezone.utc)
|
|
|
|
# prepare column-based data for turbopuffer - optimized for batch insert
|
|
ids = []
|
|
vectors = []
|
|
texts = []
|
|
organization_ids = []
|
|
source_ids = []
|
|
file_ids = []
|
|
created_ats = []
|
|
passages = []
|
|
|
|
for text, embedding in zip(filtered_chunks, embeddings):
|
|
passage = PydanticPassage(
|
|
text=text,
|
|
file_id=file_id,
|
|
source_id=source_id,
|
|
embedding=embedding,
|
|
embedding_config=self.default_embedding_config,
|
|
organization_id=actor.organization_id,
|
|
)
|
|
passages.append(passage)
|
|
|
|
# append to columns
|
|
ids.append(passage.id)
|
|
vectors.append(embedding)
|
|
texts.append(text)
|
|
organization_ids.append(organization_id)
|
|
source_ids.append(source_id)
|
|
file_ids.append(file_id)
|
|
created_ats.append(timestamp)
|
|
|
|
# build column-based upsert data
|
|
upsert_columns = {
|
|
"id": ids,
|
|
"vector": vectors,
|
|
"text": texts,
|
|
"organization_id": organization_ids,
|
|
"source_id": source_ids,
|
|
"file_id": file_ids,
|
|
"created_at": created_ats,
|
|
}
|
|
|
|
try:
|
|
# Use global semaphore to limit concurrent Turbopuffer writes
|
|
async with _GLOBAL_TURBOPUFFER_SEMAPHORE:
|
|
# Run in thread pool to prevent CPU-intensive base64 encoding from blocking event loop
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
upsert_columns=upsert_columns,
|
|
distance_metric="cosine_distance",
|
|
schema={"text": {"type": "string", "full_text_search": True}},
|
|
)
|
|
logger.info(f"Successfully inserted {len(ids)} file passages to Turbopuffer for source {source_id}, file {file_id}")
|
|
return passages
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to insert file passages to Turbopuffer: {e}")
|
|
# check if it's a duplicate ID error
|
|
if "duplicate" in str(e).lower():
|
|
logger.error("Duplicate passage IDs detected in batch")
|
|
raise
|
|
|
|
@trace_method
|
|
async def query_file_passages(
|
|
self,
|
|
source_ids: List[str],
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
query_text: Optional[str] = None,
|
|
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
|
top_k: int = 10,
|
|
file_id: Optional[str] = None, # optional filter by specific file
|
|
vector_weight: float = 0.5,
|
|
fts_weight: float = 0.5,
|
|
) -> List[Tuple[PydanticPassage, float, dict]]:
|
|
"""Query file passages from Turbopuffer using org-scoped namespace.
|
|
|
|
Args:
|
|
source_ids: List of source IDs to query
|
|
organization_id: Organization ID for namespace lookup
|
|
actor: User actor for embedding generation
|
|
query_text: Text query for search
|
|
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
|
|
top_k: Number of results to return
|
|
file_id: Optional file ID to filter results to a specific file
|
|
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
|
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
|
|
|
Returns:
|
|
List of (passage, score, metadata) tuples with relevance rankings
|
|
"""
|
|
# generate embedding for vector/hybrid search if query_text is provided
|
|
query_embedding = None
|
|
if query_text and search_mode in ["vector", "hybrid"]:
|
|
embeddings = await self._generate_embeddings([query_text], actor)
|
|
query_embedding = embeddings[0]
|
|
|
|
# check if we should fallback to timestamp-based retrieval
|
|
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
|
# fallback to retrieving most recent passages when no search query is provided
|
|
search_mode = "timestamp"
|
|
|
|
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
|
|
|
# build filters - always filter by source_ids
|
|
if len(source_ids) == 1:
|
|
# single source_id, use Eq for efficiency
|
|
filters = [("source_id", "Eq", source_ids[0])]
|
|
else:
|
|
# multiple source_ids, use In operator
|
|
filters = [("source_id", "In", source_ids)]
|
|
|
|
# add file filter if specified
|
|
if file_id:
|
|
filters.append(("file_id", "Eq", file_id))
|
|
|
|
# combine filters
|
|
final_filter = filters[0] if len(filters) == 1 else ("And", filters)
|
|
|
|
try:
|
|
# use generic query executor
|
|
result = await self._execute_query(
|
|
namespace_name=namespace_name,
|
|
search_mode=search_mode,
|
|
query_embedding=query_embedding,
|
|
query_text=query_text,
|
|
top_k=top_k,
|
|
include_attributes=["text", "organization_id", "source_id", "file_id", "created_at"],
|
|
filters=final_filter,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
)
|
|
|
|
# process results based on search mode
|
|
if search_mode == "hybrid":
|
|
# for hybrid mode, we get a multi-query response
|
|
vector_results = self._process_file_query_results(result.results[0])
|
|
fts_results = self._process_file_query_results(result.results[1], is_fts=True)
|
|
# use RRF and include metadata with ranks
|
|
results_with_metadata = self._reciprocal_rank_fusion(
|
|
vector_results=[passage for passage, _ in vector_results],
|
|
fts_results=[passage for passage, _ in fts_results],
|
|
get_id_func=lambda p: p.id,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
top_k=top_k,
|
|
)
|
|
return results_with_metadata
|
|
else:
|
|
# for single queries (vector, fts, timestamp) - add basic metadata
|
|
is_fts = search_mode == "fts"
|
|
results = self._process_file_query_results(result, is_fts=is_fts)
|
|
# add simple metadata for single search modes
|
|
results_with_metadata = []
|
|
for idx, (passage, score) in enumerate(results):
|
|
metadata = {
|
|
"combined_score": score,
|
|
f"{search_mode}_rank": idx + 1, # add the rank for this search mode
|
|
}
|
|
results_with_metadata.append((passage, score, metadata))
|
|
return results_with_metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to query file passages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
def _process_file_query_results(self, result, is_fts: bool = False) -> List[Tuple[PydanticPassage, float]]:
|
|
"""Process results from a file query into passage objects with scores."""
|
|
passages_with_scores = []
|
|
|
|
for row in result.rows:
|
|
# build metadata
|
|
metadata = {}
|
|
|
|
# create a passage with minimal fields - embeddings are not returned from Turbopuffer
|
|
passage = PydanticPassage(
|
|
id=row.id,
|
|
text=getattr(row, "text", ""),
|
|
organization_id=getattr(row, "organization_id", None),
|
|
source_id=getattr(row, "source_id", None), # get source_id from the row
|
|
file_id=getattr(row, "file_id", None),
|
|
created_at=getattr(row, "created_at", None),
|
|
metadata_=metadata,
|
|
tags=[],
|
|
# set required fields to empty/default values since we don't store embeddings
|
|
embedding=[], # empty embedding since we don't return it from Turbopuffer
|
|
embedding_config=self.default_embedding_config,
|
|
)
|
|
|
|
# handle score based on search type
|
|
if is_fts:
|
|
# for FTS, use the BM25 score directly (higher is better)
|
|
score = getattr(row, "$score", 0.0)
|
|
else:
|
|
# for vector search, convert distance to similarity score
|
|
distance = getattr(row, "$dist", 0.0)
|
|
score = 1.0 - distance
|
|
|
|
passages_with_scores.append((passage, score))
|
|
|
|
return passages_with_scores
|
|
|
|
@trace_method
|
|
async def delete_file_passages(self, source_id: str, file_id: str, organization_id: str) -> bool:
|
|
"""Delete all passages for a specific file from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
|
|
|
try:
|
|
# use delete_by_filter to only delete passages for this file
|
|
# need to filter by both source_id and file_id
|
|
filter_expr = ("And", [("source_id", "Eq", source_id), ("file_id", "Eq", file_id)])
|
|
|
|
# Run in thread pool for consistency
|
|
result = await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
delete_by_filter=filter_expr,
|
|
)
|
|
logger.info(
|
|
f"Successfully deleted passages for file {file_id} from source {source_id} (deleted {result.rows_affected if result else 0} rows)"
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete file passages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
async def delete_source_passages(self, source_id: str, organization_id: str) -> bool:
|
|
"""Delete all passages for a source from Turbopuffer."""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
|
|
|
try:
|
|
# Run in thread pool for consistency
|
|
result = await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
delete_by_filter=("source_id", "Eq", source_id),
|
|
)
|
|
logger.info(f"Successfully deleted all passages for source {source_id} (deleted {result.rows_affected if result else 0} rows)")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete source passages from Turbopuffer: {e}")
|
|
raise
|
|
|
|
# tool methods
|
|
|
|
@trace_method
|
|
async def delete_tools(self, organization_id: str, tool_ids: List[str]) -> bool:
|
|
"""Delete tools from Turbopuffer.
|
|
|
|
Args:
|
|
organization_id: Organization ID for namespace lookup
|
|
tool_ids: List of tool IDs to delete
|
|
|
|
Returns:
|
|
True if successful
|
|
"""
|
|
from turbopuffer import AsyncTurbopuffer
|
|
|
|
if not tool_ids:
|
|
return True
|
|
|
|
namespace_name = await self._get_tool_namespace_name(organization_id)
|
|
|
|
try:
|
|
# Run in thread pool for consistency
|
|
await asyncio.to_thread(
|
|
_run_turbopuffer_write_in_thread,
|
|
api_key=self.api_key,
|
|
region=self.region,
|
|
namespace_name=namespace_name,
|
|
deletes=tool_ids,
|
|
)
|
|
logger.info(f"Successfully deleted {len(tool_ids)} tools from Turbopuffer")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete tools from Turbopuffer: {e}")
|
|
raise
|
|
|
|
@trace_method
|
|
async def query_tools(
|
|
self,
|
|
organization_id: str,
|
|
actor: "PydanticUser",
|
|
query_text: Optional[str] = None,
|
|
search_mode: str = "hybrid", # "vector", "fts", "hybrid", "timestamp"
|
|
top_k: int = 50,
|
|
tool_types: Optional[List[str]] = None,
|
|
tags: Optional[List[str]] = None,
|
|
vector_weight: float = 0.5,
|
|
fts_weight: float = 0.5,
|
|
) -> List[Tuple[dict, float, dict]]:
|
|
"""Query tools from Turbopuffer using semantic search.
|
|
|
|
Args:
|
|
organization_id: Organization ID for namespace lookup
|
|
actor: User actor for embedding generation
|
|
query_text: Text query for search
|
|
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp"
|
|
top_k: Number of results to return
|
|
tool_types: Optional list of tool types to filter by
|
|
tags: Optional list of tags to filter by (match any)
|
|
vector_weight: Weight for vector search in hybrid mode
|
|
fts_weight: Weight for FTS in hybrid mode
|
|
|
|
Returns:
|
|
List of (tool_dict, score, metadata) tuples
|
|
"""
|
|
# Generate embedding for vector/hybrid search
|
|
query_embedding = None
|
|
if query_text and search_mode in ["vector", "hybrid"]:
|
|
embeddings = await self._generate_embeddings([query_text], actor)
|
|
query_embedding = embeddings[0] if embeddings else None
|
|
|
|
# Fallback to timestamp-based retrieval when no query
|
|
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
|
search_mode = "timestamp"
|
|
|
|
namespace_name = await self._get_tool_namespace_name(organization_id)
|
|
|
|
# Build filters
|
|
all_filters = []
|
|
|
|
if tool_types:
|
|
if len(tool_types) == 1:
|
|
all_filters.append(("tool_type", "Eq", tool_types[0]))
|
|
else:
|
|
all_filters.append(("tool_type", "In", tool_types))
|
|
|
|
if tags:
|
|
all_filters.append(("tags", "ContainsAny", tags))
|
|
|
|
# Combine filters
|
|
final_filter = None
|
|
if len(all_filters) == 1:
|
|
final_filter = all_filters[0]
|
|
elif len(all_filters) > 1:
|
|
final_filter = ("And", all_filters)
|
|
|
|
try:
|
|
result = await self._execute_query(
|
|
namespace_name=namespace_name,
|
|
search_mode=search_mode,
|
|
query_embedding=query_embedding,
|
|
query_text=query_text,
|
|
top_k=top_k,
|
|
include_attributes=["text", "name", "organization_id", "tool_type", "tags", "created_at"],
|
|
filters=final_filter,
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
)
|
|
|
|
if search_mode == "hybrid":
|
|
vector_results = self._process_tool_query_results(result.results[0])
|
|
fts_results = self._process_tool_query_results(result.results[1])
|
|
results_with_metadata = self._reciprocal_rank_fusion(
|
|
vector_results=vector_results,
|
|
fts_results=fts_results,
|
|
get_id_func=lambda d: d["id"],
|
|
vector_weight=vector_weight,
|
|
fts_weight=fts_weight,
|
|
top_k=top_k,
|
|
)
|
|
return results_with_metadata
|
|
else:
|
|
results = self._process_tool_query_results(result)
|
|
results_with_metadata = []
|
|
for idx, tool_dict in enumerate(results):
|
|
metadata = {
|
|
"combined_score": 1.0 / (idx + 1),
|
|
"search_mode": search_mode,
|
|
f"{search_mode}_rank": idx + 1,
|
|
}
|
|
results_with_metadata.append((tool_dict, metadata["combined_score"], metadata))
|
|
return results_with_metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to query tools from Turbopuffer: {e}")
|
|
raise
|
|
|
|
def _process_tool_query_results(self, result) -> List[dict]:
|
|
"""Process results from a tool query into tool dicts."""
|
|
tools = []
|
|
for row in result.rows:
|
|
tool_dict = {
|
|
"id": row.id,
|
|
"text": getattr(row, "text", ""),
|
|
"name": getattr(row, "name", ""),
|
|
"organization_id": getattr(row, "organization_id", None),
|
|
"tool_type": getattr(row, "tool_type", None),
|
|
"tags": getattr(row, "tags", []),
|
|
"created_at": getattr(row, "created_at", None),
|
|
}
|
|
tools.append(tool_dict)
|
|
return tools
|