Files
letta-server/letta/serialize_schemas/marshmallow_custom_fields.py
Kian Jones b8e9a80d93 merge this (#4759)
* wait I forgot to comit locally

* cp the entire core directory and then rm the .git subdir
2025-09-17 15:47:40 -07:00

82 lines
2.4 KiB
Python

from marshmallow import fields
from letta.helpers.converters import (
deserialize_embedding_config,
deserialize_llm_config,
deserialize_message_content,
deserialize_tool_calls,
deserialize_tool_rules,
serialize_embedding_config,
serialize_llm_config,
serialize_message_content,
serialize_tool_calls,
serialize_tool_rules,
)
class PydanticField(fields.Field):
"""Generic Marshmallow field for handling Pydantic models."""
def __init__(self, pydantic_class, **kwargs):
self.pydantic_class = pydantic_class
super().__init__(**kwargs)
def _serialize(self, value, attr, obj, **kwargs):
return value.model_dump() if value else None
def _deserialize(self, value, attr, data, **kwargs):
return self.pydantic_class(**value) if value else None
class LLMConfigField(fields.Field):
"""Marshmallow field for handling LLMConfig serialization."""
def _serialize(self, value, attr, obj, **kwargs):
return serialize_llm_config(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_llm_config(value)
class EmbeddingConfigField(fields.Field):
"""Marshmallow field for handling EmbeddingConfig serialization."""
def _serialize(self, value, attr, obj, **kwargs):
return serialize_embedding_config(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_embedding_config(value)
class ToolRulesField(fields.List):
"""Custom Marshmallow field to handle a list of ToolRules."""
def __init__(self, **kwargs):
super().__init__(fields.Dict(), **kwargs)
def _serialize(self, value, attr, obj, **kwargs):
return serialize_tool_rules(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_tool_rules(value)
class ToolCallField(fields.Field):
"""Marshmallow field for handling a list of OpenAI ToolCall objects."""
def _serialize(self, value, attr, obj, **kwargs):
return serialize_tool_calls(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_tool_calls(value)
class MessageContentField(fields.Field):
"""Marshmallow field for handling a list of Message Content Part objects."""
def _serialize(self, value, attr, obj, **kwargs):
return serialize_message_content(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_message_content(value)