feat: Add monotonic ids on messages (#1522)
This commit is contained in:
@@ -0,0 +1,144 @@
|
||||
"""Add monotonically increasing IDs to messages table
|
||||
|
||||
Revision ID: e991d2e3b428
|
||||
Revises: 74f2ede29317
|
||||
Create Date: 2025-04-01 17:02:59.820272
|
||||
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "e991d2e3b428"
|
||||
down_revision: Union[str, None] = "74f2ede29317"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
# --- Configuration ---
|
||||
TABLE_NAME = "messages"
|
||||
COLUMN_NAME = "sequence_id"
|
||||
SEQUENCE_NAME = "message_seq_id"
|
||||
INDEX_NAME = "ix_messages_agent_sequence"
|
||||
UNIQUE_CONSTRAINT_NAME = f"uq_{TABLE_NAME}_{COLUMN_NAME}"
|
||||
|
||||
# Columns to determine the order for back-filling existing data
|
||||
ORDERING_COLUMNS = ["created_at", "id"]
|
||||
|
||||
|
||||
def print_flush(message):
|
||||
"""Helper function to print and flush stdout immediately."""
|
||||
print(message)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Adds sequence_id, backfills data, adds constraints and index."""
|
||||
print_flush(f"\n--- Starting upgrade for revision {revision} ---")
|
||||
|
||||
# Step 1: Add the sequence_id column to the table, initially allowing NULL values.
|
||||
# This allows us to add and backfill data without immediately enforcing NOT NULL.
|
||||
print_flush(f"Step 1: Adding nullable column '{COLUMN_NAME}' to table '{TABLE_NAME}'...")
|
||||
op.add_column(TABLE_NAME, sa.Column(COLUMN_NAME, sa.BigInteger(), nullable=True))
|
||||
|
||||
# Step 2: Create a new PostgreSQL sequence.
|
||||
# This sequence will later be used as the server-side default for generating new sequence_id values.
|
||||
print_flush(f"Step 2: Creating sequence '{SEQUENCE_NAME}'...")
|
||||
op.execute(f"CREATE SEQUENCE {SEQUENCE_NAME} START 1;")
|
||||
|
||||
# Step 3: Backfill the sequence_id for existing rows based on a defined ordering.
|
||||
# The SQL query does the following:
|
||||
# - Uses a Common Table Expression named 'numbered_rows' to compute a row number for each row.
|
||||
# - The ROW_NUMBER() window function assigns a sequential number (rn) to each row, ordered by the columns specified
|
||||
# in ORDERING_COLUMNS (e.g., created_at, id) in ascending order.
|
||||
# - The UPDATE statement then sets each row's sequence_id to its corresponding row number (rn)
|
||||
# by joining the original table with the CTE on the id column.
|
||||
print_flush(f"Step 3: Backfilling '{COLUMN_NAME}' based on order: {', '.join(ORDERING_COLUMNS)}...")
|
||||
print_flush(" (This may take a while on large tables)")
|
||||
try:
|
||||
op.execute(
|
||||
f"""
|
||||
WITH numbered_rows AS (
|
||||
SELECT
|
||||
id,
|
||||
ROW_NUMBER() OVER (ORDER BY {', '.join(ORDERING_COLUMNS)} ASC) as rn
|
||||
FROM {TABLE_NAME}
|
||||
)
|
||||
UPDATE {TABLE_NAME}
|
||||
SET {COLUMN_NAME} = numbered_rows.rn
|
||||
FROM numbered_rows
|
||||
WHERE {TABLE_NAME}.id = numbered_rows.id;
|
||||
"""
|
||||
)
|
||||
print_flush(" Backfill successful.")
|
||||
except Exception as e:
|
||||
print_flush(f"!!! ERROR during backfill: {e}")
|
||||
print_flush("!!! Migration failed. Manual intervention might be needed.")
|
||||
raise
|
||||
|
||||
# Step 4: Set the sequence's next value to be one more than the current maximum sequence_id.
|
||||
# The query works as follows:
|
||||
# - It calculates the maximum value in the sequence_id column using MAX({COLUMN_NAME}).
|
||||
# - COALESCE is used to default to 0 if there are no rows (i.e., the table is empty).
|
||||
# - It then adds 1 to ensure that the next call to nextval() returns a number higher than any existing value.
|
||||
# - The 'false' argument tells PostgreSQL that the next nextval() should return the value as-is, without pre-incrementing.
|
||||
print_flush(f"Step 4: Setting sequence '{SEQUENCE_NAME}' to next value after backfill...")
|
||||
op.execute(
|
||||
f"""
|
||||
SELECT setval('{SEQUENCE_NAME}', COALESCE(MAX({COLUMN_NAME}), 0) + 1, false)
|
||||
FROM {TABLE_NAME};
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 5: Now that every row has a sequence_id, alter the column to be NOT NULL.
|
||||
# This enforces that all rows must have a valid sequence_id.
|
||||
print_flush(f"Step 5: Altering column '{COLUMN_NAME}' to NOT NULL...")
|
||||
op.alter_column(TABLE_NAME, COLUMN_NAME, existing_type=sa.BigInteger(), nullable=False)
|
||||
|
||||
# Step 6: Add a UNIQUE constraint on sequence_id to ensure its values remain distinct.
|
||||
# This mirrors the model definition where sequence_id is defined as unique.
|
||||
print_flush(f"Step 6: Creating unique constraint '{UNIQUE_CONSTRAINT_NAME}' on '{COLUMN_NAME}'...")
|
||||
op.create_unique_constraint(UNIQUE_CONSTRAINT_NAME, TABLE_NAME, [COLUMN_NAME])
|
||||
|
||||
# Step 7: Set the server-side default for sequence_id so that future inserts automatically use the sequence.
|
||||
# The server default calls nextval() on the sequence, and the "::regclass" cast helps PostgreSQL resolve the sequence name correctly.
|
||||
print_flush(f"Step 7: Setting server default for '{COLUMN_NAME}' to use sequence '{SEQUENCE_NAME}'...")
|
||||
op.alter_column(TABLE_NAME, COLUMN_NAME, existing_type=sa.BigInteger(), server_default=sa.text(f"nextval('{SEQUENCE_NAME}'::regclass)"))
|
||||
|
||||
# Step 8: Create an index on (agent_id, sequence_id) to improve performance of queries filtering on these columns.
|
||||
print_flush(f"Step 8: Creating index '{INDEX_NAME}' on (agent_id, {COLUMN_NAME})...")
|
||||
op.create_index(INDEX_NAME, TABLE_NAME, ["agent_id", COLUMN_NAME], unique=False)
|
||||
|
||||
print_flush(f"--- Upgrade for revision {revision} complete ---")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Reverses the changes made in the upgrade function."""
|
||||
print_flush(f"\n--- Starting downgrade from revision {revision} ---")
|
||||
|
||||
# 1. Drop the index
|
||||
print_flush(f"Step 1: Dropping index '{INDEX_NAME}'...")
|
||||
op.drop_index(INDEX_NAME, table_name=TABLE_NAME)
|
||||
|
||||
# 2. Remove the server-side default
|
||||
print_flush(f"Step 2: Removing server default from '{COLUMN_NAME}'...")
|
||||
op.alter_column(TABLE_NAME, COLUMN_NAME, existing_type=sa.BigInteger(), server_default=None)
|
||||
|
||||
# 3. Drop the unique constraint (using the explicit name)
|
||||
print_flush(f"Step 3: Dropping unique constraint '{UNIQUE_CONSTRAINT_NAME}'...")
|
||||
op.drop_constraint(UNIQUE_CONSTRAINT_NAME, TABLE_NAME, type_="unique")
|
||||
|
||||
# 4. Drop the column (this implicitly removes the NOT NULL constraint)
|
||||
print_flush(f"Step 4: Dropping column '{COLUMN_NAME}'...")
|
||||
op.drop_column(TABLE_NAME, COLUMN_NAME)
|
||||
|
||||
# 5. Drop the sequence
|
||||
print_flush(f"Step 5: Dropping sequence '{SEQUENCE_NAME}'...")
|
||||
op.execute(f"DROP SEQUENCE IF EXISTS {SEQUENCE_NAME};") # Use IF EXISTS for safety
|
||||
|
||||
print_flush(f"--- Downgrade from revision {revision} complete ---")
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from sqlalchemy import ForeignKey, Index
|
||||
from sqlalchemy import BigInteger, ForeignKey, Index, Sequence
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
||||
@@ -20,6 +20,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
__table_args__ = (
|
||||
Index("ix_messages_agent_created_at", "agent_id", "created_at"),
|
||||
Index("ix_messages_created_at", "created_at", "id"),
|
||||
Index("ix_messages_agent_sequence", "agent_id", "sequence_id"),
|
||||
)
|
||||
__pydantic_model__ = PydanticMessage
|
||||
|
||||
@@ -40,6 +41,11 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
)
|
||||
group_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The multi-agent group that the message was sent in")
|
||||
|
||||
# Monotonically increasing sequence for efficient/correct listing
|
||||
sequence_id: Mapped[int] = mapped_column(
|
||||
BigInteger, Sequence("message_seq_id"), unique=True, nullable=False, doc="Global monotonically increasing ID"
|
||||
)
|
||||
|
||||
# Relationships
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from sqlalchemy import and_, exists, func, or_, select, text
|
||||
from sqlalchemy import exists, func, select, text
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
@@ -270,19 +270,20 @@ class MessageManager:
|
||||
Most performant query to list messages for an agent by directly querying the Message table.
|
||||
|
||||
This function filters by the agent_id (leveraging the index on messages.agent_id)
|
||||
and applies efficient pagination using (created_at, id) as the cursor.
|
||||
and applies pagination using sequence_id as the cursor.
|
||||
If query_text is provided, it will filter messages whose text content partially matches the query.
|
||||
If role is provided, it will filter messages by the specified role.
|
||||
|
||||
Args:
|
||||
agent_id: The ID of the agent whose messages are queried.
|
||||
actor: The user performing the action (used for permission checks).
|
||||
after: A message ID; if provided, only messages *after* this message (per sort order) are returned.
|
||||
before: A message ID; if provided, only messages *before* this message are returned.
|
||||
after: A message ID; if provided, only messages *after* this message (by sequence_id) are returned.
|
||||
before: A message ID; if provided, only messages *before* this message (by sequence_id) are returned.
|
||||
query_text: Optional string to partially match the message text content.
|
||||
roles: Optional MessageRole to filter messages by role.
|
||||
limit: Maximum number of messages to return.
|
||||
ascending: If True, sort by (created_at, id) ascending; if False, sort descending.
|
||||
ascending: If True, sort by sequence_id ascending; if False, sort descending.
|
||||
group_id: Optional group ID to filter messages by group_id.
|
||||
|
||||
Returns:
|
||||
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
|
||||
@@ -290,6 +291,7 @@ class MessageManager:
|
||||
Raises:
|
||||
NoResultFound: If the provided after/before message IDs do not exist.
|
||||
"""
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Permission check: raise if the agent doesn't exist or actor is not allowed.
|
||||
AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
@@ -301,7 +303,7 @@ class MessageManager:
|
||||
if group_id:
|
||||
query = query.filter(MessageModel.group_id == group_id)
|
||||
|
||||
# If query_text is provided, filter messages using subquery.
|
||||
# If query_text is provided, filter messages using subquery + json_array_elements.
|
||||
if query_text:
|
||||
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
|
||||
query = query.filter(
|
||||
@@ -313,48 +315,32 @@ class MessageManager:
|
||||
)
|
||||
)
|
||||
|
||||
# If role is provided, filter messages by role.
|
||||
# If role(s) are provided, filter messages by those roles.
|
||||
if roles:
|
||||
role_values = [r.value for r in roles]
|
||||
query = query.filter(MessageModel.role.in_(role_values))
|
||||
|
||||
# Apply 'after' pagination if specified.
|
||||
if after:
|
||||
after_ref = session.query(MessageModel.created_at, MessageModel.id).filter(MessageModel.id == after).limit(1).one_or_none()
|
||||
after_ref = session.query(MessageModel.sequence_id).filter(MessageModel.id == after).one_or_none()
|
||||
if not after_ref:
|
||||
raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.")
|
||||
query = query.filter(
|
||||
or_(
|
||||
MessageModel.created_at > after_ref.created_at,
|
||||
and_(
|
||||
MessageModel.created_at == after_ref.created_at,
|
||||
MessageModel.id > after_ref.id,
|
||||
),
|
||||
)
|
||||
)
|
||||
# Filter out any messages with a sequence_id <= after_ref.sequence_id
|
||||
query = query.filter(MessageModel.sequence_id > after_ref.sequence_id)
|
||||
|
||||
# Apply 'before' pagination if specified.
|
||||
if before:
|
||||
before_ref = (
|
||||
session.query(MessageModel.created_at, MessageModel.id).filter(MessageModel.id == before).limit(1).one_or_none()
|
||||
)
|
||||
before_ref = session.query(MessageModel.sequence_id).filter(MessageModel.id == before).one_or_none()
|
||||
if not before_ref:
|
||||
raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.")
|
||||
query = query.filter(
|
||||
or_(
|
||||
MessageModel.created_at < before_ref.created_at,
|
||||
and_(
|
||||
MessageModel.created_at == before_ref.created_at,
|
||||
MessageModel.id < before_ref.id,
|
||||
),
|
||||
)
|
||||
)
|
||||
# Filter out any messages with a sequence_id >= before_ref.sequence_id
|
||||
query = query.filter(MessageModel.sequence_id < before_ref.sequence_id)
|
||||
|
||||
# Apply ordering based on the ascending flag.
|
||||
if ascending:
|
||||
query = query.order_by(MessageModel.created_at.asc(), MessageModel.id.asc())
|
||||
query = query.order_by(MessageModel.sequence_id.asc())
|
||||
else:
|
||||
query = query.order_by(MessageModel.created_at.desc(), MessageModel.id.desc())
|
||||
query = query.order_by(MessageModel.sequence_id.desc())
|
||||
|
||||
# Limit the number of results.
|
||||
query = query.limit(limit)
|
||||
|
||||
@@ -42,7 +42,7 @@ def server_url() -> str:
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture
|
||||
def client(server_url: str) -> Letta:
|
||||
"""
|
||||
Creates and returns a synchronous Letta REST client for testing.
|
||||
@@ -51,7 +51,7 @@ def client(server_url: str) -> Letta:
|
||||
yield client_instance
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture
|
||||
def async_client(server_url: str) -> AsyncLetta:
|
||||
"""
|
||||
Creates and returns an asynchronous Letta REST client for testing.
|
||||
@@ -60,7 +60,7 @@ def async_client(server_url: str) -> AsyncLetta:
|
||||
yield async_client_instance
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture
|
||||
def roll_dice_tool(client: Letta) -> Tool:
|
||||
"""
|
||||
Registers a simple roll dice tool with the provided client.
|
||||
@@ -82,7 +82,7 @@ def roll_dice_tool(client: Letta) -> Tool:
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture
|
||||
def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
|
||||
"""
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
|
||||
Reference in New Issue
Block a user