diff --git a/letta/orm/message.py b/letta/orm/message.py index 753fa657..dacadd6c 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -1,8 +1,8 @@ from typing import List, Optional from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall -from sqlalchemy import BigInteger, ForeignKey, Index, Sequence -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy import BigInteger, ForeignKey, Index, Sequence, event, text +from sqlalchemy.orm import Mapped, Session, mapped_column, relationship from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn from letta.orm.mixins import AgentMixin, OrganizationMixin @@ -11,6 +11,7 @@ from letta.schemas.letta_message_content import MessageContent from letta.schemas.letta_message_content import TextContent as PydanticTextContent from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import ToolReturn +from letta.settings import settings class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): @@ -42,9 +43,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): group_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The multi-agent group that the message was sent in") # Monotonically increasing sequence for efficient/correct listing - sequence_id: Mapped[int] = mapped_column( - BigInteger, Sequence("message_seq_id"), unique=True, nullable=False, doc="Global monotonically increasing ID" - ) + sequence_id = mapped_column(BigInteger, Sequence("message_seq_id"), unique=True, nullable=False) # Relationships agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin") @@ -67,3 +66,21 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): if self.text and not model.content: model.content = [PydanticTextContent(text=self.text)] return model + + +# listener + + +@event.listens_for(Message, "before_insert") +def set_sequence_id_for_sqlite(mapper, connection, target): + # TODO: Kind of hacky, used to detect if we are using sqlite or not + if not settings.pg_uri: + session = Session.object_session(target) + + if not hasattr(session, "_sequence_id_counter"): + # Initialize counter for this flush + max_seq = connection.scalar(text("SELECT MAX(sequence_id) FROM messages")) + session._sequence_id_counter = max_seq or 0 + + session._sequence_id_counter += 1 + target.sequence_id = session._sequence_id_counter diff --git a/tests/integration_test_experimental.py b/tests/integration_test_experimental.py index 1ca45ec1..71268b21 100644 --- a/tests/integration_test_experimental.py +++ b/tests/integration_test_experimental.py @@ -482,7 +482,7 @@ def test_create_agents_telemetry(client: Letta): print(f"[telemetry] Deleted {len(workers)} existing worker agents in {end_delete - start_delete:.2f}s") # create worker agents - num_workers = 100 + num_workers = 1 agent_times = [] for idx in range(num_workers): start = time.perf_counter()