diff --git a/alembic/versions/e991d2e3b428_add_monotonically_increasing_ids_to_.py b/alembic/versions/e991d2e3b428_add_monotonically_increasing_ids_to_.py new file mode 100644 index 00000000..3f028c71 --- /dev/null +++ b/alembic/versions/e991d2e3b428_add_monotonically_increasing_ids_to_.py @@ -0,0 +1,144 @@ +"""Add monotonically increasing IDs to messages table + +Revision ID: e991d2e3b428 +Revises: 74f2ede29317 +Create Date: 2025-04-01 17:02:59.820272 + +""" + +import sys +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "e991d2e3b428" +down_revision: Union[str, None] = "74f2ede29317" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +# --- Configuration --- +TABLE_NAME = "messages" +COLUMN_NAME = "sequence_id" +SEQUENCE_NAME = "message_seq_id" +INDEX_NAME = "ix_messages_agent_sequence" +UNIQUE_CONSTRAINT_NAME = f"uq_{TABLE_NAME}_{COLUMN_NAME}" + +# Columns to determine the order for back-filling existing data +ORDERING_COLUMNS = ["created_at", "id"] + + +def print_flush(message): + """Helper function to print and flush stdout immediately.""" + print(message) + sys.stdout.flush() + + +def upgrade() -> None: + """Adds sequence_id, backfills data, adds constraints and index.""" + print_flush(f"\n--- Starting upgrade for revision {revision} ---") + + # Step 1: Add the sequence_id column to the table, initially allowing NULL values. + # This allows us to add and backfill data without immediately enforcing NOT NULL. + print_flush(f"Step 1: Adding nullable column '{COLUMN_NAME}' to table '{TABLE_NAME}'...") + op.add_column(TABLE_NAME, sa.Column(COLUMN_NAME, sa.BigInteger(), nullable=True)) + + # Step 2: Create a new PostgreSQL sequence. + # This sequence will later be used as the server-side default for generating new sequence_id values. + print_flush(f"Step 2: Creating sequence '{SEQUENCE_NAME}'...") + op.execute(f"CREATE SEQUENCE {SEQUENCE_NAME} START 1;") + + # Step 3: Backfill the sequence_id for existing rows based on a defined ordering. + # The SQL query does the following: + # - Uses a Common Table Expression named 'numbered_rows' to compute a row number for each row. + # - The ROW_NUMBER() window function assigns a sequential number (rn) to each row, ordered by the columns specified + # in ORDERING_COLUMNS (e.g., created_at, id) in ascending order. + # - The UPDATE statement then sets each row's sequence_id to its corresponding row number (rn) + # by joining the original table with the CTE on the id column. + print_flush(f"Step 3: Backfilling '{COLUMN_NAME}' based on order: {', '.join(ORDERING_COLUMNS)}...") + print_flush(" (This may take a while on large tables)") + try: + op.execute( + f""" + WITH numbered_rows AS ( + SELECT + id, + ROW_NUMBER() OVER (ORDER BY {', '.join(ORDERING_COLUMNS)} ASC) as rn + FROM {TABLE_NAME} + ) + UPDATE {TABLE_NAME} + SET {COLUMN_NAME} = numbered_rows.rn + FROM numbered_rows + WHERE {TABLE_NAME}.id = numbered_rows.id; + """ + ) + print_flush(" Backfill successful.") + except Exception as e: + print_flush(f"!!! ERROR during backfill: {e}") + print_flush("!!! Migration failed. Manual intervention might be needed.") + raise + + # Step 4: Set the sequence's next value to be one more than the current maximum sequence_id. + # The query works as follows: + # - It calculates the maximum value in the sequence_id column using MAX({COLUMN_NAME}). + # - COALESCE is used to default to 0 if there are no rows (i.e., the table is empty). + # - It then adds 1 to ensure that the next call to nextval() returns a number higher than any existing value. + # - The 'false' argument tells PostgreSQL that the next nextval() should return the value as-is, without pre-incrementing. + print_flush(f"Step 4: Setting sequence '{SEQUENCE_NAME}' to next value after backfill...") + op.execute( + f""" + SELECT setval('{SEQUENCE_NAME}', COALESCE(MAX({COLUMN_NAME}), 0) + 1, false) + FROM {TABLE_NAME}; + """ + ) + + # Step 5: Now that every row has a sequence_id, alter the column to be NOT NULL. + # This enforces that all rows must have a valid sequence_id. + print_flush(f"Step 5: Altering column '{COLUMN_NAME}' to NOT NULL...") + op.alter_column(TABLE_NAME, COLUMN_NAME, existing_type=sa.BigInteger(), nullable=False) + + # Step 6: Add a UNIQUE constraint on sequence_id to ensure its values remain distinct. + # This mirrors the model definition where sequence_id is defined as unique. + print_flush(f"Step 6: Creating unique constraint '{UNIQUE_CONSTRAINT_NAME}' on '{COLUMN_NAME}'...") + op.create_unique_constraint(UNIQUE_CONSTRAINT_NAME, TABLE_NAME, [COLUMN_NAME]) + + # Step 7: Set the server-side default for sequence_id so that future inserts automatically use the sequence. + # The server default calls nextval() on the sequence, and the "::regclass" cast helps PostgreSQL resolve the sequence name correctly. + print_flush(f"Step 7: Setting server default for '{COLUMN_NAME}' to use sequence '{SEQUENCE_NAME}'...") + op.alter_column(TABLE_NAME, COLUMN_NAME, existing_type=sa.BigInteger(), server_default=sa.text(f"nextval('{SEQUENCE_NAME}'::regclass)")) + + # Step 8: Create an index on (agent_id, sequence_id) to improve performance of queries filtering on these columns. + print_flush(f"Step 8: Creating index '{INDEX_NAME}' on (agent_id, {COLUMN_NAME})...") + op.create_index(INDEX_NAME, TABLE_NAME, ["agent_id", COLUMN_NAME], unique=False) + + print_flush(f"--- Upgrade for revision {revision} complete ---") + + +def downgrade() -> None: + """Reverses the changes made in the upgrade function.""" + print_flush(f"\n--- Starting downgrade from revision {revision} ---") + + # 1. Drop the index + print_flush(f"Step 1: Dropping index '{INDEX_NAME}'...") + op.drop_index(INDEX_NAME, table_name=TABLE_NAME) + + # 2. Remove the server-side default + print_flush(f"Step 2: Removing server default from '{COLUMN_NAME}'...") + op.alter_column(TABLE_NAME, COLUMN_NAME, existing_type=sa.BigInteger(), server_default=None) + + # 3. Drop the unique constraint (using the explicit name) + print_flush(f"Step 3: Dropping unique constraint '{UNIQUE_CONSTRAINT_NAME}'...") + op.drop_constraint(UNIQUE_CONSTRAINT_NAME, TABLE_NAME, type_="unique") + + # 4. Drop the column (this implicitly removes the NOT NULL constraint) + print_flush(f"Step 4: Dropping column '{COLUMN_NAME}'...") + op.drop_column(TABLE_NAME, COLUMN_NAME) + + # 5. Drop the sequence + print_flush(f"Step 5: Dropping sequence '{SEQUENCE_NAME}'...") + op.execute(f"DROP SEQUENCE IF EXISTS {SEQUENCE_NAME};") # Use IF EXISTS for safety + + print_flush(f"--- Downgrade from revision {revision} complete ---") diff --git a/letta/orm/message.py b/letta/orm/message.py index d8ee5692..753fa657 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -1,7 +1,7 @@ from typing import List, Optional from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall -from sqlalchemy import ForeignKey, Index +from sqlalchemy import BigInteger, ForeignKey, Index, Sequence from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn @@ -20,6 +20,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): __table_args__ = ( Index("ix_messages_agent_created_at", "agent_id", "created_at"), Index("ix_messages_created_at", "created_at", "id"), + Index("ix_messages_agent_sequence", "agent_id", "sequence_id"), ) __pydantic_model__ = PydanticMessage @@ -40,6 +41,11 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): ) group_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The multi-agent group that the message was sent in") + # Monotonically increasing sequence for efficient/correct listing + sequence_id: Mapped[int] = mapped_column( + BigInteger, Sequence("message_seq_id"), unique=True, nullable=False, doc="Global monotonically increasing ID" + ) + # Relationships agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin") organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin") diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 364e76e6..c9a5d536 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -1,7 +1,7 @@ import json from typing import List, Optional, Sequence -from sqlalchemy import and_, exists, func, or_, select, text +from sqlalchemy import exists, func, select, text from letta.log import get_logger from letta.orm.agent import Agent as AgentModel @@ -270,19 +270,20 @@ class MessageManager: Most performant query to list messages for an agent by directly querying the Message table. This function filters by the agent_id (leveraging the index on messages.agent_id) - and applies efficient pagination using (created_at, id) as the cursor. + and applies pagination using sequence_id as the cursor. If query_text is provided, it will filter messages whose text content partially matches the query. If role is provided, it will filter messages by the specified role. Args: agent_id: The ID of the agent whose messages are queried. actor: The user performing the action (used for permission checks). - after: A message ID; if provided, only messages *after* this message (per sort order) are returned. - before: A message ID; if provided, only messages *before* this message are returned. + after: A message ID; if provided, only messages *after* this message (by sequence_id) are returned. + before: A message ID; if provided, only messages *before* this message (by sequence_id) are returned. query_text: Optional string to partially match the message text content. roles: Optional MessageRole to filter messages by role. limit: Maximum number of messages to return. - ascending: If True, sort by (created_at, id) ascending; if False, sort descending. + ascending: If True, sort by sequence_id ascending; if False, sort descending. + group_id: Optional group ID to filter messages by group_id. Returns: List[PydanticMessage]: A list of messages (converted via .to_pydantic()). @@ -290,6 +291,7 @@ class MessageManager: Raises: NoResultFound: If the provided after/before message IDs do not exist. """ + with self.session_maker() as session: # Permission check: raise if the agent doesn't exist or actor is not allowed. AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -301,7 +303,7 @@ class MessageManager: if group_id: query = query.filter(MessageModel.group_id == group_id) - # If query_text is provided, filter messages using subquery. + # If query_text is provided, filter messages using subquery + json_array_elements. if query_text: content_element = func.json_array_elements(MessageModel.content).alias("content_element") query = query.filter( @@ -313,48 +315,32 @@ class MessageManager: ) ) - # If role is provided, filter messages by role. + # If role(s) are provided, filter messages by those roles. if roles: role_values = [r.value for r in roles] query = query.filter(MessageModel.role.in_(role_values)) # Apply 'after' pagination if specified. if after: - after_ref = session.query(MessageModel.created_at, MessageModel.id).filter(MessageModel.id == after).limit(1).one_or_none() + after_ref = session.query(MessageModel.sequence_id).filter(MessageModel.id == after).one_or_none() if not after_ref: raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.") - query = query.filter( - or_( - MessageModel.created_at > after_ref.created_at, - and_( - MessageModel.created_at == after_ref.created_at, - MessageModel.id > after_ref.id, - ), - ) - ) + # Filter out any messages with a sequence_id <= after_ref.sequence_id + query = query.filter(MessageModel.sequence_id > after_ref.sequence_id) # Apply 'before' pagination if specified. if before: - before_ref = ( - session.query(MessageModel.created_at, MessageModel.id).filter(MessageModel.id == before).limit(1).one_or_none() - ) + before_ref = session.query(MessageModel.sequence_id).filter(MessageModel.id == before).one_or_none() if not before_ref: raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.") - query = query.filter( - or_( - MessageModel.created_at < before_ref.created_at, - and_( - MessageModel.created_at == before_ref.created_at, - MessageModel.id < before_ref.id, - ), - ) - ) + # Filter out any messages with a sequence_id >= before_ref.sequence_id + query = query.filter(MessageModel.sequence_id < before_ref.sequence_id) # Apply ordering based on the ascending flag. if ascending: - query = query.order_by(MessageModel.created_at.asc(), MessageModel.id.asc()) + query = query.order_by(MessageModel.sequence_id.asc()) else: - query = query.order_by(MessageModel.created_at.desc(), MessageModel.id.desc()) + query = query.order_by(MessageModel.sequence_id.desc()) # Limit the number of results. query = query.limit(limit) diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index d5d69969..23f7af4a 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -42,7 +42,7 @@ def server_url() -> str: return url -@pytest.fixture(scope="module") +@pytest.fixture def client(server_url: str) -> Letta: """ Creates and returns a synchronous Letta REST client for testing. @@ -51,7 +51,7 @@ def client(server_url: str) -> Letta: yield client_instance -@pytest.fixture(scope="module") +@pytest.fixture def async_client(server_url: str) -> AsyncLetta: """ Creates and returns an asynchronous Letta REST client for testing. @@ -60,7 +60,7 @@ def async_client(server_url: str) -> AsyncLetta: yield async_client_instance -@pytest.fixture(scope="module") +@pytest.fixture def roll_dice_tool(client: Letta) -> Tool: """ Registers a simple roll dice tool with the provided client. @@ -82,7 +82,7 @@ def roll_dice_tool(client: Letta) -> Tool: yield tool -@pytest.fixture(scope="module") +@pytest.fixture def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState: """ Creates and returns an agent state for testing with a pre-configured agent.