feat: Embed all messages in turbopuffer [LET-4143] (#4352)

* wip

* Finish embedding

* Fix ruff and tests
This commit is contained in:
Matthew Zhou
2025-09-02 12:43:48 -07:00
committed by GitHub
parent a696d9e3d5
commit b8a198688f
4 changed files with 1396 additions and 176 deletions

View File

@@ -2,10 +2,10 @@
import logging
from datetime import datetime, timezone
from typing import List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple
from letta.otel.tracing import trace_method
from letta.schemas.enums import TagMatchMode
from letta.schemas.enums import MessageRole, TagMatchMode
from letta.schemas.passage import Passage as PydanticPassage
from letta.settings import settings
@@ -16,6 +16,11 @@ def should_use_tpuf() -> bool:
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key)
def should_use_tpuf_for_messages() -> bool:
"""Check if Turbopuffer should be used for messages."""
return should_use_tpuf() and bool(settings.embed_all_messages)
class TurbopufferClient:
"""Client for managing archival memory with Turbopuffer vector database."""
@@ -28,7 +33,7 @@ class TurbopufferClient:
raise ValueError("Turbopuffer API key not provided")
@trace_method
def _get_namespace_name(self, archive_id: str) -> str:
def _get_archive_namespace_name(self, archive_id: str) -> str:
"""Get namespace name for a specific archive."""
# use archive_id as namespace to isolate different archives' memories
# append environment suffix to namespace for isolation if environment is set
@@ -39,6 +44,18 @@ class TurbopufferClient:
namespace_name = archive_id
return namespace_name
@trace_method
def _get_message_namespace_name(self, agent_id: str) -> str:
"""Get namespace name for a specific agent's messages."""
# use agent_id as namespace to isolate different agents' messages
# append environment suffix to namespace for isolation if environment is set
environment = settings.environment
if environment:
namespace_name = f"messages_{agent_id}_{environment.lower()}"
else:
namespace_name = f"messages_{agent_id}"
return namespace_name
@trace_method
async def insert_archival_memories(
self,
@@ -66,7 +83,7 @@ class TurbopufferClient:
"""
from turbopuffer import AsyncTurbopuffer
namespace_name = self._get_namespace_name(archive_id)
namespace_name = self._get_archive_namespace_name(archive_id)
# handle timestamp - ensure UTC
if created_at is None:
@@ -155,6 +172,212 @@ class TurbopufferClient:
logger.error("Duplicate passage IDs detected in batch")
raise
@trace_method
async def insert_messages(
self,
agent_id: str,
message_texts: List[str],
embeddings: List[List[float]],
message_ids: List[str],
organization_id: str,
roles: List[MessageRole],
created_ats: List[datetime],
) -> bool:
"""Insert messages into Turbopuffer.
Args:
agent_id: ID of the agent
message_texts: List of message text content to store
embeddings: List of embedding vectors corresponding to message texts
message_ids: List of message IDs (must match 1:1 with message_texts)
organization_id: Organization ID for the messages
roles: List of message roles corresponding to each message
created_ats: List of creation timestamps for each message
Returns:
True if successful
"""
from turbopuffer import AsyncTurbopuffer
namespace_name = self._get_message_namespace_name(agent_id)
# validation checks
if not message_ids:
raise ValueError("message_ids must be provided for Turbopuffer insertion")
if len(message_ids) != len(message_texts):
raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
if len(message_ids) != len(embeddings):
raise ValueError(f"message_ids length ({len(message_ids)}) must match embeddings length ({len(embeddings)})")
if len(message_ids) != len(roles):
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
if len(message_ids) != len(created_ats):
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
# prepare column-based data for turbopuffer - optimized for batch insert
ids = []
vectors = []
texts = []
organization_ids = []
agent_ids = []
message_roles = []
created_at_timestamps = []
for idx, (text, embedding, role, created_at) in enumerate(zip(message_texts, embeddings, roles, created_ats)):
message_id = message_ids[idx]
# ensure the provided timestamp is timezone-aware and in UTC
if created_at.tzinfo is None:
# assume UTC if no timezone provided
timestamp = created_at.replace(tzinfo=timezone.utc)
else:
# convert to UTC if in different timezone
timestamp = created_at.astimezone(timezone.utc)
# append to columns
ids.append(message_id)
vectors.append(embedding)
texts.append(text)
organization_ids.append(organization_id)
agent_ids.append(agent_id)
message_roles.append(role.value)
created_at_timestamps.append(timestamp)
# build column-based upsert data
upsert_columns = {
"id": ids,
"vector": vectors,
"text": texts,
"organization_id": organization_ids,
"agent_id": agent_ids,
"role": message_roles,
"created_at": created_at_timestamps,
}
try:
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# turbopuffer recommends column-based writes for performance
await namespace.write(
upsert_columns=upsert_columns,
distance_metric="cosine_distance",
schema={"text": {"type": "string", "full_text_search": True}},
)
logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}")
return True
except Exception as e:
logger.error(f"Failed to insert messages to Turbopuffer: {e}")
# check if it's a duplicate ID error
if "duplicate" in str(e).lower():
logger.error("Duplicate message IDs detected in batch")
raise
@trace_method
async def _execute_query(
self,
namespace_name: str,
search_mode: str,
query_embedding: Optional[List[float]],
query_text: Optional[str],
top_k: int,
include_attributes: List[str],
filters: Optional[Any] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
) -> Any:
"""Generic query execution for Turbopuffer.
Args:
namespace_name: Turbopuffer namespace to query
search_mode: "vector", "fts", "hybrid", or "timestamp"
query_embedding: Embedding for vector search
query_text: Text for full-text search
top_k: Number of results to return
include_attributes: Attributes to include in results
filters: Turbopuffer filter expression
vector_weight: Weight for vector search in hybrid mode
fts_weight: Weight for FTS in hybrid mode
Returns:
Raw Turbopuffer query results or multi-query response
"""
from turbopuffer import AsyncTurbopuffer
from turbopuffer.types import QueryParam
# validate inputs based on search mode
if search_mode == "vector" and query_embedding is None:
raise ValueError("query_embedding is required for vector search mode")
if search_mode == "fts" and query_text is None:
raise ValueError("query_text is required for FTS search mode")
if search_mode == "hybrid":
if query_embedding is None or query_text is None:
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
if search_mode not in ["vector", "fts", "hybrid", "timestamp"]:
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'")
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
if search_mode == "timestamp":
# retrieve most recent items by timestamp
query_params = {
"rank_by": ("created_at", "desc"),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
query_params["filters"] = filters
return await namespace.query(**query_params)
elif search_mode == "vector":
# vector search query
query_params = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
query_params["filters"] = filters
return await namespace.query(**query_params)
elif search_mode == "fts":
# full-text search query
query_params = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
query_params["filters"] = filters
return await namespace.query(**query_params)
else: # hybrid mode
queries = []
# vector search query
vector_query = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
vector_query["filters"] = filters
queries.append(vector_query)
# full-text search query
fts_query = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": include_attributes,
}
if filters:
fts_query["filters"] = filters
queries.append(fts_query)
# execute multi-query
return await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
@trace_method
async def query_passages(
self,
@@ -188,147 +411,213 @@ class TurbopufferClient:
Returns:
List of (passage, score) tuples
"""
from turbopuffer import AsyncTurbopuffer
from turbopuffer.types import QueryParam
# validate inputs based on search mode first
if search_mode == "vector" and query_embedding is None:
raise ValueError("query_embedding is required for vector search mode")
if search_mode == "fts" and query_text is None:
raise ValueError("query_text is required for FTS search mode")
if search_mode == "hybrid":
if query_embedding is None or query_text is None:
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
if search_mode not in ["vector", "fts", "hybrid", "timestamp"]:
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'")
# Check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
# Fallback to retrieving most recent passages when no search query is provided
search_mode = "timestamp"
namespace_name = self._get_namespace_name(archive_id)
namespace_name = self._get_archive_namespace_name(archive_id)
# build tag filter conditions
tag_filter = None
if tags:
if tag_match_mode == TagMatchMode.ALL:
# For ALL mode, need to check each tag individually with Contains
tag_conditions = []
for tag in tags:
tag_conditions.append(("tags", "Contains", tag))
if len(tag_conditions) == 1:
tag_filter = tag_conditions[0]
else:
tag_filter = ("And", tag_conditions)
else: # tag_match_mode == TagMatchMode.ANY
# For ANY mode, use ContainsAny to match any of the tags
tag_filter = ("tags", "ContainsAny", tags)
# build date filter conditions
date_filters = []
if start_date:
date_filters.append(("created_at", "Gte", start_date))
if end_date:
date_filters.append(("created_at", "Lte", end_date))
# combine all filters
all_filters = []
if tag_filter:
all_filters.append(tag_filter)
if date_filters:
all_filters.extend(date_filters)
# create final filter expression
final_filter = None
if len(all_filters) == 1:
final_filter = all_filters[0]
elif len(all_filters) > 1:
final_filter = ("And", all_filters)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# use generic query executor
result = await self._execute_query(
namespace_name=namespace_name,
search_mode=search_mode,
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=["text", "organization_id", "archive_id", "created_at", "tags"],
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
)
# build tag filter conditions
tag_filter = None
if tags:
if tag_match_mode == TagMatchMode.ALL:
# For ALL mode, need to check each tag individually with Contains
tag_conditions = []
for tag in tags:
tag_conditions.append(("tags", "Contains", tag))
if len(tag_conditions) == 1:
tag_filter = tag_conditions[0]
else:
tag_filter = ("And", tag_conditions)
else: # tag_match_mode == TagMatchMode.ANY
# For ANY mode, use ContainsAny to match any of the tags
tag_filter = ("tags", "ContainsAny", tags)
# build date filter conditions
date_filters = []
if start_date:
# Turbopuffer expects datetime objects directly for comparison
date_filters.append(("created_at", "Gte", start_date))
if end_date:
# Turbopuffer expects datetime objects directly for comparison
date_filters.append(("created_at", "Lte", end_date))
# combine all filters
all_filters = []
if tag_filter:
all_filters.append(tag_filter)
if date_filters:
all_filters.extend(date_filters)
# create final filter expression
final_filter = None
if len(all_filters) == 1:
final_filter = all_filters[0]
elif len(all_filters) > 1:
final_filter = ("And", all_filters)
if search_mode == "timestamp":
# Fallback: retrieve most recent passages by timestamp
query_params = {
"rank_by": ("created_at", "desc"), # Order by created_at in descending order
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if final_filter:
query_params["filters"] = final_filter
result = await namespace.query(**query_params)
return self._process_single_query_results(result, archive_id, tags)
elif search_mode == "vector":
# single vector search query
query_params = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if final_filter:
query_params["filters"] = final_filter
result = await namespace.query(**query_params)
return self._process_single_query_results(result, archive_id, tags)
elif search_mode == "fts":
# single full-text search query
query_params = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if final_filter:
query_params["filters"] = final_filter
result = await namespace.query(**query_params)
return self._process_single_query_results(result, archive_id, tags, is_fts=True)
else: # hybrid mode
# multi-query for both vector and FTS
queries = []
# vector search query
vector_query = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if final_filter:
vector_query["filters"] = final_filter
queries.append(vector_query)
# full-text search query
fts_query = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if final_filter:
fts_query["filters"] = final_filter
queries.append(fts_query)
# execute multi-query
response = await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
# process and combine results using reciprocal rank fusion
vector_results = self._process_single_query_results(response.results[0], archive_id, tags)
fts_results = self._process_single_query_results(response.results[1], archive_id, tags, is_fts=True)
# combine results using reciprocal rank fusion
return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k)
# process results based on search mode
if search_mode == "hybrid":
# for hybrid mode, we get a multi-query response
vector_results = self._process_single_query_results(result.results[0], archive_id, tags)
fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True)
# use backwards-compatible wrapper which calls generic RRF
return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k)
else:
# for single queries (vector, fts, timestamp)
is_fts = search_mode == "fts"
return self._process_single_query_results(result, archive_id, tags, is_fts=is_fts)
except Exception as e:
logger.error(f"Failed to query passages from Turbopuffer: {e}")
raise
@trace_method
async def query_messages(
self,
agent_id: str,
query_embedding: Optional[List[float]] = None,
query_text: Optional[str] = None,
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
top_k: int = 10,
roles: Optional[List[MessageRole]] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> List[Tuple[dict, float]]:
"""Query messages from Turbopuffer using vector search, full-text search, or hybrid search.
Args:
agent_id: ID of the agent
query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes)
query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector")
top_k: Number of results to return
roles: Optional list of message roles to filter by
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter messages created after this date
end_date: Optional datetime to filter messages created before this date
Returns:
List of (message_dict, score) tuples where message_dict contains id, text, role, created_at
"""
# Check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
# Fallback to retrieving most recent messages when no search query is provided
search_mode = "timestamp"
namespace_name = self._get_message_namespace_name(agent_id)
# build role filter conditions
role_filter = None
if roles:
role_values = [r.value for r in roles]
if len(role_values) == 1:
role_filter = ("role", "Eq", role_values[0])
else:
role_filter = ("role", "In", role_values)
# build date filter conditions
date_filters = []
if start_date:
date_filters.append(("created_at", "Gte", start_date))
if end_date:
date_filters.append(("created_at", "Lte", end_date))
# combine all filters
all_filters = []
if role_filter:
all_filters.append(role_filter)
if date_filters:
all_filters.extend(date_filters)
# create final filter expression
final_filter = None
if len(all_filters) == 1:
final_filter = all_filters[0]
elif len(all_filters) > 1:
final_filter = ("And", all_filters)
try:
# use generic query executor
result = await self._execute_query(
namespace_name=namespace_name,
search_mode=search_mode,
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
)
# process results based on search mode
if search_mode == "hybrid":
# for hybrid mode, we get a multi-query response
vector_results = self._process_message_query_results(result.results[0])
fts_results = self._process_message_query_results(result.results[1], is_fts=True)
# use generic RRF with lambda to extract ID from dict
return self._generic_reciprocal_rank_fusion(
vector_results=vector_results,
fts_results=fts_results,
get_id_func=lambda msg_dict: msg_dict["id"],
vector_weight=vector_weight,
fts_weight=fts_weight,
top_k=top_k,
)
else:
# for single queries (vector, fts, timestamp)
is_fts = search_mode == "fts"
return self._process_message_query_results(result, is_fts=is_fts)
except Exception as e:
logger.error(f"Failed to query messages from Turbopuffer: {e}")
raise
def _process_message_query_results(self, result, is_fts: bool = False) -> List[Tuple[dict, float]]:
"""Process results from a message query into message dicts with scores."""
messages_with_scores = []
for row in result.rows:
# Build message dict with key fields
message_dict = {
"id": row.id,
"text": getattr(row, "text", ""),
"organization_id": getattr(row, "organization_id", None),
"agent_id": getattr(row, "agent_id", None),
"role": getattr(row, "role", None),
"created_at": getattr(row, "created_at", None),
}
# handle score based on search type
if is_fts:
# for FTS, use the BM25 score directly (higher is better)
score = getattr(row, "$score", 0.0)
else:
# for vector search, convert distance to similarity score
distance = getattr(row, "$dist", 0.0)
score = 1.0 - distance
messages_with_scores.append((message_dict, score))
return messages_with_scores
def _process_single_query_results(
self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False
) -> List[Tuple[PydanticPassage, float]]:
@@ -369,6 +658,56 @@ class TurbopufferClient:
return passages_with_scores
def _generic_reciprocal_rank_fusion(
self,
vector_results: List[Tuple[Any, float]],
fts_results: List[Tuple[Any, float]],
get_id_func: Callable[[Any], str],
vector_weight: float,
fts_weight: float,
top_k: int,
) -> List[Tuple[Any, float]]:
"""Generic RRF implementation that works with any object type.
RRF score = vector_weight * (1/(k + vector_rank)) + fts_weight * (1/(k + fts_rank))
where k is a constant (typically 60) to avoid division by zero
Args:
vector_results: List of (item, score) tuples from vector search
fts_results: List of (item, score) tuples from FTS
get_id_func: Function to extract ID from an item
vector_weight: Weight for vector search results
fts_weight: Weight for FTS results
top_k: Number of results to return
Returns:
List of (item, score) tuples sorted by RRF score
"""
k = 60 # standard RRF constant
# create rank mappings using the get_id_func
vector_ranks = {get_id_func(item): rank + 1 for rank, (item, _) in enumerate(vector_results)}
fts_ranks = {get_id_func(item): rank + 1 for rank, (item, _) in enumerate(fts_results)}
# combine all unique items
all_items = {}
for item, _ in vector_results:
all_items[get_id_func(item)] = item
for item, _ in fts_results:
all_items[get_id_func(item)] = item
# calculate RRF scores
rrf_scores = {}
for item_id in all_items:
vector_score = vector_weight / (k + vector_ranks.get(item_id, k + top_k))
fts_score = fts_weight / (k + fts_ranks.get(item_id, k + top_k))
rrf_scores[item_id] = vector_score + fts_score
# sort by RRF score and return top_k
sorted_results = sorted([(all_items[iid], score) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True)
return sorted_results[:top_k]
def _reciprocal_rank_fusion(
self,
vector_results: List[Tuple[PydanticPassage, float]],
@@ -377,42 +716,22 @@ class TurbopufferClient:
fts_weight: float,
top_k: int,
) -> List[Tuple[PydanticPassage, float]]:
"""Combine vector and FTS results using Reciprocal Rank Fusion (RRF).
RRF score = vector_weight * (1/(k + vector_rank)) + fts_weight * (1/(k + fts_rank))
where k is a constant (typically 60) to avoid division by zero
"""
k = 60 # standard RRF constant
# create rank mappings
vector_ranks = {passage.id: rank + 1 for rank, (passage, _) in enumerate(vector_results)}
fts_ranks = {passage.id: rank + 1 for rank, (passage, _) in enumerate(fts_results)}
# combine all unique passage IDs
all_passages = {}
for passage, _ in vector_results:
all_passages[passage.id] = passage
for passage, _ in fts_results:
all_passages[passage.id] = passage
# calculate RRF scores
rrf_scores = {}
for passage_id in all_passages:
vector_score = vector_weight / (k + vector_ranks.get(passage_id, k + top_k))
fts_score = fts_weight / (k + fts_ranks.get(passage_id, k + top_k))
rrf_scores[passage_id] = vector_score + fts_score
# sort by RRF score and return top_k
sorted_results = sorted([(all_passages[pid], score) for pid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True)
return sorted_results[:top_k]
"""Wrapper for backwards compatibility - uses generic RRF for passages."""
return self._generic_reciprocal_rank_fusion(
vector_results=vector_results,
fts_results=fts_results,
get_id_func=lambda p: p.id,
vector_weight=vector_weight,
fts_weight=fts_weight,
top_k=top_k,
)
@trace_method
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
"""Delete a passage from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = self._get_namespace_name(archive_id)
namespace_name = self._get_archive_namespace_name(archive_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
@@ -433,7 +752,7 @@ class TurbopufferClient:
if not passage_ids:
return True
namespace_name = self._get_namespace_name(archive_id)
namespace_name = self._get_archive_namespace_name(archive_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
@@ -451,7 +770,7 @@ class TurbopufferClient:
"""Delete all passages for an archive from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = self._get_namespace_name(archive_id)
namespace_name = self._get_archive_namespace_name(archive_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
@@ -463,3 +782,21 @@ class TurbopufferClient:
except Exception as e:
logger.error(f"Failed to delete all passages from Turbopuffer: {e}")
raise
@trace_method
async def delete_all_messages(self, agent_id: str) -> bool:
"""Delete all messages for an agent from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = self._get_message_namespace_name(agent_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# Turbopuffer has a delete_all() method on namespace
await namespace.delete_all()
logger.info(f"Successfully deleted all messages for agent {agent_id}")
return True
except Exception as e:
logger.error(f"Failed to delete all messages from Turbopuffer: {e}")
raise

View File

@@ -1,5 +1,6 @@
import json
import uuid
from datetime import datetime
from typing import List, Optional, Sequence
from sqlalchemy import delete, exists, func, select, text
@@ -11,7 +12,7 @@ from letta.orm.message import Message as MessageModel
from letta.otel.tracing import trace_method
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import LettaMessageUpdateUnion
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType, TextContent
from letta.schemas.message import Message as PydanticMessage, MessageUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
@@ -30,6 +31,34 @@ class MessageManager:
"""Initialize the MessageManager."""
self.file_manager = FileManager()
def _extract_message_text(self, message: PydanticMessage) -> str:
"""Extract text content from a message's complex content structure.
Args:
message: The message to extract text from
Returns:
Concatenated text content from the message
"""
# TODO: Make this much more complex/extend to beyond text content
if not message.content:
return ""
# handle string content (legacy)
if isinstance(message.content, str):
return message.content
# handle list of content items
text_parts = []
for content_item in message.content:
if isinstance(content_item, TextContent):
text_parts.append(content_item.text)
elif hasattr(content_item, "text"):
# handle other content types that might have text
text_parts.append(content_item.text)
return " ".join(text_parts)
@enforce_types
@trace_method
def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
@@ -125,13 +154,19 @@ class MessageManager:
@enforce_types
@trace_method
async def create_many_messages_async(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]:
async def create_many_messages_async(
self,
pydantic_msgs: List[PydanticMessage],
actor: PydanticUser,
embedding_config: Optional[dict] = None,
) -> List[PydanticMessage]:
"""
Create multiple messages in a single database transaction asynchronously.
Args:
pydantic_msgs: List of Pydantic message models to create
actor: User performing the action
embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer
Returns:
List of created Pydantic message models
@@ -169,6 +204,62 @@ class MessageManager:
created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor, no_commit=True, no_refresh=True)
result = [msg.to_pydantic() for msg in created_messages]
await session.commit()
# embed messages in turbopuffer if enabled and embedding_config provided
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
if should_use_tpuf_for_messages() and embedding_config and result:
try:
# extract agent_id from the first message (all should have same agent_id)
agent_id = result[0].agent_id
if agent_id:
# extract text content from each message
message_texts = []
message_ids = []
roles = []
created_ats = []
for msg in result:
text = self._extract_message_text(msg)
if text: # only embed messages with text content
message_texts.append(text)
message_ids.append(msg.id)
roles.append(msg.role)
created_ats.append(msg.created_at)
if message_texts:
# generate embeddings using provided config
from letta.llm_api.llm_client import LLMClient
# extract provider info from embedding_config
embedding_provider = embedding_config.get("provider", "openai")
embedding_api_key = embedding_config.get("api_key")
embedding_endpoint = embedding_config.get("endpoint", "https://api.openai.com/v1")
embedding_client = LLMClient(
llm_provider_type=embedding_provider,
api_key=embedding_api_key,
endpoint=embedding_endpoint,
actor=actor,
)
embeddings = await embedding_client.request_embeddings(message_texts, embedding_config)
# insert to turbopuffer
tpuf_client = TurbopufferClient()
await tpuf_client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=actor.organization_id,
roles=roles,
created_ats=created_ats,
)
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
except Exception as e:
# log error but don't fail the message creation
logger.error(f"Failed to embed messages in Turbopuffer: {e}")
return result
@enforce_types
@@ -672,3 +763,116 @@ class MessageManager:
# return the number of rows deleted
return result.rowcount
@enforce_types
@trace_method
async def search_messages_async(
self,
agent_id: str,
actor: PydanticUser,
query_text: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
search_mode: str = "hybrid",
roles: Optional[List[MessageRole]] = None,
limit: int = 50,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
embedding_config: Optional[dict] = None,
) -> List[PydanticMessage]:
"""
Search messages using Turbopuffer if enabled, otherwise fall back to SQL search.
Args:
agent_id: ID of the agent whose messages to search
actor: User performing the search
query_text: Text query for full-text search
query_embedding: Optional pre-computed embedding for vector search
search_mode: "vector", "fts", "hybrid", or "timestamp" (default: "hybrid")
roles: Optional list of message roles to filter by
limit: Maximum number of results to return
start_date: Optional filter for messages created after this date
end_date: Optional filter for messages created before this date
embedding_config: Optional embedding configuration for generating query embedding
Returns:
List of matching messages
"""
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
# check if we should use turbopuffer
if should_use_tpuf_for_messages():
try:
# generate embedding if needed and not provided
if search_mode in ["vector", "hybrid"] and query_embedding is None and query_text:
if not embedding_config:
# fall back to SQL search if no embedding config
logger.warning("No embedding config provided for vector search, falling back to SQL")
return await self.list_messages_for_agent_async(
agent_id=agent_id,
actor=actor,
query_text=query_text,
roles=roles,
limit=limit,
ascending=False,
)
# generate embedding from query text
from letta.llm_api.llm_client import LLMClient
embedding_provider = embedding_config.get("provider", "openai")
embedding_api_key = embedding_config.get("api_key")
embedding_endpoint = embedding_config.get("endpoint", "https://api.openai.com/v1")
embedding_client = LLMClient(
llm_provider_type=embedding_provider,
api_key=embedding_api_key,
endpoint=embedding_endpoint,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
query_embedding = embeddings[0]
# use turbopuffer for search
tpuf_client = TurbopufferClient()
results = await tpuf_client.query_messages(
agent_id=agent_id,
query_embedding=query_embedding,
query_text=query_text,
search_mode=search_mode,
top_k=limit,
roles=roles,
start_date=start_date,
end_date=end_date,
)
# fetch full message objects from database using the IDs
message_ids = [msg_dict["id"] for msg_dict, _ in results]
if message_ids:
messages = await self.get_messages_by_ids_async(message_ids, actor)
# maintain the order from turbopuffer results
message_dict = {msg.id: msg for msg in messages}
return [message_dict[msg_id] for msg_id in message_ids if msg_id in message_dict]
else:
return []
except Exception as e:
logger.error(f"Failed to search messages with Turbopuffer, falling back to SQL: {e}")
# fall back to SQL search
return await self.list_messages_for_agent_async(
agent_id=agent_id,
actor=actor,
query_text=query_text,
roles=roles,
limit=limit,
ascending=False,
)
else:
# use sql-based search
return await self.list_messages_for_agent_async(
agent_id=agent_id,
actor=actor,
query_text=query_text,
roles=roles,
limit=limit,
ascending=False,
)

View File

@@ -301,6 +301,7 @@ class Settings(BaseSettings):
use_tpuf: bool = False
tpuf_api_key: Optional[str] = None
tpuf_region: str = "gcp-us-central1"
embed_all_messages: bool = False
# File processing timeout settings
file_processing_timeout_minutes: int = 30

View File

@@ -4,9 +4,11 @@ from datetime import datetime, timezone
import pytest
from letta.config import LettaConfig
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf, should_use_tpuf_for_messages
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import TagMatchMode, VectorDBProvider
from letta.schemas.enums import MessageRole, TagMatchMode, VectorDBProvider
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.passage import Passage
from letta.server.server import SyncServer
from letta.settings import settings
@@ -69,6 +71,48 @@ def enable_turbopuffer():
settings.environment = original_environment
@pytest.fixture
def enable_message_embedding():
"""Enable both Turbopuffer and message embedding"""
original_use_tpuf = settings.use_tpuf
original_api_key = settings.tpuf_api_key
original_embed_messages = settings.embed_all_messages
original_environment = settings.environment
settings.use_tpuf = True
settings.tpuf_api_key = settings.tpuf_api_key or "test-key"
settings.embed_all_messages = True
settings.environment = "DEV"
yield
settings.use_tpuf = original_use_tpuf
settings.tpuf_api_key = original_api_key
settings.embed_all_messages = original_embed_messages
settings.environment = original_environment
@pytest.fixture
def disable_turbopuffer():
"""Ensure Turbopuffer is disabled for testing"""
original_use_tpuf = settings.use_tpuf
original_embed_messages = settings.embed_all_messages
settings.use_tpuf = False
settings.embed_all_messages = False
yield
settings.use_tpuf = original_use_tpuf
settings.embed_all_messages = original_embed_messages
@pytest.fixture
def sample_embedding_config():
"""Provide a sample embedding configuration"""
return EmbeddingConfig.default_config(model_name="letta")
class TestTurbopufferIntegration:
"""Test Turbopuffer integration functionality with real connections"""
@@ -374,14 +418,6 @@ class TestTurbopufferIntegration:
# all results should have scores
assert all(isinstance(score, float) for _, score in vector_heavy_results)
# Test error handling - missing embedding for vector mode
with pytest.raises(ValueError, match="query_embedding is required for vector search mode"):
await client.query_passages(archive_id=archive_id, search_mode="vector", top_k=3)
# Test error handling - missing text for FTS mode
with pytest.raises(ValueError, match="query_text is required for FTS search mode"):
await client.query_passages(archive_id=archive_id, search_mode="fts", top_k=3)
# Test error handling - missing text for hybrid mode (embedding provided but text missing)
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
await client.query_passages(archive_id=archive_id, query_embedding=[1.0, 2.0, 3.0], search_mode="hybrid", top_k=3)
@@ -751,3 +787,645 @@ class TestTurbopufferParametrized:
# Clean up
await server.passage_manager.delete_agent_passages_async(recent_passage, default_user)
await server.passage_manager.delete_agent_passages_async(old_passage, default_user)
class TestTurbopufferMessagesIntegration:
"""Test Turbopuffer message embedding functionality"""
def test_should_use_tpuf_for_messages_settings(self):
"""Test that should_use_tpuf_for_messages correctly checks both use_tpuf AND embed_all_messages"""
# Save original values
original_use_tpuf = settings.use_tpuf
original_api_key = settings.tpuf_api_key
original_embed_messages = settings.embed_all_messages
try:
# Test when both are true
settings.use_tpuf = True
settings.tpuf_api_key = "test-key"
settings.embed_all_messages = True
assert should_use_tpuf_for_messages() is True
# Test when use_tpuf is False
settings.use_tpuf = False
settings.embed_all_messages = True
assert should_use_tpuf_for_messages() is False
# Test when embed_all_messages is False
settings.use_tpuf = True
settings.tpuf_api_key = "test-key"
settings.embed_all_messages = False
assert should_use_tpuf_for_messages() is False
# Test when both are false
settings.use_tpuf = False
settings.embed_all_messages = False
assert should_use_tpuf_for_messages() is False
# Test when API key is missing
settings.use_tpuf = True
settings.tpuf_api_key = None
settings.embed_all_messages = True
assert should_use_tpuf_for_messages() is False
finally:
# Restore original values
settings.use_tpuf = original_use_tpuf
settings.tpuf_api_key = original_api_key
settings.embed_all_messages = original_embed_messages
def test_message_text_extraction(self, server, default_user):
"""Test extraction of text from various message content structures"""
manager = server.message_manager
# Test 1: List with single string-like TextContent
msg1 = PydanticMessage(
role=MessageRole.user,
content=[TextContent(text="Simple text content")],
agent_id="test-agent",
)
text1 = manager._extract_message_text(msg1)
assert text1 == "Simple text content"
# Test 2: List with single TextContent
msg2 = PydanticMessage(
role=MessageRole.user,
content=[TextContent(text="Single text content")],
agent_id="test-agent",
)
text2 = manager._extract_message_text(msg2)
assert text2 == "Single text content"
# Test 3: List with multiple TextContent items
msg3 = PydanticMessage(
role=MessageRole.user,
content=[
TextContent(text="First part"),
TextContent(text="Second part"),
TextContent(text="Third part"),
],
agent_id="test-agent",
)
text3 = manager._extract_message_text(msg3)
assert text3 == "First part Second part Third part"
# Test 4: Empty content
msg4 = PydanticMessage(
role=MessageRole.system,
content=None,
agent_id="test-agent",
)
text4 = manager._extract_message_text(msg4)
assert text4 == ""
# Test 5: Empty list
msg5 = PydanticMessage(
role=MessageRole.assistant,
content=[],
agent_id="test-agent",
)
text5 = manager._extract_message_text(msg5)
assert text5 == ""
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_embedding_without_config(self, server, default_user, sarah_agent, enable_message_embedding):
"""Test that messages are NOT embedded without embedding_config even when tpuf is enabled"""
# Create messages WITHOUT embedding_config
messages = [
PydanticMessage(
role=MessageRole.user,
content=[TextContent(text="Test message without embedding config")],
agent_id=sarah_agent.id,
),
PydanticMessage(
role=MessageRole.assistant,
content=[TextContent(text="Response without embedding config")],
agent_id=sarah_agent.id,
),
]
# Create messages without embedding_config
created = await server.message_manager.create_many_messages_async(
pydantic_msgs=messages,
actor=default_user,
embedding_config=None, # No config provided
)
assert len(created) == 2
assert all(msg.agent_id == sarah_agent.id for msg in created)
# Messages should be in SQL
sql_messages = await server.message_manager.list_messages_for_agent_async(
agent_id=sarah_agent.id,
actor=default_user,
limit=10,
)
assert len(sql_messages) >= 2
# Clean up
message_ids = [msg.id for msg in created]
await server.message_manager.delete_messages_by_ids_async(message_ids, default_user)
@pytest.mark.asyncio
async def test_generic_reciprocal_rank_fusion(self):
"""Test the generic RRF function with different object types"""
from letta.helpers.tpuf_client import TurbopufferClient
client = TurbopufferClient()
# Test with passage objects (backward compatibility)
p1_id = "passage-78d49031-8502-49c1-a970-45663e9f6e07"
p2_id = "passage-90df8386-4caf-49cc-acbc-d71526de6f77"
passage1 = Passage(
id=p1_id,
text="First passage",
organization_id="org1",
archive_id="archive1",
created_at=datetime.now(timezone.utc),
metadata_={},
tags=[],
embedding=[],
embedding_config=None,
)
passage2 = Passage(
id=p2_id,
text="Second passage",
organization_id="org1",
archive_id="archive1",
created_at=datetime.now(timezone.utc),
metadata_={},
tags=[],
embedding=[],
embedding_config=None,
)
vector_results = [(passage1, 0.9), (passage2, 0.7)]
fts_results = [(passage2, 0.8), (passage1, 0.6)]
# Test with passages using the wrapper function
combined = client._reciprocal_rank_fusion(
vector_results=vector_results,
fts_results=fts_results,
vector_weight=0.5,
fts_weight=0.5,
top_k=2,
)
assert len(combined) == 2
# Both passages should be in results
result_ids = [p.id for p, _ in combined]
assert p1_id in result_ids
assert p2_id in result_ids
# Test with message dicts using generic function
msg1 = {"id": "m1", "text": "First message"}
msg2 = {"id": "m2", "text": "Second message"}
msg3 = {"id": "m3", "text": "Third message"}
vector_msg_results = [(msg1, 0.95), (msg2, 0.85), (msg3, 0.75)]
fts_msg_results = [(msg2, 0.90), (msg3, 0.80), (msg1, 0.70)]
combined_msgs = client._generic_reciprocal_rank_fusion(
vector_results=vector_msg_results,
fts_results=fts_msg_results,
get_id_func=lambda m: m["id"],
vector_weight=0.6,
fts_weight=0.4,
top_k=3,
)
assert len(combined_msgs) == 3
msg_ids = [m["id"] for m, _ in combined_msgs]
assert "m1" in msg_ids
assert "m2" in msg_ids
assert "m3" in msg_ids
# Test edge cases
# Empty results
empty_combined = client._generic_reciprocal_rank_fusion(
vector_results=[],
fts_results=[],
get_id_func=lambda x: x["id"],
vector_weight=0.5,
fts_weight=0.5,
top_k=10,
)
assert len(empty_combined) == 0
# Single result list
single_combined = client._generic_reciprocal_rank_fusion(
vector_results=[(msg1, 0.9)],
fts_results=[],
get_id_func=lambda m: m["id"],
vector_weight=0.5,
fts_weight=0.5,
top_k=10,
)
assert len(single_combined) == 1
assert single_combined[0][0]["id"] == "m1"
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding):
"""Test actual message embedding and storage in Turbopuffer"""
import uuid
from datetime import datetime, timezone
from letta.helpers.tpuf_client import TurbopufferClient
from letta.schemas.enums import MessageRole
client = TurbopufferClient()
agent_id = f"test-agent-{uuid.uuid4()}"
org_id = str(uuid.uuid4())
try:
# Prepare test messages
message_texts = [
"Hello, how can I help you today?",
"I need help with Python programming.",
"Sure, what specific Python topic?",
]
message_ids = [str(uuid.uuid4()) for _ in message_texts]
roles = [MessageRole.assistant, MessageRole.user, MessageRole.assistant]
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
# Generate embeddings (dummy for test)
embeddings = [[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))]
# Insert messages into Turbopuffer
success = await client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=org_id,
roles=roles,
created_ats=created_ats,
)
assert success == True
# Verify we can query the messages
results = await client.query_messages(
agent_id=agent_id,
search_mode="timestamp",
top_k=10,
)
assert len(results) == 3
# Results should be ordered by timestamp (most recent first)
for msg_dict, score in results:
assert msg_dict["agent_id"] == agent_id
assert msg_dict["organization_id"] == org_id
assert msg_dict["text"] in message_texts
assert msg_dict["role"] in ["assistant", "user"]
finally:
# Clean up namespace
try:
await client.delete_all_messages(agent_id)
except:
pass
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding):
"""Test vector search on messages in Turbopuffer"""
import uuid
from datetime import datetime, timezone
from letta.helpers.tpuf_client import TurbopufferClient
from letta.schemas.enums import MessageRole
client = TurbopufferClient()
agent_id = f"test-agent-{uuid.uuid4()}"
org_id = str(uuid.uuid4())
try:
# Insert messages with different embeddings
message_texts = [
"Python is a great programming language",
"JavaScript is used for web development",
"Machine learning with Python is powerful",
]
message_ids = [str(uuid.uuid4()) for _ in message_texts]
roles = [MessageRole.assistant] * len(message_texts)
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
# Create embeddings that reflect content similarity
embeddings = [
[1.0, 0.0, 0.0], # Python programming
[0.0, 1.0, 0.0], # JavaScript web
[0.8, 0.0, 0.2], # ML with Python (similar to first)
]
# Insert messages
await client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=org_id,
roles=roles,
created_ats=created_ats,
)
# Search for Python-related messages using vector search
query_embedding = [0.9, 0.0, 0.1] # Similar to Python messages
results = await client.query_messages(
agent_id=agent_id,
query_embedding=query_embedding,
search_mode="vector",
top_k=2,
)
assert len(results) == 2
# Should return Python-related messages first
result_texts = [msg["text"] for msg, _ in results]
assert "Python is a great programming language" in result_texts
assert "Machine learning with Python is powerful" in result_texts
finally:
# Clean up namespace
try:
await client.delete_all_messages(agent_id)
except:
pass
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding):
"""Test hybrid search combining vector and FTS for messages"""
import uuid
from datetime import datetime, timezone
from letta.helpers.tpuf_client import TurbopufferClient
from letta.schemas.enums import MessageRole
client = TurbopufferClient()
agent_id = f"test-agent-{uuid.uuid4()}"
org_id = str(uuid.uuid4())
try:
# Insert diverse messages
message_texts = [
"The quick brown fox jumps over the lazy dog",
"Machine learning algorithms are fascinating",
"Quick tutorial on Python programming",
"Deep learning with neural networks",
]
message_ids = [str(uuid.uuid4()) for _ in message_texts]
roles = [MessageRole.assistant] * len(message_texts)
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
# Embeddings
embeddings = [
[0.1, 0.9, 0.0], # fox text
[0.9, 0.1, 0.0], # ML algorithms
[0.5, 0.5, 0.0], # Quick Python
[0.8, 0.2, 0.0], # Deep learning
]
# Insert messages
await client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=org_id,
roles=roles,
created_ats=created_ats,
)
# Hybrid search - vector similar to ML but text contains "quick"
results = await client.query_messages(
agent_id=agent_id,
query_embedding=[0.7, 0.3, 0.0], # Similar to ML messages
query_text="quick", # Text search for "quick"
search_mode="hybrid",
top_k=3,
vector_weight=0.5,
fts_weight=0.5,
)
assert len(results) > 0
# Should get a mix of results based on both vector and text similarity
result_texts = [msg["text"] for msg, _ in results]
# At least one result should contain "quick" due to FTS
assert any("quick" in text.lower() for text in result_texts)
finally:
# Clean up namespace
try:
await client.delete_all_messages(agent_id)
except:
pass
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding):
"""Test filtering messages by role"""
import uuid
from datetime import datetime, timezone
from letta.helpers.tpuf_client import TurbopufferClient
from letta.schemas.enums import MessageRole
client = TurbopufferClient()
agent_id = f"test-agent-{uuid.uuid4()}"
org_id = str(uuid.uuid4())
try:
# Insert messages with different roles
message_data = [
("Hello! How can I help?", MessageRole.assistant),
("I need help with Python", MessageRole.user),
("Here's a Python example", MessageRole.assistant),
("Can you explain this?", MessageRole.user),
("System message here", MessageRole.system),
]
message_texts = [text for text, _ in message_data]
roles = [role for _, role in message_data]
message_ids = [str(uuid.uuid4()) for _ in message_texts]
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
embeddings = [[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))]
# Insert messages
await client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=org_id,
roles=roles,
created_ats=created_ats,
)
# Query only user messages
user_results = await client.query_messages(
agent_id=agent_id,
search_mode="timestamp",
top_k=10,
roles=[MessageRole.user],
)
assert len(user_results) == 2
for msg, _ in user_results:
assert msg["role"] == "user"
assert msg["text"] in ["I need help with Python", "Can you explain this?"]
# Query assistant and system messages
non_user_results = await client.query_messages(
agent_id=agent_id,
search_mode="timestamp",
top_k=10,
roles=[MessageRole.assistant, MessageRole.system],
)
assert len(non_user_results) == 3
for msg, _ in non_user_results:
assert msg["role"] in ["assistant", "system"]
finally:
# Clean up namespace
try:
await client.delete_all_messages(agent_id)
except:
pass
@pytest.mark.asyncio
async def test_message_search_fallback_to_sql(self, server, default_user, sarah_agent):
"""Test that message search falls back to SQL when Turbopuffer is disabled"""
# Save original settings
original_use_tpuf = settings.use_tpuf
original_embed_messages = settings.embed_all_messages
try:
# Disable Turbopuffer for messages
settings.use_tpuf = False
settings.embed_all_messages = False
# Create messages
messages = await server.message_manager.create_many_messages_async(
pydantic_msgs=[
PydanticMessage(
role=MessageRole.user,
content=[TextContent(text="Test message for SQL fallback")],
agent_id=sarah_agent.id,
)
],
actor=default_user,
)
# Search should use SQL backend (not Turbopuffer)
results = await server.message_manager.search_messages_async(
actor=default_user,
agent_id=sarah_agent.id,
query_text="fallback",
limit=10,
)
# Should return results from SQL search
assert len(results) > 0
# Extract text from messages and check for "fallback"
for msg in results:
text = server.message_manager._extract_message_text(msg)
if "fallback" in text.lower():
break
else:
assert False, "No messages containing 'fallback' found"
finally:
# Restore settings
settings.use_tpuf = original_use_tpuf
settings.embed_all_messages = original_embed_messages
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding):
"""Test filtering messages by date range"""
import uuid
from datetime import datetime, timedelta, timezone
from letta.helpers.tpuf_client import TurbopufferClient
from letta.schemas.enums import MessageRole
client = TurbopufferClient()
agent_id = f"test-agent-{uuid.uuid4()}"
org_id = str(uuid.uuid4())
try:
# Create messages with different timestamps
now = datetime.now(timezone.utc)
yesterday = now - timedelta(days=1)
last_week = now - timedelta(days=7)
last_month = now - timedelta(days=30)
message_data = [
("Today's message", now),
("Yesterday's message", yesterday),
("Last week's message", last_week),
("Last month's message", last_month),
]
for text, timestamp in message_data:
await client.insert_messages(
agent_id=agent_id,
message_texts=[text],
embeddings=[[1.0, 2.0, 3.0]],
message_ids=[str(uuid.uuid4())],
organization_id=org_id,
roles=[MessageRole.assistant],
created_ats=[timestamp],
)
# Query messages from the last 3 days
three_days_ago = now - timedelta(days=3)
recent_results = await client.query_messages(
agent_id=agent_id,
search_mode="timestamp",
top_k=10,
start_date=three_days_ago,
)
# Should get today's and yesterday's messages
assert len(recent_results) == 2
result_texts = [msg["text"] for msg, _ in recent_results]
assert "Today's message" in result_texts
assert "Yesterday's message" in result_texts
# Query messages between 2 weeks ago and 1 week ago
two_weeks_ago = now - timedelta(days=14)
week_results = await client.query_messages(
agent_id=agent_id,
search_mode="timestamp",
top_k=10,
start_date=two_weeks_ago,
end_date=last_week + timedelta(days=1), # Include last week's message
)
# Should get only last week's message
assert len(week_results) == 1
assert week_results[0][0]["text"] == "Last week's message"
# Query with vector search and date filtering
filtered_vector_results = await client.query_messages(
agent_id=agent_id,
query_embedding=[1.0, 2.0, 3.0],
search_mode="vector",
top_k=10,
start_date=three_days_ago,
)
# Should get only recent messages
assert len(filtered_vector_results) == 2
for msg, _ in filtered_vector_results:
assert msg["text"] in ["Today's message", "Yesterday's message"]
finally:
# Clean up namespace
try:
await client.delete_all_messages(agent_id)
except:
pass