feat: add lazy=raise for passage-org relationship (#8482)
This commit is contained in:
@@ -40,8 +40,8 @@ class BasePassage(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
"""Relationship to organization"""
|
||||
return relationship("Organization", back_populates="passages", lazy="selectin")
|
||||
"""Relationship to organization - use lazy='raise' to prevent accidental blocking in async contexts"""
|
||||
return relationship("Organization", back_populates="passages", lazy="raise")
|
||||
|
||||
|
||||
class SourcePassage(BasePassage, FileMixin, SourceMixin):
|
||||
@@ -53,7 +53,7 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
|
||||
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
return relationship("Organization", back_populates="source_passages", lazy="selectin")
|
||||
return relationship("Organization", back_populates="source_passages", lazy="raise")
|
||||
|
||||
@declared_attr
|
||||
def __table_args__(cls):
|
||||
@@ -84,7 +84,7 @@ class ArchivalPassage(BasePassage, ArchiveMixin):
|
||||
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
return relationship("Organization", back_populates="archival_passages", lazy="selectin")
|
||||
return relationship("Organization", back_populates="archival_passages", lazy="raise")
|
||||
|
||||
@declared_attr
|
||||
def __table_args__(cls):
|
||||
|
||||
@@ -1099,8 +1099,8 @@ async def build_source_passage_query(
|
||||
embedded_text = np.array(embeddings[0])
|
||||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||||
|
||||
# Base query for source passages
|
||||
query = select(SourcePassage).where(SourcePassage.organization_id == actor.organization_id)
|
||||
# Base query for source passages - use noload to prevent lazy loading which can block the event loop
|
||||
query = select(SourcePassage).options(noload(SourcePassage.organization)).where(SourcePassage.organization_id == actor.organization_id)
|
||||
|
||||
# If agent_id is specified, join with SourcesAgents to get only passages linked to that agent
|
||||
if agent_id is not None:
|
||||
@@ -1208,23 +1208,26 @@ async def build_agent_passage_query(
|
||||
embedded_text = np.array(embeddings[0])
|
||||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||||
|
||||
# Base query for passages
|
||||
# Base query for passages - use noload to prevent lazy loading which can block the event loop
|
||||
if agent_id:
|
||||
# Query for agent passages - join through archives_agents
|
||||
# Agent_id takes precedence if both agent_id and archive_id are provided
|
||||
query = (
|
||||
select(ArchivalPassage)
|
||||
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
|
||||
.join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id)
|
||||
.where(ArchivesAgents.agent_id == agent_id, ArchivalPassage.organization_id == actor.organization_id)
|
||||
)
|
||||
elif archive_id:
|
||||
# Query for archive passages directly
|
||||
query = select(ArchivalPassage).where(
|
||||
ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id
|
||||
query = (
|
||||
select(ArchivalPassage)
|
||||
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
|
||||
.where(ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id)
|
||||
)
|
||||
else:
|
||||
# Org-wide search - all passages in organization
|
||||
query = select(ArchivalPassage).where(ArchivalPassage.organization_id == actor.organization_id)
|
||||
query = (
|
||||
select(ArchivalPassage)
|
||||
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
|
||||
.where(ArchivalPassage.organization_id == actor.organization_id)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if start_date:
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Dict, List, Optional
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.errors import EmbeddingConfigRequiredError
|
||||
@@ -955,7 +956,10 @@ class PassageManager:
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(SourcePassage).where(SourcePassage.file_id == file_id).where(SourcePassage.organization_id == actor.organization_id)
|
||||
select(SourcePassage)
|
||||
.options(noload(SourcePassage.organization))
|
||||
.where(SourcePassage.file_id == file_id)
|
||||
.where(SourcePassage.organization_id == actor.organization_id)
|
||||
)
|
||||
passages = result.scalars().all()
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
Reference in New Issue
Block a user