feat: Support arbitrary string tagging filtering [LET-3467] (#4285)

* Finish tagging

* Add comprehensive tags functionality

* Add fern autogen

* Create passage tags table

* Add indices

* Add comments explaining dual storage

* Fix alembic heads

* Fix alembic

---------

Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com>
This commit is contained in:
Matthew Zhou
2025-08-28 16:57:36 -07:00
committed by GitHub
parent 2d19903252
commit c1f8c48818
13 changed files with 919 additions and 52 deletions

View File

@@ -0,0 +1,73 @@
"""Add tags to passages and create passage_tags junction table
Revision ID: 54c76f7cabca
Revises: c41c87205254
Create Date: 2025-08-28 15:13:01.549590
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from letta.settings import settings
# revision identifiers, used by Alembic.
revision: str = "54c76f7cabca"
down_revision: Union[str, None] = "c41c87205254"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Database-specific timestamp defaults
if not settings.letta_pg_uri_no_default:
# SQLite uses CURRENT_TIMESTAMP
timestamp_default = sa.text("(CURRENT_TIMESTAMP)")
else:
# PostgreSQL uses now()
timestamp_default = sa.text("now()")
op.create_table(
"passage_tags",
sa.Column("id", sa.String(), nullable=False),
sa.Column("tag", sa.String(), nullable=False),
sa.Column("passage_id", sa.String(), nullable=False),
sa.Column("archive_id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=timestamp_default, nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=timestamp_default, nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["archive_id"], ["archives.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.ForeignKeyConstraint(["passage_id"], ["archival_passages.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("passage_id", "tag", name="uq_passage_tag"),
)
op.create_index("ix_passage_tags_archive_id", "passage_tags", ["archive_id"], unique=False)
op.create_index("ix_passage_tags_archive_tag", "passage_tags", ["archive_id", "tag"], unique=False)
op.create_index("ix_passage_tags_org_archive", "passage_tags", ["organization_id", "archive_id"], unique=False)
op.create_index("ix_passage_tags_tag", "passage_tags", ["tag"], unique=False)
op.add_column("archival_passages", sa.Column("tags", sa.JSON(), nullable=True))
op.add_column("source_passages", sa.Column("tags", sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("source_passages", "tags")
op.drop_column("archival_passages", "tags")
op.drop_index("ix_passage_tags_tag", table_name="passage_tags")
op.drop_index("ix_passage_tags_org_archive", table_name="passage_tags")
op.drop_index("ix_passage_tags_archive_tag", table_name="passage_tags")
op.drop_index("ix_passage_tags_archive_id", table_name="passage_tags")
op.drop_table("passage_tags")
# ### end Alembic commands ###

View File

@@ -2,9 +2,10 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
from letta.otel.tracing import trace_method
from letta.schemas.enums import TagMatchMode
from letta.schemas.passage import Passage as PydanticPassage
from letta.settings import settings
@@ -168,7 +169,8 @@ class TurbopufferClient:
query_text: Optional[str] = None,
search_mode: str = "vector", # "vector", "fts", "hybrid"
top_k: int = 10,
filters: Optional[Dict[str, Any]] = None,
tags: Optional[List[str]] = None,
tag_match_mode: TagMatchMode = TagMatchMode.ANY,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
) -> List[Tuple[PydanticPassage, float]]:
@@ -180,7 +182,8 @@ class TurbopufferClient:
query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
top_k: Number of results to return
filters: Optional filter conditions
tags: Optional list of tags to filter by
tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
@@ -207,15 +210,19 @@ class TurbopufferClient:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
namespace = client.namespace(namespace_name)
# build filter conditions
filter_conditions = []
if filters:
for key, value in filters.items():
filter_conditions.append((key, "Eq", value))
# build tag filter conditions
tag_filter = None
if tags:
tag_conditions = []
for tag in tags:
tag_conditions.append((tag, "Eq", True))
base_filter = (
("And", filter_conditions) if len(filter_conditions) > 1 else (filter_conditions[0] if filter_conditions else None)
)
if len(tag_conditions) == 1:
tag_filter = tag_conditions[0]
elif tag_match_mode == TagMatchMode.ALL:
tag_filter = ("And", tag_conditions)
else: # tag_match_mode == TagMatchMode.ANY
tag_filter = ("Or", tag_conditions)
if search_mode == "vector":
# single vector search query
@@ -224,11 +231,11 @@ class TurbopufferClient:
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
}
if base_filter:
query_params["filters"] = base_filter
if tag_filter:
query_params["filters"] = tag_filter
result = await namespace.query(**query_params)
return self._process_single_query_results(result, archive_id, filters)
return self._process_single_query_results(result, archive_id, tags)
elif search_mode == "fts":
# single full-text search query
@@ -237,11 +244,11 @@ class TurbopufferClient:
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
}
if base_filter:
query_params["filters"] = base_filter
if tag_filter:
query_params["filters"] = tag_filter
result = await namespace.query(**query_params)
return self._process_single_query_results(result, archive_id, filters, is_fts=True)
return self._process_single_query_results(result, archive_id, tags, is_fts=True)
else: # hybrid mode
# multi-query for both vector and FTS
@@ -253,8 +260,8 @@ class TurbopufferClient:
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
}
if base_filter:
vector_query["filters"] = base_filter
if tag_filter:
vector_query["filters"] = tag_filter
queries.append(vector_query)
# full-text search query
@@ -263,16 +270,16 @@ class TurbopufferClient:
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
}
if base_filter:
fts_query["filters"] = base_filter
if tag_filter:
fts_query["filters"] = tag_filter
queries.append(fts_query)
# execute multi-query
response = await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
# process and combine results using reciprocal rank fusion
vector_results = self._process_single_query_results(response.results[0], archive_id, filters)
fts_results = self._process_single_query_results(response.results[1], archive_id, filters, is_fts=True)
vector_results = self._process_single_query_results(response.results[0], archive_id, tags)
fts_results = self._process_single_query_results(response.results[1], archive_id, tags, is_fts=True)
# combine results using reciprocal rank fusion
return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k)
@@ -282,16 +289,16 @@ class TurbopufferClient:
raise
def _process_single_query_results(
self, result, archive_id: str, filters: Optional[Dict[str, Any]], is_fts: bool = False
self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False
) -> List[Tuple[PydanticPassage, float]]:
"""Process results from a single query into passage objects with scores."""
passages_with_scores = []
for row in result.rows:
# Build metadata including any filter conditions that were applied
# Build metadata including any tag filters that were applied
metadata = {}
if filters:
metadata["applied_filters"] = filters
if tags:
metadata["applied_tags"] = tags
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
passage = PydanticPassage(
@@ -300,7 +307,7 @@ class TurbopufferClient:
organization_id=getattr(row, "organization_id", None),
archive_id=archive_id, # use the archive_id from the query
created_at=getattr(row, "created_at", None),
metadata_=metadata, # Include filter conditions in metadata
metadata_=metadata, # Include tag filters in metadata
# Set required fields to empty/default values since we don't store embeddings
embedding=[], # Empty embedding since we don't return it from Turbopuffer
embedding_config=None, # No embedding config needed for retrieved passages

View File

@@ -22,6 +22,7 @@ from letta.orm.mcp_server import MCPServer
from letta.orm.message import Message
from letta.orm.organization import Organization
from letta.orm.passage import ArchivalPassage, BasePassage, SourcePassage
from letta.orm.passage_tag import PassageTag
from letta.orm.prompt import Prompt
from letta.orm.provider import Provider
from letta.orm.provider_trace import ProviderTrace

View File

@@ -16,6 +16,7 @@ if TYPE_CHECKING:
from letta.orm.llm_batch_job import LLMBatchJob
from letta.orm.message import Message
from letta.orm.passage import ArchivalPassage, SourcePassage
from letta.orm.passage_tag import PassageTag
from letta.orm.provider import Provider
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
from letta.orm.tool import Tool
@@ -56,6 +57,7 @@ class Organization(SqlalchemyBase):
archival_passages: Mapped[List["ArchivalPassage"]] = relationship(
"ArchivalPassage", back_populates="organization", cascade="all, delete-orphan"
)
passage_tags: Mapped[List["PassageTag"]] = relationship("PassageTag", back_populates="organization", cascade="all, delete-orphan")
archives: Mapped[List["Archive"]] = relationship("Archive", back_populates="organization", cascade="all, delete-orphan")
providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan")
identities: Mapped[List["Identity"]] = relationship("Identity", back_populates="organization", cascade="all, delete-orphan")

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, Column, Index
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
@@ -27,6 +27,8 @@ class BasePassage(SqlalchemyBase, OrganizationMixin):
text: Mapped[str] = mapped_column(doc="Passage text content")
embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration")
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
# dual storage: json column for fast retrieval, junction table for efficient queries
tags: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="Tags associated with this passage")
# Vector embedding field based on database type
if settings.database_engine is DatabaseChoice.POSTGRES:
@@ -75,6 +77,11 @@ class ArchivalPassage(BasePassage, ArchiveMixin):
__tablename__ = "archival_passages"
# junction table for efficient tag queries (complements json column above)
passage_tags: Mapped[List["PassageTag"]] = relationship(
"PassageTag", back_populates="passage", cascade="all, delete-orphan", lazy="noload"
)
@declared_attr
def organization(cls) -> Mapped["Organization"]:
return relationship("Organization", back_populates="archival_passages", lazy="selectin")

55
letta/orm/passage_tag.py Normal file
View File

@@ -0,0 +1,55 @@
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Index, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.passage import ArchivalPassage
class PassageTag(SqlalchemyBase, OrganizationMixin):
"""Junction table for tags associated with passages.
Design: dual storage approach where tags are stored both in:
1. JSON column in passages table (fast retrieval with passage data)
2. This junction table (efficient DISTINCT/COUNT queries and filtering)
"""
__tablename__ = "passage_tags"
__table_args__ = (
# ensure uniqueness of tag per passage
UniqueConstraint("passage_id", "tag", name="uq_passage_tag"),
# indexes for efficient queries
Index("ix_passage_tags_archive_id", "archive_id"),
Index("ix_passage_tags_tag", "tag"),
Index("ix_passage_tags_archive_tag", "archive_id", "tag"),
Index("ix_passage_tags_org_archive", "organization_id", "archive_id"),
)
# primary key
id: Mapped[str] = mapped_column(String, primary_key=True, doc="Unique identifier for the tag entry")
# tag value
tag: Mapped[str] = mapped_column(String, nullable=False, doc="The tag value")
# foreign keys
passage_id: Mapped[str] = mapped_column(
String, ForeignKey("archival_passages.id", ondelete="CASCADE"), nullable=False, doc="ID of the passage this tag belongs to"
)
archive_id: Mapped[str] = mapped_column(
String,
ForeignKey("archives.id", ondelete="CASCADE"),
nullable=False,
doc="ID of the archive this passage belongs to (denormalized for efficient queries)",
)
# relationships
passage: Mapped["ArchivalPassage"] = relationship("ArchivalPassage", back_populates="passage_tags", lazy="noload")
organization: Mapped["Organization"] = relationship("Organization", back_populates="passage_tags", lazy="selectin")

View File

@@ -179,3 +179,10 @@ class VectorDBProvider(str, Enum):
NATIVE = "native"
TPUF = "tpuf"
class TagMatchMode(str, Enum):
"""Tag matching behavior for filtering"""
ANY = "any"
ALL = "all"

View File

@@ -25,6 +25,7 @@ class PassageBase(OrmMetadataBase):
file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.")
file_name: Optional[str] = Field(None, description="The name of the file (only for source passages).")
metadata: Optional[Dict] = Field({}, validation_alias="metadata_", description="The metadata of the passage.")
tags: Optional[List[str]] = Field(None, description="Tags associated with this passage.")
class Passage(PassageBase):

View File

@@ -48,7 +48,7 @@ from letta.schemas.block import DEFAULT_BLOCKS
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderType, ToolType, VectorDBProvider
from letta.schemas.enums import ProviderType, TagMatchMode, ToolType, VectorDBProvider
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.group import Group as PydanticGroup
from letta.schemas.group import ManagerType
@@ -2649,6 +2649,8 @@ class AgentManager:
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
tags: Optional[List[str]] = None,
tag_match_mode: Optional[TagMatchMode] = None,
) -> List[PydanticPassage]:
"""Lists all passages attached to an agent."""
# Check if we should use Turbopuffer for vector search
@@ -2686,6 +2688,8 @@ class AgentManager:
query_text=query_text, # pass text for potential hybrid search
search_mode="hybrid", # use hybrid mode for better results
top_k=limit,
tags=tags,
tag_match_mode=tag_match_mode or TagMatchMode.ANY,
)
# Return just the passages (without scores)
@@ -2719,7 +2723,30 @@ class AgentManager:
passages = result.scalars().all()
# Convert to Pydantic models
return [p.to_pydantic() for p in passages]
pydantic_passages = [p.to_pydantic() for p in passages]
# TODO: Integrate tag filtering directly into the SQL query for better performance.
# Currently using post-filtering which is less efficient but simpler to implement.
# Future optimization: Add JOIN with passage_tags table and WHERE clause for tag filtering.
if tags:
filtered_passages = []
for passage in pydantic_passages:
if passage.tags:
passage_tags = set(passage.tags)
query_tags = set(tags)
if tag_match_mode == TagMatchMode.ALL:
# ALL mode: passage must have all query tags
if query_tags.issubset(passage_tags):
filtered_passages.append(passage)
else:
# ANY mode (default): passage must have at least one query tag
if query_tags.intersection(passage_tags):
filtered_passages.append(passage)
return filtered_passages
return pydantic_passages
@enforce_types
@trace_method

View File

@@ -1,9 +1,11 @@
import uuid
from datetime import datetime, timezone
from functools import lru_cache
from typing import List, Optional
from typing import Dict, List, Optional
from openai import AsyncOpenAI, OpenAI
from sqlalchemy import select
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from letta.constants import MAX_EMBEDDING_DIM
from letta.embeddings import parse_and_chunk_text
@@ -12,6 +14,7 @@ from letta.llm_api.llm_client import LLMClient
from letta.orm import ArchivesAgents
from letta.orm.errors import NoResultFound
from letta.orm.passage import ArchivalPassage, SourcePassage
from letta.orm.passage_tag import PassageTag
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import VectorDBProvider
@@ -48,6 +51,44 @@ class PassageManager:
def __init__(self):
self.archive_manager = ArchiveManager()
async def _create_tags_for_passage(
self,
session: AsyncSession,
passage_id: str,
archive_id: str,
organization_id: str,
tags: List[str],
actor: PydanticUser,
) -> List[PassageTag]:
"""Create tag entries in junction table (complements tags stored in JSON column).
Junction table enables efficient DISTINCT queries and tag-based filtering.
Note: Tags are already deduplicated before being passed to this method.
"""
if not tags:
return []
tag_objects = []
for tag in tags:
tag_obj = PassageTag(
id=f"passage-tag-{uuid.uuid4()}",
tag=tag,
passage_id=passage_id,
archive_id=archive_id,
organization_id=organization_id,
)
tag_objects.append(tag_obj)
# batch create all tags
created_tags = await PassageTag.batch_create_async(
items=tag_objects,
db_session=session,
actor=actor,
)
return created_tags
# AGENT PASSAGE METHODS
@enforce_types
@trace_method
@@ -155,6 +196,12 @@ class PassageManager:
raise ValueError("Agent passage cannot have source_id")
data = pydantic_passage.model_dump(to_orm=True)
# Deduplicate tags if provided (for dual storage consistency)
tags = data.get("tags")
if tags:
tags = list(set(tags))
common_fields = {
"id": data.get("id"),
"text": data["text"],
@@ -162,6 +209,7 @@ class PassageManager:
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"tags": tags,
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
@@ -182,6 +230,12 @@ class PassageManager:
raise ValueError("Agent passage cannot have source_id")
data = pydantic_passage.model_dump(to_orm=True)
# Deduplicate tags if provided (for dual storage consistency)
tags = data.get("tags")
if tags:
tags = list(set(tags))
common_fields = {
"id": data.get("id"),
"text": data["text"],
@@ -189,6 +243,7 @@ class PassageManager:
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"tags": tags,
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
@@ -197,6 +252,18 @@ class PassageManager:
async with db_registry.async_session() as session:
passage = await passage.create_async(session, actor=actor)
# dual storage: save tags to junction table for efficient queries
if tags: # use the deduplicated tags variable
await self._create_tags_for_passage(
session=session,
passage_id=passage.id,
archive_id=passage.archive_id,
organization_id=passage.organization_id,
tags=tags, # pass deduplicated tags
actor=actor,
)
return passage.to_pydantic()
@enforce_types
@@ -211,6 +278,12 @@ class PassageManager:
raise ValueError("Source passage cannot have archive_id")
data = pydantic_passage.model_dump(to_orm=True)
# Deduplicate tags if provided (for dual storage consistency)
tags = data.get("tags")
if tags:
tags = list(set(tags))
common_fields = {
"id": data.get("id"),
"text": data["text"],
@@ -218,6 +291,7 @@ class PassageManager:
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"tags": tags,
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
@@ -244,6 +318,12 @@ class PassageManager:
raise ValueError("Source passage cannot have archive_id")
data = pydantic_passage.model_dump(to_orm=True)
# Deduplicate tags if provided (for dual storage consistency)
tags = data.get("tags")
if tags:
tags = list(set(tags))
common_fields = {
"id": data.get("id"),
"text": data["text"],
@@ -251,6 +331,7 @@ class PassageManager:
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"tags": tags,
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
@@ -310,6 +391,7 @@ class PassageManager:
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"tags": data.get("tags"),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
@@ -357,6 +439,7 @@ class PassageManager:
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"tags": data.get("tags"),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
@@ -396,6 +479,7 @@ class PassageManager:
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"tags": data.get("tags"),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
@@ -466,8 +550,19 @@ class PassageManager:
agent_state: AgentState,
text: str,
actor: PydanticUser,
tags: Optional[List[str]] = None,
) -> List[PydanticPassage]:
"""Insert passage(s) into archival memory"""
"""Insert passage(s) into archival memory
Args:
agent_state: Agent state for embedding configuration
text: Text content to store as passages
actor: User performing the operation
tags: Optional list of tags to attach to all created passages
Returns:
List of created passage objects
"""
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
embedding_client = LLMClient.create(
@@ -500,6 +595,7 @@ class PassageManager:
text=chunk_text,
embedding=embedding,
embedding_config=agent_state.embedding_config,
tags=tags,
),
actor=actor,
)
@@ -522,6 +618,7 @@ class PassageManager:
embeddings=embeddings,
passage_ids=passage_ids, # Use same IDs as SQL
organization_id=actor.organization_id,
tags=tags,
created_at=passages[0].created_at if passages else None,
)
@@ -590,6 +687,34 @@ class PassageManager:
# Update the database record with values from the provided record
update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Handle tags update separately for junction table
new_tags = update_data.pop("tags", None)
if new_tags is not None:
# Deduplicate tags
if new_tags:
new_tags = list(set(new_tags))
# Delete existing tags from junction table
from sqlalchemy import delete
await session.execute(delete(PassageTag).where(PassageTag.passage_id == passage_id))
# Create new tags in junction table
if new_tags:
await self._create_tags_for_passage(
session=session,
passage_id=passage_id,
archive_id=curr_passage.archive_id,
organization_id=curr_passage.organization_id,
tags=new_tags,
actor=actor,
)
# Update the tags on the passage object
setattr(curr_passage, "tags", new_tags)
# Update other fields
for key, value in update_data.items():
setattr(curr_passage, key, value)
@@ -1067,3 +1192,69 @@ class PassageManager:
)
passages = result.scalars().all()
return [p.to_pydantic() for p in passages]
@enforce_types
@trace_method
async def get_unique_tags_for_archive_async(
self,
archive_id: str,
actor: PydanticUser,
) -> List[str]:
"""Get all unique tags for an archive.
Args:
archive_id: ID of the archive
actor: User performing the operation
Returns:
List of unique tag values
"""
async with db_registry.async_session() as session:
stmt = (
select(PassageTag.tag)
.distinct()
.where(
PassageTag.archive_id == archive_id,
PassageTag.organization_id == actor.organization_id,
PassageTag.is_deleted == False,
)
.order_by(PassageTag.tag)
)
result = await session.execute(stmt)
tags = result.scalars().all()
return list(tags)
@enforce_types
@trace_method
async def get_tag_counts_for_archive_async(
self,
archive_id: str,
actor: PydanticUser,
) -> Dict[str, int]:
"""Get tag counts for an archive.
Args:
archive_id: ID of the archive
actor: User performing the operation
Returns:
Dictionary mapping tag values to their counts
"""
async with db_registry.async_session() as session:
stmt = (
select(PassageTag.tag, func.count(PassageTag.id).label("count"))
.where(
PassageTag.archive_id == archive_id,
PassageTag.organization_id == actor.organization_id,
PassageTag.is_deleted == False,
)
.group_by(PassageTag.tag)
.order_by(PassageTag.tag)
)
result = await session.execute(stmt)
rows = result.all()
return {row.tag: row.count for row in rows}

View File

@@ -13,9 +13,7 @@ from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.tool_executor.tool_executor_base import ToolExecutor
from letta.utils import get_friendly_error_msg
@@ -143,7 +141,7 @@ class LettaCoreToolExecutor(ToolExecutor):
try:
# Get results using passage manager
all_results = await AgentManager().query_agent_passages_async(
all_results = await self.agent_manager.query_agent_passages_async(
actor=actor,
agent_id=agent_state.id,
query_text=query,
@@ -174,12 +172,12 @@ class LettaCoreToolExecutor(ToolExecutor):
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
await PassageManager().insert_passage(
await self.passage_manager.insert_passage(
agent_state=agent_state,
text=content,
actor=actor,
)
await AgentManager().rebuild_system_prompt_async(agent_id=agent_state.id, actor=actor, force=True)
await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_state.id, actor=actor, force=True)
return None
async def core_memory_append(self, agent_state: AgentState, actor: User, label: str, content: str) -> Optional[str]:
@@ -198,7 +196,7 @@ class LettaCoreToolExecutor(ToolExecutor):
current_value = str(agent_state.memory.get_block(label).value)
new_value = current_value + "\n" + str(content)
agent_state.memory.update_block_value(label=label, value=new_value)
await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
return None
async def core_memory_replace(
@@ -227,7 +225,7 @@ class LettaCoreToolExecutor(ToolExecutor):
raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'")
new_value = current_value.replace(str(old_content), str(new_content))
agent_state.memory.update_block_value(label=label, value=new_value)
await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
return None
async def memory_replace(self, agent_state: AgentState, actor: User, label: str, old_str: str, new_str: str) -> str:
@@ -291,7 +289,7 @@ class LettaCoreToolExecutor(ToolExecutor):
# Write the new content to the block
agent_state.memory.update_block_value(label=label, value=new_value)
await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
# Create a snippet of the edited section
SNIPPET_LINES = 3
@@ -384,7 +382,7 @@ class LettaCoreToolExecutor(ToolExecutor):
# Write into the block
agent_state.memory.update_block_value(label=label, value=new_value)
await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
# Prepare the success message
success_msg = f"The core memory block with label `{label}` has been edited. "
@@ -437,7 +435,7 @@ class LettaCoreToolExecutor(ToolExecutor):
agent_state.memory.update_block_value(label=label, value=new_memory)
await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
# Prepare the success message
success_msg = f"The core memory block with label `{label}` has been edited. "

View File

@@ -64,6 +64,7 @@ from letta.schemas.enums import (
ProviderType,
SandboxType,
StepStatus,
TagMatchMode,
ToolType,
)
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
@@ -3166,6 +3167,7 @@ def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_a
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata={"type": "test_specific"},
tags=["python", "test", "agent"],
),
actor=default_user,
)
@@ -3174,6 +3176,7 @@ def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_a
assert passage.text == "Test agent passage via specific method"
assert passage.archive_id == archive.id
assert passage.source_id is None
assert sorted(passage.tags) == sorted(["python", "test", "agent"])
def test_create_source_passage_specific(server: SyncServer, default_user, default_file, default_source):
@@ -3187,6 +3190,7 @@ def test_create_source_passage_specific(server: SyncServer, default_user, defaul
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata={"type": "test_specific"},
tags=["document", "test", "source"],
),
file_metadata=default_file,
actor=default_user,
@@ -3196,6 +3200,7 @@ def test_create_source_passage_specific(server: SyncServer, default_user, defaul
assert passage.text == "Test source passage via specific method"
assert passage.source_id == default_source.id
assert passage.archive_id is None
assert sorted(passage.tags) == sorted(["document", "test", "source"])
def test_create_agent_passage_validation(server: SyncServer, default_user, default_source, sarah_agent):
@@ -3509,6 +3514,7 @@ async def test_create_many_agent_passages_async(server: SyncServer, default_user
organization_id=default_user.organization_id,
embedding=[0.1 * i],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
tags=["batch", f"item{i}"] if i % 2 == 0 else ["batch", "odd"],
)
for i in range(3)
]
@@ -3520,6 +3526,8 @@ async def test_create_many_agent_passages_async(server: SyncServer, default_user
assert passage.text == f"Batch agent passage {i}"
assert passage.archive_id == archive.id
assert passage.source_id is None
expected_tags = ["batch", f"item{i}"] if i % 2 == 0 else ["batch", "odd"]
assert passage.tags == expected_tags
@pytest.mark.asyncio
@@ -3611,6 +3619,374 @@ def test_deprecated_methods_show_warnings(server: SyncServer, default_user, sara
assert any("size is deprecated" in str(warning.message) for warning in w)
@pytest.mark.asyncio
async def test_passage_tags_functionality(server: SyncServer, default_user, sarah_agent):
"""Test comprehensive tag functionality for passages."""
from letta.schemas.enums import TagMatchMode
# Get or create default archive for the agent
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
)
# Create passages with different tag combinations
test_passages = [
{"text": "Python programming tutorial", "tags": ["python", "tutorial", "programming"]},
{"text": "Machine learning with Python", "tags": ["python", "ml", "ai"]},
{"text": "JavaScript web development", "tags": ["javascript", "web", "frontend"]},
{"text": "Python data science guide", "tags": ["python", "tutorial", "data"]},
{"text": "No tags passage", "tags": None},
]
created_passages = []
for test_data in test_passages:
passage = await server.passage_manager.create_agent_passage_async(
PydanticPassage(
text=test_data["text"],
archive_id=archive.id,
organization_id=default_user.organization_id,
embedding=[0.1, 0.2, 0.3],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
tags=test_data["tags"],
),
actor=default_user,
)
created_passages.append(passage)
# Test that tags are properly stored (deduplicated)
for i, passage in enumerate(created_passages):
expected_tags = test_passages[i]["tags"]
if expected_tags:
assert set(passage.tags) == set(expected_tags)
else:
assert passage.tags is None
# Test querying with tag filtering (if Turbopuffer is enabled)
if hasattr(server.agent_manager, "query_agent_passages_async"):
# Test querying with python tag (should find 3 passages)
python_results = await server.agent_manager.query_agent_passages_async(
actor=default_user,
agent_id=sarah_agent.id,
tags=["python"],
tag_match_mode=TagMatchMode.ANY,
)
python_texts = [p.text for p in python_results]
assert len([t for t in python_texts if "Python" in t]) >= 2
# Test querying with multiple tags using ALL mode
tutorial_python_results = await server.agent_manager.query_agent_passages_async(
actor=default_user,
agent_id=sarah_agent.id,
tags=["python", "tutorial"],
tag_match_mode=TagMatchMode.ALL,
)
tutorial_texts = [p.text for p in tutorial_python_results]
expected_matches = [t for t in tutorial_texts if "tutorial" in t and "Python" in t]
assert len(expected_matches) >= 1
@pytest.mark.asyncio
async def test_comprehensive_tag_functionality(disable_turbopuffer, server: SyncServer, sarah_agent, default_user):
"""Comprehensive test for tag functionality including dual storage and junction table."""
# Test 1: Create passages with tags and verify they're stored in both places
passages_with_tags = []
test_tags = {
"passage1": ["important", "documentation", "python"],
"passage2": ["important", "testing"],
"passage3": ["documentation", "api"],
"passage4": ["python", "testing", "api"],
"passage5": [], # Test empty tags
}
for i, (passage_key, tags) in enumerate(test_tags.items(), 1):
text = f"Test passage {i} for comprehensive tag testing"
created_passages = await server.passage_manager.insert_passage(
agent_state=sarah_agent,
text=text,
actor=default_user,
tags=tags if tags else None,
)
assert len(created_passages) == 1
passage = created_passages[0]
# Verify tags are stored in the JSON column (deduplicated)
if tags:
assert set(passage.tags) == set(tags)
else:
assert passage.tags is None
passages_with_tags.append(passage)
# Test 2: Verify unique tags for archive
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id,
agent_name=sarah_agent.name,
actor=default_user,
)
unique_tags = await server.passage_manager.get_unique_tags_for_archive_async(
archive_id=archive.id,
actor=default_user,
)
# Should have all unique tags: "important", "documentation", "python", "testing", "api"
expected_unique_tags = {"important", "documentation", "python", "testing", "api"}
assert set(unique_tags) == expected_unique_tags
assert len(unique_tags) == 5
# Test 3: Verify tag counts
tag_counts = await server.passage_manager.get_tag_counts_for_archive_async(
archive_id=archive.id,
actor=default_user,
)
# Verify counts
assert tag_counts["important"] == 2 # passage1 and passage2
assert tag_counts["documentation"] == 2 # passage1 and passage3
assert tag_counts["python"] == 2 # passage1 and passage4
assert tag_counts["testing"] == 2 # passage2 and passage4
assert tag_counts["api"] == 2 # passage3 and passage4
# Test 4: Query passages with ANY tag matching
any_results = await server.agent_manager.query_agent_passages_async(
agent_id=sarah_agent.id,
query_text="test",
limit=10,
tags=["important", "api"],
tag_match_mode=TagMatchMode.ANY,
actor=default_user,
)
# Should match passages with "important" OR "api" tags (passages 1, 2, 3, 4)
[p.text for p in any_results]
assert len(any_results) >= 4
# Test 5: Query passages with ALL tag matching
all_results = await server.agent_manager.query_agent_passages_async(
agent_id=sarah_agent.id,
query_text="test",
limit=10,
tags=["python", "testing"],
tag_match_mode=TagMatchMode.ALL,
actor=default_user,
)
# Should only match passage4 which has both "python" AND "testing"
all_passage_texts = [p.text for p in all_results]
assert any("Test passage 4" in text for text in all_passage_texts)
# Test 6: Query with non-existent tags
no_results = await server.agent_manager.query_agent_passages_async(
agent_id=sarah_agent.id,
query_text="test",
limit=10,
tags=["nonexistent", "missing"],
tag_match_mode=TagMatchMode.ANY,
actor=default_user,
)
# Should return no results
assert len(no_results) == 0
# Test 7: Verify tags CAN be updated (with junction table properly maintained)
first_passage = passages_with_tags[0]
new_tags = ["updated", "modified", "changed"]
update_data = PydanticPassage(
id=first_passage.id,
text="Updated text",
tags=new_tags,
organization_id=first_passage.organization_id,
archive_id=first_passage.archive_id,
embedding=first_passage.embedding,
embedding_config=first_passage.embedding_config,
)
# Update should work and tags should be updated
updated = await server.passage_manager.update_agent_passage_by_id_async(
passage_id=first_passage.id,
passage=update_data,
actor=default_user,
)
# Both text and tags should be updated
assert updated.text == "Updated text"
assert set(updated.tags) == set(new_tags)
# Verify tags are properly updated in junction table
updated_unique_tags = await server.passage_manager.get_unique_tags_for_archive_async(
archive_id=archive.id,
actor=default_user,
)
# Should include new tags and not include old "important", "documentation", "python" from passage1
# But still have tags from other passages
assert "updated" in updated_unique_tags
assert "modified" in updated_unique_tags
assert "changed" in updated_unique_tags
# Test 8: Delete a passage and verify cascade deletion of tags
passage_to_delete = passages_with_tags[1] # passage2 with ["important", "testing"]
await server.passage_manager.delete_agent_passage_by_id_async(
passage_id=passage_to_delete.id,
actor=default_user,
)
# Get updated tag counts
updated_tag_counts = await server.passage_manager.get_tag_counts_for_archive_async(
archive_id=archive.id,
actor=default_user,
)
# "important" no longer exists (was in passage1 which was updated and passage2 which was deleted)
assert "important" not in updated_tag_counts
# "testing" count should decrease from 2 to 1 (only in passage4 now)
assert updated_tag_counts["testing"] == 1
# Test 9: Batch create passages with tags
batch_texts = [
"Batch passage 1",
"Batch passage 2",
"Batch passage 3",
]
batch_tags = ["batch", "test", "multiple"]
batch_passages = []
for text in batch_texts:
passages = await server.passage_manager.insert_passage(
agent_state=sarah_agent,
text=text,
actor=default_user,
tags=batch_tags,
)
batch_passages.extend(passages)
# Verify all batch passages have the same tags
for passage in batch_passages:
assert set(passage.tags) == set(batch_tags)
# Test 10: Verify tag counts include batch passages
final_tag_counts = await server.passage_manager.get_tag_counts_for_archive_async(
archive_id=archive.id,
actor=default_user,
)
assert final_tag_counts["batch"] == 3
assert final_tag_counts["test"] == 3
assert final_tag_counts["multiple"] == 3
# Test 11: Complex query with multiple tags and ALL matching
complex_all_results = await server.agent_manager.query_agent_passages_async(
agent_id=sarah_agent.id,
query_text="batch",
limit=10,
tags=["batch", "test", "multiple"],
tag_match_mode=TagMatchMode.ALL,
actor=default_user,
)
# Should match all 3 batch passages
assert len(complex_all_results) >= 3
# Test 12: Empty tag list should return all passages
all_passages = await server.agent_manager.query_agent_passages_async(
agent_id=sarah_agent.id,
query_text="passage",
limit=50,
tags=[],
tag_match_mode=TagMatchMode.ANY,
actor=default_user,
)
# Should return passages based on text search only
assert len(all_passages) > 0
@pytest.mark.asyncio
async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_agent, default_user):
"""Test edge cases for tag functionality."""
# Test 1: Very long tag names
long_tag = "a" * 500 # 500 character tag
passages = await server.passage_manager.insert_passage(
agent_state=sarah_agent,
text="Testing long tag names",
actor=default_user,
tags=[long_tag, "normal_tag"],
)
assert len(passages) == 1
assert long_tag in passages[0].tags
# Test 2: Special characters in tags
special_tags = [
"tag-with-dash",
"tag_with_underscore",
"tag.with.dots",
"tag/with/slash",
"tag:with:colon",
"tag@with@at",
"tag#with#hash",
"tag with spaces",
"CamelCaseTag",
"数字标签",
]
passages_special = await server.passage_manager.insert_passage(
agent_state=sarah_agent,
text="Testing special character tags",
actor=default_user,
tags=special_tags,
)
assert len(passages_special) == 1
assert set(passages_special[0].tags) == set(special_tags)
# Verify unique tags includes all special character tags
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
agent_id=sarah_agent.id,
agent_name=sarah_agent.name,
actor=default_user,
)
unique_tags = await server.passage_manager.get_unique_tags_for_archive_async(
archive_id=archive.id,
actor=default_user,
)
for tag in special_tags:
assert tag in unique_tags
# Test 3: Duplicate tags in input (should be deduplicated)
duplicate_tags = ["tag1", "tag2", "tag1", "tag3", "tag2", "tag1"]
passages_dup = await server.passage_manager.insert_passage(
agent_state=sarah_agent,
text="Testing duplicate tags",
actor=default_user,
tags=duplicate_tags,
)
# Should only have unique tags (duplicates removed)
assert len(passages_dup) == 1
assert set(passages_dup[0].tags) == {"tag1", "tag2", "tag3"}
assert len(passages_dup[0].tags) == 3 # Should be deduplicated
# Test 4: Case sensitivity in tags
case_tags = ["Tag", "tag", "TAG", "tAg"]
passages_case = await server.passage_manager.insert_passage(
agent_state=sarah_agent,
text="Testing case sensitive tags",
actor=default_user,
tags=case_tags,
)
# All variations should be preserved (case-sensitive)
assert len(passages_case) == 1
assert set(passages_case[0].tags) == set(case_tags)
# ======================================================================================================================
# User Manager Tests
# ======================================================================================================================

View File

@@ -6,7 +6,7 @@ import pytest
from letta.config import LettaConfig
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import VectorDBProvider
from letta.schemas.enums import TagMatchMode, VectorDBProvider
from letta.server.server import SyncServer
from letta.settings import settings
@@ -233,16 +233,14 @@ class TestTurbopufferIntegration:
assert len(result) == 3
# Query with organization filter
# Query all passages (no tag filtering)
query_vector = [0.15] * 1536
results = await client.query_passages(
archive_id=archive_id, query_embedding=query_vector, top_k=10, filters={"organization_id": "org-123"}
)
results = await client.query_passages(archive_id=archive_id, query_embedding=query_vector, top_k=10)
# Should only get passages from org-123
assert len(results) >= 2 # At least the first two passages
# Should get all passages
assert len(results) == 3 # All three passages
for passage, score in results:
assert passage.organization_id == "org-123"
assert passage.organization_id is not None
# Clean up
await client.delete_passages(archive_id=archive_id, passage_ids=[d["id"] for d in test_data])
@@ -394,6 +392,130 @@ class TestTurbopufferIntegration:
except:
pass
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
async def test_tag_filtering_with_real_tpuf(self, enable_turbopuffer):
"""Test tag filtering functionality with AND and OR logic"""
import uuid
from letta.helpers.tpuf_client import TurbopufferClient
client = TurbopufferClient()
archive_id = f"test-tags-{datetime.now().timestamp()}"
org_id = str(uuid.uuid4())
try:
# Insert passages with different tag combinations
texts = [
"Python programming tutorial",
"Machine learning with Python",
"JavaScript web development",
"Python data science tutorial",
"React JavaScript framework",
]
tag_sets = [
["python", "tutorial"],
["python", "ml"],
["javascript", "web"],
["python", "tutorial", "data"],
["javascript", "react"],
]
embeddings = [[float(i), float(i + 5), float(i + 10)] for i in range(len(texts))]
passage_ids = [f"passage-{str(uuid.uuid4())}" for _ in texts]
# Insert passages with tags
for i, (text, tags, embedding, passage_id) in enumerate(zip(texts, tag_sets, embeddings, passage_ids)):
await client.insert_archival_memories(
archive_id=archive_id,
text_chunks=[text],
embeddings=[embedding],
passage_ids=[passage_id],
organization_id=org_id,
tags=tags,
created_at=datetime.now(timezone.utc),
)
# Test tag filtering with "any" mode (should find passages with any of the specified tags)
python_any_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 6.0, 11.0],
search_mode="vector",
top_k=10,
tags=["python"],
tag_match_mode=TagMatchMode.ANY,
)
# Should find 3 passages with python tag
python_passages = [passage for passage, _ in python_any_results]
python_texts = [p.text for p in python_passages]
assert len(python_passages) == 3
assert "Python programming tutorial" in python_texts
assert "Machine learning with Python" in python_texts
assert "Python data science tutorial" in python_texts
# Test tag filtering with "all" mode
python_tutorial_all_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 6.0, 11.0],
search_mode="vector",
top_k=10,
tags=["python", "tutorial"],
tag_match_mode=TagMatchMode.ALL,
)
# Should find 2 passages that have both python AND tutorial tags
tutorial_passages = [passage for passage, _ in python_tutorial_all_results]
tutorial_texts = [p.text for p in tutorial_passages]
assert len(tutorial_passages) == 2
assert "Python programming tutorial" in tutorial_texts
assert "Python data science tutorial" in tutorial_texts
# Test tag filtering with FTS mode
js_fts_results = await client.query_passages(
archive_id=archive_id,
query_text="javascript",
search_mode="fts",
top_k=10,
tags=["javascript"],
tag_match_mode=TagMatchMode.ANY,
)
# Should find 2 passages with javascript tag
js_passages = [passage for passage, _ in js_fts_results]
js_texts = [p.text for p in js_passages]
assert len(js_passages) == 2
assert "JavaScript web development" in js_texts
assert "React JavaScript framework" in js_texts
# Test hybrid search with tags
python_hybrid_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[2.0, 7.0, 12.0],
query_text="python programming",
search_mode="hybrid",
top_k=10,
tags=["python"],
tag_match_mode=TagMatchMode.ANY,
vector_weight=0.6,
fts_weight=0.4,
)
# Should find python-tagged passages
hybrid_passages = [passage for passage, _ in python_hybrid_results]
hybrid_texts = [p.text for p in hybrid_passages]
assert len(hybrid_passages) == 3
assert all("Python" in text for text in hybrid_texts)
finally:
# Clean up
try:
await client.delete_all_passages(archive_id)
except:
pass
@pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True)
class TestTurbopufferParametrized: