feat: Add monotonic ids on messages (#1522)

This commit is contained in:
Matthew Zhou
2025-04-01 18:23:34 -07:00
committed by GitHub
parent 227b76fe0e
commit af97837c99
4 changed files with 172 additions and 36 deletions

View File

@@ -0,0 +1,144 @@
"""Add monotonically increasing IDs to messages table
Revision ID: e991d2e3b428
Revises: 74f2ede29317
Create Date: 2025-04-01 17:02:59.820272
"""
import sys
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "e991d2e3b428"
down_revision: Union[str, None] = "74f2ede29317"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
# --- Configuration ---
TABLE_NAME = "messages"
COLUMN_NAME = "sequence_id"
SEQUENCE_NAME = "message_seq_id"
INDEX_NAME = "ix_messages_agent_sequence"
UNIQUE_CONSTRAINT_NAME = f"uq_{TABLE_NAME}_{COLUMN_NAME}"
# Columns to determine the order for back-filling existing data
ORDERING_COLUMNS = ["created_at", "id"]
def print_flush(message):
"""Helper function to print and flush stdout immediately."""
print(message)
sys.stdout.flush()
def upgrade() -> None:
"""Adds sequence_id, backfills data, adds constraints and index."""
print_flush(f"\n--- Starting upgrade for revision {revision} ---")
# Step 1: Add the sequence_id column to the table, initially allowing NULL values.
# This allows us to add and backfill data without immediately enforcing NOT NULL.
print_flush(f"Step 1: Adding nullable column '{COLUMN_NAME}' to table '{TABLE_NAME}'...")
op.add_column(TABLE_NAME, sa.Column(COLUMN_NAME, sa.BigInteger(), nullable=True))
# Step 2: Create a new PostgreSQL sequence.
# This sequence will later be used as the server-side default for generating new sequence_id values.
print_flush(f"Step 2: Creating sequence '{SEQUENCE_NAME}'...")
op.execute(f"CREATE SEQUENCE {SEQUENCE_NAME} START 1;")
# Step 3: Backfill the sequence_id for existing rows based on a defined ordering.
# The SQL query does the following:
# - Uses a Common Table Expression named 'numbered_rows' to compute a row number for each row.
# - The ROW_NUMBER() window function assigns a sequential number (rn) to each row, ordered by the columns specified
# in ORDERING_COLUMNS (e.g., created_at, id) in ascending order.
# - The UPDATE statement then sets each row's sequence_id to its corresponding row number (rn)
# by joining the original table with the CTE on the id column.
print_flush(f"Step 3: Backfilling '{COLUMN_NAME}' based on order: {', '.join(ORDERING_COLUMNS)}...")
print_flush(" (This may take a while on large tables)")
try:
op.execute(
f"""
WITH numbered_rows AS (
SELECT
id,
ROW_NUMBER() OVER (ORDER BY {', '.join(ORDERING_COLUMNS)} ASC) as rn
FROM {TABLE_NAME}
)
UPDATE {TABLE_NAME}
SET {COLUMN_NAME} = numbered_rows.rn
FROM numbered_rows
WHERE {TABLE_NAME}.id = numbered_rows.id;
"""
)
print_flush(" Backfill successful.")
except Exception as e:
print_flush(f"!!! ERROR during backfill: {e}")
print_flush("!!! Migration failed. Manual intervention might be needed.")
raise
# Step 4: Set the sequence's next value to be one more than the current maximum sequence_id.
# The query works as follows:
# - It calculates the maximum value in the sequence_id column using MAX({COLUMN_NAME}).
# - COALESCE is used to default to 0 if there are no rows (i.e., the table is empty).
# - It then adds 1 to ensure that the next call to nextval() returns a number higher than any existing value.
# - The 'false' argument tells PostgreSQL that the next nextval() should return the value as-is, without pre-incrementing.
print_flush(f"Step 4: Setting sequence '{SEQUENCE_NAME}' to next value after backfill...")
op.execute(
f"""
SELECT setval('{SEQUENCE_NAME}', COALESCE(MAX({COLUMN_NAME}), 0) + 1, false)
FROM {TABLE_NAME};
"""
)
# Step 5: Now that every row has a sequence_id, alter the column to be NOT NULL.
# This enforces that all rows must have a valid sequence_id.
print_flush(f"Step 5: Altering column '{COLUMN_NAME}' to NOT NULL...")
op.alter_column(TABLE_NAME, COLUMN_NAME, existing_type=sa.BigInteger(), nullable=False)
# Step 6: Add a UNIQUE constraint on sequence_id to ensure its values remain distinct.
# This mirrors the model definition where sequence_id is defined as unique.
print_flush(f"Step 6: Creating unique constraint '{UNIQUE_CONSTRAINT_NAME}' on '{COLUMN_NAME}'...")
op.create_unique_constraint(UNIQUE_CONSTRAINT_NAME, TABLE_NAME, [COLUMN_NAME])
# Step 7: Set the server-side default for sequence_id so that future inserts automatically use the sequence.
# The server default calls nextval() on the sequence, and the "::regclass" cast helps PostgreSQL resolve the sequence name correctly.
print_flush(f"Step 7: Setting server default for '{COLUMN_NAME}' to use sequence '{SEQUENCE_NAME}'...")
op.alter_column(TABLE_NAME, COLUMN_NAME, existing_type=sa.BigInteger(), server_default=sa.text(f"nextval('{SEQUENCE_NAME}'::regclass)"))
# Step 8: Create an index on (agent_id, sequence_id) to improve performance of queries filtering on these columns.
print_flush(f"Step 8: Creating index '{INDEX_NAME}' on (agent_id, {COLUMN_NAME})...")
op.create_index(INDEX_NAME, TABLE_NAME, ["agent_id", COLUMN_NAME], unique=False)
print_flush(f"--- Upgrade for revision {revision} complete ---")
def downgrade() -> None:
"""Reverses the changes made in the upgrade function."""
print_flush(f"\n--- Starting downgrade from revision {revision} ---")
# 1. Drop the index
print_flush(f"Step 1: Dropping index '{INDEX_NAME}'...")
op.drop_index(INDEX_NAME, table_name=TABLE_NAME)
# 2. Remove the server-side default
print_flush(f"Step 2: Removing server default from '{COLUMN_NAME}'...")
op.alter_column(TABLE_NAME, COLUMN_NAME, existing_type=sa.BigInteger(), server_default=None)
# 3. Drop the unique constraint (using the explicit name)
print_flush(f"Step 3: Dropping unique constraint '{UNIQUE_CONSTRAINT_NAME}'...")
op.drop_constraint(UNIQUE_CONSTRAINT_NAME, TABLE_NAME, type_="unique")
# 4. Drop the column (this implicitly removes the NOT NULL constraint)
print_flush(f"Step 4: Dropping column '{COLUMN_NAME}'...")
op.drop_column(TABLE_NAME, COLUMN_NAME)
# 5. Drop the sequence
print_flush(f"Step 5: Dropping sequence '{SEQUENCE_NAME}'...")
op.execute(f"DROP SEQUENCE IF EXISTS {SEQUENCE_NAME};") # Use IF EXISTS for safety
print_flush(f"--- Downgrade from revision {revision} complete ---")

View File

@@ -1,7 +1,7 @@
from typing import List, Optional
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from sqlalchemy import ForeignKey, Index
from sqlalchemy import BigInteger, ForeignKey, Index, Sequence
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
@@ -20,6 +20,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
__table_args__ = (
Index("ix_messages_agent_created_at", "agent_id", "created_at"),
Index("ix_messages_created_at", "created_at", "id"),
Index("ix_messages_agent_sequence", "agent_id", "sequence_id"),
)
__pydantic_model__ = PydanticMessage
@@ -40,6 +41,11 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
)
group_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The multi-agent group that the message was sent in")
# Monotonically increasing sequence for efficient/correct listing
sequence_id: Mapped[int] = mapped_column(
BigInteger, Sequence("message_seq_id"), unique=True, nullable=False, doc="Global monotonically increasing ID"
)
# Relationships
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")

