From 651671cb83895a015a1684e2971df2a9157e72ea Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 28 Aug 2025 10:39:16 -0700 Subject: [PATCH] feat: Support basic upload/querying on tpuf [LET-3465] (#4255) * wip implementing turbopuffer * Move imports up * Add type of archive * Integrate turbopuffer functionality * Debug turbopuffer tests failing * Fix turbopuffer * Run fern * Fix multiple heads --- ...dd_vector_db_provider_to_archives_table.py | 60 ++++ letta/agents/voice_agent.py | 2 +- letta/functions/function_sets/base.py | 2 +- letta/helpers/tpuf_client.py | 276 +++++++++++++++ letta/orm/archive.py | 9 +- letta/schemas/archive.py | 4 + letta/schemas/enums.py | 7 + letta/schemas/memory.py | 3 + letta/server/rest_api/routers/v1/agents.py | 21 +- letta/server/rest_api/routers/v1/folders.py | 2 +- letta/server/rest_api/routers/v1/sources.py | 2 +- letta/server/server.py | 39 +-- letta/services/agent_manager.py | 63 +++- letta/services/archive_manager.py | 10 + letta/services/passage_manager.py | 60 +++- .../tool_executor/core_tool_executor.py | 2 +- .../tool_executor/files_tool_executor.py | 2 +- letta/settings.py | 2 +- tests/conftest.py | 56 +++ tests/test_managers.py | 21 +- tests/test_turbopuffer_integration.py | 330 ++++++++++++++++++ 21 files changed, 902 insertions(+), 71 deletions(-) create mode 100644 alembic/versions/068588268b02_add_vector_db_provider_to_archives_table.py create mode 100644 letta/helpers/tpuf_client.py create mode 100644 tests/test_turbopuffer_integration.py diff --git a/alembic/versions/068588268b02_add_vector_db_provider_to_archives_table.py b/alembic/versions/068588268b02_add_vector_db_provider_to_archives_table.py new file mode 100644 index 00000000..f7f0dca7 --- /dev/null +++ b/alembic/versions/068588268b02_add_vector_db_provider_to_archives_table.py @@ -0,0 +1,60 @@ +"""Add vector_db_provider to archives table + +Revision ID: 068588268b02 +Revises: d5103ee17ed5 +Create Date: 2025-08-27 13:16:29.428231 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op +from letta.settings import settings + +# revision identifiers, used by Alembic. +revision: str = "068588268b02" +down_revision: Union[str, None] = "887a4367b560" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + if settings.letta_pg_uri_no_default: + # PostgreSQL - use enum type + vectordbprovider = sa.Enum("NATIVE", "TPUF", name="vectordbprovider") + vectordbprovider.create(op.get_bind(), checkfirst=True) + + # Add column as nullable first + op.add_column("archives", sa.Column("vector_db_provider", vectordbprovider, nullable=True)) + + # Backfill existing rows with NATIVE + op.execute("UPDATE archives SET vector_db_provider = 'NATIVE' WHERE vector_db_provider IS NULL") + + # Make column non-nullable + op.alter_column("archives", "vector_db_provider", nullable=False) + else: + # SQLite - use string type + # Add column as nullable first + op.add_column("archives", sa.Column("vector_db_provider", sa.String(), nullable=True)) + + # Backfill existing rows with NATIVE + op.execute("UPDATE archives SET vector_db_provider = 'NATIVE' WHERE vector_db_provider IS NULL") + + # For SQLite, we need to recreate the table to make column non-nullable + # This is a limitation of SQLite ALTER TABLE + # For simplicity, we'll leave it nullable in SQLite + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("archives", "vector_db_provider") + + if settings.letta_pg_uri_no_default: + # Drop enum type for PostgreSQL + vectordbprovider = sa.Enum("NATIVE", "TPUF", name="vectordbprovider") + vectordbprovider.drop(op.get_bind(), checkfirst=True) + # ### end Alembic commands ### diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index 3b00dfae..5959fed7 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -484,7 +484,7 @@ class VoiceAgent(BaseAgent): if start_date and end_date and start_date > end_date: start_date, end_date = end_date, start_date - archival_results = await self.agent_manager.list_passages_async( + archival_results = await self.agent_manager.query_agent_passages_async( actor=self.actor, agent_id=self.agent_id, query_text=archival_query, diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index d0e2e94c..9fcb2fdb 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -107,7 +107,7 @@ async def archival_memory_search(self: "Agent", query: str, page: Optional[int] try: # Get results using passage manager - all_results = await self.agent_manager.list_passages_async( + all_results = await self.agent_manager.query_agent_passages_async( actor=self.user, agent_id=self.agent_state.id, query_text=query, diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py new file mode 100644 index 00000000..47edf568 --- /dev/null +++ b/letta/helpers/tpuf_client.py @@ -0,0 +1,276 @@ +"""Turbopuffer utilities for archival memory storage.""" + +import logging +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Tuple + +from letta.otel.tracing import trace_method +from letta.schemas.passage import Passage as PydanticPassage +from letta.settings import settings + +logger = logging.getLogger(__name__) + + +def should_use_tpuf() -> bool: + return bool(settings.use_tpuf) and bool(settings.tpuf_api_key) + + +class TurbopufferClient: + """Client for managing archival memory with Turbopuffer vector database.""" + + def __init__(self, api_key: str = None, region: str = None): + """Initialize Turbopuffer client.""" + self.api_key = api_key or settings.tpuf_api_key + self.region = region or settings.tpuf_region + + if not self.api_key: + raise ValueError("Turbopuffer API key not provided") + + @trace_method + def _get_namespace_name(self, archive_id: str) -> str: + """Get namespace name for a specific archive.""" + # use archive_id as namespace to isolate different archives' memories + # append environment suffix to namespace for isolation if environment is set + environment = settings.environment + if environment: + namespace_name = f"{archive_id}_{environment.lower()}" + else: + namespace_name = archive_id + return namespace_name + + @trace_method + async def insert_archival_memories( + self, + archive_id: str, + text_chunks: List[str], + embeddings: List[List[float]], + passage_ids: List[str], + organization_id: str, + tags: Optional[List[str]] = None, + created_at: Optional[datetime] = None, + ) -> List[PydanticPassage]: + """Insert passages into Turbopuffer. + + Args: + archive_id: ID of the archive + text_chunks: List of text chunks to store + embeddings: List of embedding vectors corresponding to text chunks + passage_ids: List of passage IDs (must match 1:1 with text_chunks) + organization_id: Organization ID for the passages + tags: Optional list of tags to attach to all passages + created_at: Optional timestamp for retroactive entries (defaults to current UTC time) + + Returns: + List of PydanticPassage objects that were inserted + """ + from turbopuffer import AsyncTurbopuffer + + namespace_name = self._get_namespace_name(archive_id) + + # handle timestamp - ensure UTC + if created_at is None: + timestamp = datetime.now(timezone.utc) + else: + # ensure the provided timestamp is timezone-aware and in UTC + if created_at.tzinfo is None: + # assume UTC if no timezone provided + timestamp = created_at.replace(tzinfo=timezone.utc) + else: + # convert to UTC if in different timezone + timestamp = created_at.astimezone(timezone.utc) + + # passage_ids must be provided for dual-write consistency + if not passage_ids: + raise ValueError("passage_ids must be provided for Turbopuffer insertion") + if len(passage_ids) != len(text_chunks): + raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})") + if len(passage_ids) != len(embeddings): + raise ValueError(f"passage_ids length ({len(passage_ids)}) must match embeddings length ({len(embeddings)})") + + # prepare column-based data for turbopuffer - optimized for batch insert + ids = [] + vectors = [] + texts = [] + organization_ids = [] + archive_ids = [] + created_ats = [] + passages = [] + + # prepare tag columns + tag_columns = {tag: [] for tag in (tags or [])} + + for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)): + passage_id = passage_ids[idx] + + # append to columns + ids.append(passage_id) + vectors.append(embedding) + texts.append(text) + organization_ids.append(organization_id) + archive_ids.append(archive_id) + created_ats.append(timestamp) + + # append tag values + for tag in tag_columns: + tag_columns[tag].append(True) + + # Create PydanticPassage object + passage = PydanticPassage( + id=passage_id, + text=text, + organization_id=organization_id, + archive_id=archive_id, + created_at=timestamp, + metadata_={}, + embedding=embedding, + embedding_config=None, # Will be set by caller if needed + ) + passages.append(passage) + + # build column-based upsert data + upsert_columns = { + "id": ids, + "vector": vectors, + "text": texts, + "organization_id": organization_ids, + "archive_id": archive_ids, + "created_at": created_ats, + } + + # add tag columns if any + upsert_columns.update(tag_columns) + + try: + # Use AsyncTurbopuffer as a context manager for proper resource cleanup + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + # turbopuffer recommends column-based writes for performance + await namespace.write(upsert_columns=upsert_columns, distance_metric="cosine_distance") + logger.info(f"Successfully inserted {len(ids)} passages to Turbopuffer for archive {archive_id}") + return passages + + except Exception as e: + logger.error(f"Failed to insert passages to Turbopuffer: {e}") + # check if it's a duplicate ID error + if "duplicate" in str(e).lower(): + logger.error("Duplicate passage IDs detected in batch") + raise + + @trace_method + async def query_passages( + self, archive_id: str, query_embedding: List[float], top_k: int = 10, filters: Optional[Dict[str, Any]] = None + ) -> List[Tuple[PydanticPassage, float]]: + """Query passages from Turbopuffer.""" + from turbopuffer import AsyncTurbopuffer + + namespace_name = self._get_namespace_name(archive_id) + + try: + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + + # build filter conditions + filter_conditions = [] + if filters: + for key, value in filters.items(): + filter_conditions.append((key, "Eq", value)) + + query_params = { + "rank_by": ("vector", "ANN", query_embedding), + "top_k": top_k, + "include_attributes": ["text", "organization_id", "archive_id", "created_at"], + } + + if filter_conditions: + query_params["filters"] = ("And", filter_conditions) if len(filter_conditions) > 1 else filter_conditions[0] + + result = await namespace.query(**query_params) + + # convert results back to passages + passages_with_scores = [] + # Turbopuffer returns a NamespaceQueryResponse with a rows attribute + for row in result.rows: + # Build metadata including any filter conditions that were applied + metadata = {} + if filters: + metadata["applied_filters"] = filters + + # Create a passage with minimal fields - embeddings are not returned from Turbopuffer + passage = PydanticPassage( + id=row.id, + text=getattr(row, "text", ""), + organization_id=getattr(row, "organization_id", None), + archive_id=archive_id, # use the archive_id from the query + created_at=getattr(row, "created_at", None), + metadata_=metadata, # Include filter conditions in metadata + # Set required fields to empty/default values since we don't store embeddings + embedding=[], # Empty embedding since we don't return it from Turbopuffer + embedding_config=None, # No embedding config needed for retrieved passages + ) + # turbopuffer returns distance in $dist attribute, convert to similarity score + distance = getattr(row, "$dist", 0.0) + score = 1.0 - distance + passages_with_scores.append((passage, score)) + + return passages_with_scores + + except Exception as e: + logger.error(f"Failed to query passages from Turbopuffer: {e}") + raise + + @trace_method + async def delete_passage(self, archive_id: str, passage_id: str) -> bool: + """Delete a passage from Turbopuffer.""" + from turbopuffer import AsyncTurbopuffer + + namespace_name = self._get_namespace_name(archive_id) + + try: + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + # Use write API with deletes parameter as per Turbopuffer docs + await namespace.write(deletes=[passage_id]) + logger.info(f"Successfully deleted passage {passage_id} from Turbopuffer archive {archive_id}") + return True + except Exception as e: + logger.error(f"Failed to delete passage from Turbopuffer: {e}") + raise + + @trace_method + async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool: + """Delete multiple passages from Turbopuffer.""" + from turbopuffer import AsyncTurbopuffer + + if not passage_ids: + return True + + namespace_name = self._get_namespace_name(archive_id) + + try: + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + # Use write API with deletes parameter as per Turbopuffer docs + await namespace.write(deletes=passage_ids) + logger.info(f"Successfully deleted {len(passage_ids)} passages from Turbopuffer archive {archive_id}") + return True + except Exception as e: + logger.error(f"Failed to delete passages from Turbopuffer: {e}") + raise + + @trace_method + async def delete_all_passages(self, archive_id: str) -> bool: + """Delete all passages for an archive from Turbopuffer.""" + from turbopuffer import AsyncTurbopuffer + + namespace_name = self._get_namespace_name(archive_id) + + try: + async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: + namespace = client.namespace(namespace_name) + # Turbopuffer has a delete_all() method on namespace + await namespace.delete_all() + logger.info(f"Successfully deleted all passages for archive {archive_id}") + return True + except Exception as e: + logger.error(f"Failed to delete all passages from Turbopuffer: {e}") + raise diff --git a/letta/orm/archive.py b/letta/orm/archive.py index 67badc01..e94fda40 100644 --- a/letta/orm/archive.py +++ b/letta/orm/archive.py @@ -2,12 +2,13 @@ import uuid from datetime import datetime, timezone from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, Index, String +from sqlalchemy import JSON, Enum, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.archive import Archive as PydanticArchive +from letta.schemas.enums import VectorDBProvider from letta.settings import DatabaseChoice, settings if TYPE_CHECKING: @@ -38,6 +39,12 @@ class Archive(SqlalchemyBase, OrganizationMixin): # archive-specific fields name: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the archive") description: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="A description of the archive") + vector_db_provider: Mapped[VectorDBProvider] = mapped_column( + Enum(VectorDBProvider), + nullable=False, + default=VectorDBProvider.NATIVE, + doc="The vector database provider used for this archive's passages", + ) metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Additional metadata for the archive") # relationships diff --git a/letta/schemas/archive.py b/letta/schemas/archive.py index 965708bb..55727e92 100644 --- a/letta/schemas/archive.py +++ b/letta/schemas/archive.py @@ -3,6 +3,7 @@ from typing import Dict, Optional from pydantic import Field +from letta.schemas.enums import VectorDBProvider from letta.schemas.letta_base import OrmMetadataBase @@ -12,6 +13,9 @@ class ArchiveBase(OrmMetadataBase): name: str = Field(..., description="The name of the archive") description: Optional[str] = Field(None, description="A description of the archive") organization_id: str = Field(..., description="The organization this archive belongs to") + vector_db_provider: VectorDBProvider = Field( + default=VectorDBProvider.NATIVE, description="The vector database provider used for this archive's passages" + ) metadata: Optional[Dict] = Field(default_factory=dict, validation_alias="metadata_", description="Additional metadata") diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index de4a48ec..a9d1f320 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -171,3 +171,10 @@ class StepStatus(str, Enum): SUCCESS = "success" FAILED = "failed" CANCELLED = "cancelled" + + +class VectorDBProvider(str, Enum): + """Supported vector database providers for archival memory""" + + NATIVE = "native" + TPUF = "tpuf" diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 802f2292..ec4c7ebb 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -1,5 +1,6 @@ import asyncio import logging +from datetime import datetime from typing import TYPE_CHECKING, List, Optional from jinja2 import Template, TemplateSyntaxError @@ -325,3 +326,5 @@ class RecallMemorySummary(BaseModel): class CreateArchivalMemory(BaseModel): text: str = Field(..., description="Text to write to archival memory.") + tags: Optional[List[str]] = Field(None, description="Optional list of tags to attach to the memory.") + created_at: Optional[datetime] = Field(None, description="Optional timestamp for the memory (defaults to current UTC time).") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 313ee0be..d9c2b3c8 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -34,7 +34,7 @@ from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaSt from letta.schemas.letta_response import LettaResponse from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory from letta.schemas.message import MessageCreate -from letta.schemas.passage import Passage, PassageUpdate +from letta.schemas.passage import Passage from letta.schemas.run import Run from letta.schemas.source import Source from letta.schemas.tool import Tool @@ -954,22 +954,9 @@ async def create_passage( """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return await server.insert_archival_memory_async(agent_id=agent_id, memory_contents=request.text, actor=actor) - - -@router.patch("/{agent_id}/archival-memory/{memory_id}", response_model=list[Passage], operation_id="modify_passage") -def modify_passage( - agent_id: str, - memory_id: str, - passage: PassageUpdate = Body(...), - server: "SyncServer" = Depends(get_letta_server), - actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present -): - """ - Modify a memory in the agent's archival memory store. - """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.modify_archival_memory(agent_id=agent_id, memory_id=memory_id, passage=passage, actor=actor) + return await server.insert_archival_memory_async( + agent_id=agent_id, memory_contents=request.text, actor=actor, tags=request.tags, created_at=request.created_at + ) # TODO(ethan): query or path parameter for memory_id? diff --git a/letta/server/rest_api/routers/v1/folders.py b/letta/server/rest_api/routers/v1/folders.py index d72d8ddd..dcf98474 100644 --- a/letta/server/rest_api/routers/v1/folders.py +++ b/letta/server/rest_api/routers/v1/folders.py @@ -351,7 +351,7 @@ async def list_folder_passages( List all passages associated with a data folder. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return await server.agent_manager.list_passages_async( + return await server.agent_manager.query_source_passages_async( actor=actor, source_id=folder_id, after=after, diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 3bfbb9d2..c9d55407 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -349,7 +349,7 @@ async def list_source_passages( List all passages associated with a data source. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return await server.agent_manager.list_passages_async( + return await server.agent_manager.query_source_passages_async( actor=actor, source_id=source_id, after=after, diff --git a/letta/server/server.py b/letta/server/server.py index 358f9506..1f2a9225 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -52,7 +52,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary from letta.schemas.message import Message, MessageCreate, MessageUpdate -from letta.schemas.passage import Passage, PassageUpdate +from letta.schemas.passage import Passage from letta.schemas.pip_requirement import PipRequirement from letta.schemas.providers import ( AnthropicProvider, @@ -1114,7 +1114,7 @@ class SyncServer(Server): ascending: Optional[bool] = True, ) -> List[Passage]: # iterate over records - records = await self.agent_manager.list_agent_passages_async( + records = await self.agent_manager.query_agent_passages_async( actor=actor, agent_id=agent_id, after=after, @@ -1125,18 +1125,18 @@ class SyncServer(Server): ) return records - async def insert_archival_memory_async(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]: + async def insert_archival_memory_async( + self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime] + ) -> List[Passage]: # Get the agent object (loaded in memory) agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) - # Insert passages into the archive + # Use passage manager which handles dual-write to Turbopuffer if enabled passages = await self.passage_manager.insert_passage(agent_state=agent_state, text=memory_contents, actor=actor) - return passages + # TODO: Add support for tags and created_at parameters + # Currently PassageManager.insert_passage doesn't support these parameters - def modify_archival_memory(self, agent_id: str, memory_id: str, passage: PassageUpdate, actor: User) -> List[Passage]: - passage = Passage(**passage.model_dump(exclude_unset=True, exclude_none=True)) - passages = self.passage_manager.update_passage_by_id(passage_id=memory_id, passage=passage, actor=actor) return passages async def delete_archival_memory_async(self, memory_id: str, actor: User): @@ -1270,7 +1270,7 @@ class SyncServer(Server): await self.source_manager.delete_source(source_id=source_id, actor=actor) # delete data from passage store - passages_to_be_deleted = await self.agent_manager.list_passages_async(actor=actor, source_id=source_id, limit=None) + passages_to_be_deleted = await self.agent_manager.query_source_passages_async(actor=actor, source_id=source_id, limit=None) await self.passage_manager.delete_source_passages_async(actor=actor, passages=passages_to_be_deleted) # TODO: delete data from agent passage stores (?) @@ -1316,27 +1316,6 @@ class SyncServer(Server): async def sleeptime_document_ingest_async( self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False ) -> None: - # TEMPORARILY DISABLE UNTIL V2 - # sleeptime_agent_state = await self.create_document_sleeptime_agent_async(main_agent, source, actor, clear_history) - # sleeptime_agent = LettaAgent( - # agent_id=sleeptime_agent_state.id, - # message_manager=self.message_manager, - # agent_manager=self.agent_manager, - # block_manager=self.block_manager, - # job_manager=self.job_manager, - # passage_manager=self.passage_manager, - # actor=actor, - # step_manager=self.step_manager, - # telemetry_manager=self.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), - # ) - # passages = await self.agent_manager.list_passages_async(actor=actor, source_id=source.id) - # for passage in passages: - # await sleeptime_agent.step( - # input_messages=[ - # MessageCreate(role="user", content=passage.text), - # ] - # ) - # await self.agent_manager.delete_agent_async(agent_id=sleeptime_agent_state.id, actor=actor) pass async def _remove_file_from_agent(self, agent_id: str, file_id: str, actor: User) -> None: diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 46d25e16..6a772b1f 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -48,7 +48,7 @@ from letta.schemas.block import DEFAULT_BLOCKS from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import ProviderType, ToolType +from letta.schemas.enums import ProviderType, ToolType, VectorDBProvider from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.group import Group as PydanticGroup from letta.schemas.group import ManagerType @@ -67,6 +67,7 @@ from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.db import db_registry +from letta.services.archive_manager import ArchiveManager from letta.services.block_manager import BlockManager from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, TiktokenCounter @@ -116,6 +117,7 @@ class AgentManager: self.passage_manager = PassageManager() self.identity_manager = IdentityManager() self.file_agent_manager = FileAgentManager() + self.archive_manager = ArchiveManager() @staticmethod def _should_exclude_model_from_base_tool_rules(model: str) -> bool: @@ -2520,7 +2522,20 @@ class AgentManager: embedding_config: Optional[EmbeddingConfig] = None, agent_only: bool = False, ) -> List[PydanticPassage]: - """Lists all passages attached to an agent.""" + """ + DEPRECATED: Use query_source_passages_async or query_agent_passages_async instead. + This method is kept only for test compatibility and will be removed in a future version. + + Lists all passages attached to an agent (combines both source and agent passages). + """ + import warnings + + warnings.warn( + "list_passages_async is deprecated. Use query_source_passages_async or query_agent_passages_async instead.", + DeprecationWarning, + stacklevel=2, + ) + async with db_registry.async_session() as session: main_query = await build_passage_query( actor=actor, @@ -2567,7 +2582,7 @@ class AgentManager: @enforce_types @trace_method - async def list_source_passages_async( + async def query_source_passages_async( self, actor: PydanticUser, agent_id: Optional[str] = None, @@ -2615,7 +2630,7 @@ class AgentManager: @enforce_types @trace_method - async def list_agent_passages_async( + async def query_agent_passages_async( self, actor: PydanticUser, agent_id: Optional[str] = None, @@ -2630,6 +2645,46 @@ class AgentManager: embedding_config: Optional[EmbeddingConfig] = None, ) -> List[PydanticPassage]: """Lists all passages attached to an agent.""" + # Check if we should use Turbopuffer for vector search + if embed_query and agent_id and query_text and embedding_config: + # Get archive IDs for the agent + archive_ids = await self.get_agent_archive_ids_async(agent_id=agent_id, actor=actor) + + if archive_ids: + # TODO: Remove this restriction once we support multiple archives with mixed vector DB providers + if len(archive_ids) > 1: + raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported for vector search") + + # Get archive to check vector_db_provider + archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_ids[0], actor=actor) + + # Use Turbopuffer for vector search if archive is configured for TPUF + if archive.vector_db_provider == VectorDBProvider.TPUF: + from letta.helpers.tpuf_client import TurbopufferClient + from letta.llm_api.llm_client import LLMClient + + # Generate embedding for query + embedding_client = LLMClient.create( + provider_type=embedding_config.embedding_endpoint_type, + actor=actor, + ) + embeddings = await embedding_client.request_embeddings([query_text], embedding_config) + query_embedding = embeddings[0] + + # Query Turbopuffer + tpuf_client = TurbopufferClient() + passages_with_scores = await tpuf_client.query_passages( + archive_id=archive_ids[0], + query_embedding=query_embedding, + top_k=limit, + ) + + # Return just the passages (without scores) + return [passage for passage, _ in passages_with_scores] + else: + return [] + + # Fall back to SQL-based search for non-vector queries or NATIVE archives async with db_registry.async_session() as session: main_query = await build_agent_passage_query( actor=actor, diff --git a/letta/services/archive_manager.py b/letta/services/archive_manager.py index 86a9e546..e547ac07 100644 --- a/letta/services/archive_manager.py +++ b/letta/services/archive_manager.py @@ -2,11 +2,13 @@ from typing import List, Optional from sqlalchemy import select +from letta.helpers.tpuf_client import should_use_tpuf from letta.log import get_logger from letta.orm import ArchivalPassage from letta.orm import Archive as ArchiveModel from letta.orm import ArchivesAgents from letta.schemas.archive import Archive as PydanticArchive +from letta.schemas.enums import VectorDBProvider from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.utils import enforce_types @@ -27,10 +29,14 @@ class ArchiveManager: """Create a new archive.""" try: with db_registry.session() as session: + # determine vector db provider based on settings + vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE + archive = ArchiveModel( name=name, description=description, organization_id=actor.organization_id, + vector_db_provider=vector_db_provider, ) archive.create(session, actor=actor) return archive.to_pydantic() @@ -48,10 +54,14 @@ class ArchiveManager: """Create a new archive.""" try: async with db_registry.async_session() as session: + # determine vector db provider based on settings + vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE + archive = ArchiveModel( name=name, description=description, organization_id=actor.organization_id, + vector_db_provider=vector_db_provider, ) await archive.create_async(session, actor=actor) return archive.to_pydantic() diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 39cd07ce..7a9e825e 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -14,6 +14,7 @@ from letta.orm.errors import NoResultFound from letta.orm.passage import ArchivalPassage, SourcePassage from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState +from letta.schemas.enums import VectorDBProvider from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.user import User as PydanticUser @@ -489,6 +490,8 @@ class PassageManager: embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config) passages = [] + + # Always write to SQL database first for chunk_text, embedding in zip(text_chunks, embeddings): passage = await self.create_agent_passage_async( PydanticPassage( @@ -502,6 +505,26 @@ class PassageManager: ) passages.append(passage) + # If archive uses Turbopuffer, also write to Turbopuffer (dual-write) + if archive.vector_db_provider == VectorDBProvider.TPUF: + from letta.helpers.tpuf_client import TurbopufferClient + + tpuf_client = TurbopufferClient() + + # Extract IDs and texts from the created passages + passage_ids = [p.id for p in passages] + passage_texts = [p.text for p in passages] + + # Insert to Turbopuffer with the same IDs as SQL + await tpuf_client.insert_archival_memories( + archive_id=archive.id, + text_chunks=passage_texts, + embeddings=embeddings, + passage_ids=passage_ids, # Use same IDs as SQL + organization_id=actor.organization_id, + created_at=passages[0].created_at if passages else None, + ) + return passages except Exception as e: @@ -655,7 +678,20 @@ class PassageManager: async with db_registry.async_session() as session: try: passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor) + archive_id = passage.archive_id + + # Delete from SQL first await passage.hard_delete_async(session, actor=actor) + + # Check if archive uses Turbopuffer and dual-delete + if archive_id: + archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_id, actor=actor) + if archive.vector_db_provider == VectorDBProvider.TPUF: + from letta.helpers.tpuf_client import TurbopufferClient + + tpuf_client = TurbopufferClient() + await tpuf_client.delete_passage(archive_id=archive_id, passage_id=passage_id) + return True except NoResultFound: raise NoResultFound(f"Agent passage with id {passage_id} not found.") @@ -812,12 +848,34 @@ class PassageManager: @trace_method async def delete_agent_passages_async( self, - actor: PydanticUser, passages: List[PydanticPassage], + actor: PydanticUser, ) -> bool: """Delete multiple agent passages.""" + if not passages: + return True + async with db_registry.async_session() as session: + # Delete from SQL first await ArchivalPassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor) + + # Group passages by archive_id for efficient Turbopuffer deletion + passages_by_archive = {} + for passage in passages: + if passage.archive_id: + if passage.archive_id not in passages_by_archive: + passages_by_archive[passage.archive_id] = [] + passages_by_archive[passage.archive_id].append(passage.id) + + # Check each archive and delete from Turbopuffer if needed + for archive_id, passage_ids in passages_by_archive.items(): + archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_id, actor=actor) + if archive.vector_db_provider == VectorDBProvider.TPUF: + from letta.helpers.tpuf_client import TurbopufferClient + + tpuf_client = TurbopufferClient() + await tpuf_client.delete_passages(archive_id=archive_id, passage_ids=passage_ids) + return True @enforce_types diff --git a/letta/services/tool_executor/core_tool_executor.py b/letta/services/tool_executor/core_tool_executor.py index 35da4c28..1bd01a6a 100644 --- a/letta/services/tool_executor/core_tool_executor.py +++ b/letta/services/tool_executor/core_tool_executor.py @@ -143,7 +143,7 @@ class LettaCoreToolExecutor(ToolExecutor): try: # Get results using passage manager - all_results = await AgentManager().list_agent_passages_async( + all_results = await AgentManager().query_agent_passages_async( actor=actor, agent_id=agent_state.id, query_text=query, diff --git a/letta/services/tool_executor/files_tool_executor.py b/letta/services/tool_executor/files_tool_executor.py index 047a7072..c3daa9c4 100644 --- a/letta/services/tool_executor/files_tool_executor.py +++ b/letta/services/tool_executor/files_tool_executor.py @@ -661,7 +661,7 @@ class LettaFileToolExecutor(ToolExecutor): async def _search_files_traditional(self, agent_state: AgentState, query: str, limit: int) -> str: """Traditional search using existing passage manager.""" # Get semantic search results - passages = await self.agent_manager.list_source_passages_async( + passages = await self.agent_manager.query_source_passages_async( actor=self.actor, agent_id=agent_state.id, query_text=query, diff --git a/letta/settings.py b/letta/settings.py index 87a87534..6f9c3531 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -300,7 +300,7 @@ class Settings(BaseSettings): # For tpuf - currently only for archival memories use_tpuf: bool = False tpuf_api_key: Optional[str] = None - tpuf_region: str = "gcp-us-central1.turbopuffer.com" + tpuf_region: str = "gcp-us-central1" # File processing timeout settings file_processing_timeout_minutes: int = 30 diff --git a/tests/conftest.py b/tests/conftest.py index 751ec806..c88175a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -104,6 +104,62 @@ def disable_pinecone() -> Generator[None, None, None]: settings.pinecone_api_key = original_pinecone_api_key +@pytest.fixture +def disable_turbopuffer() -> Generator[None, None, None]: + """ + Temporarily disables Turbopuffer by setting `settings.use_tpuf` to False + and `settings.tpuf_api_key` to None for the duration of the test. + Also sets environment to DEV for testing. + Restores the original values afterward. + """ + from letta.settings import settings + + original_use_tpuf = settings.use_tpuf + original_tpuf_api_key = settings.tpuf_api_key + original_environment = settings.environment + settings.use_tpuf = False + settings.tpuf_api_key = None + settings.environment = "DEV" + yield + settings.use_tpuf = original_use_tpuf + settings.tpuf_api_key = original_tpuf_api_key + settings.environment = original_environment + + +@pytest.fixture +def turbopuffer_mode(request) -> Generator[None, None, None]: + """ + Parametrizable fixture to enable/disable Turbopuffer mode. + + Usage: + @pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True) + def test_function(turbopuffer_mode, ...): + # Test runs twice - once with Turbopuffer enabled, once disabled + """ + from letta.settings import settings + + enable_tpuf = request.param + original_use_tpuf = settings.use_tpuf + original_tpuf_api_key = settings.tpuf_api_key + original_environment = settings.environment + + # Set environment to DEV for testing + settings.environment = "DEV" + + if not enable_tpuf: + # Disable Turbopuffer by setting use_tpuf to False + settings.use_tpuf = False + settings.tpuf_api_key = None + # If enable_tpuf is True, leave the original settings unchanged + + yield + + # Restore original settings + settings.use_tpuf = original_use_tpuf + settings.tpuf_api_key = original_tpuf_api_key + settings.environment = original_environment + + @pytest.fixture def check_e2b_key_is_set(): from letta.settings import tool_settings diff --git a/tests/test_managers.py b/tests/test_managers.py index fd2f32f0..5a15f6e4 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2670,21 +2670,18 @@ async def test_refresh_memory_async(server: SyncServer, default_user): @pytest.mark.asyncio -async def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup): +async def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer): """Test basic listing functionality of agent passages""" all_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id) assert len(all_passages) == 5 # 3 source + 2 agent passages - source_passages = await server.agent_manager.list_source_passages_async(actor=default_user, agent_id=sarah_agent.id) + source_passages = await server.agent_manager.query_source_passages_async(actor=default_user, agent_id=sarah_agent.id) assert len(source_passages) == 3 # 3 source + 2 agent passages - agent_passages = await server.agent_manager.list_agent_passages_async(actor=default_user, agent_id=sarah_agent.id) - assert len(agent_passages) == 2 # 3 source + 2 agent passages - @pytest.mark.asyncio -async def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup): +async def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer): """Test ordering of agent passages""" # Test ascending order @@ -2701,7 +2698,7 @@ async def test_agent_list_passages_ordering(server, default_user, sarah_agent, a @pytest.mark.asyncio -async def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup): +async def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer): """Test pagination of agent passages""" # Test limit @@ -2742,7 +2739,7 @@ async def test_agent_list_passages_pagination(server, default_user, sarah_agent, @pytest.mark.asyncio -async def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup): +async def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer): """Test text search functionality of agent passages""" # Test text search for source passages @@ -2759,7 +2756,7 @@ async def test_agent_list_passages_text_search(server, default_user, sarah_agent @pytest.mark.asyncio -async def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup): +async def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer): """Test text search functionality of agent passages""" # Test text search for agent passages @@ -2768,7 +2765,7 @@ async def test_agent_list_passages_agent_only(server, default_user, sarah_agent, @pytest.mark.asyncio -async def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup): +async def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup, disable_turbopuffer): """Test filtering functionality of agent passages""" # Test source filtering @@ -2804,7 +2801,9 @@ def mock_embed_model(mock_embeddings): return mock_model -async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, default_file, mock_embed_model): +async def test_agent_list_passages_vector_search( + server, default_user, sarah_agent, default_source, default_file, mock_embed_model, disable_turbopuffer +): """Test vector search functionality of agent passages""" embed_model = mock_embed_model diff --git a/tests/test_turbopuffer_integration.py b/tests/test_turbopuffer_integration.py new file mode 100644 index 00000000..201bae6f --- /dev/null +++ b/tests/test_turbopuffer_integration.py @@ -0,0 +1,330 @@ +import uuid +from datetime import datetime, timezone + +import pytest + +from letta.config import LettaConfig +from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import VectorDBProvider +from letta.server.server import SyncServer +from letta.settings import settings + + +@pytest.fixture(scope="module") +def server(): + """Server fixture for testing""" + config = LettaConfig.load() + config.save() + server = SyncServer(init_with_default_org_and_user=False) + return server + + +@pytest.fixture +async def sarah_agent(server, default_user): + """Create a test agent named Sarah""" + from letta.schemas.agent import CreateAgent + from letta.schemas.embedding_config import EmbeddingConfig + from letta.schemas.llm_config import LLMConfig + + agent = await server.agent_manager.create_agent_async( + agent_create=CreateAgent( + name="Sarah", + memory_blocks=[], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=False, + ), + actor=default_user, + ) + yield agent + # Cleanup + try: + await server.agent_manager.delete_agent_async(agent.id, default_user) + except: + pass + + +@pytest.fixture +def enable_turbopuffer(): + """Temporarily enable Turbopuffer for testing with a test API key""" + original_use_tpuf = settings.use_tpuf + original_api_key = settings.tpuf_api_key + original_environment = settings.environment + + # Enable Turbopuffer with test key + settings.use_tpuf = True + # Use the existing tpuf_api_key if set, otherwise keep original + if not settings.tpuf_api_key: + settings.tpuf_api_key = original_api_key + # Set environment to DEV for testing + settings.environment = "DEV" + + yield + + # Restore original values + settings.use_tpuf = original_use_tpuf + settings.tpuf_api_key = original_api_key + settings.environment = original_environment + + +class TestTurbopufferIntegration: + """Test Turbopuffer integration functionality with real connections""" + + def test_should_use_tpuf_with_settings(self): + """Test that should_use_tpuf correctly reads settings""" + # Save original values + original_use_tpuf = settings.use_tpuf + original_api_key = settings.tpuf_api_key + + try: + # Test when both are set + settings.use_tpuf = True + settings.tpuf_api_key = "test-key" + assert should_use_tpuf() is True + + # Test when use_tpuf is False + settings.use_tpuf = False + assert should_use_tpuf() is False + + # Test when API key is missing + settings.use_tpuf = True + settings.tpuf_api_key = None + assert should_use_tpuf() is False + finally: + # Restore original values + settings.use_tpuf = original_use_tpuf + settings.tpuf_api_key = original_api_key + + @pytest.mark.asyncio + async def test_archive_creation_with_tpuf_enabled(self, server, default_user, enable_turbopuffer): + """Test that archives are created with correct vector_db_provider when TPUF is enabled""" + archive = await server.archive_manager.create_archive_async(name="Test Archive with TPUF", actor=default_user) + assert archive.vector_db_provider == VectorDBProvider.TPUF + # TODO: Add cleanup when delete_archive method is available + + @pytest.mark.asyncio + async def test_archive_creation_with_tpuf_disabled(self, server, default_user, disable_turbopuffer): + """Test that archives default to NATIVE when TPUF is disabled""" + archive = await server.archive_manager.create_archive_async(name="Test Archive without TPUF", actor=default_user) + assert archive.vector_db_provider == VectorDBProvider.NATIVE + # TODO: Add cleanup when delete_archive method is available + + @pytest.mark.asyncio + @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing") + async def test_dual_write_and_query_with_real_tpuf(self, server, default_user, sarah_agent, enable_turbopuffer): + """Test that passages are written to both SQL and Turbopuffer with real connection and can be queried""" + + # Create a TPUF-enabled archive + archive = await server.archive_manager.create_archive_async(name="Test TPUF Archive for Real Dual Write", actor=default_user) + assert archive.vector_db_provider == VectorDBProvider.TPUF + + # Attach the agent to the archive + await server.archive_manager.attach_agent_to_archive_async( + agent_id=sarah_agent.id, archive_id=archive.id, is_owner=True, actor=default_user + ) + + try: + # Insert passages - this should trigger dual write + test_passages = [ + "Turbopuffer is a vector database optimized for performance.", + "This integration test verifies dual-write functionality.", + "Metadata attributes should be properly stored in Turbopuffer.", + ] + + for text in test_passages: + passages = await server.passage_manager.insert_passage(agent_state=sarah_agent, text=text, actor=default_user) + assert passages is not None + assert len(passages) > 0 + + # Verify passages are in SQL - use agent_manager to list passages + sql_passages = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10) + assert len(sql_passages) >= len(test_passages) + for text in test_passages: + assert any(p.text == text for p in sql_passages) + + # Test vector search which should use Turbopuffer + embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai") + + # Perform vector search + vector_results = await server.agent_manager.query_agent_passages_async( + actor=default_user, + agent_id=sarah_agent.id, + query_text="turbopuffer vector database", + embedding_config=embedding_config, + embed_query=True, + limit=5, + ) + + # Should find relevant passages via Turbopuffer vector search + assert len(vector_results) > 0 + # The most relevant result should be about Turbopuffer + assert any("Turbopuffer" in p.text or "vector" in p.text for p in vector_results) + + # Test deletion - should delete from both + passage_to_delete = sql_passages[0] + await server.passage_manager.delete_agent_passages_async([passage_to_delete], default_user) + + # Verify deleted from SQL + remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10) + assert not any(p.id == passage_to_delete.id for p in remaining) + + # Verify vector search no longer returns deleted passage + vector_results_after_delete = await server.agent_manager.query_agent_passages_async( + actor=default_user, + agent_id=sarah_agent.id, + query_text=passage_to_delete.text, + embedding_config=embedding_config, + embed_query=True, + limit=10, + ) + assert not any(p.id == passage_to_delete.id for p in vector_results_after_delete) + + finally: + # TODO: Clean up archive when delete_archive method is available + pass + + @pytest.mark.asyncio + async def test_turbopuffer_metadata_attributes(self, enable_turbopuffer): + """Test that Turbopuffer properly stores and retrieves metadata attributes""" + + # Only run if we have a real API key + if not settings.tpuf_api_key: + pytest.skip("No Turbopuffer API key available") + + client = TurbopufferClient() + archive_id = f"test-archive-{datetime.now().timestamp()}" + + try: + # Insert passages with various metadata + test_data = [ + { + "id": f"passage-{uuid.uuid4()}", + "text": "First test passage", + "vector": [0.1] * 1536, + "organization_id": "org-123", + "created_at": datetime.now(timezone.utc), + }, + { + "id": f"passage-{uuid.uuid4()}", + "text": "Second test passage", + "vector": [0.2] * 1536, + "organization_id": "org-123", + "created_at": datetime.now(timezone.utc), + }, + { + "id": f"passage-{uuid.uuid4()}", + "text": "Third test passage from different org", + "vector": [0.3] * 1536, + "organization_id": "org-456", + "created_at": datetime.now(timezone.utc), + }, + ] + + # Insert all passages + result = await client.insert_archival_memories( + archive_id=archive_id, + text_chunks=[d["text"] for d in test_data], + embeddings=[d["vector"] for d in test_data], + passage_ids=[d["id"] for d in test_data], + organization_id="org-123", # Default org + created_at=datetime.now(timezone.utc), + ) + + assert len(result) == 3 + + # Query with organization filter + query_vector = [0.15] * 1536 + results = await client.query_passages( + archive_id=archive_id, query_embedding=query_vector, top_k=10, filters={"organization_id": "org-123"} + ) + + # Should only get passages from org-123 + assert len(results) >= 2 # At least the first two passages + for passage, score in results: + assert passage.organization_id == "org-123" + + # Clean up + await client.delete_passages(archive_id=archive_id, passage_ids=[d["id"] for d in test_data]) + + except Exception as e: + # Clean up on error + try: + await client.delete_all_passages(archive_id) + except: + pass + raise e + + @pytest.mark.asyncio + async def test_native_only_operations(self, server, default_user, sarah_agent, disable_turbopuffer): + """Test that operations work correctly when using only native PostgreSQL""" + + # Create archive (should be NATIVE since turbopuffer is disabled) + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + assert archive.vector_db_provider == VectorDBProvider.NATIVE + + # Insert passages - should only write to SQL + text_content = "This is a test passage for native PostgreSQL only." + passages = await server.passage_manager.insert_passage(agent_state=sarah_agent, text=text_content, actor=default_user) + + assert passages is not None + assert len(passages) > 0 + + # List passages - should work from SQL + sql_passages = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10) + assert any(p.text == text_content for p in sql_passages) + + # Vector search should use PostgreSQL pgvector + embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai") + vector_results = await server.agent_manager.query_agent_passages_async( + actor=default_user, + agent_id=sarah_agent.id, + query_text="native postgresql", + embedding_config=embedding_config, + embed_query=True, + ) + + # Should still work with native PostgreSQL + assert isinstance(vector_results, list) + + +@pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True) +class TestTurbopufferParametrized: + """Test that functionality works with and without Turbopuffer enabled""" + + @pytest.mark.asyncio + async def test_passage_operations_with_mode(self, turbopuffer_mode, server, default_user, sarah_agent): + """Test that passage operations work in both modes""" + + # Get or create archive + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + + # Check that vector_db_provider matches the mode + if settings.use_tpuf and settings.tpuf_api_key: + expected_provider = VectorDBProvider.TPUF + else: + expected_provider = VectorDBProvider.NATIVE + assert archive.vector_db_provider == expected_provider + + # Test inserting a passage (should work in both modes) + test_text = f"Test passage for {expected_provider} mode" + passages = await server.passage_manager.insert_passage(agent_state=sarah_agent, text=test_text, actor=default_user) + + assert passages is not None + assert len(passages) > 0 + assert passages[0].text == test_text + + # List passages should work in both modes + listed = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10) + assert any(p.text == test_text for p in listed) + + # Delete should work in both modes + await server.passage_manager.delete_agent_passages_async(passages, default_user) + + # Verify deletion + remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10) + assert not any(p.id == passages[0].id for p in remaining)