feat: Add ability to redo to more recent checkpoint (#1496)
This commit is contained in:
@@ -15,7 +15,6 @@ from letta.orm.sqlite_functions import adapt_array
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.block import Block as BlockModel
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.enums import ActorType
|
||||
@@ -210,64 +212,114 @@ class BlockManager:
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def undo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock:
|
||||
def _move_block_to_sequence(self, session: Session, block: BlockModel, target_seq: int, actor: PydanticUser) -> BlockModel:
|
||||
"""
|
||||
Move the block to the previous checkpoint by copying fields
|
||||
from the immediately previous BlockHistory entry (sequence_number - 1).
|
||||
|
||||
1) Load the current block (either by merging a preloaded block or reading from DB).
|
||||
2) Identify its current history entry. If none, there's nothing to undo.
|
||||
3) Determine the previous checkpoint's sequence_number. If seq=1, we can't go earlier.
|
||||
4) Copy state from that previous checkpoint into the block.
|
||||
5) Commit transaction (optimistic lock check).
|
||||
6) Return the updated block as Pydantic.
|
||||
Internal helper that moves the 'block' to the specified 'target_seq' within BlockHistory.
|
||||
1) Find the BlockHistory row at sequence_number=target_seq
|
||||
2) Copy fields into the block
|
||||
3) Update and flush (no_commit=True) - the caller is responsible for final commit
|
||||
|
||||
Raises:
|
||||
ValueError: If no previous checkpoint exists or if we can't find the matching row.
|
||||
NoResultFound: If the block or block history row do not exist.
|
||||
StaleDataError: If another transaction updated the block concurrently (optimistic locking).
|
||||
NoResultFound: if no BlockHistory row for (block_id, target_seq)
|
||||
"""
|
||||
if not block.id:
|
||||
raise ValueError("Block is missing an ID. Cannot move sequence.")
|
||||
|
||||
target_entry = (
|
||||
session.query(BlockHistory)
|
||||
.filter(
|
||||
BlockHistory.block_id == block.id,
|
||||
BlockHistory.sequence_number == target_seq,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if not target_entry:
|
||||
raise NoResultFound(f"No BlockHistory row found for block_id={block.id} at sequence={target_seq}")
|
||||
|
||||
# Copy fields from target_entry to block
|
||||
block.description = target_entry.description # type: ignore
|
||||
block.label = target_entry.label # type: ignore
|
||||
block.value = target_entry.value # type: ignore
|
||||
block.limit = target_entry.limit # type: ignore
|
||||
block.metadata_ = target_entry.metadata_ # type: ignore
|
||||
block.current_history_entry_id = target_entry.id # type: ignore
|
||||
|
||||
# Update in DB (optimistic locking).
|
||||
# We'll do a flush now; the caller does final commit.
|
||||
updated_block = block.update(db_session=session, actor=actor, no_commit=True)
|
||||
return updated_block
|
||||
|
||||
@enforce_types
|
||||
def undo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock:
|
||||
"""
|
||||
Move the block to the previous checkpoint (sequence_number - 1).
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# 1) Load the block
|
||||
if use_preloaded_block is not None:
|
||||
block = session.merge(use_preloaded_block)
|
||||
else:
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
# 1) Load current block
|
||||
block = (
|
||||
session.merge(use_preloaded_block)
|
||||
if use_preloaded_block
|
||||
else BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
)
|
||||
|
||||
# 2) Ensure there's a current checkpoint to undo from
|
||||
if not block.current_history_entry_id:
|
||||
# There's no known history entry to revert from
|
||||
raise ValueError(f"Block {block_id} has no history entry - cannot undo.")
|
||||
|
||||
# 2) Fetch the current history entry
|
||||
current_entry = session.get(BlockHistory, block.current_history_entry_id)
|
||||
if not current_entry:
|
||||
raise NoResultFound(f"BlockHistory row not found for id={block.current_history_entry_id}")
|
||||
|
||||
current_seq = current_entry.sequence_number
|
||||
if current_seq <= 1:
|
||||
# This means there's no previous checkpoint
|
||||
raise ValueError(f"Block {block_id} is at the first checkpoint (seq=1). Cannot undo further.")
|
||||
|
||||
# 3) The previous checkpoint is current_seq - 1
|
||||
# 3) Move to the previous sequence
|
||||
previous_seq = current_seq - 1
|
||||
prev_entry = (
|
||||
session.query(BlockHistory)
|
||||
.filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number == previous_seq)
|
||||
.one_or_none()
|
||||
)
|
||||
if not prev_entry:
|
||||
raise NoResultFound(f"No BlockHistory row for block_id={block.id} at sequence_number={previous_seq}")
|
||||
block = self._move_block_to_sequence(session, block, previous_seq, actor)
|
||||
|
||||
# 4) Copy fields from the prev_entry back to the block
|
||||
block.description = prev_entry.description
|
||||
block.label = prev_entry.label
|
||||
block.value = prev_entry.value
|
||||
block.limit = prev_entry.limit
|
||||
block.metadata_ = prev_entry.metadata_
|
||||
block.current_history_entry_id = prev_entry.id
|
||||
|
||||
# 5) Commit with optimistic locking. We do a single commit at the end.
|
||||
block = block.update(db_session=session, actor=actor, no_commit=True)
|
||||
# 4) Commit once at the end
|
||||
session.commit()
|
||||
return block.to_pydantic() # type: ignore
|
||||
|
||||
return block.to_pydantic()
|
||||
@enforce_types
|
||||
def redo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock:
|
||||
"""
|
||||
Move the block to the next checkpoint (sequence_number + 1).
|
||||
|
||||
Raises:
|
||||
ValueError: if the block is not pointing to a known checkpoint,
|
||||
or if there's no higher sequence_number to move to.
|
||||
NoResultFound: if the relevant BlockHistory row doesn't exist.
|
||||
StaleDataError: on concurrency conflicts.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# 1) Load current block
|
||||
block = (
|
||||
session.merge(use_preloaded_block)
|
||||
if use_preloaded_block
|
||||
else BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
)
|
||||
|
||||
# 2) If no current_history_entry_id, can't redo
|
||||
if not block.current_history_entry_id:
|
||||
raise ValueError(f"Block {block_id} has no history entry - cannot redo.")
|
||||
|
||||
current_entry = session.get(BlockHistory, block.current_history_entry_id)
|
||||
if not current_entry:
|
||||
raise NoResultFound(f"BlockHistory row not found for id={block.current_history_entry_id}")
|
||||
|
||||
current_seq = current_entry.sequence_number
|
||||
|
||||
# We'll move to next_seq = current_seq + 1
|
||||
next_seq = current_seq + 1
|
||||
|
||||
# 3) Move to the next sequence using our helper
|
||||
try:
|
||||
block = self._move_block_to_sequence(session, block, next_seq, actor)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Block {block_id} is at the highest checkpoint (seq={current_seq}). Cannot redo further.")
|
||||
|
||||
# 4) Commit once
|
||||
session.commit()
|
||||
return block.to_pydantic() # type: ignore
|
||||
|
||||
@@ -2980,6 +2980,172 @@ def test_undo_concurrency_stale(server: SyncServer, default_user):
|
||||
block_manager.undo_checkpoint_block(block_id=block_v1.id, actor=default_user, use_preloaded_block=block_s2) # also seq=2 in memory
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block Manager Tests - Redo
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_redo_checkpoint_block(server: SyncServer, default_user):
|
||||
"""
|
||||
1) Create a block with value v1 -> checkpoint => seq=1
|
||||
2) Update to v2 -> checkpoint => seq=2
|
||||
3) Update to v3 -> checkpoint => seq=3
|
||||
4) Undo once (seq=3 -> seq=2)
|
||||
5) Redo once (seq=2 -> seq=3)
|
||||
"""
|
||||
|
||||
block_manager = BlockManager()
|
||||
|
||||
# 1) Create block, set value='v1'; checkpoint => seq=1
|
||||
block_v1 = block_manager.create_or_update_block(PydanticBlock(label="redo_test", value="v1"), actor=default_user)
|
||||
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
|
||||
|
||||
# 2) Update to 'v2'; checkpoint => seq=2
|
||||
block_v2 = PydanticBlock(**block_v1.dict())
|
||||
block_v2.value = "v2"
|
||||
block_manager.create_or_update_block(block_v2, actor=default_user)
|
||||
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
|
||||
|
||||
# 3) Update to 'v3'; checkpoint => seq=3
|
||||
block_v3 = PydanticBlock(**block_v1.dict())
|
||||
block_v3.value = "v3"
|
||||
block_manager.create_or_update_block(block_v3, actor=default_user)
|
||||
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
|
||||
|
||||
# Undo from seq=3 -> seq=2
|
||||
undone_block = block_manager.undo_checkpoint_block(block_v1.id, actor=default_user)
|
||||
assert undone_block.value == "v2", "After undo, block should revert to v2"
|
||||
|
||||
# Redo from seq=2 -> seq=3
|
||||
redone_block = block_manager.redo_checkpoint_block(block_v1.id, actor=default_user)
|
||||
assert redone_block.value == "v3", "After redo, block should go back to v3"
|
||||
|
||||
|
||||
def test_redo_no_history(server: SyncServer, default_user):
|
||||
"""
|
||||
If a block has no current_history_entry_id (never checkpointed),
|
||||
then redo_checkpoint_block should raise ValueError.
|
||||
"""
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Create block with no checkpoint
|
||||
block = block_manager.create_or_update_block(PydanticBlock(label="redo_no_history", value="v0"), actor=default_user)
|
||||
|
||||
# Attempt to redo => expect ValueError
|
||||
with pytest.raises(ValueError, match="no history entry - cannot redo"):
|
||||
block_manager.redo_checkpoint_block(block.id, actor=default_user)
|
||||
|
||||
|
||||
def test_redo_at_highest_checkpoint(server: SyncServer, default_user):
|
||||
"""
|
||||
If the block is at the maximum sequence number, there's no higher checkpoint to move to.
|
||||
redo_checkpoint_block should raise ValueError.
|
||||
"""
|
||||
block_manager = BlockManager()
|
||||
|
||||
# 1) Create block => checkpoint => seq=1
|
||||
b_init = block_manager.create_or_update_block(PydanticBlock(label="redo_highest", value="v1"), actor=default_user)
|
||||
block_manager.checkpoint_block(b_init.id, actor=default_user)
|
||||
|
||||
# 2) Another edit => seq=2
|
||||
b_next = PydanticBlock(**b_init.dict())
|
||||
b_next.value = "v2"
|
||||
block_manager.create_or_update_block(b_next, actor=default_user)
|
||||
block_manager.checkpoint_block(b_init.id, actor=default_user)
|
||||
|
||||
# We are at seq=2, which is the highest checkpoint.
|
||||
# Attempt redo => there's no seq=3
|
||||
with pytest.raises(ValueError, match="Cannot redo further"):
|
||||
block_manager.redo_checkpoint_block(b_init.id, actor=default_user)
|
||||
|
||||
|
||||
def test_redo_after_multiple_undo(server: SyncServer, default_user):
|
||||
"""
|
||||
1) Create and checkpoint versions: v1 -> seq=1, v2 -> seq=2, v3 -> seq=3, v4 -> seq=4
|
||||
2) Undo thrice => from seq=4 to seq=1
|
||||
3) Redo thrice => from seq=1 back to seq=4
|
||||
"""
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Step 1: create initial block => seq=1
|
||||
b_init = block_manager.create_or_update_block(PydanticBlock(label="redo_multi", value="v1"), actor=default_user)
|
||||
block_manager.checkpoint_block(b_init.id, actor=default_user)
|
||||
|
||||
# seq=2
|
||||
b_v2 = PydanticBlock(**b_init.dict())
|
||||
b_v2.value = "v2"
|
||||
block_manager.create_or_update_block(b_v2, actor=default_user)
|
||||
block_manager.checkpoint_block(b_init.id, actor=default_user)
|
||||
|
||||
# seq=3
|
||||
b_v3 = PydanticBlock(**b_init.dict())
|
||||
b_v3.value = "v3"
|
||||
block_manager.create_or_update_block(b_v3, actor=default_user)
|
||||
block_manager.checkpoint_block(b_init.id, actor=default_user)
|
||||
|
||||
# seq=4
|
||||
b_v4 = PydanticBlock(**b_init.dict())
|
||||
b_v4.value = "v4"
|
||||
block_manager.create_or_update_block(b_v4, actor=default_user)
|
||||
block_manager.checkpoint_block(b_init.id, actor=default_user)
|
||||
|
||||
# We have 4 checkpoints: v1...v4. Current is seq=4.
|
||||
|
||||
# 2) Undo thrice => from seq=4 -> seq=1
|
||||
for expected_value in ["v3", "v2", "v1"]:
|
||||
undone_block = block_manager.undo_checkpoint_block(b_init.id, actor=default_user)
|
||||
assert undone_block.value == expected_value, f"Undo should get us back to {expected_value}"
|
||||
|
||||
# 3) Redo thrice => from seq=1 -> seq=4
|
||||
for expected_value in ["v2", "v3", "v4"]:
|
||||
redone_block = block_manager.redo_checkpoint_block(b_init.id, actor=default_user)
|
||||
assert redone_block.value == expected_value, f"Redo should get us forward to {expected_value}"
|
||||
|
||||
|
||||
def test_redo_concurrency_stale(server: SyncServer, default_user):
|
||||
block_manager = BlockManager()
|
||||
|
||||
# 1) Create block => checkpoint => seq=1
|
||||
block = block_manager.create_or_update_block(PydanticBlock(label="redo_concurrency", value="v1"), actor=default_user)
|
||||
block_manager.checkpoint_block(block.id, actor=default_user)
|
||||
|
||||
# 2) Another edit => checkpoint => seq=2
|
||||
block_v2 = PydanticBlock(**block.dict())
|
||||
block_v2.value = "v2"
|
||||
block_manager.create_or_update_block(block_v2, actor=default_user)
|
||||
block_manager.checkpoint_block(block.id, actor=default_user)
|
||||
|
||||
# 3) Another edit => checkpoint => seq=3
|
||||
block_v3 = PydanticBlock(**block.dict())
|
||||
block_v3.value = "v3"
|
||||
block_manager.create_or_update_block(block_v3, actor=default_user)
|
||||
block_manager.checkpoint_block(block.id, actor=default_user)
|
||||
# Now the block is at seq=3 in the DB
|
||||
|
||||
# 4) Undo from seq=3 -> seq=2 so that we have a known future state at seq=3
|
||||
undone_block = block_manager.undo_checkpoint_block(block.id, actor=default_user)
|
||||
assert undone_block.value == "v2"
|
||||
|
||||
# At this point the block is physically at seq=2 in DB,
|
||||
# but there's a valid row for seq=3 in block_history (the 'v3' state).
|
||||
|
||||
# 5) Simulate concurrency: two sessions each read the block at seq=2
|
||||
with db_context() as s1:
|
||||
block_s1 = s1.get(Block, block.id)
|
||||
with db_context() as s2:
|
||||
block_s2 = s2.get(Block, block.id)
|
||||
|
||||
# 6) Session1 redoes to seq=3 first -> success
|
||||
block_manager.redo_checkpoint_block(block_id=block.id, actor=default_user, use_preloaded_block=block_s1)
|
||||
# commits => block is now seq=3 in DB, version increments
|
||||
|
||||
# 7) Session2 tries to do the same from stale version
|
||||
# => we expect StaleDataError, because the second session is using
|
||||
# an out-of-date version of the block
|
||||
with pytest.raises(StaleDataError):
|
||||
block_manager.redo_checkpoint_block(block_id=block.id, actor=default_user, use_preloaded_block=block_s2)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Identity Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user