feat: Add search messages endpoint [LET-4144] (#4434)

* Add search messages endpoint

* Run fern autogen and fix tests
This commit is contained in:
Matthew Zhou
2025-09-05 14:28:27 -07:00
committed by GitHub
parent 80adb82b34
commit 2e3cabc080
13 changed files with 462 additions and 312 deletions

View File

@@ -178,7 +178,6 @@ class BaseAgent(ABC):
curr_system_message.id,
message_update=MessageUpdate(content=new_system_message_str),
actor=self.actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)
return [new_system_message] + in_context_messages[1:]

View File

@@ -117,7 +117,6 @@ async def _prepare_in_context_messages_async(
new_in_context_messages = await message_manager.create_many_messages_async(
create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor),
actor=actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)

View File

@@ -497,7 +497,6 @@ class LettaAgent(BaseAgent):
await self.message_manager.create_many_messages_async(
initial_messages,
actor=self.actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
@@ -828,7 +827,6 @@ class LettaAgent(BaseAgent):
await self.message_manager.create_many_messages_async(
initial_messages,
actor=self.actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
@@ -1267,7 +1265,6 @@ class LettaAgent(BaseAgent):
await self.message_manager.create_many_messages_async(
initial_messages,
actor=self.actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
@@ -1676,7 +1673,7 @@ class LettaAgent(BaseAgent):
)
messages_to_persist = (initial_messages or []) + tool_call_messages
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config, project_id=agent_state.project_id
messages_to_persist, actor=self.actor, project_id=agent_state.project_id
)
return persisted_messages, continue_stepping, stop_reason
@@ -1788,7 +1785,7 @@ class LettaAgent(BaseAgent):
messages_to_persist = (initial_messages or []) + tool_call_messages
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config, project_id=agent_state.project_id
messages_to_persist, actor=self.actor, project_id=agent_state.project_id
)
if run_id:

View File

