feat: Add search messages endpoint [LET-4144] (#4434)
* Add search messages endpoint * Run fern autogen and fix tests
This commit is contained in:
@@ -178,7 +178,6 @@ class BaseAgent(ABC):
|
||||
curr_system_message.id,
|
||||
message_update=MessageUpdate(content=new_system_message_str),
|
||||
actor=self.actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
|
||||
@@ -117,7 +117,6 @@ async def _prepare_in_context_messages_async(
|
||||
new_in_context_messages = await message_manager.create_many_messages_async(
|
||||
create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor),
|
||||
actor=actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -497,7 +497,6 @@ class LettaAgent(BaseAgent):
|
||||
await self.message_manager.create_many_messages_async(
|
||||
initial_messages,
|
||||
actor=self.actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
@@ -828,7 +827,6 @@ class LettaAgent(BaseAgent):
|
||||
await self.message_manager.create_many_messages_async(
|
||||
initial_messages,
|
||||
actor=self.actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
@@ -1267,7 +1265,6 @@ class LettaAgent(BaseAgent):
|
||||
await self.message_manager.create_many_messages_async(
|
||||
initial_messages,
|
||||
actor=self.actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
@@ -1676,7 +1673,7 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config, project_id=agent_state.project_id
|
||||
messages_to_persist, actor=self.actor, project_id=agent_state.project_id
|
||||
)
|
||||
return persisted_messages, continue_stepping, stop_reason
|
||||
|
||||
@@ -1788,7 +1785,7 @@ class LettaAgent(BaseAgent):
|
||||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config, project_id=agent_state.project_id
|
||||
messages_to_persist, actor=self.actor, project_id=agent_state.project_id
|
||||
)
|
||||
|
||||
if run_id:
|
||||
|
||||
@@ -4,26 +4,37 @@ import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole, TagMatchMode
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.settings import settings
|
||||
from letta.settings import model_settings, settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def should_use_tpuf() -> bool:
|
||||
# We need OpenAI since we default to their embedding model
|
||||
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key)
|
||||
|
||||
|
||||
def should_use_tpuf_for_messages() -> bool:
|
||||
"""Check if Turbopuffer should be used for messages."""
|
||||
return should_use_tpuf() and bool(settings.embed_all_messages)
|
||||
return should_use_tpuf() and bool(settings.embed_all_messages) and bool(model_settings.openai_api_key)
|
||||
|
||||
|
||||
class TurbopufferClient:
|
||||
"""Client for managing archival memory with Turbopuffer vector database."""
|
||||
|
||||
default_embedding_config = EmbeddingConfig(
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
)
|
||||
|
||||
def __init__(self, api_key: str = None, region: str = None):
|
||||
"""Initialize Turbopuffer client."""
|
||||
self.api_key = api_key or settings.tpuf_api_key
|
||||
@@ -38,6 +49,26 @@ class TurbopufferClient:
|
||||
if not self.api_key:
|
||||
raise ValueError("Turbopuffer API key not provided")
|
||||
|
||||
@trace_method
|
||||
async def _generate_embeddings(self, texts: List[str], actor: "PydanticUser") -> List[List[float]]:
|
||||
"""Generate embeddings using the default embedding configuration.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
actor: User actor for embedding generation
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=self.default_embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings(texts, self.default_embedding_config)
|
||||
return embeddings
|
||||
|
||||
@trace_method
|
||||
async def _get_archive_namespace_name(self, archive_id: str) -> str:
|
||||
"""Get namespace name for a specific archive."""
|
||||
@@ -61,9 +92,9 @@ class TurbopufferClient:
|
||||
self,
|
||||
archive_id: str,
|
||||
text_chunks: List[str],
|
||||
embeddings: List[List[float]],
|
||||
passage_ids: List[str],
|
||||
organization_id: str,
|
||||
actor: "PydanticUser",
|
||||
tags: Optional[List[str]] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
@@ -72,9 +103,9 @@ class TurbopufferClient:
|
||||
Args:
|
||||
archive_id: ID of the archive
|
||||
text_chunks: List of text chunks to store
|
||||
embeddings: List of embedding vectors corresponding to text chunks
|
||||
passage_ids: List of passage IDs (must match 1:1 with text_chunks)
|
||||
organization_id: Organization ID for the passages
|
||||
actor: User actor for embedding generation
|
||||
tags: Optional list of tags to attach to all passages
|
||||
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
|
||||
|
||||
@@ -83,6 +114,9 @@ class TurbopufferClient:
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
# generate embeddings using the default config
|
||||
embeddings = await self._generate_embeddings(text_chunks, actor)
|
||||
|
||||
namespace_name = await self._get_archive_namespace_name(archive_id)
|
||||
|
||||
# handle timestamp - ensure UTC
|
||||
@@ -102,8 +136,6 @@ class TurbopufferClient:
|
||||
raise ValueError("passage_ids must be provided for Turbopuffer insertion")
|
||||
if len(passage_ids) != len(text_chunks):
|
||||
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})")
|
||||
if len(passage_ids) != len(embeddings):
|
||||
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match embeddings length ({len(embeddings)})")
|
||||
|
||||
# prepare column-based data for turbopuffer - optimized for batch insert
|
||||
ids = []
|
||||
@@ -137,7 +169,7 @@ class TurbopufferClient:
|
||||
metadata_={},
|
||||
tags=tags or [], # Include tags in the passage
|
||||
embedding=embedding,
|
||||
embedding_config=None, # Will be set by caller if needed
|
||||
embedding_config=self.default_embedding_config, # Will be set by caller if needed
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
@@ -177,9 +209,9 @@ class TurbopufferClient:
|
||||
self,
|
||||
agent_id: str,
|
||||
message_texts: List[str],
|
||||
embeddings: List[List[float]],
|
||||
message_ids: List[str],
|
||||
organization_id: str,
|
||||
actor: "PydanticUser",
|
||||
roles: List[MessageRole],
|
||||
created_ats: List[datetime],
|
||||
project_id: Optional[str] = None,
|
||||
@@ -189,9 +221,9 @@ class TurbopufferClient:
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
message_texts: List of message text content to store
|
||||
embeddings: List of embedding vectors corresponding to message texts
|
||||
message_ids: List of message IDs (must match 1:1 with message_texts)
|
||||
organization_id: Organization ID for the messages
|
||||
actor: User actor for embedding generation
|
||||
roles: List of message roles corresponding to each message
|
||||
created_ats: List of creation timestamps for each message
|
||||
project_id: Optional project ID for all messages
|
||||
@@ -201,6 +233,9 @@ class TurbopufferClient:
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
# generate embeddings using the default config
|
||||
embeddings = await self._generate_embeddings(message_texts, actor)
|
||||
|
||||
namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
|
||||
|
||||
# validation checks
|
||||
@@ -208,8 +243,6 @@ class TurbopufferClient:
|
||||
raise ValueError("message_ids must be provided for Turbopuffer insertion")
|
||||
if len(message_ids) != len(message_texts):
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
|
||||
if len(message_ids) != len(embeddings):
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match embeddings length ({len(embeddings)})")
|
||||
if len(message_ids) != len(roles):
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
|
||||
if len(message_ids) != len(created_ats):
|
||||
@@ -390,7 +423,7 @@ class TurbopufferClient:
|
||||
async def query_passages(
|
||||
self,
|
||||
archive_id: str,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
actor: "PydanticUser",
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
||||
top_k: int = 10,
|
||||
@@ -405,8 +438,8 @@ class TurbopufferClient:
|
||||
|
||||
Args:
|
||||
archive_id: ID of the archive
|
||||
query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes)
|
||||
query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
|
||||
actor: User actor for embedding generation
|
||||
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
||||
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
|
||||
top_k: Number of results to return
|
||||
tags: Optional list of tags to filter by
|
||||
@@ -419,6 +452,12 @@ class TurbopufferClient:
|
||||
Returns:
|
||||
List of (passage, score, metadata) tuples with relevance rankings
|
||||
"""
|
||||
# generate embedding for vector/hybrid search if query_text is provided
|
||||
query_embedding = None
|
||||
if query_text and search_mode in ["vector", "hybrid"]:
|
||||
embeddings = await self._generate_embeddings([query_text], actor)
|
||||
query_embedding = embeddings[0]
|
||||
|
||||
# Check if we should fallback to timestamp-based retrieval
|
||||
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
||||
# Fallback to retrieving most recent passages when no search query is provided
|
||||
@@ -519,11 +558,11 @@ class TurbopufferClient:
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def query_messages(
|
||||
async def query_messages_by_agent_id(
|
||||
self,
|
||||
agent_id: str,
|
||||
organization_id: str,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
actor: "PydanticUser",
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
|
||||
top_k: int = 10,
|
||||
@@ -539,8 +578,8 @@ class TurbopufferClient:
|
||||
Args:
|
||||
agent_id: ID of the agent (used for filtering results)
|
||||
organization_id: Organization ID for namespace lookup
|
||||
query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes)
|
||||
query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
|
||||
actor: User actor for embedding generation
|
||||
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
||||
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector")
|
||||
top_k: Number of results to return
|
||||
roles: Optional list of message roles to filter by
|
||||
@@ -556,6 +595,12 @@ class TurbopufferClient:
|
||||
- score is the final relevance score
|
||||
- metadata contains individual scores and ranking information
|
||||
"""
|
||||
# generate embedding for vector/hybrid search if query_text is provided
|
||||
query_embedding = None
|
||||
if query_text and search_mode in ["vector", "hybrid"]:
|
||||
embeddings = await self._generate_embeddings([query_text], actor)
|
||||
query_embedding = embeddings[0]
|
||||
|
||||
# Check if we should fallback to timestamp-based retrieval
|
||||
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
||||
# Fallback to retrieving most recent messages when no search query is provided
|
||||
@@ -658,6 +703,159 @@ class TurbopufferClient:
|
||||
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
async def query_messages_by_org_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
actor: "PydanticUser",
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "hybrid", # "vector", "fts", "hybrid"
|
||||
top_k: int = 10,
|
||||
roles: Optional[List[MessageRole]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
) -> List[Tuple[dict, float, dict]]:
|
||||
"""Query messages from Turbopuffer across an entire organization.
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID for namespace lookup (required)
|
||||
actor: User actor for embedding generation
|
||||
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
||||
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "hybrid")
|
||||
top_k: Number of results to return
|
||||
roles: Optional list of message roles to filter by
|
||||
project_id: Optional project ID to filter messages by
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
start_date: Optional datetime to filter messages created after this date
|
||||
end_date: Optional datetime to filter messages created on or before this date (inclusive)
|
||||
|
||||
Returns:
|
||||
List of (message_dict, score, metadata) tuples where:
|
||||
- message_dict contains id, text, role, created_at, agent_id
|
||||
- score is the final relevance score (RRF score for hybrid, rank-based for single mode)
|
||||
- metadata contains individual scores and ranking information
|
||||
"""
|
||||
# generate embedding for vector/hybrid search if query_text is provided
|
||||
query_embedding = None
|
||||
if query_text and search_mode in ["vector", "hybrid"]:
|
||||
embeddings = await self._generate_embeddings([query_text], actor)
|
||||
query_embedding = embeddings[0]
|
||||
# namespace is org-scoped
|
||||
namespace_name = f"letta_messages_{organization_id}"
|
||||
|
||||
# build filters
|
||||
all_filters = []
|
||||
|
||||
# role filter
|
||||
if roles:
|
||||
role_values = [r.value for r in roles]
|
||||
if len(role_values) == 1:
|
||||
all_filters.append(("role", "Eq", role_values[0]))
|
||||
else:
|
||||
all_filters.append(("role", "In", role_values))
|
||||
|
||||
# project filter
|
||||
if project_id:
|
||||
all_filters.append(("project_id", "Eq", project_id))
|
||||
|
||||
# date filters
|
||||
if start_date:
|
||||
all_filters.append(("created_at", "Gte", start_date))
|
||||
if end_date:
|
||||
# make end_date inclusive of the entire day
|
||||
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
|
||||
from datetime import timedelta
|
||||
|
||||
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
|
||||
all_filters.append(("created_at", "Lte", end_date))
|
||||
|
||||
# combine filters
|
||||
final_filter = None
|
||||
if len(all_filters) == 1:
|
||||
final_filter = all_filters[0]
|
||||
elif len(all_filters) > 1:
|
||||
final_filter = ("And", all_filters)
|
||||
|
||||
try:
|
||||
# execute query
|
||||
result = await self._execute_query(
|
||||
namespace_name=namespace_name,
|
||||
search_mode=search_mode,
|
||||
query_embedding=query_embedding,
|
||||
query_text=query_text,
|
||||
top_k=top_k,
|
||||
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
|
||||
filters=final_filter,
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
)
|
||||
|
||||
# process results based on search mode
|
||||
if search_mode == "hybrid":
|
||||
# for hybrid mode, we get a multi-query response
|
||||
vector_results = self._process_message_query_results(result.results[0])
|
||||
fts_results = self._process_message_query_results(result.results[1])
|
||||
|
||||
# use existing RRF method - it already returns metadata with ranks
|
||||
results_with_metadata = self._reciprocal_rank_fusion(
|
||||
vector_results=vector_results,
|
||||
fts_results=fts_results,
|
||||
get_id_func=lambda msg_dict: msg_dict["id"],
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
# add raw scores to metadata if available
|
||||
vector_scores = {}
|
||||
for row in result.results[0].rows:
|
||||
if hasattr(row, "dist"):
|
||||
vector_scores[row.id] = row.dist
|
||||
|
||||
fts_scores = {}
|
||||
for row in result.results[1].rows:
|
||||
if hasattr(row, "score"):
|
||||
fts_scores[row.id] = row.score
|
||||
|
||||
# enhance metadata with raw scores
|
||||
enhanced_results = []
|
||||
for msg_dict, rrf_score, metadata in results_with_metadata:
|
||||
msg_id = msg_dict["id"]
|
||||
if msg_id in vector_scores:
|
||||
metadata["vector_score"] = vector_scores[msg_id]
|
||||
if msg_id in fts_scores:
|
||||
metadata["fts_score"] = fts_scores[msg_id]
|
||||
enhanced_results.append((msg_dict, rrf_score, metadata))
|
||||
|
||||
return enhanced_results
|
||||
else:
|
||||
# for single queries (vector or fts)
|
||||
results = self._process_message_query_results(result)
|
||||
results_with_metadata = []
|
||||
for idx, msg_dict in enumerate(results):
|
||||
metadata = {
|
||||
"combined_score": 1.0 / (idx + 1),
|
||||
"search_mode": search_mode,
|
||||
f"{search_mode}_rank": idx + 1,
|
||||
}
|
||||
|
||||
# add raw score if available
|
||||
if hasattr(result.rows[idx], "dist"):
|
||||
metadata["vector_score"] = result.rows[idx].dist
|
||||
elif hasattr(result.rows[idx], "score"):
|
||||
metadata["fts_score"] = result.rows[idx].score
|
||||
|
||||
results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
|
||||
|
||||
return results_with_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
def _process_message_query_results(self, result) -> List[dict]:
|
||||
"""Process results from a message query into message dicts.
|
||||
|
||||
@@ -703,7 +901,7 @@ class TurbopufferClient:
|
||||
tags=passage_tags, # Set the actual tags from the passage
|
||||
# Set required fields to empty/default values since we don't store embeddings
|
||||
embedding=[], # Empty embedding since we don't return it from Turbopuffer
|
||||
embedding_config=None, # No embedding config needed for retrieved passages
|
||||
embedding_config=self.default_embedding_config, # No embedding config needed for retrieved passages
|
||||
)
|
||||
|
||||
# handle score based on search type
|
||||
|
||||
@@ -1187,3 +1187,26 @@ class ToolReturn(BaseModel):
|
||||
stdout: Optional[List[str]] = Field(default=None, description="Captured stdout (e.g. prints, logs) from the tool invocation")
|
||||
stderr: Optional[List[str]] = Field(default=None, description="Captured stderr from the tool invocation")
|
||||
# func_return: Optional[Any] = Field(None, description="The function return object")
|
||||
|
||||
|
||||
class MessageSearchRequest(BaseModel):
|
||||
"""Request model for searching messages across the organization"""
|
||||
|
||||
query_text: Optional[str] = Field(None, description="Text query for full-text search")
|
||||
search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use")
|
||||
roles: Optional[List[MessageRole]] = Field(None, description="Filter messages by role")
|
||||
project_id: Optional[str] = Field(None, description="Filter messages by project ID")
|
||||
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
|
||||
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
|
||||
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")
|
||||
|
||||
|
||||
class MessageSearchResult(BaseModel):
|
||||
"""Result from a message search operation with scoring details."""
|
||||
|
||||
message: Message = Field(..., description="The message content and metadata")
|
||||
fts_score: Optional[float] = Field(None, description="Full-text search (BM25) score if FTS was used")
|
||||
fts_rank: Optional[int] = Field(None, description="Full-text search rank position if FTS was used")
|
||||
vector_score: Optional[float] = Field(None, description="Vector similarity score if vector search was used")
|
||||
vector_rank: Optional[int] = Field(None, description="Vector search rank position if vector search was used")
|
||||
rrf_score: float = Field(..., description="Reciprocal Rank Fusion combined score")
|
||||
|
||||
@@ -39,7 +39,7 @@ from letta.schemas.memory import (
|
||||
CreateArchivalMemory,
|
||||
Memory,
|
||||
)
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.message import MessageCreate, MessageSearchRequest, MessageSearchResult
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.source import Source
|
||||
@@ -1498,6 +1498,42 @@ async def cancel_agent_run(
|
||||
return results
|
||||
|
||||
|
||||
@router.post("/messages/search", response_model=List[MessageSearchResult], operation_id="search_messages")
|
||||
async def search_messages(
|
||||
request: MessageSearchRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Search messages across the entire organization with optional project filtering.
|
||||
Returns messages with FTS/vector ranks and total RRF score.
|
||||
|
||||
Requires message embedding and Turbopuffer to be enabled.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# get embedding config from the default agent if needed
|
||||
# check if any agents exist in the org
|
||||
agent_count = await server.agent_manager.size_async(actor=actor)
|
||||
if agent_count == 0:
|
||||
raise HTTPException(status_code=400, detail="No agents found in organization to derive embedding configuration from")
|
||||
|
||||
try:
|
||||
results = await server.message_manager.search_messages_org_async(
|
||||
actor=actor,
|
||||
query_text=request.query_text,
|
||||
search_mode=request.search_mode,
|
||||
roles=request.roles,
|
||||
project_id=request.project_id,
|
||||
limit=request.limit,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
)
|
||||
return results
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
async def _process_message_background(
|
||||
run_id: str,
|
||||
server: SyncServer,
|
||||
|
||||
@@ -719,9 +719,7 @@ class AgentManager:
|
||||
|
||||
# Only create messages if we initialized with messages
|
||||
if not _init_with_no_messages:
|
||||
await self.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=init_messages, actor=actor, embedding_config=result.embedding_config, project_id=result.project_id
|
||||
)
|
||||
await self.message_manager.create_many_messages_async(pydantic_msgs=init_messages, actor=actor, project_id=result.project_id)
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@@ -1834,7 +1832,6 @@ class AgentManager:
|
||||
message_id=curr_system_message.id,
|
||||
message_update=MessageUpdate(**temp_message.model_dump()),
|
||||
actor=actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
else:
|
||||
@@ -1889,9 +1886,7 @@ class AgentManager:
|
||||
self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser
|
||||
) -> PydanticAgentState:
|
||||
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
messages = await self.message_manager.create_many_messages_async(
|
||||
messages, actor=actor, embedding_config=agent.embedding_config, project_id=agent.project_id
|
||||
)
|
||||
messages = await self.message_manager.create_many_messages_async(messages, actor=actor, project_id=agent.project_id)
|
||||
message_ids = agent.message_ids or []
|
||||
message_ids += [m.id for m in messages]
|
||||
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
@@ -2692,7 +2687,6 @@ class AgentManager:
|
||||
# use hybrid search to combine vector and full-text search
|
||||
passages_with_scores = await tpuf_client.query_passages(
|
||||
archive_id=archive_ids[0],
|
||||
query_embedding=query_embedding,
|
||||
query_text=query_text, # pass text for potential hybrid search
|
||||
search_mode="hybrid", # use hybrid mode for better results
|
||||
top_k=limit,
|
||||
@@ -2700,6 +2694,7 @@ class AgentManager:
|
||||
tag_match_mode=tag_match_mode or TagMatchMode.ANY,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Return full tuples with metadata
|
||||
|
||||
@@ -678,7 +678,6 @@ class AgentSerializationManager:
|
||||
created_messages = await self.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=messages,
|
||||
actor=actor,
|
||||
embedding_config=created_agent.embedding_config,
|
||||
project_id=created_agent.project_id,
|
||||
)
|
||||
imported_count += len(created_messages)
|
||||
|
||||
@@ -11,11 +11,10 @@ from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessageUpdateUnion
|
||||
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType, TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage, MessageUpdate
|
||||
from letta.schemas.message import Message as PydanticMessage, MessageSearchResult, MessageUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.file_manager import FileManager
|
||||
@@ -314,7 +313,6 @@ class MessageManager:
|
||||
self,
|
||||
pydantic_msgs: List[PydanticMessage],
|
||||
actor: PydanticUser,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
strict_mode: bool = False,
|
||||
project_id: Optional[str] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
@@ -324,7 +322,6 @@ class MessageManager:
|
||||
Args:
|
||||
pydantic_msgs: List of Pydantic message models to create
|
||||
actor: User performing the action
|
||||
embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer
|
||||
strict_mode: If True, wait for embedding to complete; if False, run in background
|
||||
project_id: Optional project ID for the messages (for Turbopuffer indexing)
|
||||
|
||||
@@ -365,20 +362,20 @@ class MessageManager:
|
||||
result = [msg.to_pydantic() for msg in created_messages]
|
||||
await session.commit()
|
||||
|
||||
# embed messages in turbopuffer if enabled and embedding_config provided
|
||||
# embed messages in turbopuffer if enabled
|
||||
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
|
||||
|
||||
if should_use_tpuf_for_messages() and embedding_config and result:
|
||||
if should_use_tpuf_for_messages() and result:
|
||||
# extract agent_id from the first message (all should have same agent_id)
|
||||
agent_id = result[0].agent_id
|
||||
if agent_id:
|
||||
if strict_mode:
|
||||
# wait for embedding to complete
|
||||
await self._embed_messages_background(result, embedding_config, actor, agent_id, project_id)
|
||||
await self._embed_messages_background(result, actor, agent_id, project_id)
|
||||
else:
|
||||
# fire and forget - run embedding in background
|
||||
fire_and_forget(
|
||||
self._embed_messages_background(result, embedding_config, actor, agent_id, project_id),
|
||||
self._embed_messages_background(result, actor, agent_id, project_id),
|
||||
task_name=f"embed_messages_for_agent_{agent_id}",
|
||||
)
|
||||
|
||||
@@ -387,7 +384,6 @@ class MessageManager:
|
||||
async def _embed_messages_background(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
embedding_config: EmbeddingConfig,
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
project_id: Optional[str] = None,
|
||||
@@ -396,14 +392,12 @@ class MessageManager:
|
||||
|
||||
Args:
|
||||
messages: List of messages to embed
|
||||
embedding_config: Embedding configuration
|
||||
actor: User performing the action
|
||||
agent_id: Agent ID for the messages
|
||||
project_id: Optional project ID for the messages
|
||||
"""
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
# extract text content from each message
|
||||
message_texts = []
|
||||
@@ -423,21 +417,14 @@ class MessageManager:
|
||||
created_ats.append(msg.created_at)
|
||||
|
||||
if message_texts:
|
||||
# generate embeddings using provided config
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings(message_texts, embedding_config)
|
||||
|
||||
# insert to turbopuffer
|
||||
# insert to turbopuffer - TurbopufferClient will generate embeddings internally
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
project_id=project_id,
|
||||
@@ -550,7 +537,6 @@ class MessageManager:
|
||||
message_id: str,
|
||||
message_update: MessageUpdate,
|
||||
actor: PydanticUser,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
strict_mode: bool = False,
|
||||
project_id: Optional[str] = None,
|
||||
) -> PydanticMessage:
|
||||
@@ -562,7 +548,6 @@ class MessageManager:
|
||||
message_id: ID of the message to update
|
||||
message_update: Update data for the message
|
||||
actor: User performing the action
|
||||
embedding_config: Optional embedding configuration for Turbopuffer
|
||||
strict_mode: If True, wait for embedding update to complete; if False, run in background
|
||||
project_id: Optional project ID for the message (for Turbopuffer indexing)
|
||||
"""
|
||||
@@ -582,7 +567,7 @@ class MessageManager:
|
||||
# update message in turbopuffer if enabled (delete and re-insert)
|
||||
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
|
||||
|
||||
if should_use_tpuf_for_messages() and embedding_config and pydantic_message.agent_id:
|
||||
if should_use_tpuf_for_messages() and pydantic_message.agent_id:
|
||||
# extract text content from updated message
|
||||
text = self._extract_message_text(pydantic_message)
|
||||
|
||||
@@ -590,51 +575,42 @@ class MessageManager:
|
||||
if text:
|
||||
if strict_mode:
|
||||
# wait for embedding update to complete
|
||||
await self._update_message_embedding_background(pydantic_message, text, embedding_config, actor, project_id)
|
||||
await self._update_message_embedding_background(pydantic_message, text, actor, project_id)
|
||||
else:
|
||||
# fire and forget - run embedding update in background
|
||||
fire_and_forget(
|
||||
self._update_message_embedding_background(pydantic_message, text, embedding_config, actor, project_id),
|
||||
self._update_message_embedding_background(pydantic_message, text, actor, project_id),
|
||||
task_name=f"update_message_embedding_{message_id}",
|
||||
)
|
||||
|
||||
return pydantic_message
|
||||
|
||||
async def _update_message_embedding_background(
|
||||
self, message: PydanticMessage, text: str, embedding_config: EmbeddingConfig, actor: PydanticUser, project_id: Optional[str] = None
|
||||
self, message: PydanticMessage, text: str, actor: PydanticUser, project_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""Background task to update a message's embedding in Turbopuffer.
|
||||
|
||||
Args:
|
||||
message: The updated message
|
||||
text: Extracted text content from the message
|
||||
embedding_config: Embedding configuration
|
||||
actor: User performing the action
|
||||
project_id: Optional project ID for the message
|
||||
"""
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# delete old message from turbopuffer
|
||||
await tpuf_client.delete_messages(agent_id=message.agent_id, organization_id=actor.organization_id, message_ids=[message.id])
|
||||
|
||||
# generate new embedding
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([text], embedding_config)
|
||||
|
||||
# re-insert with updated content
|
||||
# re-insert with updated content - TurbopufferClient will generate embeddings internally
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=message.agent_id,
|
||||
message_texts=[text],
|
||||
embeddings=embeddings,
|
||||
message_ids=[message.id],
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
roles=[message.role],
|
||||
created_ats=[message.created_at],
|
||||
project_id=project_id,
|
||||
@@ -1119,13 +1095,11 @@ class MessageManager:
|
||||
agent_id: str,
|
||||
actor: PydanticUser,
|
||||
query_text: Optional[str] = None,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
search_mode: str = "hybrid",
|
||||
roles: Optional[List[MessageRole]] = None,
|
||||
limit: int = 50,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
) -> List[Tuple[PydanticMessage, dict]]:
|
||||
"""
|
||||
Search messages using Turbopuffer if enabled, otherwise fall back to SQL search.
|
||||
@@ -1133,14 +1107,12 @@ class MessageManager:
|
||||
Args:
|
||||
agent_id: ID of the agent whose messages to search
|
||||
actor: User performing the search
|
||||
query_text: Text query for full-text search
|
||||
query_embedding: Optional pre-computed embedding for vector search
|
||||
query_text: Text query (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
||||
search_mode: "vector", "fts", "hybrid", or "timestamp" (default: "hybrid")
|
||||
roles: Optional list of message roles to filter by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional filter for messages created after this date
|
||||
end_date: Optional filter for messages created on or before this date (inclusive)
|
||||
embedding_config: Optional embedding configuration for generating query embedding
|
||||
|
||||
Returns:
|
||||
List of tuples (message, metadata) where metadata contains relevance scores
|
||||
@@ -1150,36 +1122,12 @@ class MessageManager:
|
||||
# check if we should use turbopuffer
|
||||
if should_use_tpuf_for_messages():
|
||||
try:
|
||||
# generate embedding if needed and not provided
|
||||
if search_mode in ["vector", "hybrid"] and query_embedding is None and query_text:
|
||||
if not embedding_config:
|
||||
# fall back to SQL search if no embedding config
|
||||
logger.warning("No embedding config provided for vector search, falling back to SQL")
|
||||
return await self.list_messages_for_agent_async(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
query_text=query_text,
|
||||
roles=roles,
|
||||
limit=limit,
|
||||
ascending=False,
|
||||
)
|
||||
|
||||
# generate embedding from query text
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
|
||||
query_embedding = embeddings[0]
|
||||
|
||||
# use turbopuffer for search
|
||||
# use turbopuffer for search - TurbopufferClient will generate embeddings internally
|
||||
tpuf_client = TurbopufferClient()
|
||||
results = await tpuf_client.query_messages(
|
||||
results = await tpuf_client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=actor.organization_id,
|
||||
query_embedding=query_embedding,
|
||||
actor=actor,
|
||||
query_text=query_text,
|
||||
search_mode=search_mode,
|
||||
top_k=limit,
|
||||
@@ -1255,3 +1203,76 @@ class MessageManager:
|
||||
}
|
||||
message_tuples.append((message, metadata))
|
||||
return message_tuples
|
||||
|
||||
async def search_messages_org_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "hybrid",
|
||||
roles: Optional[List[MessageRole]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
) -> List[MessageSearchResult]:
|
||||
"""
|
||||
Search messages across entire organization using Turbopuffer.
|
||||
|
||||
Args:
|
||||
actor: User performing the search (must have org access)
|
||||
query_text: Text query for full-text search
|
||||
search_mode: "vector", "fts", or "hybrid" (default: "hybrid")
|
||||
roles: Optional list of message roles to filter by
|
||||
project_id: Optional project ID to filter messages by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional filter for messages created after this date
|
||||
end_date: Optional filter for messages created on or before this date (inclusive)
|
||||
|
||||
Returns:
|
||||
List of MessageSearchResult objects with scoring details
|
||||
|
||||
Raises:
|
||||
ValueError: If message embedding or Turbopuffer is not enabled
|
||||
"""
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||||
|
||||
# check if turbopuffer is enabled
|
||||
# TODO: extend to non-Turbopuffer in the future.
|
||||
if not should_use_tpuf_for_messages():
|
||||
raise ValueError("Message search requires message embedding, OpenAI, and Turbopuffer to be enabled.")
|
||||
|
||||
# use turbopuffer for search - TurbopufferClient will generate embeddings internally
|
||||
tpuf_client = TurbopufferClient()
|
||||
results = await tpuf_client.query_messages_by_org_id(
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
query_text=query_text,
|
||||
search_mode=search_mode,
|
||||
top_k=limit,
|
||||
roles=roles,
|
||||
project_id=project_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# convert results to MessageSearchResult objects
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# create message mapping
|
||||
message_ids = [msg_dict["id"] for msg_dict, _, _ in results]
|
||||
messages = await self.get_messages_by_ids_async(message_ids=message_ids, actor=actor)
|
||||
message_mapping = {message.id: message for message in messages}
|
||||
|
||||
# create search results using list comprehension
|
||||
return [
|
||||
MessageSearchResult(
|
||||
message=message_mapping.get(msg_dict["id"]),
|
||||
fts_score=metadata.get("fts_score"),
|
||||
fts_rank=metadata.get("fts_rank"),
|
||||
vector_score=metadata.get("vector_score"),
|
||||
vector_rank=metadata.get("vector_rank"),
|
||||
rrf_score=rrf_score,
|
||||
)
|
||||
for msg_dict, rrf_score, metadata in results
|
||||
]
|
||||
|
||||
@@ -623,12 +623,13 @@ class PassageManager:
|
||||
passage_texts = [p.text for p in passages]
|
||||
|
||||
# Insert to Turbopuffer with the same IDs as SQL
|
||||
# TurbopufferClient will generate embeddings internally using default config
|
||||
await tpuf_client.insert_archival_memories(
|
||||
archive_id=archive.id,
|
||||
text_chunks=passage_texts,
|
||||
embeddings=embeddings,
|
||||
passage_ids=passage_ids, # Use same IDs as SQL
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
tags=tags,
|
||||
created_at=passages[0].created_at if passages else None,
|
||||
)
|
||||
|
||||
@@ -195,7 +195,6 @@ class Summarizer:
|
||||
await self.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[summary_message_obj],
|
||||
actor=self.actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -163,7 +163,6 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
limit=search_limit,
|
||||
start_date=start_datetime,
|
||||
end_date=end_datetime,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
)
|
||||
|
||||
if len(message_results) == 0:
|
||||
|
||||
@@ -233,7 +233,7 @@ class TestTurbopufferIntegration:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turbopuffer_metadata_attributes(self, enable_turbopuffer):
|
||||
async def test_turbopuffer_metadata_attributes(self, default_user, enable_turbopuffer):
|
||||
"""Test that Turbopuffer properly stores and retrieves metadata attributes"""
|
||||
|
||||
# Only run if we have a real API key
|
||||
@@ -273,17 +273,16 @@ class TestTurbopufferIntegration:
|
||||
result = await client.insert_archival_memories(
|
||||
archive_id=archive_id,
|
||||
text_chunks=[d["text"] for d in test_data],
|
||||
embeddings=[d["vector"] for d in test_data],
|
||||
passage_ids=[d["id"] for d in test_data],
|
||||
organization_id="org-123", # Default org
|
||||
actor=default_user,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
# Query all passages (no tag filtering)
|
||||
query_vector = [0.15] * 1536
|
||||
results = await client.query_passages(archive_id=archive_id, query_embedding=query_vector, top_k=10)
|
||||
results = await client.query_passages(archive_id=archive_id, actor=default_user, top_k=10)
|
||||
|
||||
# Should get all passages
|
||||
assert len(results) == 3 # All three passages
|
||||
@@ -339,7 +338,7 @@ class TestTurbopufferIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
|
||||
async def test_hybrid_search_with_real_tpuf(self, enable_turbopuffer):
|
||||
async def test_hybrid_search_with_real_tpuf(self, default_user, enable_turbopuffer):
|
||||
"""Test hybrid search functionality combining vector and full-text search"""
|
||||
|
||||
import uuid
|
||||
@@ -366,13 +365,14 @@ class TestTurbopufferIntegration:
|
||||
|
||||
# Insert passages
|
||||
await client.insert_archival_memories(
|
||||
archive_id=archive_id, text_chunks=texts, embeddings=embeddings, passage_ids=passage_ids, organization_id=org_id
|
||||
archive_id=archive_id, text_chunks=texts, passage_ids=passage_ids, organization_id=org_id, actor=default_user
|
||||
)
|
||||
|
||||
# Test vector-only search
|
||||
vector_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 6.0, 11.0], # similar to second passage embedding
|
||||
actor=default_user,
|
||||
query_text="python programming tutorial",
|
||||
search_mode="vector",
|
||||
top_k=3,
|
||||
)
|
||||
@@ -382,7 +382,7 @@ class TestTurbopufferIntegration:
|
||||
|
||||
# Test FTS-only search
|
||||
fts_results = await client.query_passages(
|
||||
archive_id=archive_id, query_text="Turbopuffer vector database", search_mode="fts", top_k=3
|
||||
archive_id=archive_id, actor=default_user, query_text="Turbopuffer vector database", search_mode="fts", top_k=3
|
||||
)
|
||||
assert 0 < len(fts_results) <= 3
|
||||
# should find passages mentioning Turbopuffer
|
||||
@@ -393,7 +393,7 @@ class TestTurbopufferIntegration:
|
||||
# Test hybrid search
|
||||
hybrid_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[2.0, 7.0, 12.0],
|
||||
actor=default_user,
|
||||
query_text="vector search Turbopuffer",
|
||||
search_mode="hybrid",
|
||||
top_k=3,
|
||||
@@ -412,7 +412,7 @@ class TestTurbopufferIntegration:
|
||||
# Test with different weights
|
||||
vector_heavy_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[0.0, 5.0, 10.0], # very similar to first passage
|
||||
actor=default_user,
|
||||
query_text="quick brown fox", # matches second passage
|
||||
search_mode="hybrid",
|
||||
top_k=3,
|
||||
@@ -423,16 +423,13 @@ class TestTurbopufferIntegration:
|
||||
# all results should have scores
|
||||
assert all(isinstance(score, float) for _, score, _ in vector_heavy_results)
|
||||
|
||||
# Test error handling - missing text for hybrid mode (embedding provided but text missing)
|
||||
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
|
||||
await client.query_passages(archive_id=archive_id, query_embedding=[1.0, 2.0, 3.0], search_mode="hybrid", top_k=3)
|
||||
|
||||
# Test error handling - missing embedding for hybrid mode (text provided but embedding missing)
|
||||
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
|
||||
await client.query_passages(archive_id=archive_id, query_text="test", search_mode="hybrid", top_k=3)
|
||||
# Test with different search modes
|
||||
await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="vector", top_k=3)
|
||||
await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="fts", top_k=3)
|
||||
await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="hybrid", top_k=3)
|
||||
|
||||
# Test explicit timestamp mode
|
||||
timestamp_results = await client.query_passages(archive_id=archive_id, search_mode="timestamp", top_k=3)
|
||||
timestamp_results = await client.query_passages(archive_id=archive_id, actor=default_user, search_mode="timestamp", top_k=3)
|
||||
assert len(timestamp_results) <= 3
|
||||
# Should return passages ordered by timestamp (most recent first)
|
||||
assert all(isinstance(passage, Passage) for passage, _, _ in timestamp_results)
|
||||
@@ -446,7 +443,7 @@ class TestTurbopufferIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
|
||||
async def test_tag_filtering_with_real_tpuf(self, enable_turbopuffer):
|
||||
async def test_tag_filtering_with_real_tpuf(self, default_user, enable_turbopuffer):
|
||||
"""Test tag filtering functionality with AND and OR logic"""
|
||||
|
||||
import uuid
|
||||
@@ -479,13 +476,13 @@ class TestTurbopufferIntegration:
|
||||
passage_ids = [f"passage-{str(uuid.uuid4())}" for _ in texts]
|
||||
|
||||
# Insert passages with tags
|
||||
for i, (text, tags, embedding, passage_id) in enumerate(zip(texts, tag_sets, embeddings, passage_ids)):
|
||||
for i, (text, tags, passage_id) in enumerate(zip(texts, tag_sets, passage_ids)):
|
||||
await client.insert_archival_memories(
|
||||
archive_id=archive_id,
|
||||
text_chunks=[text],
|
||||
embeddings=[embedding],
|
||||
passage_ids=[passage_id],
|
||||
organization_id=org_id,
|
||||
actor=default_user,
|
||||
tags=tags,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -493,7 +490,8 @@ class TestTurbopufferIntegration:
|
||||
# Test tag filtering with "any" mode (should find passages with any of the specified tags)
|
||||
python_any_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 6.0, 11.0],
|
||||
actor=default_user,
|
||||
query_text="python programming",
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
tags=["python"],
|
||||
@@ -511,7 +509,8 @@ class TestTurbopufferIntegration:
|
||||
# Test tag filtering with "all" mode
|
||||
python_tutorial_all_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 6.0, 11.0],
|
||||
actor=default_user,
|
||||
query_text="python tutorial",
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
tags=["python", "tutorial"],
|
||||
@@ -528,6 +527,7 @@ class TestTurbopufferIntegration:
|
||||
# Test tag filtering with FTS mode
|
||||
js_fts_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
actor=default_user,
|
||||
query_text="javascript",
|
||||
search_mode="fts",
|
||||
top_k=10,
|
||||
@@ -545,7 +545,7 @@ class TestTurbopufferIntegration:
|
||||
# Test hybrid search with tags
|
||||
python_hybrid_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[2.0, 7.0, 12.0],
|
||||
actor=default_user,
|
||||
query_text="python programming",
|
||||
search_mode="hybrid",
|
||||
top_k=10,
|
||||
@@ -569,7 +569,7 @@ class TestTurbopufferIntegration:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_filtering_with_real_tpuf(self, enable_turbopuffer):
|
||||
async def test_temporal_filtering_with_real_tpuf(self, default_user, enable_turbopuffer):
|
||||
"""Test temporal filtering with date ranges"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
@@ -601,15 +601,14 @@ class TestTurbopufferIntegration:
|
||||
# We need to generate embeddings for the passages
|
||||
# For testing, we'll use simple dummy embeddings
|
||||
for text, timestamp in test_passages:
|
||||
dummy_embedding = [1.0, 2.0, 3.0] # Simple test embedding
|
||||
passage_id = f"passage-{uuid.uuid4()}"
|
||||
|
||||
await client.insert_archival_memories(
|
||||
archive_id=archive_id,
|
||||
text_chunks=[text],
|
||||
embeddings=[dummy_embedding],
|
||||
passage_ids=[passage_id],
|
||||
organization_id="test-org",
|
||||
actor=default_user,
|
||||
created_at=timestamp,
|
||||
)
|
||||
|
||||
@@ -617,7 +616,8 @@ class TestTurbopufferIntegration:
|
||||
three_days_ago = now - timedelta(days=3)
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
actor=default_user,
|
||||
query_text="meeting notes",
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
start_date=three_days_ago,
|
||||
@@ -637,7 +637,8 @@ class TestTurbopufferIntegration:
|
||||
two_weeks_ago = now - timedelta(days=14)
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
actor=default_user,
|
||||
query_text="meeting notes",
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
start_date=two_weeks_ago,
|
||||
@@ -652,7 +653,8 @@ class TestTurbopufferIntegration:
|
||||
# Test 3: Query with only end_date (everything before yesterday)
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
actor=default_user,
|
||||
query_text="meeting notes",
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
end_date=yesterday + timedelta(hours=12), # Middle of yesterday
|
||||
@@ -667,6 +669,7 @@ class TestTurbopufferIntegration:
|
||||
# Test 4: Test with FTS mode and date filtering
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
actor=default_user,
|
||||
query_text="meeting notes project",
|
||||
search_mode="fts",
|
||||
top_k=10,
|
||||
@@ -682,7 +685,7 @@ class TestTurbopufferIntegration:
|
||||
# Test 5: Test with hybrid mode and date filtering
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
actor=default_user,
|
||||
query_text="sprint review",
|
||||
search_mode="hybrid",
|
||||
top_k=10,
|
||||
@@ -934,11 +937,9 @@ class TestTurbopufferMessagesIntegration:
|
||||
),
|
||||
]
|
||||
|
||||
# Create messages without embedding_config
|
||||
created = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=messages,
|
||||
actor=default_user,
|
||||
embedding_config=None, # No config provided
|
||||
)
|
||||
|
||||
assert len(created) == 2
|
||||
@@ -1057,7 +1058,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding):
|
||||
async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding, default_user):
|
||||
"""Test actual message embedding and storage in Turbopuffer"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
@@ -1087,9 +1088,9 @@ class TestTurbopufferMessagesIntegration:
|
||||
success = await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=org_id,
|
||||
actor=default_user,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
@@ -1097,11 +1098,8 @@ class TestTurbopufferMessagesIntegration:
|
||||
assert success == True
|
||||
|
||||
# Verify we can query the messages
|
||||
results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, actor=default_user
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
@@ -1121,7 +1119,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding):
|
||||
async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding, default_user):
|
||||
"""Test vector search on messages in Turbopuffer"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
@@ -1145,29 +1143,23 @@ class TestTurbopufferMessagesIntegration:
|
||||
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
|
||||
|
||||
# Create embeddings that reflect content similarity
|
||||
embeddings = [
|
||||
[1.0, 0.0, 0.0], # Python programming
|
||||
[0.0, 1.0, 0.0], # JavaScript web
|
||||
[0.8, 0.0, 0.2], # ML with Python (similar to first)
|
||||
]
|
||||
|
||||
# Insert messages
|
||||
await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=org_id,
|
||||
actor=default_user,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
|
||||
# Search for Python-related messages using vector search
|
||||
query_embedding = [0.9, 0.0, 0.1] # Similar to Python messages
|
||||
results = await client.query_messages(
|
||||
results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
query_embedding=query_embedding,
|
||||
actor=default_user,
|
||||
query_text="Python programming",
|
||||
search_mode="vector",
|
||||
top_k=2,
|
||||
)
|
||||
@@ -1187,7 +1179,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding):
|
||||
async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding, default_user):
|
||||
"""Test hybrid search combining vector and FTS for messages"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
@@ -1211,30 +1203,22 @@ class TestTurbopufferMessagesIntegration:
|
||||
roles = [MessageRole.assistant] * len(message_texts)
|
||||
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
|
||||
|
||||
# Embeddings
|
||||
embeddings = [
|
||||
[0.1, 0.9, 0.0], # fox text
|
||||
[0.9, 0.1, 0.0], # ML algorithms
|
||||
[0.5, 0.5, 0.0], # Quick Python
|
||||
[0.8, 0.2, 0.0], # Deep learning
|
||||
]
|
||||
|
||||
# Insert messages
|
||||
await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=org_id,
|
||||
actor=default_user,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
|
||||
# Hybrid search - vector similar to ML but text contains "quick"
|
||||
results = await client.query_messages(
|
||||
# Hybrid search - text search for "quick"
|
||||
results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
query_embedding=[0.7, 0.3, 0.0], # Similar to ML messages
|
||||
actor=default_user,
|
||||
query_text="quick", # Text search for "quick"
|
||||
search_mode="hybrid",
|
||||
top_k=3,
|
||||
@@ -1257,7 +1241,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding):
|
||||
async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding, default_user):
|
||||
"""Test filtering messages by role"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
@@ -1283,26 +1267,21 @@ class TestTurbopufferMessagesIntegration:
|
||||
roles = [role for _, role in message_data]
|
||||
message_ids = [str(uuid.uuid4()) for _ in message_texts]
|
||||
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
|
||||
embeddings = [[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))]
|
||||
|
||||
# Insert messages
|
||||
await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=org_id,
|
||||
actor=default_user,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
|
||||
# Query only user messages
|
||||
user_results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
roles=[MessageRole.user],
|
||||
user_results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, roles=[MessageRole.user], actor=default_user
|
||||
)
|
||||
|
||||
assert len(user_results) == 2
|
||||
@@ -1311,12 +1290,13 @@ class TestTurbopufferMessagesIntegration:
|
||||
assert msg["text"] in ["I need help with Python", "Can you explain this?"]
|
||||
|
||||
# Query assistant and system messages
|
||||
non_user_results = await client.query_messages(
|
||||
non_user_results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
roles=[MessageRole.assistant, MessageRole.system],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(non_user_results) == 3
|
||||
@@ -1395,7 +1375,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
|
||||
@@ -1409,7 +1388,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Python",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(python_results) > 0
|
||||
assert any(msg.id == message_id for msg, metadata in python_results)
|
||||
@@ -1419,7 +1397,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
message_id=message_id,
|
||||
message_update=MessageUpdate(content="Updated content about JavaScript development"),
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
|
||||
@@ -1432,7 +1409,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Python",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
# Should either find no results or results that don't include our message
|
||||
assert not any(msg.id == message_id for msg, metadata in python_results_after)
|
||||
@@ -1444,7 +1420,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="JavaScript",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(js_results) > 0
|
||||
assert any(msg.id == message_id for msg, metadata in js_results)
|
||||
@@ -1497,7 +1472,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
agent_a_messages.extend(msgs)
|
||||
@@ -1514,7 +1488,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
agent_b_messages.extend(msgs)
|
||||
@@ -1526,7 +1499,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Agent A",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(agent_a_search) == 5
|
||||
|
||||
@@ -1536,7 +1508,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Agent B",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(agent_b_search) == 3
|
||||
|
||||
@@ -1559,7 +1530,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Agent A",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(agent_a_final) == 2
|
||||
# Verify the remaining messages are the correct ones
|
||||
@@ -1574,7 +1544,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Agent B",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(agent_b_final) == 0
|
||||
|
||||
@@ -1583,84 +1552,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
await server.agent_manager.delete_agent_async(agent_a.id, default_user)
|
||||
await server.agent_manager.delete_agent_async(agent_b.id, default_user)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_crud_operations_without_embedding_config(self, server, default_user, sarah_agent, enable_message_embedding):
|
||||
"""Test that CRUD operations handle missing embedding_config gracefully"""
|
||||
from letta.schemas.message import MessageUpdate
|
||||
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# Create message WITH embedding_config
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Message with searchable content about databases")],
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
message_id = messages[0].id
|
||||
|
||||
# Verify message is searchable initially
|
||||
initial_search = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="databases",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(initial_search) > 0
|
||||
assert any(msg.id == message_id for msg, metadata in initial_search)
|
||||
|
||||
# Update message WITHOUT embedding_config - should update postgres but not turbopuffer
|
||||
updated_message = await server.message_manager.update_message_by_id_async(
|
||||
message_id=message_id,
|
||||
message_update=MessageUpdate(content="Updated content about algorithms"),
|
||||
actor=default_user,
|
||||
embedding_config=None, # No config provided
|
||||
)
|
||||
|
||||
# Verify postgres was updated
|
||||
assert updated_message.id == message_id
|
||||
updated_text = server.message_manager._extract_message_text(updated_message)
|
||||
assert "algorithms" in updated_text
|
||||
assert "databases" not in updated_text
|
||||
|
||||
# Original search term should STILL find the message (turbopuffer wasn't updated)
|
||||
still_searchable = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="databases",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(still_searchable) > 0
|
||||
assert any(msg.id == message_id for msg, metadata in still_searchable)
|
||||
|
||||
# New content should NOT be searchable (wasn't re-indexed)
|
||||
not_searchable = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="algorithms",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
# Should either find no results or results that don't include our message
|
||||
assert not any(msg.id == message_id for msg, metadata in not_searchable)
|
||||
|
||||
# Clean up
|
||||
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_turbopuffer_failure_does_not_break_postgres(self, server, default_user, sarah_agent, enable_message_embedding):
|
||||
@@ -1681,7 +1572,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
@@ -1702,7 +1592,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
message_id=message_id,
|
||||
message_update=MessageUpdate(content="Updated despite turbopuffer failure"),
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=False, # Don't fail on turbopuffer errors - that's what we're testing!
|
||||
)
|
||||
|
||||
@@ -1722,7 +1611,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=None, # Create without embedding to avoid mock issues
|
||||
)
|
||||
message_to_delete_id = messages2[0].id
|
||||
|
||||
@@ -1741,7 +1629,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=False)
|
||||
|
||||
async def wait_for_embedding(
|
||||
self, agent_id: str, message_id: str, organization_id: str, max_wait: float = 10.0, poll_interval: float = 0.5
|
||||
self, agent_id: str, message_id: str, organization_id: str, actor, max_wait: float = 10.0, poll_interval: float = 0.5
|
||||
) -> bool:
|
||||
"""Poll Turbopuffer directly to check if a message has been embedded.
|
||||
|
||||
@@ -1765,9 +1653,10 @@ class TestTurbopufferMessagesIntegration:
|
||||
while asyncio.get_event_loop().time() - start_time < max_wait:
|
||||
try:
|
||||
# Query Turbopuffer directly using timestamp mode to get all messages
|
||||
results = await client.query_messages(
|
||||
results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=organization_id,
|
||||
actor=actor,
|
||||
search_mode="timestamp",
|
||||
top_k=100, # Get more messages to ensure we find it
|
||||
)
|
||||
@@ -1800,7 +1689,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=False, # Background mode
|
||||
)
|
||||
|
||||
@@ -1814,7 +1702,12 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
# Poll for embedding completion by querying Turbopuffer directly
|
||||
embedded = await self.wait_for_embedding(
|
||||
agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5
|
||||
agent_id=sarah_agent.id,
|
||||
message_id=message_id,
|
||||
organization_id=default_user.organization_id,
|
||||
actor=default_user,
|
||||
max_wait=10.0,
|
||||
poll_interval=0.5,
|
||||
)
|
||||
assert embedded, "Message was not embedded in Turbopuffer within timeout"
|
||||
|
||||
@@ -1825,7 +1718,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Python programming",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(search_results) > 0
|
||||
assert any(msg.id == message_id for msg, _ in search_results)
|
||||
@@ -1851,7 +1743,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True, # Ensure initial embedding
|
||||
)
|
||||
|
||||
@@ -1865,7 +1756,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="databases",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert any(msg.id == message_id for msg, _ in initial_results)
|
||||
|
||||
@@ -1874,7 +1764,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
message_id=message_id,
|
||||
message_update=MessageUpdate(content="Updated content about machine learning"),
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=False, # Background mode
|
||||
)
|
||||
|
||||
@@ -1890,7 +1779,12 @@ class TestTurbopufferMessagesIntegration:
|
||||
# Poll for the update to be reflected in Turbopuffer
|
||||
# We check by searching for the new content
|
||||
embedded = await self.wait_for_embedding(
|
||||
agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5
|
||||
agent_id=sarah_agent.id,
|
||||
message_id=message_id,
|
||||
organization_id=default_user.organization_id,
|
||||
actor=default_user,
|
||||
max_wait=10.0,
|
||||
poll_interval=0.5,
|
||||
)
|
||||
assert embedded, "Updated message was not re-embedded within timeout"
|
||||
|
||||
@@ -1901,7 +1795,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="machine learning",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert any(msg.id == message_id for msg, _ in new_results)
|
||||
|
||||
@@ -1914,7 +1807,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="databases",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
# The message shouldn't match the old search term anymore
|
||||
if len(old_results) > 0:
|
||||
@@ -1929,7 +1821,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding):
|
||||
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding, default_user):
|
||||
"""Test filtering messages by date range"""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -1959,21 +1851,17 @@ class TestTurbopufferMessagesIntegration:
|
||||
await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=[text],
|
||||
embeddings=[[1.0, 2.0, 3.0]],
|
||||
message_ids=[str(uuid.uuid4())],
|
||||
organization_id=org_id,
|
||||
actor=default_user,
|
||||
roles=[MessageRole.assistant],
|
||||
created_ats=[timestamp],
|
||||
)
|
||||
|
||||
# Query messages from the last 3 days
|
||||
three_days_ago = now - timedelta(days=3)
|
||||
recent_results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
start_date=three_days_ago,
|
||||
recent_results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, start_date=three_days_ago, actor=default_user
|
||||
)
|
||||
|
||||
# Should get today's and yesterday's messages
|
||||
@@ -1984,13 +1872,14 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
# Query messages between 2 weeks ago and 1 week ago
|
||||
two_weeks_ago = now - timedelta(days=14)
|
||||
week_results = await client.query_messages(
|
||||
week_results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
start_date=two_weeks_ago,
|
||||
end_date=last_week + timedelta(days=1), # Include last week's message
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Should get only last week's message
|
||||
@@ -1998,10 +1887,11 @@ class TestTurbopufferMessagesIntegration:
|
||||
assert week_results[0][0]["text"] == "Last week's message"
|
||||
|
||||
# Query with vector search and date filtering
|
||||
filtered_vector_results = await client.query_messages(
|
||||
filtered_vector_results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
actor=default_user,
|
||||
query_text="message",
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
start_date=three_days_ago,
|
||||
@@ -2101,7 +1991,7 @@ class TestNamespaceTracking:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_project_id_filtering(self, server, sarah_agent, default_user, enable_turbopuffer, enable_message_embedding):
|
||||
"""Test that project_id filtering works correctly in query_messages"""
|
||||
"""Test that project_id filtering works correctly in query_messages_by_agent_id"""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
|
||||
# Create two project IDs
|
||||
@@ -2124,24 +2014,15 @@ class TestNamespaceTracking:
|
||||
# Insert messages with their respective project IDs
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# Generate embeddings
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=sarah_agent.embedding_config.embedding_endpoint_type,
|
||||
actor=default_user,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings(
|
||||
[message_a.content[0].text, message_b.content[0].text], sarah_agent.embedding_config
|
||||
)
|
||||
# Embeddings will be generated automatically by the client
|
||||
|
||||
# Insert message A with project_a_id
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=sarah_agent.id,
|
||||
message_texts=[message_a.content[0].text],
|
||||
embeddings=[embeddings[0]],
|
||||
message_ids=[message_a.id],
|
||||
organization_id=default_user.organization_id,
|
||||
actor=default_user,
|
||||
roles=[message_a.role],
|
||||
created_ats=[message_a.created_at],
|
||||
project_id=project_a_id,
|
||||
@@ -2151,9 +2032,9 @@ class TestNamespaceTracking:
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=sarah_agent.id,
|
||||
message_texts=[message_b.content[0].text],
|
||||
embeddings=[embeddings[1]],
|
||||
message_ids=[message_b.id],
|
||||
organization_id=default_user.organization_id,
|
||||
actor=default_user,
|
||||
roles=[message_b.role],
|
||||
created_ats=[message_b.created_at],
|
||||
project_id=project_b_id,
|
||||
@@ -2162,12 +2043,13 @@ class TestNamespaceTracking:
|
||||
# Poll for message A with project_a_id filter
|
||||
max_retries = 10
|
||||
for i in range(max_retries):
|
||||
results_a = await tpuf_client.query_messages(
|
||||
results_a = await tpuf_client.query_messages_by_agent_id(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
search_mode="timestamp", # Simple timestamp retrieval
|
||||
top_k=10,
|
||||
project_id=project_a_id,
|
||||
actor=default_user,
|
||||
)
|
||||
if len(results_a) == 1 and results_a[0][0]["id"] == message_a.id:
|
||||
break
|
||||
@@ -2179,12 +2061,13 @@ class TestNamespaceTracking:
|
||||
|
||||
# Poll for message B with project_b_id filter
|
||||
for i in range(max_retries):
|
||||
results_b = await tpuf_client.query_messages(
|
||||
results_b = await tpuf_client.query_messages_by_agent_id(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
project_id=project_b_id,
|
||||
actor=default_user,
|
||||
)
|
||||
if len(results_b) == 1 and results_b[0][0]["id"] == message_b.id:
|
||||
break
|
||||
@@ -2195,12 +2078,13 @@ class TestNamespaceTracking:
|
||||
assert "JavaScript" in results_b[0][0]["text"]
|
||||
|
||||
# Query without project filter - should find both
|
||||
results_all = await tpuf_client.query_messages(
|
||||
results_all = await tpuf_client.query_messages_by_agent_id(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
project_id=None, # No filter
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(results_all) >= 2 # May have other messages from setup
|
||||
|
||||
Reference in New Issue
Block a user