feat: add tags support to blocks (#8474)

* feat: add tags support to blocks

* fix: add timestamps and org scoping to blocks_tags

Addresses PR feedback:

1. Migration: Added timestamps (created_at, updated_at), soft delete
   (is_deleted), audit fields (_created_by_id, _last_updated_by_id),
   and organization_id to blocks_tags table for filtering support.
   Follows SQLite baseline pattern (composite PK of block_id+tag, no
   separate id column) to avoid insert failures.

2. ORM: Relationship already correct with lazy="raise" to prevent
   implicit joins and passive_deletes=True for efficient CASCADE deletes.

3. Schema: Changed normalize_tags() from Any to dict for type safety.

4. SQLite: Added blocks_tags to SQLite baseline schema to prevent
   table-not-found errors.

5. Code: Updated all tag row inserts to include organization_id.

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix: add ORM columns and update SQLite baseline for blocks_tags

Fixes test failures (CompileError: Unconsumed column names: organization_id):

1. ORM: Added organization_id, timestamps, audit fields to BlocksTags
   ORM model to match database schema from migrations.

2. SQLite baseline: Added full column set to blocks_tags (organization_id,
   timestamps, audit fields) to match PostgreSQL schema.

3. Test: Added 'tags' to expected Block schema fields.

This ensures SQLite and PostgreSQL have matching schemas and the ORM
can consume all columns that the code inserts.

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* revert change to existing alembic migration

* fix: remove passive_deletes and SQLite support for blocks_tags

1. Removed passive_deletes=True from Block.tags relationship to match
   AgentsTags pattern (neither have ondelete CASCADE in DB schema).

2. Removed SQLite branch from _replace_block_pivot_rows_async since
   blocks_tags table is PostgreSQL-only (migration skips SQLite).

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* api sync

---------

Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
cthomas
2026-01-12 16:04:21 -08:00
committed by Sarah Wooders
parent c550457b60
commit ab4ccfca31
11 changed files with 617 additions and 34 deletions

View File

@@ -0,0 +1,58 @@
"""Add blocks tags table
Revision ID: cf3c4d025dbc
Revises: 27de0f58e076
Create Date: 2026-01-08 23:36:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from letta.settings import settings
# revision identifiers, used by Alembic.
revision: str = "cf3c4d025dbc"
down_revision: Union[str, None] = "27de0f58e076"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Skip this migration for SQLite
if not settings.letta_pg_uri_no_default:
return
# Create blocks_tags table with timestamps and org scoping for filtering
# Note: Matches agents_tags structure but follows SQLite baseline pattern (no separate id column)
op.create_table(
"blocks_tags",
sa.Column("block_id", sa.String(), nullable=False),
sa.Column("tag", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["block_id"],
["block.id"],
),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("block_id", "tag"),
sa.UniqueConstraint("block_id", "tag", name="unique_block_tag"),
)
def downgrade() -> None:
# Skip this migration for SQLite
if not settings.letta_pg_uri_no_default:
return
op.drop_table("blocks_tags")

View File

@@ -6,6 +6,7 @@ from letta.orm.base import Base
from letta.orm.block import Block
from letta.orm.block_history import BlockHistory
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.blocks_tags import BlocksTags
from letta.orm.conversation import Conversation
from letta.orm.conversation_messages import ConversationMessage
from letta.orm.file import FileMetadata

View File

@@ -12,6 +12,7 @@ from letta.schemas.block import Block as PydanticBlock, Human, Persona
if TYPE_CHECKING:
from letta.orm import Organization
from letta.orm.blocks_tags import BlocksTags
from letta.orm.identity import Identity
@@ -82,6 +83,12 @@ class Block(OrganizationMixin, SqlalchemyBase, ProjectMixin, TemplateEntityMixin
back_populates="shared_blocks",
passive_deletes=True,
)
tags: Mapped[List["BlocksTags"]] = relationship(
"BlocksTags",
back_populates="block",
cascade="all, delete-orphan",
lazy="raise",
)
def to_pydantic(self) -> Type:
match self.label:

37
letta/orm/blocks_tags.py Normal file
View File

@@ -0,0 +1,37 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, UniqueConstraint, func, text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.base import Base
class BlocksTags(Base):
__tablename__ = "blocks_tags"
__table_args__ = (
UniqueConstraint("block_id", "tag", name="unique_block_tag"),
Index("ix_blocks_tags_block_id_tag", "block_id", "tag"),
Index("ix_blocks_tags_tag_block_id", "tag", "block_id"),
)
# Primary key columns
block_id: Mapped[String] = mapped_column(String, ForeignKey("block.id"), primary_key=True)
tag: Mapped[str] = mapped_column(String, doc="The name of the tag associated with the block.", primary_key=True)
# Organization scoping for filtering
organization_id: Mapped[str] = mapped_column(String, ForeignKey("organizations.id"), nullable=False)
# Timestamps for filtering by date
created_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now())
# Soft delete support
is_deleted: Mapped[bool] = mapped_column(Boolean, server_default=text("FALSE"))
# Audit fields
_created_by_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
_last_updated_by_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
# Relationships
block: Mapped["Block"] = relationship("Block", back_populates="tags")

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, Optional
from typing import Any, List, Optional
from pydantic import ConfigDict, Field, model_validator
@@ -88,6 +88,17 @@ class Block(BaseBlock):
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Block.")
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that last updated this Block.")
# tags - using Optional with default [] to allow None input to become empty list
tags: Optional[List[str]] = Field(default=[], description="The tags associated with the block.")
@model_validator(mode="before")
@classmethod
def normalize_tags(cls, data: dict) -> dict:
"""Convert None tags to empty list."""
if isinstance(data, dict) and data.get("tags") is None:
data["tags"] = []
return data
class BlockResponse(Block):
id: str = Field(
@@ -142,6 +153,9 @@ class BlockUpdate(BaseBlock):
value: Optional[str] = Field(None, description="Value of the block.")
project_id: Optional[str] = Field(None, description="The associated project id.")
# tags
tags: Optional[List[str]] = Field(None, description="The tags to associate with the block.")
model_config = ConfigDict(extra="ignore") # Ignores extra fields
@@ -157,6 +171,9 @@ class CreateBlock(BaseBlock):
is_template: bool = False
template_name: Optional[str] = Field(None, description="Name of the block if it is a template.")
# tags
tags: Optional[List[str]] = Field(None, description="The tags to associate with the block.")
@model_validator(mode="before")
@classmethod
def ensure_value_is_string(cls, data):

View File

@@ -33,6 +33,11 @@ async def list_blocks(
identity_id: IdentityIdQuery = None,
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
project_id: Optional[str] = Query(None, description="Search blocks by project id"),
tags: Optional[List[str]] = Query(None, description="List of tags to filter blocks by"),
match_all_tags: bool = Query(
False,
description="If True, only returns blocks that match ALL given tags. Otherwise, return blocks that have ANY of the passed-in tags.",
),
limit: Optional[int] = Query(50, description="Number of blocks to return"),
before: Optional[str] = Query(
None,
@@ -98,19 +103,44 @@ async def list_blocks(
after=after,
ascending=(order == "asc"),
show_hidden_blocks=show_hidden_blocks,
tags=tags,
match_all_tags=match_all_tags,
)
@router.get("/count", response_model=int, operation_id="count_blocks")
async def count_blocks(
label: BlockLabelQuery = None,
templates_only: bool = Query(False, description="Whether to include only templates"),
name: BlockNameQuery = None,
tags: Optional[List[str]] = Query(None, description="List of tags to filter blocks by"),
match_all_tags: bool = Query(
False,
description="If True, only counts blocks that match ALL given tags. Otherwise, counts blocks that have ANY of the passed-in tags.",
),
project_id: Optional[str] = Query(None, description="Search blocks by project id"),
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""
Count all blocks created by a user.
Count all blocks with optional filtering.
Supports the same filters as list_blocks for consistent querying.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.block_manager.size_async(actor=actor)
# If no filters are provided, use the simpler size_async method
if all(param is None or param is False for param in [label, templates_only, name, tags, project_id]):
return await server.block_manager.size_async(actor=actor)
return await server.block_manager.count_blocks_async(
actor=actor,
label=label,
is_template=templates_only,
template_name=name,
tags=tags,
match_all_tags=match_all_tags,
project_id=project_id,
)
@router.post("/", response_model=BlockResponse, operation_id="create_block")

View File

@@ -32,11 +32,34 @@ async def list_tags(
headers: HeaderParams = Depends(get_headers),
):
"""
Get the list of all agent tags that have been created.
Get the list of all tags (from agents and blocks) that have been created.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
text_filter = name or query_text
tags = await server.agent_manager.list_tags_async(
# Get tags from both agents and blocks
agent_tags = await server.agent_manager.list_tags_async(
actor=actor, before=before, after=after, limit=limit, query_text=text_filter, ascending=(order == "asc")
)
return tags
block_tags = await server.block_manager.list_tags_async(actor=actor, query_text=text_filter)
# Merge and deduplicate, then sort and apply pagination
all_tags = sorted(set(agent_tags) | set(block_tags), reverse=(order == "desc"))
# Apply cursor-based pagination on merged results
if after:
if order == "asc":
all_tags = [t for t in all_tags if t > after]
else:
all_tags = [t for t in all_tags if t < after]
if before:
if order == "asc":
all_tags = [t for t in all_tags if t < before]
else:
all_tags = [t for t in all_tags if t > before]
# Apply limit
if limit:
all_tags = all_tags[:limit]
return all_tags

View File

@@ -36,6 +36,7 @@ from letta.orm import (
ArchivalPassage,
Block as BlockModel,
BlocksAgents,
BlocksTags,
Group as GroupModel,
GroupsAgents,
IdentitiesAgents,
@@ -1929,7 +1930,10 @@ class AgentManager:
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
for block in agent.core_memory:
if block.label == block_label:
return block.to_pydantic()
pydantic_block = block.to_pydantic()
tags_result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block.id))
pydantic_block.tags = [row[0] for row in tags_result.fetchall()]
return pydantic_block
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
@enforce_types
@@ -1941,7 +1945,7 @@ class AgentManager:
block_update: BlockUpdate,
actor: PydanticUser,
) -> PydanticBlock:
"""Gets a block attached to an agent by its label."""
"""Modifies a block attached to an agent by its label."""
async with db_registry.async_session() as session:
matched_block = None
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
@@ -1954,6 +1958,9 @@ class AgentManager:
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Extract tags from update data (it's not a column on the block table)
new_tags = update_data.pop("tags", None)
# Validate limit constraints before updating
validate_block_limit_constraint(update_data, matched_block)
@@ -1961,7 +1968,23 @@ class AgentManager:
setattr(matched_block, key, value)
await matched_block.update_async(session, actor=actor)
return matched_block.to_pydantic()
if new_tags is not None:
await BlockManager._replace_block_pivot_rows_async(
session,
BlocksTags.__table__,
matched_block.id,
[{"block_id": matched_block.id, "tag": tag} for tag in new_tags],
)
pydantic_block = matched_block.to_pydantic()
if new_tags is not None:
pydantic_block.tags = new_tags
else:
tags_result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == matched_block.id))
pydantic_block.tags = [row[0] for row in tags_result.fetchall()]
return pydantic_block
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -3015,7 +3038,25 @@ class AgentManager:
result = await session.execute(query)
blocks = result.scalars().all()
return [block.to_pydantic() for block in blocks]
if not blocks:
return []
block_ids = [block.id for block in blocks]
tags_result = await session.execute(select(BlocksTags.block_id, BlocksTags.tag).where(BlocksTags.block_id.in_(block_ids)))
tags_by_block: Dict[str, List[str]] = {}
for row in tags_result.fetchall():
block_id, tag = row
if block_id not in tags_by_block:
tags_by_block[block_id] = []
tags_by_block[block_id].append(tag)
pydantic_blocks = []
for block in blocks:
pydantic_block = block.to_pydantic()
pydantic_block.tags = tags_by_block.get(block.id, [])
pydantic_blocks.append(pydantic_block)
return pydantic_blocks
@enforce_types
@trace_method

