feat: Support arbitrary string tagging filtering [LET-3467] (#4285)
* Finish tagging * Add comprehensive tags functionality * Add fern autogen * Create passage tags table * Add indices * Add comments explaining dual storage * Fix alembic heads * Fix alembic --------- Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
"""Add tags to passages and create passage_tags junction table
|
||||
|
||||
Revision ID: 54c76f7cabca
|
||||
Revises: c41c87205254
|
||||
Create Date: 2025-08-28 15:13:01.549590
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from letta.settings import settings
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "54c76f7cabca"
|
||||
down_revision: Union[str, None] = "c41c87205254"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# Database-specific timestamp defaults
|
||||
if not settings.letta_pg_uri_no_default:
|
||||
# SQLite uses CURRENT_TIMESTAMP
|
||||
timestamp_default = sa.text("(CURRENT_TIMESTAMP)")
|
||||
else:
|
||||
# PostgreSQL uses now()
|
||||
timestamp_default = sa.text("now()")
|
||||
|
||||
op.create_table(
|
||||
"passage_tags",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("tag", sa.String(), nullable=False),
|
||||
sa.Column("passage_id", sa.String(), nullable=False),
|
||||
sa.Column("archive_id", sa.String(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=timestamp_default, nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=timestamp_default, nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["archive_id"], ["archives.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["passage_id"], ["archival_passages.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("passage_id", "tag", name="uq_passage_tag"),
|
||||
)
|
||||
op.create_index("ix_passage_tags_archive_id", "passage_tags", ["archive_id"], unique=False)
|
||||
op.create_index("ix_passage_tags_archive_tag", "passage_tags", ["archive_id", "tag"], unique=False)
|
||||
op.create_index("ix_passage_tags_org_archive", "passage_tags", ["organization_id", "archive_id"], unique=False)
|
||||
op.create_index("ix_passage_tags_tag", "passage_tags", ["tag"], unique=False)
|
||||
op.add_column("archival_passages", sa.Column("tags", sa.JSON(), nullable=True))
|
||||
op.add_column("source_passages", sa.Column("tags", sa.JSON(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("source_passages", "tags")
|
||||
op.drop_column("archival_passages", "tags")
|
||||
op.drop_index("ix_passage_tags_tag", table_name="passage_tags")
|
||||
op.drop_index("ix_passage_tags_org_archive", table_name="passage_tags")
|
||||
op.drop_index("ix_passage_tags_archive_tag", table_name="passage_tags")
|
||||
op.drop_index("ix_passage_tags_archive_id", table_name="passage_tags")
|
||||
op.drop_table("passage_tags")
|
||||
# ### end Alembic commands ###
|
||||
@@ -2,9 +2,10 @@
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import TagMatchMode
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.settings import settings
|
||||
|
||||
@@ -168,7 +169,8 @@ class TurbopufferClient:
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
||||
top_k: int = 10,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tag_match_mode: TagMatchMode = TagMatchMode.ANY,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
) -> List[Tuple[PydanticPassage, float]]:
|
||||
@@ -180,7 +182,8 @@ class TurbopufferClient:
|
||||
query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
|
||||
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
|
||||
top_k: Number of results to return
|
||||
filters: Optional filter conditions
|
||||
tags: Optional list of tags to filter by
|
||||
tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
|
||||
@@ -207,15 +210,19 @@ class TurbopufferClient:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
|
||||
# build filter conditions
|
||||
filter_conditions = []
|
||||
if filters:
|
||||
for key, value in filters.items():
|
||||
filter_conditions.append((key, "Eq", value))
|
||||
# build tag filter conditions
|
||||
tag_filter = None
|
||||
if tags:
|
||||
tag_conditions = []
|
||||
for tag in tags:
|
||||
tag_conditions.append((tag, "Eq", True))
|
||||
|
||||
base_filter = (
|
||||
("And", filter_conditions) if len(filter_conditions) > 1 else (filter_conditions[0] if filter_conditions else None)
|
||||
)
|
||||
if len(tag_conditions) == 1:
|
||||
tag_filter = tag_conditions[0]
|
||||
elif tag_match_mode == TagMatchMode.ALL:
|
||||
tag_filter = ("And", tag_conditions)
|
||||
else: # tag_match_mode == TagMatchMode.ANY
|
||||
tag_filter = ("Or", tag_conditions)
|
||||
|
||||
if search_mode == "vector":
|
||||
# single vector search query
|
||||
@@ -224,11 +231,11 @@ class TurbopufferClient:
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
}
|
||||
if base_filter:
|
||||
query_params["filters"] = base_filter
|
||||
if tag_filter:
|
||||
query_params["filters"] = tag_filter
|
||||
|
||||
result = await namespace.query(**query_params)
|
||||
return self._process_single_query_results(result, archive_id, filters)
|
||||
return self._process_single_query_results(result, archive_id, tags)
|
||||
|
||||
elif search_mode == "fts":
|
||||
# single full-text search query
|
||||
@@ -237,11 +244,11 @@ class TurbopufferClient:
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
}
|
||||
if base_filter:
|
||||
query_params["filters"] = base_filter
|
||||
if tag_filter:
|
||||
query_params["filters"] = tag_filter
|
||||
|
||||
result = await namespace.query(**query_params)
|
||||
return self._process_single_query_results(result, archive_id, filters, is_fts=True)
|
||||
return self._process_single_query_results(result, archive_id, tags, is_fts=True)
|
||||
|
||||
else: # hybrid mode
|
||||
# multi-query for both vector and FTS
|
||||
@@ -253,8 +260,8 @@ class TurbopufferClient:
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
}
|
||||
if base_filter:
|
||||
vector_query["filters"] = base_filter
|
||||
if tag_filter:
|
||||
vector_query["filters"] = tag_filter
|
||||
queries.append(vector_query)
|
||||
|
||||
# full-text search query
|
||||
@@ -263,16 +270,16 @@ class TurbopufferClient:
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
}
|
||||
if base_filter:
|
||||
fts_query["filters"] = base_filter
|
||||
if tag_filter:
|
||||
fts_query["filters"] = tag_filter
|
||||
queries.append(fts_query)
|
||||
|
||||
# execute multi-query
|
||||
response = await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
|
||||
|
||||
# process and combine results using reciprocal rank fusion
|
||||
vector_results = self._process_single_query_results(response.results[0], archive_id, filters)
|
||||
fts_results = self._process_single_query_results(response.results[1], archive_id, filters, is_fts=True)
|
||||
vector_results = self._process_single_query_results(response.results[0], archive_id, tags)
|
||||
fts_results = self._process_single_query_results(response.results[1], archive_id, tags, is_fts=True)
|
||||
|
||||
# combine results using reciprocal rank fusion
|
||||
return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k)
|
||||
@@ -282,16 +289,16 @@ class TurbopufferClient:
|
||||
raise
|
||||
|
||||
def _process_single_query_results(
|
||||
self, result, archive_id: str, filters: Optional[Dict[str, Any]], is_fts: bool = False
|
||||
self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False
|
||||
) -> List[Tuple[PydanticPassage, float]]:
|
||||
"""Process results from a single query into passage objects with scores."""
|
||||
passages_with_scores = []
|
||||
|
||||
for row in result.rows:
|
||||
# Build metadata including any filter conditions that were applied
|
||||
# Build metadata including any tag filters that were applied
|
||||
metadata = {}
|
||||
if filters:
|
||||
metadata["applied_filters"] = filters
|
||||
if tags:
|
||||
metadata["applied_tags"] = tags
|
||||
|
||||
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
|
||||
passage = PydanticPassage(
|
||||
@@ -300,7 +307,7 @@ class TurbopufferClient:
|
||||
organization_id=getattr(row, "organization_id", None),
|
||||
archive_id=archive_id, # use the archive_id from the query
|
||||
created_at=getattr(row, "created_at", None),
|
||||
metadata_=metadata, # Include filter conditions in metadata
|
||||
metadata_=metadata, # Include tag filters in metadata
|
||||
# Set required fields to empty/default values since we don't store embeddings
|
||||
embedding=[], # Empty embedding since we don't return it from Turbopuffer
|
||||
embedding_config=None, # No embedding config needed for retrieved passages
|
||||
|
||||
@@ -22,6 +22,7 @@ from letta.orm.mcp_server import MCPServer
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import ArchivalPassage, BasePassage, SourcePassage
|
||||
from letta.orm.passage_tag import PassageTag
|
||||
from letta.orm.prompt import Prompt
|
||||
from letta.orm.provider import Provider
|
||||
from letta.orm.provider_trace import ProviderTrace
|
||||
|
||||
@@ -16,6 +16,7 @@ if TYPE_CHECKING:
|
||||
from letta.orm.llm_batch_job import LLMBatchJob
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.passage import ArchivalPassage, SourcePassage
|
||||
from letta.orm.passage_tag import PassageTag
|
||||
from letta.orm.provider import Provider
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.tool import Tool
|
||||
@@ -56,6 +57,7 @@ class Organization(SqlalchemyBase):
|
||||
archival_passages: Mapped[List["ArchivalPassage"]] = relationship(
|
||||
"ArchivalPassage", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
passage_tags: Mapped[List["PassageTag"]] = relationship("PassageTag", back_populates="organization", cascade="all, delete-orphan")
|
||||
archives: Mapped[List["Archive"]] = relationship("Archive", back_populates="organization", cascade="all, delete-orphan")
|
||||
providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan")
|
||||
identities: Mapped[List["Identity"]] = relationship("Identity", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import JSON, Column, Index
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
|
||||
@@ -27,6 +27,8 @@ class BasePassage(SqlalchemyBase, OrganizationMixin):
|
||||
text: Mapped[str] = mapped_column(doc="Passage text content")
|
||||
embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration")
|
||||
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
|
||||
# dual storage: json column for fast retrieval, junction table for efficient queries
|
||||
tags: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="Tags associated with this passage")
|
||||
|
||||
# Vector embedding field based on database type
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
@@ -75,6 +77,11 @@ class ArchivalPassage(BasePassage, ArchiveMixin):
|
||||
|
||||
__tablename__ = "archival_passages"
|
||||
|
||||
# junction table for efficient tag queries (complements json column above)
|
||||
passage_tags: Mapped[List["PassageTag"]] = relationship(
|
||||
"PassageTag", back_populates="passage", cascade="all, delete-orphan", lazy="noload"
|
||||
)
|
||||
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
return relationship("Organization", back_populates="archival_passages", lazy="selectin")
|
||||
|
||||
55
letta/orm/passage_tag.py
Normal file
55
letta/orm/passage_tag.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ForeignKey, Index, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import ArchivalPassage
|
||||
|
||||
|
||||
class PassageTag(SqlalchemyBase, OrganizationMixin):
|
||||
"""Junction table for tags associated with passages.
|
||||
|
||||
Design: dual storage approach where tags are stored both in:
|
||||
1. JSON column in passages table (fast retrieval with passage data)
|
||||
2. This junction table (efficient DISTINCT/COUNT queries and filtering)
|
||||
"""
|
||||
|
||||
__tablename__ = "passage_tags"
|
||||
|
||||
__table_args__ = (
|
||||
# ensure uniqueness of tag per passage
|
||||
UniqueConstraint("passage_id", "tag", name="uq_passage_tag"),
|
||||
# indexes for efficient queries
|
||||
Index("ix_passage_tags_archive_id", "archive_id"),
|
||||
Index("ix_passage_tags_tag", "tag"),
|
||||
Index("ix_passage_tags_archive_tag", "archive_id", "tag"),
|
||||
Index("ix_passage_tags_org_archive", "organization_id", "archive_id"),
|
||||
)
|
||||
|
||||
# primary key
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, doc="Unique identifier for the tag entry")
|
||||
|
||||
# tag value
|
||||
tag: Mapped[str] = mapped_column(String, nullable=False, doc="The tag value")
|
||||
|
||||
# foreign keys
|
||||
passage_id: Mapped[str] = mapped_column(
|
||||
String, ForeignKey("archival_passages.id", ondelete="CASCADE"), nullable=False, doc="ID of the passage this tag belongs to"
|
||||
)
|
||||
|
||||
archive_id: Mapped[str] = mapped_column(
|
||||
String,
|
||||
ForeignKey("archives.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
doc="ID of the archive this passage belongs to (denormalized for efficient queries)",
|
||||
)
|
||||
|
||||
# relationships
|
||||
passage: Mapped["ArchivalPassage"] = relationship("ArchivalPassage", back_populates="passage_tags", lazy="noload")
|
||||
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="passage_tags", lazy="selectin")
|
||||
@@ -179,3 +179,10 @@ class VectorDBProvider(str, Enum):
|
||||
|
||||
NATIVE = "native"
|
||||
TPUF = "tpuf"
|
||||
|
||||
|
||||
class TagMatchMode(str, Enum):
|
||||
"""Tag matching behavior for filtering"""
|
||||
|
||||
ANY = "any"
|
||||
ALL = "all"
|
||||
|
||||
@@ -25,6 +25,7 @@ class PassageBase(OrmMetadataBase):
|
||||
file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.")
|
||||
file_name: Optional[str] = Field(None, description="The name of the file (only for source passages).")
|
||||
metadata: Optional[Dict] = Field({}, validation_alias="metadata_", description="The metadata of the passage.")
|
||||
tags: Optional[List[str]] = Field(None, description="Tags associated with this passage.")
|
||||
|
||||
|
||||
class Passage(PassageBase):
|
||||
|
||||
@@ -48,7 +48,7 @@ from letta.schemas.block import DEFAULT_BLOCKS
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderType, ToolType, VectorDBProvider
|
||||
from letta.schemas.enums import ProviderType, TagMatchMode, ToolType, VectorDBProvider
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.group import Group as PydanticGroup
|
||||
from letta.schemas.group import ManagerType
|
||||
@@ -2649,6 +2649,8 @@ class AgentManager:
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tag_match_mode: Optional[TagMatchMode] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Lists all passages attached to an agent."""
|
||||
# Check if we should use Turbopuffer for vector search
|
||||
@@ -2686,6 +2688,8 @@ class AgentManager:
|
||||
query_text=query_text, # pass text for potential hybrid search
|
||||
search_mode="hybrid", # use hybrid mode for better results
|
||||
top_k=limit,
|
||||
tags=tags,
|
||||
tag_match_mode=tag_match_mode or TagMatchMode.ANY,
|
||||
)
|
||||
|
||||
# Return just the passages (without scores)
|
||||
@@ -2719,7 +2723,30 @@ class AgentManager:
|
||||
passages = result.scalars().all()
|
||||
|
||||
# Convert to Pydantic models
|
||||
return [p.to_pydantic() for p in passages]
|
||||
pydantic_passages = [p.to_pydantic() for p in passages]
|
||||
|
||||
# TODO: Integrate tag filtering directly into the SQL query for better performance.
|
||||
# Currently using post-filtering which is less efficient but simpler to implement.
|
||||
# Future optimization: Add JOIN with passage_tags table and WHERE clause for tag filtering.
|
||||
if tags:
|
||||
filtered_passages = []
|
||||
for passage in pydantic_passages:
|
||||
if passage.tags:
|
||||
passage_tags = set(passage.tags)
|
||||
query_tags = set(tags)
|
||||
|
||||
if tag_match_mode == TagMatchMode.ALL:
|
||||
# ALL mode: passage must have all query tags
|
||||
if query_tags.issubset(passage_tags):
|
||||
filtered_passages.append(passage)
|
||||
else:
|
||||
# ANY mode (default): passage must have at least one query tag
|
||||
if query_tags.intersection(passage_tags):
|
||||
filtered_passages.append(passage)
|
||||
|
||||
return filtered_passages
|
||||
|
||||
return pydantic_passages
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.embeddings import parse_and_chunk_text
|
||||
@@ -12,6 +14,7 @@ from letta.llm_api.llm_client import LLMClient
|
||||
from letta.orm import ArchivesAgents
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.passage import ArchivalPassage, SourcePassage
|
||||
from letta.orm.passage_tag import PassageTag
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
@@ -48,6 +51,44 @@ class PassageManager:
|
||||
def __init__(self):
|
||||
self.archive_manager = ArchiveManager()
|
||||
|
||||
async def _create_tags_for_passage(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
passage_id: str,
|
||||
archive_id: str,
|
||||
organization_id: str,
|
||||
tags: List[str],
|
||||
actor: PydanticUser,
|
||||
) -> List[PassageTag]:
|
||||
"""Create tag entries in junction table (complements tags stored in JSON column).
|
||||
|
||||
Junction table enables efficient DISTINCT queries and tag-based filtering.
|
||||
|
||||
Note: Tags are already deduplicated before being passed to this method.
|
||||
"""
|
||||
if not tags:
|
||||
return []
|
||||
|
||||
tag_objects = []
|
||||
for tag in tags:
|
||||
tag_obj = PassageTag(
|
||||
id=f"passage-tag-{uuid.uuid4()}",
|
||||
tag=tag,
|
||||
passage_id=passage_id,
|
||||
archive_id=archive_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
tag_objects.append(tag_obj)
|
||||
|
||||
# batch create all tags
|
||||
created_tags = await PassageTag.batch_create_async(
|
||||
items=tag_objects,
|
||||
db_session=session,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
return created_tags
|
||||
|
||||
# AGENT PASSAGE METHODS
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -155,6 +196,12 @@ class PassageManager:
|
||||
raise ValueError("Agent passage cannot have source_id")
|
||||
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
|
||||
# Deduplicate tags if provided (for dual storage consistency)
|
||||
tags = data.get("tags")
|
||||
if tags:
|
||||
tags = list(set(tags))
|
||||
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
@@ -162,6 +209,7 @@ class PassageManager:
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"tags": tags,
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
@@ -182,6 +230,12 @@ class PassageManager:
|
||||
raise ValueError("Agent passage cannot have source_id")
|
||||
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
|
||||
# Deduplicate tags if provided (for dual storage consistency)
|
||||
tags = data.get("tags")
|
||||
if tags:
|
||||
tags = list(set(tags))
|
||||
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
@@ -189,6 +243,7 @@ class PassageManager:
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"tags": tags,
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
@@ -197,6 +252,18 @@ class PassageManager:
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
passage = await passage.create_async(session, actor=actor)
|
||||
|
||||
# dual storage: save tags to junction table for efficient queries
|
||||
if tags: # use the deduplicated tags variable
|
||||
await self._create_tags_for_passage(
|
||||
session=session,
|
||||
passage_id=passage.id,
|
||||
archive_id=passage.archive_id,
|
||||
organization_id=passage.organization_id,
|
||||
tags=tags, # pass deduplicated tags
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
return passage.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -211,6 +278,12 @@ class PassageManager:
|
||||
raise ValueError("Source passage cannot have archive_id")
|
||||
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
|
||||
# Deduplicate tags if provided (for dual storage consistency)
|
||||
tags = data.get("tags")
|
||||
if tags:
|
||||
tags = list(set(tags))
|
||||
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
@@ -218,6 +291,7 @@ class PassageManager:
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"tags": tags,
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
@@ -244,6 +318,12 @@ class PassageManager:
|
||||
raise ValueError("Source passage cannot have archive_id")
|
||||
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
|
||||
# Deduplicate tags if provided (for dual storage consistency)
|
||||
tags = data.get("tags")
|
||||
if tags:
|
||||
tags = list(set(tags))
|
||||
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
@@ -251,6 +331,7 @@ class PassageManager:
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"tags": tags,
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
@@ -310,6 +391,7 @@ class PassageManager:
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"tags": data.get("tags"),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
@@ -357,6 +439,7 @@ class PassageManager:
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"tags": data.get("tags"),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
@@ -396,6 +479,7 @@ class PassageManager:
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"tags": data.get("tags"),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
@@ -466,8 +550,19 @@ class PassageManager:
|
||||
agent_state: AgentState,
|
||||
text: str,
|
||||
actor: PydanticUser,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Insert passage(s) into archival memory"""
|
||||
"""Insert passage(s) into archival memory
|
||||
|
||||
Args:
|
||||
agent_state: Agent state for embedding configuration
|
||||
text: Text content to store as passages
|
||||
actor: User performing the operation
|
||||
tags: Optional list of tags to attach to all created passages
|
||||
|
||||
Returns:
|
||||
List of created passage objects
|
||||
"""
|
||||
|
||||
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
||||
embedding_client = LLMClient.create(
|
||||
@@ -500,6 +595,7 @@ class PassageManager:
|
||||
text=chunk_text,
|
||||
embedding=embedding,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
tags=tags,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
@@ -522,6 +618,7 @@ class PassageManager:
|
||||
embeddings=embeddings,
|
||||
passage_ids=passage_ids, # Use same IDs as SQL
|
||||
organization_id=actor.organization_id,
|
||||
tags=tags,
|
||||
created_at=passages[0].created_at if passages else None,
|
||||
)
|
||||
|
||||
@@ -590,6 +687,34 @@ class PassageManager:
|
||||
|
||||
# Update the database record with values from the provided record
|
||||
update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
# Handle tags update separately for junction table
|
||||
new_tags = update_data.pop("tags", None)
|
||||
if new_tags is not None:
|
||||
# Deduplicate tags
|
||||
if new_tags:
|
||||
new_tags = list(set(new_tags))
|
||||
|
||||
# Delete existing tags from junction table
|
||||
from sqlalchemy import delete
|
||||
|
||||
await session.execute(delete(PassageTag).where(PassageTag.passage_id == passage_id))
|
||||
|
||||
# Create new tags in junction table
|
||||
if new_tags:
|
||||
await self._create_tags_for_passage(
|
||||
session=session,
|
||||
passage_id=passage_id,
|
||||
archive_id=curr_passage.archive_id,
|
||||
organization_id=curr_passage.organization_id,
|
||||
tags=new_tags,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Update the tags on the passage object
|
||||
setattr(curr_passage, "tags", new_tags)
|
||||
|
||||
# Update other fields
|
||||
for key, value in update_data.items():
|
||||
setattr(curr_passage, key, value)
|
||||
|
||||
@@ -1067,3 +1192,69 @@ class PassageManager:
|
||||
)
|
||||
passages = result.scalars().all()
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_unique_tags_for_archive_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> List[str]:
|
||||
"""Get all unique tags for an archive.
|
||||
|
||||
Args:
|
||||
archive_id: ID of the archive
|
||||
actor: User performing the operation
|
||||
|
||||
Returns:
|
||||
List of unique tag values
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
stmt = (
|
||||
select(PassageTag.tag)
|
||||
.distinct()
|
||||
.where(
|
||||
PassageTag.archive_id == archive_id,
|
||||
PassageTag.organization_id == actor.organization_id,
|
||||
PassageTag.is_deleted == False,
|
||||
)
|
||||
.order_by(PassageTag.tag)
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
tags = result.scalars().all()
|
||||
|
||||
return list(tags)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_tag_counts_for_archive_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> Dict[str, int]:
|
||||
"""Get tag counts for an archive.
|
||||
|
||||
Args:
|
||||
archive_id: ID of the archive
|
||||
actor: User performing the operation
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tag values to their counts
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
stmt = (
|
||||
select(PassageTag.tag, func.count(PassageTag.id).label("count"))
|
||||
.where(
|
||||
PassageTag.archive_id == archive_id,
|
||||
PassageTag.organization_id == actor.organization_id,
|
||||
PassageTag.is_deleted == False,
|
||||
)
|
||||
.group_by(PassageTag.tag)
|
||||
.order_by(PassageTag.tag)
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
return {row.tag: row.count for row in rows}
|
||||
|
||||
@@ -13,9 +13,7 @@ from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.tool_executor.tool_executor_base import ToolExecutor
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
@@ -143,7 +141,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
|
||||
try:
|
||||
# Get results using passage manager
|
||||
all_results = await AgentManager().query_agent_passages_async(
|
||||
all_results = await self.agent_manager.query_agent_passages_async(
|
||||
actor=actor,
|
||||
agent_id=agent_state.id,
|
||||
query_text=query,
|
||||
@@ -174,12 +172,12 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
await PassageManager().insert_passage(
|
||||
await self.passage_manager.insert_passage(
|
||||
agent_state=agent_state,
|
||||
text=content,
|
||||
actor=actor,
|
||||
)
|
||||
await AgentManager().rebuild_system_prompt_async(agent_id=agent_state.id, actor=actor, force=True)
|
||||
await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_state.id, actor=actor, force=True)
|
||||
return None
|
||||
|
||||
async def core_memory_append(self, agent_state: AgentState, actor: User, label: str, content: str) -> Optional[str]:
|
||||
@@ -198,7 +196,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
current_value = str(agent_state.memory.get_block(label).value)
|
||||
new_value = current_value + "\n" + str(content)
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
return None
|
||||
|
||||
async def core_memory_replace(
|
||||
@@ -227,7 +225,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'")
|
||||
new_value = current_value.replace(str(old_content), str(new_content))
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
return None
|
||||
|
||||
async def memory_replace(self, agent_state: AgentState, actor: User, label: str, old_str: str, new_str: str) -> str:
|
||||
@@ -291,7 +289,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
# Write the new content to the block
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
|
||||
await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
|
||||
# Create a snippet of the edited section
|
||||
SNIPPET_LINES = 3
|
||||
@@ -384,7 +382,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
# Write into the block
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
|
||||
await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The core memory block with label `{label}` has been edited. "
|
||||
@@ -437,7 +435,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
|
||||
agent_state.memory.update_block_value(label=label, value=new_memory)
|
||||
|
||||
await AgentManager().update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
await self.agent_manager.update_memory_if_changed_async(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The core memory block with label `{label}` has been edited. "
|
||||
|
||||
@@ -64,6 +64,7 @@ from letta.schemas.enums import (
|
||||
ProviderType,
|
||||
SandboxType,
|
||||
StepStatus,
|
||||
TagMatchMode,
|
||||
ToolType,
|
||||
)
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
@@ -3166,6 +3167,7 @@ def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_a
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata={"type": "test_specific"},
|
||||
tags=["python", "test", "agent"],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -3174,6 +3176,7 @@ def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_a
|
||||
assert passage.text == "Test agent passage via specific method"
|
||||
assert passage.archive_id == archive.id
|
||||
assert passage.source_id is None
|
||||
assert sorted(passage.tags) == sorted(["python", "test", "agent"])
|
||||
|
||||
|
||||
def test_create_source_passage_specific(server: SyncServer, default_user, default_file, default_source):
|
||||
@@ -3187,6 +3190,7 @@ def test_create_source_passage_specific(server: SyncServer, default_user, defaul
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata={"type": "test_specific"},
|
||||
tags=["document", "test", "source"],
|
||||
),
|
||||
file_metadata=default_file,
|
||||
actor=default_user,
|
||||
@@ -3196,6 +3200,7 @@ def test_create_source_passage_specific(server: SyncServer, default_user, defaul
|
||||
assert passage.text == "Test source passage via specific method"
|
||||
assert passage.source_id == default_source.id
|
||||
assert passage.archive_id is None
|
||||
assert sorted(passage.tags) == sorted(["document", "test", "source"])
|
||||
|
||||
|
||||
def test_create_agent_passage_validation(server: SyncServer, default_user, default_source, sarah_agent):
|
||||
@@ -3509,6 +3514,7 @@ async def test_create_many_agent_passages_async(server: SyncServer, default_user
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1 * i],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
tags=["batch", f"item{i}"] if i % 2 == 0 else ["batch", "odd"],
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
@@ -3520,6 +3526,8 @@ async def test_create_many_agent_passages_async(server: SyncServer, default_user
|
||||
assert passage.text == f"Batch agent passage {i}"
|
||||
assert passage.archive_id == archive.id
|
||||
assert passage.source_id is None
|
||||
expected_tags = ["batch", f"item{i}"] if i % 2 == 0 else ["batch", "odd"]
|
||||
assert passage.tags == expected_tags
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -3611,6 +3619,374 @@ def test_deprecated_methods_show_warnings(server: SyncServer, default_user, sara
|
||||
assert any("size is deprecated" in str(warning.message) for warning in w)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passage_tags_functionality(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test comprehensive tag functionality for passages."""
|
||||
from letta.schemas.enums import TagMatchMode
|
||||
|
||||
# Get or create default archive for the agent
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
|
||||
# Create passages with different tag combinations
|
||||
test_passages = [
|
||||
{"text": "Python programming tutorial", "tags": ["python", "tutorial", "programming"]},
|
||||
{"text": "Machine learning with Python", "tags": ["python", "ml", "ai"]},
|
||||
{"text": "JavaScript web development", "tags": ["javascript", "web", "frontend"]},
|
||||
{"text": "Python data science guide", "tags": ["python", "tutorial", "data"]},
|
||||
{"text": "No tags passage", "tags": None},
|
||||
]
|
||||
|
||||
created_passages = []
|
||||
for test_data in test_passages:
|
||||
passage = await server.passage_manager.create_agent_passage_async(
|
||||
PydanticPassage(
|
||||
text=test_data["text"],
|
||||
archive_id=archive.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
tags=test_data["tags"],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
created_passages.append(passage)
|
||||
|
||||
# Test that tags are properly stored (deduplicated)
|
||||
for i, passage in enumerate(created_passages):
|
||||
expected_tags = test_passages[i]["tags"]
|
||||
if expected_tags:
|
||||
assert set(passage.tags) == set(expected_tags)
|
||||
else:
|
||||
assert passage.tags is None
|
||||
|
||||
# Test querying with tag filtering (if Turbopuffer is enabled)
|
||||
if hasattr(server.agent_manager, "query_agent_passages_async"):
|
||||
# Test querying with python tag (should find 3 passages)
|
||||
python_results = await server.agent_manager.query_agent_passages_async(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
tags=["python"],
|
||||
tag_match_mode=TagMatchMode.ANY,
|
||||
)
|
||||
|
||||
python_texts = [p.text for p in python_results]
|
||||
assert len([t for t in python_texts if "Python" in t]) >= 2
|
||||
|
||||
# Test querying with multiple tags using ALL mode
|
||||
tutorial_python_results = await server.agent_manager.query_agent_passages_async(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
tags=["python", "tutorial"],
|
||||
tag_match_mode=TagMatchMode.ALL,
|
||||
)
|
||||
|
||||
tutorial_texts = [p.text for p in tutorial_python_results]
|
||||
expected_matches = [t for t in tutorial_texts if "tutorial" in t and "Python" in t]
|
||||
assert len(expected_matches) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comprehensive_tag_functionality(disable_turbopuffer, server: SyncServer, sarah_agent, default_user):
|
||||
"""Comprehensive test for tag functionality including dual storage and junction table."""
|
||||
|
||||
# Test 1: Create passages with tags and verify they're stored in both places
|
||||
passages_with_tags = []
|
||||
test_tags = {
|
||||
"passage1": ["important", "documentation", "python"],
|
||||
"passage2": ["important", "testing"],
|
||||
"passage3": ["documentation", "api"],
|
||||
"passage4": ["python", "testing", "api"],
|
||||
"passage5": [], # Test empty tags
|
||||
}
|
||||
|
||||
for i, (passage_key, tags) in enumerate(test_tags.items(), 1):
|
||||
text = f"Test passage {i} for comprehensive tag testing"
|
||||
created_passages = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent,
|
||||
text=text,
|
||||
actor=default_user,
|
||||
tags=tags if tags else None,
|
||||
)
|
||||
assert len(created_passages) == 1
|
||||
passage = created_passages[0]
|
||||
|
||||
# Verify tags are stored in the JSON column (deduplicated)
|
||||
if tags:
|
||||
assert set(passage.tags) == set(tags)
|
||||
else:
|
||||
assert passage.tags is None
|
||||
passages_with_tags.append(passage)
|
||||
|
||||
# Test 2: Verify unique tags for archive
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id,
|
||||
agent_name=sarah_agent.name,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
unique_tags = await server.passage_manager.get_unique_tags_for_archive_async(
|
||||
archive_id=archive.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Should have all unique tags: "important", "documentation", "python", "testing", "api"
|
||||
expected_unique_tags = {"important", "documentation", "python", "testing", "api"}
|
||||
assert set(unique_tags) == expected_unique_tags
|
||||
assert len(unique_tags) == 5
|
||||
|
||||
# Test 3: Verify tag counts
|
||||
tag_counts = await server.passage_manager.get_tag_counts_for_archive_async(
|
||||
archive_id=archive.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify counts
|
||||
assert tag_counts["important"] == 2 # passage1 and passage2
|
||||
assert tag_counts["documentation"] == 2 # passage1 and passage3
|
||||
assert tag_counts["python"] == 2 # passage1 and passage4
|
||||
assert tag_counts["testing"] == 2 # passage2 and passage4
|
||||
assert tag_counts["api"] == 2 # passage3 and passage4
|
||||
|
||||
# Test 4: Query passages with ANY tag matching
|
||||
any_results = await server.agent_manager.query_agent_passages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="test",
|
||||
limit=10,
|
||||
tags=["important", "api"],
|
||||
tag_match_mode=TagMatchMode.ANY,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Should match passages with "important" OR "api" tags (passages 1, 2, 3, 4)
|
||||
[p.text for p in any_results]
|
||||
assert len(any_results) >= 4
|
||||
|
||||
# Test 5: Query passages with ALL tag matching
|
||||
all_results = await server.agent_manager.query_agent_passages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="test",
|
||||
limit=10,
|
||||
tags=["python", "testing"],
|
||||
tag_match_mode=TagMatchMode.ALL,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Should only match passage4 which has both "python" AND "testing"
|
||||
all_passage_texts = [p.text for p in all_results]
|
||||
assert any("Test passage 4" in text for text in all_passage_texts)
|
||||
|
||||
# Test 6: Query with non-existent tags
|
||||
no_results = await server.agent_manager.query_agent_passages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="test",
|
||||
limit=10,
|
||||
tags=["nonexistent", "missing"],
|
||||
tag_match_mode=TagMatchMode.ANY,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Should return no results
|
||||
assert len(no_results) == 0
|
||||
|
||||
# Test 7: Verify tags CAN be updated (with junction table properly maintained)
|
||||
first_passage = passages_with_tags[0]
|
||||
new_tags = ["updated", "modified", "changed"]
|
||||
update_data = PydanticPassage(
|
||||
id=first_passage.id,
|
||||
text="Updated text",
|
||||
tags=new_tags,
|
||||
organization_id=first_passage.organization_id,
|
||||
archive_id=first_passage.archive_id,
|
||||
embedding=first_passage.embedding,
|
||||
embedding_config=first_passage.embedding_config,
|
||||
)
|
||||
|
||||
# Update should work and tags should be updated
|
||||
updated = await server.passage_manager.update_agent_passage_by_id_async(
|
||||
passage_id=first_passage.id,
|
||||
passage=update_data,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Both text and tags should be updated
|
||||
assert updated.text == "Updated text"
|
||||
assert set(updated.tags) == set(new_tags)
|
||||
|
||||
# Verify tags are properly updated in junction table
|
||||
updated_unique_tags = await server.passage_manager.get_unique_tags_for_archive_async(
|
||||
archive_id=archive.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Should include new tags and not include old "important", "documentation", "python" from passage1
|
||||
# But still have tags from other passages
|
||||
assert "updated" in updated_unique_tags
|
||||
assert "modified" in updated_unique_tags
|
||||
assert "changed" in updated_unique_tags
|
||||
|
||||
# Test 8: Delete a passage and verify cascade deletion of tags
|
||||
passage_to_delete = passages_with_tags[1] # passage2 with ["important", "testing"]
|
||||
|
||||
await server.passage_manager.delete_agent_passage_by_id_async(
|
||||
passage_id=passage_to_delete.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Get updated tag counts
|
||||
updated_tag_counts = await server.passage_manager.get_tag_counts_for_archive_async(
|
||||
archive_id=archive.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# "important" no longer exists (was in passage1 which was updated and passage2 which was deleted)
|
||||
assert "important" not in updated_tag_counts
|
||||
# "testing" count should decrease from 2 to 1 (only in passage4 now)
|
||||
assert updated_tag_counts["testing"] == 1
|
||||
|
||||
# Test 9: Batch create passages with tags
|
||||
batch_texts = [
|
||||
"Batch passage 1",
|
||||
"Batch passage 2",
|
||||
"Batch passage 3",
|
||||
]
|
||||
batch_tags = ["batch", "test", "multiple"]
|
||||
|
||||
batch_passages = []
|
||||
for text in batch_texts:
|
||||
passages = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent,
|
||||
text=text,
|
||||
actor=default_user,
|
||||
tags=batch_tags,
|
||||
)
|
||||
batch_passages.extend(passages)
|
||||
|
||||
# Verify all batch passages have the same tags
|
||||
for passage in batch_passages:
|
||||
assert set(passage.tags) == set(batch_tags)
|
||||
|
||||
# Test 10: Verify tag counts include batch passages
|
||||
final_tag_counts = await server.passage_manager.get_tag_counts_for_archive_async(
|
||||
archive_id=archive.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert final_tag_counts["batch"] == 3
|
||||
assert final_tag_counts["test"] == 3
|
||||
assert final_tag_counts["multiple"] == 3
|
||||
|
||||
# Test 11: Complex query with multiple tags and ALL matching
|
||||
complex_all_results = await server.agent_manager.query_agent_passages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="batch",
|
||||
limit=10,
|
||||
tags=["batch", "test", "multiple"],
|
||||
tag_match_mode=TagMatchMode.ALL,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Should match all 3 batch passages
|
||||
assert len(complex_all_results) >= 3
|
||||
|
||||
# Test 12: Empty tag list should return all passages
|
||||
all_passages = await server.agent_manager.query_agent_passages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="passage",
|
||||
limit=50,
|
||||
tags=[],
|
||||
tag_match_mode=TagMatchMode.ANY,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Should return passages based on text search only
|
||||
assert len(all_passages) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test edge cases for tag functionality."""
|
||||
|
||||
# Test 1: Very long tag names
|
||||
long_tag = "a" * 500 # 500 character tag
|
||||
passages = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent,
|
||||
text="Testing long tag names",
|
||||
actor=default_user,
|
||||
tags=[long_tag, "normal_tag"],
|
||||
)
|
||||
|
||||
assert len(passages) == 1
|
||||
assert long_tag in passages[0].tags
|
||||
|
||||
# Test 2: Special characters in tags
|
||||
special_tags = [
|
||||
"tag-with-dash",
|
||||
"tag_with_underscore",
|
||||
"tag.with.dots",
|
||||
"tag/with/slash",
|
||||
"tag:with:colon",
|
||||
"tag@with@at",
|
||||
"tag#with#hash",
|
||||
"tag with spaces",
|
||||
"CamelCaseTag",
|
||||
"数字标签",
|
||||
]
|
||||
|
||||
passages_special = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent,
|
||||
text="Testing special character tags",
|
||||
actor=default_user,
|
||||
tags=special_tags,
|
||||
)
|
||||
|
||||
assert len(passages_special) == 1
|
||||
assert set(passages_special[0].tags) == set(special_tags)
|
||||
|
||||
# Verify unique tags includes all special character tags
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id,
|
||||
agent_name=sarah_agent.name,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
unique_tags = await server.passage_manager.get_unique_tags_for_archive_async(
|
||||
archive_id=archive.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
for tag in special_tags:
|
||||
assert tag in unique_tags
|
||||
|
||||
# Test 3: Duplicate tags in input (should be deduplicated)
|
||||
duplicate_tags = ["tag1", "tag2", "tag1", "tag3", "tag2", "tag1"]
|
||||
passages_dup = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent,
|
||||
text="Testing duplicate tags",
|
||||
actor=default_user,
|
||||
tags=duplicate_tags,
|
||||
)
|
||||
|
||||
# Should only have unique tags (duplicates removed)
|
||||
assert len(passages_dup) == 1
|
||||
assert set(passages_dup[0].tags) == {"tag1", "tag2", "tag3"}
|
||||
assert len(passages_dup[0].tags) == 3 # Should be deduplicated
|
||||
|
||||
# Test 4: Case sensitivity in tags
|
||||
case_tags = ["Tag", "tag", "TAG", "tAg"]
|
||||
passages_case = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent,
|
||||
text="Testing case sensitive tags",
|
||||
actor=default_user,
|
||||
tags=case_tags,
|
||||
)
|
||||
|
||||
# All variations should be preserved (case-sensitive)
|
||||
assert len(passages_case) == 1
|
||||
assert set(passages_case[0].tags) == set(case_tags)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# User Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
from letta.config import LettaConfig
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.enums import TagMatchMode, VectorDBProvider
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
@@ -233,16 +233,14 @@ class TestTurbopufferIntegration:
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
# Query with organization filter
|
||||
# Query all passages (no tag filtering)
|
||||
query_vector = [0.15] * 1536
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id, query_embedding=query_vector, top_k=10, filters={"organization_id": "org-123"}
|
||||
)
|
||||
results = await client.query_passages(archive_id=archive_id, query_embedding=query_vector, top_k=10)
|
||||
|
||||
# Should only get passages from org-123
|
||||
assert len(results) >= 2 # At least the first two passages
|
||||
# Should get all passages
|
||||
assert len(results) == 3 # All three passages
|
||||
for passage, score in results:
|
||||
assert passage.organization_id == "org-123"
|
||||
assert passage.organization_id is not None
|
||||
|
||||
# Clean up
|
||||
await client.delete_passages(archive_id=archive_id, passage_ids=[d["id"] for d in test_data])
|
||||
@@ -394,6 +392,130 @@ class TestTurbopufferIntegration:
|
||||
except:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
|
||||
async def test_tag_filtering_with_real_tpuf(self, enable_turbopuffer):
|
||||
"""Test tag filtering functionality with AND and OR logic"""
|
||||
|
||||
import uuid
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
client = TurbopufferClient()
|
||||
archive_id = f"test-tags-{datetime.now().timestamp()}"
|
||||
org_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Insert passages with different tag combinations
|
||||
texts = [
|
||||
"Python programming tutorial",
|
||||
"Machine learning with Python",
|
||||
"JavaScript web development",
|
||||
"Python data science tutorial",
|
||||
"React JavaScript framework",
|
||||
]
|
||||
|
||||
tag_sets = [
|
||||
["python", "tutorial"],
|
||||
["python", "ml"],
|
||||
["javascript", "web"],
|
||||
["python", "tutorial", "data"],
|
||||
["javascript", "react"],
|
||||
]
|
||||
|
||||
embeddings = [[float(i), float(i + 5), float(i + 10)] for i in range(len(texts))]
|
||||
passage_ids = [f"passage-{str(uuid.uuid4())}" for _ in texts]
|
||||
|
||||
# Insert passages with tags
|
||||
for i, (text, tags, embedding, passage_id) in enumerate(zip(texts, tag_sets, embeddings, passage_ids)):
|
||||
await client.insert_archival_memories(
|
||||
archive_id=archive_id,
|
||||
text_chunks=[text],
|
||||
embeddings=[embedding],
|
||||
passage_ids=[passage_id],
|
||||
organization_id=org_id,
|
||||
tags=tags,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Test tag filtering with "any" mode (should find passages with any of the specified tags)
|
||||
python_any_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 6.0, 11.0],
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
tags=["python"],
|
||||
tag_match_mode=TagMatchMode.ANY,
|
||||
)
|
||||
|
||||
# Should find 3 passages with python tag
|
||||
python_passages = [passage for passage, _ in python_any_results]
|
||||
python_texts = [p.text for p in python_passages]
|
||||
assert len(python_passages) == 3
|
||||
assert "Python programming tutorial" in python_texts
|
||||
assert "Machine learning with Python" in python_texts
|
||||
assert "Python data science tutorial" in python_texts
|
||||
|
||||
# Test tag filtering with "all" mode
|
||||
python_tutorial_all_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 6.0, 11.0],
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
tags=["python", "tutorial"],
|
||||
tag_match_mode=TagMatchMode.ALL,
|
||||
)
|
||||
|
||||
# Should find 2 passages that have both python AND tutorial tags
|
||||
tutorial_passages = [passage for passage, _ in python_tutorial_all_results]
|
||||
tutorial_texts = [p.text for p in tutorial_passages]
|
||||
assert len(tutorial_passages) == 2
|
||||
assert "Python programming tutorial" in tutorial_texts
|
||||
assert "Python data science tutorial" in tutorial_texts
|
||||
|
||||
# Test tag filtering with FTS mode
|
||||
js_fts_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_text="javascript",
|
||||
search_mode="fts",
|
||||
top_k=10,
|
||||
tags=["javascript"],
|
||||
tag_match_mode=TagMatchMode.ANY,
|
||||
)
|
||||
|
||||
# Should find 2 passages with javascript tag
|
||||
js_passages = [passage for passage, _ in js_fts_results]
|
||||
js_texts = [p.text for p in js_passages]
|
||||
assert len(js_passages) == 2
|
||||
assert "JavaScript web development" in js_texts
|
||||
assert "React JavaScript framework" in js_texts
|
||||
|
||||
# Test hybrid search with tags
|
||||
python_hybrid_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[2.0, 7.0, 12.0],
|
||||
query_text="python programming",
|
||||
search_mode="hybrid",
|
||||
top_k=10,
|
||||
tags=["python"],
|
||||
tag_match_mode=TagMatchMode.ANY,
|
||||
vector_weight=0.6,
|
||||
fts_weight=0.4,
|
||||
)
|
||||
|
||||
# Should find python-tagged passages
|
||||
hybrid_passages = [passage for passage, _ in python_hybrid_results]
|
||||
hybrid_texts = [p.text for p in hybrid_passages]
|
||||
assert len(hybrid_passages) == 3
|
||||
assert all("Python" in text for text in hybrid_texts)
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
try:
|
||||
await client.delete_all_passages(archive_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True)
|
||||
class TestTurbopufferParametrized:
|
||||
|
||||
Reference in New Issue
Block a user