From e964307f6ae521243ebea4988e6755cf143b981b Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 9 Jan 2026 11:02:53 -0800 Subject: [PATCH] feat: add lazy=raise for passage-org relationship (#8482) --- letta/orm/passage.py | 8 +++---- .../services/helpers/agent_manager_helper.py | 23 +++++++++++-------- letta/services/passage_manager.py | 6 ++++- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/letta/orm/passage.py b/letta/orm/passage.py index caca2764..404830f2 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -40,8 +40,8 @@ class BasePassage(SqlalchemyBase, OrganizationMixin): @declared_attr def organization(cls) -> Mapped["Organization"]: - """Relationship to organization""" - return relationship("Organization", back_populates="passages", lazy="selectin") + """Relationship to organization - use lazy='raise' to prevent accidental blocking in async contexts""" + return relationship("Organization", back_populates="passages", lazy="raise") class SourcePassage(BasePassage, FileMixin, SourceMixin): @@ -53,7 +53,7 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin): @declared_attr def organization(cls) -> Mapped["Organization"]: - return relationship("Organization", back_populates="source_passages", lazy="selectin") + return relationship("Organization", back_populates="source_passages", lazy="raise") @declared_attr def __table_args__(cls): @@ -84,7 +84,7 @@ class ArchivalPassage(BasePassage, ArchiveMixin): @declared_attr def organization(cls) -> Mapped["Organization"]: - return relationship("Organization", back_populates="archival_passages", lazy="selectin") + return relationship("Organization", back_populates="archival_passages", lazy="raise") @declared_attr def __table_args__(cls): diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 63fb7590..542b6501 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -1099,8 +1099,8 @@ async def build_source_passage_query( embedded_text = np.array(embeddings[0]) embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() - # Base query for source passages - query = select(SourcePassage).where(SourcePassage.organization_id == actor.organization_id) + # Base query for source passages - use noload to prevent lazy loading which can block the event loop + query = select(SourcePassage).options(noload(SourcePassage.organization)).where(SourcePassage.organization_id == actor.organization_id) # If agent_id is specified, join with SourcesAgents to get only passages linked to that agent if agent_id is not None: @@ -1208,23 +1208,26 @@ async def build_agent_passage_query( embedded_text = np.array(embeddings[0]) embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() - # Base query for passages + # Base query for passages - use noload to prevent lazy loading which can block the event loop if agent_id: - # Query for agent passages - join through archives_agents - # Agent_id takes precedence if both agent_id and archive_id are provided query = ( select(ArchivalPassage) + .options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags)) .join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id) .where(ArchivesAgents.agent_id == agent_id, ArchivalPassage.organization_id == actor.organization_id) ) elif archive_id: - # Query for archive passages directly - query = select(ArchivalPassage).where( - ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id + query = ( + select(ArchivalPassage) + .options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags)) + .where(ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id) ) else: - # Org-wide search - all passages in organization - query = select(ArchivalPassage).where(ArchivalPassage.organization_id == actor.organization_id) + query = ( + select(ArchivalPassage) + .options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags)) + .where(ArchivalPassage.organization_id == actor.organization_id) + ) # Apply filters if start_date: diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 8f8f2994..cbd30549 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional from openai import AsyncOpenAI, OpenAI from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import noload from letta.constants import MAX_EMBEDDING_DIM from letta.errors import EmbeddingConfigRequiredError @@ -955,7 +956,10 @@ class PassageManager: """ async with db_registry.async_session() as session: result = await session.execute( - select(SourcePassage).where(SourcePassage.file_id == file_id).where(SourcePassage.organization_id == actor.organization_id) + select(SourcePassage) + .options(noload(SourcePassage.organization)) + .where(SourcePassage.file_id == file_id) + .where(SourcePassage.organization_id == actor.organization_id) ) passages = result.scalars().all() return [p.to_pydantic() for p in passages]