View File

@@ -2,8 +2,12 @@ import asyncio
from datetime import datetime
from typing import Dict, List, Optional
from sqlalchemy import and_, delete, func, or_, select
import sqlalchemy as sa
from sqlalchemy import and_, delete, exists, func, literal, or_, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import noload
from sqlalchemy.sql.expression import tuple_
from letta.errors import LettaInvalidArgumentError
from letta.log import get_logger
@@ -11,7 +15,9 @@ from letta.orm.agent import Agent as AgentModel
from letta.orm.block import Block as BlockModel
from letta.orm.block_history import BlockHistory
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.blocks_tags import BlocksTags
from letta.orm.errors import NoResultFound
from letta.orm.sqlalchemy_base import AccessType
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.block import Block as PydanticBlock, BlockUpdate
@@ -84,6 +90,68 @@ def validate_block_creation(block_data: dict) -> None:
class BlockManager:
"""Manager class to handle business logic related to Blocks."""
# ======================================================================================================================
# Helper methods for pivot tables
# ======================================================================================================================
@staticmethod
async def _bulk_insert_block_pivot_async(session, table, rows: list[dict]):
"""Bulk insert rows into a pivot table, ignoring conflicts."""
if not rows:
return
dialect = session.bind.dialect.name
if dialect == "postgresql":
stmt = pg_insert(table).values(rows).on_conflict_do_nothing()
elif dialect == "sqlite":
stmt = sa.insert(table).values(rows).prefix_with("OR IGNORE")
else:
# fallback: filter out exact-duplicate dicts in Python
seen = set()
filtered = []
for row in rows:
key = tuple(sorted(row.items()))
if key not in seen:
seen.add(key)
filtered.append(row)
stmt = sa.insert(table).values(filtered)
await session.execute(stmt)
@staticmethod
async def _replace_block_pivot_rows_async(session, table, block_id: str, rows: list[dict]):
"""
Replace all pivot rows for a block atomically using MERGE pattern.
Only supports PostgreSQL (blocks_tags table not supported on SQLite).
"""
dialect = session.bind.dialect.name
if dialect == "postgresql":
if rows:
# separate upsert and delete operations
stmt = pg_insert(table).values(rows)
stmt = stmt.on_conflict_do_nothing()
await session.execute(stmt)
# delete rows not in new set
pk_names = [c.name for c in table.primary_key.columns]
new_keys = [tuple(r[c] for c in pk_names) for r in rows]
await session.execute(
delete(table).where(table.c.block_id == block_id, ~tuple_(*[table.c[c] for c in pk_names]).in_(new_keys))
)
else:
# if no rows to insert, just delete all
await session.execute(delete(table).where(table.c.block_id == block_id))
else:
# fallback: use original DELETE + INSERT pattern
await session.execute(delete(table).where(table.c.block_id == block_id))
if rows:
await BlockManager._bulk_insert_block_pivot_async(session, table, rows)
# ======================================================================================================================
# Basic CRUD operations
# ======================================================================================================================
@enforce_types
@trace_method
async def create_or_update_block_async(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
@@ -95,11 +163,23 @@ class BlockManager:
else:
async with db_registry.async_session() as session:
data = block.model_dump(to_orm=True, exclude_none=True)
# Extract tags before creating the ORM model (tags is not a column)
tags = data.pop("tags", None) or []
# Validate block creation constraints
validate_block_creation(data)
block = BlockModel(**data, organization_id=actor.organization_id)
await block.create_async(session, actor=actor, no_commit=True, no_refresh=True)
pydantic_block = block.to_pydantic()
block_model = BlockModel(**data, organization_id=actor.organization_id)
await block_model.create_async(session, actor=actor, no_commit=True, no_refresh=True)
if tags:
await self._bulk_insert_block_pivot_async(
session,
BlocksTags.__table__,
[{"block_id": block_model.id, "tag": tag, "organization_id": actor.organization_id} for tag in tags],
)
pydantic_block = block_model.to_pydantic()
pydantic_block.tags = tags
# context manager now handles commits
# await session.commit()
return pydantic_block
@@ -119,10 +199,13 @@ class BlockManager:
return []
async with db_registry.async_session() as session:
# Validate all blocks before creating any
validated_data = []
for block in blocks:
tags_by_index: Dict[int, List[str]] = {}
for i, block in enumerate(blocks):
block_data = block.model_dump(to_orm=True, exclude_none=True)
tags = block_data.pop("tags", None) or []
if tags:
tags_by_index[i] = tags
validate_block_creation(block_data)
validated_data.append(block_data)
@@ -130,9 +213,22 @@ class BlockManager:
created_models = await BlockModel.batch_create_async(
items=block_models, db_session=session, actor=actor, no_commit=True, no_refresh=True
)
result = [m.to_pydantic() for m in created_models]
# context manager now handles commits
# await session.commit()
all_tag_rows = []
for i, model in enumerate(created_models):
if i in tags_by_index:
for tag in tags_by_index[i]:
all_tag_rows.append({"block_id": model.id, "tag": tag, "organization_id": actor.organization_id})
if all_tag_rows:
await self._bulk_insert_block_pivot_async(session, BlocksTags.__table__, all_tag_rows)
result = []
for i, model in enumerate(created_models):
pydantic_block = model.to_pydantic()
pydantic_block.tags = tags_by_index.get(i, [])
result.append(pydantic_block)
return result
@enforce_types
@@ -144,6 +240,9 @@ class BlockManager:
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Extract tags from update data (it's not a column on the block table)
new_tags = update_data.pop("tags", None)
# Validate limit constraints before updating
validate_block_limit_constraint(update_data, block)
@@ -151,7 +250,22 @@ class BlockManager:
setattr(block, key, value)
await block.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
if new_tags is not None:
await self._replace_block_pivot_rows_async(
session,
BlocksTags.__table__,
block_id,
[{"block_id": block_id, "tag": tag, "organization_id": block.organization_id} for tag in new_tags],
)
pydantic_block = block.to_pydantic()
if new_tags is not None:
pydantic_block.tags = new_tags
else:
result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block_id))
pydantic_block.tags = [row[0] for row in result.fetchall()]
# context manager now handles commits
# await session.commit()
return pydantic_block
@@ -164,6 +278,8 @@ class BlockManager:
async with db_registry.async_session() as session:
# First, delete all references in blocks_agents table
await session.execute(delete(BlocksAgents).where(BlocksAgents.block_id == block_id))
# Also delete all tags associated with this block
await session.execute(delete(BlocksTags).where(BlocksTags.block_id == block_id))
await session.flush()
# Then delete the block itself
@@ -192,19 +308,18 @@ class BlockManager:
connected_to_agents_count_eq: Optional[List[int]] = None,
ascending: bool = True,
show_hidden_blocks: Optional[bool] = None,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
) -> List[PydanticBlock]:
"""Async version of get_blocks method. Retrieve blocks based on various optional filters."""
from sqlalchemy import select
from sqlalchemy.orm import noload
from letta.orm.sqlalchemy_base import AccessType
async with db_registry.async_session() as session:
# Start with a basic query
query = select(BlockModel)
# Explicitly avoid loading relationships
query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups))
query = query.options(
noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags)
)
# Apply access control
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
@@ -295,6 +410,20 @@ class BlockManager:
query = query.join(BlockModel.identities).filter(BlockModel.identities.property.mapper.class_.id == identity_id)
needs_distinct = True
if tags:
if match_all_tags:
# Must match ALL tags - use subquery with having count
tag_subquery = (
select(BlocksTags.block_id)
.where(BlocksTags.tag.in_(tags))
.group_by(BlocksTags.block_id)
.having(func.count(BlocksTags.tag) == literal(len(tags)))
)
query = query.where(BlockModel.id.in_(tag_subquery))
else:
# Must match ANY tag
query = query.where(exists().where((BlocksTags.block_id == BlockModel.id) & (BlocksTags.tag.in_(tags))))
if after:
result = (await session.execute(select(BlockModel.created_at, BlockModel.id).where(BlockModel.id == after))).first()
if result:
@@ -353,17 +482,39 @@ class BlockManager:
result = await session.execute(query)
blocks = result.scalars().all()
return [block.to_pydantic() for block in blocks]
if not blocks:
return []
block_ids = [block.id for block in blocks]
tags_result = await session.execute(select(BlocksTags.block_id, BlocksTags.tag).where(BlocksTags.block_id.in_(block_ids)))
tags_by_block: Dict[str, List[str]] = {}
for row in tags_result.fetchall():
block_id, tag = row
if block_id not in tags_by_block:
tags_by_block[block_id] = []
tags_by_block[block_id].append(tag)
pydantic_blocks = []
for block in blocks:
pydantic_block = block.to_pydantic()
pydantic_block.tags = tags_by_block.get(block.id, [])
pydantic_blocks.append(pydantic_block)
return pydantic_blocks
@enforce_types
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
@trace_method
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
"""Retrieve a block by its name."""
"""Retrieve a block by its ID, including tags."""
async with db_registry.async_session() as session:
try:
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
return block.to_pydantic()
pydantic_block = block.to_pydantic()
tags_result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block_id))
pydantic_block.tags = [row[0] for row in tags_result.fetchall()]
return pydantic_block
except NoResultFound:
return None
@@ -371,11 +522,6 @@ class BlockManager:
@trace_method
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
from sqlalchemy import select
from sqlalchemy.orm import noload
from letta.orm.sqlalchemy_base import AccessType
if not block_ids:
return []
@@ -387,7 +533,9 @@ class BlockManager:
query = query.where(BlockModel.id.in_(block_ids))
# Explicitly avoid loading relationships
query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups))
query = query.options(
noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags)
)
# Apply access control if actor is provided
if actor:
@@ -522,6 +670,88 @@ class BlockManager:
async with db_registry.async_session() as session:
return await BlockModel.size_async(db_session=session, actor=actor)
@enforce_types
@trace_method
async def count_blocks_async(
self,
actor: PydanticUser,
label: Optional[str] = None,
is_template: Optional[bool] = None,
template_name: Optional[str] = None,
project_id: Optional[str] = None,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
) -> int:
"""
Count blocks with optional filtering. Supports same filters as get_blocks_async.
"""
async with db_registry.async_session() as session:
query = select(func.count(BlockModel.id))
# Apply access control
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
query = query.where(BlockModel.organization_id == actor.organization_id)
# Apply filters
if label:
query = query.where(BlockModel.label == label)
if is_template is not None:
query = query.where(BlockModel.is_template == is_template)
if template_name:
query = query.where(BlockModel.template_name == template_name)
if project_id:
query = query.where(BlockModel.project_id == project_id)
# Apply tag filtering
if tags:
if match_all_tags:
tag_subquery = (
select(BlocksTags.block_id)
.where(BlocksTags.tag.in_(tags))
.group_by(BlocksTags.block_id)
.having(func.count(BlocksTags.tag) == literal(len(tags)))
)
query = query.where(BlockModel.id.in_(tag_subquery))
else:
query = query.where(exists().where((BlocksTags.block_id == BlockModel.id) & (BlocksTags.tag.in_(tags))))
result = await session.execute(query)
return result.scalar() or 0
@enforce_types
@trace_method
async def list_tags_async(
self,
actor: PydanticUser,
query_text: Optional[str] = None,
) -> List[str]:
"""
Get all unique block tags for the actor's organization.
Args:
actor: User performing the action.
query_text: Filter tags by text search.
Returns:
List[str]: List of unique block tags.
"""
async with db_registry.async_session() as session:
query = (
select(BlocksTags.tag)
.join(BlockModel, BlocksTags.block_id == BlockModel.id)
.where(BlockModel.organization_id == actor.organization_id)
.distinct()
)
if query_text:
if settings.database_engine is DatabaseChoice.POSTGRES:
query = query.where(BlocksTags.tag.ilike(f"%{query_text}%"))
else:
query = query.where(func.lower(BlocksTags.tag).like(func.lower(f"%{query_text}%")))
result = await session.execute(query)
return [row[0] for row in result.fetchall()]
# Block History Functions
@enforce_types

