feat: Add turbopuffer embedder by default [LET-4253] (#4476)

* Adapt to turbopuffer embedder

* Make turbopuffer search more efficient over all source ids

* Combine turbopuffer and pinecone hybrid

* Fix test sources
This commit is contained in:
Matthew Zhou
2025-09-08 18:46:41 -07:00
committed by GitHub
parent 0f383ed776
commit 516f2963e0
13 changed files with 968 additions and 199 deletions

View File

@@ -16,12 +16,12 @@ logger = logging.getLogger(__name__)
def should_use_tpuf() -> bool:
# We need OpenAI since we default to their embedding model
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key)
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key) and bool(model_settings.openai_api_key)
def should_use_tpuf_for_messages() -> bool:
"""Check if Turbopuffer should be used for messages."""
return should_use_tpuf() and bool(settings.embed_all_messages) and bool(model_settings.openai_api_key)
return should_use_tpuf() and bool(settings.embed_all_messages)
class TurbopufferClient:
@@ -1113,3 +1113,309 @@ class TurbopufferClient:
except Exception as e:
logger.error(f"Failed to delete all messages from Turbopuffer: {e}")
raise
# file/source passage methods
@trace_method
async def _get_file_passages_namespace_name(self, organization_id: str) -> str:
"""Get namespace name for file passages (org-scoped).
Args:
organization_id: Organization ID for namespace generation
Returns:
The org-scoped namespace name for file passages
"""
environment = settings.environment
if environment:
namespace_name = f"file_passages_{organization_id}_{environment.lower()}"
else:
namespace_name = f"file_passages_{organization_id}"
return namespace_name
@trace_method
async def insert_file_passages(
self,
source_id: str,
file_id: str,
text_chunks: List[str],
organization_id: str,
actor: "PydanticUser",
created_at: Optional[datetime] = None,
) -> List[PydanticPassage]:
"""Insert file passages into Turbopuffer using org-scoped namespace.
Args:
source_id: ID of the source containing the file
file_id: ID of the file
text_chunks: List of text chunks to store
organization_id: Organization ID for the passages
actor: User actor for embedding generation
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
Returns:
List of PydanticPassage objects that were inserted
"""
from turbopuffer import AsyncTurbopuffer
if not text_chunks:
return []
# generate embeddings using the default config
embeddings = await self._generate_embeddings(text_chunks, actor)
namespace_name = await self._get_file_passages_namespace_name(organization_id)
# handle timestamp - ensure UTC
if created_at is None:
timestamp = datetime.now(timezone.utc)
else:
# ensure the provided timestamp is timezone-aware and in UTC
if created_at.tzinfo is None:
# assume UTC if no timezone provided
timestamp = created_at.replace(tzinfo=timezone.utc)
else:
# convert to UTC if in different timezone
timestamp = created_at.astimezone(timezone.utc)
# prepare column-based data for turbopuffer - optimized for batch insert
ids = []
vectors = []
texts = []
organization_ids = []
source_ids = []
file_ids = []
created_ats = []
passages = []
for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)):
passage = PydanticPassage(
text=text,
file_id=file_id,
source_id=source_id,
embedding=embedding,
embedding_config=self.default_embedding_config,
organization_id=actor.organization_id,
)
passages.append(passage)
# append to columns
ids.append(passage.id)
vectors.append(embedding)
texts.append(text)
organization_ids.append(organization_id)
source_ids.append(source_id)
file_ids.append(file_id)
created_ats.append(timestamp)
# build column-based upsert data
upsert_columns = {
"id": ids,
"vector": vectors,
"text": texts,
"organization_id": organization_ids,
"source_id": source_ids,
"file_id": file_ids,
"created_at": created_ats,
}
try:
# use AsyncTurbopuffer as a context manager for proper resource cleanup
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# turbopuffer recommends column-based writes for performance
await namespace.write(
upsert_columns=upsert_columns,
distance_metric="cosine_distance",
schema={"text": {"type": "string", "full_text_search": True}},
)
logger.info(f"Successfully inserted {len(ids)} file passages to Turbopuffer for source {source_id}, file {file_id}")
return passages
except Exception as e:
logger.error(f"Failed to insert file passages to Turbopuffer: {e}")
# check if it's a duplicate ID error
if "duplicate" in str(e).lower():
logger.error("Duplicate passage IDs detected in batch")
raise
@trace_method
async def query_file_passages(
self,
source_ids: List[str],
organization_id: str,
actor: "PydanticUser",
query_text: Optional[str] = None,
search_mode: str = "vector", # "vector", "fts", "hybrid"
top_k: int = 10,
file_id: Optional[str] = None, # optional filter by specific file
vector_weight: float = 0.5,
fts_weight: float = 0.5,
) -> List[Tuple[PydanticPassage, float, dict]]:
"""Query file passages from Turbopuffer using org-scoped namespace.
Args:
source_ids: List of source IDs to query
organization_id: Organization ID for namespace lookup
actor: User actor for embedding generation
query_text: Text query for search
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
top_k: Number of results to return
file_id: Optional file ID to filter results to a specific file
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
Returns:
List of (passage, score, metadata) tuples with relevance rankings
"""
# generate embedding for vector/hybrid search if query_text is provided
query_embedding = None
if query_text and search_mode in ["vector", "hybrid"]:
embeddings = await self._generate_embeddings([query_text], actor)
query_embedding = embeddings[0]
# check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
# fallback to retrieving most recent passages when no search query is provided
search_mode = "timestamp"
namespace_name = await self._get_file_passages_namespace_name(organization_id)
# build filters - always filter by source_ids
if len(source_ids) == 1:
# single source_id, use Eq for efficiency
filters = [("source_id", "Eq", source_ids[0])]
else:
# multiple source_ids, use In operator
filters = [("source_id", "In", source_ids)]
# add file filter if specified
if file_id:
filters.append(("file_id", "Eq", file_id))
# combine filters
final_filter = filters[0] if len(filters) == 1 else ("And", filters)
try:
# use generic query executor
result = await self._execute_query(
namespace_name=namespace_name,
search_mode=search_mode,
query_embedding=query_embedding,
query_text=query_text,
top_k=top_k,
include_attributes=["text", "organization_id", "source_id", "file_id", "created_at"],
filters=final_filter,
vector_weight=vector_weight,
fts_weight=fts_weight,
)
# process results based on search mode
if search_mode == "hybrid":
# for hybrid mode, we get a multi-query response
vector_results = self._process_file_query_results(result.results[0])
fts_results = self._process_file_query_results(result.results[1], is_fts=True)
# use RRF and include metadata with ranks
results_with_metadata = self._reciprocal_rank_fusion(
vector_results=[passage for passage, _ in vector_results],
fts_results=[passage for passage, _ in fts_results],
get_id_func=lambda p: p.id,
vector_weight=vector_weight,
fts_weight=fts_weight,
top_k=top_k,
)
return results_with_metadata
else:
# for single queries (vector, fts, timestamp) - add basic metadata
is_fts = search_mode == "fts"
results = self._process_file_query_results(result, is_fts=is_fts)
# add simple metadata for single search modes
results_with_metadata = []
for idx, (passage, score) in enumerate(results):
metadata = {
"combined_score": score,
f"{search_mode}_rank": idx + 1, # add the rank for this search mode
}
results_with_metadata.append((passage, score, metadata))
return results_with_metadata
except Exception as e:
logger.error(f"Failed to query file passages from Turbopuffer: {e}")
raise
def _process_file_query_results(self, result, is_fts: bool = False) -> List[Tuple[PydanticPassage, float]]:
"""Process results from a file query into passage objects with scores."""
passages_with_scores = []
for row in result.rows:
# build metadata
metadata = {}
# create a passage with minimal fields - embeddings are not returned from Turbopuffer
passage = PydanticPassage(
id=row.id,
text=getattr(row, "text", ""),
organization_id=getattr(row, "organization_id", None),
source_id=getattr(row, "source_id", None), # get source_id from the row
file_id=getattr(row, "file_id", None),
created_at=getattr(row, "created_at", None),
metadata_=metadata,
tags=[],
# set required fields to empty/default values since we don't store embeddings
embedding=[], # empty embedding since we don't return it from Turbopuffer
embedding_config=self.default_embedding_config,
)
# handle score based on search type
if is_fts:
# for FTS, use the BM25 score directly (higher is better)
score = getattr(row, "$score", 0.0)
else:
# for vector search, convert distance to similarity score
distance = getattr(row, "$dist", 0.0)
score = 1.0 - distance
passages_with_scores.append((passage, score))
return passages_with_scores
@trace_method
async def delete_file_passages(self, source_id: str, file_id: str, organization_id: str) -> bool:
"""Delete all passages for a specific file from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = await self._get_file_passages_namespace_name(organization_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# use delete_by_filter to only delete passages for this file
# need to filter by both source_id and file_id
filter_expr = ("And", [("source_id", "Eq", source_id), ("file_id", "Eq", file_id)])
result = await namespace.write(delete_by_filter=filter_expr)
logger.info(
f"Successfully deleted passages for file {file_id} from source {source_id} (deleted {result.rows_affected} rows)"
)
return True
except Exception as e:
logger.error(f"Failed to delete file passages from Turbopuffer: {e}")
raise
@trace_method
async def delete_source_passages(self, source_id: str, organization_id: str) -> bool:
"""Delete all passages for a source from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = await self._get_file_passages_namespace_name(organization_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# delete all passages for this source
result = await namespace.write(delete_by_filter=("source_id", "Eq", source_id))
logger.info(f"Successfully deleted all passages for source {source_id} (deleted {result.rows_affected} rows)")
return True
except Exception as e:
logger.error(f"Failed to delete source passages from Turbopuffer: {e}")
raise

View File

@@ -15,6 +15,7 @@ from letta.helpers.pinecone_utils import (
delete_source_records_from_pinecone_index,
should_use_pinecone,
)
from letta.helpers.tpuf_client import should_use_tpuf
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
@@ -191,7 +192,13 @@ async def delete_folder(
files = await server.file_manager.list_files(folder_id, actor)
file_ids = [f.id for f in files]
if should_use_pinecone():
if should_use_tpuf():
logger.info(f"Deleting folder {folder_id} from Turbopuffer")
from letta.helpers.tpuf_client import TurbopufferClient
tpuf_client = TurbopufferClient()
await tpuf_client.delete_source_passages(source_id=folder_id, organization_id=actor.organization_id)
elif should_use_pinecone():
logger.info(f"Deleting folder {folder_id} from pinecone index")
await delete_source_records_from_pinecone_index(source_id=folder_id, actor=actor)
@@ -450,7 +457,13 @@ async def delete_file_from_folder(
await server.remove_file_from_context_windows(source_id=folder_id, file_id=deleted_file.id, actor=actor)
if should_use_pinecone():
if should_use_tpuf():
logger.info(f"Deleting file {file_id} from Turbopuffer")
from letta.helpers.tpuf_client import TurbopufferClient
tpuf_client = TurbopufferClient()
await tpuf_client.delete_file_passages(source_id=folder_id, file_id=file_id, organization_id=actor.organization_id)
elif should_use_pinecone():
logger.info(f"Deleting file {file_id} from pinecone index")
await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor)
@@ -496,10 +509,15 @@ async def load_file_to_source_cloud(
else:
file_parser = MarkitdownFileParser()
using_pinecone = should_use_pinecone()
if using_pinecone:
# determine which embedder to use - turbopuffer takes precedence
if should_use_tpuf():
from letta.services.file_processor.embedder.turbopuffer_embedder import TurbopufferEmbedder
embedder = TurbopufferEmbedder(embedding_config=embedding_config)
elif should_use_pinecone():
embedder = PineconeEmbedder(embedding_config=embedding_config)
else:
embedder = OpenAIEmbedder(embedding_config=embedding_config)
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor, using_pinecone=using_pinecone)
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor)
await file_processor.process(agent_states=agent_states, source_id=source_id, content=content, file_metadata=file_metadata)

View File

@@ -15,6 +15,7 @@ from letta.helpers.pinecone_utils import (
delete_source_records_from_pinecone_index,
should_use_pinecone,
)
from letta.helpers.tpuf_client import should_use_tpuf
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
@@ -189,7 +190,13 @@ async def delete_source(
files = await server.file_manager.list_files(source_id, actor)
file_ids = [f.id for f in files]
if should_use_pinecone():
if should_use_tpuf():
logger.info(f"Deleting source {source_id} from Turbopuffer")
from letta.helpers.tpuf_client import TurbopufferClient
tpuf_client = TurbopufferClient()
await tpuf_client.delete_source_passages(source_id=source_id, organization_id=actor.organization_id)
elif should_use_pinecone():
logger.info(f"Deleting source {source_id} from pinecone index")
await delete_source_records_from_pinecone_index(source_id=source_id, actor=actor)
@@ -435,7 +442,13 @@ async def delete_file_from_source(
await server.remove_file_from_context_windows(source_id=source_id, file_id=deleted_file.id, actor=actor)
if should_use_pinecone():
if should_use_tpuf():
logger.info(f"Deleting file {file_id} from Turbopuffer")
from letta.helpers.tpuf_client import TurbopufferClient
tpuf_client = TurbopufferClient()
await tpuf_client.delete_file_passages(source_id=source_id, file_id=file_id, organization_id=actor.organization_id)
elif should_use_pinecone():
logger.info(f"Deleting file {file_id} from pinecone index")
await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor)
@@ -481,10 +494,15 @@ async def load_file_to_source_cloud(
else:
file_parser = MarkitdownFileParser()
using_pinecone = should_use_pinecone()
if using_pinecone:
# determine which embedder to use - turbopuffer takes precedence
if should_use_tpuf():
from letta.services.file_processor.embedder.turbopuffer_embedder import TurbopufferEmbedder
embedder = TurbopufferEmbedder(embedding_config=embedding_config)
elif should_use_pinecone():
embedder = PineconeEmbedder(embedding_config=embedding_config)
else:
embedder = OpenAIEmbedder(embedding_config=embedding_config)
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor, using_pinecone=using_pinecone)
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor)
await file_processor.process(agent_states=agent_states, source_id=source_id, content=content, file_metadata=file_metadata)

View File

@@ -12,6 +12,7 @@ from letta.errors import (
AgentNotFoundForExportError,
)
from letta.helpers.pinecone_utils import should_use_pinecone
from letta.helpers.tpuf_client import should_use_tpuf
from letta.log import get_logger
from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.agent_file import (
@@ -29,7 +30,7 @@ from letta.schemas.agent_file import (
)
from letta.schemas.block import Block
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import FileProcessingStatus
from letta.schemas.enums import FileProcessingStatus, VectorDBProvider
from letta.schemas.file import FileMetadata
from letta.schemas.group import Group, GroupCreate
from letta.schemas.mcp import MCPServer
@@ -90,7 +91,6 @@ class AgentSerializationManager:
self.file_agent_manager = file_agent_manager
self.message_manager = message_manager
self.file_parser = MistralFileParser() if settings.mistral_api_key else MarkitdownFileParser()
self.using_pinecone = should_use_pinecone()
# ID mapping state for export
self._db_to_file_ids: Dict[str, str] = {}
@@ -588,7 +588,12 @@ class AgentSerializationManager:
if schema.files and any(f.content for f in schema.files):
# Use override embedding config if provided, otherwise use agent's config
embedder_config = override_embedding_config if override_embedding_config else schema.agents[0].embedding_config
if should_use_pinecone():
# determine which embedder to use - turbopuffer takes precedence
if should_use_tpuf():
from letta.services.file_processor.embedder.turbopuffer_embedder import TurbopufferEmbedder
embedder = TurbopufferEmbedder(embedding_config=embedder_config)
elif should_use_pinecone():
embedder = PineconeEmbedder(embedding_config=embedder_config)
else:
embedder = OpenAIEmbedder(embedding_config=embedder_config)
@@ -596,7 +601,6 @@ class AgentSerializationManager:
file_parser=self.file_parser,
embedder=embedder,
actor=actor,
using_pinecone=self.using_pinecone,
)
for file_schema in schema.files:

View File

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import List
from letta.log import get_logger
from letta.schemas.enums import VectorDBProvider
from letta.schemas.passage import Passage
from letta.schemas.user import User
@@ -11,6 +12,10 @@ logger = get_logger(__name__)
class BaseEmbedder(ABC):
"""Abstract base class for embedding generation"""
def __init__(self):
# Default to NATIVE, subclasses will override this
self.vector_db_type = VectorDBProvider.NATIVE
@abstractmethod
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
"""Generate embeddings for chunks with batching and concurrent processing"""

View File

@@ -19,6 +19,10 @@ class OpenAIEmbedder(BaseEmbedder):
"""OpenAI-based embedding generation"""
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
super().__init__()
# OpenAI embedder uses the native vector db (PostgreSQL)
# self.vector_db_type already set to VectorDBProvider.NATIVE by parent
self.default_embedding_config = (
EmbeddingConfig.default_config(model_name="text-embedding-3-small", provider="openai")
if model_settings.openai_api_key

View File

@@ -4,6 +4,7 @@ from letta.helpers.pinecone_utils import upsert_file_records_to_pinecone_index
from letta.log import get_logger
from letta.otel.tracing import log_event, trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import VectorDBProvider
from letta.schemas.passage import Passage
from letta.schemas.user import User
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
@@ -20,6 +21,10 @@ class PineconeEmbedder(BaseEmbedder):
"""Pinecone-based embedding generation"""
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
super().__init__()
# set the vector db type for pinecone
self.vector_db_type = VectorDBProvider.PINECONE
if not PINECONE_AVAILABLE:
raise ImportError("Pinecone package is not installed. Install it with: pip install pinecone")
@@ -28,7 +33,6 @@ class PineconeEmbedder(BaseEmbedder):
embedding_config = EmbeddingConfig.default_config(provider="pinecone")
self.embedding_config = embedding_config
super().__init__()
@trace_method
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:

View File

@@ -0,0 +1,71 @@
from typing import List, Optional
from letta.helpers.tpuf_client import TurbopufferClient
from letta.log import get_logger
from letta.otel.tracing import log_event, trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import VectorDBProvider
from letta.schemas.passage import Passage
from letta.schemas.user import User
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
logger = get_logger(__name__)
class TurbopufferEmbedder(BaseEmbedder):
"""Turbopuffer-based embedding generation and storage"""
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
super().__init__()
# set the vector db type for turbopuffer
self.vector_db_type = VectorDBProvider.TPUF
# use the default embedding config from TurbopufferClient if not provided
self.embedding_config = embedding_config or TurbopufferClient.default_embedding_config
self.tpuf_client = TurbopufferClient()
@trace_method
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
"""Generate embeddings and store in Turbopuffer, then return Passage objects"""
if not chunks:
return []
logger.info(f"Generating embeddings for {len(chunks)} chunks using Turbopuffer")
log_event(
"turbopuffer_embedder.generation_started",
{
"total_chunks": len(chunks),
"file_id": file_id,
"source_id": source_id,
"embedding_model": self.embedding_config.embedding_model,
},
)
try:
# insert passages to Turbopuffer - it will handle embedding generation internally
passages = await self.tpuf_client.insert_file_passages(
source_id=source_id,
file_id=file_id,
text_chunks=chunks,
organization_id=actor.organization_id,
actor=actor,
)
logger.info(f"Successfully generated and stored {len(passages)} passages in Turbopuffer")
log_event(
"turbopuffer_embedder.generation_completed",
{
"passages_created": len(passages),
"total_chunks_processed": len(chunks),
"file_id": file_id,
"source_id": source_id,
},
)
return passages
except Exception as e:
logger.error(f"Failed to generate embeddings with Turbopuffer: {str(e)}")
log_event(
"turbopuffer_embedder.generation_failed",
{"error": str(e), "error_type": type(e).__name__, "file_id": file_id, "source_id": source_id},
)
raise

View File

@@ -6,7 +6,7 @@ from letta.log import get_logger
from letta.otel.context import get_ctx_attributes
from letta.otel.tracing import log_event, trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import FileProcessingStatus
from letta.schemas.enums import FileProcessingStatus, VectorDBProvider
from letta.schemas.file import FileMetadata
from letta.schemas.passage import Passage
from letta.schemas.user import User
@@ -30,7 +30,6 @@ class FileProcessor:
file_parser: FileParser,
embedder: BaseEmbedder,
actor: User,
using_pinecone: bool,
max_file_size: int = 50 * 1024 * 1024, # 50MB default
):
self.file_parser = file_parser
@@ -42,7 +41,8 @@ class FileProcessor:
self.job_manager = JobManager()
self.agent_manager = AgentManager()
self.actor = actor
self.using_pinecone = using_pinecone
# get vector db type from the embedder
self.vector_db_type = embedder.vector_db_type
async def _chunk_and_embed_with_fallback(self, file_metadata: FileMetadata, ocr_response, source_id: str) -> List:
"""Chunk text and generate embeddings with fallback to default chunker if needed"""
@@ -218,7 +218,7 @@ class FileProcessor:
source_id=source_id,
)
if not self.using_pinecone:
if self.vector_db_type == VectorDBProvider.NATIVE:
all_passages = await self.passage_manager.create_many_source_passages_async(
passages=all_passages,
file_metadata=file_metadata,
@@ -241,7 +241,8 @@ class FileProcessor:
)
# update job status
if not self.using_pinecone:
# pinecone completes slowly, so gets updated later
if self.vector_db_type != VectorDBProvider.PINECONE:
await self.file_manager.update_file_status(
file_id=file_metadata.id,
actor=self.actor,
@@ -317,14 +318,15 @@ class FileProcessor:
)
# Create passages in database (unless using Pinecone)
if not self.using_pinecone:
if self.vector_db_type == VectorDBProvider.NATIVE:
all_passages = await self.passage_manager.create_many_source_passages_async(
passages=all_passages, file_metadata=file_metadata, actor=self.actor
)
log_event("file_processor.import_passages_created", {"filename": filename, "total_passages": len(all_passages)})
# Update file status to completed (valid transition from EMBEDDING)
if not self.using_pinecone:
# pinecone completes slowly, so gets updated later
if self.vector_db_type != VectorDBProvider.PINECONE:
await self.file_manager.update_file_status(
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED
)

