From a7092c9794d0d2aa803946af19d18631aa5dddc7 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Sat, 21 Jun 2025 21:06:37 -0700 Subject: [PATCH] fix: patches sqlite support (#2876) Co-authored-by: Jin Peng Co-authored-by: jnjpng --- letta/orm/sqlalchemy_base.py | 11 +++++++++-- letta/services/user_manager.py | 7 +++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index ae27908d..8c33b4b5 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -49,6 +49,11 @@ def handle_db_timeout(func): return async_wrapper +def is_postgresql_session(session: Session) -> bool: + """Check if the database session is PostgreSQL instead of SQLite for setting query options.""" + return session.bind.dialect.name == "postgresql" + + class AccessType(str, Enum): ORGANIZATION = "organization" USER = "user" @@ -494,12 +499,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, check_is_deleted, **kwargs) if query is None: raise NoResultFound(f"{cls.__name__} not found with identifier {identifier}") - await db_session.execute(text("SET LOCAL enable_seqscan = OFF")) + if is_postgresql_session(db_session): + await db_session.execute(text("SET LOCAL enable_seqscan = OFF")) try: result = await db_session.execute(query) item = result.scalar_one_or_none() finally: - await db_session.execute(text("SET LOCAL enable_seqscan = ON")) + if is_postgresql_session(db_session): + await db_session.execute(text("SET LOCAL enable_seqscan = ON")) if item is None: raise NoResultFound(f"{cls.__name__} not found with {', '.join(query_conditions if query_conditions else ['no conditions'])}") diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index 306ec7f0..0f05e41f 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -8,6 +8,7 @@ from letta.helpers.decorators import async_redis_cache from letta.log import get_logger from letta.orm.errors import NoResultFound from letta.orm.organization import Organization as OrganizationModel +from letta.orm.sqlalchemy_base import is_postgresql_session from letta.orm.user import User as UserModel from letta.otel.tracing import trace_method from letta.schemas.user import User as PydanticUser @@ -157,13 +158,15 @@ class UserManager: """Fetch a user by ID asynchronously.""" async with db_registry.async_session() as session: # Turn off seqscan to force use pk index - await session.execute(text("SET LOCAL enable_seqscan = OFF")) + if is_postgresql_session(session): + await session.execute(text("SET LOCAL enable_seqscan = OFF")) try: stmt = select(UserModel).where(UserModel.id == actor_id) result = await session.execute(stmt) user = result.scalar_one_or_none() finally: - await session.execute(text("SET LOCAL enable_seqscan = ON")) + if is_postgresql_session(session): + await session.execute(text("SET LOCAL enable_seqscan = ON")) if not user: raise NoResultFound(f"User not found with id={actor_id}")