feat: Add block history tables (#1489)

This commit is contained in:
Matthew Zhou
2025-03-31 16:39:23 -07:00
committed by GitHub
parent a0ebfa0cd1
commit aa05df68a7
8 changed files with 380 additions and 24 deletions

View File

@@ -0,0 +1,65 @@
"""Add block history tables
Revision ID: bff040379479
Revises: a66510f83fc2
Create Date: 2025-03-31 14:49:30.449052
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "bff040379479"
down_revision: Union[str, None] = "a66510f83fc2"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"block_history",
sa.Column("description", sa.Text(), nullable=True),
sa.Column("label", sa.String(), nullable=False),
sa.Column("value", sa.Text(), nullable=False),
sa.Column("limit", sa.BigInteger(), nullable=False),
sa.Column("metadata_", sa.JSON(), nullable=True),
sa.Column("actor_type", sa.String(), nullable=True),
sa.Column("actor_id", sa.String(), nullable=True),
sa.Column("block_id", sa.String(), nullable=False),
sa.Column("sequence_number", sa.Integer(), nullable=False),
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.ForeignKeyConstraint(["block_id"], ["block.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_block_history_block_id_sequence", "block_history", ["block_id", "sequence_number"], unique=True)
op.add_column("block", sa.Column("current_history_entry_id", sa.String(), nullable=True))
op.add_column("block", sa.Column("version", sa.Integer(), server_default="1", nullable=False))
op.create_index(op.f("ix_block_current_history_entry_id"), "block", ["current_history_entry_id"], unique=False)
op.create_foreign_key("fk_block_current_history_entry", "block", "block_history", ["current_history_entry_id"], ["id"], use_alter=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("fk_block_current_history_entry", "block", type_="foreignkey")
op.drop_index(op.f("ix_block_current_history_entry_id"), table_name="block")
op.drop_column("block", "version")
op.drop_column("block", "current_history_entry_id")
op.drop_index("ix_block_history_block_id_sequence", table_name="block_history")
op.drop_table("block_history")
# ### end Alembic commands ###

View File

@@ -2,6 +2,7 @@ from letta.orm.agent import Agent
from letta.orm.agents_tags import AgentsTags
from letta.orm.base import Base
from letta.orm.block import Block
from letta.orm.block_history import BlockHistory
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.file import FileMetadata
from letta.orm.group import Group

View File

@@ -1,9 +1,10 @@
from typing import TYPE_CHECKING, List, Optional, Type
from sqlalchemy import JSON, BigInteger, Index, Integer, UniqueConstraint, event
from sqlalchemy.orm import Mapped, attributes, mapped_column, relationship
from sqlalchemy import JSON, BigInteger, ForeignKey, Index, Integer, String, UniqueConstraint, event
from sqlalchemy.orm import Mapped, attributes, declared_attr, mapped_column, relationship
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
from letta.orm.block_history import BlockHistory
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
@@ -38,6 +39,17 @@ class Block(OrganizationMixin, SqlalchemyBase):
limit: Mapped[BigInteger] = mapped_column(Integer, default=CORE_MEMORY_BLOCK_CHAR_LIMIT, doc="Character limit of the block.")
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default={}, doc="arbitrary information related to the block.")
# history pointers / locking mechanisms
current_history_entry_id: Mapped[Optional[str]] = mapped_column(
String, ForeignKey("block_history.id", name="fk_block_current_history_entry", use_alter=True), nullable=True, index=True
)
version: Mapped[int] = mapped_column(
Integer, nullable=False, default=1, server_default="1", doc="Optimistic locking version counter, incremented on each state change."
)
# NOTE: This takes advantage of built-in optimistic locking functionality by SqlAlchemy
# https://docs.sqlalchemy.org/en/20/orm/versioning.html
__mapper_args__ = {"version_id_col": version}
# relationships
organization: Mapped[Optional["Organization"]] = relationship("Organization")
agents: Mapped[List["Agent"]] = relationship(
@@ -68,6 +80,17 @@ class Block(OrganizationMixin, SqlalchemyBase):
model_dict["metadata"] = self.metadata_
return Schema.model_validate(model_dict)
@declared_attr
def current_history_entry(cls) -> Mapped[Optional["BlockHistory"]]:
# Relationship to easily load the specific history entry that is current
return relationship(
"BlockHistory",
primaryjoin=lambda: cls.current_history_entry_id == BlockHistory.id,
foreign_keys=[cls.current_history_entry_id],
lazy="joined", # Typically want current history details readily available
post_update=True,
) # Helps manage potential FK cycles
@event.listens_for(Block, "after_update") # Changed from 'before_update'
def block_before_update(mapper, connection, target):

View File

@@ -0,0 +1,46 @@
import uuid
from typing import Optional
from sqlalchemy import JSON, BigInteger, ForeignKey, Index, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.enums import ActorType
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
class BlockHistory(OrganizationMixin, SqlalchemyBase):
"""Stores a single historical state of a Block for undo/redo functionality."""
__tablename__ = "block_history"
__table_args__ = (
# PRIMARY lookup index for finding specific history entries & ordering
Index("ix_block_history_block_id_sequence", "block_id", "sequence_number", unique=True),
)
# agent generates its own id
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
# TODO: Some still rely on the Pydantic object to do this
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"block_hist-{uuid.uuid4()}")
# Snapshot State Fields (Copied from Block)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
label: Mapped[str] = mapped_column(String, nullable=False)
value: Mapped[str] = mapped_column(Text, nullable=False)
limit: Mapped[BigInteger] = mapped_column(BigInteger, nullable=False)
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
# Editor info
# These are not made to be FKs because these may not always exist (e.g. a User be deleted after they made a checkpoint)
actor_type: Mapped[Optional[ActorType]] = mapped_column(String, nullable=True)
actor_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
# Relationships
block_id: Mapped[str] = mapped_column(
String, ForeignKey("block.id", ondelete="CASCADE"), nullable=False # History deleted if Block is deleted
)
sequence_number: Mapped[int] = mapped_column(
Integer, nullable=False, doc="Monotonically increasing sequence number for the history of a specific block_id, starting from 1."
)

View File

@@ -22,3 +22,9 @@ class ToolSourceType(str, Enum):
python = "python"
json = "json"
class ActorType(str, Enum):
LETTA_USER = "letta_user"
LETTA_AGENT = "letta_agent"
LETTA_SYSTEM = "letta_system"

View File

@@ -370,17 +370,19 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
return []
@handle_db_timeout
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
def create(self, db_session: "Session", actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase":
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
if actor:
self._set_created_and_updated_by_fields(actor.id)
try:
with db_session as session:
session.add(self)
session.commit()
session.refresh(self)
return self
db_session.add(self)
if no_commit:
db_session.flush() # no commit, just flush to get PK
else:
db_session.commit()
db_session.refresh(self)
return self
except (DBAPIError, IntegrityError) as e:
self._handle_dbapi_error(e)
@@ -455,18 +457,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
@handle_db_timeout
def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
def update(self, db_session: Session, actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase":
logger.debug(...)
if actor:
self._set_created_and_updated_by_fields(actor.id)
self.set_updated_at()
with db_session as session:
session.add(self)
session.commit()
session.refresh(self)
return self
# remove the context manager:
db_session.add(self)
if no_commit:
db_session.flush() # no commit, just flush to get PK
else:
db_session.commit()
db_session.refresh(self)
return self
@classmethod
@handle_db_timeout

View File

@@ -1,10 +1,13 @@
import os
from typing import List, Optional
from sqlalchemy import func
from letta.orm.block import Block as BlockModel
from letta.orm.block_history import BlockHistory
from letta.orm.enums import ActorType
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.block import Block
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate, Human, Persona
from letta.schemas.user import User as PydanticUser
@@ -21,7 +24,7 @@ class BlockManager:
self.session_maker = db_context
@enforce_types
def create_or_update_block(self, block: Block, actor: PydanticUser) -> PydanticBlock:
def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
"""Create a new block based on the Block schema."""
db_block = self.get_block_by_id(block.id, actor)
if db_block:
@@ -140,3 +143,59 @@ class BlockManager:
agents_pydantic = [agent.to_pydantic() for agent in agents_orm]
return agents_pydantic
# Block History Functions
@enforce_types
def checkpoint_block(
self,
block_id: str,
actor: PydanticUser,
agent_id: Optional[str] = None,
use_preloaded_block: Optional[BlockModel] = None, # TODO: Useful for testing concurrency
) -> PydanticBlock:
"""
Create a new checkpoint for the given Block by copying its
current state into BlockHistory, using SQLAlchemy's built-in
version_id_col for concurrency checks.
Note: We only have a single commit at the end, to avoid weird intermediate states.
e.g. created a BlockHistory, but the block update failed
"""
"""If `use_preloaded_block` is given, skip re-reading from DB."""
with self.session_maker() as session:
# 1) Load the block via the ORM
if use_preloaded_block is not None:
block = session.merge(use_preloaded_block)
else:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
# 2) Create a new sequence number for BlockHistory
current_max_seq = session.query(func.max(BlockHistory.sequence_number)).filter(BlockHistory.block_id == block_id).scalar()
next_seq = (current_max_seq or 0) + 1
# 3) Create a snapshot in BlockHistory
history_entry = BlockHistory(
organization_id=actor.organization_id,
block_id=block.id,
sequence_number=next_seq,
description=block.description,
label=block.label,
value=block.value,
limit=block.limit,
metadata_=block.metadata_,
actor_type=ActorType.LETTA_AGENT if agent_id else ActorType.LETTA_USER,
actor_id=agent_id if agent_id else actor.id,
)
history_entry.create(session, actor=actor, no_commit=True)
# 4) Update the blocks pointer
block.current_history_entry_id = history_entry.id
# 5) Now just flush; SQLAlchemy will:
block = block.update(db_session=session, actor=actor, no_commit=True)
session.commit()
# Return the blocks new state
return block.to_pydantic()

View File

@@ -3,19 +3,22 @@ import random
import string
import time
from datetime import datetime, timedelta
from typing import List
import pytest
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import StaleDataError
from letta.config import LettaConfig
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_TOOL_EXECUTION_DIR, MCP_TOOL_TAG_NAME_PREFIX, MULTI_AGENT_TOOLS
from letta.embeddings import embedding_model
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.functions.mcp_client.types import MCPTool
from letta.orm import Base
from letta.orm.enums import JobType, ToolType
from letta.orm import Base, Block
from letta.orm.block_history import BlockHistory
from letta.orm.enums import ActorType, JobType, ToolType
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock
@@ -46,6 +49,7 @@ from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.schemas.tool_rule import InitToolRule
from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserUpdate
from letta.server.db import db_context
from letta.server.server import SyncServer
from letta.services.block_manager import BlockManager
from letta.services.organization_manager import OrganizationManager
@@ -67,11 +71,12 @@ USING_SQLITE = not bool(os.getenv("LETTA_PG_URI"))
@pytest.fixture(autouse=True)
def clear_tables():
from letta.server.db import db_context
def _clear_tables():
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
# If this is the block_history table, skip it
if table.name == "block_history":
continue
session.execute(table.delete()) # Truncate table
session.commit()
@@ -2366,7 +2371,7 @@ def test_message_listing_text_search(server: SyncServer, hello_world_message_fix
# ======================================================================================================================
# Block Manager Tests
# Block Manager Tests - Basic
# ======================================================================================================================
@@ -2575,6 +2580,153 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de
assert charles_agent.id in agent_state_ids
# ======================================================================================================================
# Block Manager Tests - History (Undo/Redo/Checkpoint)
# ======================================================================================================================
def test_checkpoint_creates_history(server: SyncServer, default_user):
"""
Ensures that calling checkpoint_block creates a BlockHistory row and updates
the block's current_history_entry_id appropriately.
"""
block_manager = BlockManager()
# Create a block
initial_value = "Initial block content"
created_block = block_manager.create_or_update_block(PydanticBlock(label="test_checkpoint", value=initial_value), actor=default_user)
# Act: checkpoint it
block_manager.checkpoint_block(block_id=created_block.id, actor=default_user)
with db_context() as session:
# Get BlockHistory entries for this block
history_entries: List[BlockHistory] = session.query(BlockHistory).filter(BlockHistory.block_id == created_block.id).all()
assert len(history_entries) == 1, "Exactly one history entry should be created"
hist = history_entries[0]
# Fetch ORM block for internal checks
db_block = session.get(Block, created_block.id)
assert hist.sequence_number == 1
assert hist.value == initial_value
assert hist.actor_type == ActorType.LETTA_USER
assert hist.actor_id == default_user.id
assert db_block.current_history_entry_id == hist.id
def test_multiple_checkpoints(server: SyncServer, default_user):
block_manager = BlockManager()
# Create a block
block = block_manager.create_or_update_block(PydanticBlock(label="test_multi_checkpoint", value="v1"), actor=default_user)
# 1) First checkpoint
block_manager.checkpoint_block(block_id=block.id, actor=default_user)
# 2) Update block content
updated_block_data = PydanticBlock(**block.dict())
updated_block_data.value = "v2"
block_manager.create_or_update_block(updated_block_data, actor=default_user)
# 3) Second checkpoint
block_manager.checkpoint_block(block_id=block.id, actor=default_user)
with db_context() as session:
history_entries = (
session.query(BlockHistory).filter(BlockHistory.block_id == block.id).order_by(BlockHistory.sequence_number.asc()).all()
)
assert len(history_entries) == 2, "Should have two history entries"
# First is seq=1, value='v1'
assert history_entries[0].sequence_number == 1
assert history_entries[0].value == "v1"
# Second is seq=2, value='v2'
assert history_entries[1].sequence_number == 2
assert history_entries[1].value == "v2"
# The block should now point to the second entry
db_block = session.get(Block, block.id)
assert db_block.current_history_entry_id == history_entries[1].id
def test_checkpoint_with_agent_id(server: SyncServer, default_user, sarah_agent):
"""
Ensures that if we pass agent_id to checkpoint_block, we get
actor_type=LETTA_AGENT, actor_id=<agent.id> in BlockHistory.
"""
block_manager = BlockManager()
# Create a block
block = block_manager.create_or_update_block(PydanticBlock(label="test_agent_checkpoint", value="Agent content"), actor=default_user)
# Checkpoint with agent_id
block_manager.checkpoint_block(block_id=block.id, actor=default_user, agent_id=sarah_agent.id)
# Verify
with db_context() as session:
hist_entry = session.query(BlockHistory).filter(BlockHistory.block_id == block.id).one()
assert hist_entry.actor_type == ActorType.LETTA_AGENT
assert hist_entry.actor_id == sarah_agent.id
def test_checkpoint_with_no_state_change(server: SyncServer, default_user):
"""
If we call checkpoint_block twice without any edits,
we expect two entries or only one, depending on your policy.
"""
block_manager = BlockManager()
# Create block
block = block_manager.create_or_update_block(PydanticBlock(label="test_no_change", value="original"), actor=default_user)
# 1) checkpoint
block_manager.checkpoint_block(block_id=block.id, actor=default_user)
# 2) checkpoint again (no changes)
block_manager.checkpoint_block(block_id=block.id, actor=default_user)
with db_context() as session:
all_hist = session.query(BlockHistory).filter(BlockHistory.block_id == block.id).all()
assert len(all_hist) == 2
def test_checkpoint_concurrency_stale(server: SyncServer, default_user):
block_manager = BlockManager()
# create block
block = block_manager.create_or_update_block(PydanticBlock(label="test_stale_checkpoint", value="hello"), actor=default_user)
# session1 loads
with db_context() as s1:
block_s1 = s1.get(Block, block.id) # version=1
# session2 loads
with db_context() as s2:
block_s2 = s2.get(Block, block.id) # also version=1
# session1 checkpoint => version=2
with db_context() as s1:
block_s1 = s1.merge(block_s1)
block_manager.checkpoint_block(
block_id=block_s1.id,
actor=default_user,
use_preloaded_block=block_s1, # let manager use the object in memory
)
# commits inside checkpoint_block => version goes to 2
# session2 tries to checkpoint => sees old version=1 => stale error
with pytest.raises(StaleDataError):
with db_context() as s2:
block_s2 = s2.merge(block_s2)
block_manager.checkpoint_block(
block_id=block_s2.id,
actor=default_user,
use_preloaded_block=block_s2,
)
# ======================================================================================================================
# Identity Manager Tests
# ======================================================================================================================