View File

@@ -3,6 +3,7 @@ from typing import List, Optional, Union
from sqlalchemy import and_, exists, select
from letta.helpers.pinecone_utils import should_use_pinecone
from letta.helpers.tpuf_client import should_use_tpuf
from letta.orm import Agent as AgentModel
from letta.orm.errors import NoResultFound
@@ -18,6 +19,18 @@ from letta.utils import enforce_types, printd
class SourceManager:
def _get_vector_db_provider(self) -> VectorDBProvider:
"""
determine which vector db provider to use based on configuration.
turbopuffer takes precedence when available.
"""
if should_use_tpuf():
return VectorDBProvider.TPUF
elif should_use_pinecone():
return VectorDBProvider.PINECONE
else:
return VectorDBProvider.NATIVE
"""Manager class to handle business logic related to Sources."""
@trace_method
@@ -52,7 +65,7 @@ class SourceManager:
if db_source:
return db_source
else:
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
vector_db_provider = self._get_vector_db_provider()
async with db_registry.async_session() as session:
# Provide default embedding config if not given
@@ -96,7 +109,7 @@ class SourceManager:
Returns:
List of created/updated sources
"""
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
vector_db_provider = self._get_vector_db_provider()
for pydantic_source in pydantic_sources:
pydantic_source.vector_db_provider = vector_db_provider

View File

@@ -5,10 +5,13 @@ from typing import Any, Dict, List, Optional
from letta.constants import PINECONE_TEXT_FIELD_NAME
from letta.functions.types import FileOpenRequest
from letta.helpers.pinecone_utils import search_pinecone_index, should_use_pinecone
from letta.helpers.tpuf_client import should_use_tpuf
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import VectorDBProvider
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User
@@ -554,18 +557,140 @@ class LettaFileToolExecutor(ToolExecutor):
self.logger.info(f"Semantic search started for agent {agent_state.id} with query '{query}' (limit: {limit})")
# Check if Pinecone is enabled and use it if available
if should_use_pinecone():
return await self._search_files_pinecone(agent_state, query, limit)
else:
return await self._search_files_traditional(agent_state, query, limit)
# Check which vector DB to use - Turbopuffer takes precedence
attached_sources = await self.agent_manager.list_attached_sources_async(agent_id=agent_state.id, actor=self.actor)
attached_tpuf_sources = [source for source in attached_sources if source.vector_db_provider == VectorDBProvider.TPUF]
attached_pinecone_sources = [source for source in attached_sources if source.vector_db_provider == VectorDBProvider.PINECONE]
async def _search_files_pinecone(self, agent_state: AgentState, query: str, limit: int) -> str:
if not attached_tpuf_sources and not attached_pinecone_sources:
return await self._search_files_native(agent_state, query, limit)
results = []
# If both have items, we half the limit roughly
# TODO: This is very hacky bc it skips the re-ranking - but this is a temporary stopgap while we think about migrating data
if attached_tpuf_sources and attached_pinecone_sources:
limit = max(limit // 2, 1)
if should_use_tpuf() and attached_tpuf_sources:
tpuf_result = await self._search_files_turbopuffer(agent_state, attached_tpuf_sources, query, limit)
results.append(tpuf_result)
if should_use_pinecone() and attached_pinecone_sources:
pinecone_result = await self._search_files_pinecone(agent_state, attached_pinecone_sources, query, limit)
results.append(pinecone_result)
# combine results from both sources
if results:
return "\n\n".join(results)
# fallback if no results from either source
return "No results found"
async def _search_files_turbopuffer(self, agent_state: AgentState, attached_sources: List[Source], query: str, limit: int) -> str:
"""Search files using Turbopuffer vector database."""
# Get attached sources
source_ids = [source.id for source in attached_sources]
if not source_ids:
return "No valid source IDs found for attached files"
# Get all attached files for this agent
file_agents = await self.files_agents_manager.list_files_for_agent(
agent_id=agent_state.id, per_file_view_window_char_limit=agent_state.per_file_view_window_char_limit, actor=self.actor
)
if not file_agents:
return "No files are currently attached to search"
# Create a map of file_id to file_name for quick lookup
file_map = {fa.file_id: fa.file_name for fa in file_agents}
results = []
total_hits = 0
files_with_matches = {}
try:
from letta.helpers.tpuf_client import TurbopufferClient
tpuf_client = TurbopufferClient()
# Query Turbopuffer for all sources at once
search_results = await tpuf_client.query_file_passages(
source_ids=source_ids, # pass all source_ids as a list
organization_id=self.actor.organization_id,
actor=self.actor,
query_text=query,
search_mode="hybrid", # use hybrid search for best results
top_k=limit,
)
# Process search results
for passage, score, metadata in search_results:
if total_hits >= limit:
break
total_hits += 1
# get file name from our map
file_name = file_map.get(passage.file_id, "Unknown File")
# group by file name
if file_name not in files_with_matches:
files_with_matches[file_name] = []
files_with_matches[file_name].append({"text": passage.text, "score": score, "passage_id": passage.id})
except Exception as e:
self.logger.error(f"Turbopuffer search failed: {str(e)}")
raise e
if not files_with_matches:
return f"No semantic matches found in Turbopuffer for query: '{query}'"
# Format results
passage_num = 0
for file_name, matches in files_with_matches.items():
for match in matches:
passage_num += 1
# format each passage with terminal-style header
score_display = f"(score: {match['score']:.3f})"
passage_header = f"\n=== {file_name} (passage #{passage_num}) {score_display} ==="
# format the passage text
passage_text = match["text"].strip()
lines = passage_text.splitlines()
formatted_lines = []
for line in lines[:20]: # limit to first 20 lines per passage
formatted_lines.append(f" {line}")
if len(lines) > 20:
formatted_lines.append(f" ... [truncated {len(lines) - 20} more lines]")
passage_content = "\n".join(formatted_lines)
results.append(f"{passage_header}\n{passage_content}")
# mark access for files that had matches
if files_with_matches:
matched_file_names = [name for name in files_with_matches.keys() if name != "Unknown File"]
if matched_file_names:
await self.files_agents_manager.mark_access_bulk(agent_id=agent_state.id, file_names=matched_file_names, actor=self.actor)
# create summary header
file_count = len(files_with_matches)
summary = f"Found {total_hits} Turbopuffer matches in {file_count} file{'s' if file_count != 1 else ''} for query: '{query}'"
# combine all results
formatted_results = [summary, "=" * len(summary)] + results
self.logger.info(f"Turbopuffer search completed: {total_hits} matches across {file_count} files")
return "\n".join(formatted_results)
async def _search_files_pinecone(self, agent_state: AgentState, attached_sources: List[Source], query: str, limit: int) -> str:
"""Search files using Pinecone vector database."""
# Extract unique source_ids
# TODO: Inefficient
attached_sources = await self.agent_manager.list_attached_sources_async(agent_id=agent_state.id, actor=self.actor)
source_ids = [source.id for source in attached_sources]
if not source_ids:
return "No valid source IDs found for attached files"
@@ -658,7 +783,7 @@ class LettaFileToolExecutor(ToolExecutor):
self.logger.info(f"Pinecone search completed: {total_hits} matches across {file_count} files")
return "\n".join(formatted_results)
async def _search_files_traditional(self, agent_state: AgentState, query: str, limit: int) -> str:
async def _search_files_native(self, agent_state: AgentState, query: str, limit: int) -> str:
"""Traditional search using existing passage manager."""
# Get semantic search results
passages = await self.agent_manager.query_source_passages_async(

View File

@@ -258,7 +258,7 @@ class TestFileProcessorWithPinecone:
embedder = PineconeEmbedder()
# Create file processor with Pinecone enabled
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=mock_actor, using_pinecone=True)
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=mock_actor)
# Track file manager update calls
update_calls = []

View File

@@ -13,6 +13,7 @@ from letta_client.types import AgentState
from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS
from letta.helpers.pinecone_utils import should_use_pinecone
from letta.helpers.tpuf_client import TurbopufferClient
from letta.schemas.enums import FileProcessingStatus, ToolType
from letta.schemas.message import MessageCreate
from letta.schemas.user import User
@@ -95,7 +96,7 @@ def agent_state(disable_pinecone, client: LettaSDKClient):
# Tests
def test_auto_attach_detach_files_tools(disable_pinecone, client: LettaSDKClient):
def test_auto_attach_detach_files_tools(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
"""Test automatic attachment and detachment of file tools when managing agent sources."""
# Create agent with basic configuration
agent = client.agents.create(
@@ -168,6 +169,7 @@ def test_auto_attach_detach_files_tools(disable_pinecone, client: LettaSDKClient
)
def test_file_upload_creates_source_blocks_correctly(
disable_pinecone,
disable_turbopuffer,
client: LettaSDKClient,
agent_state: AgentState,
file_path: str,
@@ -237,7 +239,9 @@ def test_file_upload_creates_source_blocks_correctly(
settings.mistral_api_key = original_mistral_key
def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
def test_attach_existing_files_creates_source_blocks_correctly(
disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState
):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
assert len(client.sources.list()) == 1
@@ -302,7 +306,9 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone,
assert "<directories>" not in raw_system_message_after_detach
def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
def test_delete_source_removes_source_blocks_correctly(
disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState
):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
assert len(client.sources.list()) == 1
@@ -360,7 +366,7 @@ def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client:
assert not any("test" in b.value for b in blocks)
def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
def test_agent_uses_open_close_file_correctly(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -463,7 +469,7 @@ def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDK
print("✓ File successfully opened with different range - content differs as expected")
def test_agent_uses_search_files_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
def test_agent_uses_search_files_correctly(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -509,7 +515,7 @@ def test_agent_uses_search_files_correctly(disable_pinecone, client: LettaSDKCli
assert all(tr.status == "success" for tr in tool_returns), f"Tool call failed {tr}"
def test_agent_uses_grep_correctly_basic(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
def test_agent_uses_grep_correctly_basic(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -551,7 +557,7 @@ def test_agent_uses_grep_correctly_basic(disable_pinecone, client: LettaSDKClien
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
def test_agent_uses_grep_correctly_advanced(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
def test_agent_uses_grep_correctly_advanced(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -599,7 +605,7 @@ def test_agent_uses_grep_correctly_advanced(disable_pinecone, client: LettaSDKCl
assert "511:" in tool_return_message.tool_return
def test_create_agent_with_source_ids_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient):
def test_create_agent_with_source_ids_creates_source_blocks_correctly(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
"""Test that creating an agent with source_ids parameter correctly creates source blocks."""
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -642,7 +648,7 @@ def test_create_agent_with_source_ids_creates_source_blocks_correctly(disable_pi
assert file_tools == set(FILES_TOOLS)
def test_view_ranges_have_metadata(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
def test_view_ranges_have_metadata(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
@@ -705,7 +711,7 @@ def test_view_ranges_have_metadata(disable_pinecone, client: LettaSDKClient, age
)
def test_duplicate_file_renaming(disable_pinecone, client: LettaSDKClient):
def test_duplicate_file_renaming(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
"""Test that duplicate files are renamed with count-based suffixes (e.g., file.txt, file (1).txt, file (2).txt)"""
# Create a new source
source = client.sources.create(name="test_duplicate_source", embedding="openai/text-embedding-3-small")
@@ -744,7 +750,7 @@ def test_duplicate_file_renaming(disable_pinecone, client: LettaSDKClient):
print(f" File {i + 1}: original='{file.original_file_name}' → renamed='{file.file_name}'")
def test_duplicate_file_handling_replace(disable_pinecone, client: LettaSDKClient):
def test_duplicate_file_handling_replace(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
"""Test that DuplicateFileHandling.REPLACE replaces existing files with same name"""
# Create a new source
source = client.sources.create(name="test_replace_source", embedding="openai/text-embedding-3-small")
@@ -826,7 +832,7 @@ def test_duplicate_file_handling_replace(disable_pinecone, client: LettaSDKClien
os.unlink(temp_file_path)
def test_upload_file_with_custom_name(disable_pinecone, client: LettaSDKClient):
def test_upload_file_with_custom_name(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
"""Test that uploading a file with a custom name overrides the original filename"""
# Create agent
agent_state = client.agents.create(
@@ -907,7 +913,7 @@ def test_upload_file_with_custom_name(disable_pinecone, client: LettaSDKClient):
os.unlink(temp_file_path)
def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient):
def test_open_files_schema_descriptions(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
"""Test that open_files tool schema contains correct descriptions from docstring"""
# Get the open_files tool
@@ -990,7 +996,7 @@ def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient
assert length_prop["type"] == "integer"
def test_grep_files_schema_descriptions(disable_pinecone, client: LettaSDKClient):
def test_grep_files_schema_descriptions(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
"""Test that grep_files tool schema contains correct descriptions from docstring"""
# Get the grep_files tool
@@ -1076,10 +1082,174 @@ def test_grep_files_schema_descriptions(disable_pinecone, client: LettaSDKClient
assert "Navigation hint for next page if more matches exist" in description
def test_agent_open_file(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
"""Test client.agents.open_file() function"""
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
# Attach source to agent
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
# Upload a file
file_path = "tests/data/test.txt"
file_metadata = upload_file_and_wait(client, source.id, file_path)
# Basic test open_file function
closed_files = client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id)
assert len(closed_files) == 0
system = get_raw_system_message(client, agent_state.id)
assert '<file status="open" name="test_source/test.txt">' in system
assert "[Viewing file start (out of 1 lines)]" in system
def test_agent_close_file(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
"""Test client.agents.close_file() function"""
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
# Attach source to agent
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
# Upload a file
file_path = "tests/data/test.txt"
file_metadata = upload_file_and_wait(client, source.id, file_path)
# First open the file
client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id)
# Test close_file function
client.agents.files.close(agent_id=agent_state.id, file_id=file_metadata.id)
system = get_raw_system_message(client, agent_state.id)
assert '<file status="closed" name="test_source/test.txt">' in system
def test_agent_close_all_open_files(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
"""Test client.agents.close_all_open_files() function"""
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
# Attach source to agent
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
# Upload multiple files
file_paths = ["tests/data/test.txt", "tests/data/test.md"]
file_metadatas = []
for file_path in file_paths:
file_metadata = upload_file_and_wait(client, source.id, file_path)
file_metadatas.append(file_metadata)
# Open each file
client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id)
system = get_raw_system_message(client, agent_state.id)
assert '<file status="open"' in system
# Test close_all_open_files function
result = client.agents.files.close_all(agent_id=agent_state.id)
# Verify result is a list of strings
assert isinstance(result, list), f"Expected list, got {type(result)}"
assert all(isinstance(item, str) for item in result), "All items in result should be strings"
system = get_raw_system_message(client, agent_state.id)
assert '<file status="open"' not in system
def test_file_processing_timeout(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
"""Test that files in non-terminal states are moved to error after timeout"""
# Create a source
source = client.sources.create(name="test_timeout_source", embedding="openai/text-embedding-3-small")
# Upload a file
file_path = "tests/data/test.txt"
with open(file_path, "rb") as f:
file_metadata = client.sources.files.upload(source_id=source.id, file=f)
# Get the file ID
file_id = file_metadata.id
# Test the is_terminal_state method directly (this doesn't require server mocking)
assert FileProcessingStatus.COMPLETED.is_terminal_state() == True
assert FileProcessingStatus.ERROR.is_terminal_state() == True
assert FileProcessingStatus.PARSING.is_terminal_state() == False
assert FileProcessingStatus.EMBEDDING.is_terminal_state() == False
assert FileProcessingStatus.PENDING.is_terminal_state() == False
# For testing the actual timeout logic, we can check the current file status
current_file = client.sources.get_file_metadata(source_id=source.id, file_id=file_id)
# Convert string status to enum for testing
status_enum = FileProcessingStatus(current_file.processing_status)
# Verify that files in terminal states are not affected by timeout checks
if status_enum.is_terminal_state():
# This is the expected behavior - files that completed processing shouldn't timeout
print(f"File {file_id} is in terminal state: {current_file.processing_status}")
assert status_enum in [FileProcessingStatus.COMPLETED, FileProcessingStatus.ERROR]
else:
# If file is still processing, it should eventually complete or timeout
# In a real scenario, we'd wait and check, but for unit tests we just verify the logic exists
print(f"File {file_id} is still processing: {current_file.processing_status}")
assert status_enum in [FileProcessingStatus.PENDING, FileProcessingStatus.PARSING, FileProcessingStatus.EMBEDDING]
@pytest.mark.unit
def test_file_processing_timeout_logic():
"""Test the timeout logic directly without server dependencies"""
from datetime import timezone
# Test scenario: file created 35 minutes ago, timeout is 30 minutes
old_time = datetime.now(timezone.utc) - timedelta(minutes=35)
current_time = datetime.now(timezone.utc)
timeout_minutes = 30
# Calculate timeout threshold
timeout_threshold = current_time - timedelta(minutes=timeout_minutes)
# Verify timeout logic
assert old_time < timeout_threshold, "File created 35 minutes ago should be past 30-minute timeout"
# Test edge case: file created exactly at timeout
edge_time = current_time - timedelta(minutes=timeout_minutes)
assert not (edge_time < timeout_threshold), "File created exactly at timeout should not trigger timeout"
# Test recent file
recent_time = current_time - timedelta(minutes=10)
assert not (recent_time < timeout_threshold), "Recent file should not trigger timeout"
def test_letta_free_embedding(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
"""Test creating a source with letta/letta-free embedding and uploading a file"""
# create a source with letta-free embedding
source = client.sources.create(name="test_letta_free_source", embedding="letta/letta-free")
# verify source was created with correct embedding
assert source.name == "test_letta_free_source"
print("\n\n\n\ntest")
print(source.embedding_config)
# assert source.embedding_config.embedding_model == "letta-free"
# upload test.txt file
file_path = "tests/data/test.txt"
file_metadata = upload_file_and_wait(client, source.id, file_path)
# verify file was uploaded successfully
assert file_metadata.processing_status == "completed"
assert file_metadata.source_id == source.id
assert file_metadata.file_name == "test.txt"
# verify file appears in source files list
files = client.sources.files.list(source_id=source.id, limit=1)
assert len(files) == 1
assert files[0].id == file_metadata.id
# cleanup
client.sources.delete(source_id=source.id)
# --- Pinecone Tests ---
def test_pinecone_search_files_tool(client: LettaSDKClient):
def test_pinecone_search_files_tool(disable_turbopuffer, client: LettaSDKClient):
"""Test that search_files tool uses Pinecone when enabled"""
from letta.helpers.pinecone_utils import should_use_pinecone
@@ -1130,7 +1300,7 @@ def test_pinecone_search_files_tool(client: LettaSDKClient):
)
def test_pinecone_list_files_status(client: LettaSDKClient):
def test_pinecone_list_files_status(disable_turbopuffer, client: LettaSDKClient):
"""Test that list_source_files properly syncs embedding status with Pinecone"""
if not should_use_pinecone():
pytest.skip("Pinecone not configured (missing API key or disabled), skipping Pinecone-specific tests")
@@ -1165,7 +1335,7 @@ def test_pinecone_list_files_status(client: LettaSDKClient):
client.sources.delete(source_id=source.id)
def test_pinecone_lifecycle_file_and_source_deletion(client: LettaSDKClient):
def test_pinecone_lifecycle_file_and_source_deletion(disable_turbopuffer, client: LettaSDKClient):
"""Test that file and source deletion removes records from Pinecone"""
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
@@ -1236,167 +1406,196 @@ def test_pinecone_lifecycle_file_and_source_deletion(client: LettaSDKClient):
)
def test_agent_open_file(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
"""Test client.agents.open_file() function"""
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
# Attach source to agent
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
# Upload a file
file_path = "tests/data/test.txt"
file_metadata = upload_file_and_wait(client, source.id, file_path)
# Basic test open_file function
closed_files = client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id)
assert len(closed_files) == 0
system = get_raw_system_message(client, agent_state.id)
assert '<file status="open" name="test_source/test.txt">' in system
assert "[Viewing file start (out of 1 lines)]" in system
# --- End Pinecone Tests ---
def test_agent_close_file(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
"""Test client.agents.close_file() function"""
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
# --- Turbopuffer Tests ---
def test_turbopuffer_search_files_tool(disable_pinecone, client: LettaSDKClient):
"""Test that search_files tool uses Turbopuffer when enabled"""
agent = client.agents.create(
name="test_turbopuffer_agent",
memory_blocks=[
CreateBlock(label="human", value="username: testuser"),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
)
# Attach source to agent
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
source = client.sources.create(name="test_turbopuffer_source", embedding="openai/text-embedding-3-small")
client.agents.sources.attach(source_id=source.id, agent_id=agent.id)
# Upload a file
file_path = "tests/data/test.txt"
file_metadata = upload_file_and_wait(client, source.id, file_path)
file_path = "tests/data/long_test.txt"
upload_file_and_wait(client, source.id, file_path)
# First open the file
client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id)
search_response = client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Use the semantic_search_files tool to search for 'electoral history' in the files.")],
)
# Test close_file function
client.agents.files.close(agent_id=agent_state.id, file_id=file_metadata.id)
tool_calls = [msg for msg in search_response.messages if msg.message_type == "tool_call_message"]
assert len(tool_calls) > 0, "No tool calls found"
assert any(tc.tool_call.name == "semantic_search_files" for tc in tool_calls), "semantic_search_files not called"
system = get_raw_system_message(client, agent_state.id)
assert '<file status="closed" name="test_source/test.txt">' in system
tool_returns = [msg for msg in search_response.messages if msg.message_type == "tool_return_message"]
assert len(tool_returns) > 0, "No tool returns found"
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
search_results = tool_returns[0].tool_return
print(f"Turbopuffer search results: {search_results}")
assert "electoral" in search_results.lower() or "history" in search_results.lower(), (
f"Search results should contain relevant content: {search_results}"
)
client.agents.delete(agent_id=agent.id)
client.sources.delete(source_id=source.id)
def test_agent_close_all_open_files(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
"""Test client.agents.close_all_open_files() function"""
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
def test_turbopuffer_file_processing_status(disable_pinecone, client: LettaSDKClient):
"""Test that file processing completes successfully with Turbopuffer"""
print("Testing Turbopuffer file processing status")
# Attach source to agent
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
source = client.sources.create(name="test_tpuf_file_status", embedding="openai/text-embedding-3-small")
# Upload multiple files
file_paths = ["tests/data/test.txt", "tests/data/test.md"]
file_metadatas = []
file_paths = ["tests/data/long_test.txt", "tests/data/test.md"]
uploaded_files = []
for file_path in file_paths:
file_metadata = upload_file_and_wait(client, source.id, file_path)
file_metadatas.append(file_metadata)
# Open each file
client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id)
uploaded_files.append(file_metadata)
assert file_metadata.processing_status == "completed", f"File {file_path} should be completed"
system = get_raw_system_message(client, agent_state.id)
assert '<file status="open"' in system
files_list = client.sources.files.list(source_id=source.id, limit=100)
# Test close_all_open_files function
result = client.agents.files.close_all(agent_id=agent_state.id)
assert len(files_list) == len(uploaded_files), f"Expected {len(uploaded_files)} files, got {len(files_list)}"
# Verify result is a list of strings
assert isinstance(result, list), f"Expected list, got {type(result)}"
assert all(isinstance(item, str) for item in result), "All items in result should be strings"
for file_metadata in files_list:
assert file_metadata.processing_status == "completed", f"File {file_metadata.file_name} should show completed status"
system = get_raw_system_message(client, agent_state.id)
assert '<file status="open"' not in system
if file_metadata.total_chunks and file_metadata.total_chunks > 0:
assert file_metadata.chunks_embedded == file_metadata.total_chunks, (
f"File {file_metadata.file_name} should have all chunks embedded: {file_metadata.chunks_embedded}/{file_metadata.total_chunks}"
)
def test_file_processing_timeout(disable_pinecone, client: LettaSDKClient):
"""Test that files in non-terminal states are moved to error after timeout"""
# Create a source
source = client.sources.create(name="test_timeout_source", embedding="openai/text-embedding-3-small")
# Upload a file
file_path = "tests/data/test.txt"
with open(file_path, "rb") as f:
file_metadata = client.sources.files.upload(source_id=source.id, file=f)
# Get the file ID
file_id = file_metadata.id
# Test the is_terminal_state method directly (this doesn't require server mocking)
assert FileProcessingStatus.COMPLETED.is_terminal_state() == True
assert FileProcessingStatus.ERROR.is_terminal_state() == True
assert FileProcessingStatus.PARSING.is_terminal_state() == False
assert FileProcessingStatus.EMBEDDING.is_terminal_state() == False
assert FileProcessingStatus.PENDING.is_terminal_state() == False
# For testing the actual timeout logic, we can check the current file status
current_file = client.sources.get_file_metadata(source_id=source.id, file_id=file_id)
# Convert string status to enum for testing
status_enum = FileProcessingStatus(current_file.processing_status)
# Verify that files in terminal states are not affected by timeout checks
if status_enum.is_terminal_state():
# This is the expected behavior - files that completed processing shouldn't timeout
print(f"File {file_id} is in terminal state: {current_file.processing_status}")
assert status_enum in [FileProcessingStatus.COMPLETED, FileProcessingStatus.ERROR]
else:
# If file is still processing, it should eventually complete or timeout
# In a real scenario, we'd wait and check, but for unit tests we just verify the logic exists
print(f"File {file_id} is still processing: {current_file.processing_status}")
assert status_enum in [FileProcessingStatus.PENDING, FileProcessingStatus.PARSING, FileProcessingStatus.EMBEDDING]
@pytest.mark.unit
def test_file_processing_timeout_logic():
"""Test the timeout logic directly without server dependencies"""
from datetime import timezone
# Test scenario: file created 35 minutes ago, timeout is 30 minutes
old_time = datetime.now(timezone.utc) - timedelta(minutes=35)
current_time = datetime.now(timezone.utc)
timeout_minutes = 30
# Calculate timeout threshold
timeout_threshold = current_time - timedelta(minutes=timeout_minutes)
# Verify timeout logic
assert old_time < timeout_threshold, "File created 35 minutes ago should be past 30-minute timeout"
# Test edge case: file created exactly at timeout
edge_time = current_time - timedelta(minutes=timeout_minutes)
assert not (edge_time < timeout_threshold), "File created exactly at timeout should not trigger timeout"
# Test recent file
recent_time = current_time - timedelta(minutes=10)
assert not (recent_time < timeout_threshold), "Recent file should not trigger timeout"
def test_letta_free_embedding(disable_pinecone, client: LettaSDKClient):
"""Test creating a source with letta/letta-free embedding and uploading a file"""
# create a source with letta-free embedding
source = client.sources.create(name="test_letta_free_source", embedding="letta/letta-free")
# verify source was created with correct embedding
assert source.name == "test_letta_free_source"
print("\n\n\n\ntest")
print(source.embedding_config)
# assert source.embedding_config.embedding_model == "letta-free"
# upload test.txt file
file_path = "tests/data/test.txt"
file_metadata = upload_file_and_wait(client, source.id, file_path)
# verify file was uploaded successfully
assert file_metadata.processing_status == "completed"
assert file_metadata.source_id == source.id
assert file_metadata.file_name == "test.txt"
# verify file appears in source files list
files = client.sources.files.list(source_id=source.id, limit=1)
assert len(files) == 1
assert files[0].id == file_metadata.id
# cleanup
client.sources.delete(source_id=source.id)
def test_turbopuffer_lifecycle_file_and_source_deletion(disable_pinecone, client: LettaSDKClient):
"""Test that file and source deletion removes records from Turbopuffer"""
source = client.sources.create(name="test_tpuf_lifecycle", embedding="openai/text-embedding-3-small")
file_paths = ["tests/data/test.txt", "tests/data/test.md"]
uploaded_files = []
for file_path in file_paths:
file_metadata = upload_file_and_wait(client, source.id, file_path)
uploaded_files.append(file_metadata)
user = User(name="temp", organization_id=DEFAULT_ORG_ID)
tpuf_client = TurbopufferClient()
# test file-level deletion
if len(uploaded_files) > 1:
file_to_delete = uploaded_files[0]
passages_before = asyncio.run(
tpuf_client.query_file_passages(
source_ids=[source.id], organization_id=user.organization_id, actor=user, file_id=file_to_delete.id, top_k=100
)
)
print(f"Found {len(passages_before)} passages for file before deletion")
assert len(passages_before) > 0, "Should have passages before deletion"
client.sources.files.delete(source_id=source.id, file_id=file_to_delete.id)
time.sleep(2)
passages_after = asyncio.run(
tpuf_client.query_file_passages(
source_ids=[source.id], organization_id=user.organization_id, actor=user, file_id=file_to_delete.id, top_k=100
)
)
print(f"Found {len(passages_after)} passages for file after deletion")
assert len(passages_after) == 0, f"File passages should be removed from Turbopuffer after deletion, but found {len(passages_after)}"
# test source-level deletion
remaining_passages_before = []
for file_metadata in uploaded_files[1:]:
passages = asyncio.run(
tpuf_client.query_file_passages(
source_ids=[source.id], organization_id=user.organization_id, actor=user, file_id=file_metadata.id, top_k=100
)
)
remaining_passages_before.extend(passages)
print(f"Found {len(remaining_passages_before)} passages for remaining files before source deletion")
assert len(remaining_passages_before) > 0, "Should have passages for remaining files"
client.sources.delete(source_id=source.id)
time.sleep(3)
remaining_passages_after = []
for file_metadata in uploaded_files[1:]:
try:
passages = asyncio.run(
tpuf_client.query_file_passages(
source_ids=[source.id], organization_id=user.organization_id, actor=user, file_id=file_metadata.id, top_k=100
)
)
remaining_passages_after.extend(passages)
except Exception as e:
print(f"Expected error querying deleted source: {e}")
print(f"Found {len(remaining_passages_after)} passages for files after source deletion")
assert len(remaining_passages_after) == 0, (
f"All source passages should be removed from Turbopuffer after source deletion, but found {len(remaining_passages_after)}"
)
def test_turbopuffer_multiple_sources(disable_pinecone, client: LettaSDKClient):
"""Test that Turbopuffer correctly isolates passages by source in org-scoped namespace"""
source1 = client.sources.create(name="test_tpuf_source1", embedding="openai/text-embedding-3-small")
source2 = client.sources.create(name="test_tpuf_source2", embedding="openai/text-embedding-3-small")
file1_metadata = upload_file_and_wait(client, source1.id, "tests/data/test.txt")
file2_metadata = upload_file_and_wait(client, source2.id, "tests/data/test.md")
user = User(name="temp", organization_id=DEFAULT_ORG_ID)
tpuf_client = TurbopufferClient()
source1_passages = asyncio.run(
tpuf_client.query_file_passages(source_ids=[source1.id], organization_id=user.organization_id, actor=user, top_k=100)
)
source2_passages = asyncio.run(
tpuf_client.query_file_passages(source_ids=[source2.id], organization_id=user.organization_id, actor=user, top_k=100)
)
print(f"Source1 has {len(source1_passages)} passages")
print(f"Source2 has {len(source2_passages)} passages")
assert len(source1_passages) > 0, "Source1 should have passages"
assert len(source2_passages) > 0, "Source2 should have passages"
for passage, _, _ in source1_passages:
assert passage.source_id == source1.id, f"Passage should belong to source1, but has source_id={passage.source_id}"
assert passage.file_id == file1_metadata.id, f"Passage should belong to file1, but has file_id={passage.file_id}"
for passage, _, _ in source2_passages:
assert passage.source_id == source2.id, f"Passage should belong to source2, but has source_id={passage.source_id}"
assert passage.file_id == file2_metadata.id, f"Passage should belong to file2, but has file_id={passage.file_id}"
# delete source1 and verify source2 is unaffected
client.sources.delete(source_id=source1.id)
time.sleep(2)
source2_passages_after = asyncio.run(
tpuf_client.query_file_passages(source_ids=[source2.id], organization_id=user.organization_id, actor=user, top_k=100)
)
assert len(source2_passages_after) == len(source2_passages), (
f"Source2 should still have all passages after source1 deletion: {len(source2_passages_after)} vs {len(source2_passages)}"
)
client.sources.delete(source_id=source2.id)
# --- End Turbopuffer Tests ---