From 0780cec6b9d63d7bbf836b1680a9caef381ad450 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 15 Apr 2025 19:25:46 -0700 Subject: [PATCH] feat: Set SqlAlchemy to fetch sequence_id from DB, not generate it itself (#1726) --- letta/orm/message.py | 10 ++++-- letta/orm/sqlalchemy_base.py | 67 ++++++++++-------------------------- 2 files changed, 27 insertions(+), 50 deletions(-) diff --git a/letta/orm/message.py b/letta/orm/message.py index cba74000..6e08f2d7 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -1,7 +1,7 @@ from typing import List, Optional from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall -from sqlalchemy import BigInteger, ForeignKey, Index, Sequence, event, text +from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, Sequence, event, text from sqlalchemy.orm import Mapped, Session, mapped_column, relationship from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn @@ -46,7 +46,13 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): ) # Monotonically increasing sequence for efficient/correct listing - sequence_id = mapped_column(BigInteger, Sequence("message_seq_id"), unique=True, nullable=False) + sequence_id: Mapped[int] = mapped_column( + BigInteger, + Sequence("message_seq_id"), + server_default=FetchedValue(), + unique=True, + nullable=False, + ) # Relationships agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin") diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 92bb6965..ca2e19b4 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -390,73 +390,44 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): @classmethod @handle_db_timeout - def batch_create( - cls, - items: List["SqlalchemyBase"], - db_session: "Session", - actor: Optional["User"] = None, - batch_size: int = 1000, # TODO: Make this a configurable setting - requery: bool = True, - ) -> List["SqlalchemyBase"]: + def batch_create(cls, items: List["SqlalchemyBase"], db_session: "Session", actor: Optional["User"] = None) -> List["SqlalchemyBase"]: """ Create multiple records in a single transaction for better performance. - Args: items: List of model instances to create db_session: SQLAlchemy session actor: Optional user performing the action - batch_size: Maximum number of items to process in a single batch - requery: Whether to requery the objects after creation - Returns: List of created model instances """ logger.debug(f"Batch creating {len(items)} {cls.__name__} items with actor={actor}") - if not items: return [] - result_items = [] + # Set created/updated by fields if actor is provided + if actor: + for item in items: + item._set_created_and_updated_by_fields(actor.id) - # Process in batches to avoid memory issues with very large sets - for i in range(0, len(items), batch_size): - batch = items[i : i + batch_size] + try: + with db_session as session: + session.add_all(items) + session.flush() # Flush to generate IDs but don't commit yet - # Set created/updated by fields if actor is provided - if actor: - for item in batch: - item._set_created_and_updated_by_fields(actor.id) + # Collect IDs to fetch the complete objects after commit + item_ids = [item.id for item in items] - try: - with db_session as session: - session.add_all(batch) - session.flush() # Flush to generate IDs but don't commit yet + session.commit() - # Collect IDs to fetch the complete objects after commit - item_ids = [item.id for item in batch] + # Re-query the objects to get them with relationships loaded + query = select(cls).where(cls.id.in_(item_ids)) + if hasattr(cls, "created_at"): + query = query.order_by(cls.created_at) - session.commit() + return list(session.execute(query).scalars()) - if requery: - # Re-query the objects to get them with relationships loaded - query = select(cls).where(cls.id.in_(item_ids)) - if hasattr(cls, "created_at"): - query = query.order_by(cls.created_at) - - batch_result = list(session.execute(query).scalars()) - else: - # Use the objects we already have in memory - batch_result = batch - - result_items.extend(batch_result) - - except (DBAPIError, IntegrityError) as e: - logger.error(f"Database error during batch creation: {e}") - # Log which items we were processing when the error occurred - logger.error(f"Failed batch starting at index {i} of {len(items)}") - cls._handle_dbapi_error(e) - - return result_items + except (DBAPIError, IntegrityError) as e: + cls._handle_dbapi_error(e) @handle_db_timeout def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":