chore: cleanup database detection
This commit is contained in:
@@ -6,7 +6,7 @@ from sqlalchemy import engine_from_config, pool
|
||||
from alembic import context
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm import Base
|
||||
from letta.settings import settings
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
letta_config = LettaConfig.load()
|
||||
|
||||
@@ -14,7 +14,7 @@ letta_config = LettaConfig.load()
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
config.set_main_option("sqlalchemy.url", settings.letta_pg_uri)
|
||||
print("Using database: ", settings.letta_pg_uri)
|
||||
else:
|
||||
|
||||
@@ -11,7 +11,7 @@ from letta.schemas.letta_message_content import MessageContent
|
||||
from letta.schemas.letta_message_content import TextContent as PydanticTextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import ToolReturn
|
||||
from letta.settings import settings
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
|
||||
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
@@ -92,7 +92,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
@event.listens_for(Session, "before_flush")
|
||||
def set_sequence_id_for_sqlite_bulk(session, flush_context, instances):
|
||||
# Handle bulk inserts for SQLite
|
||||
if not settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.SQLITE:
|
||||
# Find all new Message objects that need sequence IDs
|
||||
new_messages = [obj for obj in session.new if isinstance(obj, Message) and obj.sequence_id is None]
|
||||
|
||||
@@ -165,8 +165,7 @@ def set_sequence_id_for_sqlite_bulk(session, flush_context, instances):
|
||||
|
||||
@event.listens_for(Message, "before_insert")
|
||||
def set_sequence_id_for_sqlite(mapper, connection, target):
|
||||
# TODO: Kind of hacky, used to detect if we are using sqlite or not
|
||||
if not settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.SQLITE:
|
||||
# For SQLite, we need to generate sequence_id manually
|
||||
# Use a database-level atomic operation to avoid race conditions
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn
|
||||
from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.settings import settings
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
config = LettaConfig()
|
||||
|
||||
@@ -29,7 +29,7 @@ class BasePassage(SqlalchemyBase, OrganizationMixin):
|
||||
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
|
||||
|
||||
# Vector embedding field based on database type
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
||||
@@ -56,7 +56,7 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
|
||||
@declared_attr
|
||||
def __table_args__(cls):
|
||||
# TODO (cliandy): investigate if this is necessary, may be for SQLite compatability or do we need to add as well?
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
return (
|
||||
Index("source_passages_org_idx", "organization_id"),
|
||||
Index("source_passages_created_at_id_idx", "created_at", "id"),
|
||||
@@ -81,7 +81,7 @@ class AgentPassage(BasePassage, AgentMixin):
|
||||
|
||||
@declared_attr
|
||||
def __table_args__(cls):
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
return (
|
||||
Index("agent_passages_org_idx", "organization_id"),
|
||||
Index("ix_agent_passages_org_agent", "organization_id", "agent_id"),
|
||||
|
||||
@@ -15,6 +15,7 @@ from letta.log import get_logger
|
||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
from letta.settings import DatabaseChoice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
@@ -395,7 +396,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
# PostgreSQL with pgvector
|
||||
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
|
||||
else:
|
||||
|
||||
@@ -14,7 +14,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from letta.config import LettaConfig
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.settings import settings
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -90,7 +90,7 @@ class DatabaseRegistry:
|
||||
return
|
||||
|
||||
# Postgres engine
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
self.logger.info("Creating postgres engine")
|
||||
self.config.recall_storage_type = "postgres"
|
||||
self.config.recall_storage_uri = settings.letta_pg_uri_no_default
|
||||
@@ -128,7 +128,7 @@ class DatabaseRegistry:
|
||||
if self._initialized.get("async") and not force:
|
||||
return
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
self.logger.info("Creating async postgres engine")
|
||||
|
||||
# Create async engine - convert URI to async format
|
||||
|
||||
@@ -104,7 +104,7 @@ from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings, settings, tool_settings
|
||||
from letta.settings import DatabaseChoice, model_settings, settings, tool_settings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||
from letta.utils import get_friendly_error_msg, get_persona_text, make_key
|
||||
|
||||
@@ -196,7 +196,7 @@ class SyncServer(Server):
|
||||
|
||||
# Initialize the metadata store
|
||||
config = LettaConfig.load()
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
config.recall_storage_type = "postgres"
|
||||
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
||||
config.archival_storage_type = "postgres"
|
||||
|
||||
@@ -91,7 +91,7 @@ from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import settings
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types, united_diff
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -2732,7 +2732,7 @@ class AgentManager:
|
||||
)
|
||||
|
||||
if query_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
# PostgreSQL: Use ILIKE for case-insensitive search
|
||||
query = query.filter(AgentsTags.tag.ilike(f"%{query_text}%"))
|
||||
else:
|
||||
@@ -2773,7 +2773,7 @@ class AgentManager:
|
||||
)
|
||||
|
||||
if query_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
# PostgreSQL: Use ILIKE for case-insensitive search
|
||||
query = query.where(AgentsTags.tag.ilike(f"%{query_text}%"))
|
||||
else:
|
||||
|
||||
@@ -37,7 +37,7 @@ from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.system import get_initial_boot_messages, get_login_event, package_function_response
|
||||
|
||||
|
||||
@@ -552,7 +552,7 @@ async def _apply_pagination_async(
|
||||
if result:
|
||||
after_sort_value, after_id = result
|
||||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||||
if not settings.letta_pg_uri_no_default and isinstance(after_sort_value, datetime):
|
||||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime):
|
||||
after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S")
|
||||
query = query.where(
|
||||
_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last)
|
||||
@@ -563,7 +563,7 @@ async def _apply_pagination_async(
|
||||
if result:
|
||||
before_sort_value, before_id = result
|
||||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||||
if not settings.letta_pg_uri_no_default and isinstance(before_sort_value, datetime):
|
||||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime):
|
||||
before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S")
|
||||
query = query.where(
|
||||
_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last)
|
||||
@@ -655,7 +655,7 @@ def _apply_filters(
|
||||
query = query.where(AgentModel.name == name)
|
||||
# Apply a case-insensitive partial match for the agent's name.
|
||||
if query_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
# PostgreSQL: Use ILIKE for case-insensitive search
|
||||
query = query.where(AgentModel.name.ilike(f"%{query_text}%"))
|
||||
else:
|
||||
@@ -801,7 +801,7 @@ def build_passage_query(
|
||||
|
||||
# Vector search
|
||||
if embedded_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
# PostgreSQL with pgvector
|
||||
main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc())
|
||||
else:
|
||||
@@ -928,7 +928,7 @@ def build_source_passage_query(
|
||||
|
||||
# Handle text search or vector search
|
||||
if embedded_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
# PostgreSQL with pgvector
|
||||
query = query.order_by(SourcePassage.embedding.cosine_distance(embedded_text).asc())
|
||||
else:
|
||||
@@ -1015,7 +1015,7 @@ def build_agent_passage_query(
|
||||
|
||||
# Handle text search or vector search
|
||||
if embedded_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
# PostgreSQL with pgvector
|
||||
query = query.order_by(AgentPassage.embedding.cosine_distance(embedded_text).asc())
|
||||
else:
|
||||
|
||||
@@ -13,7 +13,7 @@ from letta.schemas.identity import Identity as PydanticIdentity
|
||||
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityType, IdentityUpdate, IdentityUpsert
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import settings
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class IdentityManager:
|
||||
|
||||
# For SQLite compatibility: check for unique constraint violation manually
|
||||
# since SQLite doesn't support postgresql_nulls_not_distinct=True
|
||||
if not settings.letta_pg_uri_no_default: # Using SQLite
|
||||
if settings.database_engine is DatabaseChoice.SQLITE:
|
||||
# Check if an identity with the same identifier_key, project_id, and organization_id exists
|
||||
query = select(IdentityModel).where(
|
||||
IdentityModel.identifier_key == new_identity.identifier_key,
|
||||
|
||||
@@ -28,7 +28,7 @@ from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import settings
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -336,7 +336,7 @@ class JobManager:
|
||||
|
||||
before_timestamp = before_obj.created_at
|
||||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||||
if not settings.letta_pg_uri_no_default and isinstance(before_timestamp, datetime):
|
||||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_timestamp, datetime):
|
||||
before_timestamp = before_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
conditions.append(
|
||||
@@ -350,7 +350,7 @@ class JobManager:
|
||||
# records after this cursor (newer)
|
||||
after_timestamp = after_obj.created_at
|
||||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||||
if not settings.letta_pg_uri_no_default and isinstance(after_timestamp, datetime):
|
||||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_timestamp, datetime):
|
||||
after_timestamp = after_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
conditions.append(
|
||||
|
||||
@@ -18,7 +18,7 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.file_manager import FileManager
|
||||
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
|
||||
from letta.settings import settings
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -457,7 +457,7 @@ class MessageManager:
|
||||
|
||||
# If query_text is provided, filter messages using database-specific JSON search.
|
||||
if query_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
# PostgreSQL: Use json_array_elements and ILIKE
|
||||
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
|
||||
query = query.filter(
|
||||
@@ -565,7 +565,7 @@ class MessageManager:
|
||||
|
||||
# If query_text is provided, filter messages using database-specific JSON search.
|
||||
if query_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
# PostgreSQL: Use json_array_elements and ILIKE
|
||||
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
|
||||
query = query.where(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -182,6 +183,11 @@ if "--use-file-pg-uri" in sys.argv:
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseChoice(str, Enum):
|
||||
POSTGRES = "postgres"
|
||||
SQLITE = "sqlite"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix="letta_", extra="ignore")
|
||||
|
||||
@@ -291,6 +297,10 @@ class Settings(BaseSettings):
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def database_engine(self) -> DatabaseChoice:
|
||||
return DatabaseChoice.POSTGRES if self.letta_pg_uri_no_default else DatabaseChoice.SQLITE
|
||||
|
||||
@property
|
||||
def plugin_register_dict(self) -> dict:
|
||||
plugins = {}
|
||||
|
||||
Reference in New Issue
Block a user