fix: Scale up database (#2263)

This commit is contained in:
Matthew Zhou
2024-12-17 15:02:28 -08:00
committed by GitHub
parent 27ea364a32
commit e09bde67ef
5 changed files with 81 additions and 29 deletions

View File

@@ -12,3 +12,11 @@ class UniqueConstraintViolationError(ValueError):
class ForeignKeyConstraintViolationError(ValueError):
"""Custom exception for foreign key constraint violations."""
class DatabaseTimeoutError(Exception):
"""Custom exception for database timeout issues."""
def __init__(self, message="Database operation timed out", original_exception=None):
super().__init__(message)
self.original_exception = original_exception

View File

@@ -1,14 +1,16 @@
from datetime import datetime
from enum import Enum
from functools import wraps
from typing import TYPE_CHECKING, List, Literal, Optional
from sqlalchemy import String, desc, func, or_, select
from sqlalchemy.exc import DBAPIError, IntegrityError
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
from sqlalchemy.orm import Mapped, Session, mapped_column
from letta.log import get_logger
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
from letta.orm.errors import (
DatabaseTimeoutError,
ForeignKeyConstraintViolationError,
NoResultFound,
UniqueConstraintViolationError,
@@ -23,6 +25,20 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def handle_db_timeout(func):
"""Decorator to handle SQLAlchemy TimeoutError and wrap it in a custom exception."""
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except TimeoutError as e:
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
return wrapper
class AccessType(str, Enum):
ORGANIZATION = "organization"
USER = "user"
@@ -36,22 +52,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
id: Mapped[str] = mapped_column(String, primary_key=True)
@classmethod
def get(cls, *, db_session: Session, id: str) -> Optional["SqlalchemyBase"]:
"""Get a record by ID.
Args:
db_session: SQLAlchemy session
id: Record ID to retrieve
Returns:
Optional[SqlalchemyBase]: The record if found, None otherwise
"""
try:
return db_session.query(cls).filter(cls.id == id).first()
except DBAPIError:
return None
@classmethod
@handle_db_timeout
def list(
cls,
*,
@@ -180,6 +181,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
return list(session.execute(query).scalars())
@classmethod
@handle_db_timeout
def read(
cls,
db_session: "Session",
@@ -231,6 +233,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
@handle_db_timeout
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
@@ -245,6 +248,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
except (DBAPIError, IntegrityError) as e:
self._handle_dbapi_error(e)
@handle_db_timeout
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
@@ -254,6 +258,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
self.is_deleted = True
return self.update(db_session)
@handle_db_timeout
def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
"""Permanently removes the record from the database."""
logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
@@ -269,6 +274,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
else:
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
@handle_db_timeout
def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
if actor:
@@ -281,6 +287,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
return self
@classmethod
@handle_db_timeout
def size(
cls,
*,

View File

@@ -15,7 +15,12 @@ from letta.__init__ import __version__
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.orm.errors import (
DatabaseTimeoutError,
ForeignKeyConstraintViolationError,
NoResultFound,
UniqueConstraintViolationError,
)
from letta.schemas.letta_response import LettaResponse
from letta.server.constants import REST_DEFAULT_PORT
@@ -175,7 +180,6 @@ def create_application() -> "FastAPI":
@app.exception_handler(NoResultFound)
async def no_result_found_handler(request: Request, exc: NoResultFound):
logger.error(f"NoResultFound request: {request}")
logger.error(f"NoResultFound: {exc}")
return JSONResponse(
@@ -183,6 +187,32 @@ def create_application() -> "FastAPI":
content={"detail": str(exc)},
)
@app.exception_handler(ForeignKeyConstraintViolationError)
async def foreign_key_constraint_handler(request: Request, exc: ForeignKeyConstraintViolationError):
logger.error(f"ForeignKeyConstraintViolationError: {exc}")
return JSONResponse(
status_code=409,
content={"detail": str(exc)},
)
@app.exception_handler(UniqueConstraintViolationError)
async def unique_key_constraint_handler(request: Request, exc: UniqueConstraintViolationError):
logger.error(f"UniqueConstraintViolationError: {exc}")
return JSONResponse(
status_code=409,
content={"detail": str(exc)},
)
@app.exception_handler(DatabaseTimeoutError)
async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError):
logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}")
return JSONResponse(
status_code=503,
content={"detail": "The database is temporarily unavailable. Please try again later."},
)
@app.exception_handler(ValueError)
async def value_error_handler(request: Request, exc: ValueError):
return JSONResponse(status_code=400, content={"detail": str(exc)})
@@ -235,11 +265,6 @@ def create_application() -> "FastAPI":
@app.on_event("startup")
def on_startup():
# load the default tools
# from letta.orm.tool import Tool
# Tool.load_default_tools(get_db_session())
generate_openapi_schema(app)
@app.on_event("shutdown")

View File

@@ -190,7 +190,14 @@ if settings.letta_pg_uri_no_default:
config.archival_storage_uri = settings.letta_pg_uri_no_default
# create engine
engine = create_engine(settings.letta_pg_uri)
engine = create_engine(
settings.letta_pg_uri,
pool_size=settings.pg_pool_size,
max_overflow=settings.pg_max_overflow,
pool_timeout=settings.pg_pool_timeout,
pool_recycle=settings.pg_pool_recycle,
echo=settings.pg_echo,
)
else:
# TODO: don't rely on config storage
engine = create_engine("sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db"))

View File

@@ -17,7 +17,7 @@ class ToolSettings(BaseSettings):
class ModelSettings(BaseSettings):
model_config = SettingsConfigDict(env_file='.env', extra='ignore')
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
# env_prefix='my_prefix_'
@@ -64,7 +64,7 @@ cors_origins = ["http://letta.localhost", "http://localhost:8283", "http://local
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="letta_", extra='ignore')
model_config = SettingsConfigDict(env_prefix="letta_", extra="ignore")
letta_dir: Optional[Path] = Field(Path.home() / ".letta", env="LETTA_DIR")
debug: Optional[bool] = False
@@ -76,7 +76,12 @@ class Settings(BaseSettings):
pg_password: Optional[str] = None
pg_host: Optional[str] = None
pg_port: Optional[int] = None
pg_uri: Optional[str] = None # option to specifiy full uri
pg_uri: Optional[str] = None # option to specify full uri
pg_pool_size: int = 20 # Concurrent connections
pg_max_overflow: int = 10 # Overflow limit
pg_pool_timeout: int = 30 # Seconds to wait for a connection
pg_pool_recycle: int = 1800 # When to recycle connections
pg_echo: bool = False # Logging
# tools configuration
load_default_external_tools: Optional[bool] = None
@@ -103,7 +108,7 @@ class Settings(BaseSettings):
class TestSettings(Settings):
model_config = SettingsConfigDict(env_prefix="letta_test_", extra='ignore')
model_config = SettingsConfigDict(env_prefix="letta_test_", extra="ignore")
letta_dir: Optional[Path] = Field(Path.home() / ".letta/test", env="LETTA_TEST_DIR")