feat: add tags support to blocks (#8474)
* feat: add tags support to blocks * fix: add timestamps and org scoping to blocks_tags Addresses PR feedback: 1. Migration: Added timestamps (created_at, updated_at), soft delete (is_deleted), audit fields (_created_by_id, _last_updated_by_id), and organization_id to blocks_tags table for filtering support. Follows SQLite baseline pattern (composite PK of block_id+tag, no separate id column) to avoid insert failures. 2. ORM: Relationship already correct with lazy="raise" to prevent implicit joins and passive_deletes=True for efficient CASCADE deletes. 3. Schema: Changed normalize_tags() from Any to dict for type safety. 4. SQLite: Added blocks_tags to SQLite baseline schema to prevent table-not-found errors. 5. Code: Updated all tag row inserts to include organization_id. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: add ORM columns and update SQLite baseline for blocks_tags Fixes test failures (CompileError: Unconsumed column names: organization_id): 1. ORM: Added organization_id, timestamps, audit fields to BlocksTags ORM model to match database schema from migrations. 2. SQLite baseline: Added full column set to blocks_tags (organization_id, timestamps, audit fields) to match PostgreSQL schema. 3. Test: Added 'tags' to expected Block schema fields. This ensures SQLite and PostgreSQL have matching schemas and the ORM can consume all columns that the code inserts. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * revert change to existing alembic migration * fix: remove passive_deletes and SQLite support for blocks_tags 1. Removed passive_deletes=True from Block.tags relationship to match AgentsTags pattern (neither have ondelete CASCADE in DB schema). 2. Removed SQLite branch from _replace_block_pivot_rows_async since blocks_tags table is PostgreSQL-only (migration skips SQLite). 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * api sync --------- Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
58
alembic/versions/cf3c4d025dbc_add_blocks_tags_table.py
Normal file
58
alembic/versions/cf3c4d025dbc_add_blocks_tags_table.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Add blocks tags table
|
||||
|
||||
Revision ID: cf3c4d025dbc
|
||||
Revises: 27de0f58e076
|
||||
Create Date: 2026-01-08 23:36:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from letta.settings import settings
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "cf3c4d025dbc"
|
||||
down_revision: Union[str, None] = "27de0f58e076"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Skip this migration for SQLite
|
||||
if not settings.letta_pg_uri_no_default:
|
||||
return
|
||||
|
||||
# Create blocks_tags table with timestamps and org scoping for filtering
|
||||
# Note: Matches agents_tags structure but follows SQLite baseline pattern (no separate id column)
|
||||
op.create_table(
|
||||
"blocks_tags",
|
||||
sa.Column("block_id", sa.String(), nullable=False),
|
||||
sa.Column("tag", sa.String(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["block_id"],
|
||||
["block.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("block_id", "tag"),
|
||||
sa.UniqueConstraint("block_id", "tag", name="unique_block_tag"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Skip this migration for SQLite
|
||||
if not settings.letta_pg_uri_no_default:
|
||||
return
|
||||
|
||||
op.drop_table("blocks_tags")
|
||||
@@ -6,6 +6,7 @@ from letta.orm.base import Base
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.blocks_tags import BlocksTags
|
||||
from letta.orm.conversation import Conversation
|
||||
from letta.orm.conversation_messages import ConversationMessage
|
||||
from letta.orm.file import FileMetadata
|
||||
|
||||
@@ -12,6 +12,7 @@ from letta.schemas.block import Block as PydanticBlock, Human, Persona
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import Organization
|
||||
from letta.orm.blocks_tags import BlocksTags
|
||||
from letta.orm.identity import Identity
|
||||
|
||||
|
||||
@@ -82,6 +83,12 @@ class Block(OrganizationMixin, SqlalchemyBase, ProjectMixin, TemplateEntityMixin
|
||||
back_populates="shared_blocks",
|
||||
passive_deletes=True,
|
||||
)
|
||||
tags: Mapped[List["BlocksTags"]] = relationship(
|
||||
"BlocksTags",
|
||||
back_populates="block",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
def to_pydantic(self) -> Type:
|
||||
match self.label:
|
||||
|
||||
37
letta/orm/blocks_tags.py
Normal file
37
letta/orm/blocks_tags.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, UniqueConstraint, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class BlocksTags(Base):
|
||||
__tablename__ = "blocks_tags"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("block_id", "tag", name="unique_block_tag"),
|
||||
Index("ix_blocks_tags_block_id_tag", "block_id", "tag"),
|
||||
Index("ix_blocks_tags_tag_block_id", "tag", "block_id"),
|
||||
)
|
||||
|
||||
# Primary key columns
|
||||
block_id: Mapped[String] = mapped_column(String, ForeignKey("block.id"), primary_key=True)
|
||||
tag: Mapped[str] = mapped_column(String, doc="The name of the tag associated with the block.", primary_key=True)
|
||||
|
||||
# Organization scoping for filtering
|
||||
organization_id: Mapped[str] = mapped_column(String, ForeignKey("organizations.id"), nullable=False)
|
||||
|
||||
# Timestamps for filtering by date
|
||||
created_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Soft delete support
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, server_default=text("FALSE"))
|
||||
|
||||
# Audit fields
|
||||
_created_by_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
_last_updated_by_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
block: Mapped["Block"] = relationship("Block", back_populates="tags")
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
@@ -88,6 +88,17 @@ class Block(BaseBlock):
|
||||
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Block.")
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that last updated this Block.")
|
||||
|
||||
# tags - using Optional with default [] to allow None input to become empty list
|
||||
tags: Optional[List[str]] = Field(default=[], description="The tags associated with the block.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def normalize_tags(cls, data: dict) -> dict:
|
||||
"""Convert None tags to empty list."""
|
||||
if isinstance(data, dict) and data.get("tags") is None:
|
||||
data["tags"] = []
|
||||
return data
|
||||
|
||||
|
||||
class BlockResponse(Block):
|
||||
id: str = Field(
|
||||
@@ -142,6 +153,9 @@ class BlockUpdate(BaseBlock):
|
||||
value: Optional[str] = Field(None, description="Value of the block.")
|
||||
project_id: Optional[str] = Field(None, description="The associated project id.")
|
||||
|
||||
# tags
|
||||
tags: Optional[List[str]] = Field(None, description="The tags to associate with the block.")
|
||||
|
||||
model_config = ConfigDict(extra="ignore") # Ignores extra fields
|
||||
|
||||
|
||||
@@ -157,6 +171,9 @@ class CreateBlock(BaseBlock):
|
||||
is_template: bool = False
|
||||
template_name: Optional[str] = Field(None, description="Name of the block if it is a template.")
|
||||
|
||||
# tags
|
||||
tags: Optional[List[str]] = Field(None, description="The tags to associate with the block.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def ensure_value_is_string(cls, data):
|
||||
|
||||
@@ -33,6 +33,11 @@ async def list_blocks(
|
||||
identity_id: IdentityIdQuery = None,
|
||||
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
|
||||
project_id: Optional[str] = Query(None, description="Search blocks by project id"),
|
||||
tags: Optional[List[str]] = Query(None, description="List of tags to filter blocks by"),
|
||||
match_all_tags: bool = Query(
|
||||
False,
|
||||
description="If True, only returns blocks that match ALL given tags. Otherwise, return blocks that have ANY of the passed-in tags.",
|
||||
),
|
||||
limit: Optional[int] = Query(50, description="Number of blocks to return"),
|
||||
before: Optional[str] = Query(
|
||||
None,
|
||||
@@ -98,19 +103,44 @@ async def list_blocks(
|
||||
after=after,
|
||||
ascending=(order == "asc"),
|
||||
show_hidden_blocks=show_hidden_blocks,
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/count", response_model=int, operation_id="count_blocks")
|
||||
async def count_blocks(
|
||||
label: BlockLabelQuery = None,
|
||||
templates_only: bool = Query(False, description="Whether to include only templates"),
|
||||
name: BlockNameQuery = None,
|
||||
tags: Optional[List[str]] = Query(None, description="List of tags to filter blocks by"),
|
||||
match_all_tags: bool = Query(
|
||||
False,
|
||||
description="If True, only counts blocks that match ALL given tags. Otherwise, counts blocks that have ANY of the passed-in tags.",
|
||||
),
|
||||
project_id: Optional[str] = Query(None, description="Search blocks by project id"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""
|
||||
Count all blocks created by a user.
|
||||
Count all blocks with optional filtering.
|
||||
Supports the same filters as list_blocks for consistent querying.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.block_manager.size_async(actor=actor)
|
||||
|
||||
# If no filters are provided, use the simpler size_async method
|
||||
if all(param is None or param is False for param in [label, templates_only, name, tags, project_id]):
|
||||
return await server.block_manager.size_async(actor=actor)
|
||||
|
||||
return await server.block_manager.count_blocks_async(
|
||||
actor=actor,
|
||||
label=label,
|
||||
is_template=templates_only,
|
||||
template_name=name,
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=BlockResponse, operation_id="create_block")
|
||||
|
||||
@@ -32,11 +32,34 @@ async def list_tags(
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""
|
||||
Get the list of all agent tags that have been created.
|
||||
Get the list of all tags (from agents and blocks) that have been created.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
text_filter = name or query_text
|
||||
tags = await server.agent_manager.list_tags_async(
|
||||
|
||||
# Get tags from both agents and blocks
|
||||
agent_tags = await server.agent_manager.list_tags_async(
|
||||
actor=actor, before=before, after=after, limit=limit, query_text=text_filter, ascending=(order == "asc")
|
||||
)
|
||||
return tags
|
||||
block_tags = await server.block_manager.list_tags_async(actor=actor, query_text=text_filter)
|
||||
|
||||
# Merge and deduplicate, then sort and apply pagination
|
||||
all_tags = sorted(set(agent_tags) | set(block_tags), reverse=(order == "desc"))
|
||||
|
||||
# Apply cursor-based pagination on merged results
|
||||
if after:
|
||||
if order == "asc":
|
||||
all_tags = [t for t in all_tags if t > after]
|
||||
else:
|
||||
all_tags = [t for t in all_tags if t < after]
|
||||
if before:
|
||||
if order == "asc":
|
||||
all_tags = [t for t in all_tags if t < before]
|
||||
else:
|
||||
all_tags = [t for t in all_tags if t > before]
|
||||
|
||||
# Apply limit
|
||||
if limit:
|
||||
all_tags = all_tags[:limit]
|
||||
|
||||
return all_tags
|
||||
|
||||
@@ -36,6 +36,7 @@ from letta.orm import (
|
||||
ArchivalPassage,
|
||||
Block as BlockModel,
|
||||
BlocksAgents,
|
||||
BlocksTags,
|
||||
Group as GroupModel,
|
||||
GroupsAgents,
|
||||
IdentitiesAgents,
|
||||
@@ -1929,7 +1930,10 @@ class AgentManager:
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
for block in agent.core_memory:
|
||||
if block.label == block_label:
|
||||
return block.to_pydantic()
|
||||
pydantic_block = block.to_pydantic()
|
||||
tags_result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block.id))
|
||||
pydantic_block.tags = [row[0] for row in tags_result.fetchall()]
|
||||
return pydantic_block
|
||||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||||
|
||||
@enforce_types
|
||||
@@ -1941,7 +1945,7 @@ class AgentManager:
|
||||
block_update: BlockUpdate,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticBlock:
|
||||
"""Gets a block attached to an agent by its label."""
|
||||
"""Modifies a block attached to an agent by its label."""
|
||||
async with db_registry.async_session() as session:
|
||||
matched_block = None
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
@@ -1954,6 +1958,9 @@ class AgentManager:
|
||||
|
||||
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
# Extract tags from update data (it's not a column on the block table)
|
||||
new_tags = update_data.pop("tags", None)
|
||||
|
||||
# Validate limit constraints before updating
|
||||
validate_block_limit_constraint(update_data, matched_block)
|
||||
|
||||
@@ -1961,7 +1968,23 @@ class AgentManager:
|
||||
setattr(matched_block, key, value)
|
||||
|
||||
await matched_block.update_async(session, actor=actor)
|
||||
return matched_block.to_pydantic()
|
||||
|
||||
if new_tags is not None:
|
||||
await BlockManager._replace_block_pivot_rows_async(
|
||||
session,
|
||||
BlocksTags.__table__,
|
||||
matched_block.id,
|
||||
[{"block_id": matched_block.id, "tag": tag} for tag in new_tags],
|
||||
)
|
||||
|
||||
pydantic_block = matched_block.to_pydantic()
|
||||
if new_tags is not None:
|
||||
pydantic_block.tags = new_tags
|
||||
else:
|
||||
tags_result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == matched_block.id))
|
||||
pydantic_block.tags = [row[0] for row in tags_result.fetchall()]
|
||||
|
||||
return pydantic_block
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -3015,7 +3038,25 @@ class AgentManager:
|
||||
result = await session.execute(query)
|
||||
blocks = result.scalars().all()
|
||||
|
||||
return [block.to_pydantic() for block in blocks]
|
||||
if not blocks:
|
||||
return []
|
||||
|
||||
block_ids = [block.id for block in blocks]
|
||||
tags_result = await session.execute(select(BlocksTags.block_id, BlocksTags.tag).where(BlocksTags.block_id.in_(block_ids)))
|
||||
tags_by_block: Dict[str, List[str]] = {}
|
||||
for row in tags_result.fetchall():
|
||||
block_id, tag = row
|
||||
if block_id not in tags_by_block:
|
||||
tags_by_block[block_id] = []
|
||||
tags_by_block[block_id].append(tag)
|
||||
|
||||
pydantic_blocks = []
|
||||
for block in blocks:
|
||||
pydantic_block = block.to_pydantic()
|
||||
pydantic_block.tags = tags_by_block.get(block.id, [])
|
||||
pydantic_blocks.append(pydantic_block)
|
||||
|
||||
return pydantic_blocks
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -2,8 +2,12 @@ import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import and_, delete, func, or_, select
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import and_, delete, exists, func, literal, or_, select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import noload
|
||||
from sqlalchemy.sql.expression import tuple_
|
||||
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
from letta.log import get_logger
|
||||
@@ -11,7 +15,9 @@ from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.block import Block as BlockModel
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.blocks_tags import BlocksTags
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.block import Block as PydanticBlock, BlockUpdate
|
||||
@@ -84,6 +90,68 @@ def validate_block_creation(block_data: dict) -> None:
|
||||
class BlockManager:
|
||||
"""Manager class to handle business logic related to Blocks."""
|
||||
|
||||
# ======================================================================================================================
|
||||
# Helper methods for pivot tables
|
||||
# ======================================================================================================================
|
||||
|
||||
@staticmethod
|
||||
async def _bulk_insert_block_pivot_async(session, table, rows: list[dict]):
|
||||
"""Bulk insert rows into a pivot table, ignoring conflicts."""
|
||||
if not rows:
|
||||
return
|
||||
|
||||
dialect = session.bind.dialect.name
|
||||
if dialect == "postgresql":
|
||||
stmt = pg_insert(table).values(rows).on_conflict_do_nothing()
|
||||
elif dialect == "sqlite":
|
||||
stmt = sa.insert(table).values(rows).prefix_with("OR IGNORE")
|
||||
else:
|
||||
# fallback: filter out exact-duplicate dicts in Python
|
||||
seen = set()
|
||||
filtered = []
|
||||
for row in rows:
|
||||
key = tuple(sorted(row.items()))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
filtered.append(row)
|
||||
stmt = sa.insert(table).values(filtered)
|
||||
|
||||
await session.execute(stmt)
|
||||
|
||||
@staticmethod
|
||||
async def _replace_block_pivot_rows_async(session, table, block_id: str, rows: list[dict]):
|
||||
"""
|
||||
Replace all pivot rows for a block atomically using MERGE pattern.
|
||||
Only supports PostgreSQL (blocks_tags table not supported on SQLite).
|
||||
"""
|
||||
dialect = session.bind.dialect.name
|
||||
|
||||
if dialect == "postgresql":
|
||||
if rows:
|
||||
# separate upsert and delete operations
|
||||
stmt = pg_insert(table).values(rows)
|
||||
stmt = stmt.on_conflict_do_nothing()
|
||||
await session.execute(stmt)
|
||||
|
||||
# delete rows not in new set
|
||||
pk_names = [c.name for c in table.primary_key.columns]
|
||||
new_keys = [tuple(r[c] for c in pk_names) for r in rows]
|
||||
await session.execute(
|
||||
delete(table).where(table.c.block_id == block_id, ~tuple_(*[table.c[c] for c in pk_names]).in_(new_keys))
|
||||
)
|
||||
else:
|
||||
# if no rows to insert, just delete all
|
||||
await session.execute(delete(table).where(table.c.block_id == block_id))
|
||||
else:
|
||||
# fallback: use original DELETE + INSERT pattern
|
||||
await session.execute(delete(table).where(table.c.block_id == block_id))
|
||||
if rows:
|
||||
await BlockManager._bulk_insert_block_pivot_async(session, table, rows)
|
||||
|
||||
# ======================================================================================================================
|
||||
# Basic CRUD operations
|
||||
# ======================================================================================================================
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_or_update_block_async(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
|
||||
@@ -95,11 +163,23 @@ class BlockManager:
|
||||
else:
|
||||
async with db_registry.async_session() as session:
|
||||
data = block.model_dump(to_orm=True, exclude_none=True)
|
||||
# Extract tags before creating the ORM model (tags is not a column)
|
||||
tags = data.pop("tags", None) or []
|
||||
|
||||
# Validate block creation constraints
|
||||
validate_block_creation(data)
|
||||
block = BlockModel(**data, organization_id=actor.organization_id)
|
||||
await block.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
||||
pydantic_block = block.to_pydantic()
|
||||
block_model = BlockModel(**data, organization_id=actor.organization_id)
|
||||
await block_model.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
||||
|
||||
if tags:
|
||||
await self._bulk_insert_block_pivot_async(
|
||||
session,
|
||||
BlocksTags.__table__,
|
||||
[{"block_id": block_model.id, "tag": tag, "organization_id": actor.organization_id} for tag in tags],
|
||||
)
|
||||
|
||||
pydantic_block = block_model.to_pydantic()
|
||||
pydantic_block.tags = tags
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return pydantic_block
|
||||
@@ -119,10 +199,13 @@ class BlockManager:
|
||||
return []
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Validate all blocks before creating any
|
||||
validated_data = []
|
||||
for block in blocks:
|
||||
tags_by_index: Dict[int, List[str]] = {}
|
||||
for i, block in enumerate(blocks):
|
||||
block_data = block.model_dump(to_orm=True, exclude_none=True)
|
||||
tags = block_data.pop("tags", None) or []
|
||||
if tags:
|
||||
tags_by_index[i] = tags
|
||||
validate_block_creation(block_data)
|
||||
validated_data.append(block_data)
|
||||
|
||||
@@ -130,9 +213,22 @@ class BlockManager:
|
||||
created_models = await BlockModel.batch_create_async(
|
||||
items=block_models, db_session=session, actor=actor, no_commit=True, no_refresh=True
|
||||
)
|
||||
result = [m.to_pydantic() for m in created_models]
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
all_tag_rows = []
|
||||
for i, model in enumerate(created_models):
|
||||
if i in tags_by_index:
|
||||
for tag in tags_by_index[i]:
|
||||
all_tag_rows.append({"block_id": model.id, "tag": tag, "organization_id": actor.organization_id})
|
||||
|
||||
if all_tag_rows:
|
||||
await self._bulk_insert_block_pivot_async(session, BlocksTags.__table__, all_tag_rows)
|
||||
|
||||
result = []
|
||||
for i, model in enumerate(created_models):
|
||||
pydantic_block = model.to_pydantic()
|
||||
pydantic_block.tags = tags_by_index.get(i, [])
|
||||
result.append(pydantic_block)
|
||||
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@@ -144,6 +240,9 @@ class BlockManager:
|
||||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||||
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
# Extract tags from update data (it's not a column on the block table)
|
||||
new_tags = update_data.pop("tags", None)
|
||||
|
||||
# Validate limit constraints before updating
|
||||
validate_block_limit_constraint(update_data, block)
|
||||
|
||||
@@ -151,7 +250,22 @@ class BlockManager:
|
||||
setattr(block, key, value)
|
||||
|
||||
await block.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
|
||||
if new_tags is not None:
|
||||
await self._replace_block_pivot_rows_async(
|
||||
session,
|
||||
BlocksTags.__table__,
|
||||
block_id,
|
||||
[{"block_id": block_id, "tag": tag, "organization_id": block.organization_id} for tag in new_tags],
|
||||
)
|
||||
|
||||
pydantic_block = block.to_pydantic()
|
||||
if new_tags is not None:
|
||||
pydantic_block.tags = new_tags
|
||||
else:
|
||||
result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block_id))
|
||||
pydantic_block.tags = [row[0] for row in result.fetchall()]
|
||||
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return pydantic_block
|
||||
@@ -164,6 +278,8 @@ class BlockManager:
|
||||
async with db_registry.async_session() as session:
|
||||
# First, delete all references in blocks_agents table
|
||||
await session.execute(delete(BlocksAgents).where(BlocksAgents.block_id == block_id))
|
||||
# Also delete all tags associated with this block
|
||||
await session.execute(delete(BlocksTags).where(BlocksTags.block_id == block_id))
|
||||
await session.flush()
|
||||
|
||||
# Then delete the block itself
|
||||
@@ -192,19 +308,18 @@ class BlockManager:
|
||||
connected_to_agents_count_eq: Optional[List[int]] = None,
|
||||
ascending: bool = True,
|
||||
show_hidden_blocks: Optional[bool] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
) -> List[PydanticBlock]:
|
||||
"""Async version of get_blocks method. Retrieve blocks based on various optional filters."""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Start with a basic query
|
||||
query = select(BlockModel)
|
||||
|
||||
# Explicitly avoid loading relationships
|
||||
query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups))
|
||||
query = query.options(
|
||||
noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags)
|
||||
)
|
||||
|
||||
# Apply access control
|
||||
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
@@ -295,6 +410,20 @@ class BlockManager:
|
||||
query = query.join(BlockModel.identities).filter(BlockModel.identities.property.mapper.class_.id == identity_id)
|
||||
needs_distinct = True
|
||||
|
||||
if tags:
|
||||
if match_all_tags:
|
||||
# Must match ALL tags - use subquery with having count
|
||||
tag_subquery = (
|
||||
select(BlocksTags.block_id)
|
||||
.where(BlocksTags.tag.in_(tags))
|
||||
.group_by(BlocksTags.block_id)
|
||||
.having(func.count(BlocksTags.tag) == literal(len(tags)))
|
||||
)
|
||||
query = query.where(BlockModel.id.in_(tag_subquery))
|
||||
else:
|
||||
# Must match ANY tag
|
||||
query = query.where(exists().where((BlocksTags.block_id == BlockModel.id) & (BlocksTags.tag.in_(tags))))
|
||||
|
||||
if after:
|
||||
result = (await session.execute(select(BlockModel.created_at, BlockModel.id).where(BlockModel.id == after))).first()
|
||||
if result:
|
||||
@@ -353,17 +482,39 @@ class BlockManager:
|
||||
result = await session.execute(query)
|
||||
blocks = result.scalars().all()
|
||||
|
||||
return [block.to_pydantic() for block in blocks]
|
||||
if not blocks:
|
||||
return []
|
||||
|
||||
block_ids = [block.id for block in blocks]
|
||||
tags_result = await session.execute(select(BlocksTags.block_id, BlocksTags.tag).where(BlocksTags.block_id.in_(block_ids)))
|
||||
tags_by_block: Dict[str, List[str]] = {}
|
||||
for row in tags_result.fetchall():
|
||||
block_id, tag = row
|
||||
if block_id not in tags_by_block:
|
||||
tags_by_block[block_id] = []
|
||||
tags_by_block[block_id].append(tag)
|
||||
|
||||
pydantic_blocks = []
|
||||
for block in blocks:
|
||||
pydantic_block = block.to_pydantic()
|
||||
pydantic_block.tags = tags_by_block.get(block.id, [])
|
||||
pydantic_blocks.append(pydantic_block)
|
||||
|
||||
return pydantic_blocks
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
|
||||
"""Retrieve a block by its name."""
|
||||
"""Retrieve a block by its ID, including tags."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||||
return block.to_pydantic()
|
||||
pydantic_block = block.to_pydantic()
|
||||
tags_result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block_id))
|
||||
pydantic_block.tags = [row[0] for row in tags_result.fetchall()]
|
||||
|
||||
return pydantic_block
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@@ -371,11 +522,6 @@ class BlockManager:
|
||||
@trace_method
|
||||
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
|
||||
if not block_ids:
|
||||
return []
|
||||
|
||||
@@ -387,7 +533,9 @@ class BlockManager:
|
||||
query = query.where(BlockModel.id.in_(block_ids))
|
||||
|
||||
# Explicitly avoid loading relationships
|
||||
query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups))
|
||||
query = query.options(
|
||||
noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags)
|
||||
)
|
||||
|
||||
# Apply access control if actor is provided
|
||||
if actor:
|
||||
@@ -522,6 +670,88 @@ class BlockManager:
|
||||
async with db_registry.async_session() as session:
|
||||
return await BlockModel.size_async(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def count_blocks_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
label: Optional[str] = None,
|
||||
is_template: Optional[bool] = None,
|
||||
template_name: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Count blocks with optional filtering. Supports same filters as get_blocks_async.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(func.count(BlockModel.id))
|
||||
|
||||
# Apply access control
|
||||
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
query = query.where(BlockModel.organization_id == actor.organization_id)
|
||||
|
||||
# Apply filters
|
||||
if label:
|
||||
query = query.where(BlockModel.label == label)
|
||||
if is_template is not None:
|
||||
query = query.where(BlockModel.is_template == is_template)
|
||||
if template_name:
|
||||
query = query.where(BlockModel.template_name == template_name)
|
||||
if project_id:
|
||||
query = query.where(BlockModel.project_id == project_id)
|
||||
|
||||
# Apply tag filtering
|
||||
if tags:
|
||||
if match_all_tags:
|
||||
tag_subquery = (
|
||||
select(BlocksTags.block_id)
|
||||
.where(BlocksTags.tag.in_(tags))
|
||||
.group_by(BlocksTags.block_id)
|
||||
.having(func.count(BlocksTags.tag) == literal(len(tags)))
|
||||
)
|
||||
query = query.where(BlockModel.id.in_(tag_subquery))
|
||||
else:
|
||||
query = query.where(exists().where((BlocksTags.block_id == BlockModel.id) & (BlocksTags.tag.in_(tags))))
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_tags_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
query_text: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get all unique block tags for the actor's organization.
|
||||
|
||||
Args:
|
||||
actor: User performing the action.
|
||||
query_text: Filter tags by text search.
|
||||
|
||||
Returns:
|
||||
List[str]: List of unique block tags.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
query = (
|
||||
select(BlocksTags.tag)
|
||||
.join(BlockModel, BlocksTags.block_id == BlockModel.id)
|
||||
.where(BlockModel.organization_id == actor.organization_id)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
if query_text:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
query = query.where(BlocksTags.tag.ilike(f"%{query_text}%"))
|
||||
else:
|
||||
query = query.where(func.lower(BlocksTags.tag).like(func.lower(f"%{query_text}%")))
|
||||
|
||||
result = await session.execute(query)
|
||||
return [row[0] for row in result.fetchall()]
|
||||
|
||||
# Block History Functions
|
||||
|
||||
@enforce_types
|
||||
|
||||
@@ -1766,6 +1766,7 @@ async def test_agent_state_schema_unchanged(server: SyncServer):
|
||||
"hidden",
|
||||
"created_by_id",
|
||||
"last_updated_by_id",
|
||||
"tags",
|
||||
}
|
||||
actual_block_fields = set(block_fields.keys())
|
||||
if actual_block_fields != expected_block_fields:
|
||||
|
||||
@@ -1282,3 +1282,141 @@ async def test_redo_concurrency_stale(server: SyncServer, default_user):
|
||||
# an out-of-date version of the block
|
||||
with pytest.raises(StaleDataError):
|
||||
await block_manager.redo_checkpoint_block(block_id=block.id, actor=default_user, use_preloaded_block=block_s2)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block Tags Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_tags_create_and_update(server: SyncServer, default_user):
|
||||
"""Test creating a block with tags and updating tags"""
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Create a block with tags
|
||||
block = PydanticBlock(
|
||||
label="test_tags",
|
||||
value="Block with tags",
|
||||
tags=["tag1", "tag2", "important"],
|
||||
)
|
||||
created_block = await block_manager.create_or_update_block_async(block, actor=default_user)
|
||||
|
||||
# Verify tags were saved
|
||||
assert set(created_block.tags) == {"tag1", "tag2", "important"}
|
||||
|
||||
# Update the block with new tags
|
||||
from letta.schemas.block import BlockUpdate
|
||||
|
||||
updated_block = await block_manager.update_block_async(
|
||||
block_id=created_block.id,
|
||||
block_update=BlockUpdate(tags=["tag1", "new_tag"]),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify tags were updated
|
||||
assert set(updated_block.tags) == {"tag1", "new_tag"}
|
||||
|
||||
# Clear all tags
|
||||
cleared_block = await block_manager.update_block_async(
|
||||
block_id=created_block.id,
|
||||
block_update=BlockUpdate(tags=[]),
|
||||
actor=default_user,
|
||||
)
|
||||
assert cleared_block.tags == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_tags_filter_any(server: SyncServer, default_user):
|
||||
"""Test filtering blocks by tags (match ANY)"""
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Create blocks with different tags
|
||||
block1 = await block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label="b1", value="v1", tags=["alpha", "beta"]),
|
||||
actor=default_user,
|
||||
)
|
||||
block2 = await block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label="b2", value="v2", tags=["beta", "gamma"]),
|
||||
actor=default_user,
|
||||
)
|
||||
block3 = await block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label="b3", value="v3", tags=["delta"]),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Filter by tag "beta" (match ANY)
|
||||
results = await block_manager.get_blocks_async(actor=default_user, tags=["beta"], match_all_tags=False)
|
||||
result_ids = {b.id for b in results}
|
||||
assert block1.id in result_ids
|
||||
assert block2.id in result_ids
|
||||
assert block3.id not in result_ids
|
||||
|
||||
# Filter by tag "alpha" or "delta" (match ANY)
|
||||
results = await block_manager.get_blocks_async(actor=default_user, tags=["alpha", "delta"], match_all_tags=False)
|
||||
result_ids = {b.id for b in results}
|
||||
assert block1.id in result_ids
|
||||
assert block2.id not in result_ids
|
||||
assert block3.id in result_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_tags_filter_all(server: SyncServer, default_user):
|
||||
"""Test filtering blocks by tags (match ALL)"""
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Create blocks with different tags
|
||||
block1 = await block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label="b1", value="v1", tags=["x", "y", "z"]),
|
||||
actor=default_user,
|
||||
)
|
||||
block2 = await block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label="b2", value="v2", tags=["x", "y"]),
|
||||
actor=default_user,
|
||||
)
|
||||
block3 = await block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label="b3", value="v3", tags=["x"]),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Filter by tags "x" AND "y" (match ALL)
|
||||
results = await block_manager.get_blocks_async(actor=default_user, tags=["x", "y"], match_all_tags=True)
|
||||
result_ids = {b.id for b in results}
|
||||
assert block1.id in result_ids
|
||||
assert block2.id in result_ids
|
||||
assert block3.id not in result_ids
|
||||
|
||||
# Filter by tags "x", "y", AND "z" (match ALL)
|
||||
results = await block_manager.get_blocks_async(actor=default_user, tags=["x", "y", "z"], match_all_tags=True)
|
||||
result_ids = {b.id for b in results}
|
||||
assert block1.id in result_ids
|
||||
assert block2.id not in result_ids
|
||||
assert block3.id not in result_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_tags_count(server: SyncServer, default_user):
|
||||
"""Test counting blocks with tag filters"""
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Create blocks with different tags
|
||||
await block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label="c1", value="v1", tags=["count_test", "a"]),
|
||||
actor=default_user,
|
||||
)
|
||||
await block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label="c2", value="v2", tags=["count_test", "b"]),
|
||||
actor=default_user,
|
||||
)
|
||||
await block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label="c3", value="v3", tags=["other"]),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Count blocks with tag "count_test"
|
||||
count = await block_manager.count_blocks_async(actor=default_user, tags=["count_test"], match_all_tags=False)
|
||||
assert count == 2
|
||||
|
||||
# Count blocks with tags "count_test" AND "a"
|
||||
count = await block_manager.count_blocks_async(actor=default_user, tags=["count_test", "a"], match_all_tags=True)
|
||||
assert count == 1
|
||||
|
||||
Reference in New Issue
Block a user