View File

@@ -1,7 +1,7 @@
import json
from typing import List, Optional, Sequence
from sqlalchemy import and_, exists, func, or_, select, text
from sqlalchemy import exists, func, select, text
from letta.log import get_logger
from letta.orm.agent import Agent as AgentModel
@@ -270,19 +270,20 @@ class MessageManager:
Most performant query to list messages for an agent by directly querying the Message table.
This function filters by the agent_id (leveraging the index on messages.agent_id)
and applies efficient pagination using (created_at, id) as the cursor.
and applies pagination using sequence_id as the cursor.
If query_text is provided, it will filter messages whose text content partially matches the query.
If role is provided, it will filter messages by the specified role.
Args:
agent_id: The ID of the agent whose messages are queried.
actor: The user performing the action (used for permission checks).
after: A message ID; if provided, only messages *after* this message (per sort order) are returned.
before: A message ID; if provided, only messages *before* this message are returned.
after: A message ID; if provided, only messages *after* this message (by sequence_id) are returned.
before: A message ID; if provided, only messages *before* this message (by sequence_id) are returned.
query_text: Optional string to partially match the message text content.
roles: Optional MessageRole to filter messages by role.
limit: Maximum number of messages to return.
ascending: If True, sort by (created_at, id) ascending; if False, sort descending.
ascending: If True, sort by sequence_id ascending; if False, sort descending.
group_id: Optional group ID to filter messages by group_id.
Returns:
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
@@ -290,6 +291,7 @@ class MessageManager:
Raises:
NoResultFound: If the provided after/before message IDs do not exist.
"""
with self.session_maker() as session:
# Permission check: raise if the agent doesn't exist or actor is not allowed.
AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
@@ -301,7 +303,7 @@ class MessageManager:
if group_id:
query = query.filter(MessageModel.group_id == group_id)
# If query_text is provided, filter messages using subquery.
# If query_text is provided, filter messages using subquery + json_array_elements.
if query_text:
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
query = query.filter(
@@ -313,48 +315,32 @@ class MessageManager:
)
)
# If role is provided, filter messages by role.
# If role(s) are provided, filter messages by those roles.
if roles:
role_values = [r.value for r in roles]
query = query.filter(MessageModel.role.in_(role_values))
# Apply 'after' pagination if specified.
if after:
after_ref = session.query(MessageModel.created_at, MessageModel.id).filter(MessageModel.id == after).limit(1).one_or_none()
after_ref = session.query(MessageModel.sequence_id).filter(MessageModel.id == after).one_or_none()
if not after_ref:
raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.")
query = query.filter(
or_(
MessageModel.created_at > after_ref.created_at,
and_(
MessageModel.created_at == after_ref.created_at,
MessageModel.id > after_ref.id,
),
)
)
# Filter out any messages with a sequence_id <= after_ref.sequence_id
query = query.filter(MessageModel.sequence_id > after_ref.sequence_id)
# Apply 'before' pagination if specified.
if before:
before_ref = (
session.query(MessageModel.created_at, MessageModel.id).filter(MessageModel.id == before).limit(1).one_or_none()
)
before_ref = session.query(MessageModel.sequence_id).filter(MessageModel.id == before).one_or_none()
if not before_ref:
raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.")
query = query.filter(
or_(
MessageModel.created_at < before_ref.created_at,
and_(
MessageModel.created_at == before_ref.created_at,
MessageModel.id < before_ref.id,
),
)
)
# Filter out any messages with a sequence_id >= before_ref.sequence_id
query = query.filter(MessageModel.sequence_id < before_ref.sequence_id)
# Apply ordering based on the ascending flag.
if ascending:
query = query.order_by(MessageModel.created_at.asc(), MessageModel.id.asc())
query = query.order_by(MessageModel.sequence_id.asc())
else:
query = query.order_by(MessageModel.created_at.desc(), MessageModel.id.desc())
query = query.order_by(MessageModel.sequence_id.desc())
# Limit the number of results.
query = query.limit(limit)

View File

@@ -42,7 +42,7 @@ def server_url() -> str:
return url
@pytest.fixture(scope="module")
@pytest.fixture
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
@@ -51,7 +51,7 @@ def client(server_url: str) -> Letta:
yield client_instance
@pytest.fixture(scope="module")
@pytest.fixture
def async_client(server_url: str) -> AsyncLetta:
"""
Creates and returns an asynchronous Letta REST client for testing.
@@ -60,7 +60,7 @@ def async_client(server_url: str) -> AsyncLetta:
yield async_client_instance
@pytest.fixture(scope="module")
@pytest.fixture
def roll_dice_tool(client: Letta) -> Tool:
"""
Registers a simple roll dice tool with the provided client.
@@ -82,7 +82,7 @@ def roll_dice_tool(client: Letta) -> Tool:
yield tool
@pytest.fixture(scope="module")
@pytest.fixture
def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.