From 2e3cabc080e91641b5b2f8fefad958b5e1cb2977 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 5 Sep 2025 14:28:27 -0700 Subject: [PATCH] feat: Add search messages endpoint [LET-4144] (#4434) * Add search messages endpoint * Run fern autogen and fix tests --- letta/agents/base_agent.py | 1 - letta/agents/helpers.py | 1 - letta/agents/letta_agent.py | 7 +- letta/helpers/tpuf_client.py | 236 ++++++++++++-- letta/schemas/message.py | 23 ++ letta/server/rest_api/routers/v1/agents.py | 38 ++- letta/services/agent_manager.py | 11 +- letta/services/agent_serialization_manager.py | 1 - letta/services/message_manager.py | 159 +++++----- letta/services/passage_manager.py | 3 +- letta/services/summarizer/summarizer.py | 1 - .../tool_executor/core_tool_executor.py | 1 - tests/integration_test_turbopuffer.py | 292 ++++++------------ 13 files changed, 462 insertions(+), 312 deletions(-) diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 4ada4a94..99715e0b 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -178,7 +178,6 @@ class BaseAgent(ABC): curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor, - embedding_config=agent_state.embedding_config, project_id=agent_state.project_id, ) return [new_system_message] + in_context_messages[1:] diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 011675eb..d81fe6d0 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -117,7 +117,6 @@ async def _prepare_in_context_messages_async( new_in_context_messages = await message_manager.create_many_messages_async( create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor), actor=actor, - embedding_config=agent_state.embedding_config, project_id=agent_state.project_id, ) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 2d7baa8f..19adb37f 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -497,7 +497,6 @@ class LettaAgent(BaseAgent): await self.message_manager.create_many_messages_async( initial_messages, actor=self.actor, - embedding_config=agent_state.embedding_config, project_id=agent_state.project_id, ) elif step_progression <= StepProgression.LOGGED_TRACE: @@ -828,7 +827,6 @@ class LettaAgent(BaseAgent): await self.message_manager.create_many_messages_async( initial_messages, actor=self.actor, - embedding_config=agent_state.embedding_config, project_id=agent_state.project_id, ) elif step_progression <= StepProgression.LOGGED_TRACE: @@ -1267,7 +1265,6 @@ class LettaAgent(BaseAgent): await self.message_manager.create_many_messages_async( initial_messages, actor=self.actor, - embedding_config=agent_state.embedding_config, project_id=agent_state.project_id, ) elif step_progression <= StepProgression.LOGGED_TRACE: @@ -1676,7 +1673,7 @@ class LettaAgent(BaseAgent): ) messages_to_persist = (initial_messages or []) + tool_call_messages persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config, project_id=agent_state.project_id + messages_to_persist, actor=self.actor, project_id=agent_state.project_id ) return persisted_messages, continue_stepping, stop_reason @@ -1788,7 +1785,7 @@ class LettaAgent(BaseAgent): messages_to_persist = (initial_messages or []) + tool_call_messages persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config, project_id=agent_state.project_id + messages_to_persist, actor=self.actor, project_id=agent_state.project_id ) if run_id: diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index 2920ad62..f3a088cf 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -4,26 +4,37 @@ import logging from datetime import datetime, timezone from typing import Any, Callable, List, Optional, Tuple +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE from letta.otel.tracing import trace_method +from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole, TagMatchMode from letta.schemas.passage import Passage as PydanticPassage -from letta.settings import settings +from letta.settings import model_settings, settings logger = logging.getLogger(__name__) def should_use_tpuf() -> bool: + # We need OpenAI since we default to their embedding model return bool(settings.use_tpuf) and bool(settings.tpuf_api_key) def should_use_tpuf_for_messages() -> bool: """Check if Turbopuffer should be used for messages.""" - return should_use_tpuf() and bool(settings.embed_all_messages) + return should_use_tpuf() and bool(settings.embed_all_messages) and bool(model_settings.openai_api_key) class TurbopufferClient: """Client for managing archival memory with Turbopuffer vector database.""" + default_embedding_config = EmbeddingConfig( + embedding_model="text-embedding-3-small", + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, + ) + def __init__(self, api_key: str = None, region: str = None): """Initialize Turbopuffer client.""" self.api_key = api_key or settings.tpuf_api_key @@ -38,6 +49,26 @@ class TurbopufferClient: if not self.api_key: raise ValueError("Turbopuffer API key not provided") + @trace_method + async def _generate_embeddings(self, texts: List[str], actor: "PydanticUser") -> List[List[float]]: + """Generate embeddings using the default embedding configuration. + + Args: + texts: List of texts to embed + actor: User actor for embedding generation + + Returns: + List of embedding vectors + """ + from letta.llm_api.llm_client import LLMClient + + embedding_client = LLMClient.create( + provider_type=self.default_embedding_config.embedding_endpoint_type, + actor=actor, + ) + embeddings = await embedding_client.request_embeddings(texts, self.default_embedding_config) + return embeddings + @trace_method async def _get_archive_namespace_name(self, archive_id: str) -> str: """Get namespace name for a specific archive.""" @@ -61,9 +92,9 @@ class TurbopufferClient: self, archive_id: str, text_chunks: List[str], - embeddings: List[List[float]], passage_ids: List[str], organization_id: str, + actor: "PydanticUser", tags: Optional[List[str]] = None, created_at: Optional[datetime] = None, ) -> List[PydanticPassage]: @@ -72,9 +103,9 @@ class TurbopufferClient: Args: archive_id: ID of the archive text_chunks: List of text chunks to store - embeddings: List of embedding vectors corresponding to text chunks passage_ids: List of passage IDs (must match 1:1 with text_chunks) organization_id: Organization ID for the passages + actor: User actor for embedding generation tags: Optional list of tags to attach to all passages created_at: Optional timestamp for retroactive entries (defaults to current UTC time) @@ -83,6 +114,9 @@ class TurbopufferClient: """ from turbopuffer import AsyncTurbopuffer + # generate embeddings using the default config + embeddings = await self._generate_embeddings(text_chunks, actor) + namespace_name = await self._get_archive_namespace_name(archive_id) # handle timestamp - ensure UTC @@ -102,8 +136,6 @@ class TurbopufferClient: raise ValueError("passage_ids must be provided for Turbopuffer insertion") if len(passage_ids) != len(text_chunks): raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})") - if len(passage_ids) != len(embeddings): - raise ValueError(f"passage_ids length ({len(passage_ids)}) must match embeddings length ({len(embeddings)})") # prepare column-based data for turbopuffer - optimized for batch insert ids = [] @@ -137,7 +169,7 @@ class TurbopufferClient: metadata_={}, tags=tags or [], # Include tags in the passage embedding=embedding, - embedding_config=None, # Will be set by caller if needed + embedding_config=self.default_embedding_config, # Will be set by caller if needed ) passages.append(passage) @@ -177,9 +209,9 @@ class TurbopufferClient: self, agent_id: str, message_texts: List[str], - embeddings: List[List[float]], message_ids: List[str], organization_id: str, + actor: "PydanticUser", roles: List[MessageRole], created_ats: List[datetime], project_id: Optional[str] = None, @@ -189,9 +221,9 @@ class TurbopufferClient: Args: agent_id: ID of the agent message_texts: List of message text content to store - embeddings: List of embedding vectors corresponding to message texts message_ids: List of message IDs (must match 1:1 with message_texts) organization_id: Organization ID for the messages + actor: User actor for embedding generation roles: List of message roles corresponding to each message created_ats: List of creation timestamps for each message project_id: Optional project ID for all messages @@ -201,6 +233,9 @@ class TurbopufferClient: """ from turbopuffer import AsyncTurbopuffer + # generate embeddings using the default config + embeddings = await self._generate_embeddings(message_texts, actor) + namespace_name = await self._get_message_namespace_name(agent_id, organization_id) # validation checks @@ -208,8 +243,6 @@ class TurbopufferClient: raise ValueError("message_ids must be provided for Turbopuffer insertion") if len(message_ids) != len(message_texts): raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})") - if len(message_ids) != len(embeddings): - raise ValueError(f"message_ids length ({len(message_ids)}) must match embeddings length ({len(embeddings)})") if len(message_ids) != len(roles): raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})") if len(message_ids) != len(created_ats): @@ -390,7 +423,7 @@ class TurbopufferClient: async def query_passages( self, archive_id: str, - query_embedding: Optional[List[float]] = None, + actor: "PydanticUser", query_text: Optional[str] = None, search_mode: str = "vector", # "vector", "fts", "hybrid" top_k: int = 10, @@ -405,8 +438,8 @@ class TurbopufferClient: Args: archive_id: ID of the archive - query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes) - query_text: Text query for full-text search (required for "fts" and "hybrid" modes) + actor: User actor for embedding generation + query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes) search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector") top_k: Number of results to return tags: Optional list of tags to filter by @@ -419,6 +452,12 @@ class TurbopufferClient: Returns: List of (passage, score, metadata) tuples with relevance rankings """ + # generate embedding for vector/hybrid search if query_text is provided + query_embedding = None + if query_text and search_mode in ["vector", "hybrid"]: + embeddings = await self._generate_embeddings([query_text], actor) + query_embedding = embeddings[0] + # Check if we should fallback to timestamp-based retrieval if query_embedding is None and query_text is None and search_mode not in ["timestamp"]: # Fallback to retrieving most recent passages when no search query is provided @@ -519,11 +558,11 @@ class TurbopufferClient: raise @trace_method - async def query_messages( + async def query_messages_by_agent_id( self, agent_id: str, organization_id: str, - query_embedding: Optional[List[float]] = None, + actor: "PydanticUser", query_text: Optional[str] = None, search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp" top_k: int = 10, @@ -539,8 +578,8 @@ class TurbopufferClient: Args: agent_id: ID of the agent (used for filtering results) organization_id: Organization ID for namespace lookup - query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes) - query_text: Text query for full-text search (required for "fts" and "hybrid" modes) + actor: User actor for embedding generation + query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes) search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector") top_k: Number of results to return roles: Optional list of message roles to filter by @@ -556,6 +595,12 @@ class TurbopufferClient: - score is the final relevance score - metadata contains individual scores and ranking information """ + # generate embedding for vector/hybrid search if query_text is provided + query_embedding = None + if query_text and search_mode in ["vector", "hybrid"]: + embeddings = await self._generate_embeddings([query_text], actor) + query_embedding = embeddings[0] + # Check if we should fallback to timestamp-based retrieval if query_embedding is None and query_text is None and search_mode not in ["timestamp"]: # Fallback to retrieving most recent messages when no search query is provided @@ -658,6 +703,159 @@ class TurbopufferClient: logger.error(f"Failed to query messages from Turbopuffer: {e}") raise + async def query_messages_by_org_id( + self, + organization_id: str, + actor: "PydanticUser", + query_text: Optional[str] = None, + search_mode: str = "hybrid", # "vector", "fts", "hybrid" + top_k: int = 10, + roles: Optional[List[MessageRole]] = None, + project_id: Optional[str] = None, + vector_weight: float = 0.5, + fts_weight: float = 0.5, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> List[Tuple[dict, float, dict]]: + """Query messages from Turbopuffer across an entire organization. + + Args: + organization_id: Organization ID for namespace lookup (required) + actor: User actor for embedding generation + query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes) + search_mode: Search mode - "vector", "fts", or "hybrid" (default: "hybrid") + top_k: Number of results to return + roles: Optional list of message roles to filter by + project_id: Optional project ID to filter messages by + vector_weight: Weight for vector search results in hybrid mode (default: 0.5) + fts_weight: Weight for FTS results in hybrid mode (default: 0.5) + start_date: Optional datetime to filter messages created after this date + end_date: Optional datetime to filter messages created on or before this date (inclusive) + + Returns: + List of (message_dict, score, metadata) tuples where: + - message_dict contains id, text, role, created_at, agent_id + - score is the final relevance score (RRF score for hybrid, rank-based for single mode) + - metadata contains individual scores and ranking information + """ + # generate embedding for vector/hybrid search if query_text is provided + query_embedding = None + if query_text and search_mode in ["vector", "hybrid"]: + embeddings = await self._generate_embeddings([query_text], actor) + query_embedding = embeddings[0] + # namespace is org-scoped + namespace_name = f"letta_messages_{organization_id}" + + # build filters + all_filters = [] + + # role filter + if roles: + role_values = [r.value for r in roles] + if len(role_values) == 1: + all_filters.append(("role", "Eq", role_values[0])) + else: + all_filters.append(("role", "In", role_values)) + + # project filter + if project_id: + all_filters.append(("project_id", "Eq", project_id)) + + # date filters + if start_date: + all_filters.append(("created_at", "Gte", start_date)) + if end_date: + # make end_date inclusive of the entire day + if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0: + from datetime import timedelta + + end_date = end_date + timedelta(days=1) - timedelta(microseconds=1) + all_filters.append(("created_at", "Lte", end_date)) + + # combine filters + final_filter = None + if len(all_filters) == 1: + final_filter = all_filters[0] + elif len(all_filters) > 1: + final_filter = ("And", all_filters) + + try: + # execute query + result = await self._execute_query( + namespace_name=namespace_name, + search_mode=search_mode, + query_embedding=query_embedding, + query_text=query_text, + top_k=top_k, + include_attributes=["text", "organization_id", "agent_id", "role", "created_at"], + filters=final_filter, + vector_weight=vector_weight, + fts_weight=fts_weight, + ) + + # process results based on search mode + if search_mode == "hybrid": + # for hybrid mode, we get a multi-query response + vector_results = self._process_message_query_results(result.results[0]) + fts_results = self._process_message_query_results(result.results[1]) + + # use existing RRF method - it already returns metadata with ranks + results_with_metadata = self._reciprocal_rank_fusion( + vector_results=vector_results, + fts_results=fts_results, + get_id_func=lambda msg_dict: msg_dict["id"], + vector_weight=vector_weight, + fts_weight=fts_weight, + top_k=top_k, + ) + + # add raw scores to metadata if available + vector_scores = {} + for row in result.results[0].rows: + if hasattr(row, "dist"): + vector_scores[row.id] = row.dist + + fts_scores = {} + for row in result.results[1].rows: + if hasattr(row, "score"): + fts_scores[row.id] = row.score + + # enhance metadata with raw scores + enhanced_results = [] + for msg_dict, rrf_score, metadata in results_with_metadata: + msg_id = msg_dict["id"] + if msg_id in vector_scores: + metadata["vector_score"] = vector_scores[msg_id] + if msg_id in fts_scores: + metadata["fts_score"] = fts_scores[msg_id] + enhanced_results.append((msg_dict, rrf_score, metadata)) + + return enhanced_results + else: + # for single queries (vector or fts) + results = self._process_message_query_results(result) + results_with_metadata = [] + for idx, msg_dict in enumerate(results): + metadata = { + "combined_score": 1.0 / (idx + 1), + "search_mode": search_mode, + f"{search_mode}_rank": idx + 1, + } + + # add raw score if available + if hasattr(result.rows[idx], "dist"): + metadata["vector_score"] = result.rows[idx].dist + elif hasattr(result.rows[idx], "score"): + metadata["fts_score"] = result.rows[idx].score + + results_with_metadata.append((msg_dict, metadata["combined_score"], metadata)) + + return results_with_metadata + + except Exception as e: + logger.error(f"Failed to query messages from Turbopuffer: {e}") + raise + def _process_message_query_results(self, result) -> List[dict]: """Process results from a message query into message dicts. @@ -703,7 +901,7 @@ class TurbopufferClient: tags=passage_tags, # Set the actual tags from the passage # Set required fields to empty/default values since we don't store embeddings embedding=[], # Empty embedding since we don't return it from Turbopuffer - embedding_config=None, # No embedding config needed for retrieved passages + embedding_config=self.default_embedding_config, # No embedding config needed for retrieved passages ) # handle score based on search type diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 929b95c2..769389ae 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -1187,3 +1187,26 @@ class ToolReturn(BaseModel): stdout: Optional[List[str]] = Field(default=None, description="Captured stdout (e.g. prints, logs) from the tool invocation") stderr: Optional[List[str]] = Field(default=None, description="Captured stderr from the tool invocation") # func_return: Optional[Any] = Field(None, description="The function return object") + + +class MessageSearchRequest(BaseModel): + """Request model for searching messages across the organization""" + + query_text: Optional[str] = Field(None, description="Text query for full-text search") + search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use") + roles: Optional[List[MessageRole]] = Field(None, description="Filter messages by role") + project_id: Optional[str] = Field(None, description="Filter messages by project ID") + limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100) + start_date: Optional[datetime] = Field(None, description="Filter messages created after this date") + end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date") + + +class MessageSearchResult(BaseModel): + """Result from a message search operation with scoring details.""" + + message: Message = Field(..., description="The message content and metadata") + fts_score: Optional[float] = Field(None, description="Full-text search (BM25) score if FTS was used") + fts_rank: Optional[int] = Field(None, description="Full-text search rank position if FTS was used") + vector_score: Optional[float] = Field(None, description="Vector similarity score if vector search was used") + vector_rank: Optional[int] = Field(None, description="Vector search rank position if vector search was used") + rrf_score: float = Field(..., description="Reciprocal Rank Fusion combined score") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 13da1b29..f29bcce5 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -39,7 +39,7 @@ from letta.schemas.memory import ( CreateArchivalMemory, Memory, ) -from letta.schemas.message import MessageCreate +from letta.schemas.message import MessageCreate, MessageSearchRequest, MessageSearchResult from letta.schemas.passage import Passage from letta.schemas.run import Run from letta.schemas.source import Source @@ -1498,6 +1498,42 @@ async def cancel_agent_run( return results +@router.post("/messages/search", response_model=List[MessageSearchResult], operation_id="search_messages") +async def search_messages( + request: MessageSearchRequest = Body(...), + server: SyncServer = Depends(get_letta_server), + actor_id: str | None = Header(None, alias="user_id"), +): + """ + Search messages across the entire organization with optional project filtering. + Returns messages with FTS/vector ranks and total RRF score. + + Requires message embedding and Turbopuffer to be enabled. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + + # get embedding config from the default agent if needed + # check if any agents exist in the org + agent_count = await server.agent_manager.size_async(actor=actor) + if agent_count == 0: + raise HTTPException(status_code=400, detail="No agents found in organization to derive embedding configuration from") + + try: + results = await server.message_manager.search_messages_org_async( + actor=actor, + query_text=request.query_text, + search_mode=request.search_mode, + roles=request.roles, + project_id=request.project_id, + limit=request.limit, + start_date=request.start_date, + end_date=request.end_date, + ) + return results + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + async def _process_message_background( run_id: str, server: SyncServer, diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 60272936..00a85a05 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -719,9 +719,7 @@ class AgentManager: # Only create messages if we initialized with messages if not _init_with_no_messages: - await self.message_manager.create_many_messages_async( - pydantic_msgs=init_messages, actor=actor, embedding_config=result.embedding_config, project_id=result.project_id - ) + await self.message_manager.create_many_messages_async(pydantic_msgs=init_messages, actor=actor, project_id=result.project_id) return result @enforce_types @@ -1834,7 +1832,6 @@ class AgentManager: message_id=curr_system_message.id, message_update=MessageUpdate(**temp_message.model_dump()), actor=actor, - embedding_config=agent_state.embedding_config, project_id=agent_state.project_id, ) else: @@ -1889,9 +1886,7 @@ class AgentManager: self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser ) -> PydanticAgentState: agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) - messages = await self.message_manager.create_many_messages_async( - messages, actor=actor, embedding_config=agent.embedding_config, project_id=agent.project_id - ) + messages = await self.message_manager.create_many_messages_async(messages, actor=actor, project_id=agent.project_id) message_ids = agent.message_ids or [] message_ids += [m.id for m in messages] return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor) @@ -2692,7 +2687,6 @@ class AgentManager: # use hybrid search to combine vector and full-text search passages_with_scores = await tpuf_client.query_passages( archive_id=archive_ids[0], - query_embedding=query_embedding, query_text=query_text, # pass text for potential hybrid search search_mode="hybrid", # use hybrid mode for better results top_k=limit, @@ -2700,6 +2694,7 @@ class AgentManager: tag_match_mode=tag_match_mode or TagMatchMode.ANY, start_date=start_date, end_date=end_date, + actor=actor, ) # Return full tuples with metadata diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index 0bdcf5c6..992c50ea 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -678,7 +678,6 @@ class AgentSerializationManager: created_messages = await self.message_manager.create_many_messages_async( pydantic_msgs=messages, actor=actor, - embedding_config=created_agent.embedding_config, project_id=created_agent.project_id, ) imported_count += len(created_messages) diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index c5a977b8..6f774337 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -11,11 +11,10 @@ from letta.orm.agent import Agent as AgentModel from letta.orm.errors import NoResultFound from letta.orm.message import Message as MessageModel from letta.otel.tracing import trace_method -from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole from letta.schemas.letta_message import LettaMessageUpdateUnion from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType, TextContent -from letta.schemas.message import Message as PydanticMessage, MessageUpdate +from letta.schemas.message import Message as PydanticMessage, MessageSearchResult, MessageUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.services.file_manager import FileManager @@ -314,7 +313,6 @@ class MessageManager: self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser, - embedding_config: Optional[EmbeddingConfig] = None, strict_mode: bool = False, project_id: Optional[str] = None, ) -> List[PydanticMessage]: @@ -324,7 +322,6 @@ class MessageManager: Args: pydantic_msgs: List of Pydantic message models to create actor: User performing the action - embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer strict_mode: If True, wait for embedding to complete; if False, run in background project_id: Optional project ID for the messages (for Turbopuffer indexing) @@ -365,20 +362,20 @@ class MessageManager: result = [msg.to_pydantic() for msg in created_messages] await session.commit() - # embed messages in turbopuffer if enabled and embedding_config provided + # embed messages in turbopuffer if enabled from letta.helpers.tpuf_client import should_use_tpuf_for_messages - if should_use_tpuf_for_messages() and embedding_config and result: + if should_use_tpuf_for_messages() and result: # extract agent_id from the first message (all should have same agent_id) agent_id = result[0].agent_id if agent_id: if strict_mode: # wait for embedding to complete - await self._embed_messages_background(result, embedding_config, actor, agent_id, project_id) + await self._embed_messages_background(result, actor, agent_id, project_id) else: # fire and forget - run embedding in background fire_and_forget( - self._embed_messages_background(result, embedding_config, actor, agent_id, project_id), + self._embed_messages_background(result, actor, agent_id, project_id), task_name=f"embed_messages_for_agent_{agent_id}", ) @@ -387,7 +384,6 @@ class MessageManager: async def _embed_messages_background( self, messages: List[PydanticMessage], - embedding_config: EmbeddingConfig, actor: PydanticUser, agent_id: str, project_id: Optional[str] = None, @@ -396,14 +392,12 @@ class MessageManager: Args: messages: List of messages to embed - embedding_config: Embedding configuration actor: User performing the action agent_id: Agent ID for the messages project_id: Optional project ID for the messages """ try: from letta.helpers.tpuf_client import TurbopufferClient - from letta.llm_api.llm_client import LLMClient # extract text content from each message message_texts = [] @@ -423,21 +417,14 @@ class MessageManager: created_ats.append(msg.created_at) if message_texts: - # generate embeddings using provided config - embedding_client = LLMClient.create( - provider_type=embedding_config.embedding_endpoint_type, - actor=actor, - ) - embeddings = await embedding_client.request_embeddings(message_texts, embedding_config) - - # insert to turbopuffer + # insert to turbopuffer - TurbopufferClient will generate embeddings internally tpuf_client = TurbopufferClient() await tpuf_client.insert_messages( agent_id=agent_id, message_texts=message_texts, - embeddings=embeddings, message_ids=message_ids, organization_id=actor.organization_id, + actor=actor, roles=roles, created_ats=created_ats, project_id=project_id, @@ -550,7 +537,6 @@ class MessageManager: message_id: str, message_update: MessageUpdate, actor: PydanticUser, - embedding_config: Optional[EmbeddingConfig] = None, strict_mode: bool = False, project_id: Optional[str] = None, ) -> PydanticMessage: @@ -562,7 +548,6 @@ class MessageManager: message_id: ID of the message to update message_update: Update data for the message actor: User performing the action - embedding_config: Optional embedding configuration for Turbopuffer strict_mode: If True, wait for embedding update to complete; if False, run in background project_id: Optional project ID for the message (for Turbopuffer indexing) """ @@ -582,7 +567,7 @@ class MessageManager: # update message in turbopuffer if enabled (delete and re-insert) from letta.helpers.tpuf_client import should_use_tpuf_for_messages - if should_use_tpuf_for_messages() and embedding_config and pydantic_message.agent_id: + if should_use_tpuf_for_messages() and pydantic_message.agent_id: # extract text content from updated message text = self._extract_message_text(pydantic_message) @@ -590,51 +575,42 @@ class MessageManager: if text: if strict_mode: # wait for embedding update to complete - await self._update_message_embedding_background(pydantic_message, text, embedding_config, actor, project_id) + await self._update_message_embedding_background(pydantic_message, text, actor, project_id) else: # fire and forget - run embedding update in background fire_and_forget( - self._update_message_embedding_background(pydantic_message, text, embedding_config, actor, project_id), + self._update_message_embedding_background(pydantic_message, text, actor, project_id), task_name=f"update_message_embedding_{message_id}", ) return pydantic_message async def _update_message_embedding_background( - self, message: PydanticMessage, text: str, embedding_config: EmbeddingConfig, actor: PydanticUser, project_id: Optional[str] = None + self, message: PydanticMessage, text: str, actor: PydanticUser, project_id: Optional[str] = None ) -> None: """Background task to update a message's embedding in Turbopuffer. Args: message: The updated message text: Extracted text content from the message - embedding_config: Embedding configuration actor: User performing the action project_id: Optional project ID for the message """ try: from letta.helpers.tpuf_client import TurbopufferClient - from letta.llm_api.llm_client import LLMClient tpuf_client = TurbopufferClient() # delete old message from turbopuffer await tpuf_client.delete_messages(agent_id=message.agent_id, organization_id=actor.organization_id, message_ids=[message.id]) - # generate new embedding - embedding_client = LLMClient.create( - provider_type=embedding_config.embedding_endpoint_type, - actor=actor, - ) - embeddings = await embedding_client.request_embeddings([text], embedding_config) - - # re-insert with updated content + # re-insert with updated content - TurbopufferClient will generate embeddings internally await tpuf_client.insert_messages( agent_id=message.agent_id, message_texts=[text], - embeddings=embeddings, message_ids=[message.id], organization_id=actor.organization_id, + actor=actor, roles=[message.role], created_ats=[message.created_at], project_id=project_id, @@ -1119,13 +1095,11 @@ class MessageManager: agent_id: str, actor: PydanticUser, query_text: Optional[str] = None, - query_embedding: Optional[List[float]] = None, search_mode: str = "hybrid", roles: Optional[List[MessageRole]] = None, limit: int = 50, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, - embedding_config: Optional[EmbeddingConfig] = None, ) -> List[Tuple[PydanticMessage, dict]]: """ Search messages using Turbopuffer if enabled, otherwise fall back to SQL search. @@ -1133,14 +1107,12 @@ class MessageManager: Args: agent_id: ID of the agent whose messages to search actor: User performing the search - query_text: Text query for full-text search - query_embedding: Optional pre-computed embedding for vector search + query_text: Text query (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes) search_mode: "vector", "fts", "hybrid", or "timestamp" (default: "hybrid") roles: Optional list of message roles to filter by limit: Maximum number of results to return start_date: Optional filter for messages created after this date end_date: Optional filter for messages created on or before this date (inclusive) - embedding_config: Optional embedding configuration for generating query embedding Returns: List of tuples (message, metadata) where metadata contains relevance scores @@ -1150,36 +1122,12 @@ class MessageManager: # check if we should use turbopuffer if should_use_tpuf_for_messages(): try: - # generate embedding if needed and not provided - if search_mode in ["vector", "hybrid"] and query_embedding is None and query_text: - if not embedding_config: - # fall back to SQL search if no embedding config - logger.warning("No embedding config provided for vector search, falling back to SQL") - return await self.list_messages_for_agent_async( - agent_id=agent_id, - actor=actor, - query_text=query_text, - roles=roles, - limit=limit, - ascending=False, - ) - - # generate embedding from query text - from letta.llm_api.llm_client import LLMClient - - embedding_client = LLMClient.create( - provider_type=embedding_config.embedding_endpoint_type, - actor=actor, - ) - embeddings = await embedding_client.request_embeddings([query_text], embedding_config) - query_embedding = embeddings[0] - - # use turbopuffer for search + # use turbopuffer for search - TurbopufferClient will generate embeddings internally tpuf_client = TurbopufferClient() - results = await tpuf_client.query_messages( + results = await tpuf_client.query_messages_by_agent_id( agent_id=agent_id, organization_id=actor.organization_id, - query_embedding=query_embedding, + actor=actor, query_text=query_text, search_mode=search_mode, top_k=limit, @@ -1255,3 +1203,76 @@ class MessageManager: } message_tuples.append((message, metadata)) return message_tuples + + async def search_messages_org_async( + self, + actor: PydanticUser, + query_text: Optional[str] = None, + search_mode: str = "hybrid", + roles: Optional[List[MessageRole]] = None, + project_id: Optional[str] = None, + limit: int = 50, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> List[MessageSearchResult]: + """ + Search messages across entire organization using Turbopuffer. + + Args: + actor: User performing the search (must have org access) + query_text: Text query for full-text search + search_mode: "vector", "fts", or "hybrid" (default: "hybrid") + roles: Optional list of message roles to filter by + project_id: Optional project ID to filter messages by + limit: Maximum number of results to return + start_date: Optional filter for messages created after this date + end_date: Optional filter for messages created on or before this date (inclusive) + + Returns: + List of MessageSearchResult objects with scoring details + + Raises: + ValueError: If message embedding or Turbopuffer is not enabled + """ + from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages + + # check if turbopuffer is enabled + # TODO: extend to non-Turbopuffer in the future. + if not should_use_tpuf_for_messages(): + raise ValueError("Message search requires message embedding, OpenAI, and Turbopuffer to be enabled.") + + # use turbopuffer for search - TurbopufferClient will generate embeddings internally + tpuf_client = TurbopufferClient() + results = await tpuf_client.query_messages_by_org_id( + organization_id=actor.organization_id, + actor=actor, + query_text=query_text, + search_mode=search_mode, + top_k=limit, + roles=roles, + project_id=project_id, + start_date=start_date, + end_date=end_date, + ) + + # convert results to MessageSearchResult objects + if not results: + return [] + + # create message mapping + message_ids = [msg_dict["id"] for msg_dict, _, _ in results] + messages = await self.get_messages_by_ids_async(message_ids=message_ids, actor=actor) + message_mapping = {message.id: message for message in messages} + + # create search results using list comprehension + return [ + MessageSearchResult( + message=message_mapping.get(msg_dict["id"]), + fts_score=metadata.get("fts_score"), + fts_rank=metadata.get("fts_rank"), + vector_score=metadata.get("vector_score"), + vector_rank=metadata.get("vector_rank"), + rrf_score=rrf_score, + ) + for msg_dict, rrf_score, metadata in results + ] diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 29033094..a5201554 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -623,12 +623,13 @@ class PassageManager: passage_texts = [p.text for p in passages] # Insert to Turbopuffer with the same IDs as SQL + # TurbopufferClient will generate embeddings internally using default config await tpuf_client.insert_archival_memories( archive_id=archive.id, text_chunks=passage_texts, - embeddings=embeddings, passage_ids=passage_ids, # Use same IDs as SQL organization_id=actor.organization_id, + actor=actor, tags=tags, created_at=passages[0].created_at if passages else None, ) diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index fdcf0327..d6e0fe01 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -195,7 +195,6 @@ class Summarizer: await self.message_manager.create_many_messages_async( pydantic_msgs=[summary_message_obj], actor=self.actor, - embedding_config=agent_state.embedding_config, project_id=agent_state.project_id, ) diff --git a/letta/services/tool_executor/core_tool_executor.py b/letta/services/tool_executor/core_tool_executor.py index 7b07f17c..04bac420 100644 --- a/letta/services/tool_executor/core_tool_executor.py +++ b/letta/services/tool_executor/core_tool_executor.py @@ -163,7 +163,6 @@ class LettaCoreToolExecutor(ToolExecutor): limit=search_limit, start_date=start_datetime, end_date=end_datetime, - embedding_config=agent_state.embedding_config, ) if len(message_results) == 0: diff --git a/tests/integration_test_turbopuffer.py b/tests/integration_test_turbopuffer.py index fe88522a..27dc983d 100644 --- a/tests/integration_test_turbopuffer.py +++ b/tests/integration_test_turbopuffer.py @@ -233,7 +233,7 @@ class TestTurbopufferIntegration: pass @pytest.mark.asyncio - async def test_turbopuffer_metadata_attributes(self, enable_turbopuffer): + async def test_turbopuffer_metadata_attributes(self, default_user, enable_turbopuffer): """Test that Turbopuffer properly stores and retrieves metadata attributes""" # Only run if we have a real API key @@ -273,17 +273,16 @@ class TestTurbopufferIntegration: result = await client.insert_archival_memories( archive_id=archive_id, text_chunks=[d["text"] for d in test_data], - embeddings=[d["vector"] for d in test_data], passage_ids=[d["id"] for d in test_data], organization_id="org-123", # Default org + actor=default_user, created_at=datetime.now(timezone.utc), ) assert len(result) == 3 # Query all passages (no tag filtering) - query_vector = [0.15] * 1536 - results = await client.query_passages(archive_id=archive_id, query_embedding=query_vector, top_k=10) + results = await client.query_passages(archive_id=archive_id, actor=default_user, top_k=10) # Should get all passages assert len(results) == 3 # All three passages @@ -339,7 +338,7 @@ class TestTurbopufferIntegration: @pytest.mark.asyncio @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing") - async def test_hybrid_search_with_real_tpuf(self, enable_turbopuffer): + async def test_hybrid_search_with_real_tpuf(self, default_user, enable_turbopuffer): """Test hybrid search functionality combining vector and full-text search""" import uuid @@ -366,13 +365,14 @@ class TestTurbopufferIntegration: # Insert passages await client.insert_archival_memories( - archive_id=archive_id, text_chunks=texts, embeddings=embeddings, passage_ids=passage_ids, organization_id=org_id + archive_id=archive_id, text_chunks=texts, passage_ids=passage_ids, organization_id=org_id, actor=default_user ) # Test vector-only search vector_results = await client.query_passages( archive_id=archive_id, - query_embedding=[1.0, 6.0, 11.0], # similar to second passage embedding + actor=default_user, + query_text="python programming tutorial", search_mode="vector", top_k=3, ) @@ -382,7 +382,7 @@ class TestTurbopufferIntegration: # Test FTS-only search fts_results = await client.query_passages( - archive_id=archive_id, query_text="Turbopuffer vector database", search_mode="fts", top_k=3 + archive_id=archive_id, actor=default_user, query_text="Turbopuffer vector database", search_mode="fts", top_k=3 ) assert 0 < len(fts_results) <= 3 # should find passages mentioning Turbopuffer @@ -393,7 +393,7 @@ class TestTurbopufferIntegration: # Test hybrid search hybrid_results = await client.query_passages( archive_id=archive_id, - query_embedding=[2.0, 7.0, 12.0], + actor=default_user, query_text="vector search Turbopuffer", search_mode="hybrid", top_k=3, @@ -412,7 +412,7 @@ class TestTurbopufferIntegration: # Test with different weights vector_heavy_results = await client.query_passages( archive_id=archive_id, - query_embedding=[0.0, 5.0, 10.0], # very similar to first passage + actor=default_user, query_text="quick brown fox", # matches second passage search_mode="hybrid", top_k=3, @@ -423,16 +423,13 @@ class TestTurbopufferIntegration: # all results should have scores assert all(isinstance(score, float) for _, score, _ in vector_heavy_results) - # Test error handling - missing text for hybrid mode (embedding provided but text missing) - with pytest.raises(ValueError, match="Both query_embedding and query_text are required"): - await client.query_passages(archive_id=archive_id, query_embedding=[1.0, 2.0, 3.0], search_mode="hybrid", top_k=3) - - # Test error handling - missing embedding for hybrid mode (text provided but embedding missing) - with pytest.raises(ValueError, match="Both query_embedding and query_text are required"): - await client.query_passages(archive_id=archive_id, query_text="test", search_mode="hybrid", top_k=3) + # Test with different search modes + await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="vector", top_k=3) + await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="fts", top_k=3) + await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="hybrid", top_k=3) # Test explicit timestamp mode - timestamp_results = await client.query_passages(archive_id=archive_id, search_mode="timestamp", top_k=3) + timestamp_results = await client.query_passages(archive_id=archive_id, actor=default_user, search_mode="timestamp", top_k=3) assert len(timestamp_results) <= 3 # Should return passages ordered by timestamp (most recent first) assert all(isinstance(passage, Passage) for passage, _, _ in timestamp_results) @@ -446,7 +443,7 @@ class TestTurbopufferIntegration: @pytest.mark.asyncio @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing") - async def test_tag_filtering_with_real_tpuf(self, enable_turbopuffer): + async def test_tag_filtering_with_real_tpuf(self, default_user, enable_turbopuffer): """Test tag filtering functionality with AND and OR logic""" import uuid @@ -479,13 +476,13 @@ class TestTurbopufferIntegration: passage_ids = [f"passage-{str(uuid.uuid4())}" for _ in texts] # Insert passages with tags - for i, (text, tags, embedding, passage_id) in enumerate(zip(texts, tag_sets, embeddings, passage_ids)): + for i, (text, tags, passage_id) in enumerate(zip(texts, tag_sets, passage_ids)): await client.insert_archival_memories( archive_id=archive_id, text_chunks=[text], - embeddings=[embedding], passage_ids=[passage_id], organization_id=org_id, + actor=default_user, tags=tags, created_at=datetime.now(timezone.utc), ) @@ -493,7 +490,8 @@ class TestTurbopufferIntegration: # Test tag filtering with "any" mode (should find passages with any of the specified tags) python_any_results = await client.query_passages( archive_id=archive_id, - query_embedding=[1.0, 6.0, 11.0], + actor=default_user, + query_text="python programming", search_mode="vector", top_k=10, tags=["python"], @@ -511,7 +509,8 @@ class TestTurbopufferIntegration: # Test tag filtering with "all" mode python_tutorial_all_results = await client.query_passages( archive_id=archive_id, - query_embedding=[1.0, 6.0, 11.0], + actor=default_user, + query_text="python tutorial", search_mode="vector", top_k=10, tags=["python", "tutorial"], @@ -528,6 +527,7 @@ class TestTurbopufferIntegration: # Test tag filtering with FTS mode js_fts_results = await client.query_passages( archive_id=archive_id, + actor=default_user, query_text="javascript", search_mode="fts", top_k=10, @@ -545,7 +545,7 @@ class TestTurbopufferIntegration: # Test hybrid search with tags python_hybrid_results = await client.query_passages( archive_id=archive_id, - query_embedding=[2.0, 7.0, 12.0], + actor=default_user, query_text="python programming", search_mode="hybrid", top_k=10, @@ -569,7 +569,7 @@ class TestTurbopufferIntegration: pass @pytest.mark.asyncio - async def test_temporal_filtering_with_real_tpuf(self, enable_turbopuffer): + async def test_temporal_filtering_with_real_tpuf(self, default_user, enable_turbopuffer): """Test temporal filtering with date ranges""" from datetime import datetime, timedelta, timezone @@ -601,15 +601,14 @@ class TestTurbopufferIntegration: # We need to generate embeddings for the passages # For testing, we'll use simple dummy embeddings for text, timestamp in test_passages: - dummy_embedding = [1.0, 2.0, 3.0] # Simple test embedding passage_id = f"passage-{uuid.uuid4()}" await client.insert_archival_memories( archive_id=archive_id, text_chunks=[text], - embeddings=[dummy_embedding], passage_ids=[passage_id], organization_id="test-org", + actor=default_user, created_at=timestamp, ) @@ -617,7 +616,8 @@ class TestTurbopufferIntegration: three_days_ago = now - timedelta(days=3) results = await client.query_passages( archive_id=archive_id, - query_embedding=[1.0, 2.0, 3.0], + actor=default_user, + query_text="meeting notes", search_mode="vector", top_k=10, start_date=three_days_ago, @@ -637,7 +637,8 @@ class TestTurbopufferIntegration: two_weeks_ago = now - timedelta(days=14) results = await client.query_passages( archive_id=archive_id, - query_embedding=[1.0, 2.0, 3.0], + actor=default_user, + query_text="meeting notes", search_mode="vector", top_k=10, start_date=two_weeks_ago, @@ -652,7 +653,8 @@ class TestTurbopufferIntegration: # Test 3: Query with only end_date (everything before yesterday) results = await client.query_passages( archive_id=archive_id, - query_embedding=[1.0, 2.0, 3.0], + actor=default_user, + query_text="meeting notes", search_mode="vector", top_k=10, end_date=yesterday + timedelta(hours=12), # Middle of yesterday @@ -667,6 +669,7 @@ class TestTurbopufferIntegration: # Test 4: Test with FTS mode and date filtering results = await client.query_passages( archive_id=archive_id, + actor=default_user, query_text="meeting notes project", search_mode="fts", top_k=10, @@ -682,7 +685,7 @@ class TestTurbopufferIntegration: # Test 5: Test with hybrid mode and date filtering results = await client.query_passages( archive_id=archive_id, - query_embedding=[1.0, 2.0, 3.0], + actor=default_user, query_text="sprint review", search_mode="hybrid", top_k=10, @@ -934,11 +937,9 @@ class TestTurbopufferMessagesIntegration: ), ] - # Create messages without embedding_config created = await server.message_manager.create_many_messages_async( pydantic_msgs=messages, actor=default_user, - embedding_config=None, # No config provided ) assert len(created) == 2 @@ -1057,7 +1058,7 @@ class TestTurbopufferMessagesIntegration: @pytest.mark.asyncio @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") - async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding): + async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding, default_user): """Test actual message embedding and storage in Turbopuffer""" import uuid from datetime import datetime, timezone @@ -1087,9 +1088,9 @@ class TestTurbopufferMessagesIntegration: success = await client.insert_messages( agent_id=agent_id, message_texts=message_texts, - embeddings=embeddings, message_ids=message_ids, organization_id=org_id, + actor=default_user, roles=roles, created_ats=created_ats, ) @@ -1097,11 +1098,8 @@ class TestTurbopufferMessagesIntegration: assert success == True # Verify we can query the messages - results = await client.query_messages( - agent_id=agent_id, - organization_id=org_id, - search_mode="timestamp", - top_k=10, + results = await client.query_messages_by_agent_id( + agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, actor=default_user ) assert len(results) == 3 @@ -1121,7 +1119,7 @@ class TestTurbopufferMessagesIntegration: @pytest.mark.asyncio @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") - async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding): + async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding, default_user): """Test vector search on messages in Turbopuffer""" import uuid from datetime import datetime, timezone @@ -1145,29 +1143,23 @@ class TestTurbopufferMessagesIntegration: created_ats = [datetime.now(timezone.utc) for _ in message_texts] # Create embeddings that reflect content similarity - embeddings = [ - [1.0, 0.0, 0.0], # Python programming - [0.0, 1.0, 0.0], # JavaScript web - [0.8, 0.0, 0.2], # ML with Python (similar to first) - ] - # Insert messages await client.insert_messages( agent_id=agent_id, message_texts=message_texts, - embeddings=embeddings, message_ids=message_ids, organization_id=org_id, + actor=default_user, roles=roles, created_ats=created_ats, ) # Search for Python-related messages using vector search - query_embedding = [0.9, 0.0, 0.1] # Similar to Python messages - results = await client.query_messages( + results = await client.query_messages_by_agent_id( agent_id=agent_id, organization_id=org_id, - query_embedding=query_embedding, + actor=default_user, + query_text="Python programming", search_mode="vector", top_k=2, ) @@ -1187,7 +1179,7 @@ class TestTurbopufferMessagesIntegration: @pytest.mark.asyncio @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") - async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding): + async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding, default_user): """Test hybrid search combining vector and FTS for messages""" import uuid from datetime import datetime, timezone @@ -1211,30 +1203,22 @@ class TestTurbopufferMessagesIntegration: roles = [MessageRole.assistant] * len(message_texts) created_ats = [datetime.now(timezone.utc) for _ in message_texts] - # Embeddings - embeddings = [ - [0.1, 0.9, 0.0], # fox text - [0.9, 0.1, 0.0], # ML algorithms - [0.5, 0.5, 0.0], # Quick Python - [0.8, 0.2, 0.0], # Deep learning - ] - # Insert messages await client.insert_messages( agent_id=agent_id, message_texts=message_texts, - embeddings=embeddings, message_ids=message_ids, organization_id=org_id, + actor=default_user, roles=roles, created_ats=created_ats, ) - # Hybrid search - vector similar to ML but text contains "quick" - results = await client.query_messages( + # Hybrid search - text search for "quick" + results = await client.query_messages_by_agent_id( agent_id=agent_id, organization_id=org_id, - query_embedding=[0.7, 0.3, 0.0], # Similar to ML messages + actor=default_user, query_text="quick", # Text search for "quick" search_mode="hybrid", top_k=3, @@ -1257,7 +1241,7 @@ class TestTurbopufferMessagesIntegration: @pytest.mark.asyncio @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") - async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding): + async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding, default_user): """Test filtering messages by role""" import uuid from datetime import datetime, timezone @@ -1283,26 +1267,21 @@ class TestTurbopufferMessagesIntegration: roles = [role for _, role in message_data] message_ids = [str(uuid.uuid4()) for _ in message_texts] created_ats = [datetime.now(timezone.utc) for _ in message_texts] - embeddings = [[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))] # Insert messages await client.insert_messages( agent_id=agent_id, message_texts=message_texts, - embeddings=embeddings, message_ids=message_ids, organization_id=org_id, + actor=default_user, roles=roles, created_ats=created_ats, ) # Query only user messages - user_results = await client.query_messages( - agent_id=agent_id, - organization_id=org_id, - search_mode="timestamp", - top_k=10, - roles=[MessageRole.user], + user_results = await client.query_messages_by_agent_id( + agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, roles=[MessageRole.user], actor=default_user ) assert len(user_results) == 2 @@ -1311,12 +1290,13 @@ class TestTurbopufferMessagesIntegration: assert msg["text"] in ["I need help with Python", "Can you explain this?"] # Query assistant and system messages - non_user_results = await client.query_messages( + non_user_results = await client.query_messages_by_agent_id( agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, roles=[MessageRole.assistant, MessageRole.system], + actor=default_user, ) assert len(non_user_results) == 3 @@ -1395,7 +1375,6 @@ class TestTurbopufferMessagesIntegration: ) ], actor=default_user, - embedding_config=embedding_config, strict_mode=True, ) @@ -1409,7 +1388,6 @@ class TestTurbopufferMessagesIntegration: query_text="Python", search_mode="fts", limit=10, - embedding_config=embedding_config, ) assert len(python_results) > 0 assert any(msg.id == message_id for msg, metadata in python_results) @@ -1419,7 +1397,6 @@ class TestTurbopufferMessagesIntegration: message_id=message_id, message_update=MessageUpdate(content="Updated content about JavaScript development"), actor=default_user, - embedding_config=embedding_config, strict_mode=True, ) @@ -1432,7 +1409,6 @@ class TestTurbopufferMessagesIntegration: query_text="Python", search_mode="fts", limit=10, - embedding_config=embedding_config, ) # Should either find no results or results that don't include our message assert not any(msg.id == message_id for msg, metadata in python_results_after) @@ -1444,7 +1420,6 @@ class TestTurbopufferMessagesIntegration: query_text="JavaScript", search_mode="fts", limit=10, - embedding_config=embedding_config, ) assert len(js_results) > 0 assert any(msg.id == message_id for msg, metadata in js_results) @@ -1497,7 +1472,6 @@ class TestTurbopufferMessagesIntegration: ) ], actor=default_user, - embedding_config=embedding_config, strict_mode=True, ) agent_a_messages.extend(msgs) @@ -1514,7 +1488,6 @@ class TestTurbopufferMessagesIntegration: ) ], actor=default_user, - embedding_config=embedding_config, strict_mode=True, ) agent_b_messages.extend(msgs) @@ -1526,7 +1499,6 @@ class TestTurbopufferMessagesIntegration: query_text="Agent A", search_mode="fts", limit=10, - embedding_config=embedding_config, ) assert len(agent_a_search) == 5 @@ -1536,7 +1508,6 @@ class TestTurbopufferMessagesIntegration: query_text="Agent B", search_mode="fts", limit=10, - embedding_config=embedding_config, ) assert len(agent_b_search) == 3 @@ -1559,7 +1530,6 @@ class TestTurbopufferMessagesIntegration: query_text="Agent A", search_mode="fts", limit=10, - embedding_config=embedding_config, ) assert len(agent_a_final) == 2 # Verify the remaining messages are the correct ones @@ -1574,7 +1544,6 @@ class TestTurbopufferMessagesIntegration: query_text="Agent B", search_mode="fts", limit=10, - embedding_config=embedding_config, ) assert len(agent_b_final) == 0 @@ -1583,84 +1552,6 @@ class TestTurbopufferMessagesIntegration: await server.agent_manager.delete_agent_async(agent_a.id, default_user) await server.agent_manager.delete_agent_async(agent_b.id, default_user) - @pytest.mark.asyncio - @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") - async def test_crud_operations_without_embedding_config(self, server, default_user, sarah_agent, enable_message_embedding): - """Test that CRUD operations handle missing embedding_config gracefully""" - from letta.schemas.message import MessageUpdate - - embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai") - - # Create message WITH embedding_config - messages = await server.message_manager.create_many_messages_async( - pydantic_msgs=[ - PydanticMessage( - role=MessageRole.user, - content=[TextContent(text="Message with searchable content about databases")], - agent_id=sarah_agent.id, - ) - ], - actor=default_user, - embedding_config=embedding_config, - strict_mode=True, - ) - - assert len(messages) == 1 - message_id = messages[0].id - - # Verify message is searchable initially - initial_search = await server.message_manager.search_messages_async( - agent_id=sarah_agent.id, - actor=default_user, - query_text="databases", - search_mode="fts", - limit=10, - embedding_config=embedding_config, - ) - assert len(initial_search) > 0 - assert any(msg.id == message_id for msg, metadata in initial_search) - - # Update message WITHOUT embedding_config - should update postgres but not turbopuffer - updated_message = await server.message_manager.update_message_by_id_async( - message_id=message_id, - message_update=MessageUpdate(content="Updated content about algorithms"), - actor=default_user, - embedding_config=None, # No config provided - ) - - # Verify postgres was updated - assert updated_message.id == message_id - updated_text = server.message_manager._extract_message_text(updated_message) - assert "algorithms" in updated_text - assert "databases" not in updated_text - - # Original search term should STILL find the message (turbopuffer wasn't updated) - still_searchable = await server.message_manager.search_messages_async( - agent_id=sarah_agent.id, - actor=default_user, - query_text="databases", - search_mode="fts", - limit=10, - embedding_config=embedding_config, - ) - assert len(still_searchable) > 0 - assert any(msg.id == message_id for msg, metadata in still_searchable) - - # New content should NOT be searchable (wasn't re-indexed) - not_searchable = await server.message_manager.search_messages_async( - agent_id=sarah_agent.id, - actor=default_user, - query_text="algorithms", - search_mode="fts", - limit=10, - embedding_config=embedding_config, - ) - # Should either find no results or results that don't include our message - assert not any(msg.id == message_id for msg, metadata in not_searchable) - - # Clean up - await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True) - @pytest.mark.asyncio @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") async def test_turbopuffer_failure_does_not_break_postgres(self, server, default_user, sarah_agent, enable_message_embedding): @@ -1681,7 +1572,6 @@ class TestTurbopufferMessagesIntegration: ) ], actor=default_user, - embedding_config=embedding_config, ) assert len(messages) == 1 @@ -1702,7 +1592,6 @@ class TestTurbopufferMessagesIntegration: message_id=message_id, message_update=MessageUpdate(content="Updated despite turbopuffer failure"), actor=default_user, - embedding_config=embedding_config, strict_mode=False, # Don't fail on turbopuffer errors - that's what we're testing! ) @@ -1722,7 +1611,6 @@ class TestTurbopufferMessagesIntegration: ) ], actor=default_user, - embedding_config=None, # Create without embedding to avoid mock issues ) message_to_delete_id = messages2[0].id @@ -1741,7 +1629,7 @@ class TestTurbopufferMessagesIntegration: await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=False) async def wait_for_embedding( - self, agent_id: str, message_id: str, organization_id: str, max_wait: float = 10.0, poll_interval: float = 0.5 + self, agent_id: str, message_id: str, organization_id: str, actor, max_wait: float = 10.0, poll_interval: float = 0.5 ) -> bool: """Poll Turbopuffer directly to check if a message has been embedded. @@ -1765,9 +1653,10 @@ class TestTurbopufferMessagesIntegration: while asyncio.get_event_loop().time() - start_time < max_wait: try: # Query Turbopuffer directly using timestamp mode to get all messages - results = await client.query_messages( + results = await client.query_messages_by_agent_id( agent_id=agent_id, organization_id=organization_id, + actor=actor, search_mode="timestamp", top_k=100, # Get more messages to ensure we find it ) @@ -1800,7 +1689,6 @@ class TestTurbopufferMessagesIntegration: ) ], actor=default_user, - embedding_config=embedding_config, strict_mode=False, # Background mode ) @@ -1814,7 +1702,12 @@ class TestTurbopufferMessagesIntegration: # Poll for embedding completion by querying Turbopuffer directly embedded = await self.wait_for_embedding( - agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5 + agent_id=sarah_agent.id, + message_id=message_id, + organization_id=default_user.organization_id, + actor=default_user, + max_wait=10.0, + poll_interval=0.5, ) assert embedded, "Message was not embedded in Turbopuffer within timeout" @@ -1825,7 +1718,6 @@ class TestTurbopufferMessagesIntegration: query_text="Python programming", search_mode="fts", limit=10, - embedding_config=embedding_config, ) assert len(search_results) > 0 assert any(msg.id == message_id for msg, _ in search_results) @@ -1851,7 +1743,6 @@ class TestTurbopufferMessagesIntegration: ) ], actor=default_user, - embedding_config=embedding_config, strict_mode=True, # Ensure initial embedding ) @@ -1865,7 +1756,6 @@ class TestTurbopufferMessagesIntegration: query_text="databases", search_mode="fts", limit=10, - embedding_config=embedding_config, ) assert any(msg.id == message_id for msg, _ in initial_results) @@ -1874,7 +1764,6 @@ class TestTurbopufferMessagesIntegration: message_id=message_id, message_update=MessageUpdate(content="Updated content about machine learning"), actor=default_user, - embedding_config=embedding_config, strict_mode=False, # Background mode ) @@ -1890,7 +1779,12 @@ class TestTurbopufferMessagesIntegration: # Poll for the update to be reflected in Turbopuffer # We check by searching for the new content embedded = await self.wait_for_embedding( - agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5 + agent_id=sarah_agent.id, + message_id=message_id, + organization_id=default_user.organization_id, + actor=default_user, + max_wait=10.0, + poll_interval=0.5, ) assert embedded, "Updated message was not re-embedded within timeout" @@ -1901,7 +1795,6 @@ class TestTurbopufferMessagesIntegration: query_text="machine learning", search_mode="fts", limit=10, - embedding_config=embedding_config, ) assert any(msg.id == message_id for msg, _ in new_results) @@ -1914,7 +1807,6 @@ class TestTurbopufferMessagesIntegration: query_text="databases", search_mode="fts", limit=10, - embedding_config=embedding_config, ) # The message shouldn't match the old search term anymore if len(old_results) > 0: @@ -1929,7 +1821,7 @@ class TestTurbopufferMessagesIntegration: @pytest.mark.asyncio @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") - async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding): + async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding, default_user): """Test filtering messages by date range""" import uuid from datetime import datetime, timedelta, timezone @@ -1959,21 +1851,17 @@ class TestTurbopufferMessagesIntegration: await client.insert_messages( agent_id=agent_id, message_texts=[text], - embeddings=[[1.0, 2.0, 3.0]], message_ids=[str(uuid.uuid4())], organization_id=org_id, + actor=default_user, roles=[MessageRole.assistant], created_ats=[timestamp], ) # Query messages from the last 3 days three_days_ago = now - timedelta(days=3) - recent_results = await client.query_messages( - agent_id=agent_id, - organization_id=org_id, - search_mode="timestamp", - top_k=10, - start_date=three_days_ago, + recent_results = await client.query_messages_by_agent_id( + agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, start_date=three_days_ago, actor=default_user ) # Should get today's and yesterday's messages @@ -1984,13 +1872,14 @@ class TestTurbopufferMessagesIntegration: # Query messages between 2 weeks ago and 1 week ago two_weeks_ago = now - timedelta(days=14) - week_results = await client.query_messages( + week_results = await client.query_messages_by_agent_id( agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, start_date=two_weeks_ago, end_date=last_week + timedelta(days=1), # Include last week's message + actor=default_user, ) # Should get only last week's message @@ -1998,10 +1887,11 @@ class TestTurbopufferMessagesIntegration: assert week_results[0][0]["text"] == "Last week's message" # Query with vector search and date filtering - filtered_vector_results = await client.query_messages( + filtered_vector_results = await client.query_messages_by_agent_id( agent_id=agent_id, organization_id=org_id, - query_embedding=[1.0, 2.0, 3.0], + actor=default_user, + query_text="message", search_mode="vector", top_k=10, start_date=three_days_ago, @@ -2101,7 +1991,7 @@ class TestNamespaceTracking: @pytest.mark.asyncio @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") async def test_message_project_id_filtering(self, server, sarah_agent, default_user, enable_turbopuffer, enable_message_embedding): - """Test that project_id filtering works correctly in query_messages""" + """Test that project_id filtering works correctly in query_messages_by_agent_id""" from letta.schemas.letta_message_content import TextContent # Create two project IDs @@ -2124,24 +2014,15 @@ class TestNamespaceTracking: # Insert messages with their respective project IDs tpuf_client = TurbopufferClient() - # Generate embeddings - from letta.llm_api.llm_client import LLMClient - - embedding_client = LLMClient.create( - provider_type=sarah_agent.embedding_config.embedding_endpoint_type, - actor=default_user, - ) - embeddings = await embedding_client.request_embeddings( - [message_a.content[0].text, message_b.content[0].text], sarah_agent.embedding_config - ) + # Embeddings will be generated automatically by the client # Insert message A with project_a_id await tpuf_client.insert_messages( agent_id=sarah_agent.id, message_texts=[message_a.content[0].text], - embeddings=[embeddings[0]], message_ids=[message_a.id], organization_id=default_user.organization_id, + actor=default_user, roles=[message_a.role], created_ats=[message_a.created_at], project_id=project_a_id, @@ -2151,9 +2032,9 @@ class TestNamespaceTracking: await tpuf_client.insert_messages( agent_id=sarah_agent.id, message_texts=[message_b.content[0].text], - embeddings=[embeddings[1]], message_ids=[message_b.id], organization_id=default_user.organization_id, + actor=default_user, roles=[message_b.role], created_ats=[message_b.created_at], project_id=project_b_id, @@ -2162,12 +2043,13 @@ class TestNamespaceTracking: # Poll for message A with project_a_id filter max_retries = 10 for i in range(max_retries): - results_a = await tpuf_client.query_messages( + results_a = await tpuf_client.query_messages_by_agent_id( agent_id=sarah_agent.id, organization_id=default_user.organization_id, search_mode="timestamp", # Simple timestamp retrieval top_k=10, project_id=project_a_id, + actor=default_user, ) if len(results_a) == 1 and results_a[0][0]["id"] == message_a.id: break @@ -2179,12 +2061,13 @@ class TestNamespaceTracking: # Poll for message B with project_b_id filter for i in range(max_retries): - results_b = await tpuf_client.query_messages( + results_b = await tpuf_client.query_messages_by_agent_id( agent_id=sarah_agent.id, organization_id=default_user.organization_id, search_mode="timestamp", top_k=10, project_id=project_b_id, + actor=default_user, ) if len(results_b) == 1 and results_b[0][0]["id"] == message_b.id: break @@ -2195,12 +2078,13 @@ class TestNamespaceTracking: assert "JavaScript" in results_b[0][0]["text"] # Query without project filter - should find both - results_all = await tpuf_client.query_messages( + results_all = await tpuf_client.query_messages_by_agent_id( agent_id=sarah_agent.id, organization_id=default_user.organization_id, search_mode="timestamp", top_k=10, project_id=None, # No filter + actor=default_user, ) assert len(results_all) >= 2 # May have other messages from setup