diff --git a/alembic/versions/bff040379479_add_block_history_tables.py b/alembic/versions/bff040379479_add_block_history_tables.py new file mode 100644 index 00000000..80c6cb33 --- /dev/null +++ b/alembic/versions/bff040379479_add_block_history_tables.py @@ -0,0 +1,65 @@ +"""Add block history tables + +Revision ID: bff040379479 +Revises: a66510f83fc2 +Create Date: 2025-03-31 14:49:30.449052 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "bff040379479" +down_revision: Union[str, None] = "a66510f83fc2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "block_history", + sa.Column("description", sa.Text(), nullable=True), + sa.Column("label", sa.String(), nullable=False), + sa.Column("value", sa.Text(), nullable=False), + sa.Column("limit", sa.BigInteger(), nullable=False), + sa.Column("metadata_", sa.JSON(), nullable=True), + sa.Column("actor_type", sa.String(), nullable=True), + sa.Column("actor_id", sa.String(), nullable=True), + sa.Column("block_id", sa.String(), nullable=False), + sa.Column("sequence_number", sa.Integer(), nullable=False), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.ForeignKeyConstraint(["block_id"], ["block.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_block_history_block_id_sequence", "block_history", ["block_id", "sequence_number"], unique=True) + op.add_column("block", sa.Column("current_history_entry_id", sa.String(), nullable=True)) + op.add_column("block", sa.Column("version", sa.Integer(), server_default="1", nullable=False)) + op.create_index(op.f("ix_block_current_history_entry_id"), "block", ["current_history_entry_id"], unique=False) + op.create_foreign_key("fk_block_current_history_entry", "block", "block_history", ["current_history_entry_id"], ["id"], use_alter=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("fk_block_current_history_entry", "block", type_="foreignkey") + op.drop_index(op.f("ix_block_current_history_entry_id"), table_name="block") + op.drop_column("block", "version") + op.drop_column("block", "current_history_entry_id") + op.drop_index("ix_block_history_block_id_sequence", table_name="block_history") + op.drop_table("block_history") + # ### end Alembic commands ### diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index a43f6e0b..0dad525c 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -2,6 +2,7 @@ from letta.orm.agent import Agent from letta.orm.agents_tags import AgentsTags from letta.orm.base import Base from letta.orm.block import Block +from letta.orm.block_history import BlockHistory from letta.orm.blocks_agents import BlocksAgents from letta.orm.file import FileMetadata from letta.orm.group import Group diff --git a/letta/orm/block.py b/letta/orm/block.py index 940a8ec9..edd56d0c 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -1,9 +1,10 @@ from typing import TYPE_CHECKING, List, Optional, Type -from sqlalchemy import JSON, BigInteger, Index, Integer, UniqueConstraint, event -from sqlalchemy.orm import Mapped, attributes, mapped_column, relationship +from sqlalchemy import JSON, BigInteger, ForeignKey, Index, Integer, String, UniqueConstraint, event +from sqlalchemy.orm import Mapped, attributes, declared_attr, mapped_column, relationship from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT +from letta.orm.block_history import BlockHistory from letta.orm.blocks_agents import BlocksAgents from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase @@ -38,6 +39,17 @@ class Block(OrganizationMixin, SqlalchemyBase): limit: Mapped[BigInteger] = mapped_column(Integer, default=CORE_MEMORY_BLOCK_CHAR_LIMIT, doc="Character limit of the block.") metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default={}, doc="arbitrary information related to the block.") + # history pointers / locking mechanisms + current_history_entry_id: Mapped[Optional[str]] = mapped_column( + String, ForeignKey("block_history.id", name="fk_block_current_history_entry", use_alter=True), nullable=True, index=True + ) + version: Mapped[int] = mapped_column( + Integer, nullable=False, default=1, server_default="1", doc="Optimistic locking version counter, incremented on each state change." + ) + # NOTE: This takes advantage of built-in optimistic locking functionality by SqlAlchemy + # https://docs.sqlalchemy.org/en/20/orm/versioning.html + __mapper_args__ = {"version_id_col": version} + # relationships organization: Mapped[Optional["Organization"]] = relationship("Organization") agents: Mapped[List["Agent"]] = relationship( @@ -68,6 +80,17 @@ class Block(OrganizationMixin, SqlalchemyBase): model_dict["metadata"] = self.metadata_ return Schema.model_validate(model_dict) + @declared_attr + def current_history_entry(cls) -> Mapped[Optional["BlockHistory"]]: + # Relationship to easily load the specific history entry that is current + return relationship( + "BlockHistory", + primaryjoin=lambda: cls.current_history_entry_id == BlockHistory.id, + foreign_keys=[cls.current_history_entry_id], + lazy="joined", # Typically want current history details readily available + post_update=True, + ) # Helps manage potential FK cycles + @event.listens_for(Block, "after_update") # Changed from 'before_update' def block_before_update(mapper, connection, target): diff --git a/letta/orm/block_history.py b/letta/orm/block_history.py new file mode 100644 index 00000000..de1b0c5d --- /dev/null +++ b/letta/orm/block_history.py @@ -0,0 +1,46 @@ +import uuid +from typing import Optional + +from sqlalchemy import JSON, BigInteger, ForeignKey, Index, Integer, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from letta.orm.enums import ActorType +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase + + +class BlockHistory(OrganizationMixin, SqlalchemyBase): + """Stores a single historical state of a Block for undo/redo functionality.""" + + __tablename__ = "block_history" + + __table_args__ = ( + # PRIMARY lookup index for finding specific history entries & ordering + Index("ix_block_history_block_id_sequence", "block_id", "sequence_number", unique=True), + ) + + # agent generates its own id + # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase + # TODO: Some still rely on the Pydantic object to do this + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"block_hist-{uuid.uuid4()}") + + # Snapshot State Fields (Copied from Block) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + label: Mapped[str] = mapped_column(String, nullable=False) + value: Mapped[str] = mapped_column(Text, nullable=False) + limit: Mapped[BigInteger] = mapped_column(BigInteger, nullable=False) + metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) + + # Editor info + # These are not made to be FKs because these may not always exist (e.g. a User be deleted after they made a checkpoint) + actor_type: Mapped[Optional[ActorType]] = mapped_column(String, nullable=True) + actor_id: Mapped[Optional[str]] = mapped_column(String, nullable=True) + + # Relationships + block_id: Mapped[str] = mapped_column( + String, ForeignKey("block.id", ondelete="CASCADE"), nullable=False # History deleted if Block is deleted + ) + + sequence_number: Mapped[int] = mapped_column( + Integer, nullable=False, doc="Monotonically increasing sequence number for the history of a specific block_id, starting from 1." + ) diff --git a/letta/orm/enums.py b/letta/orm/enums.py index 9f014162..5bf0648c 100644 --- a/letta/orm/enums.py +++ b/letta/orm/enums.py @@ -22,3 +22,9 @@ class ToolSourceType(str, Enum): python = "python" json = "json" + + +class ActorType(str, Enum): + LETTA_USER = "letta_user" + LETTA_AGENT = "letta_agent" + LETTA_SYSTEM = "letta_system" diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 85ccfd85..8a06b1c6 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -370,17 +370,19 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): return [] @handle_db_timeout - def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": + def create(self, db_session: "Session", actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase": logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}") if actor: self._set_created_and_updated_by_fields(actor.id) try: - with db_session as session: - session.add(self) - session.commit() - session.refresh(self) - return self + db_session.add(self) + if no_commit: + db_session.flush() # no commit, just flush to get PK + else: + db_session.commit() + db_session.refresh(self) + return self except (DBAPIError, IntegrityError) as e: self._handle_dbapi_error(e) @@ -455,18 +457,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted") @handle_db_timeout - def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": - logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}") + def update(self, db_session: Session, actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase": + logger.debug(...) if actor: self._set_created_and_updated_by_fields(actor.id) - self.set_updated_at() - with db_session as session: - session.add(self) - session.commit() - session.refresh(self) - return self + # remove the context manager: + db_session.add(self) + if no_commit: + db_session.flush() # no commit, just flush to get PK + else: + db_session.commit() + db_session.refresh(self) + return self @classmethod @handle_db_timeout diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index cc09c0b1..3bf6996d 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -1,10 +1,13 @@ import os from typing import List, Optional +from sqlalchemy import func + from letta.orm.block import Block as BlockModel +from letta.orm.block_history import BlockHistory +from letta.orm.enums import ActorType from letta.orm.errors import NoResultFound from letta.schemas.agent import AgentState as PydanticAgentState -from letta.schemas.block import Block from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate, Human, Persona from letta.schemas.user import User as PydanticUser @@ -21,7 +24,7 @@ class BlockManager: self.session_maker = db_context @enforce_types - def create_or_update_block(self, block: Block, actor: PydanticUser) -> PydanticBlock: + def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock: """Create a new block based on the Block schema.""" db_block = self.get_block_by_id(block.id, actor) if db_block: @@ -140,3 +143,59 @@ class BlockManager: agents_pydantic = [agent.to_pydantic() for agent in agents_orm] return agents_pydantic + + # Block History Functions + + @enforce_types + def checkpoint_block( + self, + block_id: str, + actor: PydanticUser, + agent_id: Optional[str] = None, + use_preloaded_block: Optional[BlockModel] = None, # TODO: Useful for testing concurrency + ) -> PydanticBlock: + """ + Create a new checkpoint for the given Block by copying its + current state into BlockHistory, using SQLAlchemy's built-in + version_id_col for concurrency checks. + + Note: We only have a single commit at the end, to avoid weird intermediate states. + e.g. created a BlockHistory, but the block update failed + """ + """If `use_preloaded_block` is given, skip re-reading from DB.""" + with self.session_maker() as session: + # 1) Load the block via the ORM + if use_preloaded_block is not None: + block = session.merge(use_preloaded_block) + else: + block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) + + # 2) Create a new sequence number for BlockHistory + current_max_seq = session.query(func.max(BlockHistory.sequence_number)).filter(BlockHistory.block_id == block_id).scalar() + next_seq = (current_max_seq or 0) + 1 + + # 3) Create a snapshot in BlockHistory + history_entry = BlockHistory( + organization_id=actor.organization_id, + block_id=block.id, + sequence_number=next_seq, + description=block.description, + label=block.label, + value=block.value, + limit=block.limit, + metadata_=block.metadata_, + actor_type=ActorType.LETTA_AGENT if agent_id else ActorType.LETTA_USER, + actor_id=agent_id if agent_id else actor.id, + ) + history_entry.create(session, actor=actor, no_commit=True) + + # 4) Update the block’s pointer + block.current_history_entry_id = history_entry.id + + # 5) Now just flush; SQLAlchemy will: + block = block.update(db_session=session, actor=actor, no_commit=True) + + session.commit() + + # Return the block’s new state + return block.to_pydantic() diff --git a/tests/test_managers.py b/tests/test_managers.py index 9837a0b3..938be4c7 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3,19 +3,22 @@ import random import string import time from datetime import datetime, timedelta +from typing import List import pytest from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm.exc import StaleDataError from letta.config import LettaConfig from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_TOOL_EXECUTION_DIR, MCP_TOOL_TAG_NAME_PREFIX, MULTI_AGENT_TOOLS from letta.embeddings import embedding_model from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.functions.mcp_client.types import MCPTool -from letta.orm import Base -from letta.orm.enums import JobType, ToolType +from letta.orm import Base, Block +from letta.orm.block_history import BlockHistory +from letta.orm.enums import ActorType, JobType, ToolType from letta.orm.errors import NoResultFound, UniqueConstraintViolationError from letta.schemas.agent import CreateAgent, UpdateAgent from letta.schemas.block import Block as PydanticBlock @@ -46,6 +49,7 @@ from letta.schemas.tool import ToolCreate, ToolUpdate from letta.schemas.tool_rule import InitToolRule from letta.schemas.user import User as PydanticUser from letta.schemas.user import UserUpdate +from letta.server.db import db_context from letta.server.server import SyncServer from letta.services.block_manager import BlockManager from letta.services.organization_manager import OrganizationManager @@ -67,11 +71,12 @@ USING_SQLITE = not bool(os.getenv("LETTA_PG_URI")) @pytest.fixture(autouse=True) -def clear_tables(): - from letta.server.db import db_context - +def _clear_tables(): with db_context() as session: for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues + # If this is the block_history table, skip it + if table.name == "block_history": + continue session.execute(table.delete()) # Truncate table session.commit() @@ -2366,7 +2371,7 @@ def test_message_listing_text_search(server: SyncServer, hello_world_message_fix # ====================================================================================================================== -# Block Manager Tests +# Block Manager Tests - Basic # ====================================================================================================================== @@ -2575,6 +2580,153 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de assert charles_agent.id in agent_state_ids +# ====================================================================================================================== +# Block Manager Tests - History (Undo/Redo/Checkpoint) +# ====================================================================================================================== + + +def test_checkpoint_creates_history(server: SyncServer, default_user): + """ + Ensures that calling checkpoint_block creates a BlockHistory row and updates + the block's current_history_entry_id appropriately. + """ + + block_manager = BlockManager() + + # Create a block + initial_value = "Initial block content" + created_block = block_manager.create_or_update_block(PydanticBlock(label="test_checkpoint", value=initial_value), actor=default_user) + + # Act: checkpoint it + block_manager.checkpoint_block(block_id=created_block.id, actor=default_user) + + with db_context() as session: + # Get BlockHistory entries for this block + history_entries: List[BlockHistory] = session.query(BlockHistory).filter(BlockHistory.block_id == created_block.id).all() + assert len(history_entries) == 1, "Exactly one history entry should be created" + hist = history_entries[0] + + # Fetch ORM block for internal checks + db_block = session.get(Block, created_block.id) + + assert hist.sequence_number == 1 + assert hist.value == initial_value + assert hist.actor_type == ActorType.LETTA_USER + assert hist.actor_id == default_user.id + assert db_block.current_history_entry_id == hist.id + + +def test_multiple_checkpoints(server: SyncServer, default_user): + block_manager = BlockManager() + + # Create a block + block = block_manager.create_or_update_block(PydanticBlock(label="test_multi_checkpoint", value="v1"), actor=default_user) + + # 1) First checkpoint + block_manager.checkpoint_block(block_id=block.id, actor=default_user) + + # 2) Update block content + updated_block_data = PydanticBlock(**block.dict()) + updated_block_data.value = "v2" + block_manager.create_or_update_block(updated_block_data, actor=default_user) + + # 3) Second checkpoint + block_manager.checkpoint_block(block_id=block.id, actor=default_user) + + with db_context() as session: + history_entries = ( + session.query(BlockHistory).filter(BlockHistory.block_id == block.id).order_by(BlockHistory.sequence_number.asc()).all() + ) + assert len(history_entries) == 2, "Should have two history entries" + + # First is seq=1, value='v1' + assert history_entries[0].sequence_number == 1 + assert history_entries[0].value == "v1" + + # Second is seq=2, value='v2' + assert history_entries[1].sequence_number == 2 + assert history_entries[1].value == "v2" + + # The block should now point to the second entry + db_block = session.get(Block, block.id) + assert db_block.current_history_entry_id == history_entries[1].id + + +def test_checkpoint_with_agent_id(server: SyncServer, default_user, sarah_agent): + """ + Ensures that if we pass agent_id to checkpoint_block, we get + actor_type=LETTA_AGENT, actor_id= in BlockHistory. + """ + block_manager = BlockManager() + + # Create a block + block = block_manager.create_or_update_block(PydanticBlock(label="test_agent_checkpoint", value="Agent content"), actor=default_user) + + # Checkpoint with agent_id + block_manager.checkpoint_block(block_id=block.id, actor=default_user, agent_id=sarah_agent.id) + + # Verify + with db_context() as session: + hist_entry = session.query(BlockHistory).filter(BlockHistory.block_id == block.id).one() + assert hist_entry.actor_type == ActorType.LETTA_AGENT + assert hist_entry.actor_id == sarah_agent.id + + +def test_checkpoint_with_no_state_change(server: SyncServer, default_user): + """ + If we call checkpoint_block twice without any edits, + we expect two entries or only one, depending on your policy. + """ + block_manager = BlockManager() + + # Create block + block = block_manager.create_or_update_block(PydanticBlock(label="test_no_change", value="original"), actor=default_user) + + # 1) checkpoint + block_manager.checkpoint_block(block_id=block.id, actor=default_user) + # 2) checkpoint again (no changes) + block_manager.checkpoint_block(block_id=block.id, actor=default_user) + + with db_context() as session: + all_hist = session.query(BlockHistory).filter(BlockHistory.block_id == block.id).all() + assert len(all_hist) == 2 + + +def test_checkpoint_concurrency_stale(server: SyncServer, default_user): + block_manager = BlockManager() + + # create block + block = block_manager.create_or_update_block(PydanticBlock(label="test_stale_checkpoint", value="hello"), actor=default_user) + + # session1 loads + with db_context() as s1: + block_s1 = s1.get(Block, block.id) # version=1 + + # session2 loads + with db_context() as s2: + block_s2 = s2.get(Block, block.id) # also version=1 + + # session1 checkpoint => version=2 + with db_context() as s1: + block_s1 = s1.merge(block_s1) + block_manager.checkpoint_block( + block_id=block_s1.id, + actor=default_user, + use_preloaded_block=block_s1, # let manager use the object in memory + ) + # commits inside checkpoint_block => version goes to 2 + + # session2 tries to checkpoint => sees old version=1 => stale error + with pytest.raises(StaleDataError): + with db_context() as s2: + block_s2 = s2.merge(block_s2) + block_manager.checkpoint_block( + block_id=block_s2.id, + actor=default_user, + use_preloaded_block=block_s2, + ) + + # ====================================================================================================================== # Identity Manager Tests # ======================================================================================================================