feat: add print for sqlite error (#2221)

This commit is contained in:
Sarah Wooders
2024-12-10 19:23:48 -08:00
committed by GitHub
parent 25980e05cd
commit 36b97d00b2
3 changed files with 72 additions and 6 deletions

View File

@@ -156,6 +156,11 @@ class Server(object):
raise NotImplementedError
from contextlib import contextmanager
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
@@ -166,6 +171,37 @@ from letta.settings import model_settings, settings, tool_settings
config = LettaConfig.load()
def print_sqlite_schema_error():
"""Print a formatted error message for SQLite schema issues"""
console = Console()
error_text = Text()
error_text.append("Existing SQLite DB schema is invalid, and schema migrations are not supported for SQLite. ", style="bold red")
error_text.append("To have migrations supported between Letta versions, please run Letta with Docker (", style="white")
error_text.append("https://docs.letta.com/server/docker", style="blue underline")
error_text.append(") or use Postgres by setting ", style="white")
error_text.append("LETTA_PG_URI", style="yellow")
error_text.append(".\n\n", style="white")
error_text.append("If you wish to keep using SQLite, you can reset your database by removing the DB file with ", style="white")
error_text.append("rm ~/.letta/sqlite.db", style="yellow")
error_text.append(" or downgrade to your previous version of Letta.", style="white")
console.print(Panel(error_text, border_style="red"))
@contextmanager
def db_error_handler():
"""Context manager for handling database errors"""
try:
yield
except Exception as e:
# Handle other SQLAlchemy errors
print(e)
print_sqlite_schema_error()
# raise ValueError(f"SQLite DB error: {str(e)}")
exit(1)
if settings.letta_pg_uri_no_default:
config.recall_storage_type = "postgres"
config.recall_storage_uri = settings.letta_pg_uri_no_default
@@ -178,6 +214,30 @@ else:
# TODO: don't rely on config storage
engine = create_engine("sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db"))
# Store the original connect method
original_connect = engine.connect
def wrapped_connect(*args, **kwargs):
with db_error_handler():
# Get the connection
connection = original_connect(*args, **kwargs)
# Store the original execution method
original_execute = connection.execute
# Wrap the execute method of the connection
def wrapped_execute(*args, **kwargs):
with db_error_handler():
return original_execute(*args, **kwargs)
# Replace the connection's execute method
connection.execute = wrapped_execute
return connection
# Replace the engine's connect method
engine.connect = wrapped_connect
Base.metadata.create_all(bind=engine)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@@ -379,7 +439,9 @@ class SyncServer(Server):
if agent_state.agent_type == AgentType.memgpt_agent:
agent = Agent(agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence)
elif agent_state.agent_type == AgentType.offline_memory_agent:
agent = OfflineMemoryAgent(agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence)
agent = OfflineMemoryAgent(
agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence
)
else:
assert initial_message_sequence is None, f"Initial message sequence is not supported for O1Agents"
agent = O1Agent(agent_state=agent_state, interface=interface, user=actor)
@@ -500,8 +562,8 @@ class SyncServer(Server):
letta_agent.attach_source(
user=self.user_manager.get_user_by_id(user_id=user_id),
source_id=data_source,
source_manager=letta_agent.source_manager,
ms=self.ms
source_manager=letta_agent.source_manager,
ms=self.ms,
)
elif command.lower() == "dump" or command.lower().startswith("dump "):
@@ -1267,7 +1329,10 @@ class SyncServer(Server):
# iterate over records
records = letta_agent.passage_manager.list_passages(
actor=self.default_user, agent_id=agent_id, cursor=cursor, limit=limit,
actor=self.default_user,
agent_id=agent_id,
cursor=cursor,
limit=limit,
)
return records
@@ -1914,7 +1979,7 @@ class SyncServer(Server):
date=get_utc_time(),
status="error",
function_return=error_msg,
stdout=[''],
stdout=[""],
stderr=[traceback.format_exc()],
)