@@ -4,26 +4,37 @@ import logging
from datetime import datetime, timezone
from typing import Any, Callable, List, Optional, Tuple
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, TagMatchMode
from letta.schemas.passage import Passage as PydanticPassage
from letta.settings import settings
from letta.settings import model_settings, settings
logger = logging.getLogger(__name__)
def should_use_tpuf() -> bool:
# We need OpenAI since we default to their embedding model
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key)
def should_use_tpuf_for_messages() -> bool:
"""Check if Turbopuffer should be used for messages."""
return should_use_tpuf() and bool(settings.embed_all_messages)
return should_use_tpuf() and bool(settings.embed_all_messages) and bool(model_settings.openai_api_key)
class TurbopufferClient:
"""Client for managing archival memory with Turbopuffer vector database."""
default_embedding_config = EmbeddingConfig(
embedding_model="text-embedding-3-small",
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
)
def __init__(self, api_key: str = None, region: str = None):
"""Initialize Turbopuffer client."""
self.api_key = api_key or settings.tpuf_api_key
@@ -38,6 +49,26 @@ class TurbopufferClient:
if not self.api_key:
raise ValueError("Turbopuffer API key not provided")
@trace_method
async def _generate_embeddings(self, texts: List[str], actor: "PydanticUser") -> List[List[float]]:
"""Generate embeddings using the default embedding configuration.
Args:
texts: List of texts to embed
actor: User actor for embedding generation
Returns:
List of embedding vectors
"""
from letta.llm_api.llm_client import LLMClient
embedding_client = LLMClient.create(
provider_type=self.default_embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings(texts, self.default_embedding_config)
return embeddings
@trace_method
async def _get_archive_namespace_name(self, archive_id: str) -> str:
"""Get namespace name for a specific archive."""
@@ -61,9 +92,9 @@ class TurbopufferClient:
self,
archive_id: str,
text_chunks: List[str],
embeddings: List[List[float]],
passage_ids: List[str],
organization_id: str,
actor: "PydanticUser",
tags: Optional[List[str]] = None,
created_at: Optional[datetime] = None,
) -> List[PydanticPassage]:
@@ -72,9 +103,9 @@ class TurbopufferClient:
Args:
archive_id: ID of the archive
text_chunks: List of text chunks to store
embeddings: List of embedding vectors corresponding to text chunks
passage_ids: List of passage IDs (must match 1:1 with text_chunks)
organization_id: Organization ID for the passages
actor: User actor for embedding generation
tags: Optional list of tags to attach to all passages
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
@@ -83,6 +114,9 @@ class TurbopufferClient:
"""
from turbopuffer import AsyncTurbopuffer
# generate embeddings using the default config
embeddings = await self._generate_embeddings(text_chunks, actor)
namespace_name = await self._get_archive_namespace_name(archive_id)
# handle timestamp - ensure UTC
@@ -102,8 +136,6 @@ class TurbopufferClient:
raise ValueError("passage_ids must be provided for Turbopuffer insertion")
if len(passage_ids) != len(text_chunks):
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})")
if len(passage_ids) != len(embeddings):
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match embeddings length ({len(embeddings)})")
# prepare column-based data for turbopuffer - optimized for batch insert
ids = []
@@ -137,7 +169,7 @@ class TurbopufferClient:
metadata_={},
tags=tags or [], # Include tags in the passage
embedding=embedding,
embedding_config=None, # Will be set by caller if needed
embedding_config=self.default_embedding_config, # Will be set by caller if needed
)
passages.append(passage)
@@ -177,9 +209,9 @@ class TurbopufferClient:
self,
agent_id: str,
message_texts: List[str],
embeddings: List[List[float]],
message_ids: List[str],
organization_id: str,
actor: "PydanticUser",
roles: List[MessageRole],
created_ats: List[datetime],
project_id: Optional[str] = None,
@@ -189,9 +221,9 @@ class TurbopufferClient:
Args:
agent_id: ID of the agent
message_texts: List of message text content to store
embeddings: List of embedding vectors corresponding to message texts
message_ids: List of message IDs (must match 1:1 with message_texts)
organization_id: Organization ID for the messages
actor: User actor for embedding generation
roles: List of message roles corresponding to each message
created_ats: List of creation timestamps for each message
project_id: Optional project ID for all messages
@@ -201,6 +233,9 @@ class TurbopufferClient:
"""
from turbopuffer import AsyncTurbopuffer
# generate embeddings using the default config
embeddings = await self._generate_embeddings(message_texts, actor)
namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
# validation checks
@@ -208,8 +243,6 @@ class TurbopufferClient:
raise ValueError("message_ids must be provided for Turbopuffer insertion")
if len(message_ids) != len(message_texts):
raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
if len(message_ids) != len(embeddings):
raise ValueError(f"message_ids length ({len(message_ids)}) must match embeddings length ({len(embeddings)})")
if len(message_ids) != len(roles):
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
if len(message_ids) != len(created_ats):
@@ -390,7 +423,7 @@ class TurbopufferClient:
async def query_passages(
self,
archive_id: str,
query_embedding: Optional[List[float]] = None,
actor: "PydanticUser",
query_text: Optional[str] = None,
search_mode: str = "vector", # "vector", "fts", "hybrid"
top_k: int = 10,
@@ -405,8 +438,8 @@ class TurbopufferClient:
Args:
archive_id: ID of the archive
query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes)
query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
actor: User actor for embedding generation
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
top_k: Number of results to return
tags: Optional list of tags to filter by
@@ -419,6 +452,12 @@ class TurbopufferClient:
Returns:
List of (passage, score, metadata) tuples with relevance rankings
"""
# generate embedding for vector/hybrid search if query_text is provided
query_embedding = None
if query_text and search_mode in ["vector", "hybrid"]:
embeddings = await self._generate_embeddings([query_text], actor)
query_embedding = embeddings[0]
# Check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
# Fallback to retrieving most recent passages when no search query is provided
@@ -519,11 +558,11 @@ class TurbopufferClient:
raise
@trace_method
async def query_messages(
async def query_messages_by_agent_id(
self,
agent_id: str,
organization_id: str,
query_embedding: Optional[List[float]] = None,
actor: "PydanticUser",
query_text: Optional[str] = None,
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
top_k: int = 10,
@@ -539,8 +578,8 @@ class TurbopufferClient:
Args:
agent_id: ID of the agent (used for filtering results)
organization_id: Organization ID for namespace lookup
query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes)
query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
actor: User actor for embedding generation
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector")
top_k: Number of results to return
roles: Optional list of message roles to filter by
@@ -556,6 +595,12 @@ class TurbopufferClient:
- score is the final relevance score
- metadata contains individual scores and ranking information
"""
# generate embedding for vector/hybrid search if query_text is provided
query_embedding = None
if query_text and search_mode in ["vector", "hybrid"]:
embeddings = await self._generate_embeddings([query_text], actor)
query_embedding = embeddings[0]
# Check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
# Fallback to retrieving most recent messages when no search query is provided
@@ -658,6 +703,159 @@ class TurbopufferClient:
logger.error(f"Failed to query messages from Turbopuffer: {e}")
raise
async def query_messages_by_org_id(
self,
organization_id: str,
actor: "PydanticUser",
query_text: Optional[str] = None,
search_mode: str = "hybrid", # "vector", "fts", "hybrid"
top_k: int = 10,
roles: Optional[List[MessageRole]] = None,
project_id: Optional[str] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> List[Tuple[dict, float, dict]]:
"""Query messages from Turbopuffer across an entire organization.
Args:
organization_id: Organization ID for namespace lookup (required)
actor: User actor for embedding generation
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "hybrid")
top_k: Number of results to return
roles: Optional list of message roles to filter by
project_id: Optional project ID to filter messages by
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter messages created after this date
end_date: Optional datetime to filter messages created on or before this date (inclusive)
Returns:
List of (message_dict, score, metadata) tuples where:
- message_dict contains id, text, role, created_at, agent_id
- score is the final relevance score (RRF score for hybrid, rank-based for single mode)
- metadata contains individual scores and ranking information
"""
# generate embedding for vector/hybrid search if query_text is provided
query_embedding = None
if query_text and search_mode in ["vector", "hybrid"]:
embeddings = await self._generate_embeddings([query_text], actor)
query_embedding = embeddings[0]
# namespace is org-scoped
namespace_name = f"letta_messages_{organization_id}"
# build filters
all_filters = []
# role filter
if roles:
role_values = [r.value for r in roles]
if len(role_values) == 1:
all_filters.append(("role", "Eq", role_values[0]))
else:
all_filters.append(("role", "In", role_values))
# project filter
if project_id:
all_filters.append(("project_id", "Eq", project_id))
# date filters
if start_date:
all_filters.append(("created_at", "Gte", start_date))
if end_date:
# make end_date inclusive of the entire day
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
from datetime import timedelta
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
all_filters.append(("created_at", "Lte", end_date))
# combine filters
final_filter = None
if len(all_filters) == 1:
final_filter = all_filters[0]
elif len(all_filters) > 1:
final_filter = ("And", all_filters)
try:
# execute query
result = await self._execute_query(
namespace_name=namespace_name,
search_mode=search_mode,
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
)
# process results based on search mode
if search_mode == "hybrid":
# for hybrid mode, we get a multi-query response
vector_results = self._process_message_query_results(result.results[0])
fts_results = self._process_message_query_results(result.results[1])
# use existing RRF method - it already returns metadata with ranks
results_with_metadata = self._reciprocal_rank_fusion(
vector_results=vector_results,
fts_results=fts_results,
get_id_func=lambda msg_dict: msg_dict["id"],
vector_weight=vector_weight,
fts_weight=fts_weight,
top_k=top_k,
)
# add raw scores to metadata if available
vector_scores = {}
for row in result.results[0].rows:
if hasattr(row, "dist"):
vector_scores[row.id] = row.dist
fts_scores = {}
for row in result.results[1].rows:
if hasattr(row, "score"):
fts_scores[row.id] = row.score
# enhance metadata with raw scores
enhanced_results = []
for msg_dict, rrf_score, metadata in results_with_metadata:
msg_id = msg_dict["id"]
if msg_id in vector_scores:
metadata["vector_score"] = vector_scores[msg_id]
if msg_id in fts_scores:
metadata["fts_score"] = fts_scores[msg_id]
enhanced_results.append((msg_dict, rrf_score, metadata))
return enhanced_results
else:
# for single queries (vector or fts)
results = self._process_message_query_results(result)
results_with_metadata = []
for idx, msg_dict in enumerate(results):
metadata = {
"combined_score": 1.0 / (idx + 1),
"search_mode": search_mode,
f"{search_mode}_rank": idx + 1,
}
# add raw score if available
if hasattr(result.rows[idx], "dist"):
metadata["vector_score"] = result.rows[idx].dist
elif hasattr(result.rows[idx], "score"):
metadata["fts_score"] = result.rows[idx].score
results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
return results_with_metadata
except Exception as e:
logger.error(f"Failed to query messages from Turbopuffer: {e}")
raise
def _process_message_query_results(self, result) -> List[dict]:
"""Process results from a message query into message dicts.
@@ -703,7 +901,7 @@ class TurbopufferClient:
tags=passage_tags, # Set the actual tags from the passage
# Set required fields to empty/default values since we don't store embeddings
embedding=[], # Empty embedding since we don't return it from Turbopuffer
embedding_config=None, # No embedding config needed for retrieved passages
embedding_config=self.default_embedding_config, # No embedding config needed for retrieved passages
)
# handle score based on search type

View File

@@ -1187,3 +1187,26 @@ class ToolReturn(BaseModel):
stdout: Optional[List[str]] = Field(default=None, description="Captured stdout (e.g. prints, logs) from the tool invocation")
stderr: Optional[List[str]] = Field(default=None, description="Captured stderr from the tool invocation")
# func_return: Optional[Any] = Field(None, description="The function return object")
class MessageSearchRequest(BaseModel):
"""Request model for searching messages across the organization"""
query_text: Optional[str] = Field(None, description="Text query for full-text search")
search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use")
roles: Optional[List[MessageRole]] = Field(None, description="Filter messages by role")
project_id: Optional[str] = Field(None, description="Filter messages by project ID")
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")
class MessageSearchResult(BaseModel):
"""Result from a message search operation with scoring details."""
message: Message = Field(..., description="The message content and metadata")
fts_score: Optional[float] = Field(None, description="Full-text search (BM25) score if FTS was used")
fts_rank: Optional[int] = Field(None, description="Full-text search rank position if FTS was used")
vector_score: Optional[float] = Field(None, description="Vector similarity score if vector search was used")
vector_rank: Optional[int] = Field(None, description="Vector search rank position if vector search was used")
rrf_score: float = Field(..., description="Reciprocal Rank Fusion combined score")

View File

@@ -39,7 +39,7 @@ from letta.schemas.memory import (
CreateArchivalMemory,
Memory,
)
from letta.schemas.message import MessageCreate
from letta.schemas.message import MessageCreate, MessageSearchRequest, MessageSearchResult
from letta.schemas.passage import Passage
from letta.schemas.run import Run
from letta.schemas.source import Source
@@ -1498,6 +1498,42 @@ async def cancel_agent_run(
return results
@router.post("/messages/search", response_model=List[MessageSearchResult], operation_id="search_messages")
async def search_messages(
request: MessageSearchRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Search messages across the entire organization with optional project filtering.
Returns messages with FTS/vector ranks and total RRF score.
Requires message embedding and Turbopuffer to be enabled.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# get embedding config from the default agent if needed
# check if any agents exist in the org
agent_count = await server.agent_manager.size_async(actor=actor)
if agent_count == 0:
raise HTTPException(status_code=400, detail="No agents found in organization to derive embedding configuration from")
try:
results = await server.message_manager.search_messages_org_async(
actor=actor,
query_text=request.query_text,
search_mode=request.search_mode,
roles=request.roles,
project_id=request.project_id,
limit=request.limit,
start_date=request.start_date,
end_date=request.end_date,
)
return results
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
async def _process_message_background(
run_id: str,
server: SyncServer,

View File

@@ -719,9 +719,7 @@ class AgentManager:
# Only create messages if we initialized with messages
if not _init_with_no_messages:
await self.message_manager.create_many_messages_async(
pydantic_msgs=init_messages, actor=actor, embedding_config=result.embedding_config, project_id=result.project_id
)
await self.message_manager.create_many_messages_async(pydantic_msgs=init_messages, actor=actor, project_id=result.project_id)
return result
@enforce_types
@@ -1834,7 +1832,6 @@ class AgentManager:
message_id=curr_system_message.id,
message_update=MessageUpdate(**temp_message.model_dump()),
actor=actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)
else:
@@ -1889,9 +1886,7 @@ class AgentManager:
self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser
) -> PydanticAgentState:
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
messages = await self.message_manager.create_many_messages_async(
messages, actor=actor, embedding_config=agent.embedding_config, project_id=agent.project_id
)
messages = await self.message_manager.create_many_messages_async(messages, actor=actor, project_id=agent.project_id)
message_ids = agent.message_ids or []
message_ids += [m.id for m in messages]
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor)
@@ -2692,7 +2687,6 @@ class AgentManager:
# use hybrid search to combine vector and full-text search
passages_with_scores = await tpuf_client.query_passages(
archive_id=archive_ids[0],
query_embedding=query_embedding,
query_text=query_text, # pass text for potential hybrid search
search_mode="hybrid", # use hybrid mode for better results
top_k=limit,
@@ -2700,6 +2694,7 @@ class AgentManager:
tag_match_mode=tag_match_mode or TagMatchMode.ANY,
start_date=start_date,
end_date=end_date,
actor=actor,
)
# Return full tuples with metadata

View File

@@ -678,7 +678,6 @@ class AgentSerializationManager:
created_messages = await self.message_manager.create_many_messages_async(
pydantic_msgs=messages,
actor=actor,
embedding_config=created_agent.embedding_config,
project_id=created_agent.project_id,
)
imported_count += len(created_messages)

View File

@@ -11,11 +11,10 @@ from letta.orm.agent import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import LettaMessageUpdateUnion
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType, TextContent
from letta.schemas.message import Message as PydanticMessage, MessageUpdate
from letta.schemas.message import Message as PydanticMessage, MessageSearchResult, MessageUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.file_manager import FileManager
@@ -314,7 +313,6 @@ class MessageManager:
self,
pydantic_msgs: List[PydanticMessage],
actor: PydanticUser,
embedding_config: Optional[EmbeddingConfig] = None,
strict_mode: bool = False,
project_id: Optional[str] = None,
) -> List[PydanticMessage]:
@@ -324,7 +322,6 @@ class MessageManager:
Args:
pydantic_msgs: List of Pydantic message models to create
actor: User performing the action
embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer
strict_mode: If True, wait for embedding to complete; if False, run in background
project_id: Optional project ID for the messages (for Turbopuffer indexing)
@@ -365,20 +362,20 @@ class MessageManager:
result = [msg.to_pydantic() for msg in created_messages]
await session.commit()
# embed messages in turbopuffer if enabled and embedding_config provided
# embed messages in turbopuffer if enabled
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
if should_use_tpuf_for_messages() and embedding_config and result:
if should_use_tpuf_for_messages() and result:
# extract agent_id from the first message (all should have same agent_id)
agent_id = result[0].agent_id
if agent_id:
if strict_mode:
# wait for embedding to complete
await self._embed_messages_background(result, embedding_config, actor, agent_id, project_id)
await self._embed_messages_background(result, actor, agent_id, project_id)
else:
# fire and forget - run embedding in background
fire_and_forget(
self._embed_messages_background(result, embedding_config, actor, agent_id, project_id),
self._embed_messages_background(result, actor, agent_id, project_id),
task_name=f"embed_messages_for_agent_{agent_id}",
)
@@ -387,7 +384,6 @@ class MessageManager:
async def _embed_messages_background(
self,
messages: List[PydanticMessage],
embedding_config: EmbeddingConfig,
actor: PydanticUser,
agent_id: str,
project_id: Optional[str] = None,
@@ -396,14 +392,12 @@ class MessageManager:
Args:
messages: List of messages to embed
embedding_config: Embedding configuration
actor: User performing the action
agent_id: Agent ID for the messages
project_id: Optional project ID for the messages
"""
try:
from letta.helpers.tpuf_client import TurbopufferClient
from letta.llm_api.llm_client import LLMClient
# extract text content from each message
message_texts = []
@@ -423,21 +417,14 @@ class MessageManager:
created_ats.append(msg.created_at)
if message_texts:
# generate embeddings using provided config
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings(message_texts, embedding_config)
# insert to turbopuffer
# insert to turbopuffer - TurbopufferClient will generate embeddings internally
tpuf_client = TurbopufferClient()
await tpuf_client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=actor.organization_id,
actor=actor,
roles=roles,
created_ats=created_ats,
project_id=project_id,
@@ -550,7 +537,6 @@ class MessageManager:
message_id: str,
message_update: MessageUpdate,
actor: PydanticUser,
embedding_config: Optional[EmbeddingConfig] = None,
strict_mode: bool = False,
project_id: Optional[str] = None,
) -> PydanticMessage:
@@ -562,7 +548,6 @@ class MessageManager:
message_id: ID of the message to update
message_update: Update data for the message
actor: User performing the action
embedding_config: Optional embedding configuration for Turbopuffer
strict_mode: If True, wait for embedding update to complete; if False, run in background
project_id: Optional project ID for the message (for Turbopuffer indexing)
"""
@@ -582,7 +567,7 @@ class MessageManager:
# update message in turbopuffer if enabled (delete and re-insert)
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
if should_use_tpuf_for_messages() and embedding_config and pydantic_message.agent_id:
if should_use_tpuf_for_messages() and pydantic_message.agent_id:
# extract text content from updated message
text = self._extract_message_text(pydantic_message)
@@ -590,51 +575,42 @@ class MessageManager:
if text:
if strict_mode:
# wait for embedding update to complete
await self._update_message_embedding_background(pydantic_message, text, embedding_config, actor, project_id)
await self._update_message_embedding_background(pydantic_message, text, actor, project_id)
else:
# fire and forget - run embedding update in background
fire_and_forget(
self._update_message_embedding_background(pydantic_message, text, embedding_config, actor, project_id),
self._update_message_embedding_background(pydantic_message, text, actor, project_id),
task_name=f"update_message_embedding_{message_id}",
)
return pydantic_message
async def _update_message_embedding_background(
self, message: PydanticMessage, text: str, embedding_config: EmbeddingConfig, actor: PydanticUser, project_id: Optional[str] = None
self, message: PydanticMessage, text: str, actor: PydanticUser, project_id: Optional[str] = None
) -> None:
"""Background task to update a message's embedding in Turbopuffer.
Args:
message: The updated message
text: Extracted text content from the message
embedding_config: Embedding configuration
actor: User performing the action
project_id: Optional project ID for the message
"""
try:
from letta.helpers.tpuf_client import TurbopufferClient
from letta.llm_api.llm_client import LLMClient
tpuf_client = TurbopufferClient()
# delete old message from turbopuffer
await tpuf_client.delete_messages(agent_id=message.agent_id, organization_id=actor.organization_id, message_ids=[message.id])
# generate new embedding
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([text], embedding_config)
# re-insert with updated content
# re-insert with updated content - TurbopufferClient will generate embeddings internally
await tpuf_client.insert_messages(
agent_id=message.agent_id,
message_texts=[text],
embeddings=embeddings,
message_ids=[message.id],
organization_id=actor.organization_id,
actor=actor,
roles=[message.role],
created_ats=[message.created_at],
project_id=project_id,
@@ -1119,13 +1095,11 @@ class MessageManager:
agent_id: str,
actor: PydanticUser,
query_text: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
search_mode: str = "hybrid",
roles: Optional[List[MessageRole]] = None,
limit: int = 50,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
embedding_config: Optional[EmbeddingConfig] = None,
) -> List[Tuple[PydanticMessage, dict]]:
"""
Search messages using Turbopuffer if enabled, otherwise fall back to SQL search.
@@ -1133,14 +1107,12 @@ class MessageManager:
Args:
agent_id: ID of the agent whose messages to search
actor: User performing the search
query_text: Text query for full-text search
query_embedding: Optional pre-computed embedding for vector search
query_text: Text query (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
search_mode: "vector", "fts", "hybrid", or "timestamp" (default: "hybrid")
roles: Optional list of message roles to filter by
limit: Maximum number of results to return
start_date: Optional filter for messages created after this date
end_date: Optional filter for messages created on or before this date (inclusive)
embedding_config: Optional embedding configuration for generating query embedding
Returns:
List of tuples (message, metadata) where metadata contains relevance scores
@@ -1150,36 +1122,12 @@ class MessageManager:
# check if we should use turbopuffer
if should_use_tpuf_for_messages():
try:
# generate embedding if needed and not provided
if search_mode in ["vector", "hybrid"] and query_embedding is None and query_text:
if not embedding_config:
# fall back to SQL search if no embedding config
logger.warning("No embedding config provided for vector search, falling back to SQL")
return await self.list_messages_for_agent_async(
agent_id=agent_id,
actor=actor,
query_text=query_text,
roles=roles,
limit=limit,
ascending=False,
)
# generate embedding from query text
from letta.llm_api.llm_client import LLMClient
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
query_embedding = embeddings[0]
# use turbopuffer for search
# use turbopuffer for search - TurbopufferClient will generate embeddings internally
tpuf_client = TurbopufferClient()
results = await tpuf_client.query_messages(
results = await tpuf_client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=actor.organization_id,
query_embedding=query_embedding,
actor=actor,
query_text=query_text,
search_mode=search_mode,
top_k=limit,
@@ -1255,3 +1203,76 @@ class MessageManager:
}
message_tuples.append((message, metadata))
return message_tuples
async def search_messages_org_async(
self,
actor: PydanticUser,
query_text: Optional[str] = None,
search_mode: str = "hybrid",
roles: Optional[List[MessageRole]] = None,
project_id: Optional[str] = None,
limit: int = 50,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> List[MessageSearchResult]:
"""
Search messages across entire organization using Turbopuffer.
Args:
actor: User performing the search (must have org access)
query_text: Text query for full-text search
search_mode: "vector", "fts", or "hybrid" (default: "hybrid")
roles: Optional list of message roles to filter by
project_id: Optional project ID to filter messages by
limit: Maximum number of results to return
start_date: Optional filter for messages created after this date
end_date: Optional filter for messages created on or before this date (inclusive)
Returns:
List of MessageSearchResult objects with scoring details
Raises:
ValueError: If message embedding or Turbopuffer is not enabled
"""
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
# check if turbopuffer is enabled
# TODO: extend to non-Turbopuffer in the future.
if not should_use_tpuf_for_messages():
raise ValueError("Message search requires message embedding, OpenAI, and Turbopuffer to be enabled.")
# use turbopuffer for search - TurbopufferClient will generate embeddings internally
tpuf_client = TurbopufferClient()
results = await tpuf_client.query_messages_by_org_id(
organization_id=actor.organization_id,
actor=actor,
query_text=query_text,
search_mode=search_mode,
top_k=limit,
roles=roles,
project_id=project_id,
start_date=start_date,
end_date=end_date,
)
# convert results to MessageSearchResult objects
if not results:
return []
# create message mapping
message_ids = [msg_dict["id"] for msg_dict, _, _ in results]
messages = await self.get_messages_by_ids_async(message_ids=message_ids, actor=actor)
message_mapping = {message.id: message for message in messages}
# create search results using list comprehension
return [
MessageSearchResult(
message=message_mapping.get(msg_dict["id"]),
fts_score=metadata.get("fts_score"),
fts_rank=metadata.get("fts_rank"),
vector_score=metadata.get("vector_score"),
vector_rank=metadata.get("vector_rank"),
rrf_score=rrf_score,
)
for msg_dict, rrf_score, metadata in results
]

View File

@@ -623,12 +623,13 @@ class PassageManager:
passage_texts = [p.text for p in passages]
# Insert to Turbopuffer with the same IDs as SQL
# TurbopufferClient will generate embeddings internally using default config
await tpuf_client.insert_archival_memories(
archive_id=archive.id,
text_chunks=passage_texts,
embeddings=embeddings,
passage_ids=passage_ids, # Use same IDs as SQL
organization_id=actor.organization_id,
actor=actor,
tags=tags,
created_at=passages[0].created_at if passages else None,
)

View File

@@ -195,7 +195,6 @@ class Summarizer:
await self.message_manager.create_many_messages_async(
pydantic_msgs=[summary_message_obj],
actor=self.actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)

View File

@@ -163,7 +163,6 @@ class LettaCoreToolExecutor(ToolExecutor):
limit=search_limit,
start_date=start_datetime,
end_date=end_datetime,
embedding_config=agent_state.embedding_config,
)
if len(message_results) == 0:

View File

@@ -233,7 +233,7 @@ class TestTurbopufferIntegration:
pass
@pytest.mark.asyncio
async def test_turbopuffer_metadata_attributes(self, enable_turbopuffer):
async def test_turbopuffer_metadata_attributes(self, default_user, enable_turbopuffer):
"""Test that Turbopuffer properly stores and retrieves metadata attributes"""
# Only run if we have a real API key
@@ -273,17 +273,16 @@ class TestTurbopufferIntegration:
result = await client.insert_archival_memories(
archive_id=archive_id,
text_chunks=[d["text"] for d in test_data],
embeddings=[d["vector"] for d in test_data],
passage_ids=[d["id"] for d in test_data],
organization_id="org-123", # Default org
actor=default_user,
created_at=datetime.now(timezone.utc),
)
assert len(result) == 3
# Query all passages (no tag filtering)
query_vector = [0.15] * 1536
results = await client.query_passages(archive_id=archive_id, query_embedding=query_vector, top_k=10)
results = await client.query_passages(archive_id=archive_id, actor=default_user, top_k=10)
# Should get all passages
assert len(results) == 3 # All three passages
@@ -339,7 +338,7 @@ class TestTurbopufferIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
async def test_hybrid_search_with_real_tpuf(self, enable_turbopuffer):
async def test_hybrid_search_with_real_tpuf(self, default_user, enable_turbopuffer):
"""Test hybrid search functionality combining vector and full-text search"""
import uuid
@@ -366,13 +365,14 @@ class TestTurbopufferIntegration:
# Insert passages
await client.insert_archival_memories(
archive_id=archive_id, text_chunks=texts, embeddings=embeddings, passage_ids=passage_ids, organization_id=org_id
archive_id=archive_id, text_chunks=texts, passage_ids=passage_ids, organization_id=org_id, actor=default_user
)
# Test vector-only search
vector_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 6.0, 11.0], # similar to second passage embedding
actor=default_user,
query_text="python programming tutorial",
search_mode="vector",
top_k=3,
)
@@ -382,7 +382,7 @@ class TestTurbopufferIntegration:
# Test FTS-only search
fts_results = await client.query_passages(
archive_id=archive_id, query_text="Turbopuffer vector database", search_mode="fts", top_k=3
archive_id=archive_id, actor=default_user, query_text="Turbopuffer vector database", search_mode="fts", top_k=3
)
assert 0 < len(fts_results) <= 3
# should find passages mentioning Turbopuffer
@@ -393,7 +393,7 @@ class TestTurbopufferIntegration:
# Test hybrid search
hybrid_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[2.0, 7.0, 12.0],
actor=default_user,
query_text="vector search Turbopuffer",
search_mode="hybrid",
top_k=3,
@@ -412,7 +412,7 @@ class TestTurbopufferIntegration:
# Test with different weights
vector_heavy_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[0.0, 5.0, 10.0], # very similar to first passage
actor=default_user,
query_text="quick brown fox", # matches second passage
search_mode="hybrid",
top_k=3,
@@ -423,16 +423,13 @@ class TestTurbopufferIntegration:
# all results should have scores
assert all(isinstance(score, float) for _, score, _ in vector_heavy_results)
# Test error handling - missing text for hybrid mode (embedding provided but text missing)
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
await client.query_passages(archive_id=archive_id, query_embedding=[1.0, 2.0, 3.0], search_mode="hybrid", top_k=3)
# Test error handling - missing embedding for hybrid mode (text provided but embedding missing)
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
await client.query_passages(archive_id=archive_id, query_text="test", search_mode="hybrid", top_k=3)
# Test with different search modes
await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="vector", top_k=3)
await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="fts", top_k=3)
await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="hybrid", top_k=3)
# Test explicit timestamp mode
timestamp_results = await client.query_passages(archive_id=archive_id, search_mode="timestamp", top_k=3)
timestamp_results = await client.query_passages(archive_id=archive_id, actor=default_user, search_mode="timestamp", top_k=3)
assert len(timestamp_results) <= 3
# Should return passages ordered by timestamp (most recent first)
assert all(isinstance(passage, Passage) for passage, _, _ in timestamp_results)
@@ -446,7 +443,7 @@ class TestTurbopufferIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
async def test_tag_filtering_with_real_tpuf(self, enable_turbopuffer):
async def test_tag_filtering_with_real_tpuf(self, default_user, enable_turbopuffer):
"""Test tag filtering functionality with AND and OR logic"""
import uuid
@@ -479,13 +476,13 @@ class TestTurbopufferIntegration:
passage_ids = [f"passage-{str(uuid.uuid4())}" for _ in texts]
# Insert passages with tags
for i, (text, tags, embedding, passage_id) in enumerate(zip(texts, tag_sets, embeddings, passage_ids)):
for i, (text, tags, passage_id) in enumerate(zip(texts, tag_sets, passage_ids)):
await client.insert_archival_memories(
archive_id=archive_id,
text_chunks=[text],
embeddings=[embedding],
passage_ids=[passage_id],
organization_id=org_id,
actor=default_user,
tags=tags,
created_at=datetime.now(timezone.utc),
)
@@ -493,7 +490,8 @@ class TestTurbopufferIntegration:
# Test tag filtering with "any" mode (should find passages with any of the specified tags)
python_any_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 6.0, 11.0],
actor=default_user,
query_text="python programming",
search_mode="vector",
top_k=10,
tags=["python"],
@@ -511,7 +509,8 @@ class TestTurbopufferIntegration:
# Test tag filtering with "all" mode
python_tutorial_all_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 6.0, 11.0],
actor=default_user,
query_text="python tutorial",
search_mode="vector",
top_k=10,
tags=["python", "tutorial"],
@@ -528,6 +527,7 @@ class TestTurbopufferIntegration:
# Test tag filtering with FTS mode
js_fts_results = await client.query_passages(
archive_id=archive_id,
actor=default_user,
query_text="javascript",
search_mode="fts",
top_k=10,
@@ -545,7 +545,7 @@ class TestTurbopufferIntegration:
# Test hybrid search with tags
python_hybrid_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[2.0, 7.0, 12.0],
actor=default_user,
query_text="python programming",
search_mode="hybrid",
top_k=10,
@@ -569,7 +569,7 @@ class TestTurbopufferIntegration:
pass
@pytest.mark.asyncio
async def test_temporal_filtering_with_real_tpuf(self, enable_turbopuffer):
async def test_temporal_filtering_with_real_tpuf(self, default_user, enable_turbopuffer):
"""Test temporal filtering with date ranges"""
from datetime import datetime, timedelta, timezone
@@ -601,15 +601,14 @@ class TestTurbopufferIntegration:
# We need to generate embeddings for the passages
# For testing, we'll use simple dummy embeddings
for text, timestamp in test_passages:
dummy_embedding = [1.0, 2.0, 3.0] # Simple test embedding
passage_id = f"passage-{uuid.uuid4()}"
await client.insert_archival_memories(
archive_id=archive_id,
text_chunks=[text],
embeddings=[dummy_embedding],
passage_ids=[passage_id],
organization_id="test-org",
actor=default_user,
created_at=timestamp,
)
@@ -617,7 +616,8 @@ class TestTurbopufferIntegration:
three_days_ago = now - timedelta(days=3)
results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 2.0, 3.0],
actor=default_user,
query_text="meeting notes",
search_mode="vector",
top_k=10,
start_date=three_days_ago,
@@ -637,7 +637,8 @@ class TestTurbopufferIntegration:
two_weeks_ago = now - timedelta(days=14)
results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 2.0, 3.0],
actor=default_user,
query_text="meeting notes",
search_mode="vector",
top_k=10,
start_date=two_weeks_ago,
@@ -652,7 +653,8 @@ class TestTurbopufferIntegration:
# Test 3: Query with only end_date (everything before yesterday)
results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 2.0, 3.0],
actor=default_user,
query_text="meeting notes",
search_mode="vector",
top_k=10,
end_date=yesterday + timedelta(hours=12), # Middle of yesterday
@@ -667,6 +669,7 @@ class TestTurbopufferIntegration:
# Test 4: Test with FTS mode and date filtering
results = await client.query_passages(
archive_id=archive_id,
actor=default_user,
query_text="meeting notes project",
search_mode="fts",
top_k=10,
@@ -682,7 +685,7 @@ class TestTurbopufferIntegration:
# Test 5: Test with hybrid mode and date filtering
results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 2.0, 3.0],
actor=default_user,
query_text="sprint review",
search_mode="hybrid",
top_k=10,
@@ -934,11 +937,9 @@ class TestTurbopufferMessagesIntegration:
),
]
# Create messages without embedding_config
created = await server.message_manager.create_many_messages_async(
pydantic_msgs=messages,
actor=default_user,
embedding_config=None, # No config provided
)
assert len(created) == 2
@@ -1057,7 +1058,7 @@ class TestTurbopufferMessagesIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding):
async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding, default_user):
"""Test actual message embedding and storage in Turbopuffer"""
import uuid
from datetime import datetime, timezone
@@ -1087,9 +1088,9 @@ class TestTurbopufferMessagesIntegration:
success = await client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=org_id,
actor=default_user,
roles=roles,
created_ats=created_ats,
)
@@ -1097,11 +1098,8 @@ class TestTurbopufferMessagesIntegration:
assert success == True
# Verify we can query the messages
results = await client.query_messages(
agent_id=agent_id,
organization_id=org_id,
search_mode="timestamp",
top_k=10,
results = await client.query_messages_by_agent_id(
agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, actor=default_user
)
assert len(results) == 3
@@ -1121,7 +1119,7 @@ class TestTurbopufferMessagesIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding):
async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding, default_user):
"""Test vector search on messages in Turbopuffer"""
import uuid
from datetime import datetime, timezone
@@ -1145,29 +1143,23 @@ class TestTurbopufferMessagesIntegration:
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
# Create embeddings that reflect content similarity
embeddings = [
[1.0, 0.0, 0.0], # Python programming
[0.0, 1.0, 0.0], # JavaScript web
[0.8, 0.0, 0.2], # ML with Python (similar to first)
]
# Insert messages
await client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=org_id,
actor=default_user,
roles=roles,
created_ats=created_ats,
)
# Search for Python-related messages using vector search
query_embedding = [0.9, 0.0, 0.1] # Similar to Python messages
results = await client.query_messages(
results = await client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=org_id,
query_embedding=query_embedding,
actor=default_user,
query_text="Python programming",
search_mode="vector",
top_k=2,
)
@@ -1187,7 +1179,7 @@ class TestTurbopufferMessagesIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding):
async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding, default_user):
"""Test hybrid search combining vector and FTS for messages"""
import uuid
from datetime import datetime, timezone
@@ -1211,30 +1203,22 @@ class TestTurbopufferMessagesIntegration:
roles = [MessageRole.assistant] * len(message_texts)
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
# Embeddings
embeddings = [
[0.1, 0.9, 0.0], # fox text
[0.9, 0.1, 0.0], # ML algorithms
[0.5, 0.5, 0.0], # Quick Python
[0.8, 0.2, 0.0], # Deep learning
]
# Insert messages
await client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=org_id,
actor=default_user,
roles=roles,
created_ats=created_ats,
)
# Hybrid search - vector similar to ML but text contains "quick"
results = await client.query_messages(
# Hybrid search - text search for "quick"
results = await client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=org_id,
query_embedding=[0.7, 0.3, 0.0], # Similar to ML messages
actor=default_user,
query_text="quick", # Text search for "quick"
search_mode="hybrid",
top_k=3,
@@ -1257,7 +1241,7 @@ class TestTurbopufferMessagesIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding):
async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding, default_user):
"""Test filtering messages by role"""
import uuid
from datetime import datetime, timezone
@@ -1283,26 +1267,21 @@ class TestTurbopufferMessagesIntegration:
roles = [role for _, role in message_data]
message_ids = [str(uuid.uuid4()) for _ in message_texts]
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
embeddings = [[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))]
# Insert messages
await client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=org_id,
actor=default_user,
roles=roles,
created_ats=created_ats,
)
# Query only user messages
user_results = await client.query_messages(
agent_id=agent_id,
organization_id=org_id,
search_mode="timestamp",
top_k=10,
roles=[MessageRole.user],
user_results = await client.query_messages_by_agent_id(
agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, roles=[MessageRole.user], actor=default_user
)
assert len(user_results) == 2
@@ -1311,12 +1290,13 @@ class TestTurbopufferMessagesIntegration:
assert msg["text"] in ["I need help with Python", "Can you explain this?"]
# Query assistant and system messages
non_user_results = await client.query_messages(
non_user_results = await client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=org_id,
search_mode="timestamp",
top_k=10,
roles=[MessageRole.assistant, MessageRole.system],
actor=default_user,
)
assert len(non_user_results) == 3
@@ -1395,7 +1375,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=True,
)
@@ -1409,7 +1388,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Python",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(python_results) > 0
assert any(msg.id == message_id for msg, metadata in python_results)
@@ -1419,7 +1397,6 @@ class TestTurbopufferMessagesIntegration:
message_id=message_id,
message_update=MessageUpdate(content="Updated content about JavaScript development"),
actor=default_user,
embedding_config=embedding_config,
strict_mode=True,
)
@@ -1432,7 +1409,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Python",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
# Should either find no results or results that don't include our message
assert not any(msg.id == message_id for msg, metadata in python_results_after)
@@ -1444,7 +1420,6 @@ class TestTurbopufferMessagesIntegration:
query_text="JavaScript",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(js_results) > 0
assert any(msg.id == message_id for msg, metadata in js_results)
@@ -1497,7 +1472,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=True,
)
agent_a_messages.extend(msgs)
@@ -1514,7 +1488,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=True,
)
agent_b_messages.extend(msgs)
@@ -1526,7 +1499,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Agent A",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(agent_a_search) == 5
@@ -1536,7 +1508,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Agent B",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(agent_b_search) == 3
@@ -1559,7 +1530,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Agent A",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(agent_a_final) == 2
# Verify the remaining messages are the correct ones
@@ -1574,7 +1544,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Agent B",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(agent_b_final) == 0
@@ -1583,84 +1552,6 @@ class TestTurbopufferMessagesIntegration:
await server.agent_manager.delete_agent_async(agent_a.id, default_user)
await server.agent_manager.delete_agent_async(agent_b.id, default_user)
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_crud_operations_without_embedding_config(self, server, default_user, sarah_agent, enable_message_embedding):
"""Test that CRUD operations handle missing embedding_config gracefully"""
from letta.schemas.message import MessageUpdate
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
# Create message WITH embedding_config
messages = await server.message_manager.create_many_messages_async(
pydantic_msgs=[
PydanticMessage(
role=MessageRole.user,
content=[TextContent(text="Message with searchable content about databases")],
agent_id=sarah_agent.id,
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=True,
)
assert len(messages) == 1
message_id = messages[0].id
# Verify message is searchable initially
initial_search = await server.message_manager.search_messages_async(
agent_id=sarah_agent.id,
actor=default_user,
query_text="databases",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(initial_search) > 0
assert any(msg.id == message_id for msg, metadata in initial_search)
# Update message WITHOUT embedding_config - should update postgres but not turbopuffer
updated_message = await server.message_manager.update_message_by_id_async(
message_id=message_id,
message_update=MessageUpdate(content="Updated content about algorithms"),
actor=default_user,
embedding_config=None, # No config provided
)
# Verify postgres was updated
assert updated_message.id == message_id
updated_text = server.message_manager._extract_message_text(updated_message)
assert "algorithms" in updated_text
assert "databases" not in updated_text
# Original search term should STILL find the message (turbopuffer wasn't updated)
still_searchable = await server.message_manager.search_messages_async(
agent_id=sarah_agent.id,
actor=default_user,
query_text="databases",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(still_searchable) > 0
assert any(msg.id == message_id for msg, metadata in still_searchable)
# New content should NOT be searchable (wasn't re-indexed)
not_searchable = await server.message_manager.search_messages_async(
agent_id=sarah_agent.id,
actor=default_user,
query_text="algorithms",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
# Should either find no results or results that don't include our message
assert not any(msg.id == message_id for msg, metadata in not_searchable)
# Clean up
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_turbopuffer_failure_does_not_break_postgres(self, server, default_user, sarah_agent, enable_message_embedding):
@@ -1681,7 +1572,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=embedding_config,
)
assert len(messages) == 1
@@ -1702,7 +1592,6 @@ class TestTurbopufferMessagesIntegration:
message_id=message_id,
message_update=MessageUpdate(content="Updated despite turbopuffer failure"),
actor=default_user,
embedding_config=embedding_config,
strict_mode=False, # Don't fail on turbopuffer errors - that's what we're testing!
)
@@ -1722,7 +1611,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=None, # Create without embedding to avoid mock issues
)
message_to_delete_id = messages2[0].id
@@ -1741,7 +1629,7 @@ class TestTurbopufferMessagesIntegration:
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=False)
async def wait_for_embedding(
self, agent_id: str, message_id: str, organization_id: str, max_wait: float = 10.0, poll_interval: float = 0.5
self, agent_id: str, message_id: str, organization_id: str, actor, max_wait: float = 10.0, poll_interval: float = 0.5
) -> bool:
"""Poll Turbopuffer directly to check if a message has been embedded.
@@ -1765,9 +1653,10 @@ class TestTurbopufferMessagesIntegration:
while asyncio.get_event_loop().time() - start_time < max_wait:
try:
# Query Turbopuffer directly using timestamp mode to get all messages
results = await client.query_messages(
results = await client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=organization_id,
actor=actor,
search_mode="timestamp",
top_k=100, # Get more messages to ensure we find it
)
@@ -1800,7 +1689,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=False, # Background mode
)
@@ -1814,7 +1702,12 @@ class TestTurbopufferMessagesIntegration:
# Poll for embedding completion by querying Turbopuffer directly
embedded = await self.wait_for_embedding(
agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5
agent_id=sarah_agent.id,
message_id=message_id,
organization_id=default_user.organization_id,
actor=default_user,
max_wait=10.0,
poll_interval=0.5,
)
assert embedded, "Message was not embedded in Turbopuffer within timeout"
@@ -1825,7 +1718,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Python programming",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(search_results) > 0
assert any(msg.id == message_id for msg, _ in search_results)
@@ -1851,7 +1743,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=True, # Ensure initial embedding
)
@@ -1865,7 +1756,6 @@ class TestTurbopufferMessagesIntegration:
query_text="databases",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert any(msg.id == message_id for msg, _ in initial_results)
@@ -1874,7 +1764,6 @@ class TestTurbopufferMessagesIntegration:
message_id=message_id,
message_update=MessageUpdate(content="Updated content about machine learning"),
actor=default_user,
embedding_config=embedding_config,
strict_mode=False, # Background mode
)
@@ -1890,7 +1779,12 @@ class TestTurbopufferMessagesIntegration:
# Poll for the update to be reflected in Turbopuffer
# We check by searching for the new content
embedded = await self.wait_for_embedding(
agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5
agent_id=sarah_agent.id,
message_id=message_id,
organization_id=default_user.organization_id,
actor=default_user,
max_wait=10.0,
poll_interval=0.5,
)
assert embedded, "Updated message was not re-embedded within timeout"
@@ -1901,7 +1795,6 @@ class TestTurbopufferMessagesIntegration:
query_text="machine learning",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert any(msg.id == message_id for msg, _ in new_results)
@@ -1914,7 +1807,6 @@ class TestTurbopufferMessagesIntegration:
query_text="databases",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
# The message shouldn't match the old search term anymore
if len(old_results) > 0:
@@ -1929,7 +1821,7 @@ class TestTurbopufferMessagesIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding):
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding, default_user):
"""Test filtering messages by date range"""
import uuid
from datetime import datetime, timedelta, timezone
@@ -1959,21 +1851,17 @@ class TestTurbopufferMessagesIntegration:
await client.insert_messages(
agent_id=agent_id,
message_texts=[text],
embeddings=[[1.0, 2.0, 3.0]],
message_ids=[str(uuid.uuid4())],
organization_id=org_id,
actor=default_user,
roles=[MessageRole.assistant],
created_ats=[timestamp],
)
# Query messages from the last 3 days
three_days_ago = now - timedelta(days=3)
recent_results = await client.query_messages(
agent_id=agent_id,
organization_id=org_id,
search_mode="timestamp",
top_k=10,
start_date=three_days_ago,
recent_results = await client.query_messages_by_agent_id(
agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, start_date=three_days_ago, actor=default_user
)
# Should get today's and yesterday's messages
@@ -1984,13 +1872,14 @@ class TestTurbopufferMessagesIntegration:
# Query messages between 2 weeks ago and 1 week ago
two_weeks_ago = now - timedelta(days=14)
week_results = await client.query_messages(
week_results = await client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=org_id,
search_mode="timestamp",
top_k=10,
start_date=two_weeks_ago,
end_date=last_week + timedelta(days=1), # Include last week's message
actor=default_user,
)
# Should get only last week's message
@@ -1998,10 +1887,11 @@ class TestTurbopufferMessagesIntegration:
assert week_results[0][0]["text"] == "Last week's message"
# Query with vector search and date filtering
filtered_vector_results = await client.query_messages(
filtered_vector_results = await client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=org_id,
query_embedding=[1.0, 2.0, 3.0],
actor=default_user,
query_text="message",
search_mode="vector",
top_k=10,
start_date=three_days_ago,
@@ -2101,7 +1991,7 @@ class TestNamespaceTracking:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_project_id_filtering(self, server, sarah_agent, default_user, enable_turbopuffer, enable_message_embedding):
"""Test that project_id filtering works correctly in query_messages"""
"""Test that project_id filtering works correctly in query_messages_by_agent_id"""
from letta.schemas.letta_message_content import TextContent
# Create two project IDs
@@ -2124,24 +2014,15 @@ class TestNamespaceTracking:
# Insert messages with their respective project IDs
tpuf_client = TurbopufferClient()
# Generate embeddings
from letta.llm_api.llm_client import LLMClient
embedding_client = LLMClient.create(
provider_type=sarah_agent.embedding_config.embedding_endpoint_type,
actor=default_user,
)
embeddings = await embedding_client.request_embeddings(
[message_a.content[0].text, message_b.content[0].text], sarah_agent.embedding_config
)
# Embeddings will be generated automatically by the client
# Insert message A with project_a_id
await tpuf_client.insert_messages(
agent_id=sarah_agent.id,
message_texts=[message_a.content[0].text],
embeddings=[embeddings[0]],
message_ids=[message_a.id],
organization_id=default_user.organization_id,
actor=default_user,
roles=[message_a.role],
created_ats=[message_a.created_at],
project_id=project_a_id,
@@ -2151,9 +2032,9 @@ class TestNamespaceTracking:
await tpuf_client.insert_messages(
agent_id=sarah_agent.id,
message_texts=[message_b.content[0].text],
embeddings=[embeddings[1]],
message_ids=[message_b.id],
organization_id=default_user.organization_id,
actor=default_user,
roles=[message_b.role],
created_ats=[message_b.created_at],
project_id=project_b_id,
@@ -2162,12 +2043,13 @@ class TestNamespaceTracking:
# Poll for message A with project_a_id filter
max_retries = 10
for i in range(max_retries):
results_a = await tpuf_client.query_messages(
results_a = await tpuf_client.query_messages_by_agent_id(
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
search_mode="timestamp", # Simple timestamp retrieval
top_k=10,
project_id=project_a_id,
actor=default_user,
)
if len(results_a) == 1 and results_a[0][0]["id"] == message_a.id:
break
@@ -2179,12 +2061,13 @@ class TestNamespaceTracking:
# Poll for message B with project_b_id filter
for i in range(max_retries):
results_b = await tpuf_client.query_messages(
results_b = await tpuf_client.query_messages_by_agent_id(
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
search_mode="timestamp",
top_k=10,
project_id=project_b_id,
actor=default_user,
)
if len(results_b) == 1 and results_b[0][0]["id"] == message_b.id:
break
@@ -2195,12 +2078,13 @@ class TestNamespaceTracking:
assert "JavaScript" in results_b[0][0]["text"]
# Query without project filter - should find both
results_all = await tpuf_client.query_messages(
results_all = await tpuf_client.query_messages_by_agent_id(
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
search_mode="timestamp",
top_k=10,
project_id=None, # No filter
actor=default_user,
)
assert len(results_all) >= 2 # May have other messages from setup