feat: Embed all messages in turbopuffer [LET-4143] (#4352)
* wip * Finish embedding * Fix ruff and tests
This commit is contained in:
@@ -2,10 +2,10 @@
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import TagMatchMode
|
||||
from letta.schemas.enums import MessageRole, TagMatchMode
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.settings import settings
|
||||
|
||||
@@ -16,6 +16,11 @@ def should_use_tpuf() -> bool:
|
||||
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key)
|
||||
|
||||
|
||||
def should_use_tpuf_for_messages() -> bool:
|
||||
"""Check if Turbopuffer should be used for messages."""
|
||||
return should_use_tpuf() and bool(settings.embed_all_messages)
|
||||
|
||||
|
||||
class TurbopufferClient:
|
||||
"""Client for managing archival memory with Turbopuffer vector database."""
|
||||
|
||||
@@ -28,7 +33,7 @@ class TurbopufferClient:
|
||||
raise ValueError("Turbopuffer API key not provided")
|
||||
|
||||
@trace_method
|
||||
def _get_namespace_name(self, archive_id: str) -> str:
|
||||
def _get_archive_namespace_name(self, archive_id: str) -> str:
|
||||
"""Get namespace name for a specific archive."""
|
||||
# use archive_id as namespace to isolate different archives' memories
|
||||
# append environment suffix to namespace for isolation if environment is set
|
||||
@@ -39,6 +44,18 @@ class TurbopufferClient:
|
||||
namespace_name = archive_id
|
||||
return namespace_name
|
||||
|
||||
@trace_method
|
||||
def _get_message_namespace_name(self, agent_id: str) -> str:
|
||||
"""Get namespace name for a specific agent's messages."""
|
||||
# use agent_id as namespace to isolate different agents' messages
|
||||
# append environment suffix to namespace for isolation if environment is set
|
||||
environment = settings.environment
|
||||
if environment:
|
||||
namespace_name = f"messages_{agent_id}_{environment.lower()}"
|
||||
else:
|
||||
namespace_name = f"messages_{agent_id}"
|
||||
return namespace_name
|
||||
|
||||
@trace_method
|
||||
async def insert_archival_memories(
|
||||
self,
|
||||
@@ -66,7 +83,7 @@ class TurbopufferClient:
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = self._get_namespace_name(archive_id)
|
||||
namespace_name = self._get_archive_namespace_name(archive_id)
|
||||
|
||||
# handle timestamp - ensure UTC
|
||||
if created_at is None:
|
||||
@@ -155,6 +172,212 @@ class TurbopufferClient:
|
||||
logger.error("Duplicate passage IDs detected in batch")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def insert_messages(
|
||||
self,
|
||||
agent_id: str,
|
||||
message_texts: List[str],
|
||||
embeddings: List[List[float]],
|
||||
message_ids: List[str],
|
||||
organization_id: str,
|
||||
roles: List[MessageRole],
|
||||
created_ats: List[datetime],
|
||||
) -> bool:
|
||||
"""Insert messages into Turbopuffer.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
message_texts: List of message text content to store
|
||||
embeddings: List of embedding vectors corresponding to message texts
|
||||
message_ids: List of message IDs (must match 1:1 with message_texts)
|
||||
organization_id: Organization ID for the messages
|
||||
roles: List of message roles corresponding to each message
|
||||
created_ats: List of creation timestamps for each message
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = self._get_message_namespace_name(agent_id)
|
||||
|
||||
# validation checks
|
||||
if not message_ids:
|
||||
raise ValueError("message_ids must be provided for Turbopuffer insertion")
|
||||
if len(message_ids) != len(message_texts):
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
|
||||
if len(message_ids) != len(embeddings):
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match embeddings length ({len(embeddings)})")
|
||||
if len(message_ids) != len(roles):
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
|
||||
if len(message_ids) != len(created_ats):
|
||||
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
|
||||
|
||||
# prepare column-based data for turbopuffer - optimized for batch insert
|
||||
ids = []
|
||||
vectors = []
|
||||
texts = []
|
||||
organization_ids = []
|
||||
agent_ids = []
|
||||
message_roles = []
|
||||
created_at_timestamps = []
|
||||
|
||||
for idx, (text, embedding, role, created_at) in enumerate(zip(message_texts, embeddings, roles, created_ats)):
|
||||
message_id = message_ids[idx]
|
||||
|
||||
# ensure the provided timestamp is timezone-aware and in UTC
|
||||
if created_at.tzinfo is None:
|
||||
# assume UTC if no timezone provided
|
||||
timestamp = created_at.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# convert to UTC if in different timezone
|
||||
timestamp = created_at.astimezone(timezone.utc)
|
||||
|
||||
# append to columns
|
||||
ids.append(message_id)
|
||||
vectors.append(embedding)
|
||||
texts.append(text)
|
||||
organization_ids.append(organization_id)
|
||||
agent_ids.append(agent_id)
|
||||
message_roles.append(role.value)
|
||||
created_at_timestamps.append(timestamp)
|
||||
|
||||
# build column-based upsert data
|
||||
upsert_columns = {
|
||||
"id": ids,
|
||||
"vector": vectors,
|
||||
"text": texts,
|
||||
"organization_id": organization_ids,
|
||||
"agent_id": agent_ids,
|
||||
"role": message_roles,
|
||||
"created_at": created_at_timestamps,
|
||||
}
|
||||
|
||||
try:
|
||||
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
# turbopuffer recommends column-based writes for performance
|
||||
await namespace.write(
|
||||
upsert_columns=upsert_columns,
|
||||
distance_metric="cosine_distance",
|
||||
schema={"text": {"type": "string", "full_text_search": True}},
|
||||
)
|
||||
logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert messages to Turbopuffer: {e}")
|
||||
# check if it's a duplicate ID error
|
||||
if "duplicate" in str(e).lower():
|
||||
logger.error("Duplicate message IDs detected in batch")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def _execute_query(
|
||||
self,
|
||||
namespace_name: str,
|
||||
search_mode: str,
|
||||
query_embedding: Optional[List[float]],
|
||||
query_text: Optional[str],
|
||||
top_k: int,
|
||||
include_attributes: List[str],
|
||||
filters: Optional[Any] = None,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
) -> Any:
|
||||
"""Generic query execution for Turbopuffer.
|
||||
|
||||
Args:
|
||||
namespace_name: Turbopuffer namespace to query
|
||||
search_mode: "vector", "fts", "hybrid", or "timestamp"
|
||||
query_embedding: Embedding for vector search
|
||||
query_text: Text for full-text search
|
||||
top_k: Number of results to return
|
||||
include_attributes: Attributes to include in results
|
||||
filters: Turbopuffer filter expression
|
||||
vector_weight: Weight for vector search in hybrid mode
|
||||
fts_weight: Weight for FTS in hybrid mode
|
||||
|
||||
Returns:
|
||||
Raw Turbopuffer query results or multi-query response
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
from turbopuffer.types import QueryParam
|
||||
|
||||
# validate inputs based on search mode
|
||||
if search_mode == "vector" and query_embedding is None:
|
||||
raise ValueError("query_embedding is required for vector search mode")
|
||||
if search_mode == "fts" and query_text is None:
|
||||
raise ValueError("query_text is required for FTS search mode")
|
||||
if search_mode == "hybrid":
|
||||
if query_embedding is None or query_text is None:
|
||||
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
|
||||
if search_mode not in ["vector", "fts", "hybrid", "timestamp"]:
|
||||
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'")
|
||||
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
|
||||
if search_mode == "timestamp":
|
||||
# retrieve most recent items by timestamp
|
||||
query_params = {
|
||||
"rank_by": ("created_at", "desc"),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
query_params["filters"] = filters
|
||||
return await namespace.query(**query_params)
|
||||
|
||||
elif search_mode == "vector":
|
||||
# vector search query
|
||||
query_params = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
query_params["filters"] = filters
|
||||
return await namespace.query(**query_params)
|
||||
|
||||
elif search_mode == "fts":
|
||||
# full-text search query
|
||||
query_params = {
|
||||
"rank_by": ("text", "BM25", query_text),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
query_params["filters"] = filters
|
||||
return await namespace.query(**query_params)
|
||||
|
||||
else: # hybrid mode
|
||||
queries = []
|
||||
|
||||
# vector search query
|
||||
vector_query = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
vector_query["filters"] = filters
|
||||
queries.append(vector_query)
|
||||
|
||||
# full-text search query
|
||||
fts_query = {
|
||||
"rank_by": ("text", "BM25", query_text),
|
||||
"top_k": top_k,
|
||||
"include_attributes": include_attributes,
|
||||
}
|
||||
if filters:
|
||||
fts_query["filters"] = filters
|
||||
queries.append(fts_query)
|
||||
|
||||
# execute multi-query
|
||||
return await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
|
||||
|
||||
@trace_method
|
||||
async def query_passages(
|
||||
self,
|
||||
@@ -188,147 +411,213 @@ class TurbopufferClient:
|
||||
Returns:
|
||||
List of (passage, score) tuples
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
from turbopuffer.types import QueryParam
|
||||
|
||||
# validate inputs based on search mode first
|
||||
if search_mode == "vector" and query_embedding is None:
|
||||
raise ValueError("query_embedding is required for vector search mode")
|
||||
if search_mode == "fts" and query_text is None:
|
||||
raise ValueError("query_text is required for FTS search mode")
|
||||
if search_mode == "hybrid":
|
||||
if query_embedding is None or query_text is None:
|
||||
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
|
||||
if search_mode not in ["vector", "fts", "hybrid", "timestamp"]:
|
||||
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'")
|
||||
|
||||
# Check if we should fallback to timestamp-based retrieval
|
||||
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
||||
# Fallback to retrieving most recent passages when no search query is provided
|
||||
search_mode = "timestamp"
|
||||
|
||||
namespace_name = self._get_namespace_name(archive_id)
|
||||
namespace_name = self._get_archive_namespace_name(archive_id)
|
||||
|
||||
# build tag filter conditions
|
||||
tag_filter = None
|
||||
if tags:
|
||||
if tag_match_mode == TagMatchMode.ALL:
|
||||
# For ALL mode, need to check each tag individually with Contains
|
||||
tag_conditions = []
|
||||
for tag in tags:
|
||||
tag_conditions.append(("tags", "Contains", tag))
|
||||
if len(tag_conditions) == 1:
|
||||
tag_filter = tag_conditions[0]
|
||||
else:
|
||||
tag_filter = ("And", tag_conditions)
|
||||
else: # tag_match_mode == TagMatchMode.ANY
|
||||
# For ANY mode, use ContainsAny to match any of the tags
|
||||
tag_filter = ("tags", "ContainsAny", tags)
|
||||
|
||||
# build date filter conditions
|
||||
date_filters = []
|
||||
if start_date:
|
||||
date_filters.append(("created_at", "Gte", start_date))
|
||||
if end_date:
|
||||
date_filters.append(("created_at", "Lte", end_date))
|
||||
|
||||
# combine all filters
|
||||
all_filters = []
|
||||
if tag_filter:
|
||||
all_filters.append(tag_filter)
|
||||
if date_filters:
|
||||
all_filters.extend(date_filters)
|
||||
|
||||
# create final filter expression
|
||||
final_filter = None
|
||||
if len(all_filters) == 1:
|
||||
final_filter = all_filters[0]
|
||||
elif len(all_filters) > 1:
|
||||
final_filter = ("And", all_filters)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
# use generic query executor
|
||||
result = await self._execute_query(
|
||||
namespace_name=namespace_name,
|
||||
search_mode=search_mode,
|
||||
query_embedding=query_embedding,
|
||||
query_text=query_text,
|
||||
top_k=top_k,
|
||||
include_attributes=["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
filters=final_filter,
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
)
|
||||
|
||||
# build tag filter conditions
|
||||
tag_filter = None
|
||||
if tags:
|
||||
if tag_match_mode == TagMatchMode.ALL:
|
||||
# For ALL mode, need to check each tag individually with Contains
|
||||
tag_conditions = []
|
||||
for tag in tags:
|
||||
tag_conditions.append(("tags", "Contains", tag))
|
||||
if len(tag_conditions) == 1:
|
||||
tag_filter = tag_conditions[0]
|
||||
else:
|
||||
tag_filter = ("And", tag_conditions)
|
||||
else: # tag_match_mode == TagMatchMode.ANY
|
||||
# For ANY mode, use ContainsAny to match any of the tags
|
||||
tag_filter = ("tags", "ContainsAny", tags)
|
||||
|
||||
# build date filter conditions
|
||||
date_filters = []
|
||||
if start_date:
|
||||
# Turbopuffer expects datetime objects directly for comparison
|
||||
date_filters.append(("created_at", "Gte", start_date))
|
||||
if end_date:
|
||||
# Turbopuffer expects datetime objects directly for comparison
|
||||
date_filters.append(("created_at", "Lte", end_date))
|
||||
|
||||
# combine all filters
|
||||
all_filters = []
|
||||
if tag_filter:
|
||||
all_filters.append(tag_filter)
|
||||
if date_filters:
|
||||
all_filters.extend(date_filters)
|
||||
|
||||
# create final filter expression
|
||||
final_filter = None
|
||||
if len(all_filters) == 1:
|
||||
final_filter = all_filters[0]
|
||||
elif len(all_filters) > 1:
|
||||
final_filter = ("And", all_filters)
|
||||
|
||||
if search_mode == "timestamp":
|
||||
# Fallback: retrieve most recent passages by timestamp
|
||||
query_params = {
|
||||
"rank_by": ("created_at", "desc"), # Order by created_at in descending order
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if final_filter:
|
||||
query_params["filters"] = final_filter
|
||||
|
||||
result = await namespace.query(**query_params)
|
||||
return self._process_single_query_results(result, archive_id, tags)
|
||||
|
||||
elif search_mode == "vector":
|
||||
# single vector search query
|
||||
query_params = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if final_filter:
|
||||
query_params["filters"] = final_filter
|
||||
|
||||
result = await namespace.query(**query_params)
|
||||
return self._process_single_query_results(result, archive_id, tags)
|
||||
|
||||
elif search_mode == "fts":
|
||||
# single full-text search query
|
||||
query_params = {
|
||||
"rank_by": ("text", "BM25", query_text),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if final_filter:
|
||||
query_params["filters"] = final_filter
|
||||
|
||||
result = await namespace.query(**query_params)
|
||||
return self._process_single_query_results(result, archive_id, tags, is_fts=True)
|
||||
|
||||
else: # hybrid mode
|
||||
# multi-query for both vector and FTS
|
||||
queries = []
|
||||
|
||||
# vector search query
|
||||
vector_query = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if final_filter:
|
||||
vector_query["filters"] = final_filter
|
||||
queries.append(vector_query)
|
||||
|
||||
# full-text search query
|
||||
fts_query = {
|
||||
"rank_by": ("text", "BM25", query_text),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if final_filter:
|
||||
fts_query["filters"] = final_filter
|
||||
queries.append(fts_query)
|
||||
|
||||
# execute multi-query
|
||||
response = await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
|
||||
|
||||
# process and combine results using reciprocal rank fusion
|
||||
vector_results = self._process_single_query_results(response.results[0], archive_id, tags)
|
||||
fts_results = self._process_single_query_results(response.results[1], archive_id, tags, is_fts=True)
|
||||
|
||||
# combine results using reciprocal rank fusion
|
||||
return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k)
|
||||
# process results based on search mode
|
||||
if search_mode == "hybrid":
|
||||
# for hybrid mode, we get a multi-query response
|
||||
vector_results = self._process_single_query_results(result.results[0], archive_id, tags)
|
||||
fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True)
|
||||
# use backwards-compatible wrapper which calls generic RRF
|
||||
return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k)
|
||||
else:
|
||||
# for single queries (vector, fts, timestamp)
|
||||
is_fts = search_mode == "fts"
|
||||
return self._process_single_query_results(result, archive_id, tags, is_fts=is_fts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query passages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def query_messages(
|
||||
self,
|
||||
agent_id: str,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
|
||||
top_k: int = 10,
|
||||
roles: Optional[List[MessageRole]] = None,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
) -> List[Tuple[dict, float]]:
|
||||
"""Query messages from Turbopuffer using vector search, full-text search, or hybrid search.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes)
|
||||
query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
|
||||
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector")
|
||||
top_k: Number of results to return
|
||||
roles: Optional list of message roles to filter by
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
start_date: Optional datetime to filter messages created after this date
|
||||
end_date: Optional datetime to filter messages created before this date
|
||||
|
||||
Returns:
|
||||
List of (message_dict, score) tuples where message_dict contains id, text, role, created_at
|
||||
"""
|
||||
# Check if we should fallback to timestamp-based retrieval
|
||||
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
||||
# Fallback to retrieving most recent messages when no search query is provided
|
||||
search_mode = "timestamp"
|
||||
|
||||
namespace_name = self._get_message_namespace_name(agent_id)
|
||||
|
||||
# build role filter conditions
|
||||
role_filter = None
|
||||
if roles:
|
||||
role_values = [r.value for r in roles]
|
||||
if len(role_values) == 1:
|
||||
role_filter = ("role", "Eq", role_values[0])
|
||||
else:
|
||||
role_filter = ("role", "In", role_values)
|
||||
|
||||
# build date filter conditions
|
||||
date_filters = []
|
||||
if start_date:
|
||||
date_filters.append(("created_at", "Gte", start_date))
|
||||
if end_date:
|
||||
date_filters.append(("created_at", "Lte", end_date))
|
||||
|
||||
# combine all filters
|
||||
all_filters = []
|
||||
if role_filter:
|
||||
all_filters.append(role_filter)
|
||||
if date_filters:
|
||||
all_filters.extend(date_filters)
|
||||
|
||||
# create final filter expression
|
||||
final_filter = None
|
||||
if len(all_filters) == 1:
|
||||
final_filter = all_filters[0]
|
||||
elif len(all_filters) > 1:
|
||||
final_filter = ("And", all_filters)
|
||||
|
||||
try:
|
||||
# use generic query executor
|
||||
result = await self._execute_query(
|
||||
namespace_name=namespace_name,
|
||||
search_mode=search_mode,
|
||||
query_embedding=query_embedding,
|
||||
query_text=query_text,
|
||||
top_k=top_k,
|
||||
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
|
||||
filters=final_filter,
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
)
|
||||
|
||||
# process results based on search mode
|
||||
if search_mode == "hybrid":
|
||||
# for hybrid mode, we get a multi-query response
|
||||
vector_results = self._process_message_query_results(result.results[0])
|
||||
fts_results = self._process_message_query_results(result.results[1], is_fts=True)
|
||||
# use generic RRF with lambda to extract ID from dict
|
||||
return self._generic_reciprocal_rank_fusion(
|
||||
vector_results=vector_results,
|
||||
fts_results=fts_results,
|
||||
get_id_func=lambda msg_dict: msg_dict["id"],
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
top_k=top_k,
|
||||
)
|
||||
else:
|
||||
# for single queries (vector, fts, timestamp)
|
||||
is_fts = search_mode == "fts"
|
||||
return self._process_message_query_results(result, is_fts=is_fts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
def _process_message_query_results(self, result, is_fts: bool = False) -> List[Tuple[dict, float]]:
|
||||
"""Process results from a message query into message dicts with scores."""
|
||||
messages_with_scores = []
|
||||
|
||||
for row in result.rows:
|
||||
# Build message dict with key fields
|
||||
message_dict = {
|
||||
"id": row.id,
|
||||
"text": getattr(row, "text", ""),
|
||||
"organization_id": getattr(row, "organization_id", None),
|
||||
"agent_id": getattr(row, "agent_id", None),
|
||||
"role": getattr(row, "role", None),
|
||||
"created_at": getattr(row, "created_at", None),
|
||||
}
|
||||
|
||||
# handle score based on search type
|
||||
if is_fts:
|
||||
# for FTS, use the BM25 score directly (higher is better)
|
||||
score = getattr(row, "$score", 0.0)
|
||||
else:
|
||||
# for vector search, convert distance to similarity score
|
||||
distance = getattr(row, "$dist", 0.0)
|
||||
score = 1.0 - distance
|
||||
|
||||
messages_with_scores.append((message_dict, score))
|
||||
|
||||
return messages_with_scores
|
||||
|
||||
def _process_single_query_results(
|
||||
self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False
|
||||
) -> List[Tuple[PydanticPassage, float]]:
|
||||
@@ -369,6 +658,56 @@ class TurbopufferClient:
|
||||
|
||||
return passages_with_scores
|
||||
|
||||
def _generic_reciprocal_rank_fusion(
|
||||
self,
|
||||
vector_results: List[Tuple[Any, float]],
|
||||
fts_results: List[Tuple[Any, float]],
|
||||
get_id_func: Callable[[Any], str],
|
||||
vector_weight: float,
|
||||
fts_weight: float,
|
||||
top_k: int,
|
||||
) -> List[Tuple[Any, float]]:
|
||||
"""Generic RRF implementation that works with any object type.
|
||||
|
||||
RRF score = vector_weight * (1/(k + vector_rank)) + fts_weight * (1/(k + fts_rank))
|
||||
where k is a constant (typically 60) to avoid division by zero
|
||||
|
||||
Args:
|
||||
vector_results: List of (item, score) tuples from vector search
|
||||
fts_results: List of (item, score) tuples from FTS
|
||||
get_id_func: Function to extract ID from an item
|
||||
vector_weight: Weight for vector search results
|
||||
fts_weight: Weight for FTS results
|
||||
top_k: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of (item, score) tuples sorted by RRF score
|
||||
"""
|
||||
k = 60 # standard RRF constant
|
||||
|
||||
# create rank mappings using the get_id_func
|
||||
vector_ranks = {get_id_func(item): rank + 1 for rank, (item, _) in enumerate(vector_results)}
|
||||
fts_ranks = {get_id_func(item): rank + 1 for rank, (item, _) in enumerate(fts_results)}
|
||||
|
||||
# combine all unique items
|
||||
all_items = {}
|
||||
for item, _ in vector_results:
|
||||
all_items[get_id_func(item)] = item
|
||||
for item, _ in fts_results:
|
||||
all_items[get_id_func(item)] = item
|
||||
|
||||
# calculate RRF scores
|
||||
rrf_scores = {}
|
||||
for item_id in all_items:
|
||||
vector_score = vector_weight / (k + vector_ranks.get(item_id, k + top_k))
|
||||
fts_score = fts_weight / (k + fts_ranks.get(item_id, k + top_k))
|
||||
rrf_scores[item_id] = vector_score + fts_score
|
||||
|
||||
# sort by RRF score and return top_k
|
||||
sorted_results = sorted([(all_items[iid], score) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True)
|
||||
|
||||
return sorted_results[:top_k]
|
||||
|
||||
def _reciprocal_rank_fusion(
|
||||
self,
|
||||
vector_results: List[Tuple[PydanticPassage, float]],
|
||||
@@ -377,42 +716,22 @@ class TurbopufferClient:
|
||||
fts_weight: float,
|
||||
top_k: int,
|
||||
) -> List[Tuple[PydanticPassage, float]]:
|
||||
"""Combine vector and FTS results using Reciprocal Rank Fusion (RRF).
|
||||
|
||||
RRF score = vector_weight * (1/(k + vector_rank)) + fts_weight * (1/(k + fts_rank))
|
||||
where k is a constant (typically 60) to avoid division by zero
|
||||
"""
|
||||
k = 60 # standard RRF constant
|
||||
|
||||
# create rank mappings
|
||||
vector_ranks = {passage.id: rank + 1 for rank, (passage, _) in enumerate(vector_results)}
|
||||
fts_ranks = {passage.id: rank + 1 for rank, (passage, _) in enumerate(fts_results)}
|
||||
|
||||
# combine all unique passage IDs
|
||||
all_passages = {}
|
||||
for passage, _ in vector_results:
|
||||
all_passages[passage.id] = passage
|
||||
for passage, _ in fts_results:
|
||||
all_passages[passage.id] = passage
|
||||
|
||||
# calculate RRF scores
|
||||
rrf_scores = {}
|
||||
for passage_id in all_passages:
|
||||
vector_score = vector_weight / (k + vector_ranks.get(passage_id, k + top_k))
|
||||
fts_score = fts_weight / (k + fts_ranks.get(passage_id, k + top_k))
|
||||
rrf_scores[passage_id] = vector_score + fts_score
|
||||
|
||||
# sort by RRF score and return top_k
|
||||
sorted_results = sorted([(all_passages[pid], score) for pid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True)
|
||||
|
||||
return sorted_results[:top_k]
|
||||
"""Wrapper for backwards compatibility - uses generic RRF for passages."""
|
||||
return self._generic_reciprocal_rank_fusion(
|
||||
vector_results=vector_results,
|
||||
fts_results=fts_results,
|
||||
get_id_func=lambda p: p.id,
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
|
||||
"""Delete a passage from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = self._get_namespace_name(archive_id)
|
||||
namespace_name = self._get_archive_namespace_name(archive_id)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
@@ -433,7 +752,7 @@ class TurbopufferClient:
|
||||
if not passage_ids:
|
||||
return True
|
||||
|
||||
namespace_name = self._get_namespace_name(archive_id)
|
||||
namespace_name = self._get_archive_namespace_name(archive_id)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
@@ -451,7 +770,7 @@ class TurbopufferClient:
|
||||
"""Delete all passages for an archive from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = self._get_namespace_name(archive_id)
|
||||
namespace_name = self._get_archive_namespace_name(archive_id)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
@@ -463,3 +782,21 @@ class TurbopufferClient:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete all passages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def delete_all_messages(self, agent_id: str) -> bool:
|
||||
"""Delete all messages for an agent from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = self._get_message_namespace_name(agent_id)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
# Turbopuffer has a delete_all() method on namespace
|
||||
await namespace.delete_all()
|
||||
logger.info(f"Successfully deleted all messages for agent {agent_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete all messages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from sqlalchemy import delete, exists, func, select, text
|
||||
@@ -11,7 +12,7 @@ from letta.orm.message import Message as MessageModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessageUpdateUnion
|
||||
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType
|
||||
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType, TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage, MessageUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
@@ -30,6 +31,34 @@ class MessageManager:
|
||||
"""Initialize the MessageManager."""
|
||||
self.file_manager = FileManager()
|
||||
|
||||
def _extract_message_text(self, message: PydanticMessage) -> str:
|
||||
"""Extract text content from a message's complex content structure.
|
||||
|
||||
Args:
|
||||
message: The message to extract text from
|
||||
|
||||
Returns:
|
||||
Concatenated text content from the message
|
||||
"""
|
||||
# TODO: Make this much more complex/extend to beyond text content
|
||||
if not message.content:
|
||||
return ""
|
||||
|
||||
# handle string content (legacy)
|
||||
if isinstance(message.content, str):
|
||||
return message.content
|
||||
|
||||
# handle list of content items
|
||||
text_parts = []
|
||||
for content_item in message.content:
|
||||
if isinstance(content_item, TextContent):
|
||||
text_parts.append(content_item.text)
|
||||
elif hasattr(content_item, "text"):
|
||||
# handle other content types that might have text
|
||||
text_parts.append(content_item.text)
|
||||
|
||||
return " ".join(text_parts)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
|
||||
@@ -125,13 +154,19 @@ class MessageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_many_messages_async(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]:
|
||||
async def create_many_messages_async(
|
||||
self,
|
||||
pydantic_msgs: List[PydanticMessage],
|
||||
actor: PydanticUser,
|
||||
embedding_config: Optional[dict] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Create multiple messages in a single database transaction asynchronously.
|
||||
|
||||
Args:
|
||||
pydantic_msgs: List of Pydantic message models to create
|
||||
actor: User performing the action
|
||||
embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer
|
||||
|
||||
Returns:
|
||||
List of created Pydantic message models
|
||||
@@ -169,6 +204,62 @@ class MessageManager:
|
||||
created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor, no_commit=True, no_refresh=True)
|
||||
result = [msg.to_pydantic() for msg in created_messages]
|
||||
await session.commit()
|
||||
|
||||
# embed messages in turbopuffer if enabled and embedding_config provided
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||||
|
||||
if should_use_tpuf_for_messages() and embedding_config and result:
|
||||
try:
|
||||
# extract agent_id from the first message (all should have same agent_id)
|
||||
agent_id = result[0].agent_id
|
||||
if agent_id:
|
||||
# extract text content from each message
|
||||
message_texts = []
|
||||
message_ids = []
|
||||
roles = []
|
||||
created_ats = []
|
||||
|
||||
for msg in result:
|
||||
text = self._extract_message_text(msg)
|
||||
if text: # only embed messages with text content
|
||||
message_texts.append(text)
|
||||
message_ids.append(msg.id)
|
||||
roles.append(msg.role)
|
||||
created_ats.append(msg.created_at)
|
||||
|
||||
if message_texts:
|
||||
# generate embeddings using provided config
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
# extract provider info from embedding_config
|
||||
embedding_provider = embedding_config.get("provider", "openai")
|
||||
embedding_api_key = embedding_config.get("api_key")
|
||||
embedding_endpoint = embedding_config.get("endpoint", "https://api.openai.com/v1")
|
||||
|
||||
embedding_client = LLMClient(
|
||||
llm_provider_type=embedding_provider,
|
||||
api_key=embedding_api_key,
|
||||
endpoint=embedding_endpoint,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings(message_texts, embedding_config)
|
||||
|
||||
# insert to turbopuffer
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=actor.organization_id,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
|
||||
except Exception as e:
|
||||
# log error but don't fail the message creation
|
||||
logger.error(f"Failed to embed messages in Turbopuffer: {e}")
|
||||
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@@ -672,3 +763,116 @@ class MessageManager:
|
||||
|
||||
# return the number of rows deleted
|
||||
return result.rowcount
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def search_messages_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: PydanticUser,
|
||||
query_text: Optional[str] = None,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
search_mode: str = "hybrid",
|
||||
roles: Optional[List[MessageRole]] = None,
|
||||
limit: int = 50,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
embedding_config: Optional[dict] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Search messages using Turbopuffer if enabled, otherwise fall back to SQL search.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent whose messages to search
|
||||
actor: User performing the search
|
||||
query_text: Text query for full-text search
|
||||
query_embedding: Optional pre-computed embedding for vector search
|
||||
search_mode: "vector", "fts", "hybrid", or "timestamp" (default: "hybrid")
|
||||
roles: Optional list of message roles to filter by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional filter for messages created after this date
|
||||
end_date: Optional filter for messages created before this date
|
||||
embedding_config: Optional embedding configuration for generating query embedding
|
||||
|
||||
Returns:
|
||||
List of matching messages
|
||||
"""
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||||
|
||||
# check if we should use turbopuffer
|
||||
if should_use_tpuf_for_messages():
|
||||
try:
|
||||
# generate embedding if needed and not provided
|
||||
if search_mode in ["vector", "hybrid"] and query_embedding is None and query_text:
|
||||
if not embedding_config:
|
||||
# fall back to SQL search if no embedding config
|
||||
logger.warning("No embedding config provided for vector search, falling back to SQL")
|
||||
return await self.list_messages_for_agent_async(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
query_text=query_text,
|
||||
roles=roles,
|
||||
limit=limit,
|
||||
ascending=False,
|
||||
)
|
||||
|
||||
# generate embedding from query text
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
embedding_provider = embedding_config.get("provider", "openai")
|
||||
embedding_api_key = embedding_config.get("api_key")
|
||||
embedding_endpoint = embedding_config.get("endpoint", "https://api.openai.com/v1")
|
||||
|
||||
embedding_client = LLMClient(
|
||||
llm_provider_type=embedding_provider,
|
||||
api_key=embedding_api_key,
|
||||
endpoint=embedding_endpoint,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
|
||||
query_embedding = embeddings[0]
|
||||
|
||||
# use turbopuffer for search
|
||||
tpuf_client = TurbopufferClient()
|
||||
results = await tpuf_client.query_messages(
|
||||
agent_id=agent_id,
|
||||
query_embedding=query_embedding,
|
||||
query_text=query_text,
|
||||
search_mode=search_mode,
|
||||
top_k=limit,
|
||||
roles=roles,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# fetch full message objects from database using the IDs
|
||||
message_ids = [msg_dict["id"] for msg_dict, _ in results]
|
||||
if message_ids:
|
||||
messages = await self.get_messages_by_ids_async(message_ids, actor)
|
||||
# maintain the order from turbopuffer results
|
||||
message_dict = {msg.id: msg for msg in messages}
|
||||
return [message_dict[msg_id] for msg_id in message_ids if msg_id in message_dict]
|
||||
else:
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search messages with Turbopuffer, falling back to SQL: {e}")
|
||||
# fall back to SQL search
|
||||
return await self.list_messages_for_agent_async(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
query_text=query_text,
|
||||
roles=roles,
|
||||
limit=limit,
|
||||
ascending=False,
|
||||
)
|
||||
else:
|
||||
# use sql-based search
|
||||
return await self.list_messages_for_agent_async(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
query_text=query_text,
|
||||
roles=roles,
|
||||
limit=limit,
|
||||
ascending=False,
|
||||
)
|
||||
|
||||
@@ -301,6 +301,7 @@ class Settings(BaseSettings):
|
||||
use_tpuf: bool = False
|
||||
tpuf_api_key: Optional[str] = None
|
||||
tpuf_region: str = "gcp-us-central1"
|
||||
embed_all_messages: bool = False
|
||||
|
||||
# File processing timeout settings
|
||||
file_processing_timeout_minutes: int = 30
|
||||
|
||||
@@ -4,9 +4,11 @@ from datetime import datetime, timezone
|
||||
import pytest
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf, should_use_tpuf_for_messages
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import TagMatchMode, VectorDBProvider
|
||||
from letta.schemas.enums import MessageRole, TagMatchMode, VectorDBProvider
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
@@ -69,6 +71,48 @@ def enable_turbopuffer():
|
||||
settings.environment = original_environment
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_message_embedding():
|
||||
"""Enable both Turbopuffer and message embedding"""
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_api_key = settings.tpuf_api_key
|
||||
original_embed_messages = settings.embed_all_messages
|
||||
original_environment = settings.environment
|
||||
|
||||
settings.use_tpuf = True
|
||||
settings.tpuf_api_key = settings.tpuf_api_key or "test-key"
|
||||
settings.embed_all_messages = True
|
||||
settings.environment = "DEV"
|
||||
|
||||
yield
|
||||
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_api_key
|
||||
settings.embed_all_messages = original_embed_messages
|
||||
settings.environment = original_environment
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_turbopuffer():
|
||||
"""Ensure Turbopuffer is disabled for testing"""
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_embed_messages = settings.embed_all_messages
|
||||
|
||||
settings.use_tpuf = False
|
||||
settings.embed_all_messages = False
|
||||
|
||||
yield
|
||||
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.embed_all_messages = original_embed_messages
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embedding_config():
|
||||
"""Provide a sample embedding configuration"""
|
||||
return EmbeddingConfig.default_config(model_name="letta")
|
||||
|
||||
|
||||
class TestTurbopufferIntegration:
|
||||
"""Test Turbopuffer integration functionality with real connections"""
|
||||
|
||||
@@ -374,14 +418,6 @@ class TestTurbopufferIntegration:
|
||||
# all results should have scores
|
||||
assert all(isinstance(score, float) for _, score in vector_heavy_results)
|
||||
|
||||
# Test error handling - missing embedding for vector mode
|
||||
with pytest.raises(ValueError, match="query_embedding is required for vector search mode"):
|
||||
await client.query_passages(archive_id=archive_id, search_mode="vector", top_k=3)
|
||||
|
||||
# Test error handling - missing text for FTS mode
|
||||
with pytest.raises(ValueError, match="query_text is required for FTS search mode"):
|
||||
await client.query_passages(archive_id=archive_id, search_mode="fts", top_k=3)
|
||||
|
||||
# Test error handling - missing text for hybrid mode (embedding provided but text missing)
|
||||
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
|
||||
await client.query_passages(archive_id=archive_id, query_embedding=[1.0, 2.0, 3.0], search_mode="hybrid", top_k=3)
|
||||
@@ -751,3 +787,645 @@ class TestTurbopufferParametrized:
|
||||
# Clean up
|
||||
await server.passage_manager.delete_agent_passages_async(recent_passage, default_user)
|
||||
await server.passage_manager.delete_agent_passages_async(old_passage, default_user)
|
||||
|
||||
|
||||
class TestTurbopufferMessagesIntegration:
|
||||
"""Test Turbopuffer message embedding functionality"""
|
||||
|
||||
def test_should_use_tpuf_for_messages_settings(self):
|
||||
"""Test that should_use_tpuf_for_messages correctly checks both use_tpuf AND embed_all_messages"""
|
||||
# Save original values
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_api_key = settings.tpuf_api_key
|
||||
original_embed_messages = settings.embed_all_messages
|
||||
|
||||
try:
|
||||
# Test when both are true
|
||||
settings.use_tpuf = True
|
||||
settings.tpuf_api_key = "test-key"
|
||||
settings.embed_all_messages = True
|
||||
assert should_use_tpuf_for_messages() is True
|
||||
|
||||
# Test when use_tpuf is False
|
||||
settings.use_tpuf = False
|
||||
settings.embed_all_messages = True
|
||||
assert should_use_tpuf_for_messages() is False
|
||||
|
||||
# Test when embed_all_messages is False
|
||||
settings.use_tpuf = True
|
||||
settings.tpuf_api_key = "test-key"
|
||||
settings.embed_all_messages = False
|
||||
assert should_use_tpuf_for_messages() is False
|
||||
|
||||
# Test when both are false
|
||||
settings.use_tpuf = False
|
||||
settings.embed_all_messages = False
|
||||
assert should_use_tpuf_for_messages() is False
|
||||
|
||||
# Test when API key is missing
|
||||
settings.use_tpuf = True
|
||||
settings.tpuf_api_key = None
|
||||
settings.embed_all_messages = True
|
||||
assert should_use_tpuf_for_messages() is False
|
||||
finally:
|
||||
# Restore original values
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_api_key
|
||||
settings.embed_all_messages = original_embed_messages
|
||||
|
||||
def test_message_text_extraction(self, server, default_user):
|
||||
"""Test extraction of text from various message content structures"""
|
||||
manager = server.message_manager
|
||||
|
||||
# Test 1: List with single string-like TextContent
|
||||
msg1 = PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Simple text content")],
|
||||
agent_id="test-agent",
|
||||
)
|
||||
text1 = manager._extract_message_text(msg1)
|
||||
assert text1 == "Simple text content"
|
||||
|
||||
# Test 2: List with single TextContent
|
||||
msg2 = PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Single text content")],
|
||||
agent_id="test-agent",
|
||||
)
|
||||
text2 = manager._extract_message_text(msg2)
|
||||
assert text2 == "Single text content"
|
||||
|
||||
# Test 3: List with multiple TextContent items
|
||||
msg3 = PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[
|
||||
TextContent(text="First part"),
|
||||
TextContent(text="Second part"),
|
||||
TextContent(text="Third part"),
|
||||
],
|
||||
agent_id="test-agent",
|
||||
)
|
||||
text3 = manager._extract_message_text(msg3)
|
||||
assert text3 == "First part Second part Third part"
|
||||
|
||||
# Test 4: Empty content
|
||||
msg4 = PydanticMessage(
|
||||
role=MessageRole.system,
|
||||
content=None,
|
||||
agent_id="test-agent",
|
||||
)
|
||||
text4 = manager._extract_message_text(msg4)
|
||||
assert text4 == ""
|
||||
|
||||
# Test 5: Empty list
|
||||
msg5 = PydanticMessage(
|
||||
role=MessageRole.assistant,
|
||||
content=[],
|
||||
agent_id="test-agent",
|
||||
)
|
||||
text5 = manager._extract_message_text(msg5)
|
||||
assert text5 == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_embedding_without_config(self, server, default_user, sarah_agent, enable_message_embedding):
|
||||
"""Test that messages are NOT embedded without embedding_config even when tpuf is enabled"""
|
||||
# Create messages WITHOUT embedding_config
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Test message without embedding config")],
|
||||
agent_id=sarah_agent.id,
|
||||
),
|
||||
PydanticMessage(
|
||||
role=MessageRole.assistant,
|
||||
content=[TextContent(text="Response without embedding config")],
|
||||
agent_id=sarah_agent.id,
|
||||
),
|
||||
]
|
||||
|
||||
# Create messages without embedding_config
|
||||
created = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=messages,
|
||||
actor=default_user,
|
||||
embedding_config=None, # No config provided
|
||||
)
|
||||
|
||||
assert len(created) == 2
|
||||
assert all(msg.agent_id == sarah_agent.id for msg in created)
|
||||
|
||||
# Messages should be in SQL
|
||||
sql_messages = await server.message_manager.list_messages_for_agent_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
limit=10,
|
||||
)
|
||||
assert len(sql_messages) >= 2
|
||||
|
||||
# Clean up
|
||||
message_ids = [msg.id for msg in created]
|
||||
await server.message_manager.delete_messages_by_ids_async(message_ids, default_user)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generic_reciprocal_rank_fusion(self):
|
||||
"""Test the generic RRF function with different object types"""
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
client = TurbopufferClient()
|
||||
|
||||
# Test with passage objects (backward compatibility)
|
||||
p1_id = "passage-78d49031-8502-49c1-a970-45663e9f6e07"
|
||||
p2_id = "passage-90df8386-4caf-49cc-acbc-d71526de6f77"
|
||||
passage1 = Passage(
|
||||
id=p1_id,
|
||||
text="First passage",
|
||||
organization_id="org1",
|
||||
archive_id="archive1",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
metadata_={},
|
||||
tags=[],
|
||||
embedding=[],
|
||||
embedding_config=None,
|
||||
)
|
||||
passage2 = Passage(
|
||||
id=p2_id,
|
||||
text="Second passage",
|
||||
organization_id="org1",
|
||||
archive_id="archive1",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
metadata_={},
|
||||
tags=[],
|
||||
embedding=[],
|
||||
embedding_config=None,
|
||||
)
|
||||
|
||||
vector_results = [(passage1, 0.9), (passage2, 0.7)]
|
||||
fts_results = [(passage2, 0.8), (passage1, 0.6)]
|
||||
|
||||
# Test with passages using the wrapper function
|
||||
combined = client._reciprocal_rank_fusion(
|
||||
vector_results=vector_results,
|
||||
fts_results=fts_results,
|
||||
vector_weight=0.5,
|
||||
fts_weight=0.5,
|
||||
top_k=2,
|
||||
)
|
||||
|
||||
assert len(combined) == 2
|
||||
# Both passages should be in results
|
||||
result_ids = [p.id for p, _ in combined]
|
||||
assert p1_id in result_ids
|
||||
assert p2_id in result_ids
|
||||
|
||||
# Test with message dicts using generic function
|
||||
msg1 = {"id": "m1", "text": "First message"}
|
||||
msg2 = {"id": "m2", "text": "Second message"}
|
||||
msg3 = {"id": "m3", "text": "Third message"}
|
||||
|
||||
vector_msg_results = [(msg1, 0.95), (msg2, 0.85), (msg3, 0.75)]
|
||||
fts_msg_results = [(msg2, 0.90), (msg3, 0.80), (msg1, 0.70)]
|
||||
|
||||
combined_msgs = client._generic_reciprocal_rank_fusion(
|
||||
vector_results=vector_msg_results,
|
||||
fts_results=fts_msg_results,
|
||||
get_id_func=lambda m: m["id"],
|
||||
vector_weight=0.6,
|
||||
fts_weight=0.4,
|
||||
top_k=3,
|
||||
)
|
||||
|
||||
assert len(combined_msgs) == 3
|
||||
msg_ids = [m["id"] for m, _ in combined_msgs]
|
||||
assert "m1" in msg_ids
|
||||
assert "m2" in msg_ids
|
||||
assert "m3" in msg_ids
|
||||
|
||||
# Test edge cases
|
||||
# Empty results
|
||||
empty_combined = client._generic_reciprocal_rank_fusion(
|
||||
vector_results=[],
|
||||
fts_results=[],
|
||||
get_id_func=lambda x: x["id"],
|
||||
vector_weight=0.5,
|
||||
fts_weight=0.5,
|
||||
top_k=10,
|
||||
)
|
||||
assert len(empty_combined) == 0
|
||||
|
||||
# Single result list
|
||||
single_combined = client._generic_reciprocal_rank_fusion(
|
||||
vector_results=[(msg1, 0.9)],
|
||||
fts_results=[],
|
||||
get_id_func=lambda m: m["id"],
|
||||
vector_weight=0.5,
|
||||
fts_weight=0.5,
|
||||
top_k=10,
|
||||
)
|
||||
assert len(single_combined) == 1
|
||||
assert single_combined[0][0]["id"] == "m1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding):
|
||||
"""Test actual message embedding and storage in Turbopuffer"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.schemas.enums import MessageRole
|
||||
|
||||
client = TurbopufferClient()
|
||||
agent_id = f"test-agent-{uuid.uuid4()}"
|
||||
org_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Prepare test messages
|
||||
message_texts = [
|
||||
"Hello, how can I help you today?",
|
||||
"I need help with Python programming.",
|
||||
"Sure, what specific Python topic?",
|
||||
]
|
||||
message_ids = [str(uuid.uuid4()) for _ in message_texts]
|
||||
roles = [MessageRole.assistant, MessageRole.user, MessageRole.assistant]
|
||||
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
|
||||
|
||||
# Generate embeddings (dummy for test)
|
||||
embeddings = [[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))]
|
||||
|
||||
# Insert messages into Turbopuffer
|
||||
success = await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=org_id,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
|
||||
assert success == True
|
||||
|
||||
# Verify we can query the messages
|
||||
results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
# Results should be ordered by timestamp (most recent first)
|
||||
for msg_dict, score in results:
|
||||
assert msg_dict["agent_id"] == agent_id
|
||||
assert msg_dict["organization_id"] == org_id
|
||||
assert msg_dict["text"] in message_texts
|
||||
assert msg_dict["role"] in ["assistant", "user"]
|
||||
|
||||
finally:
|
||||
# Clean up namespace
|
||||
try:
|
||||
await client.delete_all_messages(agent_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding):
|
||||
"""Test vector search on messages in Turbopuffer"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.schemas.enums import MessageRole
|
||||
|
||||
client = TurbopufferClient()
|
||||
agent_id = f"test-agent-{uuid.uuid4()}"
|
||||
org_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Insert messages with different embeddings
|
||||
message_texts = [
|
||||
"Python is a great programming language",
|
||||
"JavaScript is used for web development",
|
||||
"Machine learning with Python is powerful",
|
||||
]
|
||||
message_ids = [str(uuid.uuid4()) for _ in message_texts]
|
||||
roles = [MessageRole.assistant] * len(message_texts)
|
||||
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
|
||||
|
||||
# Create embeddings that reflect content similarity
|
||||
embeddings = [
|
||||
[1.0, 0.0, 0.0], # Python programming
|
||||
[0.0, 1.0, 0.0], # JavaScript web
|
||||
[0.8, 0.0, 0.2], # ML with Python (similar to first)
|
||||
]
|
||||
|
||||
# Insert messages
|
||||
await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=org_id,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
|
||||
# Search for Python-related messages using vector search
|
||||
query_embedding = [0.9, 0.0, 0.1] # Similar to Python messages
|
||||
results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
query_embedding=query_embedding,
|
||||
search_mode="vector",
|
||||
top_k=2,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
# Should return Python-related messages first
|
||||
result_texts = [msg["text"] for msg, _ in results]
|
||||
assert "Python is a great programming language" in result_texts
|
||||
assert "Machine learning with Python is powerful" in result_texts
|
||||
|
||||
finally:
|
||||
# Clean up namespace
|
||||
try:
|
||||
await client.delete_all_messages(agent_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding):
|
||||
"""Test hybrid search combining vector and FTS for messages"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.schemas.enums import MessageRole
|
||||
|
||||
client = TurbopufferClient()
|
||||
agent_id = f"test-agent-{uuid.uuid4()}"
|
||||
org_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Insert diverse messages
|
||||
message_texts = [
|
||||
"The quick brown fox jumps over the lazy dog",
|
||||
"Machine learning algorithms are fascinating",
|
||||
"Quick tutorial on Python programming",
|
||||
"Deep learning with neural networks",
|
||||
]
|
||||
message_ids = [str(uuid.uuid4()) for _ in message_texts]
|
||||
roles = [MessageRole.assistant] * len(message_texts)
|
||||
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
|
||||
|
||||
# Embeddings
|
||||
embeddings = [
|
||||
[0.1, 0.9, 0.0], # fox text
|
||||
[0.9, 0.1, 0.0], # ML algorithms
|
||||
[0.5, 0.5, 0.0], # Quick Python
|
||||
[0.8, 0.2, 0.0], # Deep learning
|
||||
]
|
||||
|
||||
# Insert messages
|
||||
await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=org_id,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
|
||||
# Hybrid search - vector similar to ML but text contains "quick"
|
||||
results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
query_embedding=[0.7, 0.3, 0.0], # Similar to ML messages
|
||||
query_text="quick", # Text search for "quick"
|
||||
search_mode="hybrid",
|
||||
top_k=3,
|
||||
vector_weight=0.5,
|
||||
fts_weight=0.5,
|
||||
)
|
||||
|
||||
assert len(results) > 0
|
||||
# Should get a mix of results based on both vector and text similarity
|
||||
result_texts = [msg["text"] for msg, _ in results]
|
||||
# At least one result should contain "quick" due to FTS
|
||||
assert any("quick" in text.lower() for text in result_texts)
|
||||
|
||||
finally:
|
||||
# Clean up namespace
|
||||
try:
|
||||
await client.delete_all_messages(agent_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding):
|
||||
"""Test filtering messages by role"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.schemas.enums import MessageRole
|
||||
|
||||
client = TurbopufferClient()
|
||||
agent_id = f"test-agent-{uuid.uuid4()}"
|
||||
org_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Insert messages with different roles
|
||||
message_data = [
|
||||
("Hello! How can I help?", MessageRole.assistant),
|
||||
("I need help with Python", MessageRole.user),
|
||||
("Here's a Python example", MessageRole.assistant),
|
||||
("Can you explain this?", MessageRole.user),
|
||||
("System message here", MessageRole.system),
|
||||
]
|
||||
|
||||
message_texts = [text for text, _ in message_data]
|
||||
roles = [role for _, role in message_data]
|
||||
message_ids = [str(uuid.uuid4()) for _ in message_texts]
|
||||
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
|
||||
embeddings = [[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))]
|
||||
|
||||
# Insert messages
|
||||
await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=org_id,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
|
||||
# Query only user messages
|
||||
user_results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
roles=[MessageRole.user],
|
||||
)
|
||||
|
||||
assert len(user_results) == 2
|
||||
for msg, _ in user_results:
|
||||
assert msg["role"] == "user"
|
||||
assert msg["text"] in ["I need help with Python", "Can you explain this?"]
|
||||
|
||||
# Query assistant and system messages
|
||||
non_user_results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
roles=[MessageRole.assistant, MessageRole.system],
|
||||
)
|
||||
|
||||
assert len(non_user_results) == 3
|
||||
for msg, _ in non_user_results:
|
||||
assert msg["role"] in ["assistant", "system"]
|
||||
|
||||
finally:
|
||||
# Clean up namespace
|
||||
try:
|
||||
await client.delete_all_messages(agent_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_search_fallback_to_sql(self, server, default_user, sarah_agent):
|
||||
"""Test that message search falls back to SQL when Turbopuffer is disabled"""
|
||||
# Save original settings
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_embed_messages = settings.embed_all_messages
|
||||
|
||||
try:
|
||||
# Disable Turbopuffer for messages
|
||||
settings.use_tpuf = False
|
||||
settings.embed_all_messages = False
|
||||
|
||||
# Create messages
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Test message for SQL fallback")],
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Search should use SQL backend (not Turbopuffer)
|
||||
results = await server.message_manager.search_messages_async(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="fallback",
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Should return results from SQL search
|
||||
assert len(results) > 0
|
||||
# Extract text from messages and check for "fallback"
|
||||
for msg in results:
|
||||
text = server.message_manager._extract_message_text(msg)
|
||||
if "fallback" in text.lower():
|
||||
break
|
||||
else:
|
||||
assert False, "No messages containing 'fallback' found"
|
||||
|
||||
finally:
|
||||
# Restore settings
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.embed_all_messages = original_embed_messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding):
|
||||
"""Test filtering messages by date range"""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.schemas.enums import MessageRole
|
||||
|
||||
client = TurbopufferClient()
|
||||
agent_id = f"test-agent-{uuid.uuid4()}"
|
||||
org_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Create messages with different timestamps
|
||||
now = datetime.now(timezone.utc)
|
||||
yesterday = now - timedelta(days=1)
|
||||
last_week = now - timedelta(days=7)
|
||||
last_month = now - timedelta(days=30)
|
||||
|
||||
message_data = [
|
||||
("Today's message", now),
|
||||
("Yesterday's message", yesterday),
|
||||
("Last week's message", last_week),
|
||||
("Last month's message", last_month),
|
||||
]
|
||||
|
||||
for text, timestamp in message_data:
|
||||
await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=[text],
|
||||
embeddings=[[1.0, 2.0, 3.0]],
|
||||
message_ids=[str(uuid.uuid4())],
|
||||
organization_id=org_id,
|
||||
roles=[MessageRole.assistant],
|
||||
created_ats=[timestamp],
|
||||
)
|
||||
|
||||
# Query messages from the last 3 days
|
||||
three_days_ago = now - timedelta(days=3)
|
||||
recent_results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
start_date=three_days_ago,
|
||||
)
|
||||
|
||||
# Should get today's and yesterday's messages
|
||||
assert len(recent_results) == 2
|
||||
result_texts = [msg["text"] for msg, _ in recent_results]
|
||||
assert "Today's message" in result_texts
|
||||
assert "Yesterday's message" in result_texts
|
||||
|
||||
# Query messages between 2 weeks ago and 1 week ago
|
||||
two_weeks_ago = now - timedelta(days=14)
|
||||
week_results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
start_date=two_weeks_ago,
|
||||
end_date=last_week + timedelta(days=1), # Include last week's message
|
||||
)
|
||||
|
||||
# Should get only last week's message
|
||||
assert len(week_results) == 1
|
||||
assert week_results[0][0]["text"] == "Last week's message"
|
||||
|
||||
# Query with vector search and date filtering
|
||||
filtered_vector_results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
start_date=three_days_ago,
|
||||
)
|
||||
|
||||
# Should get only recent messages
|
||||
assert len(filtered_vector_results) == 2
|
||||
for msg, _ in filtered_vector_results:
|
||||
assert msg["text"] in ["Today's message", "Yesterday's message"]
|
||||
|
||||
finally:
|
||||
# Clean up namespace
|
||||
try:
|
||||
await client.delete_all_messages(agent_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user