diff --git a/alembic/versions/54c76f7cabca_add_tags_to_passages_and_create_passage_.py b/alembic/versions/54c76f7cabca_add_tags_to_passages_and_create_passage_.py new file mode 100644 index 00000000..0cfa65f5 --- /dev/null +++ b/alembic/versions/54c76f7cabca_add_tags_to_passages_and_create_passage_.py @@ -0,0 +1,73 @@ +"""Add tags to passages and create passage_tags junction table + +Revision ID: 54c76f7cabca +Revises: c41c87205254 +Create Date: 2025-08-28 15:13:01.549590 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op +from letta.settings import settings + +# revision identifiers, used by Alembic. +revision: str = "54c76f7cabca" +down_revision: Union[str, None] = "c41c87205254" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + # Database-specific timestamp defaults + if not settings.letta_pg_uri_no_default: + # SQLite uses CURRENT_TIMESTAMP + timestamp_default = sa.text("(CURRENT_TIMESTAMP)") + else: + # PostgreSQL uses now() + timestamp_default = sa.text("now()") + + op.create_table( + "passage_tags", + sa.Column("id", sa.String(), nullable=False), + sa.Column("tag", sa.String(), nullable=False), + sa.Column("passage_id", sa.String(), nullable=False), + sa.Column("archive_id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=timestamp_default, nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=timestamp_default, nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["archive_id"], ["archives.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.ForeignKeyConstraint(["passage_id"], ["archival_passages.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("passage_id", "tag", name="uq_passage_tag"), + ) + op.create_index("ix_passage_tags_archive_id", "passage_tags", ["archive_id"], unique=False) + op.create_index("ix_passage_tags_archive_tag", "passage_tags", ["archive_id", "tag"], unique=False) + op.create_index("ix_passage_tags_org_archive", "passage_tags", ["organization_id", "archive_id"], unique=False) + op.create_index("ix_passage_tags_tag", "passage_tags", ["tag"], unique=False) + op.add_column("archival_passages", sa.Column("tags", sa.JSON(), nullable=True)) + op.add_column("source_passages", sa.Column("tags", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("source_passages", "tags") + op.drop_column("archival_passages", "tags") + op.drop_index("ix_passage_tags_tag", table_name="passage_tags") + op.drop_index("ix_passage_tags_org_archive", table_name="passage_tags") + op.drop_index("ix_passage_tags_archive_tag", table_name="passage_tags") + op.drop_index("ix_passage_tags_archive_id", table_name="passage_tags") + op.drop_table("passage_tags") + # ### end Alembic commands ### diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index 2ed4b4ea..3519dd60 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -2,9 +2,10 @@ import logging from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple from letta.otel.tracing import trace_method +from letta.schemas.enums import TagMatchMode from letta.schemas.passage import Passage as PydanticPassage from letta.settings import settings @@ -168,7 +169,8 @@ class TurbopufferClient: query_text: Optional[str] = None, search_mode: str = "vector", # "vector", "fts", "hybrid" top_k: int = 10, - filters: Optional[Dict[str, Any]] = None, + tags: Optional[List[str]] = None, + tag_match_mode: TagMatchMode = TagMatchMode.ANY, vector_weight: float = 0.5, fts_weight: float = 0.5, ) -> List[Tuple[PydanticPassage, float]]: @@ -180,7 +182,8 @@ class TurbopufferClient: query_text: Text query for full-text search (required for "fts" and "hybrid" modes) search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector") top_k: Number of results to return - filters: Optional filter conditions + tags: Optional list of tags to filter by + tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY vector_weight: Weight for vector search results in hybrid mode (default: 0.5) fts_weight: Weight for FTS results in hybrid mode (default: 0.5) @@ -207,15 +210,19 @@ class TurbopufferClient: async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: namespace = client.namespace(namespace_name) - # build filter conditions - filter_conditions = [] - if filters: - for key, value in filters.items(): - filter_conditions.append((key, "Eq", value)) + # build tag filter conditions + tag_filter = None + if tags: + tag_conditions = [] + for tag in tags: + tag_conditions.append((tag, "Eq", True)) - base_filter = ( - ("And", filter_conditions) if len(filter_conditions) > 1 else (filter_conditions[0] if filter_conditions else None) - ) + if len(tag_conditions) == 1: + tag_filter = tag_conditions[0] + elif tag_match_mode == TagMatchMode.ALL: + tag_filter = ("And", tag_conditions) + else: # tag_match_mode == TagMatchMode.ANY + tag_filter = ("Or", tag_conditions) if search_mode == "vector": # single vector search query @@ -224,11 +231,11 @@ class TurbopufferClient: "top_k": top_k, "include_attributes": ["text", "organization_id", "archive_id", "created_at"], } - if base_filter: - query_params["filters"] = base_filter + if tag_filter: + query_params["filters"] = tag_filter result = await namespace.query(**query_params) - return self._process_single_query_results(result, archive_id, filters) + return self._process_single_query_results(result, archive_id, tags) elif search_mode == "fts": # single full-text search query @@ -237,11 +244,11 @@ class TurbopufferClient: "top_k": top_k, "include_attributes": ["text", "organization_id", "archive_id", "created_at"], } - if base_filter: - query_params["filters"] = base_filter + if tag_filter: + query_params["filters"] = tag_filter result = await namespace.query(**query_params) - return self._process_single_query_results(result, archive_id, filters, is_fts=True) + return self._process_single_query_results(result, archive_id, tags, is_fts=True) else: # hybrid mode # multi-query for both vector and FTS @@ -253,8 +260,8 @@ class TurbopufferClient: "top_k": top_k, "include_attributes": ["text", "organization_id", "archive_id", "created_at"], } - if base_filter: - vector_query["filters"] = base_filter + if tag_filter: + vector_query["filters"] = tag_filter queries.append(vector_query) # full-text search query @@ -263,16 +270,16 @@ class TurbopufferClient: "top_k": top_k, "include_attributes": ["text", "organization_id", "archive_id", "created_at"], } - if base_filter: - fts_query["filters"] = base_filter + if tag_filter: + fts_query["filters"] = tag_filter queries.append(fts_query) # execute multi-query response = await namespace.multi_query(queries=[QueryParam(**q) for q in queries]) # process and combine results using reciprocal rank fusion - vector_results = self._process_single_query_results(response.results[0], archive_id, filters) - fts_results = self._process_single_query_results(response.results[1], archive_id, filters, is_fts=True) + vector_results = self._process_single_query_results(response.results[0], archive_id, tags) + fts_results = self._process_single_query_results(response.results[1], archive_id, tags, is_fts=True) # combine results using reciprocal rank fusion return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k) @@ -282,16 +289,16 @@ class TurbopufferClient: raise def _process_single_query_results( - self, result, archive_id: str, filters: Optional[Dict[str, Any]], is_fts: bool = False + self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False ) -> List[Tuple[PydanticPassage, float]]: """Process results from a single query into passage objects with scores.""" passages_with_scores = [] for row in result.rows: - # Build metadata including any filter conditions that were applied + # Build metadata including any tag filters that were applied metadata = {} - if filters: - metadata["applied_filters"] = filters + if tags: + metadata["applied_tags"] = tags # Create a passage with minimal fields - embeddings are not returned from Turbopuffer passage = PydanticPassage( @@ -300,7 +307,7 @@ class TurbopufferClient: organization_id=getattr(row, "organization_id", None), archive_id=archive_id, # use the archive_id from the query created_at=getattr(row, "created_at", None), - metadata_=metadata, # Include filter conditions in metadata + metadata_=metadata, # Include tag filters in metadata # Set required fields to empty/default values since we don't store embeddings embedding=[], # Empty embedding since we don't return it from Turbopuffer embedding_config=None, # No embedding config needed for retrieved passages diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index e5dcd47e..f2d1bd15 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -22,6 +22,7 @@ from letta.orm.mcp_server import MCPServer from letta.orm.message import Message from letta.orm.organization import Organization from letta.orm.passage import ArchivalPassage, BasePassage, SourcePassage +from letta.orm.passage_tag import PassageTag from letta.orm.prompt import Prompt from letta.orm.provider import Provider from letta.orm.provider_trace import ProviderTrace diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 6e63df5a..57ab0c52 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from letta.orm.llm_batch_job import LLMBatchJob from letta.orm.message import Message from letta.orm.passage import ArchivalPassage, SourcePassage + from letta.orm.passage_tag import PassageTag from letta.orm.provider import Provider from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable from letta.orm.tool import Tool @@ -56,6 +57,7 @@ class Organization(SqlalchemyBase): archival_passages: Mapped[List["ArchivalPassage"]] = relationship( "ArchivalPassage", back_populates="organization", cascade="all, delete-orphan" ) + passage_tags: Mapped[List["PassageTag"]] = relationship("PassageTag", back_populates="organization", cascade="all, delete-orphan") archives: Mapped[List["Archive"]] = relationship("Archive", back_populates="organization", cascade="all, delete-orphan") providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan") identities: Mapped[List["Identity"]] = relationship("Identity", back_populates="organization", cascade="all, delete-orphan") diff --git a/letta/orm/passage.py b/letta/orm/passage.py index 9507ffc0..cf17bc83 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional from sqlalchemy import JSON, Column, Index from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship @@ -27,6 +27,8 @@ class BasePassage(SqlalchemyBase, OrganizationMixin): text: Mapped[str] = mapped_column(doc="Passage text content") embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration") metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata") + # dual storage: json column for fast retrieval, junction table for efficient queries + tags: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="Tags associated with this passage") # Vector embedding field based on database type if settings.database_engine is DatabaseChoice.POSTGRES: @@ -75,6 +77,11 @@ class ArchivalPassage(BasePassage, ArchiveMixin): __tablename__ = "archival_passages" + # junction table for efficient tag queries (complements json column above) + passage_tags: Mapped[List["PassageTag"]] = relationship( + "PassageTag", back_populates="passage", cascade="all, delete-orphan", lazy="noload" + ) + @declared_attr def organization(cls) -> Mapped["Organization"]: return relationship("Organization", back_populates="archival_passages", lazy="selectin") diff --git a/letta/orm/passage_tag.py b/letta/orm/passage_tag.py new file mode 100644 index 00000000..45f24f0a --- /dev/null +++ b/letta/orm/passage_tag.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING + +from sqlalchemy import ForeignKey, Index, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase + +if TYPE_CHECKING: + from letta.orm.organization import Organization + from letta.orm.passage import ArchivalPassage + + +class PassageTag(SqlalchemyBase, OrganizationMixin): + """Junction table for tags associated with passages. + + Design: dual storage approach where tags are stored both in: + 1. JSON column in passages table (fast retrieval with passage data) + 2. This junction table (efficient DISTINCT/COUNT queries and filtering) + """ + + __tablename__ = "passage_tags" + + __table_args__ = ( + # ensure uniqueness of tag per passage + UniqueConstraint("passage_id", "tag", name="uq_passage_tag"), + # indexes for efficient queries + Index("ix_passage_tags_archive_id", "archive_id"), + Index("ix_passage_tags_tag", "tag"), + Index("ix_passage_tags_archive_tag", "archive_id", "tag"), + Index("ix_passage_tags_org_archive", "organization_id", "archive_id"), + ) + + # primary key + id: Mapped[str] = mapped_column(String, primary_key=True, doc="Unique identifier for the tag entry") + + # tag value + tag: Mapped[str] = mapped_column(String, nullable=False, doc="The tag value") + + # foreign keys + passage_id: Mapped[str] = mapped_column( + String, ForeignKey("archival_passages.id", ondelete="CASCADE"), nullable=False, doc="ID of the passage this tag belongs to" + ) + + archive_id: Mapped[str] = mapped_column( + String, + ForeignKey("archives.id", ondelete="CASCADE"), + nullable=False, + doc="ID of the archive this passage belongs to (denormalized for efficient queries)", + ) + + # relationships + passage: Mapped["ArchivalPassage"] = relationship("ArchivalPassage", back_populates="passage_tags", lazy="noload") + + organization: Mapped["Organization"] = relationship("Organization", back_populates="passage_tags", lazy="selectin") diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index b8b62802..da8182bb 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -179,3 +179,10 @@ class VectorDBProvider(str, Enum): NATIVE = "native" TPUF = "tpuf" + + +class TagMatchMode(str, Enum): + """Tag matching behavior for filtering""" + + ANY = "any" + ALL = "all" diff --git a/letta/schemas/passage.py b/letta/schemas/passage.py index 57ab3f3c..fdaac2f2 100644 --- a/letta/schemas/passage.py +++ b/letta/schemas/passage.py @@ -25,6 +25,7 @@ class PassageBase(OrmMetadataBase): file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.") file_name: Optional[str] = Field(None, description="The name of the file (only for source passages).") metadata: Optional[Dict] = Field({}, validation_alias="metadata_", description="The metadata of the passage.") + tags: Optional[List[str]] = Field(None, description="Tags associated with this passage.") class Passage(PassageBase): diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 735bc588..cdc3a678 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -48,7 +48,7 @@ from letta.schemas.block import DEFAULT_BLOCKS from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import ProviderType, ToolType, VectorDBProvider +from letta.schemas.enums import ProviderType, TagMatchMode, ToolType, VectorDBProvider from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.group import Group as PydanticGroup from letta.schemas.group import ManagerType @@ -2649,6 +2649,8 @@ class AgentManager: embed_query: bool = False, ascending: bool = True, embedding_config: Optional[EmbeddingConfig] = None, + tags: Optional[List[str]] = None, + tag_match_mode: Optional[TagMatchMode] = None, ) -> List[PydanticPassage]: """Lists all passages attached to an agent.""" # Check if we should use Turbopuffer for vector search @@ -2686,6 +2688,8 @@ class AgentManager: query_text=query_text, # pass text for potential hybrid search search_mode="hybrid", # use hybrid mode for better results top_k=limit, + tags=tags, + tag_match_mode=tag_match_mode or TagMatchMode.ANY, ) # Return just the passages (without scores) @@ -2719,7 +2723,30 @@ class AgentManager: passages = result.scalars().all() # Convert to Pydantic models - return [p.to_pydantic() for p in passages] + pydantic_passages = [p.to_pydantic() for p in passages] + + # TODO: Integrate tag filtering directly into the SQL query for better performance. + # Currently using post-filtering which is less efficient but simpler to implement. + # Future optimization: Add JOIN with passage_tags table and WHERE clause for tag filtering. + if tags: + filtered_passages = [] + for passage in pydantic_passages: + if passage.tags: + passage_tags = set(passage.tags) + query_tags = set(tags) + + if tag_match_mode == TagMatchMode.ALL: + # ALL mode: passage must have all query tags + if query_tags.issubset(passage_tags): + filtered_passages.append(passage) + else: + # ANY mode (default): passage must have at least one query tag + if query_tags.intersection(passage_tags): + filtered_passages.append(passage) + + return filtered_passages + + return pydantic_passages @enforce_types @trace_method diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 7a9e825e..453d1e0e 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -1,9 +1,11 @@ +import uuid from datetime import datetime, timezone from functools import lru_cache -from typing import List, Optional +from typing import Dict, List, Optional from openai import AsyncOpenAI, OpenAI -from sqlalchemy import select +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession from letta.constants import MAX_EMBEDDING_DIM from letta.embeddings import parse_and_chunk_text @@ -12,6 +14,7 @@ from letta.llm_api.llm_client import LLMClient from letta.orm import ArchivesAgents from letta.orm.errors import NoResultFound from letta.orm.passage import ArchivalPassage, SourcePassage +from letta.orm.passage_tag import PassageTag from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState from letta.schemas.enums import VectorDBProvider @@ -48,6 +51,44 @@ class PassageManager: def __init__(self): self.archive_manager = ArchiveManager() + async def _create_tags_for_passage( + self, + session: AsyncSession, + passage_id: str, + archive_id: str, + organization_id: str, + tags: List[str], + actor: PydanticUser, + ) -> List[PassageTag]: + """Create tag entries in junction table (complements tags stored in JSON column). + + Junction table enables efficient DISTINCT queries and tag-based filtering. + + Note: Tags are already deduplicated before being passed to this method. + """ + if not tags: + return [] + + tag_objects = [] + for tag in tags: + tag_obj = PassageTag( + id=f"passage-tag-{uuid.uuid4()}", + tag=tag, + passage_id=passage_id, + archive_id=archive_id, + organization_id=organization_id, + ) + tag_objects.append(tag_obj) + + # batch create all tags + created_tags = await PassageTag.batch_create_async( + items=tag_objects, + db_session=session, + actor=actor, + ) + + return created_tags + # AGENT PASSAGE METHODS @enforce_types @trace_method @@ -155,6 +196,12 @@ class PassageManager: raise ValueError("Agent passage cannot have source_id") data = pydantic_passage.model_dump(to_orm=True) + + # Deduplicate tags if provided (for dual storage consistency) + tags = data.get("tags") + if tags: + tags = list(set(tags)) + common_fields = { "id": data.get("id"), "text": data["text"], @@ -162,6 +209,7 @@ class PassageManager: "embedding_config": data["embedding_config"], "organization_id": data["organization_id"], "metadata_": data.get("metadata", {}), + "tags": tags, "is_deleted": data.get("is_deleted", False), "created_at": data.get("created_at", datetime.now(timezone.utc)), } @@ -182,6 +230,12 @@ class PassageManager: raise ValueError("Agent passage cannot have source_id") data = pydantic_passage.model_dump(to_orm=True) + + # Deduplicate tags if provided (for dual storage consistency) + tags = data.get("tags") + if tags: + tags = list(set(tags)) + common_fields = { "id": data.get("id"), "text": data["text"], @@ -189,6 +243,7 @@ class PassageManager: "embedding_config": data["embedding_config"], "organization_id": data["organization_id"], "metadata_": data.get("metadata", {}), + "tags": tags, "is_deleted": data.get("is_deleted", False), "created_at": data.get("created_at", datetime.now(timezone.utc)), } @@ -197,6 +252,18 @@ class PassageManager: async with db_registry.async_session() as session: passage = await passage.create_async(session, actor=actor) + + # dual storage: save tags to junction table for efficient queries + if tags: # use the deduplicated tags variable + await self._create_tags_for_passage( + session=session, + passage_id=passage.id, + archive_id=passage.archive_id, + organization_id=passage.organization_id, + tags=tags, # pass deduplicated tags + actor=actor, + ) + return passage.to_pydantic() @enforce_types @@ -211,6 +278,12 @@ class PassageManager: raise ValueError("Source passage cannot have archive_id") data = pydantic_passage.model_dump(to_orm=True) + + # Deduplicate tags if provided (for dual storage consistency) + tags = data.get("tags") + if tags: + tags = list(set(tags)) + common_fields = { "id": data.get("id"), "text": data["text"], @@ -218,6 +291,7 @@ class PassageManager: "embedding_config": data["embedding_config"], "organization_id": data["organization_id"], "metadata_": data.get("metadata", {}), + "tags": tags, "is_deleted": data.get("is_deleted", False), "created_at": data.get("created_at", datetime.now(timezone.utc)), } @@ -244,6 +318,12 @@ class PassageManager: raise ValueError("Source passage cannot have archive_id") data = pydantic_passage.model_dump(to_orm=True) + + # Deduplicate tags if provided (for dual storage consistency) + tags = data.get("tags") + if tags: + tags = list(set(tags)) + common_fields = { "id": data.get("id"), "text": data["text"], @@ -251,6 +331,7 @@ class PassageManager: "embedding_config": data["embedding_config"], "organization_id": data["organization_id"], "metadata_": data.get("metadata", {}), + "tags": tags, "is_deleted": data.get("is_deleted", False), "created_at": data.get("created_at", datetime.now(timezone.utc)), } @@ -310,6 +391,7 @@ class PassageManager: "embedding_config": data["embedding_config"], "organization_id": data["organization_id"], "metadata_": data.get("metadata", {}), + "tags": data.get("tags"), "is_deleted": data.get("is_deleted", False), "created_at": data.get("created_at", datetime.now(timezone.utc)), } @@ -357,6 +439,7 @@ class PassageManager: "embedding_config": data["embedding_config"], "organization_id": data["organization_id"], "metadata_": data.get("metadata", {}), + "tags": data.get("tags"), "is_deleted": data.get("is_deleted", False), "created_at": data.get("created_at", datetime.now(timezone.utc)), } @@ -396,6 +479,7 @@ class PassageManager: "embedding_config": data["embedding_config"], "organization_id": data["organization_id"], "metadata_": data.get("metadata", {}), + "tags": data.get("tags"), "is_deleted": data.get("is_deleted", False), "created_at": data.get("created_at", datetime.now(timezone.utc)), } @@ -466,8 +550,19 @@ class PassageManager: agent_state: AgentState, text: str, actor: PydanticUser, + tags: Optional[List[str]] = None, ) -> List[PydanticPassage]: - """Insert passage(s) into archival memory""" + """Insert passage(s) into archival memory + + Args: + agent_state: Agent state for embedding configuration + text: Text content to store as passages + actor: User performing the operation + tags: Optional list of tags to attach to all created passages + + Returns: + List of created passage objects + """ embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size embedding_client = LLMClient.create( @@ -500,6 +595,7 @@ class PassageManager: text=chunk_text, embedding=embedding, embedding_config=agent_state.embedding_config, + tags=tags, ), actor=actor, ) @@ -522,6 +618,7 @@ class PassageManager: embeddings=embeddings, passage_ids=passage_ids, # Use same IDs as SQL organization_id=actor.organization_id, + tags=tags, created_at=passages[0].created_at if passages else None, ) @@ -590,6 +687,34 @@ class PassageManager: # Update the database record with values from the provided record update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + + # Handle tags update separately for junction table + new_tags = update_data.pop("tags", None) + if new_tags is not None: + # Deduplicate tags + if new_tags: + new_tags = list(set(new_tags)) + + # Delete existing tags from junction table + from sqlalchemy import delete + + await session.execute(delete(PassageTag).where(PassageTag.passage_id == passage_id)) + + # Create new tags in junction table + if new_tags: + await self._create_tags_for_passage( + session=session, + passage_id=passage_id, + archive_id=curr_passage.archive_id, + organization_id=curr_passage.organization_id, + tags=new_tags, + actor=actor, + ) + + # Update the tags on the passage object + setattr(curr_passage, "tags", new_tags) + + # Update other fields for key, value in update_data.items(): setattr(curr_passage, key, value) @@ -1067,3 +1192,69 @@ class PassageManager: ) passages = result.scalars().all() return [p.to_pydantic() for p in passages] + + @enforce_types + @trace_method + async def get_unique_tags_for_archive_async( + self, + archive_id: str, + actor: PydanticUser, + ) -> List[str]: + """Get all unique tags for an archive. + + Args: + archive_id: ID of the archive + actor: User performing the operation + + Returns: + List of unique tag values + """ + async with db_registry.async_session() as session: + stmt = ( + select(PassageTag.tag) + .distinct() + .where( + PassageTag.archive_id == archive_id, + PassageTag.organization_id == actor.organization_id, + PassageTag.is_deleted == False, + ) + .order_by(PassageTag.tag) + ) + + result = await session.execute(stmt) + tags = result.scalars().all() + + return list(tags) + + @enforce_types + @trace_method + async def get_tag_counts_for_archive_async( + self, + archive_id: str, + actor: PydanticUser, + ) -> Dict[str, int]: + """Get tag counts for an archive. + + Args: + archive_id: ID of the archive + actor: User performing the operation + + Returns: + Dictionary mapping tag values to their counts + """ + async with db_registry.async_session() as session: + stmt = ( + select(PassageTag.tag, func.count(PassageTag.id).label("count")) + .where( + PassageTag.archive_id == archive_id, + PassageTag.organization_id == actor.organization_id, + PassageTag.is_deleted == False, + ) + .group_by(PassageTag.tag) + .order_by(PassageTag.tag) + ) + + result = await session.execute(stmt) + rows = result.all() + + return {row.tag: row.count for row in rows} diff --git a/letta/services/tool_executor/core_tool_executor.py b/letta/services/tool_executor/core_tool_executor.py index 1bd01a6a..26e64c1b 100644 --- a/letta/services/tool_executor/core_tool_executor.py +++ b/letta/services/tool_executor/core_tool_executor.py @@ -13,9 +13,7 @@ from letta.schemas.sandbox_config import SandboxConfig from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.user import User -from letta.services.agent_manager import AgentManager from letta.services.message_manager import MessageManager -from letta.services.passage_manager import PassageManager from letta.services.tool_executor.tool_executor_base import ToolExecutor from letta.utils import get_friendly_error_msg @@ -143,7 +141,7 @@ class LettaCoreToolExecutor(ToolExecutor): try: # Get results using passage manager - all_results = await AgentManager().query_agent_passages_async( + all_results = await self.agent_manager.query_agent_passages_async( actor=actor, agent_id=agent_state.id, query_text=query, @@ -174,12 +172,12 @@ class LettaCoreToolExecutor(ToolExecutor): Returns: Optional[str]: None is always returned as this function does not produce a response. """ - await PassageManager().insert_passage( + await self.passage_manager.insert_passage( agent_state=agent_state, text=content, actor=actor, ) - await AgentManager().rebuild_system_prompt_async(agent_id=agent_state.id, actor=actor, force=True) + await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_state.id, actor=actor, force=True) return None async def core_memory_append(self, agent_state: AgentState, actor: User, label: str, content: str) -> Optional[str]: @@ -198,7 +196,7 @@ class LettaCoreToolExecutor(ToolExecutor): current_value = str(agent_state.memory.get_block(label).value) new_value = current_value + "\n" + str(content) agent_state.memory.update_block_value(label=label, value=new_value) - await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor) + await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor) return None async def core_memory_replace( @@ -227,7 +225,7 @@ class LettaCoreToolExecutor(ToolExecutor): raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'") new_value = current_value.replace(str(old_content), str(new_content)) agent_state.memory.update_block_value(label=label, value=new_value) - await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor) + await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor) return None async def memory_replace(self, agent_state: AgentState, actor: User, label: str, old_str: str, new_str: str) -> str: @@ -291,7 +289,7 @@ class LettaCoreToolExecutor(ToolExecutor): # Write the new content to the block agent_state.memory.update_block_value(label=label, value=new_value) - await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor) + await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor) # Create a snippet of the edited section SNIPPET_LINES = 3 @@ -384,7 +382,7 @@ class LettaCoreToolExecutor(ToolExecutor): # Write into the block agent_state.memory.update_block_value(label=label, value=new_value) - await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor) + await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor) # Prepare the success message success_msg = f"The core memory block with label `{label}` has been edited. " @@ -437,7 +435,7 @@ class LettaCoreToolExecutor(ToolExecutor): agent_state.memory.update_block_value(label=label, value=new_memory) - await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor) + await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor) # Prepare the success message success_msg = f"The core memory block with label `{label}` has been edited. " diff --git a/tests/test_managers.py b/tests/test_managers.py index f715f6b8..225e2c01 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -64,6 +64,7 @@ from letta.schemas.enums import ( ProviderType, SandboxType, StepStatus, + TagMatchMode, ToolType, ) from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate @@ -3166,6 +3167,7 @@ def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_a embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, metadata={"type": "test_specific"}, + tags=["python", "test", "agent"], ), actor=default_user, ) @@ -3174,6 +3176,7 @@ def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_a assert passage.text == "Test agent passage via specific method" assert passage.archive_id == archive.id assert passage.source_id is None + assert sorted(passage.tags) == sorted(["python", "test", "agent"]) def test_create_source_passage_specific(server: SyncServer, default_user, default_file, default_source): @@ -3187,6 +3190,7 @@ def test_create_source_passage_specific(server: SyncServer, default_user, defaul embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, metadata={"type": "test_specific"}, + tags=["document", "test", "source"], ), file_metadata=default_file, actor=default_user, @@ -3196,6 +3200,7 @@ def test_create_source_passage_specific(server: SyncServer, default_user, defaul assert passage.text == "Test source passage via specific method" assert passage.source_id == default_source.id assert passage.archive_id is None + assert sorted(passage.tags) == sorted(["document", "test", "source"]) def test_create_agent_passage_validation(server: SyncServer, default_user, default_source, sarah_agent): @@ -3509,6 +3514,7 @@ async def test_create_many_agent_passages_async(server: SyncServer, default_user organization_id=default_user.organization_id, embedding=[0.1 * i], embedding_config=DEFAULT_EMBEDDING_CONFIG, + tags=["batch", f"item{i}"] if i % 2 == 0 else ["batch", "odd"], ) for i in range(3) ] @@ -3520,6 +3526,8 @@ async def test_create_many_agent_passages_async(server: SyncServer, default_user assert passage.text == f"Batch agent passage {i}" assert passage.archive_id == archive.id assert passage.source_id is None + expected_tags = ["batch", f"item{i}"] if i % 2 == 0 else ["batch", "odd"] + assert passage.tags == expected_tags @pytest.mark.asyncio @@ -3611,6 +3619,374 @@ def test_deprecated_methods_show_warnings(server: SyncServer, default_user, sara assert any("size is deprecated" in str(warning.message) for warning in w) +@pytest.mark.asyncio +async def test_passage_tags_functionality(server: SyncServer, default_user, sarah_agent): + """Test comprehensive tag functionality for passages.""" + from letta.schemas.enums import TagMatchMode + + # Get or create default archive for the agent + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + + # Create passages with different tag combinations + test_passages = [ + {"text": "Python programming tutorial", "tags": ["python", "tutorial", "programming"]}, + {"text": "Machine learning with Python", "tags": ["python", "ml", "ai"]}, + {"text": "JavaScript web development", "tags": ["javascript", "web", "frontend"]}, + {"text": "Python data science guide", "tags": ["python", "tutorial", "data"]}, + {"text": "No tags passage", "tags": None}, + ] + + created_passages = [] + for test_data in test_passages: + passage = await server.passage_manager.create_agent_passage_async( + PydanticPassage( + text=test_data["text"], + archive_id=archive.id, + organization_id=default_user.organization_id, + embedding=[0.1, 0.2, 0.3], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + tags=test_data["tags"], + ), + actor=default_user, + ) + created_passages.append(passage) + + # Test that tags are properly stored (deduplicated) + for i, passage in enumerate(created_passages): + expected_tags = test_passages[i]["tags"] + if expected_tags: + assert set(passage.tags) == set(expected_tags) + else: + assert passage.tags is None + + # Test querying with tag filtering (if Turbopuffer is enabled) + if hasattr(server.agent_manager, "query_agent_passages_async"): + # Test querying with python tag (should find 3 passages) + python_results = await server.agent_manager.query_agent_passages_async( + actor=default_user, + agent_id=sarah_agent.id, + tags=["python"], + tag_match_mode=TagMatchMode.ANY, + ) + + python_texts = [p.text for p in python_results] + assert len([t for t in python_texts if "Python" in t]) >= 2 + + # Test querying with multiple tags using ALL mode + tutorial_python_results = await server.agent_manager.query_agent_passages_async( + actor=default_user, + agent_id=sarah_agent.id, + tags=["python", "tutorial"], + tag_match_mode=TagMatchMode.ALL, + ) + + tutorial_texts = [p.text for p in tutorial_python_results] + expected_matches = [t for t in tutorial_texts if "tutorial" in t and "Python" in t] + assert len(expected_matches) >= 1 + + +@pytest.mark.asyncio +async def test_comprehensive_tag_functionality(disable_turbopuffer, server: SyncServer, sarah_agent, default_user): + """Comprehensive test for tag functionality including dual storage and junction table.""" + + # Test 1: Create passages with tags and verify they're stored in both places + passages_with_tags = [] + test_tags = { + "passage1": ["important", "documentation", "python"], + "passage2": ["important", "testing"], + "passage3": ["documentation", "api"], + "passage4": ["python", "testing", "api"], + "passage5": [], # Test empty tags + } + + for i, (passage_key, tags) in enumerate(test_tags.items(), 1): + text = f"Test passage {i} for comprehensive tag testing" + created_passages = await server.passage_manager.insert_passage( + agent_state=sarah_agent, + text=text, + actor=default_user, + tags=tags if tags else None, + ) + assert len(created_passages) == 1 + passage = created_passages[0] + + # Verify tags are stored in the JSON column (deduplicated) + if tags: + assert set(passage.tags) == set(tags) + else: + assert passage.tags is None + passages_with_tags.append(passage) + + # Test 2: Verify unique tags for archive + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( + agent_id=sarah_agent.id, + agent_name=sarah_agent.name, + actor=default_user, + ) + + unique_tags = await server.passage_manager.get_unique_tags_for_archive_async( + archive_id=archive.id, + actor=default_user, + ) + + # Should have all unique tags: "important", "documentation", "python", "testing", "api" + expected_unique_tags = {"important", "documentation", "python", "testing", "api"} + assert set(unique_tags) == expected_unique_tags + assert len(unique_tags) == 5 + + # Test 3: Verify tag counts + tag_counts = await server.passage_manager.get_tag_counts_for_archive_async( + archive_id=archive.id, + actor=default_user, + ) + + # Verify counts + assert tag_counts["important"] == 2 # passage1 and passage2 + assert tag_counts["documentation"] == 2 # passage1 and passage3 + assert tag_counts["python"] == 2 # passage1 and passage4 + assert tag_counts["testing"] == 2 # passage2 and passage4 + assert tag_counts["api"] == 2 # passage3 and passage4 + + # Test 4: Query passages with ANY tag matching + any_results = await server.agent_manager.query_agent_passages_async( + agent_id=sarah_agent.id, + query_text="test", + limit=10, + tags=["important", "api"], + tag_match_mode=TagMatchMode.ANY, + actor=default_user, + ) + + # Should match passages with "important" OR "api" tags (passages 1, 2, 3, 4) + [p.text for p in any_results] + assert len(any_results) >= 4 + + # Test 5: Query passages with ALL tag matching + all_results = await server.agent_manager.query_agent_passages_async( + agent_id=sarah_agent.id, + query_text="test", + limit=10, + tags=["python", "testing"], + tag_match_mode=TagMatchMode.ALL, + actor=default_user, + ) + + # Should only match passage4 which has both "python" AND "testing" + all_passage_texts = [p.text for p in all_results] + assert any("Test passage 4" in text for text in all_passage_texts) + + # Test 6: Query with non-existent tags + no_results = await server.agent_manager.query_agent_passages_async( + agent_id=sarah_agent.id, + query_text="test", + limit=10, + tags=["nonexistent", "missing"], + tag_match_mode=TagMatchMode.ANY, + actor=default_user, + ) + + # Should return no results + assert len(no_results) == 0 + + # Test 7: Verify tags CAN be updated (with junction table properly maintained) + first_passage = passages_with_tags[0] + new_tags = ["updated", "modified", "changed"] + update_data = PydanticPassage( + id=first_passage.id, + text="Updated text", + tags=new_tags, + organization_id=first_passage.organization_id, + archive_id=first_passage.archive_id, + embedding=first_passage.embedding, + embedding_config=first_passage.embedding_config, + ) + + # Update should work and tags should be updated + updated = await server.passage_manager.update_agent_passage_by_id_async( + passage_id=first_passage.id, + passage=update_data, + actor=default_user, + ) + + # Both text and tags should be updated + assert updated.text == "Updated text" + assert set(updated.tags) == set(new_tags) + + # Verify tags are properly updated in junction table + updated_unique_tags = await server.passage_manager.get_unique_tags_for_archive_async( + archive_id=archive.id, + actor=default_user, + ) + + # Should include new tags and not include old "important", "documentation", "python" from passage1 + # But still have tags from other passages + assert "updated" in updated_unique_tags + assert "modified" in updated_unique_tags + assert "changed" in updated_unique_tags + + # Test 8: Delete a passage and verify cascade deletion of tags + passage_to_delete = passages_with_tags[1] # passage2 with ["important", "testing"] + + await server.passage_manager.delete_agent_passage_by_id_async( + passage_id=passage_to_delete.id, + actor=default_user, + ) + + # Get updated tag counts + updated_tag_counts = await server.passage_manager.get_tag_counts_for_archive_async( + archive_id=archive.id, + actor=default_user, + ) + + # "important" no longer exists (was in passage1 which was updated and passage2 which was deleted) + assert "important" not in updated_tag_counts + # "testing" count should decrease from 2 to 1 (only in passage4 now) + assert updated_tag_counts["testing"] == 1 + + # Test 9: Batch create passages with tags + batch_texts = [ + "Batch passage 1", + "Batch passage 2", + "Batch passage 3", + ] + batch_tags = ["batch", "test", "multiple"] + + batch_passages = [] + for text in batch_texts: + passages = await server.passage_manager.insert_passage( + agent_state=sarah_agent, + text=text, + actor=default_user, + tags=batch_tags, + ) + batch_passages.extend(passages) + + # Verify all batch passages have the same tags + for passage in batch_passages: + assert set(passage.tags) == set(batch_tags) + + # Test 10: Verify tag counts include batch passages + final_tag_counts = await server.passage_manager.get_tag_counts_for_archive_async( + archive_id=archive.id, + actor=default_user, + ) + + assert final_tag_counts["batch"] == 3 + assert final_tag_counts["test"] == 3 + assert final_tag_counts["multiple"] == 3 + + # Test 11: Complex query with multiple tags and ALL matching + complex_all_results = await server.agent_manager.query_agent_passages_async( + agent_id=sarah_agent.id, + query_text="batch", + limit=10, + tags=["batch", "test", "multiple"], + tag_match_mode=TagMatchMode.ALL, + actor=default_user, + ) + + # Should match all 3 batch passages + assert len(complex_all_results) >= 3 + + # Test 12: Empty tag list should return all passages + all_passages = await server.agent_manager.query_agent_passages_async( + agent_id=sarah_agent.id, + query_text="passage", + limit=50, + tags=[], + tag_match_mode=TagMatchMode.ANY, + actor=default_user, + ) + + # Should return passages based on text search only + assert len(all_passages) > 0 + + +@pytest.mark.asyncio +async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_agent, default_user): + """Test edge cases for tag functionality.""" + + # Test 1: Very long tag names + long_tag = "a" * 500 # 500 character tag + passages = await server.passage_manager.insert_passage( + agent_state=sarah_agent, + text="Testing long tag names", + actor=default_user, + tags=[long_tag, "normal_tag"], + ) + + assert len(passages) == 1 + assert long_tag in passages[0].tags + + # Test 2: Special characters in tags + special_tags = [ + "tag-with-dash", + "tag_with_underscore", + "tag.with.dots", + "tag/with/slash", + "tag:with:colon", + "tag@with@at", + "tag#with#hash", + "tag with spaces", + "CamelCaseTag", + "数字标签", + ] + + passages_special = await server.passage_manager.insert_passage( + agent_state=sarah_agent, + text="Testing special character tags", + actor=default_user, + tags=special_tags, + ) + + assert len(passages_special) == 1 + assert set(passages_special[0].tags) == set(special_tags) + + # Verify unique tags includes all special character tags + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( + agent_id=sarah_agent.id, + agent_name=sarah_agent.name, + actor=default_user, + ) + + unique_tags = await server.passage_manager.get_unique_tags_for_archive_async( + archive_id=archive.id, + actor=default_user, + ) + + for tag in special_tags: + assert tag in unique_tags + + # Test 3: Duplicate tags in input (should be deduplicated) + duplicate_tags = ["tag1", "tag2", "tag1", "tag3", "tag2", "tag1"] + passages_dup = await server.passage_manager.insert_passage( + agent_state=sarah_agent, + text="Testing duplicate tags", + actor=default_user, + tags=duplicate_tags, + ) + + # Should only have unique tags (duplicates removed) + assert len(passages_dup) == 1 + assert set(passages_dup[0].tags) == {"tag1", "tag2", "tag3"} + assert len(passages_dup[0].tags) == 3 # Should be deduplicated + + # Test 4: Case sensitivity in tags + case_tags = ["Tag", "tag", "TAG", "tAg"] + passages_case = await server.passage_manager.insert_passage( + agent_state=sarah_agent, + text="Testing case sensitive tags", + actor=default_user, + tags=case_tags, + ) + + # All variations should be preserved (case-sensitive) + assert len(passages_case) == 1 + assert set(passages_case[0].tags) == set(case_tags) + + # ====================================================================================================================== # User Manager Tests # ====================================================================================================================== diff --git a/tests/test_turbopuffer_integration.py b/tests/test_turbopuffer_integration.py index 3f9edc58..f95e986e 100644 --- a/tests/test_turbopuffer_integration.py +++ b/tests/test_turbopuffer_integration.py @@ -6,7 +6,7 @@ import pytest from letta.config import LettaConfig from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import VectorDBProvider +from letta.schemas.enums import TagMatchMode, VectorDBProvider from letta.server.server import SyncServer from letta.settings import settings @@ -233,16 +233,14 @@ class TestTurbopufferIntegration: assert len(result) == 3 - # Query with organization filter + # Query all passages (no tag filtering) query_vector = [0.15] * 1536 - results = await client.query_passages( - archive_id=archive_id, query_embedding=query_vector, top_k=10, filters={"organization_id": "org-123"} - ) + results = await client.query_passages(archive_id=archive_id, query_embedding=query_vector, top_k=10) - # Should only get passages from org-123 - assert len(results) >= 2 # At least the first two passages + # Should get all passages + assert len(results) == 3 # All three passages for passage, score in results: - assert passage.organization_id == "org-123" + assert passage.organization_id is not None # Clean up await client.delete_passages(archive_id=archive_id, passage_ids=[d["id"] for d in test_data]) @@ -394,6 +392,130 @@ class TestTurbopufferIntegration: except: pass + @pytest.mark.asyncio + @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing") + async def test_tag_filtering_with_real_tpuf(self, enable_turbopuffer): + """Test tag filtering functionality with AND and OR logic""" + + import uuid + + from letta.helpers.tpuf_client import TurbopufferClient + + client = TurbopufferClient() + archive_id = f"test-tags-{datetime.now().timestamp()}" + org_id = str(uuid.uuid4()) + + try: + # Insert passages with different tag combinations + texts = [ + "Python programming tutorial", + "Machine learning with Python", + "JavaScript web development", + "Python data science tutorial", + "React JavaScript framework", + ] + + tag_sets = [ + ["python", "tutorial"], + ["python", "ml"], + ["javascript", "web"], + ["python", "tutorial", "data"], + ["javascript", "react"], + ] + + embeddings = [[float(i), float(i + 5), float(i + 10)] for i in range(len(texts))] + passage_ids = [f"passage-{str(uuid.uuid4())}" for _ in texts] + + # Insert passages with tags + for i, (text, tags, embedding, passage_id) in enumerate(zip(texts, tag_sets, embeddings, passage_ids)): + await client.insert_archival_memories( + archive_id=archive_id, + text_chunks=[text], + embeddings=[embedding], + passage_ids=[passage_id], + organization_id=org_id, + tags=tags, + created_at=datetime.now(timezone.utc), + ) + + # Test tag filtering with "any" mode (should find passages with any of the specified tags) + python_any_results = await client.query_passages( + archive_id=archive_id, + query_embedding=[1.0, 6.0, 11.0], + search_mode="vector", + top_k=10, + tags=["python"], + tag_match_mode=TagMatchMode.ANY, + ) + + # Should find 3 passages with python tag + python_passages = [passage for passage, _ in python_any_results] + python_texts = [p.text for p in python_passages] + assert len(python_passages) == 3 + assert "Python programming tutorial" in python_texts + assert "Machine learning with Python" in python_texts + assert "Python data science tutorial" in python_texts + + # Test tag filtering with "all" mode + python_tutorial_all_results = await client.query_passages( + archive_id=archive_id, + query_embedding=[1.0, 6.0, 11.0], + search_mode="vector", + top_k=10, + tags=["python", "tutorial"], + tag_match_mode=TagMatchMode.ALL, + ) + + # Should find 2 passages that have both python AND tutorial tags + tutorial_passages = [passage for passage, _ in python_tutorial_all_results] + tutorial_texts = [p.text for p in tutorial_passages] + assert len(tutorial_passages) == 2 + assert "Python programming tutorial" in tutorial_texts + assert "Python data science tutorial" in tutorial_texts + + # Test tag filtering with FTS mode + js_fts_results = await client.query_passages( + archive_id=archive_id, + query_text="javascript", + search_mode="fts", + top_k=10, + tags=["javascript"], + tag_match_mode=TagMatchMode.ANY, + ) + + # Should find 2 passages with javascript tag + js_passages = [passage for passage, _ in js_fts_results] + js_texts = [p.text for p in js_passages] + assert len(js_passages) == 2 + assert "JavaScript web development" in js_texts + assert "React JavaScript framework" in js_texts + + # Test hybrid search with tags + python_hybrid_results = await client.query_passages( + archive_id=archive_id, + query_embedding=[2.0, 7.0, 12.0], + query_text="python programming", + search_mode="hybrid", + top_k=10, + tags=["python"], + tag_match_mode=TagMatchMode.ANY, + vector_weight=0.6, + fts_weight=0.4, + ) + + # Should find python-tagged passages + hybrid_passages = [passage for passage, _ in python_hybrid_results] + hybrid_texts = [p.text for p in hybrid_passages] + assert len(hybrid_passages) == 3 + assert all("Python" in text for text in hybrid_texts) + + finally: + # Clean up + try: + await client.delete_all_passages(archive_id) + except: + pass + @pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True) class TestTurbopufferParametrized: