feat: Serialize agent state simple fields and messages (#1012)
This commit is contained in:
1
letta/serialize_schemas/__init__.py
Normal file
1
letta/serialize_schemas/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from letta.serialize_schemas.agent import SerializedAgentSchema
|
||||
36
letta/serialize_schemas/agent.py
Normal file
36
letta/serialize_schemas/agent.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from marshmallow import fields
|
||||
|
||||
from letta.orm import Agent
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField
|
||||
from letta.serialize_schemas.message import SerializedMessageSchema
|
||||
|
||||
|
||||
class SerializedAgentSchema(BaseSchema):
|
||||
"""
|
||||
Marshmallow schema for serializing/deserializing Agent objects.
|
||||
Excludes relational fields.
|
||||
"""
|
||||
|
||||
llm_config = LLMConfigField()
|
||||
embedding_config = EmbeddingConfigField()
|
||||
tool_rules = ToolRulesField()
|
||||
|
||||
messages = fields.List(fields.Nested(SerializedMessageSchema))
|
||||
|
||||
def __init__(self, *args, session=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if session:
|
||||
self.session = session
|
||||
|
||||
# propagate session to nested schemas
|
||||
for field_name, field_obj in self.fields.items():
|
||||
if isinstance(field_obj, fields.List) and hasattr(field_obj.inner, "schema"):
|
||||
field_obj.inner.schema.session = session
|
||||
elif hasattr(field_obj, "schema"):
|
||||
field_obj.schema.session = session
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Agent
|
||||
# TODO: Serialize these as well...
|
||||
exclude = ("tools", "sources", "core_memory", "tags", "source_passages", "agent_passages", "organization")
|
||||
12
letta/serialize_schemas/base.py
Normal file
12
letta/serialize_schemas/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from marshmallow_sqlalchemy import SQLAlchemyAutoSchema
|
||||
|
||||
|
||||
class BaseSchema(SQLAlchemyAutoSchema):
|
||||
"""
|
||||
Base schema for all SQLAlchemy models.
|
||||
This ensures all schemas share the same session.
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
include_relationships = True
|
||||
load_instance = True
|
||||
69
letta/serialize_schemas/custom_fields.py
Normal file
69
letta/serialize_schemas/custom_fields.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from marshmallow import fields
|
||||
|
||||
from letta.helpers.converters import (
|
||||
deserialize_embedding_config,
|
||||
deserialize_llm_config,
|
||||
deserialize_tool_calls,
|
||||
deserialize_tool_rules,
|
||||
serialize_embedding_config,
|
||||
serialize_llm_config,
|
||||
serialize_tool_calls,
|
||||
serialize_tool_rules,
|
||||
)
|
||||
|
||||
|
||||
class PydanticField(fields.Field):
|
||||
"""Generic Marshmallow field for handling Pydantic models."""
|
||||
|
||||
def __init__(self, pydantic_class, **kwargs):
|
||||
self.pydantic_class = pydantic_class
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return value.model_dump() if value else None
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return self.pydantic_class(**value) if value else None
|
||||
|
||||
|
||||
class LLMConfigField(fields.Field):
|
||||
"""Marshmallow field for handling LLMConfig serialization."""
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return serialize_llm_config(value)
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return deserialize_llm_config(value)
|
||||
|
||||
|
||||
class EmbeddingConfigField(fields.Field):
|
||||
"""Marshmallow field for handling EmbeddingConfig serialization."""
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return serialize_embedding_config(value)
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return deserialize_embedding_config(value)
|
||||
|
||||
|
||||
class ToolRulesField(fields.List):
|
||||
"""Custom Marshmallow field to handle a list of ToolRules."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(fields.Dict(), **kwargs)
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return serialize_tool_rules(value)
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return deserialize_tool_rules(value)
|
||||
|
||||
|
||||
class ToolCallField(fields.Field):
|
||||
"""Marshmallow field for handling a list of OpenAI ToolCall objects."""
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return serialize_tool_calls(value)
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return deserialize_tool_calls(value)
|
||||
15
letta/serialize_schemas/message.py
Normal file
15
letta/serialize_schemas/message.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from letta.orm.message import Message
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.custom_fields import ToolCallField
|
||||
|
||||
|
||||
class SerializedMessageSchema(BaseSchema):
|
||||
"""
|
||||
Marshmallow schema for serializing/deserializing Message objects.
|
||||
"""
|
||||
|
||||
tool_calls = ToolCallField()
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Message
|
||||
exclude = ("step", "job_message")
|
||||
Reference in New Issue
Block a user