feat: Add ability to redo to more recent checkpoint (#1496)

This commit is contained in:
Matthew Zhou
2025-03-31 17:35:51 -07:00
committed by GitHub
parent 00f8edaf97
commit 22dcd53bd1
3 changed files with 258 additions and 41 deletions

View File

@@ -15,7 +15,6 @@ from letta.orm.sqlite_functions import adapt_array
if TYPE_CHECKING:
from pydantic import BaseModel
from sqlalchemy.orm import Session
logger = get_logger(__name__)

View File

@@ -1,6 +1,8 @@
import os
from typing import List, Optional
from sqlalchemy.orm import Session
from letta.orm.block import Block as BlockModel
from letta.orm.block_history import BlockHistory
from letta.orm.enums import ActorType
@@ -210,64 +212,114 @@ class BlockManager:
return block.to_pydantic()
@enforce_types
def undo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock:
def _move_block_to_sequence(self, session: Session, block: BlockModel, target_seq: int, actor: PydanticUser) -> BlockModel:
"""
Move the block to the previous checkpoint by copying fields
from the immediately previous BlockHistory entry (sequence_number - 1).
1) Load the current block (either by merging a preloaded block or reading from DB).
2) Identify its current history entry. If none, there's nothing to undo.
3) Determine the previous checkpoint's sequence_number. If seq=1, we can't go earlier.
4) Copy state from that previous checkpoint into the block.
5) Commit transaction (optimistic lock check).
6) Return the updated block as Pydantic.
Internal helper that moves the 'block' to the specified 'target_seq' within BlockHistory.
1) Find the BlockHistory row at sequence_number=target_seq
2) Copy fields into the block
3) Update and flush (no_commit=True) - the caller is responsible for final commit
Raises:
ValueError: If no previous checkpoint exists or if we can't find the matching row.
NoResultFound: If the block or block history row do not exist.
StaleDataError: If another transaction updated the block concurrently (optimistic locking).
NoResultFound: if no BlockHistory row for (block_id, target_seq)
"""
if not block.id:
raise ValueError("Block is missing an ID. Cannot move sequence.")
target_entry = (
session.query(BlockHistory)
.filter(
BlockHistory.block_id == block.id,
BlockHistory.sequence_number == target_seq,
)
.one_or_none()
)
if not target_entry:
raise NoResultFound(f"No BlockHistory row found for block_id={block.id} at sequence={target_seq}")
# Copy fields from target_entry to block
block.description = target_entry.description # type: ignore
block.label = target_entry.label # type: ignore
block.value = target_entry.value # type: ignore
block.limit = target_entry.limit # type: ignore
block.metadata_ = target_entry.metadata_ # type: ignore
block.current_history_entry_id = target_entry.id # type: ignore
# Update in DB (optimistic locking).
# We'll do a flush now; the caller does final commit.
updated_block = block.update(db_session=session, actor=actor, no_commit=True)
return updated_block
@enforce_types
def undo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock:
"""
Move the block to the previous checkpoint (sequence_number - 1).
"""
with self.session_maker() as session:
# 1) Load the block
if use_preloaded_block is not None:
block = session.merge(use_preloaded_block)
else:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
# 1) Load current block
block = (
session.merge(use_preloaded_block)
if use_preloaded_block
else BlockModel.read(db_session=session, identifier=block_id, actor=actor)
)
# 2) Ensure there's a current checkpoint to undo from
if not block.current_history_entry_id:
# There's no known history entry to revert from
raise ValueError(f"Block {block_id} has no history entry - cannot undo.")
# 2) Fetch the current history entry
current_entry = session.get(BlockHistory, block.current_history_entry_id)
if not current_entry:
raise NoResultFound(f"BlockHistory row not found for id={block.current_history_entry_id}")
current_seq = current_entry.sequence_number
if current_seq <= 1:
# This means there's no previous checkpoint
raise ValueError(f"Block {block_id} is at the first checkpoint (seq=1). Cannot undo further.")
# 3) The previous checkpoint is current_seq - 1
# 3) Move to the previous sequence
previous_seq = current_seq - 1
prev_entry = (
session.query(BlockHistory)
.filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number == previous_seq)
.one_or_none()
)
if not prev_entry:
raise NoResultFound(f"No BlockHistory row for block_id={block.id} at sequence_number={previous_seq}")
block = self._move_block_to_sequence(session, block, previous_seq, actor)
# 4) Copy fields from the prev_entry back to the block
block.description = prev_entry.description
block.label = prev_entry.label
block.value = prev_entry.value
block.limit = prev_entry.limit
block.metadata_ = prev_entry.metadata_
block.current_history_entry_id = prev_entry.id
# 5) Commit with optimistic locking. We do a single commit at the end.
block = block.update(db_session=session, actor=actor, no_commit=True)
# 4) Commit once at the end
session.commit()
return block.to_pydantic() # type: ignore
return block.to_pydantic()
@enforce_types
def redo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock:
"""
Move the block to the next checkpoint (sequence_number + 1).
Raises:
ValueError: if the block is not pointing to a known checkpoint,
or if there's no higher sequence_number to move to.
NoResultFound: if the relevant BlockHistory row doesn't exist.
StaleDataError: on concurrency conflicts.
"""
with self.session_maker() as session:
# 1) Load current block
block = (
session.merge(use_preloaded_block)
if use_preloaded_block
else BlockModel.read(db_session=session, identifier=block_id, actor=actor)
)
# 2) If no current_history_entry_id, can't redo
if not block.current_history_entry_id:
raise ValueError(f"Block {block_id} has no history entry - cannot redo.")
current_entry = session.get(BlockHistory, block.current_history_entry_id)
if not current_entry:
raise NoResultFound(f"BlockHistory row not found for id={block.current_history_entry_id}")
current_seq = current_entry.sequence_number
# We'll move to next_seq = current_seq + 1
next_seq = current_seq + 1
# 3) Move to the next sequence using our helper
try:
block = self._move_block_to_sequence(session, block, next_seq, actor)
except NoResultFound:
raise ValueError(f"Block {block_id} is at the highest checkpoint (seq={current_seq}). Cannot redo further.")
# 4) Commit once
session.commit()
return block.to_pydantic() # type: ignore

