feat: force read_async to use indices (#2530)

This commit is contained in:
cthomas
2025-05-29 15:44:06 -07:00
committed by GitHub
parent b634939990
commit ca77d16a57

View File

@@ -4,7 +4,7 @@ from functools import wraps
from pprint import pformat
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union
from sqlalchemy import String, and_, delete, func, or_, select
from sqlalchemy import String, and_, delete, func, or_, select, text
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, Session, mapped_column
@@ -469,21 +469,17 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
Raises:
NoResultFound: if the object is not found
"""
# this is ok because read_multiple will check if the
identifiers = [] if identifier is None else [identifier]
query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, **kwargs)
result = await db_session.execute(query)
item = result.scalar_one_or_none()
await db_session.execute(text("SET LOCAL enable_seqscan = OFF"))
try:
result = await db_session.execute(query)
item = result.scalar_one_or_none()
finally:
await db_session.execute(text("SET LOCAL enable_seqscan = ON"))
if item is None:
# for backwards compatibility.
conditions = []
if identifier:
conditions.append(f"id={identifier}")
if actor:
conditions.append(f"access level in {access} for {actor}")
if hasattr(cls, "is_deleted"):
conditions.append("is_deleted=False")
raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}")
raise NoResultFound(f"{cls.__name__} not found with {', '.join(query_conditions if query_conditions else ['no conditions'])}")
return item
@classmethod
@@ -948,12 +944,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
org_id = getattr(actor, "organization_id", None)
if not org_id:
raise ValueError(f"object {actor} has no organization accessor")
return query.where(cls.organization_id == org_id, cls.is_deleted == False)
return query.where(cls.organization_id == org_id)
elif access_type == AccessType.USER:
user_id = getattr(actor, "id", None)
if not user_id:
raise ValueError(f"object {actor} has no user accessor")
return query.where(cls.user_id == user_id, cls.is_deleted == False)
return query.where(cls.user_id == user_id)
else:
raise ValueError(f"unknown access_type: {access_type}")