chore: import sqlite-vec even more conditionally (#2964)
This commit is contained in:
@@ -10,8 +10,16 @@ except PackageNotFoundError:
|
||||
if os.environ.get("LETTA_VERSION"):
|
||||
__version__ = os.environ["LETTA_VERSION"]
|
||||
|
||||
# Import sqlite_functions early to ensure event handlers are registered
|
||||
from letta.orm import sqlite_functions
|
||||
# Import sqlite_functions early to ensure event handlers are registered (only for SQLite)
|
||||
# This is only needed for the server, not for client usage
|
||||
try:
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
if settings.database_engine == DatabaseChoice.SQLITE:
|
||||
from letta.orm import sqlite_functions
|
||||
except ImportError:
|
||||
# If sqlite_vec is not installed, it's fine for client usage
|
||||
pass
|
||||
|
||||
# # imports for easier access
|
||||
from letta.schemas.agent import AgentState
|
||||
|
||||
@@ -44,8 +44,14 @@ from letta.schemas.tool_rule import (
|
||||
)
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
if settings.database_engine == DatabaseChoice.SQLITE:
|
||||
import sqlite_vec
|
||||
# Only import sqlite_vec if we're actually using SQLite database
|
||||
# This is a runtime dependency only needed for SQLite vector operations
|
||||
try:
|
||||
if settings.database_engine == DatabaseChoice.SQLITE:
|
||||
import sqlite_vec
|
||||
except ImportError:
|
||||
# If sqlite_vec is not installed, it's fine for client usage
|
||||
pass
|
||||
# --------------------------
|
||||
# LLMConfig Serialization
|
||||
# --------------------------
|
||||
|
||||
@@ -14,7 +14,6 @@ from sqlalchemy.orm.interfaces import ORMOption
|
||||
from letta.log import get_logger
|
||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
from letta.settings import DatabaseChoice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -401,6 +400,8 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
|
||||
else:
|
||||
# SQLite with custom vector type
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
|
||||
query_embedding_binary = adapt_array(query_embedding)
|
||||
query = query.order_by(
|
||||
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
|
||||
|
||||
@@ -29,7 +29,6 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.identity import Identity
|
||||
from letta.orm.passage import ArchivalPassage, SourcePassage
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.prompts import gpt_system
|
||||
from letta.prompts.prompt_generator import PromptGenerator
|
||||
@@ -921,6 +920,8 @@ async def build_passage_query(
|
||||
main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc())
|
||||
else:
|
||||
# SQLite with custom vector type
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
|
||||
query_embedding_binary = adapt_array(embedded_text)
|
||||
main_query = main_query.order_by(
|
||||
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
||||
@@ -1054,6 +1055,8 @@ async def build_source_passage_query(
|
||||
query = query.order_by(SourcePassage.embedding.cosine_distance(embedded_text).asc())
|
||||
else:
|
||||
# SQLite with custom vector type
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
|
||||
query_embedding_binary = adapt_array(embedded_text)
|
||||
query = query.order_by(
|
||||
func.cosine_distance(SourcePassage.embedding, query_embedding_binary).asc(),
|
||||
@@ -1151,6 +1154,8 @@ async def build_agent_passage_query(
|
||||
query = query.order_by(ArchivalPassage.embedding.cosine_distance(embedded_text).asc())
|
||||
else:
|
||||
# SQLite with custom vector type
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
|
||||
query_embedding_binary = adapt_array(embedded_text)
|
||||
query = query.order_by(
|
||||
func.cosine_distance(ArchivalPassage.embedding, query_embedding_binary).asc(),
|
||||
|
||||
@@ -36,14 +36,6 @@ user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user"))
|
||||
# Set environment variable immediately to prevent pooling issues
|
||||
os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] = "true"
|
||||
|
||||
# Recreate settings instance to pick up the environment variable
|
||||
import letta.settings
|
||||
|
||||
# Force settings reload after setting environment variable
|
||||
from letta.settings import Settings
|
||||
|
||||
letta.settings.settings = Settings()
|
||||
|
||||
|
||||
# Disable SQLAlchemy connection pooling for tests to prevent event loop issues
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
|
||||
Reference in New Issue
Block a user