View File

@@ -2980,6 +2980,172 @@ def test_undo_concurrency_stale(server: SyncServer, default_user):
block_manager.undo_checkpoint_block(block_id=block_v1.id, actor=default_user, use_preloaded_block=block_s2) # also seq=2 in memory
# ======================================================================================================================
# Block Manager Tests - Redo
# ======================================================================================================================
def test_redo_checkpoint_block(server: SyncServer, default_user):
"""
1) Create a block with value v1 -> checkpoint => seq=1
2) Update to v2 -> checkpoint => seq=2
3) Update to v3 -> checkpoint => seq=3
4) Undo once (seq=3 -> seq=2)
5) Redo once (seq=2 -> seq=3)
"""
block_manager = BlockManager()
# 1) Create block, set value='v1'; checkpoint => seq=1
block_v1 = block_manager.create_or_update_block(PydanticBlock(label="redo_test", value="v1"), actor=default_user)
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
# 2) Update to 'v2'; checkpoint => seq=2
block_v2 = PydanticBlock(**block_v1.dict())
block_v2.value = "v2"
block_manager.create_or_update_block(block_v2, actor=default_user)
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
# 3) Update to 'v3'; checkpoint => seq=3
block_v3 = PydanticBlock(**block_v1.dict())
block_v3.value = "v3"
block_manager.create_or_update_block(block_v3, actor=default_user)
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
# Undo from seq=3 -> seq=2
undone_block = block_manager.undo_checkpoint_block(block_v1.id, actor=default_user)
assert undone_block.value == "v2", "After undo, block should revert to v2"
# Redo from seq=2 -> seq=3
redone_block = block_manager.redo_checkpoint_block(block_v1.id, actor=default_user)
assert redone_block.value == "v3", "After redo, block should go back to v3"
def test_redo_no_history(server: SyncServer, default_user):
"""
If a block has no current_history_entry_id (never checkpointed),
then redo_checkpoint_block should raise ValueError.
"""
block_manager = BlockManager()
# Create block with no checkpoint
block = block_manager.create_or_update_block(PydanticBlock(label="redo_no_history", value="v0"), actor=default_user)
# Attempt to redo => expect ValueError
with pytest.raises(ValueError, match="no history entry - cannot redo"):
block_manager.redo_checkpoint_block(block.id, actor=default_user)
def test_redo_at_highest_checkpoint(server: SyncServer, default_user):
"""
If the block is at the maximum sequence number, there's no higher checkpoint to move to.
redo_checkpoint_block should raise ValueError.
"""
block_manager = BlockManager()
# 1) Create block => checkpoint => seq=1
b_init = block_manager.create_or_update_block(PydanticBlock(label="redo_highest", value="v1"), actor=default_user)
block_manager.checkpoint_block(b_init.id, actor=default_user)
# 2) Another edit => seq=2
b_next = PydanticBlock(**b_init.dict())
b_next.value = "v2"
block_manager.create_or_update_block(b_next, actor=default_user)
block_manager.checkpoint_block(b_init.id, actor=default_user)
# We are at seq=2, which is the highest checkpoint.
# Attempt redo => there's no seq=3
with pytest.raises(ValueError, match="Cannot redo further"):
block_manager.redo_checkpoint_block(b_init.id, actor=default_user)
def test_redo_after_multiple_undo(server: SyncServer, default_user):
"""
1) Create and checkpoint versions: v1 -> seq=1, v2 -> seq=2, v3 -> seq=3, v4 -> seq=4
2) Undo thrice => from seq=4 to seq=1
3) Redo thrice => from seq=1 back to seq=4
"""
block_manager = BlockManager()
# Step 1: create initial block => seq=1
b_init = block_manager.create_or_update_block(PydanticBlock(label="redo_multi", value="v1"), actor=default_user)
block_manager.checkpoint_block(b_init.id, actor=default_user)
# seq=2
b_v2 = PydanticBlock(**b_init.dict())
b_v2.value = "v2"
block_manager.create_or_update_block(b_v2, actor=default_user)
block_manager.checkpoint_block(b_init.id, actor=default_user)
# seq=3
b_v3 = PydanticBlock(**b_init.dict())
b_v3.value = "v3"
block_manager.create_or_update_block(b_v3, actor=default_user)
block_manager.checkpoint_block(b_init.id, actor=default_user)
# seq=4
b_v4 = PydanticBlock(**b_init.dict())
b_v4.value = "v4"
block_manager.create_or_update_block(b_v4, actor=default_user)
block_manager.checkpoint_block(b_init.id, actor=default_user)
# We have 4 checkpoints: v1...v4. Current is seq=4.
# 2) Undo thrice => from seq=4 -> seq=1
for expected_value in ["v3", "v2", "v1"]:
undone_block = block_manager.undo_checkpoint_block(b_init.id, actor=default_user)
assert undone_block.value == expected_value, f"Undo should get us back to {expected_value}"
# 3) Redo thrice => from seq=1 -> seq=4
for expected_value in ["v2", "v3", "v4"]:
redone_block = block_manager.redo_checkpoint_block(b_init.id, actor=default_user)
assert redone_block.value == expected_value, f"Redo should get us forward to {expected_value}"
def test_redo_concurrency_stale(server: SyncServer, default_user):
block_manager = BlockManager()
# 1) Create block => checkpoint => seq=1
block = block_manager.create_or_update_block(PydanticBlock(label="redo_concurrency", value="v1"), actor=default_user)
block_manager.checkpoint_block(block.id, actor=default_user)
# 2) Another edit => checkpoint => seq=2
block_v2 = PydanticBlock(**block.dict())
block_v2.value = "v2"
block_manager.create_or_update_block(block_v2, actor=default_user)
block_manager.checkpoint_block(block.id, actor=default_user)
# 3) Another edit => checkpoint => seq=3
block_v3 = PydanticBlock(**block.dict())
block_v3.value = "v3"
block_manager.create_or_update_block(block_v3, actor=default_user)
block_manager.checkpoint_block(block.id, actor=default_user)
# Now the block is at seq=3 in the DB
# 4) Undo from seq=3 -> seq=2 so that we have a known future state at seq=3
undone_block = block_manager.undo_checkpoint_block(block.id, actor=default_user)
assert undone_block.value == "v2"
# At this point the block is physically at seq=2 in DB,
# but there's a valid row for seq=3 in block_history (the 'v3' state).
# 5) Simulate concurrency: two sessions each read the block at seq=2
with db_context() as s1:
block_s1 = s1.get(Block, block.id)
with db_context() as s2:
block_s2 = s2.get(Block, block.id)
# 6) Session1 redoes to seq=3 first -> success
block_manager.redo_checkpoint_block(block_id=block.id, actor=default_user, use_preloaded_block=block_s1)
# commits => block is now seq=3 in DB, version increments
# 7) Session2 tries to do the same from stale version
# => we expect StaleDataError, because the second session is using
# an out-of-date version of the block
with pytest.raises(StaleDataError):
block_manager.redo_checkpoint_block(block_id=block.id, actor=default_user, use_preloaded_block=block_s2)
# ======================================================================================================================
# Identity Manager Tests
# ======================================================================================================================