feat: Add turbopuffer embedder by default [LET-4253] (#4476)
* Adapt to turbopuffer embedder * Make turbopuffer search more efficient over all source ids * Combine turbopuffer and pinecone hybrid * Fix test sources
This commit is contained in:
@@ -16,12 +16,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def should_use_tpuf() -> bool:
|
||||
# We need OpenAI since we default to their embedding model
|
||||
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key)
|
||||
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key) and bool(model_settings.openai_api_key)
|
||||
|
||||
|
||||
def should_use_tpuf_for_messages() -> bool:
|
||||
"""Check if Turbopuffer should be used for messages."""
|
||||
return should_use_tpuf() and bool(settings.embed_all_messages) and bool(model_settings.openai_api_key)
|
||||
return should_use_tpuf() and bool(settings.embed_all_messages)
|
||||
|
||||
|
||||
class TurbopufferClient:
|
||||
@@ -1113,3 +1113,309 @@ class TurbopufferClient:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete all messages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
# file/source passage methods
|
||||
|
||||
@trace_method
|
||||
async def _get_file_passages_namespace_name(self, organization_id: str) -> str:
|
||||
"""Get namespace name for file passages (org-scoped).
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID for namespace generation
|
||||
|
||||
Returns:
|
||||
The org-scoped namespace name for file passages
|
||||
"""
|
||||
environment = settings.environment
|
||||
if environment:
|
||||
namespace_name = f"file_passages_{organization_id}_{environment.lower()}"
|
||||
else:
|
||||
namespace_name = f"file_passages_{organization_id}"
|
||||
|
||||
return namespace_name
|
||||
|
||||
@trace_method
|
||||
async def insert_file_passages(
|
||||
self,
|
||||
source_id: str,
|
||||
file_id: str,
|
||||
text_chunks: List[str],
|
||||
organization_id: str,
|
||||
actor: "PydanticUser",
|
||||
created_at: Optional[datetime] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Insert file passages into Turbopuffer using org-scoped namespace.
|
||||
|
||||
Args:
|
||||
source_id: ID of the source containing the file
|
||||
file_id: ID of the file
|
||||
text_chunks: List of text chunks to store
|
||||
organization_id: Organization ID for the passages
|
||||
actor: User actor for embedding generation
|
||||
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
|
||||
|
||||
Returns:
|
||||
List of PydanticPassage objects that were inserted
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
if not text_chunks:
|
||||
return []
|
||||
|
||||
# generate embeddings using the default config
|
||||
embeddings = await self._generate_embeddings(text_chunks, actor)
|
||||
|
||||
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
||||
|
||||
# handle timestamp - ensure UTC
|
||||
if created_at is None:
|
||||
timestamp = datetime.now(timezone.utc)
|
||||
else:
|
||||
# ensure the provided timestamp is timezone-aware and in UTC
|
||||
if created_at.tzinfo is None:
|
||||
# assume UTC if no timezone provided
|
||||
timestamp = created_at.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# convert to UTC if in different timezone
|
||||
timestamp = created_at.astimezone(timezone.utc)
|
||||
|
||||
# prepare column-based data for turbopuffer - optimized for batch insert
|
||||
ids = []
|
||||
vectors = []
|
||||
texts = []
|
||||
organization_ids = []
|
||||
source_ids = []
|
||||
file_ids = []
|
||||
created_ats = []
|
||||
passages = []
|
||||
|
||||
for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)):
|
||||
passage = PydanticPassage(
|
||||
text=text,
|
||||
file_id=file_id,
|
||||
source_id=source_id,
|
||||
embedding=embedding,
|
||||
embedding_config=self.default_embedding_config,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
# append to columns
|
||||
ids.append(passage.id)
|
||||
vectors.append(embedding)
|
||||
texts.append(text)
|
||||
organization_ids.append(organization_id)
|
||||
source_ids.append(source_id)
|
||||
file_ids.append(file_id)
|
||||
created_ats.append(timestamp)
|
||||
|
||||
# build column-based upsert data
|
||||
upsert_columns = {
|
||||
"id": ids,
|
||||
"vector": vectors,
|
||||
"text": texts,
|
||||
"organization_id": organization_ids,
|
||||
"source_id": source_ids,
|
||||
"file_id": file_ids,
|
||||
"created_at": created_ats,
|
||||
}
|
||||
|
||||
try:
|
||||
# use AsyncTurbopuffer as a context manager for proper resource cleanup
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
# turbopuffer recommends column-based writes for performance
|
||||
await namespace.write(
|
||||
upsert_columns=upsert_columns,
|
||||
distance_metric="cosine_distance",
|
||||
schema={"text": {"type": "string", "full_text_search": True}},
|
||||
)
|
||||
logger.info(f"Successfully inserted {len(ids)} file passages to Turbopuffer for source {source_id}, file {file_id}")
|
||||
return passages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert file passages to Turbopuffer: {e}")
|
||||
# check if it's a duplicate ID error
|
||||
if "duplicate" in str(e).lower():
|
||||
logger.error("Duplicate passage IDs detected in batch")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def query_file_passages(
|
||||
self,
|
||||
source_ids: List[str],
|
||||
organization_id: str,
|
||||
actor: "PydanticUser",
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
||||
top_k: int = 10,
|
||||
file_id: Optional[str] = None, # optional filter by specific file
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
) -> List[Tuple[PydanticPassage, float, dict]]:
|
||||
"""Query file passages from Turbopuffer using org-scoped namespace.
|
||||
|
||||
Args:
|
||||
source_ids: List of source IDs to query
|
||||
organization_id: Organization ID for namespace lookup
|
||||
actor: User actor for embedding generation
|
||||
query_text: Text query for search
|
||||
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
|
||||
top_k: Number of results to return
|
||||
file_id: Optional file ID to filter results to a specific file
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
|
||||
Returns:
|
||||
List of (passage, score, metadata) tuples with relevance rankings
|
||||
"""
|
||||
# generate embedding for vector/hybrid search if query_text is provided
|
||||
query_embedding = None
|
||||
if query_text and search_mode in ["vector", "hybrid"]:
|
||||
embeddings = await self._generate_embeddings([query_text], actor)
|
||||
query_embedding = embeddings[0]
|
||||
|
||||
# check if we should fallback to timestamp-based retrieval
|
||||
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
||||
# fallback to retrieving most recent passages when no search query is provided
|
||||
search_mode = "timestamp"
|
||||
|
||||
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
||||
|
||||
# build filters - always filter by source_ids
|
||||
if len(source_ids) == 1:
|
||||
# single source_id, use Eq for efficiency
|
||||
filters = [("source_id", "Eq", source_ids[0])]
|
||||
else:
|
||||
# multiple source_ids, use In operator
|
||||
filters = [("source_id", "In", source_ids)]
|
||||
|
||||
# add file filter if specified
|
||||
if file_id:
|
||||
filters.append(("file_id", "Eq", file_id))
|
||||
|
||||
# combine filters
|
||||
final_filter = filters[0] if len(filters) == 1 else ("And", filters)
|
||||
|
||||
try:
|
||||
# use generic query executor
|
||||
result = await self._execute_query(
|
||||
namespace_name=namespace_name,
|
||||
search_mode=search_mode,
|
||||
query_embedding=query_embedding,
|
||||
query_text=query_text,
|
||||
top_k=top_k,
|
||||
include_attributes=["text", "organization_id", "source_id", "file_id", "created_at"],
|
||||
filters=final_filter,
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
)
|
||||
|
||||
# process results based on search mode
|
||||
if search_mode == "hybrid":
|
||||
# for hybrid mode, we get a multi-query response
|
||||
vector_results = self._process_file_query_results(result.results[0])
|
||||
fts_results = self._process_file_query_results(result.results[1], is_fts=True)
|
||||
# use RRF and include metadata with ranks
|
||||
results_with_metadata = self._reciprocal_rank_fusion(
|
||||
vector_results=[passage for passage, _ in vector_results],
|
||||
fts_results=[passage for passage, _ in fts_results],
|
||||
get_id_func=lambda p: p.id,
|
||||
vector_weight=vector_weight,
|
||||
fts_weight=fts_weight,
|
||||
top_k=top_k,
|
||||
)
|
||||
return results_with_metadata
|
||||
else:
|
||||
# for single queries (vector, fts, timestamp) - add basic metadata
|
||||
is_fts = search_mode == "fts"
|
||||
results = self._process_file_query_results(result, is_fts=is_fts)
|
||||
# add simple metadata for single search modes
|
||||
results_with_metadata = []
|
||||
for idx, (passage, score) in enumerate(results):
|
||||
metadata = {
|
||||
"combined_score": score,
|
||||
f"{search_mode}_rank": idx + 1, # add the rank for this search mode
|
||||
}
|
||||
results_with_metadata.append((passage, score, metadata))
|
||||
return results_with_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query file passages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
def _process_file_query_results(self, result, is_fts: bool = False) -> List[Tuple[PydanticPassage, float]]:
|
||||
"""Process results from a file query into passage objects with scores."""
|
||||
passages_with_scores = []
|
||||
|
||||
for row in result.rows:
|
||||
# build metadata
|
||||
metadata = {}
|
||||
|
||||
# create a passage with minimal fields - embeddings are not returned from Turbopuffer
|
||||
passage = PydanticPassage(
|
||||
id=row.id,
|
||||
text=getattr(row, "text", ""),
|
||||
organization_id=getattr(row, "organization_id", None),
|
||||
source_id=getattr(row, "source_id", None), # get source_id from the row
|
||||
file_id=getattr(row, "file_id", None),
|
||||
created_at=getattr(row, "created_at", None),
|
||||
metadata_=metadata,
|
||||
tags=[],
|
||||
# set required fields to empty/default values since we don't store embeddings
|
||||
embedding=[], # empty embedding since we don't return it from Turbopuffer
|
||||
embedding_config=self.default_embedding_config,
|
||||
)
|
||||
|
||||
# handle score based on search type
|
||||
if is_fts:
|
||||
# for FTS, use the BM25 score directly (higher is better)
|
||||
score = getattr(row, "$score", 0.0)
|
||||
else:
|
||||
# for vector search, convert distance to similarity score
|
||||
distance = getattr(row, "$dist", 0.0)
|
||||
score = 1.0 - distance
|
||||
|
||||
passages_with_scores.append((passage, score))
|
||||
|
||||
return passages_with_scores
|
||||
|
||||
@trace_method
|
||||
async def delete_file_passages(self, source_id: str, file_id: str, organization_id: str) -> bool:
|
||||
"""Delete all passages for a specific file from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
# use delete_by_filter to only delete passages for this file
|
||||
# need to filter by both source_id and file_id
|
||||
filter_expr = ("And", [("source_id", "Eq", source_id), ("file_id", "Eq", file_id)])
|
||||
result = await namespace.write(delete_by_filter=filter_expr)
|
||||
logger.info(
|
||||
f"Successfully deleted passages for file {file_id} from source {source_id} (deleted {result.rows_affected} rows)"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file passages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def delete_source_passages(self, source_id: str, organization_id: str) -> bool:
|
||||
"""Delete all passages for a source from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
# delete all passages for this source
|
||||
result = await namespace.write(delete_by_filter=("source_id", "Eq", source_id))
|
||||
logger.info(f"Successfully deleted all passages for source {source_id} (deleted {result.rows_affected} rows)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete source passages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
@@ -15,6 +15,7 @@ from letta.helpers.pinecone_utils import (
|
||||
delete_source_records_from_pinecone_index,
|
||||
should_use_pinecone,
|
||||
)
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -191,7 +192,13 @@ async def delete_folder(
|
||||
files = await server.file_manager.list_files(folder_id, actor)
|
||||
file_ids = [f.id for f in files]
|
||||
|
||||
if should_use_pinecone():
|
||||
if should_use_tpuf():
|
||||
logger.info(f"Deleting folder {folder_id} from Turbopuffer")
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.delete_source_passages(source_id=folder_id, organization_id=actor.organization_id)
|
||||
elif should_use_pinecone():
|
||||
logger.info(f"Deleting folder {folder_id} from pinecone index")
|
||||
await delete_source_records_from_pinecone_index(source_id=folder_id, actor=actor)
|
||||
|
||||
@@ -450,7 +457,13 @@ async def delete_file_from_folder(
|
||||
|
||||
await server.remove_file_from_context_windows(source_id=folder_id, file_id=deleted_file.id, actor=actor)
|
||||
|
||||
if should_use_pinecone():
|
||||
if should_use_tpuf():
|
||||
logger.info(f"Deleting file {file_id} from Turbopuffer")
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.delete_file_passages(source_id=folder_id, file_id=file_id, organization_id=actor.organization_id)
|
||||
elif should_use_pinecone():
|
||||
logger.info(f"Deleting file {file_id} from pinecone index")
|
||||
await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor)
|
||||
|
||||
@@ -496,10 +509,15 @@ async def load_file_to_source_cloud(
|
||||
else:
|
||||
file_parser = MarkitdownFileParser()
|
||||
|
||||
using_pinecone = should_use_pinecone()
|
||||
if using_pinecone:
|
||||
# determine which embedder to use - turbopuffer takes precedence
|
||||
if should_use_tpuf():
|
||||
from letta.services.file_processor.embedder.turbopuffer_embedder import TurbopufferEmbedder
|
||||
|
||||
embedder = TurbopufferEmbedder(embedding_config=embedding_config)
|
||||
elif should_use_pinecone():
|
||||
embedder = PineconeEmbedder(embedding_config=embedding_config)
|
||||
else:
|
||||
embedder = OpenAIEmbedder(embedding_config=embedding_config)
|
||||
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor, using_pinecone=using_pinecone)
|
||||
|
||||
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor)
|
||||
await file_processor.process(agent_states=agent_states, source_id=source_id, content=content, file_metadata=file_metadata)
|
||||
|
||||
@@ -15,6 +15,7 @@ from letta.helpers.pinecone_utils import (
|
||||
delete_source_records_from_pinecone_index,
|
||||
should_use_pinecone,
|
||||
)
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -189,7 +190,13 @@ async def delete_source(
|
||||
files = await server.file_manager.list_files(source_id, actor)
|
||||
file_ids = [f.id for f in files]
|
||||
|
||||
if should_use_pinecone():
|
||||
if should_use_tpuf():
|
||||
logger.info(f"Deleting source {source_id} from Turbopuffer")
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.delete_source_passages(source_id=source_id, organization_id=actor.organization_id)
|
||||
elif should_use_pinecone():
|
||||
logger.info(f"Deleting source {source_id} from pinecone index")
|
||||
await delete_source_records_from_pinecone_index(source_id=source_id, actor=actor)
|
||||
|
||||
@@ -435,7 +442,13 @@ async def delete_file_from_source(
|
||||
|
||||
await server.remove_file_from_context_windows(source_id=source_id, file_id=deleted_file.id, actor=actor)
|
||||
|
||||
if should_use_pinecone():
|
||||
if should_use_tpuf():
|
||||
logger.info(f"Deleting file {file_id} from Turbopuffer")
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.delete_file_passages(source_id=source_id, file_id=file_id, organization_id=actor.organization_id)
|
||||
elif should_use_pinecone():
|
||||
logger.info(f"Deleting file {file_id} from pinecone index")
|
||||
await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor)
|
||||
|
||||
@@ -481,10 +494,15 @@ async def load_file_to_source_cloud(
|
||||
else:
|
||||
file_parser = MarkitdownFileParser()
|
||||
|
||||
using_pinecone = should_use_pinecone()
|
||||
if using_pinecone:
|
||||
# determine which embedder to use - turbopuffer takes precedence
|
||||
if should_use_tpuf():
|
||||
from letta.services.file_processor.embedder.turbopuffer_embedder import TurbopufferEmbedder
|
||||
|
||||
embedder = TurbopufferEmbedder(embedding_config=embedding_config)
|
||||
elif should_use_pinecone():
|
||||
embedder = PineconeEmbedder(embedding_config=embedding_config)
|
||||
else:
|
||||
embedder = OpenAIEmbedder(embedding_config=embedding_config)
|
||||
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor, using_pinecone=using_pinecone)
|
||||
|
||||
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor)
|
||||
await file_processor.process(agent_states=agent_states, source_id=source_id, content=content, file_metadata=file_metadata)
|
||||
|
||||
@@ -12,6 +12,7 @@ from letta.errors import (
|
||||
AgentNotFoundForExportError,
|
||||
)
|
||||
from letta.helpers.pinecone_utils import should_use_pinecone
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState, CreateAgent
|
||||
from letta.schemas.agent_file import (
|
||||
@@ -29,7 +30,7 @@ from letta.schemas.agent_file import (
|
||||
)
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import FileProcessingStatus
|
||||
from letta.schemas.enums import FileProcessingStatus, VectorDBProvider
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.group import Group, GroupCreate
|
||||
from letta.schemas.mcp import MCPServer
|
||||
@@ -90,7 +91,6 @@ class AgentSerializationManager:
|
||||
self.file_agent_manager = file_agent_manager
|
||||
self.message_manager = message_manager
|
||||
self.file_parser = MistralFileParser() if settings.mistral_api_key else MarkitdownFileParser()
|
||||
self.using_pinecone = should_use_pinecone()
|
||||
|
||||
# ID mapping state for export
|
||||
self._db_to_file_ids: Dict[str, str] = {}
|
||||
@@ -588,7 +588,12 @@ class AgentSerializationManager:
|
||||
if schema.files and any(f.content for f in schema.files):
|
||||
# Use override embedding config if provided, otherwise use agent's config
|
||||
embedder_config = override_embedding_config if override_embedding_config else schema.agents[0].embedding_config
|
||||
if should_use_pinecone():
|
||||
# determine which embedder to use - turbopuffer takes precedence
|
||||
if should_use_tpuf():
|
||||
from letta.services.file_processor.embedder.turbopuffer_embedder import TurbopufferEmbedder
|
||||
|
||||
embedder = TurbopufferEmbedder(embedding_config=embedder_config)
|
||||
elif should_use_pinecone():
|
||||
embedder = PineconeEmbedder(embedding_config=embedder_config)
|
||||
else:
|
||||
embedder = OpenAIEmbedder(embedding_config=embedder_config)
|
||||
@@ -596,7 +601,6 @@ class AgentSerializationManager:
|
||||
file_parser=self.file_parser,
|
||||
embedder=embedder,
|
||||
actor=actor,
|
||||
using_pinecone=self.using_pinecone,
|
||||
)
|
||||
|
||||
for file_schema in schema.files:
|
||||
|
||||
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.user import User
|
||||
|
||||
@@ -11,6 +12,10 @@ logger = get_logger(__name__)
|
||||
class BaseEmbedder(ABC):
|
||||
"""Abstract base class for embedding generation"""
|
||||
|
||||
def __init__(self):
|
||||
# Default to NATIVE, subclasses will override this
|
||||
self.vector_db_type = VectorDBProvider.NATIVE
|
||||
|
||||
@abstractmethod
|
||||
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
|
||||
"""Generate embeddings for chunks with batching and concurrent processing"""
|
||||
|
||||
@@ -19,6 +19,10 @@ class OpenAIEmbedder(BaseEmbedder):
|
||||
"""OpenAI-based embedding generation"""
|
||||
|
||||
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
|
||||
super().__init__()
|
||||
# OpenAI embedder uses the native vector db (PostgreSQL)
|
||||
# self.vector_db_type already set to VectorDBProvider.NATIVE by parent
|
||||
|
||||
self.default_embedding_config = (
|
||||
EmbeddingConfig.default_config(model_name="text-embedding-3-small", provider="openai")
|
||||
if model_settings.openai_api_key
|
||||
|
||||
@@ -4,6 +4,7 @@ from letta.helpers.pinecone_utils import upsert_file_records_to_pinecone_index
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.user import User
|
||||
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
|
||||
@@ -20,6 +21,10 @@ class PineconeEmbedder(BaseEmbedder):
|
||||
"""Pinecone-based embedding generation"""
|
||||
|
||||
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
|
||||
super().__init__()
|
||||
# set the vector db type for pinecone
|
||||
self.vector_db_type = VectorDBProvider.PINECONE
|
||||
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone package is not installed. Install it with: pip install pinecone")
|
||||
|
||||
@@ -28,7 +33,6 @@ class PineconeEmbedder(BaseEmbedder):
|
||||
embedding_config = EmbeddingConfig.default_config(provider="pinecone")
|
||||
|
||||
self.embedding_config = embedding_config
|
||||
super().__init__()
|
||||
|
||||
@trace_method
|
||||
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.user import User
|
||||
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TurbopufferEmbedder(BaseEmbedder):
|
||||
"""Turbopuffer-based embedding generation and storage"""
|
||||
|
||||
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
|
||||
super().__init__()
|
||||
# set the vector db type for turbopuffer
|
||||
self.vector_db_type = VectorDBProvider.TPUF
|
||||
# use the default embedding config from TurbopufferClient if not provided
|
||||
self.embedding_config = embedding_config or TurbopufferClient.default_embedding_config
|
||||
self.tpuf_client = TurbopufferClient()
|
||||
|
||||
@trace_method
|
||||
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
|
||||
"""Generate embeddings and store in Turbopuffer, then return Passage objects"""
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
logger.info(f"Generating embeddings for {len(chunks)} chunks using Turbopuffer")
|
||||
log_event(
|
||||
"turbopuffer_embedder.generation_started",
|
||||
{
|
||||
"total_chunks": len(chunks),
|
||||
"file_id": file_id,
|
||||
"source_id": source_id,
|
||||
"embedding_model": self.embedding_config.embedding_model,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# insert passages to Turbopuffer - it will handle embedding generation internally
|
||||
passages = await self.tpuf_client.insert_file_passages(
|
||||
source_id=source_id,
|
||||
file_id=file_id,
|
||||
text_chunks=chunks,
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
logger.info(f"Successfully generated and stored {len(passages)} passages in Turbopuffer")
|
||||
log_event(
|
||||
"turbopuffer_embedder.generation_completed",
|
||||
{
|
||||
"passages_created": len(passages),
|
||||
"total_chunks_processed": len(chunks),
|
||||
"file_id": file_id,
|
||||
"source_id": source_id,
|
||||
},
|
||||
)
|
||||
return passages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embeddings with Turbopuffer: {str(e)}")
|
||||
log_event(
|
||||
"turbopuffer_embedder.generation_failed",
|
||||
{"error": str(e), "error_type": type(e).__name__, "file_id": file_id, "source_id": source_id},
|
||||
)
|
||||
raise
|
||||
@@ -6,7 +6,7 @@ from letta.log import get_logger
|
||||
from letta.otel.context import get_ctx_attributes
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import FileProcessingStatus
|
||||
from letta.schemas.enums import FileProcessingStatus, VectorDBProvider
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.user import User
|
||||
@@ -30,7 +30,6 @@ class FileProcessor:
|
||||
file_parser: FileParser,
|
||||
embedder: BaseEmbedder,
|
||||
actor: User,
|
||||
using_pinecone: bool,
|
||||
max_file_size: int = 50 * 1024 * 1024, # 50MB default
|
||||
):
|
||||
self.file_parser = file_parser
|
||||
@@ -42,7 +41,8 @@ class FileProcessor:
|
||||
self.job_manager = JobManager()
|
||||
self.agent_manager = AgentManager()
|
||||
self.actor = actor
|
||||
self.using_pinecone = using_pinecone
|
||||
# get vector db type from the embedder
|
||||
self.vector_db_type = embedder.vector_db_type
|
||||
|
||||
async def _chunk_and_embed_with_fallback(self, file_metadata: FileMetadata, ocr_response, source_id: str) -> List:
|
||||
"""Chunk text and generate embeddings with fallback to default chunker if needed"""
|
||||
@@ -218,7 +218,7 @@ class FileProcessor:
|
||||
source_id=source_id,
|
||||
)
|
||||
|
||||
if not self.using_pinecone:
|
||||
if self.vector_db_type == VectorDBProvider.NATIVE:
|
||||
all_passages = await self.passage_manager.create_many_source_passages_async(
|
||||
passages=all_passages,
|
||||
file_metadata=file_metadata,
|
||||
@@ -241,7 +241,8 @@ class FileProcessor:
|
||||
)
|
||||
|
||||
# update job status
|
||||
if not self.using_pinecone:
|
||||
# pinecone completes slowly, so gets updated later
|
||||
if self.vector_db_type != VectorDBProvider.PINECONE:
|
||||
await self.file_manager.update_file_status(
|
||||
file_id=file_metadata.id,
|
||||
actor=self.actor,
|
||||
@@ -317,14 +318,15 @@ class FileProcessor:
|
||||
)
|
||||
|
||||
# Create passages in database (unless using Pinecone)
|
||||
if not self.using_pinecone:
|
||||
if self.vector_db_type == VectorDBProvider.NATIVE:
|
||||
all_passages = await self.passage_manager.create_many_source_passages_async(
|
||||
passages=all_passages, file_metadata=file_metadata, actor=self.actor
|
||||
)
|
||||
log_event("file_processor.import_passages_created", {"filename": filename, "total_passages": len(all_passages)})
|
||||
|
||||
# Update file status to completed (valid transition from EMBEDDING)
|
||||
if not self.using_pinecone:
|
||||
# pinecone completes slowly, so gets updated later
|
||||
if self.vector_db_type != VectorDBProvider.PINECONE:
|
||||
await self.file_manager.update_file_status(
|
||||
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Optional, Union
|
||||
|
||||
from sqlalchemy import and_, exists, select
|
||||
|
||||
from letta.helpers.pinecone_utils import should_use_pinecone
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
@@ -18,6 +19,18 @@ from letta.utils import enforce_types, printd
|
||||
|
||||
|
||||
class SourceManager:
|
||||
def _get_vector_db_provider(self) -> VectorDBProvider:
|
||||
"""
|
||||
determine which vector db provider to use based on configuration.
|
||||
turbopuffer takes precedence when available.
|
||||
"""
|
||||
if should_use_tpuf():
|
||||
return VectorDBProvider.TPUF
|
||||
elif should_use_pinecone():
|
||||
return VectorDBProvider.PINECONE
|
||||
else:
|
||||
return VectorDBProvider.NATIVE
|
||||
|
||||
"""Manager class to handle business logic related to Sources."""
|
||||
|
||||
@trace_method
|
||||
@@ -52,7 +65,7 @@ class SourceManager:
|
||||
if db_source:
|
||||
return db_source
|
||||
else:
|
||||
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
|
||||
vector_db_provider = self._get_vector_db_provider()
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Provide default embedding config if not given
|
||||
@@ -96,7 +109,7 @@ class SourceManager:
|
||||
Returns:
|
||||
List of created/updated sources
|
||||
"""
|
||||
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
|
||||
vector_db_provider = self._get_vector_db_provider()
|
||||
for pydantic_source in pydantic_sources:
|
||||
pydantic_source.vector_db_provider = vector_db_provider
|
||||
|
||||
|
||||
@@ -5,10 +5,13 @@ from typing import Any, Dict, List, Optional
|
||||
from letta.constants import PINECONE_TEXT_FIELD_NAME
|
||||
from letta.functions.types import FileOpenRequest
|
||||
from letta.helpers.pinecone_utils import search_pinecone_index, should_use_pinecone
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
@@ -554,18 +557,140 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
|
||||
self.logger.info(f"Semantic search started for agent {agent_state.id} with query '{query}' (limit: {limit})")
|
||||
|
||||
# Check if Pinecone is enabled and use it if available
|
||||
if should_use_pinecone():
|
||||
return await self._search_files_pinecone(agent_state, query, limit)
|
||||
else:
|
||||
return await self._search_files_traditional(agent_state, query, limit)
|
||||
# Check which vector DB to use - Turbopuffer takes precedence
|
||||
attached_sources = await self.agent_manager.list_attached_sources_async(agent_id=agent_state.id, actor=self.actor)
|
||||
attached_tpuf_sources = [source for source in attached_sources if source.vector_db_provider == VectorDBProvider.TPUF]
|
||||
attached_pinecone_sources = [source for source in attached_sources if source.vector_db_provider == VectorDBProvider.PINECONE]
|
||||
|
||||
async def _search_files_pinecone(self, agent_state: AgentState, query: str, limit: int) -> str:
|
||||
if not attached_tpuf_sources and not attached_pinecone_sources:
|
||||
return await self._search_files_native(agent_state, query, limit)
|
||||
|
||||
results = []
|
||||
|
||||
# If both have items, we half the limit roughly
|
||||
# TODO: This is very hacky bc it skips the re-ranking - but this is a temporary stopgap while we think about migrating data
|
||||
|
||||
if attached_tpuf_sources and attached_pinecone_sources:
|
||||
limit = max(limit // 2, 1)
|
||||
|
||||
if should_use_tpuf() and attached_tpuf_sources:
|
||||
tpuf_result = await self._search_files_turbopuffer(agent_state, attached_tpuf_sources, query, limit)
|
||||
results.append(tpuf_result)
|
||||
|
||||
if should_use_pinecone() and attached_pinecone_sources:
|
||||
pinecone_result = await self._search_files_pinecone(agent_state, attached_pinecone_sources, query, limit)
|
||||
results.append(pinecone_result)
|
||||
|
||||
# combine results from both sources
|
||||
if results:
|
||||
return "\n\n".join(results)
|
||||
|
||||
# fallback if no results from either source
|
||||
return "No results found"
|
||||
|
||||
async def _search_files_turbopuffer(self, agent_state: AgentState, attached_sources: List[Source], query: str, limit: int) -> str:
|
||||
"""Search files using Turbopuffer vector database."""
|
||||
|
||||
# Get attached sources
|
||||
source_ids = [source.id for source in attached_sources]
|
||||
if not source_ids:
|
||||
return "No valid source IDs found for attached files"
|
||||
|
||||
# Get all attached files for this agent
|
||||
file_agents = await self.files_agents_manager.list_files_for_agent(
|
||||
agent_id=agent_state.id, per_file_view_window_char_limit=agent_state.per_file_view_window_char_limit, actor=self.actor
|
||||
)
|
||||
if not file_agents:
|
||||
return "No files are currently attached to search"
|
||||
|
||||
# Create a map of file_id to file_name for quick lookup
|
||||
file_map = {fa.file_id: fa.file_name for fa in file_agents}
|
||||
|
||||
results = []
|
||||
total_hits = 0
|
||||
files_with_matches = {}
|
||||
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# Query Turbopuffer for all sources at once
|
||||
search_results = await tpuf_client.query_file_passages(
|
||||
source_ids=source_ids, # pass all source_ids as a list
|
||||
organization_id=self.actor.organization_id,
|
||||
actor=self.actor,
|
||||
query_text=query,
|
||||
search_mode="hybrid", # use hybrid search for best results
|
||||
top_k=limit,
|
||||
)
|
||||
|
||||
# Process search results
|
||||
for passage, score, metadata in search_results:
|
||||
if total_hits >= limit:
|
||||
break
|
||||
|
||||
total_hits += 1
|
||||
|
||||
# get file name from our map
|
||||
file_name = file_map.get(passage.file_id, "Unknown File")
|
||||
|
||||
# group by file name
|
||||
if file_name not in files_with_matches:
|
||||
files_with_matches[file_name] = []
|
||||
files_with_matches[file_name].append({"text": passage.text, "score": score, "passage_id": passage.id})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Turbopuffer search failed: {str(e)}")
|
||||
raise e
|
||||
|
||||
if not files_with_matches:
|
||||
return f"No semantic matches found in Turbopuffer for query: '{query}'"
|
||||
|
||||
# Format results
|
||||
passage_num = 0
|
||||
for file_name, matches in files_with_matches.items():
|
||||
for match in matches:
|
||||
passage_num += 1
|
||||
|
||||
# format each passage with terminal-style header
|
||||
score_display = f"(score: {match['score']:.3f})"
|
||||
passage_header = f"\n=== {file_name} (passage #{passage_num}) {score_display} ==="
|
||||
|
||||
# format the passage text
|
||||
passage_text = match["text"].strip()
|
||||
lines = passage_text.splitlines()
|
||||
formatted_lines = []
|
||||
for line in lines[:20]: # limit to first 20 lines per passage
|
||||
formatted_lines.append(f" {line}")
|
||||
|
||||
if len(lines) > 20:
|
||||
formatted_lines.append(f" ... [truncated {len(lines) - 20} more lines]")
|
||||
|
||||
passage_content = "\n".join(formatted_lines)
|
||||
results.append(f"{passage_header}\n{passage_content}")
|
||||
|
||||
# mark access for files that had matches
|
||||
if files_with_matches:
|
||||
matched_file_names = [name for name in files_with_matches.keys() if name != "Unknown File"]
|
||||
if matched_file_names:
|
||||
await self.files_agents_manager.mark_access_bulk(agent_id=agent_state.id, file_names=matched_file_names, actor=self.actor)
|
||||
|
||||
# create summary header
|
||||
file_count = len(files_with_matches)
|
||||
summary = f"Found {total_hits} Turbopuffer matches in {file_count} file{'s' if file_count != 1 else ''} for query: '{query}'"
|
||||
|
||||
# combine all results
|
||||
formatted_results = [summary, "=" * len(summary)] + results
|
||||
|
||||
self.logger.info(f"Turbopuffer search completed: {total_hits} matches across {file_count} files")
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
async def _search_files_pinecone(self, agent_state: AgentState, attached_sources: List[Source], query: str, limit: int) -> str:
|
||||
"""Search files using Pinecone vector database."""
|
||||
|
||||
# Extract unique source_ids
|
||||
# TODO: Inefficient
|
||||
attached_sources = await self.agent_manager.list_attached_sources_async(agent_id=agent_state.id, actor=self.actor)
|
||||
source_ids = [source.id for source in attached_sources]
|
||||
if not source_ids:
|
||||
return "No valid source IDs found for attached files"
|
||||
@@ -658,7 +783,7 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
self.logger.info(f"Pinecone search completed: {total_hits} matches across {file_count} files")
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
async def _search_files_traditional(self, agent_state: AgentState, query: str, limit: int) -> str:
|
||||
async def _search_files_native(self, agent_state: AgentState, query: str, limit: int) -> str:
|
||||
"""Traditional search using existing passage manager."""
|
||||
# Get semantic search results
|
||||
passages = await self.agent_manager.query_source_passages_async(
|
||||
|
||||
@@ -258,7 +258,7 @@ class TestFileProcessorWithPinecone:
|
||||
embedder = PineconeEmbedder()
|
||||
|
||||
# Create file processor with Pinecone enabled
|
||||
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=mock_actor, using_pinecone=True)
|
||||
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=mock_actor)
|
||||
|
||||
# Track file manager update calls
|
||||
update_calls = []
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta_client.types import AgentState
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS
|
||||
from letta.helpers.pinecone_utils import should_use_pinecone
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.schemas.enums import FileProcessingStatus, ToolType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.user import User
|
||||
@@ -95,7 +96,7 @@ def agent_state(disable_pinecone, client: LettaSDKClient):
|
||||
# Tests
|
||||
|
||||
|
||||
def test_auto_attach_detach_files_tools(disable_pinecone, client: LettaSDKClient):
|
||||
def test_auto_attach_detach_files_tools(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
|
||||
"""Test automatic attachment and detachment of file tools when managing agent sources."""
|
||||
# Create agent with basic configuration
|
||||
agent = client.agents.create(
|
||||
@@ -168,6 +169,7 @@ def test_auto_attach_detach_files_tools(disable_pinecone, client: LettaSDKClient
|
||||
)
|
||||
def test_file_upload_creates_source_blocks_correctly(
|
||||
disable_pinecone,
|
||||
disable_turbopuffer,
|
||||
client: LettaSDKClient,
|
||||
agent_state: AgentState,
|
||||
file_path: str,
|
||||
@@ -237,7 +239,9 @@ def test_file_upload_creates_source_blocks_correctly(
|
||||
settings.mistral_api_key = original_mistral_key
|
||||
|
||||
|
||||
def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_attach_existing_files_creates_source_blocks_correctly(
|
||||
disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState
|
||||
):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
assert len(client.sources.list()) == 1
|
||||
@@ -302,7 +306,9 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone,
|
||||
assert "<directories>" not in raw_system_message_after_detach
|
||||
|
||||
|
||||
def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_delete_source_removes_source_blocks_correctly(
|
||||
disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState
|
||||
):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
assert len(client.sources.list()) == 1
|
||||
@@ -360,7 +366,7 @@ def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client:
|
||||
assert not any("test" in b.value for b in blocks)
|
||||
|
||||
|
||||
def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_agent_uses_open_close_file_correctly(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
@@ -463,7 +469,7 @@ def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDK
|
||||
print("✓ File successfully opened with different range - content differs as expected")
|
||||
|
||||
|
||||
def test_agent_uses_search_files_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_agent_uses_search_files_correctly(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
@@ -509,7 +515,7 @@ def test_agent_uses_search_files_correctly(disable_pinecone, client: LettaSDKCli
|
||||
assert all(tr.status == "success" for tr in tool_returns), f"Tool call failed {tr}"
|
||||
|
||||
|
||||
def test_agent_uses_grep_correctly_basic(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_agent_uses_grep_correctly_basic(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
@@ -551,7 +557,7 @@ def test_agent_uses_grep_correctly_basic(disable_pinecone, client: LettaSDKClien
|
||||
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
|
||||
|
||||
|
||||
def test_agent_uses_grep_correctly_advanced(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_agent_uses_grep_correctly_advanced(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
@@ -599,7 +605,7 @@ def test_agent_uses_grep_correctly_advanced(disable_pinecone, client: LettaSDKCl
|
||||
assert "511:" in tool_return_message.tool_return
|
||||
|
||||
|
||||
def test_create_agent_with_source_ids_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient):
|
||||
def test_create_agent_with_source_ids_creates_source_blocks_correctly(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
|
||||
"""Test that creating an agent with source_ids parameter correctly creates source blocks."""
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
@@ -642,7 +648,7 @@ def test_create_agent_with_source_ids_creates_source_blocks_correctly(disable_pi
|
||||
assert file_tools == set(FILES_TOOLS)
|
||||
|
||||
|
||||
def test_view_ranges_have_metadata(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
def test_view_ranges_have_metadata(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
@@ -705,7 +711,7 @@ def test_view_ranges_have_metadata(disable_pinecone, client: LettaSDKClient, age
|
||||
)
|
||||
|
||||
|
||||
def test_duplicate_file_renaming(disable_pinecone, client: LettaSDKClient):
|
||||
def test_duplicate_file_renaming(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
|
||||
"""Test that duplicate files are renamed with count-based suffixes (e.g., file.txt, file (1).txt, file (2).txt)"""
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_duplicate_source", embedding="openai/text-embedding-3-small")
|
||||
@@ -744,7 +750,7 @@ def test_duplicate_file_renaming(disable_pinecone, client: LettaSDKClient):
|
||||
print(f" File {i + 1}: original='{file.original_file_name}' → renamed='{file.file_name}'")
|
||||
|
||||
|
||||
def test_duplicate_file_handling_replace(disable_pinecone, client: LettaSDKClient):
|
||||
def test_duplicate_file_handling_replace(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
|
||||
"""Test that DuplicateFileHandling.REPLACE replaces existing files with same name"""
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_replace_source", embedding="openai/text-embedding-3-small")
|
||||
@@ -826,7 +832,7 @@ def test_duplicate_file_handling_replace(disable_pinecone, client: LettaSDKClien
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_upload_file_with_custom_name(disable_pinecone, client: LettaSDKClient):
|
||||
def test_upload_file_with_custom_name(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
|
||||
"""Test that uploading a file with a custom name overrides the original filename"""
|
||||
# Create agent
|
||||
agent_state = client.agents.create(
|
||||
@@ -907,7 +913,7 @@ def test_upload_file_with_custom_name(disable_pinecone, client: LettaSDKClient):
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient):
|
||||
def test_open_files_schema_descriptions(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
|
||||
"""Test that open_files tool schema contains correct descriptions from docstring"""
|
||||
|
||||
# Get the open_files tool
|
||||
@@ -990,7 +996,7 @@ def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient
|
||||
assert length_prop["type"] == "integer"
|
||||
|
||||
|
||||
def test_grep_files_schema_descriptions(disable_pinecone, client: LettaSDKClient):
|
||||
def test_grep_files_schema_descriptions(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
|
||||
"""Test that grep_files tool schema contains correct descriptions from docstring"""
|
||||
|
||||
# Get the grep_files tool
|
||||
@@ -1076,10 +1082,174 @@ def test_grep_files_schema_descriptions(disable_pinecone, client: LettaSDKClient
|
||||
assert "Navigation hint for next page if more matches exist" in description
|
||||
|
||||
|
||||
def test_agent_open_file(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
|
||||
"""Test client.agents.open_file() function"""
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
# Attach source to agent
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
|
||||
# Upload a file
|
||||
file_path = "tests/data/test.txt"
|
||||
file_metadata = upload_file_and_wait(client, source.id, file_path)
|
||||
|
||||
# Basic test open_file function
|
||||
closed_files = client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id)
|
||||
assert len(closed_files) == 0
|
||||
|
||||
system = get_raw_system_message(client, agent_state.id)
|
||||
assert '<file status="open" name="test_source/test.txt">' in system
|
||||
assert "[Viewing file start (out of 1 lines)]" in system
|
||||
|
||||
|
||||
def test_agent_close_file(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
|
||||
"""Test client.agents.close_file() function"""
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
# Attach source to agent
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
|
||||
# Upload a file
|
||||
file_path = "tests/data/test.txt"
|
||||
file_metadata = upload_file_and_wait(client, source.id, file_path)
|
||||
|
||||
# First open the file
|
||||
client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id)
|
||||
|
||||
# Test close_file function
|
||||
client.agents.files.close(agent_id=agent_state.id, file_id=file_metadata.id)
|
||||
|
||||
system = get_raw_system_message(client, agent_state.id)
|
||||
assert '<file status="closed" name="test_source/test.txt">' in system
|
||||
|
||||
|
||||
def test_agent_close_all_open_files(disable_pinecone, disable_turbopuffer, client: LettaSDKClient, agent_state: AgentState):
|
||||
"""Test client.agents.close_all_open_files() function"""
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
# Attach source to agent
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
|
||||
# Upload multiple files
|
||||
file_paths = ["tests/data/test.txt", "tests/data/test.md"]
|
||||
file_metadatas = []
|
||||
for file_path in file_paths:
|
||||
file_metadata = upload_file_and_wait(client, source.id, file_path)
|
||||
file_metadatas.append(file_metadata)
|
||||
# Open each file
|
||||
client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id)
|
||||
|
||||
system = get_raw_system_message(client, agent_state.id)
|
||||
assert '<file status="open"' in system
|
||||
|
||||
# Test close_all_open_files function
|
||||
result = client.agents.files.close_all(agent_id=agent_state.id)
|
||||
|
||||
# Verify result is a list of strings
|
||||
assert isinstance(result, list), f"Expected list, got {type(result)}"
|
||||
assert all(isinstance(item, str) for item in result), "All items in result should be strings"
|
||||
|
||||
system = get_raw_system_message(client, agent_state.id)
|
||||
assert '<file status="open"' not in system
|
||||
|
||||
|
||||
def test_file_processing_timeout(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
|
||||
"""Test that files in non-terminal states are moved to error after timeout"""
|
||||
# Create a source
|
||||
source = client.sources.create(name="test_timeout_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
# Upload a file
|
||||
file_path = "tests/data/test.txt"
|
||||
with open(file_path, "rb") as f:
|
||||
file_metadata = client.sources.files.upload(source_id=source.id, file=f)
|
||||
|
||||
# Get the file ID
|
||||
file_id = file_metadata.id
|
||||
|
||||
# Test the is_terminal_state method directly (this doesn't require server mocking)
|
||||
assert FileProcessingStatus.COMPLETED.is_terminal_state() == True
|
||||
assert FileProcessingStatus.ERROR.is_terminal_state() == True
|
||||
assert FileProcessingStatus.PARSING.is_terminal_state() == False
|
||||
assert FileProcessingStatus.EMBEDDING.is_terminal_state() == False
|
||||
assert FileProcessingStatus.PENDING.is_terminal_state() == False
|
||||
|
||||
# For testing the actual timeout logic, we can check the current file status
|
||||
current_file = client.sources.get_file_metadata(source_id=source.id, file_id=file_id)
|
||||
|
||||
# Convert string status to enum for testing
|
||||
status_enum = FileProcessingStatus(current_file.processing_status)
|
||||
|
||||
# Verify that files in terminal states are not affected by timeout checks
|
||||
if status_enum.is_terminal_state():
|
||||
# This is the expected behavior - files that completed processing shouldn't timeout
|
||||
print(f"File {file_id} is in terminal state: {current_file.processing_status}")
|
||||
assert status_enum in [FileProcessingStatus.COMPLETED, FileProcessingStatus.ERROR]
|
||||
else:
|
||||
# If file is still processing, it should eventually complete or timeout
|
||||
# In a real scenario, we'd wait and check, but for unit tests we just verify the logic exists
|
||||
print(f"File {file_id} is still processing: {current_file.processing_status}")
|
||||
assert status_enum in [FileProcessingStatus.PENDING, FileProcessingStatus.PARSING, FileProcessingStatus.EMBEDDING]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_file_processing_timeout_logic():
|
||||
"""Test the timeout logic directly without server dependencies"""
|
||||
from datetime import timezone
|
||||
|
||||
# Test scenario: file created 35 minutes ago, timeout is 30 minutes
|
||||
old_time = datetime.now(timezone.utc) - timedelta(minutes=35)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
timeout_minutes = 30
|
||||
|
||||
# Calculate timeout threshold
|
||||
timeout_threshold = current_time - timedelta(minutes=timeout_minutes)
|
||||
|
||||
# Verify timeout logic
|
||||
assert old_time < timeout_threshold, "File created 35 minutes ago should be past 30-minute timeout"
|
||||
|
||||
# Test edge case: file created exactly at timeout
|
||||
edge_time = current_time - timedelta(minutes=timeout_minutes)
|
||||
assert not (edge_time < timeout_threshold), "File created exactly at timeout should not trigger timeout"
|
||||
|
||||
# Test recent file
|
||||
recent_time = current_time - timedelta(minutes=10)
|
||||
assert not (recent_time < timeout_threshold), "Recent file should not trigger timeout"
|
||||
|
||||
|
||||
def test_letta_free_embedding(disable_pinecone, disable_turbopuffer, client: LettaSDKClient):
|
||||
"""Test creating a source with letta/letta-free embedding and uploading a file"""
|
||||
# create a source with letta-free embedding
|
||||
source = client.sources.create(name="test_letta_free_source", embedding="letta/letta-free")
|
||||
|
||||
# verify source was created with correct embedding
|
||||
assert source.name == "test_letta_free_source"
|
||||
print("\n\n\n\ntest")
|
||||
print(source.embedding_config)
|
||||
# assert source.embedding_config.embedding_model == "letta-free"
|
||||
|
||||
# upload test.txt file
|
||||
file_path = "tests/data/test.txt"
|
||||
file_metadata = upload_file_and_wait(client, source.id, file_path)
|
||||
|
||||
# verify file was uploaded successfully
|
||||
assert file_metadata.processing_status == "completed"
|
||||
assert file_metadata.source_id == source.id
|
||||
assert file_metadata.file_name == "test.txt"
|
||||
|
||||
# verify file appears in source files list
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
assert files[0].id == file_metadata.id
|
||||
|
||||
# cleanup
|
||||
client.sources.delete(source_id=source.id)
|
||||
|
||||
|
||||
# --- Pinecone Tests ---
|
||||
|
||||
|
||||
def test_pinecone_search_files_tool(client: LettaSDKClient):
|
||||
def test_pinecone_search_files_tool(disable_turbopuffer, client: LettaSDKClient):
|
||||
"""Test that search_files tool uses Pinecone when enabled"""
|
||||
from letta.helpers.pinecone_utils import should_use_pinecone
|
||||
|
||||
@@ -1130,7 +1300,7 @@ def test_pinecone_search_files_tool(client: LettaSDKClient):
|
||||
)
|
||||
|
||||
|
||||
def test_pinecone_list_files_status(client: LettaSDKClient):
|
||||
def test_pinecone_list_files_status(disable_turbopuffer, client: LettaSDKClient):
|
||||
"""Test that list_source_files properly syncs embedding status with Pinecone"""
|
||||
if not should_use_pinecone():
|
||||
pytest.skip("Pinecone not configured (missing API key or disabled), skipping Pinecone-specific tests")
|
||||
@@ -1165,7 +1335,7 @@ def test_pinecone_list_files_status(client: LettaSDKClient):
|
||||
client.sources.delete(source_id=source.id)
|
||||
|
||||
|
||||
def test_pinecone_lifecycle_file_and_source_deletion(client: LettaSDKClient):
|
||||
def test_pinecone_lifecycle_file_and_source_deletion(disable_turbopuffer, client: LettaSDKClient):
|
||||
"""Test that file and source deletion removes records from Pinecone"""
|
||||
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
|
||||
|
||||
@@ -1236,167 +1406,196 @@ def test_pinecone_lifecycle_file_and_source_deletion(client: LettaSDKClient):
|
||||
)
|
||||
|
||||
|
||||
def test_agent_open_file(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
"""Test client.agents.open_file() function"""
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
# Attach source to agent
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
|
||||
# Upload a file
|
||||
file_path = "tests/data/test.txt"
|
||||
file_metadata = upload_file_and_wait(client, source.id, file_path)
|
||||
|
||||
# Basic test open_file function
|
||||
closed_files = client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id)
|
||||
assert len(closed_files) == 0
|
||||
|
||||
system = get_raw_system_message(client, agent_state.id)
|
||||
assert '<file status="open" name="test_source/test.txt">' in system
|
||||
assert "[Viewing file start (out of 1 lines)]" in system
|
||||
# --- End Pinecone Tests ---
|
||||
|
||||
|
||||
def test_agent_close_file(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
"""Test client.agents.close_file() function"""
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
# --- Turbopuffer Tests ---
|
||||
def test_turbopuffer_search_files_tool(disable_pinecone, client: LettaSDKClient):
|
||||
"""Test that search_files tool uses Turbopuffer when enabled"""
|
||||
agent = client.agents.create(
|
||||
name="test_turbopuffer_agent",
|
||||
memory_blocks=[
|
||||
CreateBlock(label="human", value="username: testuser"),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Attach source to agent
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
source = client.sources.create(name="test_turbopuffer_source", embedding="openai/text-embedding-3-small")
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent.id)
|
||||
|
||||
# Upload a file
|
||||
file_path = "tests/data/test.txt"
|
||||
file_metadata = upload_file_and_wait(client, source.id, file_path)
|
||||
file_path = "tests/data/long_test.txt"
|
||||
upload_file_and_wait(client, source.id, file_path)
|
||||
|
||||
# First open the file
|
||||
client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id)
|
||||
search_response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Use the semantic_search_files tool to search for 'electoral history' in the files.")],
|
||||
)
|
||||
|
||||
# Test close_file function
|
||||
client.agents.files.close(agent_id=agent_state.id, file_id=file_metadata.id)
|
||||
tool_calls = [msg for msg in search_response.messages if msg.message_type == "tool_call_message"]
|
||||
assert len(tool_calls) > 0, "No tool calls found"
|
||||
assert any(tc.tool_call.name == "semantic_search_files" for tc in tool_calls), "semantic_search_files not called"
|
||||
|
||||
system = get_raw_system_message(client, agent_state.id)
|
||||
assert '<file status="closed" name="test_source/test.txt">' in system
|
||||
tool_returns = [msg for msg in search_response.messages if msg.message_type == "tool_return_message"]
|
||||
assert len(tool_returns) > 0, "No tool returns found"
|
||||
assert all(tr.status == "success" for tr in tool_returns), "Tool call failed"
|
||||
|
||||
search_results = tool_returns[0].tool_return
|
||||
print(f"Turbopuffer search results: {search_results}")
|
||||
assert "electoral" in search_results.lower() or "history" in search_results.lower(), (
|
||||
f"Search results should contain relevant content: {search_results}"
|
||||
)
|
||||
|
||||
client.agents.delete(agent_id=agent.id)
|
||||
client.sources.delete(source_id=source.id)
|
||||
|
||||
|
||||
def test_agent_close_all_open_files(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
"""Test client.agents.close_all_open_files() function"""
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
def test_turbopuffer_file_processing_status(disable_pinecone, client: LettaSDKClient):
|
||||
"""Test that file processing completes successfully with Turbopuffer"""
|
||||
print("Testing Turbopuffer file processing status")
|
||||
|
||||
# Attach source to agent
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
source = client.sources.create(name="test_tpuf_file_status", embedding="openai/text-embedding-3-small")
|
||||
|
||||
# Upload multiple files
|
||||
file_paths = ["tests/data/test.txt", "tests/data/test.md"]
|
||||
file_metadatas = []
|
||||
file_paths = ["tests/data/long_test.txt", "tests/data/test.md"]
|
||||
uploaded_files = []
|
||||
for file_path in file_paths:
|
||||
file_metadata = upload_file_and_wait(client, source.id, file_path)
|
||||
file_metadatas.append(file_metadata)
|
||||
# Open each file
|
||||
client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata.id)
|
||||
uploaded_files.append(file_metadata)
|
||||
assert file_metadata.processing_status == "completed", f"File {file_path} should be completed"
|
||||
|
||||
system = get_raw_system_message(client, agent_state.id)
|
||||
assert '<file status="open"' in system
|
||||
files_list = client.sources.files.list(source_id=source.id, limit=100)
|
||||
|
||||
# Test close_all_open_files function
|
||||
result = client.agents.files.close_all(agent_id=agent_state.id)
|
||||
assert len(files_list) == len(uploaded_files), f"Expected {len(uploaded_files)} files, got {len(files_list)}"
|
||||
|
||||
# Verify result is a list of strings
|
||||
assert isinstance(result, list), f"Expected list, got {type(result)}"
|
||||
assert all(isinstance(item, str) for item in result), "All items in result should be strings"
|
||||
for file_metadata in files_list:
|
||||
assert file_metadata.processing_status == "completed", f"File {file_metadata.file_name} should show completed status"
|
||||
|
||||
system = get_raw_system_message(client, agent_state.id)
|
||||
assert '<file status="open"' not in system
|
||||
if file_metadata.total_chunks and file_metadata.total_chunks > 0:
|
||||
assert file_metadata.chunks_embedded == file_metadata.total_chunks, (
|
||||
f"File {file_metadata.file_name} should have all chunks embedded: {file_metadata.chunks_embedded}/{file_metadata.total_chunks}"
|
||||
)
|
||||
|
||||
|
||||
def test_file_processing_timeout(disable_pinecone, client: LettaSDKClient):
|
||||
"""Test that files in non-terminal states are moved to error after timeout"""
|
||||
# Create a source
|
||||
source = client.sources.create(name="test_timeout_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
# Upload a file
|
||||
file_path = "tests/data/test.txt"
|
||||
with open(file_path, "rb") as f:
|
||||
file_metadata = client.sources.files.upload(source_id=source.id, file=f)
|
||||
|
||||
# Get the file ID
|
||||
file_id = file_metadata.id
|
||||
|
||||
# Test the is_terminal_state method directly (this doesn't require server mocking)
|
||||
assert FileProcessingStatus.COMPLETED.is_terminal_state() == True
|
||||
assert FileProcessingStatus.ERROR.is_terminal_state() == True
|
||||
assert FileProcessingStatus.PARSING.is_terminal_state() == False
|
||||
assert FileProcessingStatus.EMBEDDING.is_terminal_state() == False
|
||||
assert FileProcessingStatus.PENDING.is_terminal_state() == False
|
||||
|
||||
# For testing the actual timeout logic, we can check the current file status
|
||||
current_file = client.sources.get_file_metadata(source_id=source.id, file_id=file_id)
|
||||
|
||||
# Convert string status to enum for testing
|
||||
status_enum = FileProcessingStatus(current_file.processing_status)
|
||||
|
||||
# Verify that files in terminal states are not affected by timeout checks
|
||||
if status_enum.is_terminal_state():
|
||||
# This is the expected behavior - files that completed processing shouldn't timeout
|
||||
print(f"File {file_id} is in terminal state: {current_file.processing_status}")
|
||||
assert status_enum in [FileProcessingStatus.COMPLETED, FileProcessingStatus.ERROR]
|
||||
else:
|
||||
# If file is still processing, it should eventually complete or timeout
|
||||
# In a real scenario, we'd wait and check, but for unit tests we just verify the logic exists
|
||||
print(f"File {file_id} is still processing: {current_file.processing_status}")
|
||||
assert status_enum in [FileProcessingStatus.PENDING, FileProcessingStatus.PARSING, FileProcessingStatus.EMBEDDING]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_file_processing_timeout_logic():
|
||||
"""Test the timeout logic directly without server dependencies"""
|
||||
from datetime import timezone
|
||||
|
||||
# Test scenario: file created 35 minutes ago, timeout is 30 minutes
|
||||
old_time = datetime.now(timezone.utc) - timedelta(minutes=35)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
timeout_minutes = 30
|
||||
|
||||
# Calculate timeout threshold
|
||||
timeout_threshold = current_time - timedelta(minutes=timeout_minutes)
|
||||
|
||||
# Verify timeout logic
|
||||
assert old_time < timeout_threshold, "File created 35 minutes ago should be past 30-minute timeout"
|
||||
|
||||
# Test edge case: file created exactly at timeout
|
||||
edge_time = current_time - timedelta(minutes=timeout_minutes)
|
||||
assert not (edge_time < timeout_threshold), "File created exactly at timeout should not trigger timeout"
|
||||
|
||||
# Test recent file
|
||||
recent_time = current_time - timedelta(minutes=10)
|
||||
assert not (recent_time < timeout_threshold), "Recent file should not trigger timeout"
|
||||
|
||||
|
||||
def test_letta_free_embedding(disable_pinecone, client: LettaSDKClient):
|
||||
"""Test creating a source with letta/letta-free embedding and uploading a file"""
|
||||
# create a source with letta-free embedding
|
||||
source = client.sources.create(name="test_letta_free_source", embedding="letta/letta-free")
|
||||
|
||||
# verify source was created with correct embedding
|
||||
assert source.name == "test_letta_free_source"
|
||||
print("\n\n\n\ntest")
|
||||
print(source.embedding_config)
|
||||
# assert source.embedding_config.embedding_model == "letta-free"
|
||||
|
||||
# upload test.txt file
|
||||
file_path = "tests/data/test.txt"
|
||||
file_metadata = upload_file_and_wait(client, source.id, file_path)
|
||||
|
||||
# verify file was uploaded successfully
|
||||
assert file_metadata.processing_status == "completed"
|
||||
assert file_metadata.source_id == source.id
|
||||
assert file_metadata.file_name == "test.txt"
|
||||
|
||||
# verify file appears in source files list
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
assert files[0].id == file_metadata.id
|
||||
|
||||
# cleanup
|
||||
client.sources.delete(source_id=source.id)
|
||||
|
||||
|
||||
def test_turbopuffer_lifecycle_file_and_source_deletion(disable_pinecone, client: LettaSDKClient):
|
||||
"""Test that file and source deletion removes records from Turbopuffer"""
|
||||
source = client.sources.create(name="test_tpuf_lifecycle", embedding="openai/text-embedding-3-small")
|
||||
|
||||
file_paths = ["tests/data/test.txt", "tests/data/test.md"]
|
||||
uploaded_files = []
|
||||
for file_path in file_paths:
|
||||
file_metadata = upload_file_and_wait(client, source.id, file_path)
|
||||
uploaded_files.append(file_metadata)
|
||||
|
||||
user = User(name="temp", organization_id=DEFAULT_ORG_ID)
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# test file-level deletion
|
||||
if len(uploaded_files) > 1:
|
||||
file_to_delete = uploaded_files[0]
|
||||
|
||||
passages_before = asyncio.run(
|
||||
tpuf_client.query_file_passages(
|
||||
source_ids=[source.id], organization_id=user.organization_id, actor=user, file_id=file_to_delete.id, top_k=100
|
||||
)
|
||||
)
|
||||
print(f"Found {len(passages_before)} passages for file before deletion")
|
||||
assert len(passages_before) > 0, "Should have passages before deletion"
|
||||
|
||||
client.sources.files.delete(source_id=source.id, file_id=file_to_delete.id)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
passages_after = asyncio.run(
|
||||
tpuf_client.query_file_passages(
|
||||
source_ids=[source.id], organization_id=user.organization_id, actor=user, file_id=file_to_delete.id, top_k=100
|
||||
)
|
||||
)
|
||||
print(f"Found {len(passages_after)} passages for file after deletion")
|
||||
assert len(passages_after) == 0, f"File passages should be removed from Turbopuffer after deletion, but found {len(passages_after)}"
|
||||
|
||||
# test source-level deletion
|
||||
remaining_passages_before = []
|
||||
for file_metadata in uploaded_files[1:]:
|
||||
passages = asyncio.run(
|
||||
tpuf_client.query_file_passages(
|
||||
source_ids=[source.id], organization_id=user.organization_id, actor=user, file_id=file_metadata.id, top_k=100
|
||||
)
|
||||
)
|
||||
remaining_passages_before.extend(passages)
|
||||
|
||||
print(f"Found {len(remaining_passages_before)} passages for remaining files before source deletion")
|
||||
assert len(remaining_passages_before) > 0, "Should have passages for remaining files"
|
||||
|
||||
client.sources.delete(source_id=source.id)
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
remaining_passages_after = []
|
||||
for file_metadata in uploaded_files[1:]:
|
||||
try:
|
||||
passages = asyncio.run(
|
||||
tpuf_client.query_file_passages(
|
||||
source_ids=[source.id], organization_id=user.organization_id, actor=user, file_id=file_metadata.id, top_k=100
|
||||
)
|
||||
)
|
||||
remaining_passages_after.extend(passages)
|
||||
except Exception as e:
|
||||
print(f"Expected error querying deleted source: {e}")
|
||||
|
||||
print(f"Found {len(remaining_passages_after)} passages for files after source deletion")
|
||||
assert len(remaining_passages_after) == 0, (
|
||||
f"All source passages should be removed from Turbopuffer after source deletion, but found {len(remaining_passages_after)}"
|
||||
)
|
||||
|
||||
|
||||
def test_turbopuffer_multiple_sources(disable_pinecone, client: LettaSDKClient):
|
||||
"""Test that Turbopuffer correctly isolates passages by source in org-scoped namespace"""
|
||||
source1 = client.sources.create(name="test_tpuf_source1", embedding="openai/text-embedding-3-small")
|
||||
source2 = client.sources.create(name="test_tpuf_source2", embedding="openai/text-embedding-3-small")
|
||||
|
||||
file1_metadata = upload_file_and_wait(client, source1.id, "tests/data/test.txt")
|
||||
file2_metadata = upload_file_and_wait(client, source2.id, "tests/data/test.md")
|
||||
|
||||
user = User(name="temp", organization_id=DEFAULT_ORG_ID)
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
source1_passages = asyncio.run(
|
||||
tpuf_client.query_file_passages(source_ids=[source1.id], organization_id=user.organization_id, actor=user, top_k=100)
|
||||
)
|
||||
|
||||
source2_passages = asyncio.run(
|
||||
tpuf_client.query_file_passages(source_ids=[source2.id], organization_id=user.organization_id, actor=user, top_k=100)
|
||||
)
|
||||
|
||||
print(f"Source1 has {len(source1_passages)} passages")
|
||||
print(f"Source2 has {len(source2_passages)} passages")
|
||||
|
||||
assert len(source1_passages) > 0, "Source1 should have passages"
|
||||
assert len(source2_passages) > 0, "Source2 should have passages"
|
||||
|
||||
for passage, _, _ in source1_passages:
|
||||
assert passage.source_id == source1.id, f"Passage should belong to source1, but has source_id={passage.source_id}"
|
||||
assert passage.file_id == file1_metadata.id, f"Passage should belong to file1, but has file_id={passage.file_id}"
|
||||
|
||||
for passage, _, _ in source2_passages:
|
||||
assert passage.source_id == source2.id, f"Passage should belong to source2, but has source_id={passage.source_id}"
|
||||
assert passage.file_id == file2_metadata.id, f"Passage should belong to file2, but has file_id={passage.file_id}"
|
||||
|
||||
# delete source1 and verify source2 is unaffected
|
||||
client.sources.delete(source_id=source1.id)
|
||||
time.sleep(2)
|
||||
|
||||
source2_passages_after = asyncio.run(
|
||||
tpuf_client.query_file_passages(source_ids=[source2.id], organization_id=user.organization_id, actor=user, top_k=100)
|
||||
)
|
||||
|
||||
assert len(source2_passages_after) == len(source2_passages), (
|
||||
f"Source2 should still have all passages after source1 deletion: {len(source2_passages_after)} vs {len(source2_passages)}"
|
||||
)
|
||||
|
||||
client.sources.delete(source_id=source2.id)
|
||||
|
||||
|
||||
# --- End Turbopuffer Tests ---
|
||||
|
||||
Reference in New Issue
Block a user