From ca77d16a57946a3b6f31f271bcfe0616fa7dd45a Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 29 May 2025 15:44:06 -0700 Subject: [PATCH] feat: force read_async to use indices (#2530) --- letta/orm/sqlalchemy_base.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index ead2a52f..0118f19c 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -4,7 +4,7 @@ from functools import wraps from pprint import pformat from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union -from sqlalchemy import String, and_, delete, func, or_, select +from sqlalchemy import String, and_, delete, func, or_, select, text from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, Session, mapped_column @@ -469,21 +469,17 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): Raises: NoResultFound: if the object is not found """ - # this is ok because read_multiple will check if the identifiers = [] if identifier is None else [identifier] query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, **kwargs) - result = await db_session.execute(query) - item = result.scalar_one_or_none() + await db_session.execute(text("SET LOCAL enable_seqscan = OFF")) + try: + result = await db_session.execute(query) + item = result.scalar_one_or_none() + finally: + await db_session.execute(text("SET LOCAL enable_seqscan = ON")) + if item is None: - # for backwards compatibility. - conditions = [] - if identifier: - conditions.append(f"id={identifier}") - if actor: - conditions.append(f"access level in {access} for {actor}") - if hasattr(cls, "is_deleted"): - conditions.append("is_deleted=False") - raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}") + raise NoResultFound(f"{cls.__name__} not found with {', '.join(query_conditions if query_conditions else ['no conditions'])}") return item @classmethod @@ -948,12 +944,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): org_id = getattr(actor, "organization_id", None) if not org_id: raise ValueError(f"object {actor} has no organization accessor") - return query.where(cls.organization_id == org_id, cls.is_deleted == False) + return query.where(cls.organization_id == org_id) elif access_type == AccessType.USER: user_id = getattr(actor, "id", None) if not user_id: raise ValueError(f"object {actor} has no user accessor") - return query.where(cls.user_id == user_id, cls.is_deleted == False) + return query.where(cls.user_id == user_id) else: raise ValueError(f"unknown access_type: {access_type}")