From 516f2963e0f78dad3744dd03353ca8dc9bca848b Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Mon, 8 Sep 2025 18:46:41 -0700 Subject: [PATCH] feat: Add turbopuffer embedder by default [LET-4253] (#4476) * Adapt to turbopuffer embedder * Make turbopuffer search more efficient over all source ids * Combine turbopuffer and pinecone hybrid * Fix test sources --- letta/helpers/tpuf_client.py | 310 ++++++++++- letta/server/rest_api/routers/v1/folders.py | 28 +- letta/server/rest_api/routers/v1/sources.py | 28 +- letta/services/agent_serialization_manager.py | 12 +- .../file_processor/embedder/base_embedder.py | 5 + .../embedder/openai_embedder.py | 4 + .../embedder/pinecone_embedder.py | 6 +- .../embedder/turbopuffer_embedder.py | 71 +++ .../services/file_processor/file_processor.py | 16 +- letta/services/source_manager.py | 17 +- .../tool_executor/files_tool_executor.py | 141 ++++- tests/test_file_processor.py | 2 +- tests/test_sources.py | 527 ++++++++++++------ 13 files changed, 968 insertions(+), 199 deletions(-) create mode 100644 letta/services/file_processor/embedder/turbopuffer_embedder.py diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index afdebec5..e7e2c8b0 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -16,12 +16,12 @@ logger = logging.getLogger(__name__) def should_use_tpuf() -> bool: # We need OpenAI since we default to their embedding model - return bool(settings.use_tpuf) and bool(settings.tpuf_api_key) + return bool(settings.use_tpuf) and bool(settings.tpuf_api_key) and bool(model_settings.openai_api_key) def should_use_tpuf_for_messages() -> bool: """Check if Turbopuffer should be used for messages.""" - return should_use_tpuf() and bool(settings.embed_all_messages) and bool(model_settings.openai_api_key) + return should_use_tpuf() and bool(settings.embed_all_messages) class TurbopufferClient: @@ -1113,3 +1113,309 @@ class TurbopufferClient: except Exception as e: logger.error(f"Failed to delete all messages from Turbopuffer: {e}") raise + + # file/source passage methods + + @trace_method + async def _get_file_passages_namespace_name(self, organization_id: str) -> str: + """Get namespace name for file passages (org-scoped). + + Args: + organization_id: Organization ID for namespace generation + + Returns: + The org-scoped namespace name for file passages + """ + environment = settings.environment + if environment: + namespace_name = f"file_passages_{organization_id}_{environment.lower()}" + else: + namespace_name = f"file_passages_{organization_id}" + + return namespace_name + + @trace_method + async def insert_file_passages( + self, + source_id: str, + file_id: str, + text_chunks: List[str], + organization_id: str, + actor: "PydanticUser", + created_at: Optional[datetime] = None, + ) -> List[PydanticPassage]: + """Insert file passages into Turbopuffer using org-scoped namespace. + + Args: + source_id: ID of the source containing the file + file_id: ID of the file + text_chunks: List of text chunks to store + organization_id: Organization ID for the passages + actor: User actor for embedding generation + created_at: Optional timestamp for retroactive entries (defaults to current UTC time) + + Returns: + List of PydanticPassage objects that were inserted + """ + from turbopuffer import AsyncTurbopuffer + + if not text_chunks: + return [] + + # generate embeddings using the default config + embeddings = await self._generate_embeddings(text_chunks, actor) + + namespace_name = await self._get_file_passages_namespace_name(organization_id) + + # handle timestamp - ensure UTC + if created_at is None: + timestamp = datetime.now(timezone.utc) + else: + # ensure the provided timestamp is timezone-aware and in UTC + if created_at.tzinfo is None: + # assume UTC if no timezone provided + timestamp = created_at.replace(tzinfo=timezone.utc) + else: + # convert to UTC if in different timezone + timestamp = created_at.astimezone(timezone.utc) + + # prepare column-based data for turbopuffer - optimized for batch insert + ids = [] + vectors = [] + texts = [] + organization_ids = [] + source_ids = [] + file_ids = [] + created_ats = [] + passages = [] + + for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)): + passage = PydanticPassage( + text=text, + file_id=file_id, + source_id=source_id, + embedding=embedding, + embedding_config=self.default_embedding_config, + organization_id=actor.organization_id, + ) + passages.append(passage) + + # append to columns + ids.append(passage.id) + vectors.append(embedding) + texts.append(text) + organization_ids.append(organization_id) + source_ids.append(source_id) + file_ids.append(file_id) + created_ats.append(timestamp) + + # build column-based upsert data + upsert_columns = { + "id": ids, + "vector": vectors, + "text": texts, + "organization_id": organization_ids, + "source_id": source_ids, + "file_id": file_ids, + "created_at": created_ats, + } + + try: + # use AsyncTurbopuffer as a context manager for proper resource cleanup + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + # turbopuffer recommends column-based writes for performance + await namespace.write( + upsert_columns=upsert_columns, + distance_metric="cosine_distance", + schema={"text": {"type": "string", "full_text_search": True}}, + ) + logger.info(f"Successfully inserted {len(ids)} file passages to Turbopuffer for source {source_id}, file {file_id}") + return passages + + except Exception as e: + logger.error(f"Failed to insert file passages to Turbopuffer: {e}") + # check if it's a duplicate ID error + if "duplicate" in str(e).lower(): + logger.error("Duplicate passage IDs detected in batch") + raise + + @trace_method + async def query_file_passages( + self, + source_ids: List[str], + organization_id: str, + actor: "PydanticUser", + query_text: Optional[str] = None, + search_mode: str = "vector", # "vector", "fts", "hybrid" + top_k: int = 10, + file_id: Optional[str] = None, # optional filter by specific file + vector_weight: float = 0.5, + fts_weight: float = 0.5, + ) -> List[Tuple[PydanticPassage, float, dict]]: + """Query file passages from Turbopuffer using org-scoped namespace. + + Args: + source_ids: List of source IDs to query + organization_id: Organization ID for namespace lookup + actor: User actor for embedding generation + query_text: Text query for search + search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector") + top_k: Number of results to return + file_id: Optional file ID to filter results to a specific file + vector_weight: Weight for vector search results in hybrid mode (default: 0.5) + fts_weight: Weight for FTS results in hybrid mode (default: 0.5) + + Returns: + List of (passage, score, metadata) tuples with relevance rankings + """ + # generate embedding for vector/hybrid search if query_text is provided + query_embedding = None + if query_text and search_mode in ["vector", "hybrid"]: + embeddings = await self._generate_embeddings([query_text], actor) + query_embedding = embeddings[0] + + # check if we should fallback to timestamp-based retrieval + if query_embedding is None and query_text is None and search_mode not in ["timestamp"]: + # fallback to retrieving most recent passages when no search query is provided + search_mode = "timestamp" + + namespace_name = await self._get_file_passages_namespace_name(organization_id) + + # build filters - always filter by source_ids + if len(source_ids) == 1: + # single source_id, use Eq for efficiency + filters = [("source_id", "Eq", source_ids[0])] + else: + # multiple source_ids, use In operator + filters = [("source_id", "In", source_ids)] + + # add file filter if specified + if file_id: + filters.append(("file_id", "Eq", file_id)) + + # combine filters + final_filter = filters[0] if len(filters) == 1 else ("And", filters) + + try: + # use generic query executor + result = await self._execute_query( + namespace_name=namespace_name, + search_mode=search_mode, + query_embedding=query_embedding, + query_text=query_text, + top_k=top_k, + include_attributes=["text", "organization_id", "source_id", "file_id", "created_at"], + filters=final_filter, + vector_weight=vector_weight, + fts_weight=fts_weight, + ) + + # process results based on search mode + if search_mode == "hybrid": + # for hybrid mode, we get a multi-query response + vector_results = self._process_file_query_results(result.results[0]) + fts_results = self._process_file_query_results(result.results[1], is_fts=True) + # use RRF and include metadata with ranks + results_with_metadata = self._reciprocal_rank_fusion( + vector_results=[passage for passage, _ in vector_results], + fts_results=[passage for passage, _ in fts_results], + get_id_func=lambda p: p.id, + vector_weight=vector_weight, + fts_weight=fts_weight, + top_k=top_k, + ) + return results_with_metadata + else: + # for single queries (vector, fts, timestamp) - add basic metadata + is_fts = search_mode == "fts" + results = self._process_file_query_results(result, is_fts=is_fts) + # add simple metadata for single search modes + results_with_metadata = [] + for idx, (passage, score) in enumerate(results): + metadata = { + "combined_score": score, + f"{search_mode}_rank": idx + 1, # add the rank for this search mode + } + results_with_metadata.append((passage, score, metadata)) + return results_with_metadata + + except Exception as e: + logger.error(f"Failed to query file passages from Turbopuffer: {e}") + raise + + def _process_file_query_results(self, result, is_fts: bool = False) -> List[Tuple[PydanticPassage, float]]: + """Process results from a file query into passage objects with scores.""" + passages_with_scores = [] + + for row in result.rows: + # build metadata + metadata = {} + + # create a passage with minimal fields - embeddings are not returned from Turbopuffer + passage = PydanticPassage( + id=row.id, + text=getattr(row, "text", ""), + organization_id=getattr(row, "organization_id", None), + source_id=getattr(row, "source_id", None), # get source_id from the row + file_id=getattr(row, "file_id", None), + created_at=getattr(row, "created_at", None), + metadata_=metadata, + tags=[], + # set required fields to empty/default values since we don't store embeddings + embedding=[], # empty embedding since we don't return it from Turbopuffer + embedding_config=self.default_embedding_config, + ) + + # handle score based on search type + if is_fts: + # for FTS, use the BM25 score directly (higher is better) + score = getattr(row, "$score", 0.0) + else: + # for vector search, convert distance to similarity score + distance = getattr(row, "$dist", 0.0) + score = 1.0 - distance + + passages_with_scores.append((passage, score)) + + return passages_with_scores + + @trace_method + async def delete_file_passages(self, source_id: str, file_id: str, organization_id: str) -> bool: + """Delete all passages for a specific file from Turbopuffer.""" + from turbopuffer import AsyncTurbopuffer + + namespace_name = await self._get_file_passages_namespace_name(organization_id) + + try: + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + # use delete_by_filter to only delete passages for this file + # need to filter by both source_id and file_id + filter_expr = ("And", [("source_id", "Eq", source_id), ("file_id", "Eq", file_id)]) + result = await namespace.write(delete_by_filter=filter_expr) + logger.info( + f"Successfully deleted passages for file {file_id} from source {source_id} (deleted {result.rows_affected} rows)" + ) + return True + except Exception as e: + logger.error(f"Failed to delete file passages from Turbopuffer: {e}") + raise + + @trace_method + async def delete_source_passages(self, source_id: str, organization_id: str) -> bool: + """Delete all passages for a source from Turbopuffer.""" + from turbopuffer import AsyncTurbopuffer + + namespace_name = await self._get_file_passages_namespace_name(organization_id) + + try: + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + # delete all passages for this source + result = await namespace.write(delete_by_filter=("source_id", "Eq", source_id)) + logger.info(f"Successfully deleted all passages for source {source_id} (deleted {result.rows_affected} rows)") + return True + except Exception as e: + logger.error(f"Failed to delete source passages from Turbopuffer: {e}") + raise diff --git a/letta/server/rest_api/routers/v1/folders.py b/letta/server/rest_api/routers/v1/folders.py index dcf98474..84a59723 100644 --- a/letta/server/rest_api/routers/v1/folders.py +++ b/letta/server/rest_api/routers/v1/folders.py @@ -15,6 +15,7 @@ from letta.helpers.pinecone_utils import ( delete_source_records_from_pinecone_index, should_use_pinecone, ) +from letta.helpers.tpuf_client import should_use_tpuf from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState @@ -191,7 +192,13 @@ async def delete_folder( files = await server.file_manager.list_files(folder_id, actor) file_ids = [f.id for f in files] - if should_use_pinecone(): + if should_use_tpuf(): + logger.info(f"Deleting folder {folder_id} from Turbopuffer") + from letta.helpers.tpuf_client import TurbopufferClient + + tpuf_client = TurbopufferClient() + await tpuf_client.delete_source_passages(source_id=folder_id, organization_id=actor.organization_id) + elif should_use_pinecone(): logger.info(f"Deleting folder {folder_id} from pinecone index") await delete_source_records_from_pinecone_index(source_id=folder_id, actor=actor) @@ -450,7 +457,13 @@ async def delete_file_from_folder( await server.remove_file_from_context_windows(source_id=folder_id, file_id=deleted_file.id, actor=actor) - if should_use_pinecone(): + if should_use_tpuf(): + logger.info(f"Deleting file {file_id} from Turbopuffer") + from letta.helpers.tpuf_client import TurbopufferClient + + tpuf_client = TurbopufferClient() + await tpuf_client.delete_file_passages(source_id=folder_id, file_id=file_id, organization_id=actor.organization_id) + elif should_use_pinecone(): logger.info(f"Deleting file {file_id} from pinecone index") await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor) @@ -496,10 +509,15 @@ async def load_file_to_source_cloud( else: file_parser = MarkitdownFileParser() - using_pinecone = should_use_pinecone() - if using_pinecone: + # determine which embedder to use - turbopuffer takes precedence + if should_use_tpuf(): + from letta.services.file_processor.embedder.turbopuffer_embedder import TurbopufferEmbedder + + embedder = TurbopufferEmbedder(embedding_config=embedding_config) + elif should_use_pinecone(): embedder = PineconeEmbedder(embedding_config=embedding_config) else: embedder = OpenAIEmbedder(embedding_config=embedding_config) - file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor, using_pinecone=using_pinecone) + + file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor) await file_processor.process(agent_states=agent_states, source_id=source_id, content=content, file_metadata=file_metadata) diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index c9d55407..a5fee7b8 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -15,6 +15,7 @@ from letta.helpers.pinecone_utils import ( delete_source_records_from_pinecone_index, should_use_pinecone, ) +from letta.helpers.tpuf_client import should_use_tpuf from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState @@ -189,7 +190,13 @@ async def delete_source( files = await server.file_manager.list_files(source_id, actor) file_ids = [f.id for f in files] - if should_use_pinecone(): + if should_use_tpuf(): + logger.info(f"Deleting source {source_id} from Turbopuffer") + from letta.helpers.tpuf_client import TurbopufferClient + + tpuf_client = TurbopufferClient() + await tpuf_client.delete_source_passages(source_id=source_id, organization_id=actor.organization_id) + elif should_use_pinecone(): logger.info(f"Deleting source {source_id} from pinecone index") await delete_source_records_from_pinecone_index(source_id=source_id, actor=actor) @@ -435,7 +442,13 @@ async def delete_file_from_source( await server.remove_file_from_context_windows(source_id=source_id, file_id=deleted_file.id, actor=actor) - if should_use_pinecone(): + if should_use_tpuf(): + logger.info(f"Deleting file {file_id} from Turbopuffer") + from letta.helpers.tpuf_client import TurbopufferClient + + tpuf_client = TurbopufferClient() + await tpuf_client.delete_file_passages(source_id=source_id, file_id=file_id, organization_id=actor.organization_id) + elif should_use_pinecone(): logger.info(f"Deleting file {file_id} from pinecone index") await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor) @@ -481,10 +494,15 @@ async def load_file_to_source_cloud( else: file_parser = MarkitdownFileParser() - using_pinecone = should_use_pinecone() - if using_pinecone: + # determine which embedder to use - turbopuffer takes precedence + if should_use_tpuf(): + from letta.services.file_processor.embedder.turbopuffer_embedder import TurbopufferEmbedder + + embedder = TurbopufferEmbedder(embedding_config=embedding_config) + elif should_use_pinecone(): embedder = PineconeEmbedder(embedding_config=embedding_config) else: embedder = OpenAIEmbedder(embedding_config=embedding_config) - file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor, using_pinecone=using_pinecone) + + file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor) await file_processor.process(agent_states=agent_states, source_id=source_id, content=content, file_metadata=file_metadata) diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index bc0e07c3..0cbabe4c 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -12,6 +12,7 @@ from letta.errors import ( AgentNotFoundForExportError, ) from letta.helpers.pinecone_utils import should_use_pinecone +from letta.helpers.tpuf_client import should_use_tpuf from letta.log import get_logger from letta.schemas.agent import AgentState, CreateAgent from letta.schemas.agent_file import ( @@ -29,7 +30,7 @@ from letta.schemas.agent_file import ( ) from letta.schemas.block import Block from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import FileProcessingStatus +from letta.schemas.enums import FileProcessingStatus, VectorDBProvider from letta.schemas.file import FileMetadata from letta.schemas.group import Group, GroupCreate from letta.schemas.mcp import MCPServer @@ -90,7 +91,6 @@ class AgentSerializationManager: self.file_agent_manager = file_agent_manager self.message_manager = message_manager self.file_parser = MistralFileParser() if settings.mistral_api_key else MarkitdownFileParser() - self.using_pinecone = should_use_pinecone() # ID mapping state for export self._db_to_file_ids: Dict[str, str] = {} @@ -588,7 +588,12 @@ class AgentSerializationManager: if schema.files and any(f.content for f in schema.files): # Use override embedding config if provided, otherwise use agent's config embedder_config = override_embedding_config if override_embedding_config else schema.agents[0].embedding_config - if should_use_pinecone(): + # determine which embedder to use - turbopuffer takes precedence + if should_use_tpuf(): + from letta.services.file_processor.embedder.turbopuffer_embedder import TurbopufferEmbedder + + embedder = TurbopufferEmbedder(embedding_config=embedder_config) + elif should_use_pinecone(): embedder = PineconeEmbedder(embedding_config=embedder_config) else: embedder = OpenAIEmbedder(embedding_config=embedder_config) @@ -596,7 +601,6 @@ class AgentSerializationManager: file_parser=self.file_parser, embedder=embedder, actor=actor, - using_pinecone=self.using_pinecone, ) for file_schema in schema.files: diff --git a/letta/services/file_processor/embedder/base_embedder.py b/letta/services/file_processor/embedder/base_embedder.py index b9310c3e..b2a6408b 100644 --- a/letta/services/file_processor/embedder/base_embedder.py +++ b/letta/services/file_processor/embedder/base_embedder.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from typing import List from letta.log import get_logger +from letta.schemas.enums import VectorDBProvider from letta.schemas.passage import Passage from letta.schemas.user import User @@ -11,6 +12,10 @@ logger = get_logger(__name__) class BaseEmbedder(ABC): """Abstract base class for embedding generation""" + def __init__(self): + # Default to NATIVE, subclasses will override this + self.vector_db_type = VectorDBProvider.NATIVE + @abstractmethod async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]: """Generate embeddings for chunks with batching and concurrent processing""" diff --git a/letta/services/file_processor/embedder/openai_embedder.py b/letta/services/file_processor/embedder/openai_embedder.py index b55ba936..77adbd85 100644 --- a/letta/services/file_processor/embedder/openai_embedder.py +++ b/letta/services/file_processor/embedder/openai_embedder.py @@ -19,6 +19,10 @@ class OpenAIEmbedder(BaseEmbedder): """OpenAI-based embedding generation""" def __init__(self, embedding_config: Optional[EmbeddingConfig] = None): + super().__init__() + # OpenAI embedder uses the native vector db (PostgreSQL) + # self.vector_db_type already set to VectorDBProvider.NATIVE by parent + self.default_embedding_config = ( EmbeddingConfig.default_config(model_name="text-embedding-3-small", provider="openai") if model_settings.openai_api_key diff --git a/letta/services/file_processor/embedder/pinecone_embedder.py b/letta/services/file_processor/embedder/pinecone_embedder.py index c218807e..f11aafed 100644 --- a/letta/services/file_processor/embedder/pinecone_embedder.py +++ b/letta/services/file_processor/embedder/pinecone_embedder.py @@ -4,6 +4,7 @@ from letta.helpers.pinecone_utils import upsert_file_records_to_pinecone_index from letta.log import get_logger from letta.otel.tracing import log_event, trace_method from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import VectorDBProvider from letta.schemas.passage import Passage from letta.schemas.user import User from letta.services.file_processor.embedder.base_embedder import BaseEmbedder @@ -20,6 +21,10 @@ class PineconeEmbedder(BaseEmbedder): """Pinecone-based embedding generation""" def __init__(self, embedding_config: Optional[EmbeddingConfig] = None): + super().__init__() + # set the vector db type for pinecone + self.vector_db_type = VectorDBProvider.PINECONE + if not PINECONE_AVAILABLE: raise ImportError("Pinecone package is not installed. Install it with: pip install pinecone") @@ -28,7 +33,6 @@ class PineconeEmbedder(BaseEmbedder): embedding_config = EmbeddingConfig.default_config(provider="pinecone") self.embedding_config = embedding_config - super().__init__() @trace_method async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]: diff --git a/letta/services/file_processor/embedder/turbopuffer_embedder.py b/letta/services/file_processor/embedder/turbopuffer_embedder.py new file mode 100644 index 00000000..c17b28c3 --- /dev/null +++ b/letta/services/file_processor/embedder/turbopuffer_embedder.py @@ -0,0 +1,71 @@ +from typing import List, Optional + +from letta.helpers.tpuf_client import TurbopufferClient +from letta.log import get_logger +from letta.otel.tracing import log_event, trace_method +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import VectorDBProvider +from letta.schemas.passage import Passage +from letta.schemas.user import User +from letta.services.file_processor.embedder.base_embedder import BaseEmbedder + +logger = get_logger(__name__) + + +class TurbopufferEmbedder(BaseEmbedder): + """Turbopuffer-based embedding generation and storage""" + + def __init__(self, embedding_config: Optional[EmbeddingConfig] = None): + super().__init__() + # set the vector db type for turbopuffer + self.vector_db_type = VectorDBProvider.TPUF + # use the default embedding config from TurbopufferClient if not provided + self.embedding_config = embedding_config or TurbopufferClient.default_embedding_config + self.tpuf_client = TurbopufferClient() + + @trace_method + async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]: + """Generate embeddings and store in Turbopuffer, then return Passage objects""" + if not chunks: + return [] + + logger.info(f"Generating embeddings for {len(chunks)} chunks using Turbopuffer") + log_event( + "turbopuffer_embedder.generation_started", + { + "total_chunks": len(chunks), + "file_id": file_id, + "source_id": source_id, + "embedding_model": self.embedding_config.embedding_model, + }, + ) + + try: + # insert passages to Turbopuffer - it will handle embedding generation internally + passages = await self.tpuf_client.insert_file_passages( + source_id=source_id, + file_id=file_id, + text_chunks=chunks, + organization_id=actor.organization_id, + actor=actor, + ) + + logger.info(f"Successfully generated and stored {len(passages)} passages in Turbopuffer") + log_event( + "turbopuffer_embedder.generation_completed", + { + "passages_created": len(passages), + "total_chunks_processed": len(chunks), + "file_id": file_id, + "source_id": source_id, + }, + ) + return passages + + except Exception as e: + logger.error(f"Failed to generate embeddings with Turbopuffer: {str(e)}") + log_event( + "turbopuffer_embedder.generation_failed", + {"error": str(e), "error_type": type(e).__name__, "file_id": file_id, "source_id": source_id}, + ) + raise diff --git a/letta/services/file_processor/file_processor.py b/letta/services/file_processor/file_processor.py index 6e63a1ec..529ea70d 100644 --- a/letta/services/file_processor/file_processor.py +++ b/letta/services/file_processor/file_processor.py @@ -6,7 +6,7 @@ from letta.log import get_logger from letta.otel.context import get_ctx_attributes from letta.otel.tracing import log_event, trace_method from letta.schemas.agent import AgentState -from letta.schemas.enums import FileProcessingStatus +from letta.schemas.enums import FileProcessingStatus, VectorDBProvider from letta.schemas.file import FileMetadata from letta.schemas.passage import Passage from letta.schemas.user import User @@ -30,7 +30,6 @@ class FileProcessor: file_parser: FileParser, embedder: BaseEmbedder, actor: User, - using_pinecone: bool, max_file_size: int = 50 * 1024 * 1024, # 50MB default ): self.file_parser = file_parser @@ -42,7 +41,8 @@ class FileProcessor: self.job_manager = JobManager() self.agent_manager = AgentManager() self.actor = actor - self.using_pinecone = using_pinecone + # get vector db type from the embedder + self.vector_db_type = embedder.vector_db_type async def _chunk_and_embed_with_fallback(self, file_metadata: FileMetadata, ocr_response, source_id: str) -> List: """Chunk text and generate embeddings with fallback to default chunker if needed""" @@ -218,7 +218,7 @@ class FileProcessor: source_id=source_id, ) - if not self.using_pinecone: + if self.vector_db_type == VectorDBProvider.NATIVE: all_passages = await self.passage_manager.create_many_source_passages_async( passages=all_passages, file_metadata=file_metadata, @@ -241,7 +241,8 @@ class FileProcessor: ) # update job status - if not self.using_pinecone: + # pinecone completes slowly, so gets updated later + if self.vector_db_type != VectorDBProvider.PINECONE: await self.file_manager.update_file_status( file_id=file_metadata.id, actor=self.actor, @@ -317,14 +318,15 @@ class FileProcessor: ) # Create passages in database (unless using Pinecone) - if not self.using_pinecone: + if self.vector_db_type == VectorDBProvider.NATIVE: all_passages = await self.passage_manager.create_many_source_passages_async( passages=all_passages, file_metadata=file_metadata, actor=self.actor ) log_event("file_processor.import_passages_created", {"filename": filename, "total_passages": len(all_passages)}) # Update file status to completed (valid transition from EMBEDDING) - if not self.using_pinecone: + # pinecone completes slowly, so gets updated later + if self.vector_db_type != VectorDBProvider.PINECONE: await self.file_manager.update_file_status( file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED ) diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index efe9e650..8f10baeb 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -3,6 +3,7 @@ from typing import List, Optional, Union from sqlalchemy import and_, exists, select +from letta.helpers.pinecone_utils import should_use_pinecone from letta.helpers.tpuf_client import should_use_tpuf from letta.orm import Agent as AgentModel from letta.orm.errors import NoResultFound @@ -18,6 +19,18 @@ from letta.utils import enforce_types, printd class SourceManager: + def _get_vector_db_provider(self) -> VectorDBProvider: + """ + determine which vector db provider to use based on configuration. + turbopuffer takes precedence when available. + """ + if should_use_tpuf(): + return VectorDBProvider.TPUF + elif should_use_pinecone(): + return VectorDBProvider.PINECONE + else: + return VectorDBProvider.NATIVE + """Manager class to handle business logic related to Sources.""" @trace_method @@ -52,7 +65,7 @@ class SourceManager: if db_source: return db_source else: - vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE + vector_db_provider = self._get_vector_db_provider() async with db_registry.async_session() as session: # Provide default embedding config if not given @@ -96,7 +109,7 @@ class SourceManager: Returns: List of created/updated sources """ - vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE + vector_db_provider = self._get_vector_db_provider() for pydantic_source in pydantic_sources: pydantic_source.vector_db_provider = vector_db_provider diff --git a/letta/services/tool_executor/files_tool_executor.py b/letta/services/tool_executor/files_tool_executor.py index 43a3fd97..251d0320 100644 --- a/letta/services/tool_executor/files_tool_executor.py +++ b/letta/services/tool_executor/files_tool_executor.py @@ -5,10 +5,13 @@ from typing import Any, Dict, List, Optional from letta.constants import PINECONE_TEXT_FIELD_NAME from letta.functions.types import FileOpenRequest from letta.helpers.pinecone_utils import search_pinecone_index, should_use_pinecone +from letta.helpers.tpuf_client import should_use_tpuf from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState +from letta.schemas.enums import VectorDBProvider from letta.schemas.sandbox_config import SandboxConfig +from letta.schemas.source import Source from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.user import User @@ -554,18 +557,140 @@ class LettaFileToolExecutor(ToolExecutor): self.logger.info(f"Semantic search started for agent {agent_state.id} with query '{query}' (limit: {limit})") - # Check if Pinecone is enabled and use it if available - if should_use_pinecone(): - return await self._search_files_pinecone(agent_state, query, limit) - else: - return await self._search_files_traditional(agent_state, query, limit) + # Check which vector DB to use - Turbopuffer takes precedence + attached_sources = await self.agent_manager.list_attached_sources_async(agent_id=agent_state.id, actor=self.actor) + attached_tpuf_sources = [source for source in attached_sources if source.vector_db_provider == VectorDBProvider.TPUF] + attached_pinecone_sources = [source for source in attached_sources if source.vector_db_provider == VectorDBProvider.PINECONE] - async def _search_files_pinecone(self, agent_state: AgentState, query: str, limit: int) -> str: + if not attached_tpuf_sources and not attached_pinecone_sources: + return await self._search_files_native(agent_state, query, limit) + + results = [] + + # If both have items, we half the limit roughly + # TODO: This is very hacky bc it skips the re-ranking - but this is a temporary stopgap while we think about migrating data + + if attached_tpuf_sources and attached_pinecone_sources: + limit = max(limit // 2, 1) + + if should_use_tpuf() and attached_tpuf_sources: + tpuf_result = await self._search_files_turbopuffer(agent_state, attached_tpuf_sources, query, limit) + results.append(tpuf_result) + + if should_use_pinecone() and attached_pinecone_sources: + pinecone_result = await self._search_files_pinecone(agent_state, attached_pinecone_sources, query, limit) + results.append(pinecone_result) + + # combine results from both sources + if results: + return "\n\n".join(results) + + # fallback if no results from either source + return "No results found" + + async def _search_files_turbopuffer(self, agent_state: AgentState, attached_sources: List[Source], query: str, limit: int) -> str: + """Search files using Turbopuffer vector database.""" + + # Get attached sources + source_ids = [source.id for source in attached_sources] + if not source_ids: + return "No valid source IDs found for attached files" + + # Get all attached files for this agent + file_agents = await self.files_agents_manager.list_files_for_agent( + agent_id=agent_state.id, per_file_view_window_char_limit=agent_state.per_file_view_window_char_limit, actor=self.actor + ) + if not file_agents: + return "No files are currently attached to search" + + # Create a map of file_id to file_name for quick lookup + file_map = {fa.file_id: fa.file_name for fa in file_agents} + + results = [] + total_hits = 0 + files_with_matches = {} + + try: + from letta.helpers.tpuf_client import TurbopufferClient + + tpuf_client = TurbopufferClient() + + # Query Turbopuffer for all sources at once + search_results = await tpuf_client.query_file_passages( + source_ids=source_ids, # pass all source_ids as a list + organization_id=self.actor.organization_id, + actor=self.actor, + query_text=query, + search_mode="hybrid", # use hybrid search for best results + top_k=limit, + ) + + # Process search results + for passage, score, metadata in search_results: + if total_hits >= limit: + break + + total_hits += 1 + + # get file name from our map + file_name = file_map.get(passage.file_id, "Unknown File") + + # group by file name + if file_name not in files_with_matches: + files_with_matches[file_name] = [] + files_with_matches[file_name].append({"text": passage.text, "score": score, "passage_id": passage.id}) + + except Exception as e: + self.logger.error(f"Turbopuffer search failed: {str(e)}") + raise e + + if not files_with_matches: + return f"No semantic matches found in Turbopuffer for query: '{query}'" + + # Format results + passage_num = 0 + for file_name, matches in files_with_matches.items(): + for match in matches: + passage_num += 1 + + # format each passage with terminal-style header + score_display = f"(score: {match['score']:.3f})" + passage_header = f"\n=== {file_name} (passage #{passage_num}) {score_display} ===" + + # format the passage text + passage_text = match["text"].strip() + lines = passage_text.splitlines() + formatted_lines = [] + for line in lines[:20]: # limit to first 20 lines per passage + formatted_lines.append(f" {line}") + + if len(lines) > 20: + formatted_lines.append(f" ... [truncated {len(lines) - 20} more lines]") + + passage_content = "\n".join(formatted_lines) + results.append(f"{passage_header}\n{passage_content}") + + # mark access for files that had matches + if files_with_matches: + matched_file_names = [name for name in files_with_matches.keys() if name != "Unknown File"] + if matched_file_names: + await self.files_agents_manager.mark_access_bulk(agent_id=agent_state.id, file_names=matched_file_names, actor=self.actor) + + # create summary header + file_count = len(files_with_matches) + summary = f"Found {total_hits} Turbopuffer matches in {file_count} file{'s' if file_count != 1 else ''} for query: '{query}'" + + # combine all results + formatted_results = [summary, "=" * len(summary)] + results + + self.logger.info(f"Turbopuffer search completed: {total_hits} matches across {file_count} files") + return "\n".join(formatted_results) + + async def _search_files_pinecone(self, agent_state: AgentState, attached_sources: List[Source], query: str, limit: int) -> str: """Search files using Pinecone vector database.""" # Extract unique source_ids # TODO: Inefficient - attached_sources = await self.agent_manager.list_attached_sources_async(agent_id=agent_state.id, actor=self.actor) source_ids = [source.id for source in attached_sources] if not source_ids: return "No valid source IDs found for attached files" @@ -658,7 +783,7 @@ class LettaFileToolExecutor(ToolExecutor): self.logger.info(f"Pinecone search completed: {total_hits} matches across {file_count} files") return "\n".join(formatted_results) - async def _search_files_traditional(self, agent_state: AgentState, query: str, limit: int) -> str: + async def _search_files_native(self, agent_state: AgentState, query: str, limit: int) -> str: """Traditional search using existing passage manager.""" # Get semantic search results passages = await self.agent_manager.query_source_passages_async( diff --git a/tests/test_file_processor.py b/tests/test_file_processor.py index 15ba5010..39dd790a 100644 --- a/tests/test_file_processor.py +++ b/tests/test_file_processor.py @@ -258,7 +258,7 @@ class TestFileProcessorWithPinecone: embedder = PineconeEmbedder() # Create file processor with Pinecone enabled - file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=mock_actor, using_pinecone=True) + file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=mock_actor) # Track file manager update calls update_calls = [] diff --git a/tests/test_sources.py b/tests/test_sources.py index 497405f6..f71422d8 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -13,6 +13,7 @@ from letta_client.types import AgentState from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS from letta.helpers.pinecone_utils import should_use_pinecone +from letta.helpers.tpuf_client import TurbopufferClient from letta.schemas.enums import FileProcessingStatus, ToolType from letta.schemas.message import MessageCreate from letta.schemas.user import User @@ -95,7 +96,7 @@ def agent_state(disable_pinecone, client: LettaSDKClient): # Tests -def test_auto_attach_detach_files_tools(disable_pinecone, client: LettaSDKClient): +def test_auto_attach_detach_files_tools(disable_pinecone, disable_turbopuffer, client: LettaSDKClient): """Test automatic attachment and detachment of file tools when managing agent sources.""" # Create agent with basic configuration agent = client.agents.create( @@ -168,6 +169,7 @@ def test_auto_attach_detach_files_tools(disable_pinecone, client: LettaSDKClient ) def test_file_upload_creates_source_blocks_correctly( disable_pinecone, + disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState, file_path: str, @@ -237,7 +239,9 @@ def test_file_upload_creates_source_blocks_correctly( settings.mistral_api_key = original_mistral_key -def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): +def test_attach_existing_files_creates_source_blocks_correctly( + disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState +): # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") assert len(client.sources.list()) == 1 @@ -302,7 +306,9 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, assert "" not in raw_system_message_after_detach -def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): +def test_delete_source_removes_source_blocks_correctly( + disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState +): # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") assert len(client.sources.list()) == 1 @@ -360,7 +366,7 @@ def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: assert not any("test" in b.value for b in blocks) -def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): +def test_agent_uses_open_close_file_correctly(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState): # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") @@ -463,7 +469,7 @@ def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDK print("✓ File successfully opened with different range - content differs as expected") -def test_agent_uses_search_files_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): +def test_agent_uses_search_files_correctly(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState): # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") @@ -509,7 +515,7 @@ def test_agent_uses_search_files_correctly(disable_pinecone, client: LettaSDKCli assert all(tr.status == "success" for tr in tool_returns), f"Tool call failed {tr}" -def test_agent_uses_grep_correctly_basic(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): +def test_agent_uses_grep_correctly_basic(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState): # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") @@ -551,7 +557,7 @@ def test_agent_uses_grep_correctly_basic(disable_pinecone, client: LettaSDKClien assert all(tr.status == "success" for tr in tool_returns), "Tool call failed" -def test_agent_uses_grep_correctly_advanced(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): +def test_agent_uses_grep_correctly_advanced(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState): # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") @@ -599,7 +605,7 @@ def test_agent_uses_grep_correctly_advanced(disable_pinecone, client: LettaSDKCl assert "511:" in tool_return_message.tool_return -def test_create_agent_with_source_ids_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient): +def test_create_agent_with_source_ids_creates_source_blocks_correctly(disable_pinecone, disable_turbopuffer, client: LettaSDKClient): """Test that creating an agent with source_ids parameter correctly creates source blocks.""" # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") @@ -642,7 +648,7 @@ def test_create_agent_with_source_ids_creates_source_blocks_correctly(disable_pi assert file_tools == set(FILES_TOOLS) -def test_view_ranges_have_metadata(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): +def test_view_ranges_have_metadata(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState): # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") @@ -705,7 +711,7 @@ def test_view_ranges_have_metadata(disable_pinecone, client: LettaSDKClient, age ) -def test_duplicate_file_renaming(disable_pinecone, client: LettaSDKClient): +def test_duplicate_file_renaming(disable_pinecone, disable_turbopuffer, client: LettaSDKClient): """Test that duplicate files are renamed with count-based suffixes (e.g., file.txt, file (1).txt, file (2).txt)""" # Create a new source source = client.sources.create(name="test_duplicate_source", embedding="openai/text-embedding-3-small") @@ -744,7 +750,7 @@ def test_duplicate_file_renaming(disable_pinecone, client: LettaSDKClient): print(f" File {i + 1}: original='{file.original_file_name}' → renamed='{file.file_name}'") -def test_duplicate_file_handling_replace(disable_pinecone, client: LettaSDKClient): +def test_duplicate_file_handling_replace(disable_pinecone, disable_turbopuffer, client: LettaSDKClient): """Test that DuplicateFileHandling.REPLACE replaces existing files with same name""" # Create a new source source = client.sources.create(name="test_replace_source", embedding="openai/text-embedding-3-small") @@ -826,7 +832,7 @@ def test_duplicate_file_handling_replace(disable_pinecone, client: LettaSDKClien os.unlink(temp_file_path) -def test_upload_file_with_custom_name(disable_pinecone, client: LettaSDKClient): +def test_upload_file_with_custom_name(disable_pinecone, disable_turbopuffer, client: LettaSDKClient): """Test that uploading a file with a custom name overrides the original filename""" # Create agent agent_state = client.agents.create( @@ -907,7 +913,7 @@ def test_upload_file_with_custom_name(disable_pinecone, client: LettaSDKClient): os.unlink(temp_file_path) -def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient): +def test_open_files_schema_descriptions(disable_pinecone, disable_turbopuffer, client: LettaSDKClient): """Test that open_files tool schema contains correct descriptions from docstring""" # Get the open_files tool @@ -990,7 +996,7 @@ def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient assert length_prop["type"] == "integer" -def test_grep_files_schema_descriptions(disable_pinecone, client: LettaSDKClient): +def test_grep_files_schema_descriptions(disable_pinecone, disable_turbopuffer, client: LettaSDKClient): """Test that grep_files tool schema contains correct descriptions from docstring""" # Get the grep_files tool @@ -1076,10 +1082,174 @@ def test_grep_files_schema_descriptions(disable_pinecone, client: LettaSDKClient assert "Navigation hint for next page if more matches exist" in description +def test_agent_open_file(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState): + """Test client.agents.open_file() function""" + # Create a new source + source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") + + # Attach source to agent + client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) + + # Upload a file + file_path = "tests/data/test.txt" + file_metadata = upload_file_and_wait(client, source.id, file_path) + + # Basic test open_file function + closed_files = client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id) + assert len(closed_files) == 0 + + system = get_raw_system_message(client, agent_state.id) + assert '' in system + assert "[Viewing file start (out of 1 lines)]" in system + + +def test_agent_close_file(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState): + """Test client.agents.close_file() function""" + # Create a new source + source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") + + # Attach source to agent + client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) + + # Upload a file + file_path = "tests/data/test.txt" + file_metadata = upload_file_and_wait(client, source.id, file_path) + + # First open the file + client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id) + + # Test close_file function + client.agents.files.close(agent_id=agent_state.id, file_id=file_metadata.id) + + system = get_raw_system_message(client, agent_state.id) + assert '' in system + + +def test_agent_close_all_open_files(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState): + """Test client.agents.close_all_open_files() function""" + # Create a new source + source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") + + # Attach source to agent + client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) + + # Upload multiple files + file_paths = ["tests/data/test.txt", "tests/data/test.md"] + file_metadatas = [] + for file_path in file_paths: + file_metadata = upload_file_and_wait(client, source.id, file_path) + file_metadatas.append(file_metadata) + # Open each file + client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id) + + system = get_raw_system_message(client, agent_state.id) + assert '' in system - assert "[Viewing file start (out of 1 lines)]" in system +# --- End Pinecone Tests --- -def test_agent_close_file(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): - """Test client.agents.close_file() function""" - # Create a new source - source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") +# --- Turbopuffer Tests --- +def test_turbopuffer_search_files_tool(disable_pinecone, client: LettaSDKClient): + """Test that search_files tool uses Turbopuffer when enabled""" + agent = client.agents.create( + name="test_turbopuffer_agent", + memory_blocks=[ + CreateBlock(label="human", value="username: testuser"), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + ) - # Attach source to agent - client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) + source = client.sources.create(name="test_turbopuffer_source", embedding="openai/text-embedding-3-small") + client.agents.sources.attach(source_id=source.id, agent_id=agent.id) - # Upload a file - file_path = "tests/data/test.txt" - file_metadata = upload_file_and_wait(client, source.id, file_path) + file_path = "tests/data/long_test.txt" + upload_file_and_wait(client, source.id, file_path) - # First open the file - client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id) + search_response = client.agents.messages.create( + agent_id=agent.id, + messages=[MessageCreate(role="user", content="Use the semantic_search_files tool to search for 'electoral history' in the files.")], + ) - # Test close_file function - client.agents.files.close(agent_id=agent_state.id, file_id=file_metadata.id) + tool_calls = [msg for msg in search_response.messages if msg.message_type == "tool_call_message"] + assert len(tool_calls) > 0, "No tool calls found" + assert any(tc.tool_call.name == "semantic_search_files" for tc in tool_calls), "semantic_search_files not called" - system = get_raw_system_message(client, agent_state.id) - assert '' in system + tool_returns = [msg for msg in search_response.messages if msg.message_type == "tool_return_message"] + assert len(tool_returns) > 0, "No tool returns found" + assert all(tr.status == "success" for tr in tool_returns), "Tool call failed" + + search_results = tool_returns[0].tool_return + print(f"Turbopuffer search results: {search_results}") + assert "electoral" in search_results.lower() or "history" in search_results.lower(), ( + f"Search results should contain relevant content: {search_results}" + ) + + client.agents.delete(agent_id=agent.id) + client.sources.delete(source_id=source.id) -def test_agent_close_all_open_files(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): - """Test client.agents.close_all_open_files() function""" - # Create a new source - source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") +def test_turbopuffer_file_processing_status(disable_pinecone, client: LettaSDKClient): + """Test that file processing completes successfully with Turbopuffer""" + print("Testing Turbopuffer file processing status") - # Attach source to agent - client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) + source = client.sources.create(name="test_tpuf_file_status", embedding="openai/text-embedding-3-small") - # Upload multiple files - file_paths = ["tests/data/test.txt", "tests/data/test.md"] - file_metadatas = [] + file_paths = ["tests/data/long_test.txt", "tests/data/test.md"] + uploaded_files = [] for file_path in file_paths: file_metadata = upload_file_and_wait(client, source.id, file_path) - file_metadatas.append(file_metadata) - # Open each file - client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id) + uploaded_files.append(file_metadata) + assert file_metadata.processing_status == "completed", f"File {file_path} should be completed" - system = get_raw_system_message(client, agent_state.id) - assert ' 0: + assert file_metadata.chunks_embedded == file_metadata.total_chunks, ( + f"File {file_metadata.file_name} should have all chunks embedded: {file_metadata.chunks_embedded}/{file_metadata.total_chunks}" + ) - -def test_file_processing_timeout(disable_pinecone, client: LettaSDKClient): - """Test that files in non-terminal states are moved to error after timeout""" - # Create a source - source = client.sources.create(name="test_timeout_source", embedding="openai/text-embedding-3-small") - - # Upload a file - file_path = "tests/data/test.txt" - with open(file_path, "rb") as f: - file_metadata = client.sources.files.upload(source_id=source.id, file=f) - - # Get the file ID - file_id = file_metadata.id - - # Test the is_terminal_state method directly (this doesn't require server mocking) - assert FileProcessingStatus.COMPLETED.is_terminal_state() == True - assert FileProcessingStatus.ERROR.is_terminal_state() == True - assert FileProcessingStatus.PARSING.is_terminal_state() == False - assert FileProcessingStatus.EMBEDDING.is_terminal_state() == False - assert FileProcessingStatus.PENDING.is_terminal_state() == False - - # For testing the actual timeout logic, we can check the current file status - current_file = client.sources.get_file_metadata(source_id=source.id, file_id=file_id) - - # Convert string status to enum for testing - status_enum = FileProcessingStatus(current_file.processing_status) - - # Verify that files in terminal states are not affected by timeout checks - if status_enum.is_terminal_state(): - # This is the expected behavior - files that completed processing shouldn't timeout - print(f"File {file_id} is in terminal state: {current_file.processing_status}") - assert status_enum in [FileProcessingStatus.COMPLETED, FileProcessingStatus.ERROR] - else: - # If file is still processing, it should eventually complete or timeout - # In a real scenario, we'd wait and check, but for unit tests we just verify the logic exists - print(f"File {file_id} is still processing: {current_file.processing_status}") - assert status_enum in [FileProcessingStatus.PENDING, FileProcessingStatus.PARSING, FileProcessingStatus.EMBEDDING] - - -@pytest.mark.unit -def test_file_processing_timeout_logic(): - """Test the timeout logic directly without server dependencies""" - from datetime import timezone - - # Test scenario: file created 35 minutes ago, timeout is 30 minutes - old_time = datetime.now(timezone.utc) - timedelta(minutes=35) - current_time = datetime.now(timezone.utc) - timeout_minutes = 30 - - # Calculate timeout threshold - timeout_threshold = current_time - timedelta(minutes=timeout_minutes) - - # Verify timeout logic - assert old_time < timeout_threshold, "File created 35 minutes ago should be past 30-minute timeout" - - # Test edge case: file created exactly at timeout - edge_time = current_time - timedelta(minutes=timeout_minutes) - assert not (edge_time < timeout_threshold), "File created exactly at timeout should not trigger timeout" - - # Test recent file - recent_time = current_time - timedelta(minutes=10) - assert not (recent_time < timeout_threshold), "Recent file should not trigger timeout" - - -def test_letta_free_embedding(disable_pinecone, client: LettaSDKClient): - """Test creating a source with letta/letta-free embedding and uploading a file""" - # create a source with letta-free embedding - source = client.sources.create(name="test_letta_free_source", embedding="letta/letta-free") - - # verify source was created with correct embedding - assert source.name == "test_letta_free_source" - print("\n\n\n\ntest") - print(source.embedding_config) - # assert source.embedding_config.embedding_model == "letta-free" - - # upload test.txt file - file_path = "tests/data/test.txt" - file_metadata = upload_file_and_wait(client, source.id, file_path) - - # verify file was uploaded successfully - assert file_metadata.processing_status == "completed" - assert file_metadata.source_id == source.id - assert file_metadata.file_name == "test.txt" - - # verify file appears in source files list - files = client.sources.files.list(source_id=source.id, limit=1) - assert len(files) == 1 - assert files[0].id == file_metadata.id - - # cleanup client.sources.delete(source_id=source.id) + + +def test_turbopuffer_lifecycle_file_and_source_deletion(disable_pinecone, client: LettaSDKClient): + """Test that file and source deletion removes records from Turbopuffer""" + source = client.sources.create(name="test_tpuf_lifecycle", embedding="openai/text-embedding-3-small") + + file_paths = ["tests/data/test.txt", "tests/data/test.md"] + uploaded_files = [] + for file_path in file_paths: + file_metadata = upload_file_and_wait(client, source.id, file_path) + uploaded_files.append(file_metadata) + + user = User(name="temp", organization_id=DEFAULT_ORG_ID) + tpuf_client = TurbopufferClient() + + # test file-level deletion + if len(uploaded_files) > 1: + file_to_delete = uploaded_files[0] + + passages_before = asyncio.run( + tpuf_client.query_file_passages( + source_ids=[source.id], organization_id=user.organization_id, actor=user, file_id=file_to_delete.id, top_k=100 + ) + ) + print(f"Found {len(passages_before)} passages for file before deletion") + assert len(passages_before) > 0, "Should have passages before deletion" + + client.sources.files.delete(source_id=source.id, file_id=file_to_delete.id) + + time.sleep(2) + + passages_after = asyncio.run( + tpuf_client.query_file_passages( + source_ids=[source.id], organization_id=user.organization_id, actor=user, file_id=file_to_delete.id, top_k=100 + ) + ) + print(f"Found {len(passages_after)} passages for file after deletion") + assert len(passages_after) == 0, f"File passages should be removed from Turbopuffer after deletion, but found {len(passages_after)}" + + # test source-level deletion + remaining_passages_before = [] + for file_metadata in uploaded_files[1:]: + passages = asyncio.run( + tpuf_client.query_file_passages( + source_ids=[source.id], organization_id=user.organization_id, actor=user, file_id=file_metadata.id, top_k=100 + ) + ) + remaining_passages_before.extend(passages) + + print(f"Found {len(remaining_passages_before)} passages for remaining files before source deletion") + assert len(remaining_passages_before) > 0, "Should have passages for remaining files" + + client.sources.delete(source_id=source.id) + + time.sleep(3) + + remaining_passages_after = [] + for file_metadata in uploaded_files[1:]: + try: + passages = asyncio.run( + tpuf_client.query_file_passages( + source_ids=[source.id], organization_id=user.organization_id, actor=user, file_id=file_metadata.id, top_k=100 + ) + ) + remaining_passages_after.extend(passages) + except Exception as e: + print(f"Expected error querying deleted source: {e}") + + print(f"Found {len(remaining_passages_after)} passages for files after source deletion") + assert len(remaining_passages_after) == 0, ( + f"All source passages should be removed from Turbopuffer after source deletion, but found {len(remaining_passages_after)}" + ) + + +def test_turbopuffer_multiple_sources(disable_pinecone, client: LettaSDKClient): + """Test that Turbopuffer correctly isolates passages by source in org-scoped namespace""" + source1 = client.sources.create(name="test_tpuf_source1", embedding="openai/text-embedding-3-small") + source2 = client.sources.create(name="test_tpuf_source2", embedding="openai/text-embedding-3-small") + + file1_metadata = upload_file_and_wait(client, source1.id, "tests/data/test.txt") + file2_metadata = upload_file_and_wait(client, source2.id, "tests/data/test.md") + + user = User(name="temp", organization_id=DEFAULT_ORG_ID) + tpuf_client = TurbopufferClient() + + source1_passages = asyncio.run( + tpuf_client.query_file_passages(source_ids=[source1.id], organization_id=user.organization_id, actor=user, top_k=100) + ) + + source2_passages = asyncio.run( + tpuf_client.query_file_passages(source_ids=[source2.id], organization_id=user.organization_id, actor=user, top_k=100) + ) + + print(f"Source1 has {len(source1_passages)} passages") + print(f"Source2 has {len(source2_passages)} passages") + + assert len(source1_passages) > 0, "Source1 should have passages" + assert len(source2_passages) > 0, "Source2 should have passages" + + for passage, _, _ in source1_passages: + assert passage.source_id == source1.id, f"Passage should belong to source1, but has source_id={passage.source_id}" + assert passage.file_id == file1_metadata.id, f"Passage should belong to file1, but has file_id={passage.file_id}" + + for passage, _, _ in source2_passages: + assert passage.source_id == source2.id, f"Passage should belong to source2, but has source_id={passage.source_id}" + assert passage.file_id == file2_metadata.id, f"Passage should belong to file2, but has file_id={passage.file_id}" + + # delete source1 and verify source2 is unaffected + client.sources.delete(source_id=source1.id) + time.sleep(2) + + source2_passages_after = asyncio.run( + tpuf_client.query_file_passages(source_ids=[source2.id], organization_id=user.organization_id, actor=user, top_k=100) + ) + + assert len(source2_passages_after) == len(source2_passages), ( + f"Source2 should still have all passages after source1 deletion: {len(source2_passages_after)} vs {len(source2_passages)}" + ) + + client.sources.delete(source_id=source2.id) + + +# --- End Turbopuffer Tests ---