feat: Serialize agent state simple fields and messages (#1012)

This commit is contained in:
Matthew Zhou
2025-02-18 11:01:10 -08:00
committed by GitHub
parent 3dc1767f46
commit b5e09536ae
28 changed files with 451 additions and 179 deletions

View File

@@ -0,0 +1 @@
from letta.serialize_schemas.agent import SerializedAgentSchema

View File

@@ -0,0 +1,36 @@
from marshmallow import fields
from letta.orm import Agent
from letta.serialize_schemas.base import BaseSchema
from letta.serialize_schemas.custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField
from letta.serialize_schemas.message import SerializedMessageSchema
class SerializedAgentSchema(BaseSchema):
"""
Marshmallow schema for serializing/deserializing Agent objects.
Excludes relational fields.
"""
llm_config = LLMConfigField()
embedding_config = EmbeddingConfigField()
tool_rules = ToolRulesField()
messages = fields.List(fields.Nested(SerializedMessageSchema))
def __init__(self, *args, session=None, **kwargs):
super().__init__(*args, **kwargs)
if session:
self.session = session
# propagate session to nested schemas
for field_name, field_obj in self.fields.items():
if isinstance(field_obj, fields.List) and hasattr(field_obj.inner, "schema"):
field_obj.inner.schema.session = session
elif hasattr(field_obj, "schema"):
field_obj.schema.session = session
class Meta(BaseSchema.Meta):
model = Agent
# TODO: Serialize these as well...
exclude = ("tools", "sources", "core_memory", "tags", "source_passages", "agent_passages", "organization")

View File

@@ -0,0 +1,12 @@
from marshmallow_sqlalchemy import SQLAlchemyAutoSchema
class BaseSchema(SQLAlchemyAutoSchema):
"""
Base schema for all SQLAlchemy models.
This ensures all schemas share the same session.
"""
class Meta:
include_relationships = True
load_instance = True

View File

@@ -0,0 +1,69 @@
from marshmallow import fields
from letta.helpers.converters import (
deserialize_embedding_config,
deserialize_llm_config,
deserialize_tool_calls,
deserialize_tool_rules,
serialize_embedding_config,
serialize_llm_config,
serialize_tool_calls,
serialize_tool_rules,
)
class PydanticField(fields.Field):
"""Generic Marshmallow field for handling Pydantic models."""
def __init__(self, pydantic_class, **kwargs):
self.pydantic_class = pydantic_class
super().__init__(**kwargs)
def _serialize(self, value, attr, obj, **kwargs):
return value.model_dump() if value else None
def _deserialize(self, value, attr, data, **kwargs):
return self.pydantic_class(**value) if value else None
class LLMConfigField(fields.Field):
"""Marshmallow field for handling LLMConfig serialization."""
def _serialize(self, value, attr, obj, **kwargs):
return serialize_llm_config(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_llm_config(value)
class EmbeddingConfigField(fields.Field):
"""Marshmallow field for handling EmbeddingConfig serialization."""
def _serialize(self, value, attr, obj, **kwargs):
return serialize_embedding_config(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_embedding_config(value)
class ToolRulesField(fields.List):
"""Custom Marshmallow field to handle a list of ToolRules."""
def __init__(self, **kwargs):
super().__init__(fields.Dict(), **kwargs)
def _serialize(self, value, attr, obj, **kwargs):
return serialize_tool_rules(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_tool_rules(value)
class ToolCallField(fields.Field):
"""Marshmallow field for handling a list of OpenAI ToolCall objects."""
def _serialize(self, value, attr, obj, **kwargs):
return serialize_tool_calls(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_tool_calls(value)

View File

@@ -0,0 +1,15 @@
from letta.orm.message import Message
from letta.serialize_schemas.base import BaseSchema
from letta.serialize_schemas.custom_fields import ToolCallField
class SerializedMessageSchema(BaseSchema):
"""
Marshmallow schema for serializing/deserializing Message objects.
"""
tool_calls = ToolCallField()
class Meta(BaseSchema.Meta):
model = Message
exclude = ("step", "job_message")