diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index 144e443b..64021b09 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -2,10 +2,10 @@ import logging from datetime import datetime, timezone -from typing import List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple from letta.otel.tracing import trace_method -from letta.schemas.enums import TagMatchMode +from letta.schemas.enums import MessageRole, TagMatchMode from letta.schemas.passage import Passage as PydanticPassage from letta.settings import settings @@ -16,6 +16,11 @@ def should_use_tpuf() -> bool: return bool(settings.use_tpuf) and bool(settings.tpuf_api_key) +def should_use_tpuf_for_messages() -> bool: + """Check if Turbopuffer should be used for messages.""" + return should_use_tpuf() and bool(settings.embed_all_messages) + + class TurbopufferClient: """Client for managing archival memory with Turbopuffer vector database.""" @@ -28,7 +33,7 @@ class TurbopufferClient: raise ValueError("Turbopuffer API key not provided") @trace_method - def _get_namespace_name(self, archive_id: str) -> str: + def _get_archive_namespace_name(self, archive_id: str) -> str: """Get namespace name for a specific archive.""" # use archive_id as namespace to isolate different archives' memories # append environment suffix to namespace for isolation if environment is set @@ -39,6 +44,18 @@ class TurbopufferClient: namespace_name = archive_id return namespace_name + @trace_method + def _get_message_namespace_name(self, agent_id: str) -> str: + """Get namespace name for a specific agent's messages.""" + # use agent_id as namespace to isolate different agents' messages + # append environment suffix to namespace for isolation if environment is set + environment = settings.environment + if environment: + namespace_name = f"messages_{agent_id}_{environment.lower()}" + else: + namespace_name = f"messages_{agent_id}" + return namespace_name + @trace_method async def insert_archival_memories( self, @@ -66,7 +83,7 @@ class TurbopufferClient: """ from turbopuffer import AsyncTurbopuffer - namespace_name = self._get_namespace_name(archive_id) + namespace_name = self._get_archive_namespace_name(archive_id) # handle timestamp - ensure UTC if created_at is None: @@ -155,6 +172,212 @@ class TurbopufferClient: logger.error("Duplicate passage IDs detected in batch") raise + @trace_method + async def insert_messages( + self, + agent_id: str, + message_texts: List[str], + embeddings: List[List[float]], + message_ids: List[str], + organization_id: str, + roles: List[MessageRole], + created_ats: List[datetime], + ) -> bool: + """Insert messages into Turbopuffer. + + Args: + agent_id: ID of the agent + message_texts: List of message text content to store + embeddings: List of embedding vectors corresponding to message texts + message_ids: List of message IDs (must match 1:1 with message_texts) + organization_id: Organization ID for the messages + roles: List of message roles corresponding to each message + created_ats: List of creation timestamps for each message + + Returns: + True if successful + """ + from turbopuffer import AsyncTurbopuffer + + namespace_name = self._get_message_namespace_name(agent_id) + + # validation checks + if not message_ids: + raise ValueError("message_ids must be provided for Turbopuffer insertion") + if len(message_ids) != len(message_texts): + raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})") + if len(message_ids) != len(embeddings): + raise ValueError(f"message_ids length ({len(message_ids)}) must match embeddings length ({len(embeddings)})") + if len(message_ids) != len(roles): + raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})") + if len(message_ids) != len(created_ats): + raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})") + + # prepare column-based data for turbopuffer - optimized for batch insert + ids = [] + vectors = [] + texts = [] + organization_ids = [] + agent_ids = [] + message_roles = [] + created_at_timestamps = [] + + for idx, (text, embedding, role, created_at) in enumerate(zip(message_texts, embeddings, roles, created_ats)): + message_id = message_ids[idx] + + # ensure the provided timestamp is timezone-aware and in UTC + if created_at.tzinfo is None: + # assume UTC if no timezone provided + timestamp = created_at.replace(tzinfo=timezone.utc) + else: + # convert to UTC if in different timezone + timestamp = created_at.astimezone(timezone.utc) + + # append to columns + ids.append(message_id) + vectors.append(embedding) + texts.append(text) + organization_ids.append(organization_id) + agent_ids.append(agent_id) + message_roles.append(role.value) + created_at_timestamps.append(timestamp) + + # build column-based upsert data + upsert_columns = { + "id": ids, + "vector": vectors, + "text": texts, + "organization_id": organization_ids, + "agent_id": agent_ids, + "role": message_roles, + "created_at": created_at_timestamps, + } + + try: + # Use AsyncTurbopuffer as a context manager for proper resource cleanup + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + # turbopuffer recommends column-based writes for performance + await namespace.write( + upsert_columns=upsert_columns, + distance_metric="cosine_distance", + schema={"text": {"type": "string", "full_text_search": True}}, + ) + logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}") + return True + + except Exception as e: + logger.error(f"Failed to insert messages to Turbopuffer: {e}") + # check if it's a duplicate ID error + if "duplicate" in str(e).lower(): + logger.error("Duplicate message IDs detected in batch") + raise + + @trace_method + async def _execute_query( + self, + namespace_name: str, + search_mode: str, + query_embedding: Optional[List[float]], + query_text: Optional[str], + top_k: int, + include_attributes: List[str], + filters: Optional[Any] = None, + vector_weight: float = 0.5, + fts_weight: float = 0.5, + ) -> Any: + """Generic query execution for Turbopuffer. + + Args: + namespace_name: Turbopuffer namespace to query + search_mode: "vector", "fts", "hybrid", or "timestamp" + query_embedding: Embedding for vector search + query_text: Text for full-text search + top_k: Number of results to return + include_attributes: Attributes to include in results + filters: Turbopuffer filter expression + vector_weight: Weight for vector search in hybrid mode + fts_weight: Weight for FTS in hybrid mode + + Returns: + Raw Turbopuffer query results or multi-query response + """ + from turbopuffer import AsyncTurbopuffer + from turbopuffer.types import QueryParam + + # validate inputs based on search mode + if search_mode == "vector" and query_embedding is None: + raise ValueError("query_embedding is required for vector search mode") + if search_mode == "fts" and query_text is None: + raise ValueError("query_text is required for FTS search mode") + if search_mode == "hybrid": + if query_embedding is None or query_text is None: + raise ValueError("Both query_embedding and query_text are required for hybrid search mode") + if search_mode not in ["vector", "fts", "hybrid", "timestamp"]: + raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'") + + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + + if search_mode == "timestamp": + # retrieve most recent items by timestamp + query_params = { + "rank_by": ("created_at", "desc"), + "top_k": top_k, + "include_attributes": include_attributes, + } + if filters: + query_params["filters"] = filters + return await namespace.query(**query_params) + + elif search_mode == "vector": + # vector search query + query_params = { + "rank_by": ("vector", "ANN", query_embedding), + "top_k": top_k, + "include_attributes": include_attributes, + } + if filters: + query_params["filters"] = filters + return await namespace.query(**query_params) + + elif search_mode == "fts": + # full-text search query + query_params = { + "rank_by": ("text", "BM25", query_text), + "top_k": top_k, + "include_attributes": include_attributes, + } + if filters: + query_params["filters"] = filters + return await namespace.query(**query_params) + + else: # hybrid mode + queries = [] + + # vector search query + vector_query = { + "rank_by": ("vector", "ANN", query_embedding), + "top_k": top_k, + "include_attributes": include_attributes, + } + if filters: + vector_query["filters"] = filters + queries.append(vector_query) + + # full-text search query + fts_query = { + "rank_by": ("text", "BM25", query_text), + "top_k": top_k, + "include_attributes": include_attributes, + } + if filters: + fts_query["filters"] = filters + queries.append(fts_query) + + # execute multi-query + return await namespace.multi_query(queries=[QueryParam(**q) for q in queries]) + @trace_method async def query_passages( self, @@ -188,147 +411,213 @@ class TurbopufferClient: Returns: List of (passage, score) tuples """ - from turbopuffer import AsyncTurbopuffer - from turbopuffer.types import QueryParam - - # validate inputs based on search mode first - if search_mode == "vector" and query_embedding is None: - raise ValueError("query_embedding is required for vector search mode") - if search_mode == "fts" and query_text is None: - raise ValueError("query_text is required for FTS search mode") - if search_mode == "hybrid": - if query_embedding is None or query_text is None: - raise ValueError("Both query_embedding and query_text are required for hybrid search mode") - if search_mode not in ["vector", "fts", "hybrid", "timestamp"]: - raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'") - # Check if we should fallback to timestamp-based retrieval if query_embedding is None and query_text is None and search_mode not in ["timestamp"]: # Fallback to retrieving most recent passages when no search query is provided search_mode = "timestamp" - namespace_name = self._get_namespace_name(archive_id) + namespace_name = self._get_archive_namespace_name(archive_id) + + # build tag filter conditions + tag_filter = None + if tags: + if tag_match_mode == TagMatchMode.ALL: + # For ALL mode, need to check each tag individually with Contains + tag_conditions = [] + for tag in tags: + tag_conditions.append(("tags", "Contains", tag)) + if len(tag_conditions) == 1: + tag_filter = tag_conditions[0] + else: + tag_filter = ("And", tag_conditions) + else: # tag_match_mode == TagMatchMode.ANY + # For ANY mode, use ContainsAny to match any of the tags + tag_filter = ("tags", "ContainsAny", tags) + + # build date filter conditions + date_filters = [] + if start_date: + date_filters.append(("created_at", "Gte", start_date)) + if end_date: + date_filters.append(("created_at", "Lte", end_date)) + + # combine all filters + all_filters = [] + if tag_filter: + all_filters.append(tag_filter) + if date_filters: + all_filters.extend(date_filters) + + # create final filter expression + final_filter = None + if len(all_filters) == 1: + final_filter = all_filters[0] + elif len(all_filters) > 1: + final_filter = ("And", all_filters) try: - async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: - namespace = client.namespace(namespace_name) + # use generic query executor + result = await self._execute_query( + namespace_name=namespace_name, + search_mode=search_mode, + query_embedding=query_embedding, + query_text=query_text, + top_k=top_k, + include_attributes=["text", "organization_id", "archive_id", "created_at", "tags"], + filters=final_filter, + vector_weight=vector_weight, + fts_weight=fts_weight, + ) - # build tag filter conditions - tag_filter = None - if tags: - if tag_match_mode == TagMatchMode.ALL: - # For ALL mode, need to check each tag individually with Contains - tag_conditions = [] - for tag in tags: - tag_conditions.append(("tags", "Contains", tag)) - if len(tag_conditions) == 1: - tag_filter = tag_conditions[0] - else: - tag_filter = ("And", tag_conditions) - else: # tag_match_mode == TagMatchMode.ANY - # For ANY mode, use ContainsAny to match any of the tags - tag_filter = ("tags", "ContainsAny", tags) - - # build date filter conditions - date_filters = [] - if start_date: - # Turbopuffer expects datetime objects directly for comparison - date_filters.append(("created_at", "Gte", start_date)) - if end_date: - # Turbopuffer expects datetime objects directly for comparison - date_filters.append(("created_at", "Lte", end_date)) - - # combine all filters - all_filters = [] - if tag_filter: - all_filters.append(tag_filter) - if date_filters: - all_filters.extend(date_filters) - - # create final filter expression - final_filter = None - if len(all_filters) == 1: - final_filter = all_filters[0] - elif len(all_filters) > 1: - final_filter = ("And", all_filters) - - if search_mode == "timestamp": - # Fallback: retrieve most recent passages by timestamp - query_params = { - "rank_by": ("created_at", "desc"), # Order by created_at in descending order - "top_k": top_k, - "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], - } - if final_filter: - query_params["filters"] = final_filter - - result = await namespace.query(**query_params) - return self._process_single_query_results(result, archive_id, tags) - - elif search_mode == "vector": - # single vector search query - query_params = { - "rank_by": ("vector", "ANN", query_embedding), - "top_k": top_k, - "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], - } - if final_filter: - query_params["filters"] = final_filter - - result = await namespace.query(**query_params) - return self._process_single_query_results(result, archive_id, tags) - - elif search_mode == "fts": - # single full-text search query - query_params = { - "rank_by": ("text", "BM25", query_text), - "top_k": top_k, - "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], - } - if final_filter: - query_params["filters"] = final_filter - - result = await namespace.query(**query_params) - return self._process_single_query_results(result, archive_id, tags, is_fts=True) - - else: # hybrid mode - # multi-query for both vector and FTS - queries = [] - - # vector search query - vector_query = { - "rank_by": ("vector", "ANN", query_embedding), - "top_k": top_k, - "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], - } - if final_filter: - vector_query["filters"] = final_filter - queries.append(vector_query) - - # full-text search query - fts_query = { - "rank_by": ("text", "BM25", query_text), - "top_k": top_k, - "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], - } - if final_filter: - fts_query["filters"] = final_filter - queries.append(fts_query) - - # execute multi-query - response = await namespace.multi_query(queries=[QueryParam(**q) for q in queries]) - - # process and combine results using reciprocal rank fusion - vector_results = self._process_single_query_results(response.results[0], archive_id, tags) - fts_results = self._process_single_query_results(response.results[1], archive_id, tags, is_fts=True) - - # combine results using reciprocal rank fusion - return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k) + # process results based on search mode + if search_mode == "hybrid": + # for hybrid mode, we get a multi-query response + vector_results = self._process_single_query_results(result.results[0], archive_id, tags) + fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True) + # use backwards-compatible wrapper which calls generic RRF + return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k) + else: + # for single queries (vector, fts, timestamp) + is_fts = search_mode == "fts" + return self._process_single_query_results(result, archive_id, tags, is_fts=is_fts) except Exception as e: logger.error(f"Failed to query passages from Turbopuffer: {e}") raise + @trace_method + async def query_messages( + self, + agent_id: str, + query_embedding: Optional[List[float]] = None, + query_text: Optional[str] = None, + search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp" + top_k: int = 10, + roles: Optional[List[MessageRole]] = None, + vector_weight: float = 0.5, + fts_weight: float = 0.5, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> List[Tuple[dict, float]]: + """Query messages from Turbopuffer using vector search, full-text search, or hybrid search. + + Args: + agent_id: ID of the agent + query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes) + query_text: Text query for full-text search (required for "fts" and "hybrid" modes) + search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector") + top_k: Number of results to return + roles: Optional list of message roles to filter by + vector_weight: Weight for vector search results in hybrid mode (default: 0.5) + fts_weight: Weight for FTS results in hybrid mode (default: 0.5) + start_date: Optional datetime to filter messages created after this date + end_date: Optional datetime to filter messages created before this date + + Returns: + List of (message_dict, score) tuples where message_dict contains id, text, role, created_at + """ + # Check if we should fallback to timestamp-based retrieval + if query_embedding is None and query_text is None and search_mode not in ["timestamp"]: + # Fallback to retrieving most recent messages when no search query is provided + search_mode = "timestamp" + + namespace_name = self._get_message_namespace_name(agent_id) + + # build role filter conditions + role_filter = None + if roles: + role_values = [r.value for r in roles] + if len(role_values) == 1: + role_filter = ("role", "Eq", role_values[0]) + else: + role_filter = ("role", "In", role_values) + + # build date filter conditions + date_filters = [] + if start_date: + date_filters.append(("created_at", "Gte", start_date)) + if end_date: + date_filters.append(("created_at", "Lte", end_date)) + + # combine all filters + all_filters = [] + if role_filter: + all_filters.append(role_filter) + if date_filters: + all_filters.extend(date_filters) + + # create final filter expression + final_filter = None + if len(all_filters) == 1: + final_filter = all_filters[0] + elif len(all_filters) > 1: + final_filter = ("And", all_filters) + + try: + # use generic query executor + result = await self._execute_query( + namespace_name=namespace_name, + search_mode=search_mode, + query_embedding=query_embedding, + query_text=query_text, + top_k=top_k, + include_attributes=["text", "organization_id", "agent_id", "role", "created_at"], + filters=final_filter, + vector_weight=vector_weight, + fts_weight=fts_weight, + ) + + # process results based on search mode + if search_mode == "hybrid": + # for hybrid mode, we get a multi-query response + vector_results = self._process_message_query_results(result.results[0]) + fts_results = self._process_message_query_results(result.results[1], is_fts=True) + # use generic RRF with lambda to extract ID from dict + return self._generic_reciprocal_rank_fusion( + vector_results=vector_results, + fts_results=fts_results, + get_id_func=lambda msg_dict: msg_dict["id"], + vector_weight=vector_weight, + fts_weight=fts_weight, + top_k=top_k, + ) + else: + # for single queries (vector, fts, timestamp) + is_fts = search_mode == "fts" + return self._process_message_query_results(result, is_fts=is_fts) + + except Exception as e: + logger.error(f"Failed to query messages from Turbopuffer: {e}") + raise + + def _process_message_query_results(self, result, is_fts: bool = False) -> List[Tuple[dict, float]]: + """Process results from a message query into message dicts with scores.""" + messages_with_scores = [] + + for row in result.rows: + # Build message dict with key fields + message_dict = { + "id": row.id, + "text": getattr(row, "text", ""), + "organization_id": getattr(row, "organization_id", None), + "agent_id": getattr(row, "agent_id", None), + "role": getattr(row, "role", None), + "created_at": getattr(row, "created_at", None), + } + + # handle score based on search type + if is_fts: + # for FTS, use the BM25 score directly (higher is better) + score = getattr(row, "$score", 0.0) + else: + # for vector search, convert distance to similarity score + distance = getattr(row, "$dist", 0.0) + score = 1.0 - distance + + messages_with_scores.append((message_dict, score)) + + return messages_with_scores + def _process_single_query_results( self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False ) -> List[Tuple[PydanticPassage, float]]: @@ -369,6 +658,56 @@ class TurbopufferClient: return passages_with_scores + def _generic_reciprocal_rank_fusion( + self, + vector_results: List[Tuple[Any, float]], + fts_results: List[Tuple[Any, float]], + get_id_func: Callable[[Any], str], + vector_weight: float, + fts_weight: float, + top_k: int, + ) -> List[Tuple[Any, float]]: + """Generic RRF implementation that works with any object type. + + RRF score = vector_weight * (1/(k + vector_rank)) + fts_weight * (1/(k + fts_rank)) + where k is a constant (typically 60) to avoid division by zero + + Args: + vector_results: List of (item, score) tuples from vector search + fts_results: List of (item, score) tuples from FTS + get_id_func: Function to extract ID from an item + vector_weight: Weight for vector search results + fts_weight: Weight for FTS results + top_k: Number of results to return + + Returns: + List of (item, score) tuples sorted by RRF score + """ + k = 60 # standard RRF constant + + # create rank mappings using the get_id_func + vector_ranks = {get_id_func(item): rank + 1 for rank, (item, _) in enumerate(vector_results)} + fts_ranks = {get_id_func(item): rank + 1 for rank, (item, _) in enumerate(fts_results)} + + # combine all unique items + all_items = {} + for item, _ in vector_results: + all_items[get_id_func(item)] = item + for item, _ in fts_results: + all_items[get_id_func(item)] = item + + # calculate RRF scores + rrf_scores = {} + for item_id in all_items: + vector_score = vector_weight / (k + vector_ranks.get(item_id, k + top_k)) + fts_score = fts_weight / (k + fts_ranks.get(item_id, k + top_k)) + rrf_scores[item_id] = vector_score + fts_score + + # sort by RRF score and return top_k + sorted_results = sorted([(all_items[iid], score) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True) + + return sorted_results[:top_k] + def _reciprocal_rank_fusion( self, vector_results: List[Tuple[PydanticPassage, float]], @@ -377,42 +716,22 @@ class TurbopufferClient: fts_weight: float, top_k: int, ) -> List[Tuple[PydanticPassage, float]]: - """Combine vector and FTS results using Reciprocal Rank Fusion (RRF). - - RRF score = vector_weight * (1/(k + vector_rank)) + fts_weight * (1/(k + fts_rank)) - where k is a constant (typically 60) to avoid division by zero - """ - k = 60 # standard RRF constant - - # create rank mappings - vector_ranks = {passage.id: rank + 1 for rank, (passage, _) in enumerate(vector_results)} - fts_ranks = {passage.id: rank + 1 for rank, (passage, _) in enumerate(fts_results)} - - # combine all unique passage IDs - all_passages = {} - for passage, _ in vector_results: - all_passages[passage.id] = passage - for passage, _ in fts_results: - all_passages[passage.id] = passage - - # calculate RRF scores - rrf_scores = {} - for passage_id in all_passages: - vector_score = vector_weight / (k + vector_ranks.get(passage_id, k + top_k)) - fts_score = fts_weight / (k + fts_ranks.get(passage_id, k + top_k)) - rrf_scores[passage_id] = vector_score + fts_score - - # sort by RRF score and return top_k - sorted_results = sorted([(all_passages[pid], score) for pid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True) - - return sorted_results[:top_k] + """Wrapper for backwards compatibility - uses generic RRF for passages.""" + return self._generic_reciprocal_rank_fusion( + vector_results=vector_results, + fts_results=fts_results, + get_id_func=lambda p: p.id, + vector_weight=vector_weight, + fts_weight=fts_weight, + top_k=top_k, + ) @trace_method async def delete_passage(self, archive_id: str, passage_id: str) -> bool: """Delete a passage from Turbopuffer.""" from turbopuffer import AsyncTurbopuffer - namespace_name = self._get_namespace_name(archive_id) + namespace_name = self._get_archive_namespace_name(archive_id) try: async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: @@ -433,7 +752,7 @@ class TurbopufferClient: if not passage_ids: return True - namespace_name = self._get_namespace_name(archive_id) + namespace_name = self._get_archive_namespace_name(archive_id) try: async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: @@ -451,7 +770,7 @@ class TurbopufferClient: """Delete all passages for an archive from Turbopuffer.""" from turbopuffer import AsyncTurbopuffer - namespace_name = self._get_namespace_name(archive_id) + namespace_name = self._get_archive_namespace_name(archive_id) try: async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: @@ -463,3 +782,21 @@ class TurbopufferClient: except Exception as e: logger.error(f"Failed to delete all passages from Turbopuffer: {e}") raise + + @trace_method + async def delete_all_messages(self, agent_id: str) -> bool: + """Delete all messages for an agent from Turbopuffer.""" + from turbopuffer import AsyncTurbopuffer + + namespace_name = self._get_message_namespace_name(agent_id) + + try: + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + # Turbopuffer has a delete_all() method on namespace + await namespace.delete_all() + logger.info(f"Successfully deleted all messages for agent {agent_id}") + return True + except Exception as e: + logger.error(f"Failed to delete all messages from Turbopuffer: {e}") + raise diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 4e552548..267a8a79 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -1,5 +1,6 @@ import json import uuid +from datetime import datetime from typing import List, Optional, Sequence from sqlalchemy import delete, exists, func, select, text @@ -11,7 +12,7 @@ from letta.orm.message import Message as MessageModel from letta.otel.tracing import trace_method from letta.schemas.enums import MessageRole from letta.schemas.letta_message import LettaMessageUpdateUnion -from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType +from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType, TextContent from letta.schemas.message import Message as PydanticMessage, MessageUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry @@ -30,6 +31,34 @@ class MessageManager: """Initialize the MessageManager.""" self.file_manager = FileManager() + def _extract_message_text(self, message: PydanticMessage) -> str: + """Extract text content from a message's complex content structure. + + Args: + message: The message to extract text from + + Returns: + Concatenated text content from the message + """ + # TODO: Make this much more complex/extend to beyond text content + if not message.content: + return "" + + # handle string content (legacy) + if isinstance(message.content, str): + return message.content + + # handle list of content items + text_parts = [] + for content_item in message.content: + if isinstance(content_item, TextContent): + text_parts.append(content_item.text) + elif hasattr(content_item, "text"): + # handle other content types that might have text + text_parts.append(content_item.text) + + return " ".join(text_parts) + @enforce_types @trace_method def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]: @@ -125,13 +154,19 @@ class MessageManager: @enforce_types @trace_method - async def create_many_messages_async(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]: + async def create_many_messages_async( + self, + pydantic_msgs: List[PydanticMessage], + actor: PydanticUser, + embedding_config: Optional[dict] = None, + ) -> List[PydanticMessage]: """ Create multiple messages in a single database transaction asynchronously. Args: pydantic_msgs: List of Pydantic message models to create actor: User performing the action + embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer Returns: List of created Pydantic message models @@ -169,6 +204,62 @@ class MessageManager: created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor, no_commit=True, no_refresh=True) result = [msg.to_pydantic() for msg in created_messages] await session.commit() + + # embed messages in turbopuffer if enabled and embedding_config provided + from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages + + if should_use_tpuf_for_messages() and embedding_config and result: + try: + # extract agent_id from the first message (all should have same agent_id) + agent_id = result[0].agent_id + if agent_id: + # extract text content from each message + message_texts = [] + message_ids = [] + roles = [] + created_ats = [] + + for msg in result: + text = self._extract_message_text(msg) + if text: # only embed messages with text content + message_texts.append(text) + message_ids.append(msg.id) + roles.append(msg.role) + created_ats.append(msg.created_at) + + if message_texts: + # generate embeddings using provided config + from letta.llm_api.llm_client import LLMClient + + # extract provider info from embedding_config + embedding_provider = embedding_config.get("provider", "openai") + embedding_api_key = embedding_config.get("api_key") + embedding_endpoint = embedding_config.get("endpoint", "https://api.openai.com/v1") + + embedding_client = LLMClient( + llm_provider_type=embedding_provider, + api_key=embedding_api_key, + endpoint=embedding_endpoint, + actor=actor, + ) + embeddings = await embedding_client.request_embeddings(message_texts, embedding_config) + + # insert to turbopuffer + tpuf_client = TurbopufferClient() + await tpuf_client.insert_messages( + agent_id=agent_id, + message_texts=message_texts, + embeddings=embeddings, + message_ids=message_ids, + organization_id=actor.organization_id, + roles=roles, + created_ats=created_ats, + ) + logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}") + except Exception as e: + # log error but don't fail the message creation + logger.error(f"Failed to embed messages in Turbopuffer: {e}") + return result @enforce_types @@ -672,3 +763,116 @@ class MessageManager: # return the number of rows deleted return result.rowcount + + @enforce_types + @trace_method + async def search_messages_async( + self, + agent_id: str, + actor: PydanticUser, + query_text: Optional[str] = None, + query_embedding: Optional[List[float]] = None, + search_mode: str = "hybrid", + roles: Optional[List[MessageRole]] = None, + limit: int = 50, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + embedding_config: Optional[dict] = None, + ) -> List[PydanticMessage]: + """ + Search messages using Turbopuffer if enabled, otherwise fall back to SQL search. + + Args: + agent_id: ID of the agent whose messages to search + actor: User performing the search + query_text: Text query for full-text search + query_embedding: Optional pre-computed embedding for vector search + search_mode: "vector", "fts", "hybrid", or "timestamp" (default: "hybrid") + roles: Optional list of message roles to filter by + limit: Maximum number of results to return + start_date: Optional filter for messages created after this date + end_date: Optional filter for messages created before this date + embedding_config: Optional embedding configuration for generating query embedding + + Returns: + List of matching messages + """ + from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages + + # check if we should use turbopuffer + if should_use_tpuf_for_messages(): + try: + # generate embedding if needed and not provided + if search_mode in ["vector", "hybrid"] and query_embedding is None and query_text: + if not embedding_config: + # fall back to SQL search if no embedding config + logger.warning("No embedding config provided for vector search, falling back to SQL") + return await self.list_messages_for_agent_async( + agent_id=agent_id, + actor=actor, + query_text=query_text, + roles=roles, + limit=limit, + ascending=False, + ) + + # generate embedding from query text + from letta.llm_api.llm_client import LLMClient + + embedding_provider = embedding_config.get("provider", "openai") + embedding_api_key = embedding_config.get("api_key") + embedding_endpoint = embedding_config.get("endpoint", "https://api.openai.com/v1") + + embedding_client = LLMClient( + llm_provider_type=embedding_provider, + api_key=embedding_api_key, + endpoint=embedding_endpoint, + actor=actor, + ) + embeddings = await embedding_client.request_embeddings([query_text], embedding_config) + query_embedding = embeddings[0] + + # use turbopuffer for search + tpuf_client = TurbopufferClient() + results = await tpuf_client.query_messages( + agent_id=agent_id, + query_embedding=query_embedding, + query_text=query_text, + search_mode=search_mode, + top_k=limit, + roles=roles, + start_date=start_date, + end_date=end_date, + ) + + # fetch full message objects from database using the IDs + message_ids = [msg_dict["id"] for msg_dict, _ in results] + if message_ids: + messages = await self.get_messages_by_ids_async(message_ids, actor) + # maintain the order from turbopuffer results + message_dict = {msg.id: msg for msg in messages} + return [message_dict[msg_id] for msg_id in message_ids if msg_id in message_dict] + else: + return [] + + except Exception as e: + logger.error(f"Failed to search messages with Turbopuffer, falling back to SQL: {e}") + # fall back to SQL search + return await self.list_messages_for_agent_async( + agent_id=agent_id, + actor=actor, + query_text=query_text, + roles=roles, + limit=limit, + ascending=False, + ) + else: + # use sql-based search + return await self.list_messages_for_agent_async( + agent_id=agent_id, + actor=actor, + query_text=query_text, + roles=roles, + limit=limit, + ascending=False, + ) diff --git a/letta/settings.py b/letta/settings.py index 81b54da6..fe56d5cd 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -301,6 +301,7 @@ class Settings(BaseSettings): use_tpuf: bool = False tpuf_api_key: Optional[str] = None tpuf_region: str = "gcp-us-central1" + embed_all_messages: bool = False # File processing timeout settings file_processing_timeout_minutes: int = 30 diff --git a/tests/integration_test_turbopuffer.py b/tests/integration_test_turbopuffer.py index 5cbda488..ef11e53b 100644 --- a/tests/integration_test_turbopuffer.py +++ b/tests/integration_test_turbopuffer.py @@ -4,9 +4,11 @@ from datetime import datetime, timezone import pytest from letta.config import LettaConfig -from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf +from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf, should_use_tpuf_for_messages from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import TagMatchMode, VectorDBProvider +from letta.schemas.enums import MessageRole, TagMatchMode, VectorDBProvider +from letta.schemas.letta_message_content import TextContent +from letta.schemas.message import Message as PydanticMessage from letta.schemas.passage import Passage from letta.server.server import SyncServer from letta.settings import settings @@ -69,6 +71,48 @@ def enable_turbopuffer(): settings.environment = original_environment +@pytest.fixture +def enable_message_embedding(): + """Enable both Turbopuffer and message embedding""" + original_use_tpuf = settings.use_tpuf + original_api_key = settings.tpuf_api_key + original_embed_messages = settings.embed_all_messages + original_environment = settings.environment + + settings.use_tpuf = True + settings.tpuf_api_key = settings.tpuf_api_key or "test-key" + settings.embed_all_messages = True + settings.environment = "DEV" + + yield + + settings.use_tpuf = original_use_tpuf + settings.tpuf_api_key = original_api_key + settings.embed_all_messages = original_embed_messages + settings.environment = original_environment + + +@pytest.fixture +def disable_turbopuffer(): + """Ensure Turbopuffer is disabled for testing""" + original_use_tpuf = settings.use_tpuf + original_embed_messages = settings.embed_all_messages + + settings.use_tpuf = False + settings.embed_all_messages = False + + yield + + settings.use_tpuf = original_use_tpuf + settings.embed_all_messages = original_embed_messages + + +@pytest.fixture +def sample_embedding_config(): + """Provide a sample embedding configuration""" + return EmbeddingConfig.default_config(model_name="letta") + + class TestTurbopufferIntegration: """Test Turbopuffer integration functionality with real connections""" @@ -374,14 +418,6 @@ class TestTurbopufferIntegration: # all results should have scores assert all(isinstance(score, float) for _, score in vector_heavy_results) - # Test error handling - missing embedding for vector mode - with pytest.raises(ValueError, match="query_embedding is required for vector search mode"): - await client.query_passages(archive_id=archive_id, search_mode="vector", top_k=3) - - # Test error handling - missing text for FTS mode - with pytest.raises(ValueError, match="query_text is required for FTS search mode"): - await client.query_passages(archive_id=archive_id, search_mode="fts", top_k=3) - # Test error handling - missing text for hybrid mode (embedding provided but text missing) with pytest.raises(ValueError, match="Both query_embedding and query_text are required"): await client.query_passages(archive_id=archive_id, query_embedding=[1.0, 2.0, 3.0], search_mode="hybrid", top_k=3) @@ -751,3 +787,645 @@ class TestTurbopufferParametrized: # Clean up await server.passage_manager.delete_agent_passages_async(recent_passage, default_user) await server.passage_manager.delete_agent_passages_async(old_passage, default_user) + + +class TestTurbopufferMessagesIntegration: + """Test Turbopuffer message embedding functionality""" + + def test_should_use_tpuf_for_messages_settings(self): + """Test that should_use_tpuf_for_messages correctly checks both use_tpuf AND embed_all_messages""" + # Save original values + original_use_tpuf = settings.use_tpuf + original_api_key = settings.tpuf_api_key + original_embed_messages = settings.embed_all_messages + + try: + # Test when both are true + settings.use_tpuf = True + settings.tpuf_api_key = "test-key" + settings.embed_all_messages = True + assert should_use_tpuf_for_messages() is True + + # Test when use_tpuf is False + settings.use_tpuf = False + settings.embed_all_messages = True + assert should_use_tpuf_for_messages() is False + + # Test when embed_all_messages is False + settings.use_tpuf = True + settings.tpuf_api_key = "test-key" + settings.embed_all_messages = False + assert should_use_tpuf_for_messages() is False + + # Test when both are false + settings.use_tpuf = False + settings.embed_all_messages = False + assert should_use_tpuf_for_messages() is False + + # Test when API key is missing + settings.use_tpuf = True + settings.tpuf_api_key = None + settings.embed_all_messages = True + assert should_use_tpuf_for_messages() is False + finally: + # Restore original values + settings.use_tpuf = original_use_tpuf + settings.tpuf_api_key = original_api_key + settings.embed_all_messages = original_embed_messages + + def test_message_text_extraction(self, server, default_user): + """Test extraction of text from various message content structures""" + manager = server.message_manager + + # Test 1: List with single string-like TextContent + msg1 = PydanticMessage( + role=MessageRole.user, + content=[TextContent(text="Simple text content")], + agent_id="test-agent", + ) + text1 = manager._extract_message_text(msg1) + assert text1 == "Simple text content" + + # Test 2: List with single TextContent + msg2 = PydanticMessage( + role=MessageRole.user, + content=[TextContent(text="Single text content")], + agent_id="test-agent", + ) + text2 = manager._extract_message_text(msg2) + assert text2 == "Single text content" + + # Test 3: List with multiple TextContent items + msg3 = PydanticMessage( + role=MessageRole.user, + content=[ + TextContent(text="First part"), + TextContent(text="Second part"), + TextContent(text="Third part"), + ], + agent_id="test-agent", + ) + text3 = manager._extract_message_text(msg3) + assert text3 == "First part Second part Third part" + + # Test 4: Empty content + msg4 = PydanticMessage( + role=MessageRole.system, + content=None, + agent_id="test-agent", + ) + text4 = manager._extract_message_text(msg4) + assert text4 == "" + + # Test 5: Empty list + msg5 = PydanticMessage( + role=MessageRole.assistant, + content=[], + agent_id="test-agent", + ) + text5 = manager._extract_message_text(msg5) + assert text5 == "" + + @pytest.mark.asyncio + @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") + async def test_message_embedding_without_config(self, server, default_user, sarah_agent, enable_message_embedding): + """Test that messages are NOT embedded without embedding_config even when tpuf is enabled""" + # Create messages WITHOUT embedding_config + messages = [ + PydanticMessage( + role=MessageRole.user, + content=[TextContent(text="Test message without embedding config")], + agent_id=sarah_agent.id, + ), + PydanticMessage( + role=MessageRole.assistant, + content=[TextContent(text="Response without embedding config")], + agent_id=sarah_agent.id, + ), + ] + + # Create messages without embedding_config + created = await server.message_manager.create_many_messages_async( + pydantic_msgs=messages, + actor=default_user, + embedding_config=None, # No config provided + ) + + assert len(created) == 2 + assert all(msg.agent_id == sarah_agent.id for msg in created) + + # Messages should be in SQL + sql_messages = await server.message_manager.list_messages_for_agent_async( + agent_id=sarah_agent.id, + actor=default_user, + limit=10, + ) + assert len(sql_messages) >= 2 + + # Clean up + message_ids = [msg.id for msg in created] + await server.message_manager.delete_messages_by_ids_async(message_ids, default_user) + + @pytest.mark.asyncio + async def test_generic_reciprocal_rank_fusion(self): + """Test the generic RRF function with different object types""" + from letta.helpers.tpuf_client import TurbopufferClient + + client = TurbopufferClient() + + # Test with passage objects (backward compatibility) + p1_id = "passage-78d49031-8502-49c1-a970-45663e9f6e07" + p2_id = "passage-90df8386-4caf-49cc-acbc-d71526de6f77" + passage1 = Passage( + id=p1_id, + text="First passage", + organization_id="org1", + archive_id="archive1", + created_at=datetime.now(timezone.utc), + metadata_={}, + tags=[], + embedding=[], + embedding_config=None, + ) + passage2 = Passage( + id=p2_id, + text="Second passage", + organization_id="org1", + archive_id="archive1", + created_at=datetime.now(timezone.utc), + metadata_={}, + tags=[], + embedding=[], + embedding_config=None, + ) + + vector_results = [(passage1, 0.9), (passage2, 0.7)] + fts_results = [(passage2, 0.8), (passage1, 0.6)] + + # Test with passages using the wrapper function + combined = client._reciprocal_rank_fusion( + vector_results=vector_results, + fts_results=fts_results, + vector_weight=0.5, + fts_weight=0.5, + top_k=2, + ) + + assert len(combined) == 2 + # Both passages should be in results + result_ids = [p.id for p, _ in combined] + assert p1_id in result_ids + assert p2_id in result_ids + + # Test with message dicts using generic function + msg1 = {"id": "m1", "text": "First message"} + msg2 = {"id": "m2", "text": "Second message"} + msg3 = {"id": "m3", "text": "Third message"} + + vector_msg_results = [(msg1, 0.95), (msg2, 0.85), (msg3, 0.75)] + fts_msg_results = [(msg2, 0.90), (msg3, 0.80), (msg1, 0.70)] + + combined_msgs = client._generic_reciprocal_rank_fusion( + vector_results=vector_msg_results, + fts_results=fts_msg_results, + get_id_func=lambda m: m["id"], + vector_weight=0.6, + fts_weight=0.4, + top_k=3, + ) + + assert len(combined_msgs) == 3 + msg_ids = [m["id"] for m, _ in combined_msgs] + assert "m1" in msg_ids + assert "m2" in msg_ids + assert "m3" in msg_ids + + # Test edge cases + # Empty results + empty_combined = client._generic_reciprocal_rank_fusion( + vector_results=[], + fts_results=[], + get_id_func=lambda x: x["id"], + vector_weight=0.5, + fts_weight=0.5, + top_k=10, + ) + assert len(empty_combined) == 0 + + # Single result list + single_combined = client._generic_reciprocal_rank_fusion( + vector_results=[(msg1, 0.9)], + fts_results=[], + get_id_func=lambda m: m["id"], + vector_weight=0.5, + fts_weight=0.5, + top_k=10, + ) + assert len(single_combined) == 1 + assert single_combined[0][0]["id"] == "m1" + + @pytest.mark.asyncio + @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") + async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding): + """Test actual message embedding and storage in Turbopuffer""" + import uuid + from datetime import datetime, timezone + + from letta.helpers.tpuf_client import TurbopufferClient + from letta.schemas.enums import MessageRole + + client = TurbopufferClient() + agent_id = f"test-agent-{uuid.uuid4()}" + org_id = str(uuid.uuid4()) + + try: + # Prepare test messages + message_texts = [ + "Hello, how can I help you today?", + "I need help with Python programming.", + "Sure, what specific Python topic?", + ] + message_ids = [str(uuid.uuid4()) for _ in message_texts] + roles = [MessageRole.assistant, MessageRole.user, MessageRole.assistant] + created_ats = [datetime.now(timezone.utc) for _ in message_texts] + + # Generate embeddings (dummy for test) + embeddings = [[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))] + + # Insert messages into Turbopuffer + success = await client.insert_messages( + agent_id=agent_id, + message_texts=message_texts, + embeddings=embeddings, + message_ids=message_ids, + organization_id=org_id, + roles=roles, + created_ats=created_ats, + ) + + assert success == True + + # Verify we can query the messages + results = await client.query_messages( + agent_id=agent_id, + search_mode="timestamp", + top_k=10, + ) + + assert len(results) == 3 + # Results should be ordered by timestamp (most recent first) + for msg_dict, score in results: + assert msg_dict["agent_id"] == agent_id + assert msg_dict["organization_id"] == org_id + assert msg_dict["text"] in message_texts + assert msg_dict["role"] in ["assistant", "user"] + + finally: + # Clean up namespace + try: + await client.delete_all_messages(agent_id) + except: + pass + + @pytest.mark.asyncio + @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") + async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding): + """Test vector search on messages in Turbopuffer""" + import uuid + from datetime import datetime, timezone + + from letta.helpers.tpuf_client import TurbopufferClient + from letta.schemas.enums import MessageRole + + client = TurbopufferClient() + agent_id = f"test-agent-{uuid.uuid4()}" + org_id = str(uuid.uuid4()) + + try: + # Insert messages with different embeddings + message_texts = [ + "Python is a great programming language", + "JavaScript is used for web development", + "Machine learning with Python is powerful", + ] + message_ids = [str(uuid.uuid4()) for _ in message_texts] + roles = [MessageRole.assistant] * len(message_texts) + created_ats = [datetime.now(timezone.utc) for _ in message_texts] + + # Create embeddings that reflect content similarity + embeddings = [ + [1.0, 0.0, 0.0], # Python programming + [0.0, 1.0, 0.0], # JavaScript web + [0.8, 0.0, 0.2], # ML with Python (similar to first) + ] + + # Insert messages + await client.insert_messages( + agent_id=agent_id, + message_texts=message_texts, + embeddings=embeddings, + message_ids=message_ids, + organization_id=org_id, + roles=roles, + created_ats=created_ats, + ) + + # Search for Python-related messages using vector search + query_embedding = [0.9, 0.0, 0.1] # Similar to Python messages + results = await client.query_messages( + agent_id=agent_id, + query_embedding=query_embedding, + search_mode="vector", + top_k=2, + ) + + assert len(results) == 2 + # Should return Python-related messages first + result_texts = [msg["text"] for msg, _ in results] + assert "Python is a great programming language" in result_texts + assert "Machine learning with Python is powerful" in result_texts + + finally: + # Clean up namespace + try: + await client.delete_all_messages(agent_id) + except: + pass + + @pytest.mark.asyncio + @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") + async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding): + """Test hybrid search combining vector and FTS for messages""" + import uuid + from datetime import datetime, timezone + + from letta.helpers.tpuf_client import TurbopufferClient + from letta.schemas.enums import MessageRole + + client = TurbopufferClient() + agent_id = f"test-agent-{uuid.uuid4()}" + org_id = str(uuid.uuid4()) + + try: + # Insert diverse messages + message_texts = [ + "The quick brown fox jumps over the lazy dog", + "Machine learning algorithms are fascinating", + "Quick tutorial on Python programming", + "Deep learning with neural networks", + ] + message_ids = [str(uuid.uuid4()) for _ in message_texts] + roles = [MessageRole.assistant] * len(message_texts) + created_ats = [datetime.now(timezone.utc) for _ in message_texts] + + # Embeddings + embeddings = [ + [0.1, 0.9, 0.0], # fox text + [0.9, 0.1, 0.0], # ML algorithms + [0.5, 0.5, 0.0], # Quick Python + [0.8, 0.2, 0.0], # Deep learning + ] + + # Insert messages + await client.insert_messages( + agent_id=agent_id, + message_texts=message_texts, + embeddings=embeddings, + message_ids=message_ids, + organization_id=org_id, + roles=roles, + created_ats=created_ats, + ) + + # Hybrid search - vector similar to ML but text contains "quick" + results = await client.query_messages( + agent_id=agent_id, + query_embedding=[0.7, 0.3, 0.0], # Similar to ML messages + query_text="quick", # Text search for "quick" + search_mode="hybrid", + top_k=3, + vector_weight=0.5, + fts_weight=0.5, + ) + + assert len(results) > 0 + # Should get a mix of results based on both vector and text similarity + result_texts = [msg["text"] for msg, _ in results] + # At least one result should contain "quick" due to FTS + assert any("quick" in text.lower() for text in result_texts) + + finally: + # Clean up namespace + try: + await client.delete_all_messages(agent_id) + except: + pass + + @pytest.mark.asyncio + @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") + async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding): + """Test filtering messages by role""" + import uuid + from datetime import datetime, timezone + + from letta.helpers.tpuf_client import TurbopufferClient + from letta.schemas.enums import MessageRole + + client = TurbopufferClient() + agent_id = f"test-agent-{uuid.uuid4()}" + org_id = str(uuid.uuid4()) + + try: + # Insert messages with different roles + message_data = [ + ("Hello! How can I help?", MessageRole.assistant), + ("I need help with Python", MessageRole.user), + ("Here's a Python example", MessageRole.assistant), + ("Can you explain this?", MessageRole.user), + ("System message here", MessageRole.system), + ] + + message_texts = [text for text, _ in message_data] + roles = [role for _, role in message_data] + message_ids = [str(uuid.uuid4()) for _ in message_texts] + created_ats = [datetime.now(timezone.utc) for _ in message_texts] + embeddings = [[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))] + + # Insert messages + await client.insert_messages( + agent_id=agent_id, + message_texts=message_texts, + embeddings=embeddings, + message_ids=message_ids, + organization_id=org_id, + roles=roles, + created_ats=created_ats, + ) + + # Query only user messages + user_results = await client.query_messages( + agent_id=agent_id, + search_mode="timestamp", + top_k=10, + roles=[MessageRole.user], + ) + + assert len(user_results) == 2 + for msg, _ in user_results: + assert msg["role"] == "user" + assert msg["text"] in ["I need help with Python", "Can you explain this?"] + + # Query assistant and system messages + non_user_results = await client.query_messages( + agent_id=agent_id, + search_mode="timestamp", + top_k=10, + roles=[MessageRole.assistant, MessageRole.system], + ) + + assert len(non_user_results) == 3 + for msg, _ in non_user_results: + assert msg["role"] in ["assistant", "system"] + + finally: + # Clean up namespace + try: + await client.delete_all_messages(agent_id) + except: + pass + + @pytest.mark.asyncio + async def test_message_search_fallback_to_sql(self, server, default_user, sarah_agent): + """Test that message search falls back to SQL when Turbopuffer is disabled""" + # Save original settings + original_use_tpuf = settings.use_tpuf + original_embed_messages = settings.embed_all_messages + + try: + # Disable Turbopuffer for messages + settings.use_tpuf = False + settings.embed_all_messages = False + + # Create messages + messages = await server.message_manager.create_many_messages_async( + pydantic_msgs=[ + PydanticMessage( + role=MessageRole.user, + content=[TextContent(text="Test message for SQL fallback")], + agent_id=sarah_agent.id, + ) + ], + actor=default_user, + ) + + # Search should use SQL backend (not Turbopuffer) + results = await server.message_manager.search_messages_async( + actor=default_user, + agent_id=sarah_agent.id, + query_text="fallback", + limit=10, + ) + + # Should return results from SQL search + assert len(results) > 0 + # Extract text from messages and check for "fallback" + for msg in results: + text = server.message_manager._extract_message_text(msg) + if "fallback" in text.lower(): + break + else: + assert False, "No messages containing 'fallback' found" + + finally: + # Restore settings + settings.use_tpuf = original_use_tpuf + settings.embed_all_messages = original_embed_messages + + @pytest.mark.asyncio + @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") + async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding): + """Test filtering messages by date range""" + import uuid + from datetime import datetime, timedelta, timezone + + from letta.helpers.tpuf_client import TurbopufferClient + from letta.schemas.enums import MessageRole + + client = TurbopufferClient() + agent_id = f"test-agent-{uuid.uuid4()}" + org_id = str(uuid.uuid4()) + + try: + # Create messages with different timestamps + now = datetime.now(timezone.utc) + yesterday = now - timedelta(days=1) + last_week = now - timedelta(days=7) + last_month = now - timedelta(days=30) + + message_data = [ + ("Today's message", now), + ("Yesterday's message", yesterday), + ("Last week's message", last_week), + ("Last month's message", last_month), + ] + + for text, timestamp in message_data: + await client.insert_messages( + agent_id=agent_id, + message_texts=[text], + embeddings=[[1.0, 2.0, 3.0]], + message_ids=[str(uuid.uuid4())], + organization_id=org_id, + roles=[MessageRole.assistant], + created_ats=[timestamp], + ) + + # Query messages from the last 3 days + three_days_ago = now - timedelta(days=3) + recent_results = await client.query_messages( + agent_id=agent_id, + search_mode="timestamp", + top_k=10, + start_date=three_days_ago, + ) + + # Should get today's and yesterday's messages + assert len(recent_results) == 2 + result_texts = [msg["text"] for msg, _ in recent_results] + assert "Today's message" in result_texts + assert "Yesterday's message" in result_texts + + # Query messages between 2 weeks ago and 1 week ago + two_weeks_ago = now - timedelta(days=14) + week_results = await client.query_messages( + agent_id=agent_id, + search_mode="timestamp", + top_k=10, + start_date=two_weeks_ago, + end_date=last_week + timedelta(days=1), # Include last week's message + ) + + # Should get only last week's message + assert len(week_results) == 1 + assert week_results[0][0]["text"] == "Last week's message" + + # Query with vector search and date filtering + filtered_vector_results = await client.query_messages( + agent_id=agent_id, + query_embedding=[1.0, 2.0, 3.0], + search_mode="vector", + top_k=10, + start_date=three_days_ago, + ) + + # Should get only recent messages + assert len(filtered_vector_results) == 2 + for msg, _ in filtered_vector_results: + assert msg["text"] in ["Today's message", "Yesterday's message"] + + finally: + # Clean up namespace + try: + await client.delete_all_messages(agent_id) + except: + pass