From ab4ccfca3184e7f2b6774afe943bcb4d0658d5fd Mon Sep 17 00:00:00 2001 From: cthomas Date: Mon, 12 Jan 2026 16:04:21 -0800 Subject: [PATCH] feat: add tags support to blocks (#8474) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add tags support to blocks * fix: add timestamps and org scoping to blocks_tags Addresses PR feedback: 1. Migration: Added timestamps (created_at, updated_at), soft delete (is_deleted), audit fields (_created_by_id, _last_updated_by_id), and organization_id to blocks_tags table for filtering support. Follows SQLite baseline pattern (composite PK of block_id+tag, no separate id column) to avoid insert failures. 2. ORM: Relationship already correct with lazy="raise" to prevent implicit joins and passive_deletes=True for efficient CASCADE deletes. 3. Schema: Changed normalize_tags() from Any to dict for type safety. 4. SQLite: Added blocks_tags to SQLite baseline schema to prevent table-not-found errors. 5. Code: Updated all tag row inserts to include organization_id. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta * fix: add ORM columns and update SQLite baseline for blocks_tags Fixes test failures (CompileError: Unconsumed column names: organization_id): 1. ORM: Added organization_id, timestamps, audit fields to BlocksTags ORM model to match database schema from migrations. 2. SQLite baseline: Added full column set to blocks_tags (organization_id, timestamps, audit fields) to match PostgreSQL schema. 3. Test: Added 'tags' to expected Block schema fields. This ensures SQLite and PostgreSQL have matching schemas and the ORM can consume all columns that the code inserts. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta * revert change to existing alembic migration * fix: remove passive_deletes and SQLite support for blocks_tags 1. Removed passive_deletes=True from Block.tags relationship to match AgentsTags pattern (neither have ondelete CASCADE in DB schema). 2. Removed SQLite branch from _replace_block_pivot_rows_async since blocks_tags table is PostgreSQL-only (migration skips SQLite). 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta * api sync --------- Co-authored-by: Letta --- .../cf3c4d025dbc_add_blocks_tags_table.py | 58 ++++ letta/orm/__init__.py | 1 + letta/orm/block.py | 7 + letta/orm/blocks_tags.py | 37 +++ letta/schemas/block.py | 19 +- letta/server/rest_api/routers/v1/blocks.py | 34 ++- letta/server/rest_api/routers/v1/tags.py | 29 +- letta/services/agent_manager.py | 49 ++- letta/services/block_manager.py | 278 ++++++++++++++++-- tests/managers/test_agent_manager.py | 1 + tests/managers/test_block_manager.py | 138 +++++++++ 11 files changed, 617 insertions(+), 34 deletions(-) create mode 100644 alembic/versions/cf3c4d025dbc_add_blocks_tags_table.py create mode 100644 letta/orm/blocks_tags.py diff --git a/alembic/versions/cf3c4d025dbc_add_blocks_tags_table.py b/alembic/versions/cf3c4d025dbc_add_blocks_tags_table.py new file mode 100644 index 00000000..f2be82a7 --- /dev/null +++ b/alembic/versions/cf3c4d025dbc_add_blocks_tags_table.py @@ -0,0 +1,58 @@ +"""Add blocks tags table + +Revision ID: cf3c4d025dbc +Revises: 27de0f58e076 +Create Date: 2026-01-08 23:36:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op +from letta.settings import settings + +# revision identifiers, used by Alembic. +revision: str = "cf3c4d025dbc" +down_revision: Union[str, None] = "27de0f58e076" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Skip this migration for SQLite + if not settings.letta_pg_uri_no_default: + return + + # Create blocks_tags table with timestamps and org scoping for filtering + # Note: Matches agents_tags structure but follows SQLite baseline pattern (no separate id column) + op.create_table( + "blocks_tags", + sa.Column("block_id", sa.String(), nullable=False), + sa.Column("tag", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["block_id"], + ["block.id"], + ), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("block_id", "tag"), + sa.UniqueConstraint("block_id", "tag", name="unique_block_tag"), + ) + + +def downgrade() -> None: + # Skip this migration for SQLite + if not settings.letta_pg_uri_no_default: + return + + op.drop_table("blocks_tags") diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index c44ae3bd..47fe5b1d 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -6,6 +6,7 @@ from letta.orm.base import Base from letta.orm.block import Block from letta.orm.block_history import BlockHistory from letta.orm.blocks_agents import BlocksAgents +from letta.orm.blocks_tags import BlocksTags from letta.orm.conversation import Conversation from letta.orm.conversation_messages import ConversationMessage from letta.orm.file import FileMetadata diff --git a/letta/orm/block.py b/letta/orm/block.py index 6ac48655..08f5fb28 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -12,6 +12,7 @@ from letta.schemas.block import Block as PydanticBlock, Human, Persona if TYPE_CHECKING: from letta.orm import Organization + from letta.orm.blocks_tags import BlocksTags from letta.orm.identity import Identity @@ -82,6 +83,12 @@ class Block(OrganizationMixin, SqlalchemyBase, ProjectMixin, TemplateEntityMixin back_populates="shared_blocks", passive_deletes=True, ) + tags: Mapped[List["BlocksTags"]] = relationship( + "BlocksTags", + back_populates="block", + cascade="all, delete-orphan", + lazy="raise", + ) def to_pydantic(self) -> Type: match self.label: diff --git a/letta/orm/blocks_tags.py b/letta/orm/blocks_tags.py new file mode 100644 index 00000000..23412df8 --- /dev/null +++ b/letta/orm/blocks_tags.py @@ -0,0 +1,37 @@ +from datetime import datetime +from typing import Optional + +from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, UniqueConstraint, func, text +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.base import Base + + +class BlocksTags(Base): + __tablename__ = "blocks_tags" + __table_args__ = ( + UniqueConstraint("block_id", "tag", name="unique_block_tag"), + Index("ix_blocks_tags_block_id_tag", "block_id", "tag"), + Index("ix_blocks_tags_tag_block_id", "tag", "block_id"), + ) + + # Primary key columns + block_id: Mapped[String] = mapped_column(String, ForeignKey("block.id"), primary_key=True) + tag: Mapped[str] = mapped_column(String, doc="The name of the tag associated with the block.", primary_key=True) + + # Organization scoping for filtering + organization_id: Mapped[str] = mapped_column(String, ForeignKey("organizations.id"), nullable=False) + + # Timestamps for filtering by date + created_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now()) + updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now()) + + # Soft delete support + is_deleted: Mapped[bool] = mapped_column(Boolean, server_default=text("FALSE")) + + # Audit fields + _created_by_id: Mapped[Optional[str]] = mapped_column(String, nullable=True) + _last_updated_by_id: Mapped[Optional[str]] = mapped_column(String, nullable=True) + + # Relationships + block: Mapped["Block"] = relationship("Block", back_populates="tags") diff --git a/letta/schemas/block.py b/letta/schemas/block.py index cec7c3c2..3103c1a9 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Optional +from typing import Any, List, Optional from pydantic import ConfigDict, Field, model_validator @@ -88,6 +88,17 @@ class Block(BaseBlock): created_by_id: Optional[str] = Field(None, description="The id of the user that made this Block.") last_updated_by_id: Optional[str] = Field(None, description="The id of the user that last updated this Block.") + # tags - using Optional with default [] to allow None input to become empty list + tags: Optional[List[str]] = Field(default=[], description="The tags associated with the block.") + + @model_validator(mode="before") + @classmethod + def normalize_tags(cls, data: dict) -> dict: + """Convert None tags to empty list.""" + if isinstance(data, dict) and data.get("tags") is None: + data["tags"] = [] + return data + class BlockResponse(Block): id: str = Field( @@ -142,6 +153,9 @@ class BlockUpdate(BaseBlock): value: Optional[str] = Field(None, description="Value of the block.") project_id: Optional[str] = Field(None, description="The associated project id.") + # tags + tags: Optional[List[str]] = Field(None, description="The tags to associate with the block.") + model_config = ConfigDict(extra="ignore") # Ignores extra fields @@ -157,6 +171,9 @@ class CreateBlock(BaseBlock): is_template: bool = False template_name: Optional[str] = Field(None, description="Name of the block if it is a template.") + # tags + tags: Optional[List[str]] = Field(None, description="The tags to associate with the block.") + @model_validator(mode="before") @classmethod def ensure_value_is_string(cls, data): diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index 536e243d..d297ab7a 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -33,6 +33,11 @@ async def list_blocks( identity_id: IdentityIdQuery = None, identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"), project_id: Optional[str] = Query(None, description="Search blocks by project id"), + tags: Optional[List[str]] = Query(None, description="List of tags to filter blocks by"), + match_all_tags: bool = Query( + False, + description="If True, only returns blocks that match ALL given tags. Otherwise, return blocks that have ANY of the passed-in tags.", + ), limit: Optional[int] = Query(50, description="Number of blocks to return"), before: Optional[str] = Query( None, @@ -98,19 +103,44 @@ async def list_blocks( after=after, ascending=(order == "asc"), show_hidden_blocks=show_hidden_blocks, + tags=tags, + match_all_tags=match_all_tags, ) @router.get("/count", response_model=int, operation_id="count_blocks") async def count_blocks( + label: BlockLabelQuery = None, + templates_only: bool = Query(False, description="Whether to include only templates"), + name: BlockNameQuery = None, + tags: Optional[List[str]] = Query(None, description="List of tags to filter blocks by"), + match_all_tags: bool = Query( + False, + description="If True, only counts blocks that match ALL given tags. Otherwise, counts blocks that have ANY of the passed-in tags.", + ), + project_id: Optional[str] = Query(None, description="Search blocks by project id"), server: SyncServer = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers), ): """ - Count all blocks created by a user. + Count all blocks with optional filtering. + Supports the same filters as list_blocks for consistent querying. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await server.block_manager.size_async(actor=actor) + + # If no filters are provided, use the simpler size_async method + if all(param is None or param is False for param in [label, templates_only, name, tags, project_id]): + return await server.block_manager.size_async(actor=actor) + + return await server.block_manager.count_blocks_async( + actor=actor, + label=label, + is_template=templates_only, + template_name=name, + tags=tags, + match_all_tags=match_all_tags, + project_id=project_id, + ) @router.post("/", response_model=BlockResponse, operation_id="create_block") diff --git a/letta/server/rest_api/routers/v1/tags.py b/letta/server/rest_api/routers/v1/tags.py index 52f2c75f..5ed14759 100644 --- a/letta/server/rest_api/routers/v1/tags.py +++ b/letta/server/rest_api/routers/v1/tags.py @@ -32,11 +32,34 @@ async def list_tags( headers: HeaderParams = Depends(get_headers), ): """ - Get the list of all agent tags that have been created. + Get the list of all tags (from agents and blocks) that have been created. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) text_filter = name or query_text - tags = await server.agent_manager.list_tags_async( + + # Get tags from both agents and blocks + agent_tags = await server.agent_manager.list_tags_async( actor=actor, before=before, after=after, limit=limit, query_text=text_filter, ascending=(order == "asc") ) - return tags + block_tags = await server.block_manager.list_tags_async(actor=actor, query_text=text_filter) + + # Merge and deduplicate, then sort and apply pagination + all_tags = sorted(set(agent_tags) | set(block_tags), reverse=(order == "desc")) + + # Apply cursor-based pagination on merged results + if after: + if order == "asc": + all_tags = [t for t in all_tags if t > after] + else: + all_tags = [t for t in all_tags if t < after] + if before: + if order == "asc": + all_tags = [t for t in all_tags if t < before] + else: + all_tags = [t for t in all_tags if t > before] + + # Apply limit + if limit: + all_tags = all_tags[:limit] + + return all_tags diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 9f17a4f8..da2f5913 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -36,6 +36,7 @@ from letta.orm import ( ArchivalPassage, Block as BlockModel, BlocksAgents, + BlocksTags, Group as GroupModel, GroupsAgents, IdentitiesAgents, @@ -1929,7 +1930,10 @@ class AgentManager: agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) for block in agent.core_memory: if block.label == block_label: - return block.to_pydantic() + pydantic_block = block.to_pydantic() + tags_result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block.id)) + pydantic_block.tags = [row[0] for row in tags_result.fetchall()] + return pydantic_block raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'") @enforce_types @@ -1941,7 +1945,7 @@ class AgentManager: block_update: BlockUpdate, actor: PydanticUser, ) -> PydanticBlock: - """Gets a block attached to an agent by its label.""" + """Modifies a block attached to an agent by its label.""" async with db_registry.async_session() as session: matched_block = None agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) @@ -1954,6 +1958,9 @@ class AgentManager: update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + # Extract tags from update data (it's not a column on the block table) + new_tags = update_data.pop("tags", None) + # Validate limit constraints before updating validate_block_limit_constraint(update_data, matched_block) @@ -1961,7 +1968,23 @@ class AgentManager: setattr(matched_block, key, value) await matched_block.update_async(session, actor=actor) - return matched_block.to_pydantic() + + if new_tags is not None: + await BlockManager._replace_block_pivot_rows_async( + session, + BlocksTags.__table__, + matched_block.id, + [{"block_id": matched_block.id, "tag": tag} for tag in new_tags], + ) + + pydantic_block = matched_block.to_pydantic() + if new_tags is not None: + pydantic_block.tags = new_tags + else: + tags_result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == matched_block.id)) + pydantic_block.tags = [row[0] for row in tags_result.fetchall()] + + return pydantic_block @enforce_types @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) @@ -3015,7 +3038,25 @@ class AgentManager: result = await session.execute(query) blocks = result.scalars().all() - return [block.to_pydantic() for block in blocks] + if not blocks: + return [] + + block_ids = [block.id for block in blocks] + tags_result = await session.execute(select(BlocksTags.block_id, BlocksTags.tag).where(BlocksTags.block_id.in_(block_ids))) + tags_by_block: Dict[str, List[str]] = {} + for row in tags_result.fetchall(): + block_id, tag = row + if block_id not in tags_by_block: + tags_by_block[block_id] = [] + tags_by_block[block_id].append(tag) + + pydantic_blocks = [] + for block in blocks: + pydantic_block = block.to_pydantic() + pydantic_block.tags = tags_by_block.get(block.id, []) + pydantic_blocks.append(pydantic_block) + + return pydantic_blocks @enforce_types @trace_method diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 5229ba1e..3876bc40 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -2,8 +2,12 @@ import asyncio from datetime import datetime from typing import Dict, List, Optional -from sqlalchemy import and_, delete, func, or_, select +import sqlalchemy as sa +from sqlalchemy import and_, delete, exists, func, literal, or_, select +from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import noload +from sqlalchemy.sql.expression import tuple_ from letta.errors import LettaInvalidArgumentError from letta.log import get_logger @@ -11,7 +15,9 @@ from letta.orm.agent import Agent as AgentModel from letta.orm.block import Block as BlockModel from letta.orm.block_history import BlockHistory from letta.orm.blocks_agents import BlocksAgents +from letta.orm.blocks_tags import BlocksTags from letta.orm.errors import NoResultFound +from letta.orm.sqlalchemy_base import AccessType from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.block import Block as PydanticBlock, BlockUpdate @@ -84,6 +90,68 @@ def validate_block_creation(block_data: dict) -> None: class BlockManager: """Manager class to handle business logic related to Blocks.""" + # ====================================================================================================================== + # Helper methods for pivot tables + # ====================================================================================================================== + + @staticmethod + async def _bulk_insert_block_pivot_async(session, table, rows: list[dict]): + """Bulk insert rows into a pivot table, ignoring conflicts.""" + if not rows: + return + + dialect = session.bind.dialect.name + if dialect == "postgresql": + stmt = pg_insert(table).values(rows).on_conflict_do_nothing() + elif dialect == "sqlite": + stmt = sa.insert(table).values(rows).prefix_with("OR IGNORE") + else: + # fallback: filter out exact-duplicate dicts in Python + seen = set() + filtered = [] + for row in rows: + key = tuple(sorted(row.items())) + if key not in seen: + seen.add(key) + filtered.append(row) + stmt = sa.insert(table).values(filtered) + + await session.execute(stmt) + + @staticmethod + async def _replace_block_pivot_rows_async(session, table, block_id: str, rows: list[dict]): + """ + Replace all pivot rows for a block atomically using MERGE pattern. + Only supports PostgreSQL (blocks_tags table not supported on SQLite). + """ + dialect = session.bind.dialect.name + + if dialect == "postgresql": + if rows: + # separate upsert and delete operations + stmt = pg_insert(table).values(rows) + stmt = stmt.on_conflict_do_nothing() + await session.execute(stmt) + + # delete rows not in new set + pk_names = [c.name for c in table.primary_key.columns] + new_keys = [tuple(r[c] for c in pk_names) for r in rows] + await session.execute( + delete(table).where(table.c.block_id == block_id, ~tuple_(*[table.c[c] for c in pk_names]).in_(new_keys)) + ) + else: + # if no rows to insert, just delete all + await session.execute(delete(table).where(table.c.block_id == block_id)) + else: + # fallback: use original DELETE + INSERT pattern + await session.execute(delete(table).where(table.c.block_id == block_id)) + if rows: + await BlockManager._bulk_insert_block_pivot_async(session, table, rows) + + # ====================================================================================================================== + # Basic CRUD operations + # ====================================================================================================================== + @enforce_types @trace_method async def create_or_update_block_async(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock: @@ -95,11 +163,23 @@ class BlockManager: else: async with db_registry.async_session() as session: data = block.model_dump(to_orm=True, exclude_none=True) + # Extract tags before creating the ORM model (tags is not a column) + tags = data.pop("tags", None) or [] + # Validate block creation constraints validate_block_creation(data) - block = BlockModel(**data, organization_id=actor.organization_id) - await block.create_async(session, actor=actor, no_commit=True, no_refresh=True) - pydantic_block = block.to_pydantic() + block_model = BlockModel(**data, organization_id=actor.organization_id) + await block_model.create_async(session, actor=actor, no_commit=True, no_refresh=True) + + if tags: + await self._bulk_insert_block_pivot_async( + session, + BlocksTags.__table__, + [{"block_id": block_model.id, "tag": tag, "organization_id": actor.organization_id} for tag in tags], + ) + + pydantic_block = block_model.to_pydantic() + pydantic_block.tags = tags # context manager now handles commits # await session.commit() return pydantic_block @@ -119,10 +199,13 @@ class BlockManager: return [] async with db_registry.async_session() as session: - # Validate all blocks before creating any validated_data = [] - for block in blocks: + tags_by_index: Dict[int, List[str]] = {} + for i, block in enumerate(blocks): block_data = block.model_dump(to_orm=True, exclude_none=True) + tags = block_data.pop("tags", None) or [] + if tags: + tags_by_index[i] = tags validate_block_creation(block_data) validated_data.append(block_data) @@ -130,9 +213,22 @@ class BlockManager: created_models = await BlockModel.batch_create_async( items=block_models, db_session=session, actor=actor, no_commit=True, no_refresh=True ) - result = [m.to_pydantic() for m in created_models] - # context manager now handles commits - # await session.commit() + + all_tag_rows = [] + for i, model in enumerate(created_models): + if i in tags_by_index: + for tag in tags_by_index[i]: + all_tag_rows.append({"block_id": model.id, "tag": tag, "organization_id": actor.organization_id}) + + if all_tag_rows: + await self._bulk_insert_block_pivot_async(session, BlocksTags.__table__, all_tag_rows) + + result = [] + for i, model in enumerate(created_models): + pydantic_block = model.to_pydantic() + pydantic_block.tags = tags_by_index.get(i, []) + result.append(pydantic_block) + return result @enforce_types @@ -144,6 +240,9 @@ class BlockManager: block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + # Extract tags from update data (it's not a column on the block table) + new_tags = update_data.pop("tags", None) + # Validate limit constraints before updating validate_block_limit_constraint(update_data, block) @@ -151,7 +250,22 @@ class BlockManager: setattr(block, key, value) await block.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) + + if new_tags is not None: + await self._replace_block_pivot_rows_async( + session, + BlocksTags.__table__, + block_id, + [{"block_id": block_id, "tag": tag, "organization_id": block.organization_id} for tag in new_tags], + ) + pydantic_block = block.to_pydantic() + if new_tags is not None: + pydantic_block.tags = new_tags + else: + result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block_id)) + pydantic_block.tags = [row[0] for row in result.fetchall()] + # context manager now handles commits # await session.commit() return pydantic_block @@ -164,6 +278,8 @@ class BlockManager: async with db_registry.async_session() as session: # First, delete all references in blocks_agents table await session.execute(delete(BlocksAgents).where(BlocksAgents.block_id == block_id)) + # Also delete all tags associated with this block + await session.execute(delete(BlocksTags).where(BlocksTags.block_id == block_id)) await session.flush() # Then delete the block itself @@ -192,19 +308,18 @@ class BlockManager: connected_to_agents_count_eq: Optional[List[int]] = None, ascending: bool = True, show_hidden_blocks: Optional[bool] = None, + tags: Optional[List[str]] = None, + match_all_tags: bool = False, ) -> List[PydanticBlock]: """Async version of get_blocks method. Retrieve blocks based on various optional filters.""" - from sqlalchemy import select - from sqlalchemy.orm import noload - - from letta.orm.sqlalchemy_base import AccessType - async with db_registry.async_session() as session: # Start with a basic query query = select(BlockModel) # Explicitly avoid loading relationships - query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups)) + query = query.options( + noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags) + ) # Apply access control query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) @@ -295,6 +410,20 @@ class BlockManager: query = query.join(BlockModel.identities).filter(BlockModel.identities.property.mapper.class_.id == identity_id) needs_distinct = True + if tags: + if match_all_tags: + # Must match ALL tags - use subquery with having count + tag_subquery = ( + select(BlocksTags.block_id) + .where(BlocksTags.tag.in_(tags)) + .group_by(BlocksTags.block_id) + .having(func.count(BlocksTags.tag) == literal(len(tags))) + ) + query = query.where(BlockModel.id.in_(tag_subquery)) + else: + # Must match ANY tag + query = query.where(exists().where((BlocksTags.block_id == BlockModel.id) & (BlocksTags.tag.in_(tags)))) + if after: result = (await session.execute(select(BlockModel.created_at, BlockModel.id).where(BlockModel.id == after))).first() if result: @@ -353,17 +482,39 @@ class BlockManager: result = await session.execute(query) blocks = result.scalars().all() - return [block.to_pydantic() for block in blocks] + if not blocks: + return [] + + block_ids = [block.id for block in blocks] + tags_result = await session.execute(select(BlocksTags.block_id, BlocksTags.tag).where(BlocksTags.block_id.in_(block_ids))) + tags_by_block: Dict[str, List[str]] = {} + for row in tags_result.fetchall(): + block_id, tag = row + if block_id not in tags_by_block: + tags_by_block[block_id] = [] + tags_by_block[block_id].append(tag) + + pydantic_blocks = [] + for block in blocks: + pydantic_block = block.to_pydantic() + pydantic_block.tags = tags_by_block.get(block.id, []) + pydantic_blocks.append(pydantic_block) + + return pydantic_blocks @enforce_types @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) @trace_method async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]: - """Retrieve a block by its name.""" + """Retrieve a block by its ID, including tags.""" async with db_registry.async_session() as session: try: block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) - return block.to_pydantic() + pydantic_block = block.to_pydantic() + tags_result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block_id)) + pydantic_block.tags = [row[0] for row in tags_result.fetchall()] + + return pydantic_block except NoResultFound: return None @@ -371,11 +522,6 @@ class BlockManager: @trace_method async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: """Retrieve blocks by their ids without loading unnecessary relationships. Async implementation.""" - from sqlalchemy import select - from sqlalchemy.orm import noload - - from letta.orm.sqlalchemy_base import AccessType - if not block_ids: return [] @@ -387,7 +533,9 @@ class BlockManager: query = query.where(BlockModel.id.in_(block_ids)) # Explicitly avoid loading relationships - query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups)) + query = query.options( + noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags) + ) # Apply access control if actor is provided if actor: @@ -522,6 +670,88 @@ class BlockManager: async with db_registry.async_session() as session: return await BlockModel.size_async(db_session=session, actor=actor) + @enforce_types + @trace_method + async def count_blocks_async( + self, + actor: PydanticUser, + label: Optional[str] = None, + is_template: Optional[bool] = None, + template_name: Optional[str] = None, + project_id: Optional[str] = None, + tags: Optional[List[str]] = None, + match_all_tags: bool = False, + ) -> int: + """ + Count blocks with optional filtering. Supports same filters as get_blocks_async. + """ + async with db_registry.async_session() as session: + query = select(func.count(BlockModel.id)) + + # Apply access control + query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) + query = query.where(BlockModel.organization_id == actor.organization_id) + + # Apply filters + if label: + query = query.where(BlockModel.label == label) + if is_template is not None: + query = query.where(BlockModel.is_template == is_template) + if template_name: + query = query.where(BlockModel.template_name == template_name) + if project_id: + query = query.where(BlockModel.project_id == project_id) + + # Apply tag filtering + if tags: + if match_all_tags: + tag_subquery = ( + select(BlocksTags.block_id) + .where(BlocksTags.tag.in_(tags)) + .group_by(BlocksTags.block_id) + .having(func.count(BlocksTags.tag) == literal(len(tags))) + ) + query = query.where(BlockModel.id.in_(tag_subquery)) + else: + query = query.where(exists().where((BlocksTags.block_id == BlockModel.id) & (BlocksTags.tag.in_(tags)))) + + result = await session.execute(query) + return result.scalar() or 0 + + @enforce_types + @trace_method + async def list_tags_async( + self, + actor: PydanticUser, + query_text: Optional[str] = None, + ) -> List[str]: + """ + Get all unique block tags for the actor's organization. + + Args: + actor: User performing the action. + query_text: Filter tags by text search. + + Returns: + List[str]: List of unique block tags. + """ + async with db_registry.async_session() as session: + query = ( + select(BlocksTags.tag) + .join(BlockModel, BlocksTags.block_id == BlockModel.id) + .where(BlockModel.organization_id == actor.organization_id) + .distinct() + ) + + if query_text: + if settings.database_engine is DatabaseChoice.POSTGRES: + query = query.where(BlocksTags.tag.ilike(f"%{query_text}%")) + else: + query = query.where(func.lower(BlocksTags.tag).like(func.lower(f"%{query_text}%"))) + + result = await session.execute(query) + return [row[0] for row in result.fetchall()] + # Block History Functions @enforce_types diff --git a/tests/managers/test_agent_manager.py b/tests/managers/test_agent_manager.py index 72bcd2fe..15767114 100644 --- a/tests/managers/test_agent_manager.py +++ b/tests/managers/test_agent_manager.py @@ -1766,6 +1766,7 @@ async def test_agent_state_schema_unchanged(server: SyncServer): "hidden", "created_by_id", "last_updated_by_id", + "tags", } actual_block_fields = set(block_fields.keys()) if actual_block_fields != expected_block_fields: diff --git a/tests/managers/test_block_manager.py b/tests/managers/test_block_manager.py index f4007f24..65f6d3ba 100644 --- a/tests/managers/test_block_manager.py +++ b/tests/managers/test_block_manager.py @@ -1282,3 +1282,141 @@ async def test_redo_concurrency_stale(server: SyncServer, default_user): # an out-of-date version of the block with pytest.raises(StaleDataError): await block_manager.redo_checkpoint_block(block_id=block.id, actor=default_user, use_preloaded_block=block_s2) + + +# ====================================================================================================================== +# Block Tags Tests +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_block_tags_create_and_update(server: SyncServer, default_user): + """Test creating a block with tags and updating tags""" + block_manager = BlockManager() + + # Create a block with tags + block = PydanticBlock( + label="test_tags", + value="Block with tags", + tags=["tag1", "tag2", "important"], + ) + created_block = await block_manager.create_or_update_block_async(block, actor=default_user) + + # Verify tags were saved + assert set(created_block.tags) == {"tag1", "tag2", "important"} + + # Update the block with new tags + from letta.schemas.block import BlockUpdate + + updated_block = await block_manager.update_block_async( + block_id=created_block.id, + block_update=BlockUpdate(tags=["tag1", "new_tag"]), + actor=default_user, + ) + + # Verify tags were updated + assert set(updated_block.tags) == {"tag1", "new_tag"} + + # Clear all tags + cleared_block = await block_manager.update_block_async( + block_id=created_block.id, + block_update=BlockUpdate(tags=[]), + actor=default_user, + ) + assert cleared_block.tags == [] + + +@pytest.mark.asyncio +async def test_block_tags_filter_any(server: SyncServer, default_user): + """Test filtering blocks by tags (match ANY)""" + block_manager = BlockManager() + + # Create blocks with different tags + block1 = await block_manager.create_or_update_block_async( + PydanticBlock(label="b1", value="v1", tags=["alpha", "beta"]), + actor=default_user, + ) + block2 = await block_manager.create_or_update_block_async( + PydanticBlock(label="b2", value="v2", tags=["beta", "gamma"]), + actor=default_user, + ) + block3 = await block_manager.create_or_update_block_async( + PydanticBlock(label="b3", value="v3", tags=["delta"]), + actor=default_user, + ) + + # Filter by tag "beta" (match ANY) + results = await block_manager.get_blocks_async(actor=default_user, tags=["beta"], match_all_tags=False) + result_ids = {b.id for b in results} + assert block1.id in result_ids + assert block2.id in result_ids + assert block3.id not in result_ids + + # Filter by tag "alpha" or "delta" (match ANY) + results = await block_manager.get_blocks_async(actor=default_user, tags=["alpha", "delta"], match_all_tags=False) + result_ids = {b.id for b in results} + assert block1.id in result_ids + assert block2.id not in result_ids + assert block3.id in result_ids + + +@pytest.mark.asyncio +async def test_block_tags_filter_all(server: SyncServer, default_user): + """Test filtering blocks by tags (match ALL)""" + block_manager = BlockManager() + + # Create blocks with different tags + block1 = await block_manager.create_or_update_block_async( + PydanticBlock(label="b1", value="v1", tags=["x", "y", "z"]), + actor=default_user, + ) + block2 = await block_manager.create_or_update_block_async( + PydanticBlock(label="b2", value="v2", tags=["x", "y"]), + actor=default_user, + ) + block3 = await block_manager.create_or_update_block_async( + PydanticBlock(label="b3", value="v3", tags=["x"]), + actor=default_user, + ) + + # Filter by tags "x" AND "y" (match ALL) + results = await block_manager.get_blocks_async(actor=default_user, tags=["x", "y"], match_all_tags=True) + result_ids = {b.id for b in results} + assert block1.id in result_ids + assert block2.id in result_ids + assert block3.id not in result_ids + + # Filter by tags "x", "y", AND "z" (match ALL) + results = await block_manager.get_blocks_async(actor=default_user, tags=["x", "y", "z"], match_all_tags=True) + result_ids = {b.id for b in results} + assert block1.id in result_ids + assert block2.id not in result_ids + assert block3.id not in result_ids + + +@pytest.mark.asyncio +async def test_block_tags_count(server: SyncServer, default_user): + """Test counting blocks with tag filters""" + block_manager = BlockManager() + + # Create blocks with different tags + await block_manager.create_or_update_block_async( + PydanticBlock(label="c1", value="v1", tags=["count_test", "a"]), + actor=default_user, + ) + await block_manager.create_or_update_block_async( + PydanticBlock(label="c2", value="v2", tags=["count_test", "b"]), + actor=default_user, + ) + await block_manager.create_or_update_block_async( + PydanticBlock(label="c3", value="v3", tags=["other"]), + actor=default_user, + ) + + # Count blocks with tag "count_test" + count = await block_manager.count_blocks_async(actor=default_user, tags=["count_test"], match_all_tags=False) + assert count == 2 + + # Count blocks with tags "count_test" AND "a" + count = await block_manager.count_blocks_async(actor=default_user, tags=["count_test", "a"], match_all_tags=True) + assert count == 1