import asyncio from datetime import datetime from typing import Dict, List, Optional import sqlalchemy as sa from sqlalchemy import and_, delete, exists, func, literal, or_, select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import noload from sqlalchemy.sql.expression import tuple_ from letta.errors import LettaInvalidArgumentError from letta.log import get_logger from letta.orm.agent import Agent as AgentModel from letta.orm.block import Block as BlockModel from letta.orm.block_history import BlockHistory from letta.orm.blocks_agents import BlocksAgents from letta.orm.blocks_tags import BlocksTags from letta.orm.errors import NoResultFound from letta.orm.sqlalchemy_base import AccessType from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.block import Block as PydanticBlock, BlockUpdate from letta.schemas.enums import ActorType, PrimitiveType from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.settings import DatabaseChoice, settings from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types from letta.validators import raise_on_invalid_id logger = get_logger(__name__) def validate_block_limit_constraint(update_data: dict, existing_block: BlockModel) -> None: """ Validates that block limit constraints are satisfied when updating a block. Rules: - If limit is being updated, it must be >= the length of the value (existing or new) - If value is being updated, its length must not exceed the limit (existing or new) Args: update_data: Dictionary of fields to update existing_block: The current block being updated Raises: LettaInvalidArgumentError: If validation fails """ # If limit is being updated, ensure it's >= current value length if "limit" in update_data: # Get the value that will be used (either from update_data or existing) value_to_check = update_data.get("value", existing_block.value) limit_to_check = update_data["limit"] if value_to_check and limit_to_check < len(value_to_check): raise LettaInvalidArgumentError( f"Limit ({limit_to_check}) cannot be less than current value length ({len(value_to_check)} characters)", argument_name="limit", ) # If value is being updated and there's an existing limit, ensure value doesn't exceed limit elif "value" in update_data and existing_block.limit: if len(update_data["value"]) > existing_block.limit: raise LettaInvalidArgumentError( f"Value length ({len(update_data['value'])} characters) exceeds block limit ({existing_block.limit} characters)", argument_name="value", ) def validate_block_creation(block_data: dict) -> None: """ Validates that block limit constraints are satisfied when creating a block. Rules: - If both value and limit are provided, limit must be >= value length Args: block_data: Dictionary of block fields for creation Raises: LettaInvalidArgumentError: If validation fails """ value = block_data.get("value") limit = block_data.get("limit") if value and limit and len(value) > limit: raise LettaInvalidArgumentError( f"Block limit ({limit}) must be greater than or equal to value length ({len(value)} characters)", argument_name="limit" ) def _cursor_filter(sort_col, id_col, ref_sort_val, ref_id, forward: bool): """ Returns a SQLAlchemy filter expression for cursor-based pagination. If `forward` is True, returns records after the reference. If `forward` is False, returns records before the reference. """ if forward: return or_( sort_col > ref_sort_val, and_(sort_col == ref_sort_val, id_col > ref_id), ) else: return or_( sort_col < ref_sort_val, and_(sort_col == ref_sort_val, id_col < ref_id), ) class BlockManager: """Manager class to handle business logic related to Blocks.""" # ====================================================================================================================== # Helper methods for pivot tables # ====================================================================================================================== @staticmethod async def _bulk_insert_block_pivot_async(session, table, rows: list[dict]): """Bulk insert rows into a pivot table, ignoring conflicts.""" if not rows: return dialect = session.bind.dialect.name if dialect == "postgresql": stmt = pg_insert(table).values(rows).on_conflict_do_nothing() elif dialect == "sqlite": stmt = sa.insert(table).values(rows).prefix_with("OR IGNORE") else: # fallback: filter out exact-duplicate dicts in Python seen = set() filtered = [] for row in rows: key = tuple(sorted(row.items())) if key not in seen: seen.add(key) filtered.append(row) stmt = sa.insert(table).values(filtered) await session.execute(stmt) @staticmethod async def _replace_block_pivot_rows_async(session, table, block_id: str, rows: list[dict]): """ Replace all pivot rows for a block atomically using MERGE pattern. Only supports PostgreSQL (blocks_tags table not supported on SQLite). """ dialect = session.bind.dialect.name if dialect == "postgresql": if rows: # separate upsert and delete operations stmt = pg_insert(table).values(rows) stmt = stmt.on_conflict_do_nothing() await session.execute(stmt) # delete rows not in new set pk_names = [c.name for c in table.primary_key.columns] new_keys = [tuple(r[c] for c in pk_names) for r in rows] await session.execute( delete(table).where(table.c.block_id == block_id, ~tuple_(*[table.c[c] for c in pk_names]).in_(new_keys)) ) else: # if no rows to insert, just delete all await session.execute(delete(table).where(table.c.block_id == block_id)) else: # fallback: use original DELETE + INSERT pattern await session.execute(delete(table).where(table.c.block_id == block_id)) if rows: await BlockManager._bulk_insert_block_pivot_async(session, table, rows) # ====================================================================================================================== # Basic CRUD operations # ====================================================================================================================== @enforce_types @trace_method async def create_or_update_block_async(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock: """Create a new block based on the Block schema.""" db_block = await self.get_block_by_id_async(block.id, actor) if db_block: update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True)) return await self.update_block_async(block.id, update_data, actor) else: async with db_registry.async_session() as session: data = block.model_dump(to_orm=True, exclude_none=True) # Extract tags before creating the ORM model (tags is not a column) tags = data.pop("tags", None) or [] # Validate block creation constraints validate_block_creation(data) block_model = BlockModel(**data, organization_id=actor.organization_id) await block_model.create_async(session, actor=actor, no_commit=True, no_refresh=True) if tags: await self._bulk_insert_block_pivot_async( session, BlocksTags.__table__, [{"block_id": block_model.id, "tag": tag, "organization_id": actor.organization_id} for tag in tags], ) pydantic_block = block_model.to_pydantic() pydantic_block.tags = tags # context manager now handles commits # await session.commit() return pydantic_block @enforce_types @trace_method async def batch_create_blocks_async(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]: """ Batch-create multiple Blocks in one transaction for better performance. Args: blocks: List of PydanticBlock schemas to create actor: The user performing the operation Returns: List of created PydanticBlock instances (with IDs, timestamps, etc.) """ if not blocks: return [] async with db_registry.async_session() as session: validated_data = [] tags_by_index: Dict[int, List[str]] = {} for i, block in enumerate(blocks): block_data = block.model_dump(to_orm=True, exclude_none=True) tags = block_data.pop("tags", None) or [] if tags: tags_by_index[i] = tags validate_block_creation(block_data) validated_data.append(block_data) block_models = [BlockModel(**data, organization_id=actor.organization_id) for data in validated_data] created_models = await BlockModel.batch_create_async( items=block_models, db_session=session, actor=actor, no_commit=True, no_refresh=True ) all_tag_rows = [] for i, model in enumerate(created_models): if i in tags_by_index: for tag in tags_by_index[i]: all_tag_rows.append({"block_id": model.id, "tag": tag, "organization_id": actor.organization_id}) if all_tag_rows: await self._bulk_insert_block_pivot_async(session, BlocksTags.__table__, all_tag_rows) result = [] for i, model in enumerate(created_models): pydantic_block = model.to_pydantic() pydantic_block.tags = tags_by_index.get(i, []) result.append(pydantic_block) return result @enforce_types @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) @trace_method async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: """Update a block by its ID with the given BlockUpdate object.""" async with db_registry.async_session() as session: block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) # Extract tags from update data (it's not a column on the block table) new_tags = update_data.pop("tags", None) # Validate limit constraints before updating validate_block_limit_constraint(update_data, block) for key, value in update_data.items(): setattr(block, key, value) await block.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) if new_tags is not None: await self._replace_block_pivot_rows_async( session, BlocksTags.__table__, block_id, [{"block_id": block_id, "tag": tag, "organization_id": block.organization_id} for tag in new_tags], ) pydantic_block = block.to_pydantic() if new_tags is not None: pydantic_block.tags = new_tags else: result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block_id)) pydantic_block.tags = [row[0] for row in result.fetchall()] # context manager now handles commits # await session.commit() return pydantic_block @enforce_types @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) @trace_method async def delete_block_async(self, block_id: str, actor: PydanticUser) -> None: """Delete a block by its ID.""" async with db_registry.async_session() as session: # First, delete all references in blocks_agents table await session.execute(delete(BlocksAgents).where(BlocksAgents.block_id == block_id)) # Also delete all tags associated with this block await session.execute(delete(BlocksTags).where(BlocksTags.block_id == block_id)) await session.flush() # Then delete the block itself block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) await block.hard_delete_async(db_session=session, actor=actor) @enforce_types @trace_method async def get_blocks_async( self, actor: PydanticUser, label: Optional[str] = None, is_template: Optional[bool] = None, template_name: Optional[str] = None, identity_id: Optional[str] = None, identifier_keys: Optional[List[str]] = None, project_id: Optional[str] = None, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 50, label_search: Optional[str] = None, description_search: Optional[str] = None, value_search: Optional[str] = None, connected_to_agents_count_gt: Optional[int] = None, connected_to_agents_count_lt: Optional[int] = None, connected_to_agents_count_eq: Optional[List[int]] = None, ascending: bool = True, show_hidden_blocks: Optional[bool] = None, tags: Optional[List[str]] = None, match_all_tags: bool = False, ) -> List[PydanticBlock]: """Async version of get_blocks method. Retrieve blocks based on various optional filters.""" async with db_registry.async_session() as session: # Start with a basic query query = select(BlockModel) # Explicitly avoid loading relationships query = query.options( noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags) ) # Apply access control query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) # Add filters query = query.where(BlockModel.organization_id == actor.organization_id) if label: query = query.where(BlockModel.label == label) if is_template is not None: query = query.where(BlockModel.is_template == is_template) if template_name: query = query.where(BlockModel.template_name == template_name) if project_id: query = query.where(BlockModel.project_id == project_id) if label_search and not label: query = query.where(BlockModel.label.ilike(f"%{label_search}%")) if description_search: query = query.where(BlockModel.description.ilike(f"%{description_search}%")) if value_search: query = query.where(BlockModel.value.ilike(f"%{value_search}%")) # Apply hidden filter if not show_hidden_blocks: query = query.where((BlockModel.hidden.is_(None)) | (BlockModel.hidden == False)) needs_distinct = False needs_agent_count_join = any( condition is not None for condition in [connected_to_agents_count_gt, connected_to_agents_count_lt, connected_to_agents_count_eq] ) # If any agent count filters are specified, create a single subquery and apply all filters if needs_agent_count_join: # Create a subquery to count agents per block agent_count_subquery = ( select(BlocksAgents.block_id, func.count(BlocksAgents.agent_id).label("agent_count")) .group_by(BlocksAgents.block_id) .subquery() ) # Determine if we need a left join (for cases involving 0 counts) needs_left_join = (connected_to_agents_count_lt is not None) or ( connected_to_agents_count_eq is not None and 0 in connected_to_agents_count_eq ) if needs_left_join: # Left join to include blocks with no agents query = query.outerjoin(agent_count_subquery, BlockModel.id == agent_count_subquery.c.block_id) # Use coalesce to treat NULL as 0 for blocks with no agents agent_count_expr = func.coalesce(agent_count_subquery.c.agent_count, 0) else: # Inner join since we don't need blocks with no agents query = query.join(agent_count_subquery, BlockModel.id == agent_count_subquery.c.block_id) agent_count_expr = agent_count_subquery.c.agent_count # Build the combined filter conditions conditions = [] if connected_to_agents_count_gt is not None: conditions.append(agent_count_expr > connected_to_agents_count_gt) if connected_to_agents_count_lt is not None: conditions.append(agent_count_expr < connected_to_agents_count_lt) if connected_to_agents_count_eq is not None: conditions.append(agent_count_expr.in_(connected_to_agents_count_eq)) # Apply all conditions with AND logic if conditions: query = query.where(and_(*conditions)) needs_distinct = True if identifier_keys: query = query.join(BlockModel.identities).filter( BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys) ) needs_distinct = True if identity_id: query = query.join(BlockModel.identities).filter(BlockModel.identities.property.mapper.class_.id == identity_id) needs_distinct = True if tags: if match_all_tags: # Must match ALL tags - use subquery with having count tag_subquery = ( select(BlocksTags.block_id) .where(BlocksTags.tag.in_(tags)) .group_by(BlocksTags.block_id) .having(func.count(BlocksTags.tag) == literal(len(tags))) ) query = query.where(BlockModel.id.in_(tag_subquery)) else: # Must match ANY tag query = query.where(exists().where((BlocksTags.block_id == BlockModel.id) & (BlocksTags.tag.in_(tags)))) if after: result = (await session.execute(select(BlockModel.created_at, BlockModel.id).where(BlockModel.id == after))).first() if result: after_sort_value, after_id = result # SQLite does not support as granular timestamping, so we need to round the timestamp if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime): after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S") query = query.where(_cursor_filter(BlockModel.created_at, BlockModel.id, after_sort_value, after_id, forward=ascending)) if before: result = (await session.execute(select(BlockModel.created_at, BlockModel.id).where(BlockModel.id == before))).first() if result: before_sort_value, before_id = result # SQLite does not support as granular timestamping, so we need to round the timestamp if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime): before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S") query = query.where( _cursor_filter(BlockModel.created_at, BlockModel.id, before_sort_value, before_id, forward=not ascending) ) # Apply ordering and handle distinct if needed # Note: PostgreSQL's DISTINCT ON requires ORDER BY to start with the DISTINCT ON column if needs_distinct: if ascending: query = query.distinct(BlockModel.id).order_by(BlockModel.id.asc(), BlockModel.created_at.asc()) else: query = query.distinct(BlockModel.id).order_by(BlockModel.id.desc(), BlockModel.created_at.desc()) else: if ascending: query = query.order_by(BlockModel.created_at.asc(), BlockModel.id.asc()) else: query = query.order_by(BlockModel.created_at.desc(), BlockModel.id.desc()) # Add limit if limit: query = query.limit(limit) # Execute the query result = await session.execute(query) blocks = result.scalars().all() if not blocks: return [] block_ids = [block.id for block in blocks] tags_result = await session.execute(select(BlocksTags.block_id, BlocksTags.tag).where(BlocksTags.block_id.in_(block_ids))) tags_by_block: Dict[str, List[str]] = {} for row in tags_result.fetchall(): block_id, tag = row if block_id not in tags_by_block: tags_by_block[block_id] = [] tags_by_block[block_id].append(tag) pydantic_blocks = [] for block in blocks: pydantic_block = block.to_pydantic() pydantic_block.tags = tags_by_block.get(block.id, []) pydantic_blocks.append(pydantic_block) return pydantic_blocks @enforce_types @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) @trace_method async def get_block_by_id_async(self, block_id: str, actor: PydanticUser) -> Optional[PydanticBlock]: """Retrieve a block by its ID, including tags.""" async with db_registry.async_session() as session: try: block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) pydantic_block = block.to_pydantic() tags_result = await session.execute(select(BlocksTags.tag).where(BlocksTags.block_id == block_id)) pydantic_block.tags = [row[0] for row in tags_result.fetchall()] return pydantic_block except NoResultFound: return None @enforce_types @trace_method async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: PydanticUser) -> List[PydanticBlock]: """Retrieve blocks by their ids without loading unnecessary relationships. Async implementation.""" if not block_ids: return [] async with db_registry.async_session() as session: # Start with a basic query query = select(BlockModel) # Add ID filter query = query.where(BlockModel.id.in_(block_ids)) # Explicitly avoid loading relationships query = query.options( noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags) ) # Apply access control - actor is required for org-scoping query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) # TODO: Add soft delete filter if applicable # if hasattr(BlockModel, "is_deleted"): # query = query.where(BlockModel.is_deleted == False) # Execute the query result = await session.execute(query) blocks = result.scalars().all() # Convert to Pydantic models and preserve caller-provided ID order pydantic_blocks = [block.to_pydantic() for block in blocks] blocks_by_id = {b.id: b for b in pydantic_blocks} ordered_blocks = [blocks_by_id.get(block_id) for block_id in block_ids] # For backward compatibility, include None for missing blocks if len(pydantic_blocks) < len(block_ids): return ordered_blocks return ordered_blocks @enforce_types @trace_method async def get_blocks_by_agent_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticBlock]: """Retrieve all blocks attached to a specific agent.""" async with db_registry.async_session() as session: query = ( select(BlockModel) .join(BlocksAgents, BlockModel.id == BlocksAgents.block_id) .where( BlocksAgents.agent_id == agent_id, BlockModel.organization_id == actor.organization_id, ) .options( noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags), ) ) result = await session.execute(query) blocks = result.scalars().all() return [block.to_pydantic() for block in blocks] @enforce_types @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) @trace_method async def get_agents_for_block_async( self, block_id: str, actor: PydanticUser, include_relationships: Optional[List[str]] = None, include: List[str] = [], before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 50, ascending: bool = True, ) -> List[PydanticAgentState]: """ Retrieve all agents associated with a given block with pagination support. Args: block_id: ID of the block to get agents for actor: User performing the operation include_relationships: List of relationships to include in the response before: Cursor for pagination (get items before this ID) after: Cursor for pagination (get items after this ID) limit: Maximum number of items to return ascending: Sort order (True for ascending, False for descending) Returns: List of agent states associated with the block """ async with db_registry.async_session() as session: # Start with a basic query query = ( select(AgentModel) .where(AgentModel.id.in_(select(BlocksAgents.agent_id).where(BlocksAgents.block_id == block_id))) .where(AgentModel.organization_id == actor.organization_id) ) # Apply pagination using cursor-based approach if after: result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after))).first() if result: after_sort_value, after_id = result # SQLite does not support as granular timestamping, so we need to round the timestamp if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime): after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S") query = query.where(_cursor_filter(AgentModel.created_at, AgentModel.id, after_sort_value, after_id, forward=ascending)) if before: result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before))).first() if result: before_sort_value, before_id = result # SQLite does not support as granular timestamping, so we need to round the timestamp if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime): before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S") query = query.where( _cursor_filter(AgentModel.created_at, AgentModel.id, before_sort_value, before_id, forward=not ascending) ) # Apply sorting if ascending: query = query.order_by(AgentModel.created_at.asc(), AgentModel.id.asc()) else: query = query.order_by(AgentModel.created_at.desc(), AgentModel.id.desc()) # Apply limit if limit: query = query.limit(limit) # Execute the query result = await session.execute(query) agents_orm = result.scalars().all() # Convert without decrypting to release DB connection before PBKDF2 agents_encrypted = await bounded_gather( [agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents_orm] ) # Decrypt secrets outside session return await decrypt_agent_secrets(agents_encrypted) @enforce_types @trace_method async def size_async(self, actor: PydanticUser) -> int: """ Get the total count of blocks for the given user. """ async with db_registry.async_session() as session: return await BlockModel.size_async(db_session=session, actor=actor) @enforce_types @trace_method async def count_blocks_async( self, actor: PydanticUser, label: Optional[str] = None, is_template: Optional[bool] = None, template_name: Optional[str] = None, project_id: Optional[str] = None, tags: Optional[List[str]] = None, match_all_tags: bool = False, ) -> int: """ Count blocks with optional filtering. Supports same filters as get_blocks_async. """ async with db_registry.async_session() as session: query = select(func.count(BlockModel.id)) # Apply access control query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) query = query.where(BlockModel.organization_id == actor.organization_id) # Apply filters if label: query = query.where(BlockModel.label == label) if is_template is not None: query = query.where(BlockModel.is_template == is_template) if template_name: query = query.where(BlockModel.template_name == template_name) if project_id: query = query.where(BlockModel.project_id == project_id) # Apply tag filtering if tags: if match_all_tags: tag_subquery = ( select(BlocksTags.block_id) .where(BlocksTags.tag.in_(tags)) .group_by(BlocksTags.block_id) .having(func.count(BlocksTags.tag) == literal(len(tags))) ) query = query.where(BlockModel.id.in_(tag_subquery)) else: query = query.where(exists().where((BlocksTags.block_id == BlockModel.id) & (BlocksTags.tag.in_(tags)))) result = await session.execute(query) return result.scalar() or 0 @enforce_types @trace_method async def list_tags_async( self, actor: PydanticUser, query_text: Optional[str] = None, ) -> List[str]: """ Get all unique block tags for the actor's organization. Args: actor: User performing the action. query_text: Filter tags by text search. Returns: List[str]: List of unique block tags. """ async with db_registry.async_session() as session: query = ( select(BlocksTags.tag) .join(BlockModel, BlocksTags.block_id == BlockModel.id) .where(BlockModel.organization_id == actor.organization_id) .distinct() ) if query_text: if settings.database_engine is DatabaseChoice.POSTGRES: query = query.where(BlocksTags.tag.ilike(f"%{query_text}%")) else: query = query.where(func.lower(BlocksTags.tag).like(func.lower(f"%{query_text}%"))) result = await session.execute(query) return [row[0] for row in result.fetchall()] # Block History Functions @enforce_types async def _move_block_to_sequence(self, session: AsyncSession, block: BlockModel, target_seq: int, actor: PydanticUser) -> BlockModel: """ Internal helper that moves the 'block' to the specified 'target_seq' within BlockHistory. 1) Find the BlockHistory row at sequence_number=target_seq 2) Copy fields into the block 3) Update and flush (no_commit=True) - the caller is responsible for final commit Raises: NoResultFound: if no BlockHistory row for (block_id, target_seq) """ if not block.id: raise ValueError("Block is missing an ID. Cannot move sequence.") stmt = select(BlockHistory).filter( BlockHistory.block_id == block.id, BlockHistory.sequence_number == target_seq, ) result = await session.execute(stmt) target_entry = result.scalar_one_or_none() if not target_entry: raise NoResultFound(f"No BlockHistory row found for block_id={block.id} at sequence={target_seq}") # Copy fields from target_entry to block block.description = target_entry.description # type: ignore block.label = target_entry.label # type: ignore block.value = target_entry.value # type: ignore block.limit = target_entry.limit # type: ignore block.metadata_ = target_entry.metadata_ # type: ignore block.current_history_entry_id = target_entry.id # type: ignore # Update in DB (optimistic locking). # We'll do a flush now; the caller does final commit. updated_block = await block.update_async(db_session=session, actor=actor, no_commit=True) return updated_block @enforce_types @trace_method async def bulk_update_block_values_async( self, updates: Dict[str, str], actor: PydanticUser, return_hydrated: bool = False ) -> Optional[List[PydanticBlock]]: """ Bulk-update the `value` field for multiple blocks in one transaction. Args: updates: mapping of block_id -> new value actor: the user performing the update (for org scoping, permissions, audit) return_hydrated: whether to return the pydantic Block objects that were updated Returns: the updated Block objects as Pydantic schemas Raises: NoResultFound if any block_id doesn't exist or isn't visible to this actor ValueError if any new value exceeds its block's limit """ async with db_registry.async_session() as session: query = select(BlockModel).where(BlockModel.id.in_(updates.keys()), BlockModel.organization_id == actor.organization_id) result = await session.execute(query) blocks = result.scalars().all() found_ids = {b.id for b in blocks} missing = set(updates.keys()) - found_ids if missing: logger.warning(f"Block IDs not found or inaccessible, skipping during bulk update: {missing!r}") for block in blocks: new_val = updates[block.id] if len(new_val) > block.limit: logger.warning(f"Value length ({len(new_val)}) exceeds limit ({block.limit}) for block {block.id!r}, truncating...") new_val = new_val[: block.limit] block.value = new_val # context manager now handles commits # await session.commit() if return_hydrated: # TODO: implement for async pass return None @enforce_types @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) @trace_method async def checkpoint_block_async( self, block_id: str, actor: PydanticUser, agent_id: Optional[str] = None, use_preloaded_block: Optional[BlockModel] = None, # For concurrency tests ) -> PydanticBlock: """ Create a new checkpoint for the given Block by copying its current state into BlockHistory, using SQLAlchemy's built-in version_id_col for concurrency checks. - If the block was undone to an earlier checkpoint, we remove any "future" checkpoints beyond the current state to keep a strictly linear history. - A single commit at the end ensures atomicity. """ async with db_registry.async_session() as session: # 1) Load the Block if use_preloaded_block is not None: block = await session.merge(use_preloaded_block) else: block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) # 2) Identify the block's current checkpoint (if any) current_entry = None if block.current_history_entry_id: current_entry = await session.get(BlockHistory, block.current_history_entry_id) # The current sequence, or 0 if no checkpoints exist current_seq = current_entry.sequence_number if current_entry else 0 # 3) Truncate any future checkpoints # If we are at seq=2, but there's a seq=3 or higher from a prior "redo chain", # remove those, so we maintain a strictly linear undo/redo stack. stmt = select(BlockHistory).filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number > current_seq) result = await session.execute(stmt) for entry in result.scalars(): session.delete(entry) # Flush the deletes to ensure they're executed before we create a new entry await session.flush() # 4) Determine the next sequence number next_seq = current_seq + 1 # 5) Create a new BlockHistory row reflecting the block's current state history_entry = BlockHistory( organization_id=actor.organization_id, block_id=block.id, sequence_number=next_seq, description=block.description, label=block.label, value=block.value, limit=block.limit, metadata_=block.metadata_, actor_type=ActorType.LETTA_AGENT if agent_id else ActorType.LETTA_USER, actor_id=agent_id if agent_id else actor.id, ) await history_entry.create_async(session, actor=actor, no_commit=True) # 6) Update the block’s pointer to the new checkpoint block.current_history_entry_id = history_entry.id # 7) Flush changes, then commit once block = await block.update_async(db_session=session, actor=actor, no_commit=True) # context manager now handles commits # await session.commit() return block.to_pydantic() @enforce_types async def _move_block_to_sequence(self, session: AsyncSession, block: BlockModel, target_seq: int, actor: PydanticUser) -> BlockModel: """ Internal helper that moves the 'block' to the specified 'target_seq' within BlockHistory. 1) Find the BlockHistory row at sequence_number=target_seq 2) Copy fields into the block 3) Update and flush (no_commit=True) - the caller is responsible for final commit Raises: NoResultFound: if no BlockHistory row for (block_id, target_seq) """ if not block.id: raise ValueError("Block is missing an ID. Cannot move sequence.") stmt = select(BlockHistory).filter( BlockHistory.block_id == block.id, BlockHistory.sequence_number == target_seq, ) result = await session.execute(stmt) target_entry = result.scalar_one_or_none() if not target_entry: raise NoResultFound(f"No BlockHistory row found for block_id={block.id} at sequence={target_seq}") # Copy fields from target_entry to block block.description = target_entry.description # type: ignore block.label = target_entry.label # type: ignore block.value = target_entry.value # type: ignore block.limit = target_entry.limit # type: ignore block.metadata_ = target_entry.metadata_ # type: ignore block.current_history_entry_id = target_entry.id # type: ignore # Update in DB (optimistic locking). # We'll do a flush now; the caller does final commit. updated_block = await block.update_async(db_session=session, actor=actor, no_commit=True) return updated_block @enforce_types @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) @trace_method async def undo_checkpoint_block( self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None ) -> PydanticBlock: """ Move the block to the immediately previous checkpoint in BlockHistory. If older sequences have been pruned, we jump to the largest sequence number that is still < current_seq. """ async with db_registry.async_session() as session: # 1) Load the current block block = ( await session.merge(use_preloaded_block) if use_preloaded_block else await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) ) if not block.current_history_entry_id: raise LettaInvalidArgumentError(f"Block {block_id} has no history entry - cannot undo.", argument_name="block_id") current_entry = await session.get(BlockHistory, block.current_history_entry_id) if not current_entry: raise NoResultFound(f"BlockHistory row not found for id={block.current_history_entry_id}") current_seq = current_entry.sequence_number # 2) Find the largest sequence < current_seq stmt = ( select(BlockHistory) .filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number < current_seq) .order_by(BlockHistory.sequence_number.desc()) .limit(1) ) result = await session.execute(stmt) previous_entry = result.scalar_one_or_none() if not previous_entry: # No earlier checkpoint available raise LettaInvalidArgumentError( f"Block {block_id} is already at the earliest checkpoint (seq={current_seq}). Cannot undo further.", argument_name="block_id", ) # 3) Move to that sequence block = await self._move_block_to_sequence(session, block, previous_entry.sequence_number, actor) # 4) Commit # context manager now handles commits # await session.commit() return block.to_pydantic() @enforce_types @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) @trace_method async def redo_checkpoint_block( self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None ) -> PydanticBlock: """ Move the block to the next checkpoint if it exists. If some middle checkpoints have been pruned, we jump to the smallest sequence > current_seq that remains. """ async with db_registry.async_session() as session: block = ( await session.merge(use_preloaded_block) if use_preloaded_block else await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) ) if not block.current_history_entry_id: raise LettaInvalidArgumentError(f"Block {block_id} has no history entry - cannot redo.", argument_name="block_id") current_entry = await session.get(BlockHistory, block.current_history_entry_id) if not current_entry: raise LettaInvalidArgumentError( f"BlockHistory row not found for id={block.current_history_entry_id}", argument_name="block_id" ) current_seq = current_entry.sequence_number # Find the smallest sequence that is > current_seq stmt = ( select(BlockHistory) .filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number > current_seq) .order_by(BlockHistory.sequence_number.asc()) .limit(1) ) result = await session.execute(stmt) next_entry = result.scalar_one_or_none() if not next_entry: raise LettaInvalidArgumentError( f"Block {block_id} is at the highest checkpoint (seq={current_seq}). Cannot redo further.", argument_name="block_id" ) block = await self._move_block_to_sequence(session, block, next_entry.sequence_number, actor) # context manager now handles commits # await session.commit() return block.to_pydantic()