diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 8a06b1c6..b1301be9 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -15,7 +15,6 @@ from letta.orm.sqlite_functions import adapt_array if TYPE_CHECKING: from pydantic import BaseModel - from sqlalchemy.orm import Session logger = get_logger(__name__) diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index d3982186..01b23d3b 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -1,6 +1,8 @@ import os from typing import List, Optional +from sqlalchemy.orm import Session + from letta.orm.block import Block as BlockModel from letta.orm.block_history import BlockHistory from letta.orm.enums import ActorType @@ -210,64 +212,114 @@ class BlockManager: return block.to_pydantic() @enforce_types - def undo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock: + def _move_block_to_sequence(self, session: Session, block: BlockModel, target_seq: int, actor: PydanticUser) -> BlockModel: """ - Move the block to the previous checkpoint by copying fields - from the immediately previous BlockHistory entry (sequence_number - 1). - - 1) Load the current block (either by merging a preloaded block or reading from DB). - 2) Identify its current history entry. If none, there's nothing to undo. - 3) Determine the previous checkpoint's sequence_number. If seq=1, we can't go earlier. - 4) Copy state from that previous checkpoint into the block. - 5) Commit transaction (optimistic lock check). - 6) Return the updated block as Pydantic. + Internal helper that moves the 'block' to the specified 'target_seq' within BlockHistory. + 1) Find the BlockHistory row at sequence_number=target_seq + 2) Copy fields into the block + 3) Update and flush (no_commit=True) - the caller is responsible for final commit Raises: - ValueError: If no previous checkpoint exists or if we can't find the matching row. - NoResultFound: If the block or block history row do not exist. - StaleDataError: If another transaction updated the block concurrently (optimistic locking). + NoResultFound: if no BlockHistory row for (block_id, target_seq) + """ + if not block.id: + raise ValueError("Block is missing an ID. Cannot move sequence.") + + target_entry = ( + session.query(BlockHistory) + .filter( + BlockHistory.block_id == block.id, + BlockHistory.sequence_number == target_seq, + ) + .one_or_none() + ) + if not target_entry: + raise NoResultFound(f"No BlockHistory row found for block_id={block.id} at sequence={target_seq}") + + # Copy fields from target_entry to block + block.description = target_entry.description # type: ignore + block.label = target_entry.label # type: ignore + block.value = target_entry.value # type: ignore + block.limit = target_entry.limit # type: ignore + block.metadata_ = target_entry.metadata_ # type: ignore + block.current_history_entry_id = target_entry.id # type: ignore + + # Update in DB (optimistic locking). + # We'll do a flush now; the caller does final commit. + updated_block = block.update(db_session=session, actor=actor, no_commit=True) + return updated_block + + @enforce_types + def undo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock: + """ + Move the block to the previous checkpoint (sequence_number - 1). """ with self.session_maker() as session: - # 1) Load the block - if use_preloaded_block is not None: - block = session.merge(use_preloaded_block) - else: - block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) + # 1) Load current block + block = ( + session.merge(use_preloaded_block) + if use_preloaded_block + else BlockModel.read(db_session=session, identifier=block_id, actor=actor) + ) + # 2) Ensure there's a current checkpoint to undo from if not block.current_history_entry_id: - # There's no known history entry to revert from raise ValueError(f"Block {block_id} has no history entry - cannot undo.") - # 2) Fetch the current history entry current_entry = session.get(BlockHistory, block.current_history_entry_id) if not current_entry: raise NoResultFound(f"BlockHistory row not found for id={block.current_history_entry_id}") current_seq = current_entry.sequence_number if current_seq <= 1: - # This means there's no previous checkpoint raise ValueError(f"Block {block_id} is at the first checkpoint (seq=1). Cannot undo further.") - # 3) The previous checkpoint is current_seq - 1 + # 3) Move to the previous sequence previous_seq = current_seq - 1 - prev_entry = ( - session.query(BlockHistory) - .filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number == previous_seq) - .one_or_none() - ) - if not prev_entry: - raise NoResultFound(f"No BlockHistory row for block_id={block.id} at sequence_number={previous_seq}") + block = self._move_block_to_sequence(session, block, previous_seq, actor) - # 4) Copy fields from the prev_entry back to the block - block.description = prev_entry.description - block.label = prev_entry.label - block.value = prev_entry.value - block.limit = prev_entry.limit - block.metadata_ = prev_entry.metadata_ - block.current_history_entry_id = prev_entry.id - - # 5) Commit with optimistic locking. We do a single commit at the end. - block = block.update(db_session=session, actor=actor, no_commit=True) + # 4) Commit once at the end session.commit() + return block.to_pydantic() # type: ignore - return block.to_pydantic() + @enforce_types + def redo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock: + """ + Move the block to the next checkpoint (sequence_number + 1). + + Raises: + ValueError: if the block is not pointing to a known checkpoint, + or if there's no higher sequence_number to move to. + NoResultFound: if the relevant BlockHistory row doesn't exist. + StaleDataError: on concurrency conflicts. + """ + with self.session_maker() as session: + # 1) Load current block + block = ( + session.merge(use_preloaded_block) + if use_preloaded_block + else BlockModel.read(db_session=session, identifier=block_id, actor=actor) + ) + + # 2) If no current_history_entry_id, can't redo + if not block.current_history_entry_id: + raise ValueError(f"Block {block_id} has no history entry - cannot redo.") + + current_entry = session.get(BlockHistory, block.current_history_entry_id) + if not current_entry: + raise NoResultFound(f"BlockHistory row not found for id={block.current_history_entry_id}") + + current_seq = current_entry.sequence_number + + # We'll move to next_seq = current_seq + 1 + next_seq = current_seq + 1 + + # 3) Move to the next sequence using our helper + try: + block = self._move_block_to_sequence(session, block, next_seq, actor) + except NoResultFound: + raise ValueError(f"Block {block_id} is at the highest checkpoint (seq={current_seq}). Cannot redo further.") + + # 4) Commit once + session.commit() + return block.to_pydantic() # type: ignore diff --git a/tests/test_managers.py b/tests/test_managers.py index 7d6664fb..d41a7942 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2980,6 +2980,172 @@ def test_undo_concurrency_stale(server: SyncServer, default_user): block_manager.undo_checkpoint_block(block_id=block_v1.id, actor=default_user, use_preloaded_block=block_s2) # also seq=2 in memory +# ====================================================================================================================== +# Block Manager Tests - Redo +# ====================================================================================================================== + + +def test_redo_checkpoint_block(server: SyncServer, default_user): + """ + 1) Create a block with value v1 -> checkpoint => seq=1 + 2) Update to v2 -> checkpoint => seq=2 + 3) Update to v3 -> checkpoint => seq=3 + 4) Undo once (seq=3 -> seq=2) + 5) Redo once (seq=2 -> seq=3) + """ + + block_manager = BlockManager() + + # 1) Create block, set value='v1'; checkpoint => seq=1 + block_v1 = block_manager.create_or_update_block(PydanticBlock(label="redo_test", value="v1"), actor=default_user) + block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) + + # 2) Update to 'v2'; checkpoint => seq=2 + block_v2 = PydanticBlock(**block_v1.dict()) + block_v2.value = "v2" + block_manager.create_or_update_block(block_v2, actor=default_user) + block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) + + # 3) Update to 'v3'; checkpoint => seq=3 + block_v3 = PydanticBlock(**block_v1.dict()) + block_v3.value = "v3" + block_manager.create_or_update_block(block_v3, actor=default_user) + block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) + + # Undo from seq=3 -> seq=2 + undone_block = block_manager.undo_checkpoint_block(block_v1.id, actor=default_user) + assert undone_block.value == "v2", "After undo, block should revert to v2" + + # Redo from seq=2 -> seq=3 + redone_block = block_manager.redo_checkpoint_block(block_v1.id, actor=default_user) + assert redone_block.value == "v3", "After redo, block should go back to v3" + + +def test_redo_no_history(server: SyncServer, default_user): + """ + If a block has no current_history_entry_id (never checkpointed), + then redo_checkpoint_block should raise ValueError. + """ + block_manager = BlockManager() + + # Create block with no checkpoint + block = block_manager.create_or_update_block(PydanticBlock(label="redo_no_history", value="v0"), actor=default_user) + + # Attempt to redo => expect ValueError + with pytest.raises(ValueError, match="no history entry - cannot redo"): + block_manager.redo_checkpoint_block(block.id, actor=default_user) + + +def test_redo_at_highest_checkpoint(server: SyncServer, default_user): + """ + If the block is at the maximum sequence number, there's no higher checkpoint to move to. + redo_checkpoint_block should raise ValueError. + """ + block_manager = BlockManager() + + # 1) Create block => checkpoint => seq=1 + b_init = block_manager.create_or_update_block(PydanticBlock(label="redo_highest", value="v1"), actor=default_user) + block_manager.checkpoint_block(b_init.id, actor=default_user) + + # 2) Another edit => seq=2 + b_next = PydanticBlock(**b_init.dict()) + b_next.value = "v2" + block_manager.create_or_update_block(b_next, actor=default_user) + block_manager.checkpoint_block(b_init.id, actor=default_user) + + # We are at seq=2, which is the highest checkpoint. + # Attempt redo => there's no seq=3 + with pytest.raises(ValueError, match="Cannot redo further"): + block_manager.redo_checkpoint_block(b_init.id, actor=default_user) + + +def test_redo_after_multiple_undo(server: SyncServer, default_user): + """ + 1) Create and checkpoint versions: v1 -> seq=1, v2 -> seq=2, v3 -> seq=3, v4 -> seq=4 + 2) Undo thrice => from seq=4 to seq=1 + 3) Redo thrice => from seq=1 back to seq=4 + """ + block_manager = BlockManager() + + # Step 1: create initial block => seq=1 + b_init = block_manager.create_or_update_block(PydanticBlock(label="redo_multi", value="v1"), actor=default_user) + block_manager.checkpoint_block(b_init.id, actor=default_user) + + # seq=2 + b_v2 = PydanticBlock(**b_init.dict()) + b_v2.value = "v2" + block_manager.create_or_update_block(b_v2, actor=default_user) + block_manager.checkpoint_block(b_init.id, actor=default_user) + + # seq=3 + b_v3 = PydanticBlock(**b_init.dict()) + b_v3.value = "v3" + block_manager.create_or_update_block(b_v3, actor=default_user) + block_manager.checkpoint_block(b_init.id, actor=default_user) + + # seq=4 + b_v4 = PydanticBlock(**b_init.dict()) + b_v4.value = "v4" + block_manager.create_or_update_block(b_v4, actor=default_user) + block_manager.checkpoint_block(b_init.id, actor=default_user) + + # We have 4 checkpoints: v1...v4. Current is seq=4. + + # 2) Undo thrice => from seq=4 -> seq=1 + for expected_value in ["v3", "v2", "v1"]: + undone_block = block_manager.undo_checkpoint_block(b_init.id, actor=default_user) + assert undone_block.value == expected_value, f"Undo should get us back to {expected_value}" + + # 3) Redo thrice => from seq=1 -> seq=4 + for expected_value in ["v2", "v3", "v4"]: + redone_block = block_manager.redo_checkpoint_block(b_init.id, actor=default_user) + assert redone_block.value == expected_value, f"Redo should get us forward to {expected_value}" + + +def test_redo_concurrency_stale(server: SyncServer, default_user): + block_manager = BlockManager() + + # 1) Create block => checkpoint => seq=1 + block = block_manager.create_or_update_block(PydanticBlock(label="redo_concurrency", value="v1"), actor=default_user) + block_manager.checkpoint_block(block.id, actor=default_user) + + # 2) Another edit => checkpoint => seq=2 + block_v2 = PydanticBlock(**block.dict()) + block_v2.value = "v2" + block_manager.create_or_update_block(block_v2, actor=default_user) + block_manager.checkpoint_block(block.id, actor=default_user) + + # 3) Another edit => checkpoint => seq=3 + block_v3 = PydanticBlock(**block.dict()) + block_v3.value = "v3" + block_manager.create_or_update_block(block_v3, actor=default_user) + block_manager.checkpoint_block(block.id, actor=default_user) + # Now the block is at seq=3 in the DB + + # 4) Undo from seq=3 -> seq=2 so that we have a known future state at seq=3 + undone_block = block_manager.undo_checkpoint_block(block.id, actor=default_user) + assert undone_block.value == "v2" + + # At this point the block is physically at seq=2 in DB, + # but there's a valid row for seq=3 in block_history (the 'v3' state). + + # 5) Simulate concurrency: two sessions each read the block at seq=2 + with db_context() as s1: + block_s1 = s1.get(Block, block.id) + with db_context() as s2: + block_s2 = s2.get(Block, block.id) + + # 6) Session1 redoes to seq=3 first -> success + block_manager.redo_checkpoint_block(block_id=block.id, actor=default_user, use_preloaded_block=block_s1) + # commits => block is now seq=3 in DB, version increments + + # 7) Session2 tries to do the same from stale version + # => we expect StaleDataError, because the second session is using + # an out-of-date version of the block + with pytest.raises(StaleDataError): + block_manager.redo_checkpoint_block(block_id=block.id, actor=default_user, use_preloaded_block=block_s2) + + # ====================================================================================================================== # Identity Manager Tests # ======================================================================================================================