feat: Add block history tables (#1489)
This commit is contained in:
65
alembic/versions/bff040379479_add_block_history_tables.py
Normal file
65
alembic/versions/bff040379479_add_block_history_tables.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Add block history tables
|
||||
|
||||
Revision ID: bff040379479
|
||||
Revises: a66510f83fc2
|
||||
Create Date: 2025-03-31 14:49:30.449052
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "bff040379479"
|
||||
down_revision: Union[str, None] = "a66510f83fc2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"block_history",
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("label", sa.String(), nullable=False),
|
||||
sa.Column("value", sa.Text(), nullable=False),
|
||||
sa.Column("limit", sa.BigInteger(), nullable=False),
|
||||
sa.Column("metadata_", sa.JSON(), nullable=True),
|
||||
sa.Column("actor_type", sa.String(), nullable=True),
|
||||
sa.Column("actor_id", sa.String(), nullable=True),
|
||||
sa.Column("block_id", sa.String(), nullable=False),
|
||||
sa.Column("sequence_number", sa.Integer(), nullable=False),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["block_id"], ["block.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_block_history_block_id_sequence", "block_history", ["block_id", "sequence_number"], unique=True)
|
||||
op.add_column("block", sa.Column("current_history_entry_id", sa.String(), nullable=True))
|
||||
op.add_column("block", sa.Column("version", sa.Integer(), server_default="1", nullable=False))
|
||||
op.create_index(op.f("ix_block_current_history_entry_id"), "block", ["current_history_entry_id"], unique=False)
|
||||
op.create_foreign_key("fk_block_current_history_entry", "block", "block_history", ["current_history_entry_id"], ["id"], use_alter=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("fk_block_current_history_entry", "block", type_="foreignkey")
|
||||
op.drop_index(op.f("ix_block_current_history_entry_id"), table_name="block")
|
||||
op.drop_column("block", "version")
|
||||
op.drop_column("block", "current_history_entry_id")
|
||||
op.drop_index("ix_block_history_block_id_sequence", table_name="block_history")
|
||||
op.drop_table("block_history")
|
||||
# ### end Alembic commands ###
|
||||
@@ -2,6 +2,7 @@ from letta.orm.agent import Agent
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.base import Base
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.group import Group
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import TYPE_CHECKING, List, Optional, Type
|
||||
|
||||
from sqlalchemy import JSON, BigInteger, Index, Integer, UniqueConstraint, event
|
||||
from sqlalchemy.orm import Mapped, attributes, mapped_column, relationship
|
||||
from sqlalchemy import JSON, BigInteger, ForeignKey, Index, Integer, String, UniqueConstraint, event
|
||||
from sqlalchemy.orm import Mapped, attributes, declared_attr, mapped_column, relationship
|
||||
|
||||
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
@@ -38,6 +39,17 @@ class Block(OrganizationMixin, SqlalchemyBase):
|
||||
limit: Mapped[BigInteger] = mapped_column(Integer, default=CORE_MEMORY_BLOCK_CHAR_LIMIT, doc="Character limit of the block.")
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default={}, doc="arbitrary information related to the block.")
|
||||
|
||||
# history pointers / locking mechanisms
|
||||
current_history_entry_id: Mapped[Optional[str]] = mapped_column(
|
||||
String, ForeignKey("block_history.id", name="fk_block_current_history_entry", use_alter=True), nullable=True, index=True
|
||||
)
|
||||
version: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=1, server_default="1", doc="Optimistic locking version counter, incremented on each state change."
|
||||
)
|
||||
# NOTE: This takes advantage of built-in optimistic locking functionality by SqlAlchemy
|
||||
# https://docs.sqlalchemy.org/en/20/orm/versioning.html
|
||||
__mapper_args__ = {"version_id_col": version}
|
||||
|
||||
# relationships
|
||||
organization: Mapped[Optional["Organization"]] = relationship("Organization")
|
||||
agents: Mapped[List["Agent"]] = relationship(
|
||||
@@ -68,6 +80,17 @@ class Block(OrganizationMixin, SqlalchemyBase):
|
||||
model_dict["metadata"] = self.metadata_
|
||||
return Schema.model_validate(model_dict)
|
||||
|
||||
@declared_attr
|
||||
def current_history_entry(cls) -> Mapped[Optional["BlockHistory"]]:
|
||||
# Relationship to easily load the specific history entry that is current
|
||||
return relationship(
|
||||
"BlockHistory",
|
||||
primaryjoin=lambda: cls.current_history_entry_id == BlockHistory.id,
|
||||
foreign_keys=[cls.current_history_entry_id],
|
||||
lazy="joined", # Typically want current history details readily available
|
||||
post_update=True,
|
||||
) # Helps manage potential FK cycles
|
||||
|
||||
|
||||
@event.listens_for(Block, "after_update") # Changed from 'before_update'
|
||||
def block_before_update(mapper, connection, target):
|
||||
|
||||
46
letta/orm/block_history.py
Normal file
46
letta/orm/block_history.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import JSON, BigInteger, ForeignKey, Index, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.enums import ActorType
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
|
||||
|
||||
class BlockHistory(OrganizationMixin, SqlalchemyBase):
|
||||
"""Stores a single historical state of a Block for undo/redo functionality."""
|
||||
|
||||
__tablename__ = "block_history"
|
||||
|
||||
__table_args__ = (
|
||||
# PRIMARY lookup index for finding specific history entries & ordering
|
||||
Index("ix_block_history_block_id_sequence", "block_id", "sequence_number", unique=True),
|
||||
)
|
||||
|
||||
# agent generates its own id
|
||||
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
|
||||
# TODO: Some still rely on the Pydantic object to do this
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"block_hist-{uuid.uuid4()}")
|
||||
|
||||
# Snapshot State Fields (Copied from Block)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
label: Mapped[str] = mapped_column(String, nullable=False)
|
||||
value: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
limit: Mapped[BigInteger] = mapped_column(BigInteger, nullable=False)
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
|
||||
# Editor info
|
||||
# These are not made to be FKs because these may not always exist (e.g. a User be deleted after they made a checkpoint)
|
||||
actor_type: Mapped[Optional[ActorType]] = mapped_column(String, nullable=True)
|
||||
actor_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
block_id: Mapped[str] = mapped_column(
|
||||
String, ForeignKey("block.id", ondelete="CASCADE"), nullable=False # History deleted if Block is deleted
|
||||
)
|
||||
|
||||
sequence_number: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, doc="Monotonically increasing sequence number for the history of a specific block_id, starting from 1."
|
||||
)
|
||||
@@ -22,3 +22,9 @@ class ToolSourceType(str, Enum):
|
||||
|
||||
python = "python"
|
||||
json = "json"
|
||||
|
||||
|
||||
class ActorType(str, Enum):
|
||||
LETTA_USER = "letta_user"
|
||||
LETTA_AGENT = "letta_agent"
|
||||
LETTA_SYSTEM = "letta_system"
|
||||
|
||||
@@ -370,17 +370,19 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
return []
|
||||
|
||||
@handle_db_timeout
|
||||
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
def create(self, db_session: "Session", actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase":
|
||||
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
|
||||
if actor:
|
||||
self._set_created_and_updated_by_fields(actor.id)
|
||||
try:
|
||||
with db_session as session:
|
||||
session.add(self)
|
||||
session.commit()
|
||||
session.refresh(self)
|
||||
return self
|
||||
db_session.add(self)
|
||||
if no_commit:
|
||||
db_session.flush() # no commit, just flush to get PK
|
||||
else:
|
||||
db_session.commit()
|
||||
db_session.refresh(self)
|
||||
return self
|
||||
except (DBAPIError, IntegrityError) as e:
|
||||
self._handle_dbapi_error(e)
|
||||
|
||||
@@ -455,18 +457,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
|
||||
|
||||
@handle_db_timeout
|
||||
def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
def update(self, db_session: Session, actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase":
|
||||
logger.debug(...)
|
||||
if actor:
|
||||
self._set_created_and_updated_by_fields(actor.id)
|
||||
|
||||
self.set_updated_at()
|
||||
|
||||
with db_session as session:
|
||||
session.add(self)
|
||||
session.commit()
|
||||
session.refresh(self)
|
||||
return self
|
||||
# remove the context manager:
|
||||
db_session.add(self)
|
||||
if no_commit:
|
||||
db_session.flush() # no commit, just flush to get PK
|
||||
else:
|
||||
db_session.commit()
|
||||
db_session.refresh(self)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
from letta.orm.block import Block as BlockModel
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.enums import ActorType
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate, Human, Persona
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@@ -21,7 +24,7 @@ class BlockManager:
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_or_update_block(self, block: Block, actor: PydanticUser) -> PydanticBlock:
|
||||
def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Create a new block based on the Block schema."""
|
||||
db_block = self.get_block_by_id(block.id, actor)
|
||||
if db_block:
|
||||
@@ -140,3 +143,59 @@ class BlockManager:
|
||||
agents_pydantic = [agent.to_pydantic() for agent in agents_orm]
|
||||
|
||||
return agents_pydantic
|
||||
|
||||
# Block History Functions
|
||||
|
||||
@enforce_types
|
||||
def checkpoint_block(
|
||||
self,
|
||||
block_id: str,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
use_preloaded_block: Optional[BlockModel] = None, # TODO: Useful for testing concurrency
|
||||
) -> PydanticBlock:
|
||||
"""
|
||||
Create a new checkpoint for the given Block by copying its
|
||||
current state into BlockHistory, using SQLAlchemy's built-in
|
||||
version_id_col for concurrency checks.
|
||||
|
||||
Note: We only have a single commit at the end, to avoid weird intermediate states.
|
||||
e.g. created a BlockHistory, but the block update failed
|
||||
"""
|
||||
"""If `use_preloaded_block` is given, skip re-reading from DB."""
|
||||
with self.session_maker() as session:
|
||||
# 1) Load the block via the ORM
|
||||
if use_preloaded_block is not None:
|
||||
block = session.merge(use_preloaded_block)
|
||||
else:
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
|
||||
# 2) Create a new sequence number for BlockHistory
|
||||
current_max_seq = session.query(func.max(BlockHistory.sequence_number)).filter(BlockHistory.block_id == block_id).scalar()
|
||||
next_seq = (current_max_seq or 0) + 1
|
||||
|
||||
# 3) Create a snapshot in BlockHistory
|
||||
history_entry = BlockHistory(
|
||||
organization_id=actor.organization_id,
|
||||
block_id=block.id,
|
||||
sequence_number=next_seq,
|
||||
description=block.description,
|
||||
label=block.label,
|
||||
value=block.value,
|
||||
limit=block.limit,
|
||||
metadata_=block.metadata_,
|
||||
actor_type=ActorType.LETTA_AGENT if agent_id else ActorType.LETTA_USER,
|
||||
actor_id=agent_id if agent_id else actor.id,
|
||||
)
|
||||
history_entry.create(session, actor=actor, no_commit=True)
|
||||
|
||||
# 4) Update the block’s pointer
|
||||
block.current_history_entry_id = history_entry.id
|
||||
|
||||
# 5) Now just flush; SQLAlchemy will:
|
||||
block = block.update(db_session=session, actor=actor, no_commit=True)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Return the block’s new state
|
||||
return block.to_pydantic()
|
||||
|
||||
@@ -3,19 +3,22 @@ import random
|
||||
import string
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm.exc import StaleDataError
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_TOOL_EXECUTION_DIR, MCP_TOOL_TAG_NAME_PREFIX, MULTI_AGENT_TOOLS
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.orm import Base
|
||||
from letta.orm.enums import JobType, ToolType
|
||||
from letta.orm import Base, Block
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.enums import ActorType, JobType, ToolType
|
||||
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
|
||||
from letta.schemas.agent import CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
@@ -46,6 +49,7 @@ from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.schemas.tool_rule import InitToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.schemas.user import UserUpdate
|
||||
from letta.server.db import db_context
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
@@ -67,11 +71,12 @@ USING_SQLITE = not bool(os.getenv("LETTA_PG_URI"))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_tables():
|
||||
from letta.server.db import db_context
|
||||
|
||||
def _clear_tables():
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
||||
# If this is the block_history table, skip it
|
||||
if table.name == "block_history":
|
||||
continue
|
||||
session.execute(table.delete()) # Truncate table
|
||||
session.commit()
|
||||
|
||||
@@ -2366,7 +2371,7 @@ def test_message_listing_text_search(server: SyncServer, hello_world_message_fix
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block Manager Tests
|
||||
# Block Manager Tests - Basic
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@@ -2575,6 +2580,153 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de
|
||||
assert charles_agent.id in agent_state_ids
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block Manager Tests - History (Undo/Redo/Checkpoint)
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_checkpoint_creates_history(server: SyncServer, default_user):
|
||||
"""
|
||||
Ensures that calling checkpoint_block creates a BlockHistory row and updates
|
||||
the block's current_history_entry_id appropriately.
|
||||
"""
|
||||
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Create a block
|
||||
initial_value = "Initial block content"
|
||||
created_block = block_manager.create_or_update_block(PydanticBlock(label="test_checkpoint", value=initial_value), actor=default_user)
|
||||
|
||||
# Act: checkpoint it
|
||||
block_manager.checkpoint_block(block_id=created_block.id, actor=default_user)
|
||||
|
||||
with db_context() as session:
|
||||
# Get BlockHistory entries for this block
|
||||
history_entries: List[BlockHistory] = session.query(BlockHistory).filter(BlockHistory.block_id == created_block.id).all()
|
||||
assert len(history_entries) == 1, "Exactly one history entry should be created"
|
||||
hist = history_entries[0]
|
||||
|
||||
# Fetch ORM block for internal checks
|
||||
db_block = session.get(Block, created_block.id)
|
||||
|
||||
assert hist.sequence_number == 1
|
||||
assert hist.value == initial_value
|
||||
assert hist.actor_type == ActorType.LETTA_USER
|
||||
assert hist.actor_id == default_user.id
|
||||
assert db_block.current_history_entry_id == hist.id
|
||||
|
||||
|
||||
def test_multiple_checkpoints(server: SyncServer, default_user):
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Create a block
|
||||
block = block_manager.create_or_update_block(PydanticBlock(label="test_multi_checkpoint", value="v1"), actor=default_user)
|
||||
|
||||
# 1) First checkpoint
|
||||
block_manager.checkpoint_block(block_id=block.id, actor=default_user)
|
||||
|
||||
# 2) Update block content
|
||||
updated_block_data = PydanticBlock(**block.dict())
|
||||
updated_block_data.value = "v2"
|
||||
block_manager.create_or_update_block(updated_block_data, actor=default_user)
|
||||
|
||||
# 3) Second checkpoint
|
||||
block_manager.checkpoint_block(block_id=block.id, actor=default_user)
|
||||
|
||||
with db_context() as session:
|
||||
history_entries = (
|
||||
session.query(BlockHistory).filter(BlockHistory.block_id == block.id).order_by(BlockHistory.sequence_number.asc()).all()
|
||||
)
|
||||
assert len(history_entries) == 2, "Should have two history entries"
|
||||
|
||||
# First is seq=1, value='v1'
|
||||
assert history_entries[0].sequence_number == 1
|
||||
assert history_entries[0].value == "v1"
|
||||
|
||||
# Second is seq=2, value='v2'
|
||||
assert history_entries[1].sequence_number == 2
|
||||
assert history_entries[1].value == "v2"
|
||||
|
||||
# The block should now point to the second entry
|
||||
db_block = session.get(Block, block.id)
|
||||
assert db_block.current_history_entry_id == history_entries[1].id
|
||||
|
||||
|
||||
def test_checkpoint_with_agent_id(server: SyncServer, default_user, sarah_agent):
|
||||
"""
|
||||
Ensures that if we pass agent_id to checkpoint_block, we get
|
||||
actor_type=LETTA_AGENT, actor_id=<agent.id> in BlockHistory.
|
||||
"""
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Create a block
|
||||
block = block_manager.create_or_update_block(PydanticBlock(label="test_agent_checkpoint", value="Agent content"), actor=default_user)
|
||||
|
||||
# Checkpoint with agent_id
|
||||
block_manager.checkpoint_block(block_id=block.id, actor=default_user, agent_id=sarah_agent.id)
|
||||
|
||||
# Verify
|
||||
with db_context() as session:
|
||||
hist_entry = session.query(BlockHistory).filter(BlockHistory.block_id == block.id).one()
|
||||
assert hist_entry.actor_type == ActorType.LETTA_AGENT
|
||||
assert hist_entry.actor_id == sarah_agent.id
|
||||
|
||||
|
||||
def test_checkpoint_with_no_state_change(server: SyncServer, default_user):
|
||||
"""
|
||||
If we call checkpoint_block twice without any edits,
|
||||
we expect two entries or only one, depending on your policy.
|
||||
"""
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Create block
|
||||
block = block_manager.create_or_update_block(PydanticBlock(label="test_no_change", value="original"), actor=default_user)
|
||||
|
||||
# 1) checkpoint
|
||||
block_manager.checkpoint_block(block_id=block.id, actor=default_user)
|
||||
# 2) checkpoint again (no changes)
|
||||
block_manager.checkpoint_block(block_id=block.id, actor=default_user)
|
||||
|
||||
with db_context() as session:
|
||||
all_hist = session.query(BlockHistory).filter(BlockHistory.block_id == block.id).all()
|
||||
assert len(all_hist) == 2
|
||||
|
||||
|
||||
def test_checkpoint_concurrency_stale(server: SyncServer, default_user):
|
||||
block_manager = BlockManager()
|
||||
|
||||
# create block
|
||||
block = block_manager.create_or_update_block(PydanticBlock(label="test_stale_checkpoint", value="hello"), actor=default_user)
|
||||
|
||||
# session1 loads
|
||||
with db_context() as s1:
|
||||
block_s1 = s1.get(Block, block.id) # version=1
|
||||
|
||||
# session2 loads
|
||||
with db_context() as s2:
|
||||
block_s2 = s2.get(Block, block.id) # also version=1
|
||||
|
||||
# session1 checkpoint => version=2
|
||||
with db_context() as s1:
|
||||
block_s1 = s1.merge(block_s1)
|
||||
block_manager.checkpoint_block(
|
||||
block_id=block_s1.id,
|
||||
actor=default_user,
|
||||
use_preloaded_block=block_s1, # let manager use the object in memory
|
||||
)
|
||||
# commits inside checkpoint_block => version goes to 2
|
||||
|
||||
# session2 tries to checkpoint => sees old version=1 => stale error
|
||||
with pytest.raises(StaleDataError):
|
||||
with db_context() as s2:
|
||||
block_s2 = s2.merge(block_s2)
|
||||
block_manager.checkpoint_block(
|
||||
block_id=block_s2.id,
|
||||
actor=default_user,
|
||||
use_preloaded_block=block_s2,
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Identity Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user