fix: Scale up database (#2263)
This commit is contained in:
@@ -12,3 +12,11 @@ class UniqueConstraintViolationError(ValueError):
|
||||
|
||||
class ForeignKeyConstraintViolationError(ValueError):
|
||||
"""Custom exception for foreign key constraint violations."""
|
||||
|
||||
|
||||
class DatabaseTimeoutError(Exception):
|
||||
"""Custom exception for database timeout issues."""
|
||||
|
||||
def __init__(self, message="Database operation timed out", original_exception=None):
|
||||
super().__init__(message)
|
||||
self.original_exception = original_exception
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
from sqlalchemy import String, desc, func, or_, select
|
||||
from sqlalchemy.exc import DBAPIError, IntegrityError
|
||||
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||
from letta.orm.errors import (
|
||||
DatabaseTimeoutError,
|
||||
ForeignKeyConstraintViolationError,
|
||||
NoResultFound,
|
||||
UniqueConstraintViolationError,
|
||||
@@ -23,6 +25,20 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def handle_db_timeout(func):
|
||||
"""Decorator to handle SQLAlchemy TimeoutError and wrap it in a custom exception."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except TimeoutError as e:
|
||||
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
||||
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class AccessType(str, Enum):
|
||||
ORGANIZATION = "organization"
|
||||
USER = "user"
|
||||
@@ -36,22 +52,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
|
||||
@classmethod
|
||||
def get(cls, *, db_session: Session, id: str) -> Optional["SqlalchemyBase"]:
|
||||
"""Get a record by ID.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
id: Record ID to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[SqlalchemyBase]: The record if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
return db_session.query(cls).filter(cls.id == id).first()
|
||||
except DBAPIError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
def list(
|
||||
cls,
|
||||
*,
|
||||
@@ -180,6 +181,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
return list(session.execute(query).scalars())
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
def read(
|
||||
cls,
|
||||
db_session: "Session",
|
||||
@@ -231,6 +233,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
||||
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
|
||||
|
||||
@handle_db_timeout
|
||||
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
|
||||
@@ -245,6 +248,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
except (DBAPIError, IntegrityError) as e:
|
||||
self._handle_dbapi_error(e)
|
||||
|
||||
@handle_db_timeout
|
||||
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
|
||||
@@ -254,6 +258,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
self.is_deleted = True
|
||||
return self.update(db_session)
|
||||
|
||||
@handle_db_timeout
|
||||
def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
|
||||
"""Permanently removes the record from the database."""
|
||||
logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
@@ -269,6 +274,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
else:
|
||||
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
|
||||
|
||||
@handle_db_timeout
|
||||
def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
if actor:
|
||||
@@ -281,6 +287,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
def size(
|
||||
cls,
|
||||
*,
|
||||
|
||||
@@ -15,7 +15,12 @@ from letta.__init__ import __version__
|
||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||
from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.errors import (
|
||||
DatabaseTimeoutError,
|
||||
ForeignKeyConstraintViolationError,
|
||||
NoResultFound,
|
||||
UniqueConstraintViolationError,
|
||||
)
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.server.constants import REST_DEFAULT_PORT
|
||||
|
||||
@@ -175,7 +180,6 @@ def create_application() -> "FastAPI":
|
||||
|
||||
@app.exception_handler(NoResultFound)
|
||||
async def no_result_found_handler(request: Request, exc: NoResultFound):
|
||||
logger.error(f"NoResultFound request: {request}")
|
||||
logger.error(f"NoResultFound: {exc}")
|
||||
|
||||
return JSONResponse(
|
||||
@@ -183,6 +187,32 @@ def create_application() -> "FastAPI":
|
||||
content={"detail": str(exc)},
|
||||
)
|
||||
|
||||
@app.exception_handler(ForeignKeyConstraintViolationError)
|
||||
async def foreign_key_constraint_handler(request: Request, exc: ForeignKeyConstraintViolationError):
|
||||
logger.error(f"ForeignKeyConstraintViolationError: {exc}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content={"detail": str(exc)},
|
||||
)
|
||||
|
||||
@app.exception_handler(UniqueConstraintViolationError)
|
||||
async def unique_key_constraint_handler(request: Request, exc: UniqueConstraintViolationError):
|
||||
logger.error(f"UniqueConstraintViolationError: {exc}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content={"detail": str(exc)},
|
||||
)
|
||||
|
||||
@app.exception_handler(DatabaseTimeoutError)
|
||||
async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError):
|
||||
logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}")
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"detail": "The database is temporarily unavailable. Please try again later."},
|
||||
)
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def value_error_handler(request: Request, exc: ValueError):
|
||||
return JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
@@ -235,11 +265,6 @@ def create_application() -> "FastAPI":
|
||||
|
||||
@app.on_event("startup")
|
||||
def on_startup():
|
||||
# load the default tools
|
||||
# from letta.orm.tool import Tool
|
||||
|
||||
# Tool.load_default_tools(get_db_session())
|
||||
|
||||
generate_openapi_schema(app)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
|
||||
@@ -190,7 +190,14 @@ if settings.letta_pg_uri_no_default:
|
||||
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
||||
|
||||
# create engine
|
||||
engine = create_engine(settings.letta_pg_uri)
|
||||
engine = create_engine(
|
||||
settings.letta_pg_uri,
|
||||
pool_size=settings.pg_pool_size,
|
||||
max_overflow=settings.pg_max_overflow,
|
||||
pool_timeout=settings.pg_pool_timeout,
|
||||
pool_recycle=settings.pg_pool_recycle,
|
||||
echo=settings.pg_echo,
|
||||
)
|
||||
else:
|
||||
# TODO: don't rely on config storage
|
||||
engine = create_engine("sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db"))
|
||||
|
||||
@@ -17,7 +17,7 @@ class ToolSettings(BaseSettings):
|
||||
|
||||
class ModelSettings(BaseSettings):
|
||||
|
||||
model_config = SettingsConfigDict(env_file='.env', extra='ignore')
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
||||
|
||||
# env_prefix='my_prefix_'
|
||||
|
||||
@@ -64,7 +64,7 @@ cors_origins = ["http://letta.localhost", "http://localhost:8283", "http://local
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix="letta_", extra='ignore')
|
||||
model_config = SettingsConfigDict(env_prefix="letta_", extra="ignore")
|
||||
|
||||
letta_dir: Optional[Path] = Field(Path.home() / ".letta", env="LETTA_DIR")
|
||||
debug: Optional[bool] = False
|
||||
@@ -76,7 +76,12 @@ class Settings(BaseSettings):
|
||||
pg_password: Optional[str] = None
|
||||
pg_host: Optional[str] = None
|
||||
pg_port: Optional[int] = None
|
||||
pg_uri: Optional[str] = None # option to specifiy full uri
|
||||
pg_uri: Optional[str] = None # option to specify full uri
|
||||
pg_pool_size: int = 20 # Concurrent connections
|
||||
pg_max_overflow: int = 10 # Overflow limit
|
||||
pg_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
pg_pool_recycle: int = 1800 # When to recycle connections
|
||||
pg_echo: bool = False # Logging
|
||||
|
||||
# tools configuration
|
||||
load_default_external_tools: Optional[bool] = None
|
||||
@@ -103,7 +108,7 @@ class Settings(BaseSettings):
|
||||
|
||||
|
||||
class TestSettings(Settings):
|
||||
model_config = SettingsConfigDict(env_prefix="letta_test_", extra='ignore')
|
||||
model_config = SettingsConfigDict(env_prefix="letta_test_", extra="ignore")
|
||||
|
||||
letta_dir: Optional[Path] = Field(Path.home() / ".letta/test", env="LETTA_TEST_DIR")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user