diff --git a/alembic/env.py b/alembic/env.py index a566f92c..dac40ea4 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -6,7 +6,7 @@ from sqlalchemy import engine_from_config, pool from alembic import context from letta.config import LettaConfig from letta.orm import Base -from letta.settings import settings +from letta.settings import DatabaseChoice, settings letta_config = LettaConfig.load() @@ -14,7 +14,7 @@ letta_config = LettaConfig.load() # access to the values within the .ini file in use. config = context.config -if settings.letta_pg_uri_no_default: +if settings.database_engine is DatabaseChoice.POSTGRES: config.set_main_option("sqlalchemy.url", settings.letta_pg_uri) print("Using database: ", settings.letta_pg_uri) else: diff --git a/letta/orm/message.py b/letta/orm/message.py index ba3acd82..100e8bff 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -11,7 +11,7 @@ from letta.schemas.letta_message_content import MessageContent from letta.schemas.letta_message_content import TextContent as PydanticTextContent from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import ToolReturn -from letta.settings import settings +from letta.settings import DatabaseChoice, settings class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): @@ -92,7 +92,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): @event.listens_for(Session, "before_flush") def set_sequence_id_for_sqlite_bulk(session, flush_context, instances): # Handle bulk inserts for SQLite - if not settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.SQLITE: # Find all new Message objects that need sequence IDs new_messages = [obj for obj in session.new if isinstance(obj, Message) and obj.sequence_id is None] @@ -165,8 +165,7 @@ def set_sequence_id_for_sqlite_bulk(session, flush_context, instances): @event.listens_for(Message, "before_insert") def set_sequence_id_for_sqlite(mapper, connection, target): - # TODO: Kind of hacky, used to detect if we are using sqlite or not - if not settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.SQLITE: # For SQLite, we need to generate sequence_id manually # Use a database-level atomic operation to avoid race conditions diff --git a/letta/orm/passage.py b/letta/orm/passage.py index 868f8a67..1a6c48a2 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -9,7 +9,7 @@ from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.passage import Passage as PydanticPassage -from letta.settings import settings +from letta.settings import DatabaseChoice, settings config = LettaConfig() @@ -29,7 +29,7 @@ class BasePassage(SqlalchemyBase, OrganizationMixin): metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata") # Vector embedding field based on database type - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: from pgvector.sqlalchemy import Vector embedding = mapped_column(Vector(MAX_EMBEDDING_DIM)) @@ -56,7 +56,7 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin): @declared_attr def __table_args__(cls): # TODO (cliandy): investigate if this is necessary, may be for SQLite compatability or do we need to add as well? - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: return ( Index("source_passages_org_idx", "organization_id"), Index("source_passages_created_at_id_idx", "created_at", "id"), @@ -81,7 +81,7 @@ class AgentPassage(BasePassage, AgentMixin): @declared_attr def __table_args__(cls): - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: return ( Index("agent_passages_org_idx", "organization_id"), Index("ix_agent_passages_org_agent", "organization_id", "agent_id"), diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index d1a79eb6..3c60ab25 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -15,6 +15,7 @@ from letta.log import get_logger from letta.orm.base import Base, CommonSqlalchemyMetaMixins from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError from letta.orm.sqlite_functions import adapt_array +from letta.settings import DatabaseChoice if TYPE_CHECKING: from pydantic import BaseModel @@ -395,7 +396,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): from letta.settings import settings - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: # PostgreSQL with pgvector query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc()) else: diff --git a/letta/server/db.py b/letta/server/db.py index 089b94a6..30e6a723 100644 --- a/letta/server/db.py +++ b/letta/server/db.py @@ -14,7 +14,7 @@ from sqlalchemy.orm import sessionmaker from letta.config import LettaConfig from letta.log import get_logger from letta.otel.tracing import trace_method -from letta.settings import settings +from letta.settings import DatabaseChoice, settings logger = get_logger(__name__) @@ -90,7 +90,7 @@ class DatabaseRegistry: return # Postgres engine - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: self.logger.info("Creating postgres engine") self.config.recall_storage_type = "postgres" self.config.recall_storage_uri = settings.letta_pg_uri_no_default @@ -128,7 +128,7 @@ class DatabaseRegistry: if self._initialized.get("async") and not force: return - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: self.logger.info("Creating async postgres engine") # Create async engine - convert URI to async format diff --git a/letta/server/server.py b/letta/server/server.py index 1afeccaf..3ebe2db6 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -104,7 +104,7 @@ from letta.services.telemetry_manager import TelemetryManager from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager -from letta.settings import model_settings, settings, tool_settings +from letta.settings import DatabaseChoice, model_settings, settings, tool_settings from letta.streaming_interface import AgentChunkStreamingInterface from letta.utils import get_friendly_error_msg, get_persona_text, make_key @@ -196,7 +196,7 @@ class SyncServer(Server): # Initialize the metadata store config = LettaConfig.load() - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: config.recall_storage_type = "postgres" config.recall_storage_uri = settings.letta_pg_uri_no_default config.archival_storage_type = "postgres" diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 87bc7d97..0b872a97 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -91,7 +91,7 @@ from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager -from letta.settings import settings +from letta.settings import DatabaseChoice, settings from letta.utils import enforce_types, united_diff logger = get_logger(__name__) @@ -2732,7 +2732,7 @@ class AgentManager: ) if query_text: - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: # PostgreSQL: Use ILIKE for case-insensitive search query = query.filter(AgentsTags.tag.ilike(f"%{query_text}%")) else: @@ -2773,7 +2773,7 @@ class AgentManager: ) if query_text: - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: # PostgreSQL: Use ILIKE for case-insensitive search query = query.where(AgentsTags.tag.ilike(f"%{query_text}%")) else: diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 62b85657..3d7a1ef6 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -37,7 +37,7 @@ from letta.schemas.memory import Memory from letta.schemas.message import Message, MessageCreate from letta.schemas.tool_rule import ToolRule from letta.schemas.user import User -from letta.settings import settings +from letta.settings import DatabaseChoice, settings from letta.system import get_initial_boot_messages, get_login_event, package_function_response @@ -552,7 +552,7 @@ async def _apply_pagination_async( if result: after_sort_value, after_id = result # SQLite does not support as granular timestamping, so we need to round the timestamp - if not settings.letta_pg_uri_no_default and isinstance(after_sort_value, datetime): + if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime): after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S") query = query.where( _cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last) @@ -563,7 +563,7 @@ async def _apply_pagination_async( if result: before_sort_value, before_id = result # SQLite does not support as granular timestamping, so we need to round the timestamp - if not settings.letta_pg_uri_no_default and isinstance(before_sort_value, datetime): + if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime): before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S") query = query.where( _cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last) @@ -655,7 +655,7 @@ def _apply_filters( query = query.where(AgentModel.name == name) # Apply a case-insensitive partial match for the agent's name. if query_text: - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: # PostgreSQL: Use ILIKE for case-insensitive search query = query.where(AgentModel.name.ilike(f"%{query_text}%")) else: @@ -801,7 +801,7 @@ def build_passage_query( # Vector search if embedded_text: - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: # PostgreSQL with pgvector main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc()) else: @@ -928,7 +928,7 @@ def build_source_passage_query( # Handle text search or vector search if embedded_text: - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: # PostgreSQL with pgvector query = query.order_by(SourcePassage.embedding.cosine_distance(embedded_text).asc()) else: @@ -1015,7 +1015,7 @@ def build_agent_passage_query( # Handle text search or vector search if embedded_text: - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: # PostgreSQL with pgvector query = query.order_by(AgentPassage.embedding.cosine_distance(embedded_text).asc()) else: diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 8bd82297..da04e907 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -13,7 +13,7 @@ from letta.schemas.identity import Identity as PydanticIdentity from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityType, IdentityUpdate, IdentityUpsert from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry -from letta.settings import settings +from letta.settings import DatabaseChoice, settings from letta.utils import enforce_types @@ -69,7 +69,7 @@ class IdentityManager: # For SQLite compatibility: check for unique constraint violation manually # since SQLite doesn't support postgresql_nulls_not_distinct=True - if not settings.letta_pg_uri_no_default: # Using SQLite + if settings.database_engine is DatabaseChoice.SQLITE: # Check if an identity with the same identifier_key, project_id, and organization_id exists query = select(IdentityModel).where( IdentityModel.identifier_key == new_identity.identifier_key, diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 281268d2..aa257fea 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -28,7 +28,7 @@ from letta.schemas.step import Step as PydanticStep from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry -from letta.settings import settings +from letta.settings import DatabaseChoice, settings from letta.utils import enforce_types logger = get_logger(__name__) @@ -336,7 +336,7 @@ class JobManager: before_timestamp = before_obj.created_at # SQLite does not support as granular timestamping, so we need to round the timestamp - if not settings.letta_pg_uri_no_default and isinstance(before_timestamp, datetime): + if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_timestamp, datetime): before_timestamp = before_timestamp.strftime("%Y-%m-%d %H:%M:%S") conditions.append( @@ -350,7 +350,7 @@ class JobManager: # records after this cursor (newer) after_timestamp = after_obj.created_at # SQLite does not support as granular timestamping, so we need to round the timestamp - if not settings.letta_pg_uri_no_default and isinstance(after_timestamp, datetime): + if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_timestamp, datetime): after_timestamp = after_timestamp.strftime("%Y-%m-%d %H:%M:%S") conditions.append( diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 6770157f..46e04d2c 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -18,7 +18,7 @@ from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.services.file_manager import FileManager from letta.services.helpers.agent_manager_helper import validate_agent_exists_async -from letta.settings import settings +from letta.settings import DatabaseChoice, settings from letta.utils import enforce_types logger = get_logger(__name__) @@ -457,7 +457,7 @@ class MessageManager: # If query_text is provided, filter messages using database-specific JSON search. if query_text: - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: # PostgreSQL: Use json_array_elements and ILIKE content_element = func.json_array_elements(MessageModel.content).alias("content_element") query = query.filter( @@ -565,7 +565,7 @@ class MessageManager: # If query_text is provided, filter messages using database-specific JSON search. if query_text: - if settings.letta_pg_uri_no_default: + if settings.database_engine is DatabaseChoice.POSTGRES: # PostgreSQL: Use json_array_elements and ILIKE content_element = func.json_array_elements(MessageModel.content).alias("content_element") query = query.where( diff --git a/letta/settings.py b/letta/settings.py index b38071ca..39fb754a 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -1,4 +1,5 @@ import os +from enum import Enum from pathlib import Path from typing import Optional @@ -182,6 +183,11 @@ if "--use-file-pg-uri" in sys.argv: pass +class DatabaseChoice(str, Enum): + POSTGRES = "postgres" + SQLITE = "sqlite" + + class Settings(BaseSettings): model_config = SettingsConfigDict(env_prefix="letta_", extra="ignore") @@ -291,6 +297,10 @@ class Settings(BaseSettings): else: return None + @property + def database_engine(self) -> DatabaseChoice: + return DatabaseChoice.POSTGRES if self.letta_pg_uri_no_default else DatabaseChoice.SQLITE + @property def plugin_register_dict(self) -> dict: plugins = {}