feat: add lazy=raise for passage-org relationship (#8482)

This commit is contained in:
cthomas
2026-01-09 11:02:53 -08:00
committed by Caren Thomas
parent 0cbdf452fa
commit e964307f6a
3 changed files with 22 additions and 15 deletions

View File

@@ -40,8 +40,8 @@ class BasePassage(SqlalchemyBase, OrganizationMixin):
@declared_attr
def organization(cls) -> Mapped["Organization"]:
"""Relationship to organization"""
return relationship("Organization", back_populates="passages", lazy="selectin")
"""Relationship to organization - use lazy='raise' to prevent accidental blocking in async contexts"""
return relationship("Organization", back_populates="passages", lazy="raise")
class SourcePassage(BasePassage, FileMixin, SourceMixin):
@@ -53,7 +53,7 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
@declared_attr
def organization(cls) -> Mapped["Organization"]:
return relationship("Organization", back_populates="source_passages", lazy="selectin")
return relationship("Organization", back_populates="source_passages", lazy="raise")
@declared_attr
def __table_args__(cls):
@@ -84,7 +84,7 @@ class ArchivalPassage(BasePassage, ArchiveMixin):
@declared_attr
def organization(cls) -> Mapped["Organization"]:
return relationship("Organization", back_populates="archival_passages", lazy="selectin")
return relationship("Organization", back_populates="archival_passages", lazy="raise")
@declared_attr
def __table_args__(cls):

View File

@@ -1099,8 +1099,8 @@ async def build_source_passage_query(
embedded_text = np.array(embeddings[0])
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
# Base query for source passages
query = select(SourcePassage).where(SourcePassage.organization_id == actor.organization_id)
# Base query for source passages - use noload to prevent lazy loading which can block the event loop
query = select(SourcePassage).options(noload(SourcePassage.organization)).where(SourcePassage.organization_id == actor.organization_id)
# If agent_id is specified, join with SourcesAgents to get only passages linked to that agent
if agent_id is not None:
@@ -1208,23 +1208,26 @@ async def build_agent_passage_query(
embedded_text = np.array(embeddings[0])
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
# Base query for passages
# Base query for passages - use noload to prevent lazy loading which can block the event loop
if agent_id:
# Query for agent passages - join through archives_agents
# Agent_id takes precedence if both agent_id and archive_id are provided
query = (
select(ArchivalPassage)
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
.join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id)
.where(ArchivesAgents.agent_id == agent_id, ArchivalPassage.organization_id == actor.organization_id)
)
elif archive_id:
# Query for archive passages directly
query = select(ArchivalPassage).where(
ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id
query = (
select(ArchivalPassage)
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
.where(ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id)
)
else:
# Org-wide search - all passages in organization
query = select(ArchivalPassage).where(ArchivalPassage.organization_id == actor.organization_id)
query = (
select(ArchivalPassage)
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
.where(ArchivalPassage.organization_id == actor.organization_id)
)
# Apply filters
if start_date:

View File

@@ -6,6 +6,7 @@ from typing import Dict, List, Optional
from openai import AsyncOpenAI, OpenAI
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import noload
from letta.constants import MAX_EMBEDDING_DIM
from letta.errors import EmbeddingConfigRequiredError
@@ -955,7 +956,10 @@ class PassageManager:
"""
async with db_registry.async_session() as session:
result = await session.execute(
select(SourcePassage).where(SourcePassage.file_id == file_id).where(SourcePassage.organization_id == actor.organization_id)
select(SourcePassage)
.options(noload(SourcePassage.organization))
.where(SourcePassage.file_id == file_id)
.where(SourcePassage.organization_id == actor.organization_id)
)
passages = result.scalars().all()
return [p.to_pydantic() for p in passages]