fix: Add listener for sqlite (#1616)
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from sqlalchemy import BigInteger, ForeignKey, Index, Sequence
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy import BigInteger, ForeignKey, Index, Sequence, event, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
|
||||
|
||||
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||
@@ -11,6 +11,7 @@ from letta.schemas.letta_message_content import MessageContent
|
||||
from letta.schemas.letta_message_content import TextContent as PydanticTextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import ToolReturn
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
@@ -42,9 +43,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
group_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The multi-agent group that the message was sent in")
|
||||
|
||||
# Monotonically increasing sequence for efficient/correct listing
|
||||
sequence_id: Mapped[int] = mapped_column(
|
||||
BigInteger, Sequence("message_seq_id"), unique=True, nullable=False, doc="Global monotonically increasing ID"
|
||||
)
|
||||
sequence_id = mapped_column(BigInteger, Sequence("message_seq_id"), unique=True, nullable=False)
|
||||
|
||||
# Relationships
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
||||
@@ -67,3 +66,21 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
if self.text and not model.content:
|
||||
model.content = [PydanticTextContent(text=self.text)]
|
||||
return model
|
||||
|
||||
|
||||
# listener
|
||||
|
||||
|
||||
@event.listens_for(Message, "before_insert")
|
||||
def set_sequence_id_for_sqlite(mapper, connection, target):
|
||||
# TODO: Kind of hacky, used to detect if we are using sqlite or not
|
||||
if not settings.pg_uri:
|
||||
session = Session.object_session(target)
|
||||
|
||||
if not hasattr(session, "_sequence_id_counter"):
|
||||
# Initialize counter for this flush
|
||||
max_seq = connection.scalar(text("SELECT MAX(sequence_id) FROM messages"))
|
||||
session._sequence_id_counter = max_seq or 0
|
||||
|
||||
session._sequence_id_counter += 1
|
||||
target.sequence_id = session._sequence_id_counter
|
||||
|
||||
@@ -482,7 +482,7 @@ def test_create_agents_telemetry(client: Letta):
|
||||
print(f"[telemetry] Deleted {len(workers)} existing worker agents in {end_delete - start_delete:.2f}s")
|
||||
|
||||
# create worker agents
|
||||
num_workers = 100
|
||||
num_workers = 1
|
||||
agent_times = []
|
||||
for idx in range(num_workers):
|
||||
start = time.perf_counter()
|
||||
|
||||
Reference in New Issue
Block a user