feat: Support basic upload/querying on tpuf [LET-3465] (#4255)
* wip implementing turbopuffer * Move imports up * Add type of archive * Integrate turbopuffer functionality * Debug turbopuffer tests failing * Fix turbopuffer * Run fern * Fix multiple heads
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
"""Add vector_db_provider to archives table
|
||||
|
||||
Revision ID: 068588268b02
|
||||
Revises: d5103ee17ed5
|
||||
Create Date: 2025-08-27 13:16:29.428231
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from letta.settings import settings
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "068588268b02"
|
||||
down_revision: Union[str, None] = "887a4367b560"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# PostgreSQL - use enum type
|
||||
vectordbprovider = sa.Enum("NATIVE", "TPUF", name="vectordbprovider")
|
||||
vectordbprovider.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Add column as nullable first
|
||||
op.add_column("archives", sa.Column("vector_db_provider", vectordbprovider, nullable=True))
|
||||
|
||||
# Backfill existing rows with NATIVE
|
||||
op.execute("UPDATE archives SET vector_db_provider = 'NATIVE' WHERE vector_db_provider IS NULL")
|
||||
|
||||
# Make column non-nullable
|
||||
op.alter_column("archives", "vector_db_provider", nullable=False)
|
||||
else:
|
||||
# SQLite - use string type
|
||||
# Add column as nullable first
|
||||
op.add_column("archives", sa.Column("vector_db_provider", sa.String(), nullable=True))
|
||||
|
||||
# Backfill existing rows with NATIVE
|
||||
op.execute("UPDATE archives SET vector_db_provider = 'NATIVE' WHERE vector_db_provider IS NULL")
|
||||
|
||||
# For SQLite, we need to recreate the table to make column non-nullable
|
||||
# This is a limitation of SQLite ALTER TABLE
|
||||
# For simplicity, we'll leave it nullable in SQLite
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("archives", "vector_db_provider")
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# Drop enum type for PostgreSQL
|
||||
vectordbprovider = sa.Enum("NATIVE", "TPUF", name="vectordbprovider")
|
||||
vectordbprovider.drop(op.get_bind(), checkfirst=True)
|
||||
# ### end Alembic commands ###
|
||||
@@ -484,7 +484,7 @@ class VoiceAgent(BaseAgent):
|
||||
if start_date and end_date and start_date > end_date:
|
||||
start_date, end_date = end_date, start_date
|
||||
|
||||
archival_results = await self.agent_manager.list_passages_async(
|
||||
archival_results = await self.agent_manager.query_agent_passages_async(
|
||||
actor=self.actor,
|
||||
agent_id=self.agent_id,
|
||||
query_text=archival_query,
|
||||
|
||||
@@ -107,7 +107,7 @@ async def archival_memory_search(self: "Agent", query: str, page: Optional[int]
|
||||
|
||||
try:
|
||||
# Get results using passage manager
|
||||
all_results = await self.agent_manager.list_passages_async(
|
||||
all_results = await self.agent_manager.query_agent_passages_async(
|
||||
actor=self.user,
|
||||
agent_id=self.agent_state.id,
|
||||
query_text=query,
|
||||
|
||||
276
letta/helpers/tpuf_client.py
Normal file
276
letta/helpers/tpuf_client.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""Turbopuffer utilities for archival memory storage."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def should_use_tpuf() -> bool:
|
||||
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key)
|
||||
|
||||
|
||||
class TurbopufferClient:
|
||||
"""Client for managing archival memory with Turbopuffer vector database."""
|
||||
|
||||
def __init__(self, api_key: str = None, region: str = None):
|
||||
"""Initialize Turbopuffer client."""
|
||||
self.api_key = api_key or settings.tpuf_api_key
|
||||
self.region = region or settings.tpuf_region
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("Turbopuffer API key not provided")
|
||||
|
||||
@trace_method
|
||||
def _get_namespace_name(self, archive_id: str) -> str:
|
||||
"""Get namespace name for a specific archive."""
|
||||
# use archive_id as namespace to isolate different archives' memories
|
||||
# append environment suffix to namespace for isolation if environment is set
|
||||
environment = settings.environment
|
||||
if environment:
|
||||
namespace_name = f"{archive_id}_{environment.lower()}"
|
||||
else:
|
||||
namespace_name = archive_id
|
||||
return namespace_name
|
||||
|
||||
@trace_method
|
||||
async def insert_archival_memories(
|
||||
self,
|
||||
archive_id: str,
|
||||
text_chunks: List[str],
|
||||
embeddings: List[List[float]],
|
||||
passage_ids: List[str],
|
||||
organization_id: str,
|
||||
tags: Optional[List[str]] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Insert passages into Turbopuffer.
|
||||
|
||||
Args:
|
||||
archive_id: ID of the archive
|
||||
text_chunks: List of text chunks to store
|
||||
embeddings: List of embedding vectors corresponding to text chunks
|
||||
passage_ids: List of passage IDs (must match 1:1 with text_chunks)
|
||||
organization_id: Organization ID for the passages
|
||||
tags: Optional list of tags to attach to all passages
|
||||
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
|
||||
|
||||
Returns:
|
||||
List of PydanticPassage objects that were inserted
|
||||
"""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = self._get_namespace_name(archive_id)
|
||||
|
||||
# handle timestamp - ensure UTC
|
||||
if created_at is None:
|
||||
timestamp = datetime.now(timezone.utc)
|
||||
else:
|
||||
# ensure the provided timestamp is timezone-aware and in UTC
|
||||
if created_at.tzinfo is None:
|
||||
# assume UTC if no timezone provided
|
||||
timestamp = created_at.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# convert to UTC if in different timezone
|
||||
timestamp = created_at.astimezone(timezone.utc)
|
||||
|
||||
# passage_ids must be provided for dual-write consistency
|
||||
if not passage_ids:
|
||||
raise ValueError("passage_ids must be provided for Turbopuffer insertion")
|
||||
if len(passage_ids) != len(text_chunks):
|
||||
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})")
|
||||
if len(passage_ids) != len(embeddings):
|
||||
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match embeddings length ({len(embeddings)})")
|
||||
|
||||
# prepare column-based data for turbopuffer - optimized for batch insert
|
||||
ids = []
|
||||
vectors = []
|
||||
texts = []
|
||||
organization_ids = []
|
||||
archive_ids = []
|
||||
created_ats = []
|
||||
passages = []
|
||||
|
||||
# prepare tag columns
|
||||
tag_columns = {tag: [] for tag in (tags or [])}
|
||||
|
||||
for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)):
|
||||
passage_id = passage_ids[idx]
|
||||
|
||||
# append to columns
|
||||
ids.append(passage_id)
|
||||
vectors.append(embedding)
|
||||
texts.append(text)
|
||||
organization_ids.append(organization_id)
|
||||
archive_ids.append(archive_id)
|
||||
created_ats.append(timestamp)
|
||||
|
||||
# append tag values
|
||||
for tag in tag_columns:
|
||||
tag_columns[tag].append(True)
|
||||
|
||||
# Create PydanticPassage object
|
||||
passage = PydanticPassage(
|
||||
id=passage_id,
|
||||
text=text,
|
||||
organization_id=organization_id,
|
||||
archive_id=archive_id,
|
||||
created_at=timestamp,
|
||||
metadata_={},
|
||||
embedding=embedding,
|
||||
embedding_config=None, # Will be set by caller if needed
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
# build column-based upsert data
|
||||
upsert_columns = {
|
||||
"id": ids,
|
||||
"vector": vectors,
|
||||
"text": texts,
|
||||
"organization_id": organization_ids,
|
||||
"archive_id": archive_ids,
|
||||
"created_at": created_ats,
|
||||
}
|
||||
|
||||
# add tag columns if any
|
||||
upsert_columns.update(tag_columns)
|
||||
|
||||
try:
|
||||
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
# turbopuffer recommends column-based writes for performance
|
||||
await namespace.write(upsert_columns=upsert_columns, distance_metric="cosine_distance")
|
||||
logger.info(f"Successfully inserted {len(ids)} passages to Turbopuffer for archive {archive_id}")
|
||||
return passages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert passages to Turbopuffer: {e}")
|
||||
# check if it's a duplicate ID error
|
||||
if "duplicate" in str(e).lower():
|
||||
logger.error("Duplicate passage IDs detected in batch")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def query_passages(
|
||||
self, archive_id: str, query_embedding: List[float], top_k: int = 10, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[PydanticPassage, float]]:
|
||||
"""Query passages from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = self._get_namespace_name(archive_id)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
|
||||
# build filter conditions
|
||||
filter_conditions = []
|
||||
if filters:
|
||||
for key, value in filters.items():
|
||||
filter_conditions.append((key, "Eq", value))
|
||||
|
||||
query_params = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
}
|
||||
|
||||
if filter_conditions:
|
||||
query_params["filters"] = ("And", filter_conditions) if len(filter_conditions) > 1 else filter_conditions[0]
|
||||
|
||||
result = await namespace.query(**query_params)
|
||||
|
||||
# convert results back to passages
|
||||
passages_with_scores = []
|
||||
# Turbopuffer returns a NamespaceQueryResponse with a rows attribute
|
||||
for row in result.rows:
|
||||
# Build metadata including any filter conditions that were applied
|
||||
metadata = {}
|
||||
if filters:
|
||||
metadata["applied_filters"] = filters
|
||||
|
||||
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
|
||||
passage = PydanticPassage(
|
||||
id=row.id,
|
||||
text=getattr(row, "text", ""),
|
||||
organization_id=getattr(row, "organization_id", None),
|
||||
archive_id=archive_id, # use the archive_id from the query
|
||||
created_at=getattr(row, "created_at", None),
|
||||
metadata_=metadata, # Include filter conditions in metadata
|
||||
# Set required fields to empty/default values since we don't store embeddings
|
||||
embedding=[], # Empty embedding since we don't return it from Turbopuffer
|
||||
embedding_config=None, # No embedding config needed for retrieved passages
|
||||
)
|
||||
# turbopuffer returns distance in $dist attribute, convert to similarity score
|
||||
distance = getattr(row, "$dist", 0.0)
|
||||
score = 1.0 - distance
|
||||
passages_with_scores.append((passage, score))
|
||||
|
||||
return passages_with_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query passages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
|
||||
"""Delete a passage from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = self._get_namespace_name(archive_id)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
# Use write API with deletes parameter as per Turbopuffer docs
|
||||
await namespace.write(deletes=[passage_id])
|
||||
logger.info(f"Successfully deleted passage {passage_id} from Turbopuffer archive {archive_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete passage from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool:
|
||||
"""Delete multiple passages from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
if not passage_ids:
|
||||
return True
|
||||
|
||||
namespace_name = self._get_namespace_name(archive_id)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
# Use write API with deletes parameter as per Turbopuffer docs
|
||||
await namespace.write(deletes=passage_ids)
|
||||
logger.info(f"Successfully deleted {len(passage_ids)} passages from Turbopuffer archive {archive_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete passages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def delete_all_passages(self, archive_id: str) -> bool:
|
||||
"""Delete all passages for an archive from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
namespace_name = self._get_namespace_name(archive_id)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
# Turbopuffer has a delete_all() method on namespace
|
||||
await namespace.delete_all()
|
||||
logger.info(f"Successfully deleted all passages for archive {archive_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete all passages from Turbopuffer: {e}")
|
||||
raise
|
||||
@@ -2,12 +2,13 @@ import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import JSON, Index, String
|
||||
from sqlalchemy import JSON, Enum, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.archive import Archive as PydanticArchive
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -38,6 +39,12 @@ class Archive(SqlalchemyBase, OrganizationMixin):
|
||||
# archive-specific fields
|
||||
name: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the archive")
|
||||
description: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="A description of the archive")
|
||||
vector_db_provider: Mapped[VectorDBProvider] = mapped_column(
|
||||
Enum(VectorDBProvider),
|
||||
nullable=False,
|
||||
default=VectorDBProvider.NATIVE,
|
||||
doc="The vector database provider used for this archive's passages",
|
||||
)
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Additional metadata for the archive")
|
||||
|
||||
# relationships
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
|
||||
|
||||
@@ -12,6 +13,9 @@ class ArchiveBase(OrmMetadataBase):
|
||||
name: str = Field(..., description="The name of the archive")
|
||||
description: Optional[str] = Field(None, description="A description of the archive")
|
||||
organization_id: str = Field(..., description="The organization this archive belongs to")
|
||||
vector_db_provider: VectorDBProvider = Field(
|
||||
default=VectorDBProvider.NATIVE, description="The vector database provider used for this archive's passages"
|
||||
)
|
||||
metadata: Optional[Dict] = Field(default_factory=dict, validation_alias="metadata_", description="Additional metadata")
|
||||
|
||||
|
||||
|
||||
@@ -171,3 +171,10 @@ class StepStatus(str, Enum):
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class VectorDBProvider(str, Enum):
|
||||
"""Supported vector database providers for archival memory"""
|
||||
|
||||
NATIVE = "native"
|
||||
TPUF = "tpuf"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from jinja2 import Template, TemplateSyntaxError
|
||||
@@ -325,3 +326,5 @@ class RecallMemorySummary(BaseModel):
|
||||
|
||||
class CreateArchivalMemory(BaseModel):
|
||||
text: str = Field(..., description="Text to write to archival memory.")
|
||||
tags: Optional[List[str]] = Field(None, description="Optional list of tags to attach to the memory.")
|
||||
created_at: Optional[datetime] = Field(None, description="Optional timestamp for the memory (defaults to current UTC time).")
|
||||
|
||||
@@ -34,7 +34,7 @@ from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaSt
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.passage import Passage, PassageUpdate
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
@@ -954,22 +954,9 @@ async def create_passage(
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
return await server.insert_archival_memory_async(agent_id=agent_id, memory_contents=request.text, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/archival-memory/{memory_id}", response_model=list[Passage], operation_id="modify_passage")
|
||||
def modify_passage(
|
||||
agent_id: str,
|
||||
memory_id: str,
|
||||
passage: PassageUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Modify a memory in the agent's archival memory store.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.modify_archival_memory(agent_id=agent_id, memory_id=memory_id, passage=passage, actor=actor)
|
||||
return await server.insert_archival_memory_async(
|
||||
agent_id=agent_id, memory_contents=request.text, actor=actor, tags=request.tags, created_at=request.created_at
|
||||
)
|
||||
|
||||
|
||||
# TODO(ethan): query or path parameter for memory_id?
|
||||
|
||||
@@ -351,7 +351,7 @@ async def list_folder_passages(
|
||||
List all passages associated with a data folder.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.agent_manager.list_passages_async(
|
||||
return await server.agent_manager.query_source_passages_async(
|
||||
actor=actor,
|
||||
source_id=folder_id,
|
||||
after=after,
|
||||
|
||||
@@ -349,7 +349,7 @@ async def list_source_passages(
|
||||
List all passages associated with a data source.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.agent_manager.list_passages_async(
|
||||
return await server.agent_manager.query_source_passages_async(
|
||||
actor=actor,
|
||||
source_id=source_id,
|
||||
after=after,
|
||||
|
||||
@@ -52,7 +52,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.passage import Passage, PassageUpdate
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.pip_requirement import PipRequirement
|
||||
from letta.schemas.providers import (
|
||||
AnthropicProvider,
|
||||
@@ -1114,7 +1114,7 @@ class SyncServer(Server):
|
||||
ascending: Optional[bool] = True,
|
||||
) -> List[Passage]:
|
||||
# iterate over records
|
||||
records = await self.agent_manager.list_agent_passages_async(
|
||||
records = await self.agent_manager.query_agent_passages_async(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
after=after,
|
||||
@@ -1125,18 +1125,18 @@ class SyncServer(Server):
|
||||
)
|
||||
return records
|
||||
|
||||
async def insert_archival_memory_async(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]:
|
||||
async def insert_archival_memory_async(
|
||||
self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime]
|
||||
) -> List[Passage]:
|
||||
# Get the agent object (loaded in memory)
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Insert passages into the archive
|
||||
# Use passage manager which handles dual-write to Turbopuffer if enabled
|
||||
passages = await self.passage_manager.insert_passage(agent_state=agent_state, text=memory_contents, actor=actor)
|
||||
|
||||
return passages
|
||||
# TODO: Add support for tags and created_at parameters
|
||||
# Currently PassageManager.insert_passage doesn't support these parameters
|
||||
|
||||
def modify_archival_memory(self, agent_id: str, memory_id: str, passage: PassageUpdate, actor: User) -> List[Passage]:
|
||||
passage = Passage(**passage.model_dump(exclude_unset=True, exclude_none=True))
|
||||
passages = self.passage_manager.update_passage_by_id(passage_id=memory_id, passage=passage, actor=actor)
|
||||
return passages
|
||||
|
||||
async def delete_archival_memory_async(self, memory_id: str, actor: User):
|
||||
@@ -1270,7 +1270,7 @@ class SyncServer(Server):
|
||||
await self.source_manager.delete_source(source_id=source_id, actor=actor)
|
||||
|
||||
# delete data from passage store
|
||||
passages_to_be_deleted = await self.agent_manager.list_passages_async(actor=actor, source_id=source_id, limit=None)
|
||||
passages_to_be_deleted = await self.agent_manager.query_source_passages_async(actor=actor, source_id=source_id, limit=None)
|
||||
await self.passage_manager.delete_source_passages_async(actor=actor, passages=passages_to_be_deleted)
|
||||
|
||||
# TODO: delete data from agent passage stores (?)
|
||||
@@ -1316,27 +1316,6 @@ class SyncServer(Server):
|
||||
async def sleeptime_document_ingest_async(
|
||||
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
|
||||
) -> None:
|
||||
# TEMPORARILY DISABLE UNTIL V2
|
||||
# sleeptime_agent_state = await self.create_document_sleeptime_agent_async(main_agent, source, actor, clear_history)
|
||||
# sleeptime_agent = LettaAgent(
|
||||
# agent_id=sleeptime_agent_state.id,
|
||||
# message_manager=self.message_manager,
|
||||
# agent_manager=self.agent_manager,
|
||||
# block_manager=self.block_manager,
|
||||
# job_manager=self.job_manager,
|
||||
# passage_manager=self.passage_manager,
|
||||
# actor=actor,
|
||||
# step_manager=self.step_manager,
|
||||
# telemetry_manager=self.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
# )
|
||||
# passages = await self.agent_manager.list_passages_async(actor=actor, source_id=source.id)
|
||||
# for passage in passages:
|
||||
# await sleeptime_agent.step(
|
||||
# input_messages=[
|
||||
# MessageCreate(role="user", content=passage.text),
|
||||
# ]
|
||||
# )
|
||||
# await self.agent_manager.delete_agent_async(agent_id=sleeptime_agent_state.id, actor=actor)
|
||||
pass
|
||||
|
||||
async def _remove_file_from_agent(self, agent_id: str, file_id: str, actor: User) -> None:
|
||||
|
||||
@@ -48,7 +48,7 @@ from letta.schemas.block import DEFAULT_BLOCKS
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderType, ToolType
|
||||
from letta.schemas.enums import ProviderType, ToolType, VectorDBProvider
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.group import Group as PydanticGroup
|
||||
from letta.schemas.group import ManagerType
|
||||
@@ -67,6 +67,7 @@ from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
||||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.archive_manager import ArchiveManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator
|
||||
from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, TiktokenCounter
|
||||
@@ -116,6 +117,7 @@ class AgentManager:
|
||||
self.passage_manager = PassageManager()
|
||||
self.identity_manager = IdentityManager()
|
||||
self.file_agent_manager = FileAgentManager()
|
||||
self.archive_manager = ArchiveManager()
|
||||
|
||||
@staticmethod
|
||||
def _should_exclude_model_from_base_tool_rules(model: str) -> bool:
|
||||
@@ -2520,7 +2522,20 @@ class AgentManager:
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_only: bool = False,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Lists all passages attached to an agent."""
|
||||
"""
|
||||
DEPRECATED: Use query_source_passages_async or query_agent_passages_async instead.
|
||||
This method is kept only for test compatibility and will be removed in a future version.
|
||||
|
||||
Lists all passages attached to an agent (combines both source and agent passages).
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"list_passages_async is deprecated. Use query_source_passages_async or query_agent_passages_async instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
main_query = await build_passage_query(
|
||||
actor=actor,
|
||||
@@ -2567,7 +2582,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_source_passages_async(
|
||||
async def query_source_passages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
@@ -2615,7 +2630,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_agent_passages_async(
|
||||
async def query_agent_passages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
@@ -2630,6 +2645,46 @@ class AgentManager:
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Lists all passages attached to an agent."""
|
||||
# Check if we should use Turbopuffer for vector search
|
||||
if embed_query and agent_id and query_text and embedding_config:
|
||||
# Get archive IDs for the agent
|
||||
archive_ids = await self.get_agent_archive_ids_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
if archive_ids:
|
||||
# TODO: Remove this restriction once we support multiple archives with mixed vector DB providers
|
||||
if len(archive_ids) > 1:
|
||||
raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported for vector search")
|
||||
|
||||
# Get archive to check vector_db_provider
|
||||
archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_ids[0], actor=actor)
|
||||
|
||||
# Use Turbopuffer for vector search if archive is configured for TPUF
|
||||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
# Generate embedding for query
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
|
||||
query_embedding = embeddings[0]
|
||||
|
||||
# Query Turbopuffer
|
||||
tpuf_client = TurbopufferClient()
|
||||
passages_with_scores = await tpuf_client.query_passages(
|
||||
archive_id=archive_ids[0],
|
||||
query_embedding=query_embedding,
|
||||
top_k=limit,
|
||||
)
|
||||
|
||||
# Return just the passages (without scores)
|
||||
return [passage for passage, _ in passages_with_scores]
|
||||
else:
|
||||
return []
|
||||
|
||||
# Fall back to SQL-based search for non-vector queries or NATIVE archives
|
||||
async with db_registry.async_session() as session:
|
||||
main_query = await build_agent_passage_query(
|
||||
actor=actor,
|
||||
|
||||
@@ -2,11 +2,13 @@ from typing import List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.log import get_logger
|
||||
from letta.orm import ArchivalPassage
|
||||
from letta.orm import Archive as ArchiveModel
|
||||
from letta.orm import ArchivesAgents
|
||||
from letta.schemas.archive import Archive as PydanticArchive
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
@@ -27,10 +29,14 @@ class ArchiveManager:
|
||||
"""Create a new archive."""
|
||||
try:
|
||||
with db_registry.session() as session:
|
||||
# determine vector db provider based on settings
|
||||
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
|
||||
|
||||
archive = ArchiveModel(
|
||||
name=name,
|
||||
description=description,
|
||||
organization_id=actor.organization_id,
|
||||
vector_db_provider=vector_db_provider,
|
||||
)
|
||||
archive.create(session, actor=actor)
|
||||
return archive.to_pydantic()
|
||||
@@ -48,10 +54,14 @@ class ArchiveManager:
|
||||
"""Create a new archive."""
|
||||
try:
|
||||
async with db_registry.async_session() as session:
|
||||
# determine vector db provider based on settings
|
||||
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
|
||||
|
||||
archive = ArchiveModel(
|
||||
name=name,
|
||||
description=description,
|
||||
organization_id=actor.organization_id,
|
||||
vector_db_provider=vector_db_provider,
|
||||
)
|
||||
await archive.create_async(session, actor=actor)
|
||||
return archive.to_pydantic()
|
||||
|
||||
@@ -14,6 +14,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.passage import ArchivalPassage, SourcePassage
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@@ -489,6 +490,8 @@ class PassageManager:
|
||||
embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config)
|
||||
|
||||
passages = []
|
||||
|
||||
# Always write to SQL database first
|
||||
for chunk_text, embedding in zip(text_chunks, embeddings):
|
||||
passage = await self.create_agent_passage_async(
|
||||
PydanticPassage(
|
||||
@@ -502,6 +505,26 @@ class PassageManager:
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
# If archive uses Turbopuffer, also write to Turbopuffer (dual-write)
|
||||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# Extract IDs and texts from the created passages
|
||||
passage_ids = [p.id for p in passages]
|
||||
passage_texts = [p.text for p in passages]
|
||||
|
||||
# Insert to Turbopuffer with the same IDs as SQL
|
||||
await tpuf_client.insert_archival_memories(
|
||||
archive_id=archive.id,
|
||||
text_chunks=passage_texts,
|
||||
embeddings=embeddings,
|
||||
passage_ids=passage_ids, # Use same IDs as SQL
|
||||
organization_id=actor.organization_id,
|
||||
created_at=passages[0].created_at if passages else None,
|
||||
)
|
||||
|
||||
return passages
|
||||
|
||||
except Exception as e:
|
||||
@@ -655,7 +678,20 @@ class PassageManager:
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
||||
archive_id = passage.archive_id
|
||||
|
||||
# Delete from SQL first
|
||||
await passage.hard_delete_async(session, actor=actor)
|
||||
|
||||
# Check if archive uses Turbopuffer and dual-delete
|
||||
if archive_id:
|
||||
archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_id, actor=actor)
|
||||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.delete_passage(archive_id=archive_id, passage_id=passage_id)
|
||||
|
||||
return True
|
||||
except NoResultFound:
|
||||
raise NoResultFound(f"Agent passage with id {passage_id} not found.")
|
||||
@@ -812,12 +848,34 @@ class PassageManager:
|
||||
@trace_method
|
||||
async def delete_agent_passages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
passages: List[PydanticPassage],
|
||||
actor: PydanticUser,
|
||||
) -> bool:
|
||||
"""Delete multiple agent passages."""
|
||||
if not passages:
|
||||
return True
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Delete from SQL first
|
||||
await ArchivalPassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor)
|
||||
|
||||
# Group passages by archive_id for efficient Turbopuffer deletion
|
||||
passages_by_archive = {}
|
||||
for passage in passages:
|
||||
if passage.archive_id:
|
||||
if passage.archive_id not in passages_by_archive:
|
||||
passages_by_archive[passage.archive_id] = []
|
||||
passages_by_archive[passage.archive_id].append(passage.id)
|
||||
|
||||
# Check each archive and delete from Turbopuffer if needed
|
||||
for archive_id, passage_ids in passages_by_archive.items():
|
||||
archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_id, actor=actor)
|
||||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.delete_passages(archive_id=archive_id, passage_ids=passage_ids)
|
||||
|
||||
return True
|
||||
|
||||
@enforce_types
|
||||
|
||||
@@ -143,7 +143,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
|
||||
try:
|
||||
# Get results using passage manager
|
||||
all_results = await AgentManager().list_agent_passages_async(
|
||||
all_results = await AgentManager().query_agent_passages_async(
|
||||
actor=actor,
|
||||
agent_id=agent_state.id,
|
||||
query_text=query,
|
||||
|
||||
@@ -661,7 +661,7 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
async def _search_files_traditional(self, agent_state: AgentState, query: str, limit: int) -> str:
|
||||
"""Traditional search using existing passage manager."""
|
||||
# Get semantic search results
|
||||
passages = await self.agent_manager.list_source_passages_async(
|
||||
passages = await self.agent_manager.query_source_passages_async(
|
||||
actor=self.actor,
|
||||
agent_id=agent_state.id,
|
||||
query_text=query,
|
||||
|
||||
@@ -300,7 +300,7 @@ class Settings(BaseSettings):
|
||||
# For tpuf - currently only for archival memories
|
||||
use_tpuf: bool = False
|
||||
tpuf_api_key: Optional[str] = None
|
||||
tpuf_region: str = "gcp-us-central1.turbopuffer.com"
|
||||
tpuf_region: str = "gcp-us-central1"
|
||||
|
||||
# File processing timeout settings
|
||||
file_processing_timeout_minutes: int = 30
|
||||
|
||||
@@ -104,6 +104,62 @@ def disable_pinecone() -> Generator[None, None, None]:
|
||||
settings.pinecone_api_key = original_pinecone_api_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_turbopuffer() -> Generator[None, None, None]:
|
||||
"""
|
||||
Temporarily disables Turbopuffer by setting `settings.use_tpuf` to False
|
||||
and `settings.tpuf_api_key` to None for the duration of the test.
|
||||
Also sets environment to DEV for testing.
|
||||
Restores the original values afterward.
|
||||
"""
|
||||
from letta.settings import settings
|
||||
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_tpuf_api_key = settings.tpuf_api_key
|
||||
original_environment = settings.environment
|
||||
settings.use_tpuf = False
|
||||
settings.tpuf_api_key = None
|
||||
settings.environment = "DEV"
|
||||
yield
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_tpuf_api_key
|
||||
settings.environment = original_environment
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def turbopuffer_mode(request) -> Generator[None, None, None]:
|
||||
"""
|
||||
Parametrizable fixture to enable/disable Turbopuffer mode.
|
||||
|
||||
Usage:
|
||||
@pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True)
|
||||
def test_function(turbopuffer_mode, ...):
|
||||
# Test runs twice - once with Turbopuffer enabled, once disabled
|
||||
"""
|
||||
from letta.settings import settings
|
||||
|
||||
enable_tpuf = request.param
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_tpuf_api_key = settings.tpuf_api_key
|
||||
original_environment = settings.environment
|
||||
|
||||
# Set environment to DEV for testing
|
||||
settings.environment = "DEV"
|
||||
|
||||
if not enable_tpuf:
|
||||
# Disable Turbopuffer by setting use_tpuf to False
|
||||
settings.use_tpuf = False
|
||||
settings.tpuf_api_key = None
|
||||
# If enable_tpuf is True, leave the original settings unchanged
|
||||
|
||||
yield
|
||||
|
||||
# Restore original settings
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_tpuf_api_key
|
||||
settings.environment = original_environment
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_e2b_key_is_set():
|
||||
from letta.settings import tool_settings
|
||||
|
||||
@@ -2670,21 +2670,18 @@ async def test_refresh_memory_async(server: SyncServer, default_user):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup):
|
||||
async def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer):
|
||||
"""Test basic listing functionality of agent passages"""
|
||||
|
||||
all_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id)
|
||||
assert len(all_passages) == 5 # 3 source + 2 agent passages
|
||||
|
||||
source_passages = await server.agent_manager.list_source_passages_async(actor=default_user, agent_id=sarah_agent.id)
|
||||
source_passages = await server.agent_manager.query_source_passages_async(actor=default_user, agent_id=sarah_agent.id)
|
||||
assert len(source_passages) == 3 # 3 source + 2 agent passages
|
||||
|
||||
agent_passages = await server.agent_manager.list_agent_passages_async(actor=default_user, agent_id=sarah_agent.id)
|
||||
assert len(agent_passages) == 2 # 3 source + 2 agent passages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup):
|
||||
async def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer):
|
||||
"""Test ordering of agent passages"""
|
||||
|
||||
# Test ascending order
|
||||
@@ -2701,7 +2698,7 @@ async def test_agent_list_passages_ordering(server, default_user, sarah_agent, a
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup):
|
||||
async def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer):
|
||||
"""Test pagination of agent passages"""
|
||||
|
||||
# Test limit
|
||||
@@ -2742,7 +2739,7 @@ async def test_agent_list_passages_pagination(server, default_user, sarah_agent,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup):
|
||||
async def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer):
|
||||
"""Test text search functionality of agent passages"""
|
||||
|
||||
# Test text search for source passages
|
||||
@@ -2759,7 +2756,7 @@ async def test_agent_list_passages_text_search(server, default_user, sarah_agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup):
|
||||
async def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer):
|
||||
"""Test text search functionality of agent passages"""
|
||||
|
||||
# Test text search for agent passages
|
||||
@@ -2768,7 +2765,7 @@ async def test_agent_list_passages_agent_only(server, default_user, sarah_agent,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup):
|
||||
async def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup, disable_turbopuffer):
|
||||
"""Test filtering functionality of agent passages"""
|
||||
|
||||
# Test source filtering
|
||||
@@ -2804,7 +2801,9 @@ def mock_embed_model(mock_embeddings):
|
||||
return mock_model
|
||||
|
||||
|
||||
async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, default_file, mock_embed_model):
|
||||
async def test_agent_list_passages_vector_search(
|
||||
server, default_user, sarah_agent, default_source, default_file, mock_embed_model, disable_turbopuffer
|
||||
):
|
||||
"""Test vector search functionality of agent passages"""
|
||||
embed_model = mock_embed_model
|
||||
|
||||
|
||||
330
tests/test_turbopuffer_integration.py
Normal file
330
tests/test_turbopuffer_integration.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
"""Server fixture for testing"""
|
||||
config = LettaConfig.load()
|
||||
config.save()
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sarah_agent(server, default_user):
|
||||
"""Create a test agent named Sarah"""
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
agent = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="Sarah",
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
yield agent
|
||||
# Cleanup
|
||||
try:
|
||||
await server.agent_manager.delete_agent_async(agent.id, default_user)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_turbopuffer():
|
||||
"""Temporarily enable Turbopuffer for testing with a test API key"""
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_api_key = settings.tpuf_api_key
|
||||
original_environment = settings.environment
|
||||
|
||||
# Enable Turbopuffer with test key
|
||||
settings.use_tpuf = True
|
||||
# Use the existing tpuf_api_key if set, otherwise keep original
|
||||
if not settings.tpuf_api_key:
|
||||
settings.tpuf_api_key = original_api_key
|
||||
# Set environment to DEV for testing
|
||||
settings.environment = "DEV"
|
||||
|
||||
yield
|
||||
|
||||
# Restore original values
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_api_key
|
||||
settings.environment = original_environment
|
||||
|
||||
|
||||
class TestTurbopufferIntegration:
|
||||
"""Test Turbopuffer integration functionality with real connections"""
|
||||
|
||||
def test_should_use_tpuf_with_settings(self):
|
||||
"""Test that should_use_tpuf correctly reads settings"""
|
||||
# Save original values
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_api_key = settings.tpuf_api_key
|
||||
|
||||
try:
|
||||
# Test when both are set
|
||||
settings.use_tpuf = True
|
||||
settings.tpuf_api_key = "test-key"
|
||||
assert should_use_tpuf() is True
|
||||
|
||||
# Test when use_tpuf is False
|
||||
settings.use_tpuf = False
|
||||
assert should_use_tpuf() is False
|
||||
|
||||
# Test when API key is missing
|
||||
settings.use_tpuf = True
|
||||
settings.tpuf_api_key = None
|
||||
assert should_use_tpuf() is False
|
||||
finally:
|
||||
# Restore original values
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_api_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_creation_with_tpuf_enabled(self, server, default_user, enable_turbopuffer):
|
||||
"""Test that archives are created with correct vector_db_provider when TPUF is enabled"""
|
||||
archive = await server.archive_manager.create_archive_async(name="Test Archive with TPUF", actor=default_user)
|
||||
assert archive.vector_db_provider == VectorDBProvider.TPUF
|
||||
# TODO: Add cleanup when delete_archive method is available
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_creation_with_tpuf_disabled(self, server, default_user, disable_turbopuffer):
|
||||
"""Test that archives default to NATIVE when TPUF is disabled"""
|
||||
archive = await server.archive_manager.create_archive_async(name="Test Archive without TPUF", actor=default_user)
|
||||
assert archive.vector_db_provider == VectorDBProvider.NATIVE
|
||||
# TODO: Add cleanup when delete_archive method is available
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
|
||||
async def test_dual_write_and_query_with_real_tpuf(self, server, default_user, sarah_agent, enable_turbopuffer):
|
||||
"""Test that passages are written to both SQL and Turbopuffer with real connection and can be queried"""
|
||||
|
||||
# Create a TPUF-enabled archive
|
||||
archive = await server.archive_manager.create_archive_async(name="Test TPUF Archive for Real Dual Write", actor=default_user)
|
||||
assert archive.vector_db_provider == VectorDBProvider.TPUF
|
||||
|
||||
# Attach the agent to the archive
|
||||
await server.archive_manager.attach_agent_to_archive_async(
|
||||
agent_id=sarah_agent.id, archive_id=archive.id, is_owner=True, actor=default_user
|
||||
)
|
||||
|
||||
try:
|
||||
# Insert passages - this should trigger dual write
|
||||
test_passages = [
|
||||
"Turbopuffer is a vector database optimized for performance.",
|
||||
"This integration test verifies dual-write functionality.",
|
||||
"Metadata attributes should be properly stored in Turbopuffer.",
|
||||
]
|
||||
|
||||
for text in test_passages:
|
||||
passages = await server.passage_manager.insert_passage(agent_state=sarah_agent, text=text, actor=default_user)
|
||||
assert passages is not None
|
||||
assert len(passages) > 0
|
||||
|
||||
# Verify passages are in SQL - use agent_manager to list passages
|
||||
sql_passages = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
assert len(sql_passages) >= len(test_passages)
|
||||
for text in test_passages:
|
||||
assert any(p.text == text for p in sql_passages)
|
||||
|
||||
# Test vector search which should use Turbopuffer
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# Perform vector search
|
||||
vector_results = await server.agent_manager.query_agent_passages_async(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="turbopuffer vector database",
|
||||
embedding_config=embedding_config,
|
||||
embed_query=True,
|
||||
limit=5,
|
||||
)
|
||||
|
||||
# Should find relevant passages via Turbopuffer vector search
|
||||
assert len(vector_results) > 0
|
||||
# The most relevant result should be about Turbopuffer
|
||||
assert any("Turbopuffer" in p.text or "vector" in p.text for p in vector_results)
|
||||
|
||||
# Test deletion - should delete from both
|
||||
passage_to_delete = sql_passages[0]
|
||||
await server.passage_manager.delete_agent_passages_async([passage_to_delete], default_user)
|
||||
|
||||
# Verify deleted from SQL
|
||||
remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
assert not any(p.id == passage_to_delete.id for p in remaining)
|
||||
|
||||
# Verify vector search no longer returns deleted passage
|
||||
vector_results_after_delete = await server.agent_manager.query_agent_passages_async(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text=passage_to_delete.text,
|
||||
embedding_config=embedding_config,
|
||||
embed_query=True,
|
||||
limit=10,
|
||||
)
|
||||
assert not any(p.id == passage_to_delete.id for p in vector_results_after_delete)
|
||||
|
||||
finally:
|
||||
# TODO: Clean up archive when delete_archive method is available
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turbopuffer_metadata_attributes(self, enable_turbopuffer):
|
||||
"""Test that Turbopuffer properly stores and retrieves metadata attributes"""
|
||||
|
||||
# Only run if we have a real API key
|
||||
if not settings.tpuf_api_key:
|
||||
pytest.skip("No Turbopuffer API key available")
|
||||
|
||||
client = TurbopufferClient()
|
||||
archive_id = f"test-archive-{datetime.now().timestamp()}"
|
||||
|
||||
try:
|
||||
# Insert passages with various metadata
|
||||
test_data = [
|
||||
{
|
||||
"id": f"passage-{uuid.uuid4()}",
|
||||
"text": "First test passage",
|
||||
"vector": [0.1] * 1536,
|
||||
"organization_id": "org-123",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
},
|
||||
{
|
||||
"id": f"passage-{uuid.uuid4()}",
|
||||
"text": "Second test passage",
|
||||
"vector": [0.2] * 1536,
|
||||
"organization_id": "org-123",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
},
|
||||
{
|
||||
"id": f"passage-{uuid.uuid4()}",
|
||||
"text": "Third test passage from different org",
|
||||
"vector": [0.3] * 1536,
|
||||
"organization_id": "org-456",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
},
|
||||
]
|
||||
|
||||
# Insert all passages
|
||||
result = await client.insert_archival_memories(
|
||||
archive_id=archive_id,
|
||||
text_chunks=[d["text"] for d in test_data],
|
||||
embeddings=[d["vector"] for d in test_data],
|
||||
passage_ids=[d["id"] for d in test_data],
|
||||
organization_id="org-123", # Default org
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
# Query with organization filter
|
||||
query_vector = [0.15] * 1536
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id, query_embedding=query_vector, top_k=10, filters={"organization_id": "org-123"}
|
||||
)
|
||||
|
||||
# Should only get passages from org-123
|
||||
assert len(results) >= 2 # At least the first two passages
|
||||
for passage, score in results:
|
||||
assert passage.organization_id == "org-123"
|
||||
|
||||
# Clean up
|
||||
await client.delete_passages(archive_id=archive_id, passage_ids=[d["id"] for d in test_data])
|
||||
|
||||
except Exception as e:
|
||||
# Clean up on error
|
||||
try:
|
||||
await client.delete_all_passages(archive_id)
|
||||
except:
|
||||
pass
|
||||
raise e
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_native_only_operations(self, server, default_user, sarah_agent, disable_turbopuffer):
|
||||
"""Test that operations work correctly when using only native PostgreSQL"""
|
||||
|
||||
# Create archive (should be NATIVE since turbopuffer is disabled)
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
assert archive.vector_db_provider == VectorDBProvider.NATIVE
|
||||
|
||||
# Insert passages - should only write to SQL
|
||||
text_content = "This is a test passage for native PostgreSQL only."
|
||||
passages = await server.passage_manager.insert_passage(agent_state=sarah_agent, text=text_content, actor=default_user)
|
||||
|
||||
assert passages is not None
|
||||
assert len(passages) > 0
|
||||
|
||||
# List passages - should work from SQL
|
||||
sql_passages = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
assert any(p.text == text_content for p in sql_passages)
|
||||
|
||||
# Vector search should use PostgreSQL pgvector
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
vector_results = await server.agent_manager.query_agent_passages_async(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="native postgresql",
|
||||
embedding_config=embedding_config,
|
||||
embed_query=True,
|
||||
)
|
||||
|
||||
# Should still work with native PostgreSQL
|
||||
assert isinstance(vector_results, list)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True)
|
||||
class TestTurbopufferParametrized:
|
||||
"""Test that functionality works with and without Turbopuffer enabled"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passage_operations_with_mode(self, turbopuffer_mode, server, default_user, sarah_agent):
|
||||
"""Test that passage operations work in both modes"""
|
||||
|
||||
# Get or create archive
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
|
||||
# Check that vector_db_provider matches the mode
|
||||
if settings.use_tpuf and settings.tpuf_api_key:
|
||||
expected_provider = VectorDBProvider.TPUF
|
||||
else:
|
||||
expected_provider = VectorDBProvider.NATIVE
|
||||
assert archive.vector_db_provider == expected_provider
|
||||
|
||||
# Test inserting a passage (should work in both modes)
|
||||
test_text = f"Test passage for {expected_provider} mode"
|
||||
passages = await server.passage_manager.insert_passage(agent_state=sarah_agent, text=test_text, actor=default_user)
|
||||
|
||||
assert passages is not None
|
||||
assert len(passages) > 0
|
||||
assert passages[0].text == test_text
|
||||
|
||||
# List passages should work in both modes
|
||||
listed = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
assert any(p.text == test_text for p in listed)
|
||||
|
||||
# Delete should work in both modes
|
||||
await server.passage_manager.delete_agent_passages_async(passages, default_user)
|
||||
|
||||
# Verify deletion
|
||||
remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
assert not any(p.id == passages[0].id for p in remaining)
|
||||
Reference in New Issue
Block a user