Files
letta-server/letta/services/block_manager.py
Kian Jones 0099a95a43 fix(sec): first pass of ensuring actor id is required everywhere (#9126)
first pass of ensuring actor id is required
2026-01-29 12:44:04 -08:00

1036 lines
45 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
from datetime import datetime
from typing import Dict, List, Optional
import sqlalchemy as sa
from sqlalchemy import and_, delete, exists, func, literal, or_, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import noload
from sqlalchemy.sql.expression import tuple_
from letta.errors import LettaInvalidArgumentError
from letta.log import get_logger
from letta.orm.agent import Agent as AgentModel
from letta.orm.block import Block as BlockModel
from letta.orm.block_history import BlockHistory
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.blocks_tags import BlocksTags
from letta.orm.errors import NoResultFound
from letta.orm.sqlalchemy_base import AccessType
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.block import Block as PydanticBlock, BlockUpdate
from letta.schemas.enums import ActorType, PrimitiveType
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import DatabaseChoice, settings
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
def validate_block_limit_constraint(update_data: dict, existing_block: BlockModel) -> None:
"""
Validates that block limit constraints are satisfied when updating a block.
Rules:
- If limit is being updated, it must be >= the length of the value (existing or new)
- If value is being updated, its length must not exceed the limit (existing or new)
Args:
update_data: Dictionary of fields to update
existing_block: The current block being updated
Raises:
LettaInvalidArgumentError: If validation fails
"""
# If limit is being updated, ensure it's >= current value length
if "limit" in update_data:
# Get the value that will be used (either from update_data or existing)
value_to_check = update_data.get("value", existing_block.value)
limit_to_check = update_data["limit"]
if value_to_check and limit_to_check < len(value_to_check):
raise LettaInvalidArgumentError(
f"Limit ({limit_to_check}) cannot be less than current value length ({len(value_to_check)} characters)",
argument_name="limit",
)
# If value is being updated and there's an existing limit, ensure value doesn't exceed limit
elif "value" in update_data and existing_block.limit:
if len(update_data["value"]) > existing_block.limit:
raise LettaInvalidArgumentError(
f"Value length ({len(update_data['value'])} characters) exceeds block limit ({existing_block.limit} characters)",
argument_name="value",
)
def validate_block_creation(block_data: dict) -> None:
"""
Validates that block limit constraints are satisfied when creating a block.
Rules:
- If both value and limit are provided, limit must be >= value length
Args:
block_data: Dictionary of block fields for creation
Raises:
LettaInvalidArgumentError: If validation fails
"""
value = block_data.get("value")
limit = block_data.get("limit")
if value and limit and len(value) > limit:
raise LettaInvalidArgumentError(
f"Block limit ({limit}) must be greater than or equal to value length ({len(value)} characters)", argument_name="limit"
)
def _cursor_filter(sort_col, id_col, ref_sort_val, ref_id, forward: bool):
"""
Returns a SQLAlchemy filter expression for cursor-based pagination.
If `forward` is True, returns records after the reference.
If `forward` is False, returns records before the reference.
"""
if forward:
return or_(
sort_col > ref_sort_val,
and_(sort_col == ref_sort_val, id_col > ref_id),
)
else:
return or_(
sort_col < ref_sort_val,
and_(sort_col == ref_sort_val, id_col < ref_id),
)
class BlockManager:
"""Manager class to handle business logic related to Blocks."""
# ======================================================================================================================
# Helper methods for pivot tables
# ======================================================================================================================
@staticmethod
async def _bulk_insert_block_pivot_async(session, table, rows: list[dict]):
"""Bulk insert rows into a pivot table, ignoring conflicts."""
if not rows:
return
dialect = session.bind.dialect.name
if dialect == "postgresql":
stmt = pg_insert(table).values(rows).on_conflict_do_nothing()
elif dialect == "sqlite":
stmt = sa.insert(table).values(rows).prefix_with("OR IGNORE")
else:
# fallback: filter out exact-duplicate dicts in Python
seen = set()
filtered = []
for row in rows:
key = tuple(sorted(row.items()))
if key not in seen:
seen.add(key)
filtered.append(row)
stmt = sa.insert(table).values(filtered)
await session.execute(stmt)
@staticmethod
async def _replace_block_pivot_rows_async(session, table, block_id: str, rows: list[dict]):
"""
Replace all pivot rows for a block atomically using MERGE pattern.
Only supports PostgreSQL (blocks_tags table not supported on SQLite).
"""
dialect = session.bind.dialect.name
if dialect == "postgresql":
if rows:
# separate upsert and delete operations
stmt = pg_insert(table).values(rows)
stmt = stmt.on_conflict_do_nothing()
await session.execute(stmt)
# delete rows not in new set
pk_names = [c.name for c in table.primary_key.columns]
new_keys = [tuple(r[c] for c in pk_names) for r in rows]
await session.execute(
delete(table).where(table.c.block_id == block_id, ~tuple_(*[table.c[c] for c in pk_names]).in_(new_keys))
)
else:
# if no rows to insert, just delete all
await session.execute(delete(table).where(table.c.block_id == block_id))
else:
# fallback: use original DELETE + INSERT pattern
await session.execute(delete(table).where(table.c.block_id == block_id))
if rows:
await BlockManager._bulk_insert_block_pivot_async(session, table, rows)
# ======================================================================================================================
# Basic CRUD operations
# ======================================================================================================================
@enforce_types
@trace_method
async def create_or_update_block_async(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
"""Create a new block based on the Block schema."""
db_block = await self.get_block_by_id_async(block.id, actor)
if db_block:
update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True))
return await self.update_block_async(block.id, update_data, actor)
else:
async with db_registry.async_session() as session:
data = block.model_dump(to_orm=True, exclude_none=True)
# Extract tags before creating the ORM model (tags is not a column)
tags = data.pop("tags", None) or []
# Validate block creation constraints
validate_block_creation(data)
block_model = BlockModel(**data, organization_id=actor.organization_id)
await block_model.create_async(session, actor=actor, no_commit=True, no_refresh=True)
if tags:
await self._bulk_insert_block_pivot_async(
session,
BlocksTags.__table__,
[{"block_id": block_model.id, "tag": tag, "organization_id": actor.organization_id} for tag in tags],
)
pydantic_block = block_model.to_pydantic()
pydantic_block.tags = tags
# context manager now handles commits
# await session.commit()
return pydantic_block
@enforce_types
@trace_method
async def batch_create_blocks_async(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]:
"""
Batch-create multiple Blocks in one transaction for better performance.
Args:
blocks: List of PydanticBlock schemas to create
actor: The user performing the operation
Returns:
List of created PydanticBlock instances (with IDs, timestamps, etc.)
"""
if not blocks:
return []
async with db_registry.async_session() as session:
validated_data = []
tags_by_index: Dict[int, List[str]] = {}
for i, block in enumerate(blocks):
block_data = block.model_dump(to_orm=True, exclude_none=True)
tags = block_data.pop("tags", None) or []
if tags:
tags_by_index[i] = tags
validate_block_creation(block_data)
validated_data.append(block_data)
block_models = [BlockModel(**data, organization_id=actor.organization_id) for data in validated_data]
created_models = await BlockModel.batch_create_async(
items=block_models, db_session=session, actor=actor, no_commit=True, no_refresh=True
)
all_tag_rows = []
for i, model in enumerate(created_models):
if i in tags_by_index:
for tag in tags_by_index[i]:
all_tag_rows.append({"block_id": model.id, "tag": tag, "organization_id": actor.organization_id})
if all_tag_rows:
await self._bulk_insert_block_pivot_async(session, BlocksTags.__table__, all_tag_rows)
result = []
for i, model in enumerate(created_models):
pydantic_block = model.to_pydantic()
pydantic_block.tags = tags_by_index.get(i, [])
result.append(pydantic_block)
return result
@enforce_types
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
@trace_method
async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
"""Update a block by its ID with the given BlockUpdate object."""
async with db_registry.async_session() as session:
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Extract tags from update data (it's not a column on the block table)
new_tags = update_data.pop("tags", None)
# Validate limit constraints before updating
validate_block_limit_constraint(update_data, block)
for key, value in update_data.items():
setattr(block, key, value)
await block.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
if new_tags is not None:
await self._replace_block_pivot_rows_async(
session,
BlocksTags.__table__,
block_id,
[{"block_id": block_id, "tag": tag, "organization_id": block.organization_id} for tag in new_tags],
)
pydantic_block = block.to_pydantic()
if new_tags is not None:
pydantic_block.tags = new_tags
else:
result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block_id))
pydantic_block.tags = [row[0] for row in result.fetchall()]
# context manager now handles commits
# await session.commit()
return pydantic_block
@enforce_types
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
@trace_method
async def delete_block_async(self, block_id: str, actor: PydanticUser) -> None:
"""Delete a block by its ID."""
async with db_registry.async_session() as session:
# First, delete all references in blocks_agents table
await session.execute(delete(BlocksAgents).where(BlocksAgents.block_id == block_id))
# Also delete all tags associated with this block
await session.execute(delete(BlocksTags).where(BlocksTags.block_id == block_id))
await session.flush()
# Then delete the block itself
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
await block.hard_delete_async(db_session=session, actor=actor)
@enforce_types
@trace_method
async def get_blocks_async(
self,
actor: PydanticUser,
label: Optional[str] = None,
is_template: Optional[bool] = None,
template_name: Optional[str] = None,
identity_id: Optional[str] = None,
identifier_keys: Optional[List[str]] = None,
project_id: Optional[str] = None,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
label_search: Optional[str] = None,
description_search: Optional[str] = None,
value_search: Optional[str] = None,
connected_to_agents_count_gt: Optional[int] = None,
connected_to_agents_count_lt: Optional[int] = None,
connected_to_agents_count_eq: Optional[List[int]] = None,
ascending: bool = True,
show_hidden_blocks: Optional[bool] = None,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
) -> List[PydanticBlock]:
"""Async version of get_blocks method. Retrieve blocks based on various optional filters."""
async with db_registry.async_session() as session:
# Start with a basic query
query = select(BlockModel)
# Explicitly avoid loading relationships
query = query.options(
noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags)
)
# Apply access control
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
# Add filters
query = query.where(BlockModel.organization_id == actor.organization_id)
if label:
query = query.where(BlockModel.label == label)
if is_template is not None:
query = query.where(BlockModel.is_template == is_template)
if template_name:
query = query.where(BlockModel.template_name == template_name)
if project_id:
query = query.where(BlockModel.project_id == project_id)
if label_search and not label:
query = query.where(BlockModel.label.ilike(f"%{label_search}%"))
if description_search:
query = query.where(BlockModel.description.ilike(f"%{description_search}%"))
if value_search:
query = query.where(BlockModel.value.ilike(f"%{value_search}%"))
# Apply hidden filter
if not show_hidden_blocks:
query = query.where((BlockModel.hidden.is_(None)) | (BlockModel.hidden == False))
needs_distinct = False
needs_agent_count_join = any(
condition is not None
for condition in [connected_to_agents_count_gt, connected_to_agents_count_lt, connected_to_agents_count_eq]
)
# If any agent count filters are specified, create a single subquery and apply all filters
if needs_agent_count_join:
# Create a subquery to count agents per block
agent_count_subquery = (
select(BlocksAgents.block_id, func.count(BlocksAgents.agent_id).label("agent_count"))
.group_by(BlocksAgents.block_id)
.subquery()
)
# Determine if we need a left join (for cases involving 0 counts)
needs_left_join = (connected_to_agents_count_lt is not None) or (
connected_to_agents_count_eq is not None and 0 in connected_to_agents_count_eq
)
if needs_left_join:
# Left join to include blocks with no agents
query = query.outerjoin(agent_count_subquery, BlockModel.id == agent_count_subquery.c.block_id)
# Use coalesce to treat NULL as 0 for blocks with no agents
agent_count_expr = func.coalesce(agent_count_subquery.c.agent_count, 0)
else:
# Inner join since we don't need blocks with no agents
query = query.join(agent_count_subquery, BlockModel.id == agent_count_subquery.c.block_id)
agent_count_expr = agent_count_subquery.c.agent_count
# Build the combined filter conditions
conditions = []
if connected_to_agents_count_gt is not None:
conditions.append(agent_count_expr > connected_to_agents_count_gt)
if connected_to_agents_count_lt is not None:
conditions.append(agent_count_expr < connected_to_agents_count_lt)
if connected_to_agents_count_eq is not None:
conditions.append(agent_count_expr.in_(connected_to_agents_count_eq))
# Apply all conditions with AND logic
if conditions:
query = query.where(and_(*conditions))
needs_distinct = True
if identifier_keys:
query = query.join(BlockModel.identities).filter(
BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys)
)
needs_distinct = True
if identity_id:
query = query.join(BlockModel.identities).filter(BlockModel.identities.property.mapper.class_.id == identity_id)
needs_distinct = True
if tags:
if match_all_tags:
# Must match ALL tags - use subquery with having count
tag_subquery = (
select(BlocksTags.block_id)
.where(BlocksTags.tag.in_(tags))
.group_by(BlocksTags.block_id)
.having(func.count(BlocksTags.tag) == literal(len(tags)))
)
query = query.where(BlockModel.id.in_(tag_subquery))
else:
# Must match ANY tag
query = query.where(exists().where((BlocksTags.block_id == BlockModel.id) & (BlocksTags.tag.in_(tags))))
if after:
result = (await session.execute(select(BlockModel.created_at, BlockModel.id).where(BlockModel.id == after))).first()
if result:
after_sort_value, after_id = result
# SQLite does not support as granular timestamping, so we need to round the timestamp
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime):
after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S")
query = query.where(_cursor_filter(BlockModel.created_at, BlockModel.id, after_sort_value, after_id, forward=ascending))
if before:
result = (await session.execute(select(BlockModel.created_at, BlockModel.id).where(BlockModel.id == before))).first()
if result:
before_sort_value, before_id = result
# SQLite does not support as granular timestamping, so we need to round the timestamp
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime):
before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S")
query = query.where(
_cursor_filter(BlockModel.created_at, BlockModel.id, before_sort_value, before_id, forward=not ascending)
)
# Apply ordering and handle distinct if needed
# Note: PostgreSQL's DISTINCT ON requires ORDER BY to start with the DISTINCT ON column
if needs_distinct:
if ascending:
query = query.distinct(BlockModel.id).order_by(BlockModel.id.asc(), BlockModel.created_at.asc())
else:
query = query.distinct(BlockModel.id).order_by(BlockModel.id.desc(), BlockModel.created_at.desc())
else:
if ascending:
query = query.order_by(BlockModel.created_at.asc(), BlockModel.id.asc())
else:
query = query.order_by(BlockModel.created_at.desc(), BlockModel.id.desc())
# Add limit
if limit:
query = query.limit(limit)
# Execute the query
result = await session.execute(query)
blocks = result.scalars().all()
if not blocks:
return []
block_ids = [block.id for block in blocks]
tags_result = await session.execute(select(BlocksTags.block_id, BlocksTags.tag).where(BlocksTags.block_id.in_(block_ids)))
tags_by_block: Dict[str, List[str]] = {}
for row in tags_result.fetchall():
block_id, tag = row
if block_id not in tags_by_block:
tags_by_block[block_id] = []
tags_by_block[block_id].append(tag)
pydantic_blocks = []
for block in blocks:
pydantic_block = block.to_pydantic()
pydantic_block.tags = tags_by_block.get(block.id, [])
pydantic_blocks.append(pydantic_block)
return pydantic_blocks
@enforce_types
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
@trace_method
async def get_block_by_id_async(self, block_id: str, actor: PydanticUser) -> Optional[PydanticBlock]:
"""Retrieve a block by its ID, including tags."""
async with db_registry.async_session() as session:
try:
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
pydantic_block = block.to_pydantic()
tags_result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block_id))
pydantic_block.tags = [row[0] for row in tags_result.fetchall()]
return pydantic_block
except NoResultFound:
return None
@enforce_types
@trace_method
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: PydanticUser) -> List[PydanticBlock]:
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
if not block_ids:
return []
async with db_registry.async_session() as session:
# Start with a basic query
query = select(BlockModel)
# Add ID filter
query = query.where(BlockModel.id.in_(block_ids))
# Explicitly avoid loading relationships
query = query.options(
noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags)
)
# Apply access control - actor is required for org-scoping
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
# TODO: Add soft delete filter if applicable
# if hasattr(BlockModel, "is_deleted"):
# query = query.where(BlockModel.is_deleted == False)
# Execute the query
result = await session.execute(query)
blocks = result.scalars().all()
# Convert to Pydantic models
pydantic_blocks = [block.to_pydantic() for block in blocks]
# For backward compatibility, add None for missing blocks
if len(pydantic_blocks) < len(block_ids):
{block.id for block in pydantic_blocks}
result_blocks = []
for block_id in block_ids:
block = next((b for b in pydantic_blocks if b.id == block_id), None)
result_blocks.append(block)
return result_blocks
return pydantic_blocks
@enforce_types
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
@trace_method
async def get_agents_for_block_async(
self,
block_id: str,
actor: PydanticUser,
include_relationships: Optional[List[str]] = None,
include: List[str] = [],
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
ascending: bool = True,
) -> List[PydanticAgentState]:
"""
Retrieve all agents associated with a given block with pagination support.
Args:
block_id: ID of the block to get agents for
actor: User performing the operation
include_relationships: List of relationships to include in the response
before: Cursor for pagination (get items before this ID)
after: Cursor for pagination (get items after this ID)
limit: Maximum number of items to return
ascending: Sort order (True for ascending, False for descending)
Returns:
List of agent states associated with the block
"""
async with db_registry.async_session() as session:
# Start with a basic query
query = (
select(AgentModel)
.where(AgentModel.id.in_(select(BlocksAgents.agent_id).where(BlocksAgents.block_id == block_id)))
.where(AgentModel.organization_id == actor.organization_id)
)
# Apply pagination using cursor-based approach
if after:
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after))).first()
if result:
after_sort_value, after_id = result
# SQLite does not support as granular timestamping, so we need to round the timestamp
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime):
after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S")
query = query.where(_cursor_filter(AgentModel.created_at, AgentModel.id, after_sort_value, after_id, forward=ascending))
if before:
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before))).first()
if result:
before_sort_value, before_id = result
# SQLite does not support as granular timestamping, so we need to round the timestamp
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime):
before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S")
query = query.where(
_cursor_filter(AgentModel.created_at, AgentModel.id, before_sort_value, before_id, forward=not ascending)
)
# Apply sorting
if ascending:
query = query.order_by(AgentModel.created_at.asc(), AgentModel.id.asc())
else:
query = query.order_by(AgentModel.created_at.desc(), AgentModel.id.desc())
# Apply limit
if limit:
query = query.limit(limit)
# Execute the query
result = await session.execute(query)
agents_orm = result.scalars().all()
# Convert without decrypting to release DB connection before PBKDF2
agents_encrypted = await bounded_gather(
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents_orm]
)
# Decrypt secrets outside session
return await decrypt_agent_secrets(agents_encrypted)
@enforce_types
@trace_method
async def size_async(self, actor: PydanticUser) -> int:
"""
Get the total count of blocks for the given user.
"""
async with db_registry.async_session() as session:
return await BlockModel.size_async(db_session=session, actor=actor)
@enforce_types
@trace_method
async def count_blocks_async(
self,
actor: PydanticUser,
label: Optional[str] = None,
is_template: Optional[bool] = None,
template_name: Optional[str] = None,
project_id: Optional[str] = None,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
) -> int:
"""
Count blocks with optional filtering. Supports same filters as get_blocks_async.
"""
async with db_registry.async_session() as session:
query = select(func.count(BlockModel.id))
# Apply access control
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
query = query.where(BlockModel.organization_id == actor.organization_id)
# Apply filters
if label:
query = query.where(BlockModel.label == label)
if is_template is not None:
query = query.where(BlockModel.is_template == is_template)
if template_name:
query = query.where(BlockModel.template_name == template_name)
if project_id:
query = query.where(BlockModel.project_id == project_id)
# Apply tag filtering
if tags:
if match_all_tags:
tag_subquery = (
select(BlocksTags.block_id)
.where(BlocksTags.tag.in_(tags))
.group_by(BlocksTags.block_id)
.having(func.count(BlocksTags.tag) == literal(len(tags)))
)
query = query.where(BlockModel.id.in_(tag_subquery))
else:
query = query.where(exists().where((BlocksTags.block_id == BlockModel.id) & (BlocksTags.tag.in_(tags))))
result = await session.execute(query)
return result.scalar() or 0
@enforce_types
@trace_method
async def list_tags_async(
self,
actor: PydanticUser,
query_text: Optional[str] = None,
) -> List[str]:
"""
Get all unique block tags for the actor's organization.
Args:
actor: User performing the action.
query_text: Filter tags by text search.
Returns:
List[str]: List of unique block tags.
"""
async with db_registry.async_session() as session:
query = (
select(BlocksTags.tag)
.join(BlockModel, BlocksTags.block_id == BlockModel.id)
.where(BlockModel.organization_id == actor.organization_id)
.distinct()
)
if query_text:
if settings.database_engine is DatabaseChoice.POSTGRES:
query = query.where(BlocksTags.tag.ilike(f"%{query_text}%"))
else:
query = query.where(func.lower(BlocksTags.tag).like(func.lower(f"%{query_text}%")))
result = await session.execute(query)
return [row[0] for row in result.fetchall()]
# Block History Functions
@enforce_types
async def _move_block_to_sequence(self, session: AsyncSession, block: BlockModel, target_seq: int, actor: PydanticUser) -> BlockModel:
"""
Internal helper that moves the 'block' to the specified 'target_seq' within BlockHistory.
1) Find the BlockHistory row at sequence_number=target_seq
2) Copy fields into the block
3) Update and flush (no_commit=True) - the caller is responsible for final commit
Raises:
NoResultFound: if no BlockHistory row for (block_id, target_seq)
"""
if not block.id:
raise ValueError("Block is missing an ID. Cannot move sequence.")
stmt = select(BlockHistory).filter(
BlockHistory.block_id == block.id,
BlockHistory.sequence_number == target_seq,
)
result = await session.execute(stmt)
target_entry = result.scalar_one_or_none()
if not target_entry:
raise NoResultFound(f"No BlockHistory row found for block_id={block.id} at sequence={target_seq}")
# Copy fields from target_entry to block
block.description = target_entry.description # type: ignore
block.label = target_entry.label # type: ignore
block.value = target_entry.value # type: ignore
block.limit = target_entry.limit # type: ignore
block.metadata_ = target_entry.metadata_ # type: ignore
block.current_history_entry_id = target_entry.id # type: ignore
# Update in DB (optimistic locking).
# We'll do a flush now; the caller does final commit.
updated_block = await block.update_async(db_session=session, actor=actor, no_commit=True)
return updated_block
@enforce_types
@trace_method
async def bulk_update_block_values_async(
self, updates: Dict[str, str], actor: PydanticUser, return_hydrated: bool = False
) -> Optional[List[PydanticBlock]]:
"""
Bulk-update the `value` field for multiple blocks in one transaction.
Args:
updates: mapping of block_id -> new value
actor: the user performing the update (for org scoping, permissions, audit)
return_hydrated: whether to return the pydantic Block objects that were updated
Returns:
the updated Block objects as Pydantic schemas
Raises:
NoResultFound if any block_id doesn't exist or isn't visible to this actor
ValueError if any new value exceeds its block's limit
"""
async with db_registry.async_session() as session:
query = select(BlockModel).where(BlockModel.id.in_(updates.keys()), BlockModel.organization_id == actor.organization_id)
result = await session.execute(query)
blocks = result.scalars().all()
found_ids = {b.id for b in blocks}
missing = set(updates.keys()) - found_ids
if missing:
logger.warning(f"Block IDs not found or inaccessible, skipping during bulk update: {missing!r}")
for block in blocks:
new_val = updates[block.id]
if len(new_val) > block.limit:
logger.warning(f"Value length ({len(new_val)}) exceeds limit ({block.limit}) for block {block.id!r}, truncating...")
new_val = new_val[: block.limit]
block.value = new_val
# context manager now handles commits
# await session.commit()
if return_hydrated:
# TODO: implement for async
pass
return None
@enforce_types
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@trace_method
async def checkpoint_block_async(
self,
block_id: str,
actor: PydanticUser,
agent_id: Optional[str] = None,
use_preloaded_block: Optional[BlockModel] = None, # For concurrency tests
) -> PydanticBlock:
"""
Create a new checkpoint for the given Block by copying its
current state into BlockHistory, using SQLAlchemy's built-in
version_id_col for concurrency checks.
- If the block was undone to an earlier checkpoint, we remove
any "future" checkpoints beyond the current state to keep a
strictly linear history.
- A single commit at the end ensures atomicity.
"""
async with db_registry.async_session() as session:
# 1) Load the Block
if use_preloaded_block is not None:
block = await session.merge(use_preloaded_block)
else:
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
# 2) Identify the block's current checkpoint (if any)
current_entry = None
if block.current_history_entry_id:
current_entry = await session.get(BlockHistory, block.current_history_entry_id)
# The current sequence, or 0 if no checkpoints exist
current_seq = current_entry.sequence_number if current_entry else 0
# 3) Truncate any future checkpoints
# If we are at seq=2, but there's a seq=3 or higher from a prior "redo chain",
# remove those, so we maintain a strictly linear undo/redo stack.
stmt = select(BlockHistory).filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number > current_seq)
result = await session.execute(stmt)
for entry in result.scalars():
session.delete(entry)
# Flush the deletes to ensure they're executed before we create a new entry
await session.flush()
# 4) Determine the next sequence number
next_seq = current_seq + 1
# 5) Create a new BlockHistory row reflecting the block's current state
history_entry = BlockHistory(
organization_id=actor.organization_id,
block_id=block.id,
sequence_number=next_seq,
description=block.description,
label=block.label,
value=block.value,
limit=block.limit,
metadata_=block.metadata_,
actor_type=ActorType.LETTA_AGENT if agent_id else ActorType.LETTA_USER,
actor_id=agent_id if agent_id else actor.id,
)
await history_entry.create_async(session, actor=actor, no_commit=True)
# 6) Update the blocks pointer to the new checkpoint
block.current_history_entry_id = history_entry.id
# 7) Flush changes, then commit once
block = await block.update_async(db_session=session, actor=actor, no_commit=True)
# context manager now handles commits
# await session.commit()
return block.to_pydantic()
@enforce_types
async def _move_block_to_sequence(self, session: AsyncSession, block: BlockModel, target_seq: int, actor: PydanticUser) -> BlockModel:
"""
Internal helper that moves the 'block' to the specified 'target_seq' within BlockHistory.
1) Find the BlockHistory row at sequence_number=target_seq
2) Copy fields into the block
3) Update and flush (no_commit=True) - the caller is responsible for final commit
Raises:
NoResultFound: if no BlockHistory row for (block_id, target_seq)
"""
if not block.id:
raise ValueError("Block is missing an ID. Cannot move sequence.")
stmt = select(BlockHistory).filter(
BlockHistory.block_id == block.id,
BlockHistory.sequence_number == target_seq,
)
result = await session.execute(stmt)
target_entry = result.scalar_one_or_none()
if not target_entry:
raise NoResultFound(f"No BlockHistory row found for block_id={block.id} at sequence={target_seq}")
# Copy fields from target_entry to block
block.description = target_entry.description # type: ignore
block.label = target_entry.label # type: ignore
block.value = target_entry.value # type: ignore
block.limit = target_entry.limit # type: ignore
block.metadata_ = target_entry.metadata_ # type: ignore
block.current_history_entry_id = target_entry.id # type: ignore
# Update in DB (optimistic locking).
# We'll do a flush now; the caller does final commit.
updated_block = await block.update_async(db_session=session, actor=actor, no_commit=True)
return updated_block
@enforce_types
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
@trace_method
async def undo_checkpoint_block(
self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None
) -> PydanticBlock:
"""
Move the block to the immediately previous checkpoint in BlockHistory.
If older sequences have been pruned, we jump to the largest sequence
number that is still < current_seq.
"""
async with db_registry.async_session() as session:
# 1) Load the current block
block = (
await session.merge(use_preloaded_block)
if use_preloaded_block
else await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
)
if not block.current_history_entry_id:
raise LettaInvalidArgumentError(f"Block {block_id} has no history entry - cannot undo.", argument_name="block_id")
current_entry = await session.get(BlockHistory, block.current_history_entry_id)
if not current_entry:
raise NoResultFound(f"BlockHistory row not found for id={block.current_history_entry_id}")
current_seq = current_entry.sequence_number
# 2) Find the largest sequence < current_seq
stmt = (
select(BlockHistory)
.filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number < current_seq)
.order_by(BlockHistory.sequence_number.desc())
.limit(1)
)
result = await session.execute(stmt)
previous_entry = result.scalar_one_or_none()
if not previous_entry:
# No earlier checkpoint available
raise LettaInvalidArgumentError(
f"Block {block_id} is already at the earliest checkpoint (seq={current_seq}). Cannot undo further.",
argument_name="block_id",
)
# 3) Move to that sequence
block = await self._move_block_to_sequence(session, block, previous_entry.sequence_number, actor)
# 4) Commit
# context manager now handles commits
# await session.commit()
return block.to_pydantic()
@enforce_types
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
@trace_method
async def redo_checkpoint_block(
self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None
) -> PydanticBlock:
"""
Move the block to the next checkpoint if it exists.
If some middle checkpoints have been pruned, we jump to the smallest
sequence > current_seq that remains.
"""
async with db_registry.async_session() as session:
block = (
await session.merge(use_preloaded_block)
if use_preloaded_block
else await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
)
if not block.current_history_entry_id:
raise LettaInvalidArgumentError(f"Block {block_id} has no history entry - cannot redo.", argument_name="block_id")
current_entry = await session.get(BlockHistory, block.current_history_entry_id)
if not current_entry:
raise LettaInvalidArgumentError(
f"BlockHistory row not found for id={block.current_history_entry_id}", argument_name="block_id"
)
current_seq = current_entry.sequence_number
# Find the smallest sequence that is > current_seq
stmt = (
select(BlockHistory)
.filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number > current_seq)
.order_by(BlockHistory.sequence_number.asc())
.limit(1)
)
result = await session.execute(stmt)
next_entry = result.scalar_one_or_none()
if not next_entry:
raise LettaInvalidArgumentError(
f"Block {block_id} is at the highest checkpoint (seq={current_seq}). Cannot redo further.", argument_name="block_id"
)
block = await self._move_block_to_sequence(session, block, next_entry.sequence_number, actor)
# context manager now handles commits
# await session.commit()
return block.to_pydantic()