feat: Support basic upload/querying on tpuf [LET-3465] (#4255)

* wip implementing turbopuffer

* Move imports up

* Add type of archive

* Integrate turbopuffer functionality

* Debug turbopuffer tests failing

* Fix turbopuffer

* Run fern

* Fix multiple heads
This commit is contained in:
Matthew Zhou
2025-08-28 10:39:16 -07:00
committed by GitHub
parent a8ffae6f8d
commit 651671cb83
21 changed files with 902 additions and 71 deletions

View File

@@ -0,0 +1,60 @@
"""Add vector_db_provider to archives table
Revision ID: 068588268b02
Revises: d5103ee17ed5
Create Date: 2025-08-27 13:16:29.428231
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from letta.settings import settings
# revision identifiers, used by Alembic.
revision: str = "068588268b02"
down_revision: Union[str, None] = "887a4367b560"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
if settings.letta_pg_uri_no_default:
# PostgreSQL - use enum type
vectordbprovider = sa.Enum("NATIVE", "TPUF", name="vectordbprovider")
vectordbprovider.create(op.get_bind(), checkfirst=True)
# Add column as nullable first
op.add_column("archives", sa.Column("vector_db_provider", vectordbprovider, nullable=True))
# Backfill existing rows with NATIVE
op.execute("UPDATE archives SET vector_db_provider = 'NATIVE' WHERE vector_db_provider IS NULL")
# Make column non-nullable
op.alter_column("archives", "vector_db_provider", nullable=False)
else:
# SQLite - use string type
# Add column as nullable first
op.add_column("archives", sa.Column("vector_db_provider", sa.String(), nullable=True))
# Backfill existing rows with NATIVE
op.execute("UPDATE archives SET vector_db_provider = 'NATIVE' WHERE vector_db_provider IS NULL")
# For SQLite, we need to recreate the table to make column non-nullable
# This is a limitation of SQLite ALTER TABLE
# For simplicity, we'll leave it nullable in SQLite
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("archives", "vector_db_provider")
if settings.letta_pg_uri_no_default:
# Drop enum type for PostgreSQL
vectordbprovider = sa.Enum("NATIVE", "TPUF", name="vectordbprovider")
vectordbprovider.drop(op.get_bind(), checkfirst=True)
# ### end Alembic commands ###

View File

@@ -484,7 +484,7 @@ class VoiceAgent(BaseAgent):
if start_date and end_date and start_date > end_date:
start_date, end_date = end_date, start_date
archival_results = await self.agent_manager.list_passages_async(
archival_results = await self.agent_manager.query_agent_passages_async(
actor=self.actor,
agent_id=self.agent_id,
query_text=archival_query,

View File

@@ -107,7 +107,7 @@ async def archival_memory_search(self: "Agent", query: str, page: Optional[int]
try:
# Get results using passage manager
all_results = await self.agent_manager.list_passages_async(
all_results = await self.agent_manager.query_agent_passages_async(
actor=self.user,
agent_id=self.agent_state.id,
query_text=query,

View File

@@ -0,0 +1,276 @@
"""Turbopuffer utilities for archival memory storage."""
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple
from letta.otel.tracing import trace_method
from letta.schemas.passage import Passage as PydanticPassage
from letta.settings import settings
logger = logging.getLogger(__name__)
def should_use_tpuf() -> bool:
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key)
class TurbopufferClient:
"""Client for managing archival memory with Turbopuffer vector database."""
def __init__(self, api_key: str = None, region: str = None):
"""Initialize Turbopuffer client."""
self.api_key = api_key or settings.tpuf_api_key
self.region = region or settings.tpuf_region
if not self.api_key:
raise ValueError("Turbopuffer API key not provided")
@trace_method
def _get_namespace_name(self, archive_id: str) -> str:
"""Get namespace name for a specific archive."""
# use archive_id as namespace to isolate different archives' memories
# append environment suffix to namespace for isolation if environment is set
environment = settings.environment
if environment:
namespace_name = f"{archive_id}_{environment.lower()}"
else:
namespace_name = archive_id
return namespace_name
@trace_method
async def insert_archival_memories(
self,
archive_id: str,
text_chunks: List[str],
embeddings: List[List[float]],
passage_ids: List[str],
organization_id: str,
tags: Optional[List[str]] = None,
created_at: Optional[datetime] = None,
) -> List[PydanticPassage]:
"""Insert passages into Turbopuffer.
Args:
archive_id: ID of the archive
text_chunks: List of text chunks to store
embeddings: List of embedding vectors corresponding to text chunks
passage_ids: List of passage IDs (must match 1:1 with text_chunks)
organization_id: Organization ID for the passages
tags: Optional list of tags to attach to all passages
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
Returns:
List of PydanticPassage objects that were inserted
"""
from turbopuffer import AsyncTurbopuffer
namespace_name = self._get_namespace_name(archive_id)
# handle timestamp - ensure UTC
if created_at is None:
timestamp = datetime.now(timezone.utc)
else:
# ensure the provided timestamp is timezone-aware and in UTC
if created_at.tzinfo is None:
# assume UTC if no timezone provided
timestamp = created_at.replace(tzinfo=timezone.utc)
else:
# convert to UTC if in different timezone
timestamp = created_at.astimezone(timezone.utc)
# passage_ids must be provided for dual-write consistency
if not passage_ids:
raise ValueError("passage_ids must be provided for Turbopuffer insertion")
if len(passage_ids) != len(text_chunks):
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})")
if len(passage_ids) != len(embeddings):
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match embeddings length ({len(embeddings)})")
# prepare column-based data for turbopuffer - optimized for batch insert
ids = []
vectors = []
texts = []
organization_ids = []
archive_ids = []
created_ats = []
passages = []
# prepare tag columns
tag_columns = {tag: [] for tag in (tags or [])}
for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)):
passage_id = passage_ids[idx]
# append to columns
ids.append(passage_id)
vectors.append(embedding)
texts.append(text)
organization_ids.append(organization_id)
archive_ids.append(archive_id)
created_ats.append(timestamp)
# append tag values
for tag in tag_columns:
tag_columns[tag].append(True)
# Create PydanticPassage object
passage = PydanticPassage(
id=passage_id,
text=text,
organization_id=organization_id,
archive_id=archive_id,
created_at=timestamp,
metadata_={},
embedding=embedding,
embedding_config=None, # Will be set by caller if needed
)
passages.append(passage)
# build column-based upsert data
upsert_columns = {
"id": ids,
"vector": vectors,
"text": texts,
"organization_id": organization_ids,
"archive_id": archive_ids,
"created_at": created_ats,
}
# add tag columns if any
upsert_columns.update(tag_columns)
try:
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# turbopuffer recommends column-based writes for performance
await namespace.write(upsert_columns=upsert_columns, distance_metric="cosine_distance")
logger.info(f"Successfully inserted {len(ids)} passages to Turbopuffer for archive {archive_id}")
return passages
except Exception as e:
logger.error(f"Failed to insert passages to Turbopuffer: {e}")
# check if it's a duplicate ID error
if "duplicate" in str(e).lower():
logger.error("Duplicate passage IDs detected in batch")
raise
@trace_method
async def query_passages(
self, archive_id: str, query_embedding: List[float], top_k: int = 10, filters: Optional[Dict[str, Any]] = None
) -> List[Tuple[PydanticPassage, float]]:
"""Query passages from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = self._get_namespace_name(archive_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# build filter conditions
filter_conditions = []
if filters:
for key, value in filters.items():
filter_conditions.append((key, "Eq", value))
query_params = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
}
if filter_conditions:
query_params["filters"] = ("And", filter_conditions) if len(filter_conditions) > 1 else filter_conditions[0]
result = await namespace.query(**query_params)
# convert results back to passages
passages_with_scores = []
# Turbopuffer returns a NamespaceQueryResponse with a rows attribute
for row in result.rows:
# Build metadata including any filter conditions that were applied
metadata = {}
if filters:
metadata["applied_filters"] = filters
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
passage = PydanticPassage(
id=row.id,
text=getattr(row, "text", ""),
organization_id=getattr(row, "organization_id", None),
archive_id=archive_id, # use the archive_id from the query
created_at=getattr(row, "created_at", None),
metadata_=metadata, # Include filter conditions in metadata
# Set required fields to empty/default values since we don't store embeddings
embedding=[], # Empty embedding since we don't return it from Turbopuffer
embedding_config=None, # No embedding config needed for retrieved passages
)
# turbopuffer returns distance in $dist attribute, convert to similarity score
distance = getattr(row, "$dist", 0.0)
score = 1.0 - distance
passages_with_scores.append((passage, score))
return passages_with_scores
except Exception as e:
logger.error(f"Failed to query passages from Turbopuffer: {e}")
raise
@trace_method
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
"""Delete a passage from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = self._get_namespace_name(archive_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# Use write API with deletes parameter as per Turbopuffer docs
await namespace.write(deletes=[passage_id])
logger.info(f"Successfully deleted passage {passage_id} from Turbopuffer archive {archive_id}")
return True
except Exception as e:
logger.error(f"Failed to delete passage from Turbopuffer: {e}")
raise
@trace_method
async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool:
"""Delete multiple passages from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
if not passage_ids:
return True
namespace_name = self._get_namespace_name(archive_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# Use write API with deletes parameter as per Turbopuffer docs
await namespace.write(deletes=passage_ids)
logger.info(f"Successfully deleted {len(passage_ids)} passages from Turbopuffer archive {archive_id}")
return True
except Exception as e:
logger.error(f"Failed to delete passages from Turbopuffer: {e}")
raise
@trace_method
async def delete_all_passages(self, archive_id: str) -> bool:
"""Delete all passages for an archive from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = self._get_namespace_name(archive_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# Turbopuffer has a delete_all() method on namespace
await namespace.delete_all()
logger.info(f"Successfully deleted all passages for archive {archive_id}")
return True
except Exception as e:
logger.error(f"Failed to delete all passages from Turbopuffer: {e}")
raise

View File

@@ -2,12 +2,13 @@ import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, Index, String
from sqlalchemy import JSON, Enum, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.archive import Archive as PydanticArchive
from letta.schemas.enums import VectorDBProvider
from letta.settings import DatabaseChoice, settings
if TYPE_CHECKING:
@@ -38,6 +39,12 @@ class Archive(SqlalchemyBase, OrganizationMixin):
# archive-specific fields
name: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the archive")
description: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="A description of the archive")
vector_db_provider: Mapped[VectorDBProvider] = mapped_column(
Enum(VectorDBProvider),
nullable=False,
default=VectorDBProvider.NATIVE,
doc="The vector database provider used for this archive's passages",
)
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Additional metadata for the archive")
# relationships

View File

@@ -3,6 +3,7 @@ from typing import Dict, Optional
from pydantic import Field
from letta.schemas.enums import VectorDBProvider
from letta.schemas.letta_base import OrmMetadataBase
@@ -12,6 +13,9 @@ class ArchiveBase(OrmMetadataBase):
name: str = Field(..., description="The name of the archive")
description: Optional[str] = Field(None, description="A description of the archive")
organization_id: str = Field(..., description="The organization this archive belongs to")
vector_db_provider: VectorDBProvider = Field(
default=VectorDBProvider.NATIVE, description="The vector database provider used for this archive's passages"
)
metadata: Optional[Dict] = Field(default_factory=dict, validation_alias="metadata_", description="Additional metadata")

View File

@@ -171,3 +171,10 @@ class StepStatus(str, Enum):
SUCCESS = "success"
FAILED = "failed"
CANCELLED = "cancelled"
class VectorDBProvider(str, Enum):
"""Supported vector database providers for archival memory"""
NATIVE = "native"
TPUF = "tpuf"

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
from datetime import datetime
from typing import TYPE_CHECKING, List, Optional
from jinja2 import Template, TemplateSyntaxError
@@ -325,3 +326,5 @@ class RecallMemorySummary(BaseModel):
class CreateArchivalMemory(BaseModel):
text: str = Field(..., description="Text to write to archival memory.")
tags: Optional[List[str]] = Field(None, description="Optional list of tags to attach to the memory.")
created_at: Optional[datetime] = Field(None, description="Optional timestamp for the memory (defaults to current UTC time).")

View File

@@ -34,7 +34,7 @@ from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaSt
from letta.schemas.letta_response import LettaResponse
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
from letta.schemas.message import MessageCreate
from letta.schemas.passage import Passage, PassageUpdate
from letta.schemas.passage import Passage
from letta.schemas.run import Run
from letta.schemas.source import Source
from letta.schemas.tool import Tool
@@ -954,22 +954,9 @@ async def create_passage(
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.insert_archival_memory_async(agent_id=agent_id, memory_contents=request.text, actor=actor)
@router.patch("/{agent_id}/archival-memory/{memory_id}", response_model=list[Passage], operation_id="modify_passage")
def modify_passage(
agent_id: str,
memory_id: str,
passage: PassageUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Modify a memory in the agent's archival memory store.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.modify_archival_memory(agent_id=agent_id, memory_id=memory_id, passage=passage, actor=actor)
return await server.insert_archival_memory_async(
agent_id=agent_id, memory_contents=request.text, actor=actor, tags=request.tags, created_at=request.created_at
)
# TODO(ethan): query or path parameter for memory_id?

View File

@@ -351,7 +351,7 @@ async def list_folder_passages(
List all passages associated with a data folder.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.agent_manager.list_passages_async(
return await server.agent_manager.query_source_passages_async(
actor=actor,
source_id=folder_id,
after=after,

View File

@@ -349,7 +349,7 @@ async def list_source_passages(
List all passages associated with a data source.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.agent_manager.list_passages_async(
return await server.agent_manager.query_source_passages_async(
actor=actor,
source_id=source_id,
after=after,

View File

@@ -52,7 +52,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.passage import Passage, PassageUpdate
from letta.schemas.passage import Passage
from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.providers import (
AnthropicProvider,
@@ -1114,7 +1114,7 @@ class SyncServer(Server):
ascending: Optional[bool] = True,
) -> List[Passage]:
# iterate over records
records = await self.agent_manager.list_agent_passages_async(
records = await self.agent_manager.query_agent_passages_async(
actor=actor,
agent_id=agent_id,
after=after,
@@ -1125,18 +1125,18 @@ class SyncServer(Server):
)
return records
async def insert_archival_memory_async(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]:
async def insert_archival_memory_async(
self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime]
) -> List[Passage]:
# Get the agent object (loaded in memory)
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
# Insert passages into the archive
# Use passage manager which handles dual-write to Turbopuffer if enabled
passages = await self.passage_manager.insert_passage(agent_state=agent_state, text=memory_contents, actor=actor)
return passages
# TODO: Add support for tags and created_at parameters
# Currently PassageManager.insert_passage doesn't support these parameters
def modify_archival_memory(self, agent_id: str, memory_id: str, passage: PassageUpdate, actor: User) -> List[Passage]:
passage = Passage(**passage.model_dump(exclude_unset=True, exclude_none=True))
passages = self.passage_manager.update_passage_by_id(passage_id=memory_id, passage=passage, actor=actor)
return passages
async def delete_archival_memory_async(self, memory_id: str, actor: User):
@@ -1270,7 +1270,7 @@ class SyncServer(Server):
await self.source_manager.delete_source(source_id=source_id, actor=actor)
# delete data from passage store
passages_to_be_deleted = await self.agent_manager.list_passages_async(actor=actor, source_id=source_id, limit=None)
passages_to_be_deleted = await self.agent_manager.query_source_passages_async(actor=actor, source_id=source_id, limit=None)
await self.passage_manager.delete_source_passages_async(actor=actor, passages=passages_to_be_deleted)
# TODO: delete data from agent passage stores (?)
@@ -1316,27 +1316,6 @@ class SyncServer(Server):
async def sleeptime_document_ingest_async(
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
) -> None:
# TEMPORARILY DISABLE UNTIL V2
# sleeptime_agent_state = await self.create_document_sleeptime_agent_async(main_agent, source, actor, clear_history)
# sleeptime_agent = LettaAgent(
# agent_id=sleeptime_agent_state.id,
# message_manager=self.message_manager,
# agent_manager=self.agent_manager,
# block_manager=self.block_manager,
# job_manager=self.job_manager,
# passage_manager=self.passage_manager,
# actor=actor,
# step_manager=self.step_manager,
# telemetry_manager=self.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
# )
# passages = await self.agent_manager.list_passages_async(actor=actor, source_id=source.id)
# for passage in passages:
# await sleeptime_agent.step(
# input_messages=[
# MessageCreate(role="user", content=passage.text),
# ]
# )
# await self.agent_manager.delete_agent_async(agent_id=sleeptime_agent_state.id, actor=actor)
pass
async def _remove_file_from_agent(self, agent_id: str, file_id: str, actor: User) -> None:

View File

@@ -48,7 +48,7 @@ from letta.schemas.block import DEFAULT_BLOCKS
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderType, ToolType
from letta.schemas.enums import ProviderType, ToolType, VectorDBProvider
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.group import Group as PydanticGroup
from letta.schemas.group import ManagerType
@@ -67,6 +67,7 @@ from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
from letta.server.db import db_registry
from letta.services.archive_manager import ArchiveManager
from letta.services.block_manager import BlockManager
from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator
from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, TiktokenCounter
@@ -116,6 +117,7 @@ class AgentManager:
self.passage_manager = PassageManager()
self.identity_manager = IdentityManager()
self.file_agent_manager = FileAgentManager()
self.archive_manager = ArchiveManager()
@staticmethod
def _should_exclude_model_from_base_tool_rules(model: str) -> bool:
@@ -2520,7 +2522,20 @@ class AgentManager:
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False,
) -> List[PydanticPassage]:
"""Lists all passages attached to an agent."""
"""
DEPRECATED: Use query_source_passages_async or query_agent_passages_async instead.
This method is kept only for test compatibility and will be removed in a future version.
Lists all passages attached to an agent (combines both source and agent passages).
"""
import warnings
warnings.warn(
"list_passages_async is deprecated. Use query_source_passages_async or query_agent_passages_async instead.",
DeprecationWarning,
stacklevel=2,
)
async with db_registry.async_session() as session:
main_query = await build_passage_query(
actor=actor,
@@ -2567,7 +2582,7 @@ class AgentManager:
@enforce_types
@trace_method
async def list_source_passages_async(
async def query_source_passages_async(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
@@ -2615,7 +2630,7 @@ class AgentManager:
@enforce_types
@trace_method
async def list_agent_passages_async(
async def query_agent_passages_async(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
@@ -2630,6 +2645,46 @@ class AgentManager:
embedding_config: Optional[EmbeddingConfig] = None,
) -> List[PydanticPassage]:
"""Lists all passages attached to an agent."""
# Check if we should use Turbopuffer for vector search
if embed_query and agent_id and query_text and embedding_config:
# Get archive IDs for the agent
archive_ids = await self.get_agent_archive_ids_async(agent_id=agent_id, actor=actor)
if archive_ids:
# TODO: Remove this restriction once we support multiple archives with mixed vector DB providers
if len(archive_ids) > 1:
raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported for vector search")
# Get archive to check vector_db_provider
archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_ids[0], actor=actor)
# Use Turbopuffer for vector search if archive is configured for TPUF
if archive.vector_db_provider == VectorDBProvider.TPUF:
from letta.helpers.tpuf_client import TurbopufferClient
from letta.llm_api.llm_client import LLMClient
# Generate embedding for query
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
query_embedding = embeddings[0]
# Query Turbopuffer
tpuf_client = TurbopufferClient()
passages_with_scores = await tpuf_client.query_passages(
archive_id=archive_ids[0],
query_embedding=query_embedding,
top_k=limit,
)
# Return just the passages (without scores)
return [passage for passage, _ in passages_with_scores]
else:
return []
# Fall back to SQL-based search for non-vector queries or NATIVE archives
async with db_registry.async_session() as session:
main_query = await build_agent_passage_query(
actor=actor,

View File

@@ -2,11 +2,13 @@ from typing import List, Optional
from sqlalchemy import select
from letta.helpers.tpuf_client import should_use_tpuf
from letta.log import get_logger
from letta.orm import ArchivalPassage
from letta.orm import Archive as ArchiveModel
from letta.orm import ArchivesAgents
from letta.schemas.archive import Archive as PydanticArchive
from letta.schemas.enums import VectorDBProvider
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
@@ -27,10 +29,14 @@ class ArchiveManager:
"""Create a new archive."""
try:
with db_registry.session() as session:
# determine vector db provider based on settings
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
archive = ArchiveModel(
name=name,
description=description,
organization_id=actor.organization_id,
vector_db_provider=vector_db_provider,
)
archive.create(session, actor=actor)
return archive.to_pydantic()
@@ -48,10 +54,14 @@ class ArchiveManager:
"""Create a new archive."""
try:
async with db_registry.async_session() as session:
# determine vector db provider based on settings
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
archive = ArchiveModel(
name=name,
description=description,
organization_id=actor.organization_id,
vector_db_provider=vector_db_provider,
)
await archive.create_async(session, actor=actor)
return archive.to_pydantic()

View File

@@ -14,6 +14,7 @@ from letta.orm.errors import NoResultFound
from letta.orm.passage import ArchivalPassage, SourcePassage
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import VectorDBProvider
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.user import User as PydanticUser
@@ -489,6 +490,8 @@ class PassageManager:
embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config)
passages = []
# Always write to SQL database first
for chunk_text, embedding in zip(text_chunks, embeddings):
passage = await self.create_agent_passage_async(
PydanticPassage(
@@ -502,6 +505,26 @@ class PassageManager:
)
passages.append(passage)
# If archive uses Turbopuffer, also write to Turbopuffer (dual-write)
if archive.vector_db_provider == VectorDBProvider.TPUF:
from letta.helpers.tpuf_client import TurbopufferClient
tpuf_client = TurbopufferClient()
# Extract IDs and texts from the created passages
passage_ids = [p.id for p in passages]
passage_texts = [p.text for p in passages]
# Insert to Turbopuffer with the same IDs as SQL
await tpuf_client.insert_archival_memories(
archive_id=archive.id,
text_chunks=passage_texts,
embeddings=embeddings,
passage_ids=passage_ids, # Use same IDs as SQL
organization_id=actor.organization_id,
created_at=passages[0].created_at if passages else None,
)
return passages
except Exception as e:
@@ -655,7 +678,20 @@ class PassageManager:
async with db_registry.async_session() as session:
try:
passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
archive_id = passage.archive_id
# Delete from SQL first
await passage.hard_delete_async(session, actor=actor)
# Check if archive uses Turbopuffer and dual-delete
if archive_id:
archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_id, actor=actor)
if archive.vector_db_provider == VectorDBProvider.TPUF:
from letta.helpers.tpuf_client import TurbopufferClient
tpuf_client = TurbopufferClient()
await tpuf_client.delete_passage(archive_id=archive_id, passage_id=passage_id)
return True
except NoResultFound:
raise NoResultFound(f"Agent passage with id {passage_id} not found.")
@@ -812,12 +848,34 @@ class PassageManager:
@trace_method
async def delete_agent_passages_async(
self,
actor: PydanticUser,
passages: List[PydanticPassage],
actor: PydanticUser,
) -> bool:
"""Delete multiple agent passages."""
if not passages:
return True
async with db_registry.async_session() as session:
# Delete from SQL first
await ArchivalPassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor)
# Group passages by archive_id for efficient Turbopuffer deletion
passages_by_archive = {}
for passage in passages:
if passage.archive_id:
if passage.archive_id not in passages_by_archive:
passages_by_archive[passage.archive_id] = []
passages_by_archive[passage.archive_id].append(passage.id)
# Check each archive and delete from Turbopuffer if needed
for archive_id, passage_ids in passages_by_archive.items():
archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_id, actor=actor)
if archive.vector_db_provider == VectorDBProvider.TPUF:
from letta.helpers.tpuf_client import TurbopufferClient
tpuf_client = TurbopufferClient()
await tpuf_client.delete_passages(archive_id=archive_id, passage_ids=passage_ids)
return True
@enforce_types

View File

@@ -143,7 +143,7 @@ class LettaCoreToolExecutor(ToolExecutor):
try:
# Get results using passage manager
all_results = await AgentManager().list_agent_passages_async(
all_results = await AgentManager().query_agent_passages_async(
actor=actor,
agent_id=agent_state.id,
query_text=query,

View File

@@ -661,7 +661,7 @@ class LettaFileToolExecutor(ToolExecutor):
async def _search_files_traditional(self, agent_state: AgentState, query: str, limit: int) -> str:
"""Traditional search using existing passage manager."""
# Get semantic search results
passages = await self.agent_manager.list_source_passages_async(
passages = await self.agent_manager.query_source_passages_async(
actor=self.actor,
agent_id=agent_state.id,
query_text=query,

View File

@@ -300,7 +300,7 @@ class Settings(BaseSettings):
# For tpuf - currently only for archival memories
use_tpuf: bool = False
tpuf_api_key: Optional[str] = None
tpuf_region: str = "gcp-us-central1.turbopuffer.com"
tpuf_region: str = "gcp-us-central1"
# File processing timeout settings
file_processing_timeout_minutes: int = 30

View File

@@ -104,6 +104,62 @@ def disable_pinecone() -> Generator[None, None, None]:
settings.pinecone_api_key = original_pinecone_api_key
@pytest.fixture
def disable_turbopuffer() -> Generator[None, None, None]:
"""
Temporarily disables Turbopuffer by setting `settings.use_tpuf` to False
and `settings.tpuf_api_key` to None for the duration of the test.
Also sets environment to DEV for testing.
Restores the original values afterward.
"""
from letta.settings import settings
original_use_tpuf = settings.use_tpuf
original_tpuf_api_key = settings.tpuf_api_key
original_environment = settings.environment
settings.use_tpuf = False
settings.tpuf_api_key = None
settings.environment = "DEV"
yield
settings.use_tpuf = original_use_tpuf
settings.tpuf_api_key = original_tpuf_api_key
settings.environment = original_environment
@pytest.fixture
def turbopuffer_mode(request) -> Generator[None, None, None]:
"""
Parametrizable fixture to enable/disable Turbopuffer mode.
Usage:
@pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True)
def test_function(turbopuffer_mode, ...):
# Test runs twice - once with Turbopuffer enabled, once disabled
"""
from letta.settings import settings
enable_tpuf = request.param
original_use_tpuf = settings.use_tpuf
original_tpuf_api_key = settings.tpuf_api_key
original_environment = settings.environment
# Set environment to DEV for testing
settings.environment = "DEV"
if not enable_tpuf:
# Disable Turbopuffer by setting use_tpuf to False
settings.use_tpuf = False
settings.tpuf_api_key = None
# If enable_tpuf is True, leave the original settings unchanged
yield
# Restore original settings
settings.use_tpuf = original_use_tpuf
settings.tpuf_api_key = original_tpuf_api_key
settings.environment = original_environment
@pytest.fixture
def check_e2b_key_is_set():
from letta.settings import tool_settings

View File

@@ -2670,21 +2670,18 @@ async def test_refresh_memory_async(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup):
async def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer):
"""Test basic listing functionality of agent passages"""
all_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id)
assert len(all_passages) == 5 # 3 source + 2 agent passages
source_passages = await server.agent_manager.list_source_passages_async(actor=default_user, agent_id=sarah_agent.id)
source_passages = await server.agent_manager.query_source_passages_async(actor=default_user, agent_id=sarah_agent.id)
assert len(source_passages) == 3 # 3 source + 2 agent passages
agent_passages = await server.agent_manager.list_agent_passages_async(actor=default_user, agent_id=sarah_agent.id)
assert len(agent_passages) == 2 # 3 source + 2 agent passages
@pytest.mark.asyncio
async def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup):
async def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer):
"""Test ordering of agent passages"""
# Test ascending order
@@ -2701,7 +2698,7 @@ async def test_agent_list_passages_ordering(server, default_user, sarah_agent, a
@pytest.mark.asyncio
async def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup):
async def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer):
"""Test pagination of agent passages"""
# Test limit
@@ -2742,7 +2739,7 @@ async def test_agent_list_passages_pagination(server, default_user, sarah_agent,
@pytest.mark.asyncio
async def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup):
async def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer):
"""Test text search functionality of agent passages"""
# Test text search for source passages
@@ -2759,7 +2756,7 @@ async def test_agent_list_passages_text_search(server, default_user, sarah_agent
@pytest.mark.asyncio
async def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup):
async def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer):
"""Test text search functionality of agent passages"""
# Test text search for agent passages
@@ -2768,7 +2765,7 @@ async def test_agent_list_passages_agent_only(server, default_user, sarah_agent,
@pytest.mark.asyncio
async def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup):
async def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup, disable_turbopuffer):
"""Test filtering functionality of agent passages"""
# Test source filtering
@@ -2804,7 +2801,9 @@ def mock_embed_model(mock_embeddings):
return mock_model
async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, default_file, mock_embed_model):
async def test_agent_list_passages_vector_search(
server, default_user, sarah_agent, default_source, default_file, mock_embed_model, disable_turbopuffer
):
"""Test vector search functionality of agent passages"""
embed_model = mock_embed_model

View File

@@ -0,0 +1,330 @@
import uuid
from datetime import datetime, timezone
import pytest
from letta.config import LettaConfig
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import VectorDBProvider
from letta.server.server import SyncServer
from letta.settings import settings
@pytest.fixture(scope="module")
def server():
"""Server fixture for testing"""
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=False)
return server
@pytest.fixture
async def sarah_agent(server, default_user):
"""Create a test agent named Sarah"""
from letta.schemas.agent import CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
agent = await server.agent_manager.create_agent_async(
agent_create=CreateAgent(
name="Sarah",
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tools=False,
),
actor=default_user,
)
yield agent
# Cleanup
try:
await server.agent_manager.delete_agent_async(agent.id, default_user)
except:
pass
@pytest.fixture
def enable_turbopuffer():
"""Temporarily enable Turbopuffer for testing with a test API key"""
original_use_tpuf = settings.use_tpuf
original_api_key = settings.tpuf_api_key
original_environment = settings.environment
# Enable Turbopuffer with test key
settings.use_tpuf = True
# Use the existing tpuf_api_key if set, otherwise keep original
if not settings.tpuf_api_key:
settings.tpuf_api_key = original_api_key
# Set environment to DEV for testing
settings.environment = "DEV"
yield
# Restore original values
settings.use_tpuf = original_use_tpuf
settings.tpuf_api_key = original_api_key
settings.environment = original_environment
class TestTurbopufferIntegration:
"""Test Turbopuffer integration functionality with real connections"""
def test_should_use_tpuf_with_settings(self):
"""Test that should_use_tpuf correctly reads settings"""
# Save original values
original_use_tpuf = settings.use_tpuf
original_api_key = settings.tpuf_api_key
try:
# Test when both are set
settings.use_tpuf = True
settings.tpuf_api_key = "test-key"
assert should_use_tpuf() is True
# Test when use_tpuf is False
settings.use_tpuf = False
assert should_use_tpuf() is False
# Test when API key is missing
settings.use_tpuf = True
settings.tpuf_api_key = None
assert should_use_tpuf() is False
finally:
# Restore original values
settings.use_tpuf = original_use_tpuf
settings.tpuf_api_key = original_api_key
@pytest.mark.asyncio
async def test_archive_creation_with_tpuf_enabled(self, server, default_user, enable_turbopuffer):
"""Test that archives are created with correct vector_db_provider when TPUF is enabled"""
archive = await server.archive_manager.create_archive_async(name="Test Archive with TPUF", actor=default_user)
assert archive.vector_db_provider == VectorDBProvider.TPUF
# TODO: Add cleanup when delete_archive method is available
@pytest.mark.asyncio
async def test_archive_creation_with_tpuf_disabled(self, server, default_user, disable_turbopuffer):
"""Test that archives default to NATIVE when TPUF is disabled"""
archive = await server.archive_manager.create_archive_async(name="Test Archive without TPUF", actor=default_user)
assert archive.vector_db_provider == VectorDBProvider.NATIVE
# TODO: Add cleanup when delete_archive method is available
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
async def test_dual_write_and_query_with_real_tpuf(self, server, default_user, sarah_agent, enable_turbopuffer):
"""Test that passages are written to both SQL and Turbopuffer with real connection and can be queried"""
# Create a TPUF-enabled archive
archive = await server.archive_manager.create_archive_async(name="Test TPUF Archive for Real Dual Write", actor=default_user)
assert archive.vector_db_provider == VectorDBProvider.TPUF
# Attach the agent to the archive
await server.archive_manager.attach_agent_to_archive_async(
agent_id=sarah_agent.id, archive_id=archive.id, is_owner=True, actor=default_user
)
try:
# Insert passages - this should trigger dual write
test_passages = [
"Turbopuffer is a vector database optimized for performance.",
"This integration test verifies dual-write functionality.",
"Metadata attributes should be properly stored in Turbopuffer.",
]
for text in test_passages:
passages = await server.passage_manager.insert_passage(agent_state=sarah_agent, text=text, actor=default_user)
assert passages is not None
assert len(passages) > 0
# Verify passages are in SQL - use agent_manager to list passages
sql_passages = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
assert len(sql_passages) >= len(test_passages)
for text in test_passages:
assert any(p.text == text for p in sql_passages)
# Test vector search which should use Turbopuffer
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
# Perform vector search
vector_results = await server.agent_manager.query_agent_passages_async(
actor=default_user,
agent_id=sarah_agent.id,
query_text="turbopuffer vector database",
embedding_config=embedding_config,
embed_query=True,
limit=5,
)
# Should find relevant passages via Turbopuffer vector search
assert len(vector_results) > 0
# The most relevant result should be about Turbopuffer
assert any("Turbopuffer" in p.text or "vector" in p.text for p in vector_results)
# Test deletion - should delete from both
passage_to_delete = sql_passages[0]
await server.passage_manager.delete_agent_passages_async([passage_to_delete], default_user)
# Verify deleted from SQL
remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
assert not any(p.id == passage_to_delete.id for p in remaining)
# Verify vector search no longer returns deleted passage
vector_results_after_delete = await server.agent_manager.query_agent_passages_async(
actor=default_user,
agent_id=sarah_agent.id,
query_text=passage_to_delete.text,
embedding_config=embedding_config,
embed_query=True,
limit=10,
)
assert not any(p.id == passage_to_delete.id for p in vector_results_after_delete)
finally:
# TODO: Clean up archive when delete_archive method is available
pass
@pytest.mark.asyncio
async def test_turbopuffer_metadata_attributes(self, enable_turbopuffer):
"""Test that Turbopuffer properly stores and retrieves metadata attributes"""
# Only run if we have a real API key
if not settings.tpuf_api_key:
pytest.skip("No Turbopuffer API key available")
client = TurbopufferClient()
archive_id = f"test-archive-{datetime.now().timestamp()}"
try:
# Insert passages with various metadata
test_data = [
{
"id": f"passage-{uuid.uuid4()}",
"text": "First test passage",
"vector": [0.1] * 1536,
"organization_id": "org-123",
"created_at": datetime.now(timezone.utc),
},
{
"id": f"passage-{uuid.uuid4()}",
"text": "Second test passage",
"vector": [0.2] * 1536,
"organization_id": "org-123",
"created_at": datetime.now(timezone.utc),
},
{
"id": f"passage-{uuid.uuid4()}",
"text": "Third test passage from different org",
"vector": [0.3] * 1536,
"organization_id": "org-456",
"created_at": datetime.now(timezone.utc),
},
]
# Insert all passages
result = await client.insert_archival_memories(
archive_id=archive_id,
text_chunks=[d["text"] for d in test_data],
embeddings=[d["vector"] for d in test_data],
passage_ids=[d["id"] for d in test_data],
organization_id="org-123", # Default org
created_at=datetime.now(timezone.utc),
)
assert len(result) == 3
# Query with organization filter
query_vector = [0.15] * 1536
results = await client.query_passages(
archive_id=archive_id, query_embedding=query_vector, top_k=10, filters={"organization_id": "org-123"}
)
# Should only get passages from org-123
assert len(results) >= 2 # At least the first two passages
for passage, score in results:
assert passage.organization_id == "org-123"
# Clean up
await client.delete_passages(archive_id=archive_id, passage_ids=[d["id"] for d in test_data])
except Exception as e:
# Clean up on error
try:
await client.delete_all_passages(archive_id)
except:
pass
raise e
@pytest.mark.asyncio
async def test_native_only_operations(self, server, default_user, sarah_agent, disable_turbopuffer):
"""Test that operations work correctly when using only native PostgreSQL"""
# Create archive (should be NATIVE since turbopuffer is disabled)
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
assert archive.vector_db_provider == VectorDBProvider.NATIVE
# Insert passages - should only write to SQL
text_content = "This is a test passage for native PostgreSQL only."
passages = await server.passage_manager.insert_passage(agent_state=sarah_agent, text=text_content, actor=default_user)
assert passages is not None
assert len(passages) > 0
# List passages - should work from SQL
sql_passages = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
assert any(p.text == text_content for p in sql_passages)
# Vector search should use PostgreSQL pgvector
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
vector_results = await server.agent_manager.query_agent_passages_async(
actor=default_user,
agent_id=sarah_agent.id,
query_text="native postgresql",
embedding_config=embedding_config,
embed_query=True,
)
# Should still work with native PostgreSQL
assert isinstance(vector_results, list)
@pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True)
class TestTurbopufferParametrized:
"""Test that functionality works with and without Turbopuffer enabled"""
@pytest.mark.asyncio
async def test_passage_operations_with_mode(self, turbopuffer_mode, server, default_user, sarah_agent):
"""Test that passage operations work in both modes"""
# Get or create archive
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
# Check that vector_db_provider matches the mode
if settings.use_tpuf and settings.tpuf_api_key:
expected_provider = VectorDBProvider.TPUF
else:
expected_provider = VectorDBProvider.NATIVE
assert archive.vector_db_provider == expected_provider
# Test inserting a passage (should work in both modes)
test_text = f"Test passage for {expected_provider} mode"
passages = await server.passage_manager.insert_passage(agent_state=sarah_agent, text=test_text, actor=default_user)
assert passages is not None
assert len(passages) > 0
assert passages[0].text == test_text
# List passages should work in both modes
listed = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
assert any(p.text == test_text for p in listed)
# Delete should work in both modes
await server.passage_manager.delete_agent_passages_async(passages, default_user)
# Verify deletion
remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
assert not any(p.id == passages[0].id for p in remaining)