diff --git a/letta/schemas/letta_base.py b/letta/schemas/letta_base.py index abd87d5d..0d63fecf 100644 --- a/letta/schemas/letta_base.py +++ b/letta/schemas/letta_base.py @@ -1,10 +1,10 @@ import uuid from datetime import datetime from logging import getLogger -from typing import Optional +from typing import Any, Optional from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator # from: https://gist.github.com/norton120/22242eadb80bf2cf1dd54a961b151c61 @@ -12,7 +12,21 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator logger = getLogger(__name__) -class LettaBase(BaseModel): +class LoggingBaseModel(BaseModel): + """Base model with global validation error logging for all pydantic models.""" + + @model_validator(mode="wrap") + @classmethod + def _log_validation_errors(cls, values: Any, handler, info): + """Global validator to log validation errors with the full data that failed validation.""" + try: + return handler(values) + except ValidationError as e: + logger.error(f"Pydantic validation error in {cls.__name__}. Input data: {values}. Error: {e}") + raise + + +class LettaBase(LoggingBaseModel): """Base schema for Letta schemas (does not include model provider schemas, e.g. OpenAI)""" model_config = ConfigDict( diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index 1f5b5155..278979cb 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -5,6 +5,7 @@ from typing import Annotated, List, Literal, Optional, Union from pydantic import BaseModel, Field, field_serializer, field_validator +from letta.schemas.letta_base import LoggingBaseModel from letta.schemas.letta_message_content import ( LettaAssistantMessageContentUnion, LettaUserMessageContentUnion, @@ -22,7 +23,7 @@ class MessageReturnType(str, Enum): tool = "tool" -class MessageReturn(BaseModel): +class MessageReturn(LoggingBaseModel): type: MessageReturnType = Field(..., description="The message type to be created.")