diff --git a/letta/__init__.py b/letta/__init__.py index bb6eb8a4..b52f7c04 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -10,8 +10,16 @@ except PackageNotFoundError: if os.environ.get("LETTA_VERSION"): __version__ = os.environ["LETTA_VERSION"] -# Import sqlite_functions early to ensure event handlers are registered -from letta.orm import sqlite_functions +# Import sqlite_functions early to ensure event handlers are registered (only for SQLite) +# This is only needed for the server, not for client usage +try: + from letta.settings import DatabaseChoice, settings + + if settings.database_engine == DatabaseChoice.SQLITE: + from letta.orm import sqlite_functions +except ImportError: + # If sqlite_vec is not installed, it's fine for client usage + pass # # imports for easier access from letta.schemas.agent import AgentState diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index d2fc323a..522ee031 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -44,8 +44,14 @@ from letta.schemas.tool_rule import ( ) from letta.settings import DatabaseChoice, settings -if settings.database_engine == DatabaseChoice.SQLITE: - import sqlite_vec +# Only import sqlite_vec if we're actually using SQLite database +# This is a runtime dependency only needed for SQLite vector operations +try: + if settings.database_engine == DatabaseChoice.SQLITE: + import sqlite_vec +except ImportError: + # If sqlite_vec is not installed, it's fine for client usage + pass # -------------------------- # LLMConfig Serialization # -------------------------- diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index f9ee7452..c047839f 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -14,7 +14,6 @@ from sqlalchemy.orm.interfaces import ORMOption from letta.log import get_logger from letta.orm.base import Base, CommonSqlalchemyMetaMixins from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError -from letta.orm.sqlite_functions import adapt_array from letta.settings import DatabaseChoice if TYPE_CHECKING: @@ -401,6 +400,8 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc()) else: # SQLite with custom vector type + from letta.orm.sqlite_functions import adapt_array + query_embedding_binary = adapt_array(query_embedding) query = query.order_by( func.cosine_distance(cls.embedding, query_embedding_binary).asc(), diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index d3880a6c..f0998108 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -29,7 +29,6 @@ from letta.orm.errors import NoResultFound from letta.orm.identity import Identity from letta.orm.passage import ArchivalPassage, SourcePassage from letta.orm.sources_agents import SourcesAgents -from letta.orm.sqlite_functions import adapt_array from letta.otel.tracing import trace_method from letta.prompts import gpt_system from letta.prompts.prompt_generator import PromptGenerator @@ -921,6 +920,8 @@ async def build_passage_query( main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc()) else: # SQLite with custom vector type + from letta.orm.sqlite_functions import adapt_array + query_embedding_binary = adapt_array(embedded_text) main_query = main_query.order_by( func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(), @@ -1054,6 +1055,8 @@ async def build_source_passage_query( query = query.order_by(SourcePassage.embedding.cosine_distance(embedded_text).asc()) else: # SQLite with custom vector type + from letta.orm.sqlite_functions import adapt_array + query_embedding_binary = adapt_array(embedded_text) query = query.order_by( func.cosine_distance(SourcePassage.embedding, query_embedding_binary).asc(), @@ -1151,6 +1154,8 @@ async def build_agent_passage_query( query = query.order_by(ArchivalPassage.embedding.cosine_distance(embedded_text).asc()) else: # SQLite with custom vector type + from letta.orm.sqlite_functions import adapt_array + query_embedding_binary = adapt_array(embedded_text) query = query.order_by( func.cosine_distance(ArchivalPassage.embedding, query_embedding_binary).asc(), diff --git a/tests/integration_test_async_tool_sandbox.py b/tests/integration_test_async_tool_sandbox.py index cf729e5d..884cad61 100644 --- a/tests/integration_test_async_tool_sandbox.py +++ b/tests/integration_test_async_tool_sandbox.py @@ -36,14 +36,6 @@ user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user")) # Set environment variable immediately to prevent pooling issues os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] = "true" -# Recreate settings instance to pick up the environment variable -import letta.settings - -# Force settings reload after setting environment variable -from letta.settings import Settings - -letta.settings.settings = Settings() - # Disable SQLAlchemy connection pooling for tests to prevent event loop issues @pytest.fixture(scope="session", autouse=True)