chore: Change create_tool endpoint on v1 routes to error instead of upsert (#2102)

This commit is contained in:
Matthew Zhou
2024-11-25 10:46:15 -08:00
committed by GitHub
parent f237717ce4
commit 8711e1dc00
17 changed files with 271 additions and 259 deletions

View File

@@ -1,11 +1,16 @@
from typing import TYPE_CHECKING, List, Literal, Optional, Type
from sqlalchemy import String, select
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Mapped, mapped_column
from letta.log import get_logger
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
from letta.orm.errors import NoResultFound
from letta.orm.errors import (
ForeignKeyConstraintViolationError,
NoResultFound,
UniqueConstraintViolationError,
)
if TYPE_CHECKING:
from pydantic import BaseModel
@@ -102,12 +107,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
if actor:
self._set_created_and_updated_by_fields(actor.id)
with db_session as session:
session.add(self)
session.commit()
session.refresh(self)
return self
try:
with db_session as session:
session.add(self)
session.commit()
session.refresh(self)
return self
except DBAPIError as e:
self._handle_dbapi_error(e)
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
@@ -168,6 +175,38 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
raise ValueError(f"object {actor} has no organization accessor")
return query.where(cls.organization_id == org_id, cls.is_deleted == False)
@classmethod
def _handle_dbapi_error(cls, e: DBAPIError):
"""Handle database errors and raise appropriate custom exceptions."""
orig = e.orig # Extract the original error from the DBAPIError
error_code = None
# For psycopg2
if hasattr(orig, "pgcode"):
error_code = orig.pgcode
# For pg8000
elif hasattr(orig, "args") and len(orig.args) > 0:
# The first argument contains the error details as a dictionary
err_dict = orig.args[0]
if isinstance(err_dict, dict):
error_code = err_dict.get("C") # 'C' is the error code field
logger.info(f"Extracted error_code: {error_code}")
# Handle unique constraint violations
if error_code == "23505":
raise UniqueConstraintViolationError(
f"A unique constraint was violated for {cls.__name__}. Check your input for duplicates: {e}"
) from e
# Handle foreign key violations
if error_code == "23503":
raise ForeignKeyConstraintViolationError(
f"A foreign key constraint was violated for {cls.__name__}. Check your input for missing or invalid references: {e}"
) from e
# Re-raise for other unhandled DBAPI errors
raise
@property
def __pydantic_model__(self) -> Type["BaseModel"]:
raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.")