View File

@@ -1766,6 +1766,7 @@ async def test_agent_state_schema_unchanged(server: SyncServer):
"hidden",
"created_by_id",
"last_updated_by_id",
"tags",
}
actual_block_fields = set(block_fields.keys())
if actual_block_fields != expected_block_fields:

View File

@@ -1282,3 +1282,141 @@ async def test_redo_concurrency_stale(server: SyncServer, default_user):
# an out-of-date version of the block
with pytest.raises(StaleDataError):
await block_manager.redo_checkpoint_block(block_id=block.id, actor=default_user, use_preloaded_block=block_s2)
# ======================================================================================================================
# Block Tags Tests
# ======================================================================================================================
@pytest.mark.asyncio
async def test_block_tags_create_and_update(server: SyncServer, default_user):
"""Test creating a block with tags and updating tags"""
block_manager = BlockManager()
# Create a block with tags
block = PydanticBlock(
label="test_tags",
value="Block with tags",
tags=["tag1", "tag2", "important"],
)
created_block = await block_manager.create_or_update_block_async(block, actor=default_user)
# Verify tags were saved
assert set(created_block.tags) == {"tag1", "tag2", "important"}
# Update the block with new tags
from letta.schemas.block import BlockUpdate
updated_block = await block_manager.update_block_async(
block_id=created_block.id,
block_update=BlockUpdate(tags=["tag1", "new_tag"]),
actor=default_user,
)
# Verify tags were updated
assert set(updated_block.tags) == {"tag1", "new_tag"}
# Clear all tags
cleared_block = await block_manager.update_block_async(
block_id=created_block.id,
block_update=BlockUpdate(tags=[]),
actor=default_user,
)
assert cleared_block.tags == []
@pytest.mark.asyncio
async def test_block_tags_filter_any(server: SyncServer, default_user):
"""Test filtering blocks by tags (match ANY)"""
block_manager = BlockManager()
# Create blocks with different tags
block1 = await block_manager.create_or_update_block_async(
PydanticBlock(label="b1", value="v1", tags=["alpha", "beta"]),
actor=default_user,
)
block2 = await block_manager.create_or_update_block_async(
PydanticBlock(label="b2", value="v2", tags=["beta", "gamma"]),
actor=default_user,
)
block3 = await block_manager.create_or_update_block_async(
PydanticBlock(label="b3", value="v3", tags=["delta"]),
actor=default_user,
)
# Filter by tag "beta" (match ANY)
results = await block_manager.get_blocks_async(actor=default_user, tags=["beta"], match_all_tags=False)
result_ids = {b.id for b in results}
assert block1.id in result_ids
assert block2.id in result_ids
assert block3.id not in result_ids
# Filter by tag "alpha" or "delta" (match ANY)
results = await block_manager.get_blocks_async(actor=default_user, tags=["alpha", "delta"], match_all_tags=False)
result_ids = {b.id for b in results}
assert block1.id in result_ids
assert block2.id not in result_ids
assert block3.id in result_ids
@pytest.mark.asyncio
async def test_block_tags_filter_all(server: SyncServer, default_user):
"""Test filtering blocks by tags (match ALL)"""
block_manager = BlockManager()
# Create blocks with different tags
block1 = await block_manager.create_or_update_block_async(
PydanticBlock(label="b1", value="v1", tags=["x", "y", "z"]),
actor=default_user,
)
block2 = await block_manager.create_or_update_block_async(
PydanticBlock(label="b2", value="v2", tags=["x", "y"]),
actor=default_user,
)
block3 = await block_manager.create_or_update_block_async(
PydanticBlock(label="b3", value="v3", tags=["x"]),
actor=default_user,
)
# Filter by tags "x" AND "y" (match ALL)
results = await block_manager.get_blocks_async(actor=default_user, tags=["x", "y"], match_all_tags=True)
result_ids = {b.id for b in results}
assert block1.id in result_ids
assert block2.id in result_ids
assert block3.id not in result_ids
# Filter by tags "x", "y", AND "z" (match ALL)
results = await block_manager.get_blocks_async(actor=default_user, tags=["x", "y", "z"], match_all_tags=True)
result_ids = {b.id for b in results}
assert block1.id in result_ids
assert block2.id not in result_ids
assert block3.id not in result_ids
@pytest.mark.asyncio
async def test_block_tags_count(server: SyncServer, default_user):
"""Test counting blocks with tag filters"""
block_manager = BlockManager()
# Create blocks with different tags
await block_manager.create_or_update_block_async(
PydanticBlock(label="c1", value="v1", tags=["count_test", "a"]),
actor=default_user,
)
await block_manager.create_or_update_block_async(
PydanticBlock(label="c2", value="v2", tags=["count_test", "b"]),
actor=default_user,
)
await block_manager.create_or_update_block_async(
PydanticBlock(label="c3", value="v3", tags=["other"]),
actor=default_user,
)
# Count blocks with tag "count_test"
count = await block_manager.count_blocks_async(actor=default_user, tags=["count_test"], match_all_tags=False)
assert count == 2
# Count blocks with tags "count_test" AND "a"
count = await block_manager.count_blocks_async(actor=default_user, tags=["count_test", "a"], match_all_tags=True)
assert count == 1