feat: properly scrub ids from serialized schemas and add version (#1258)
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from typing import Dict
|
||||
|
||||
from marshmallow import fields, post_dump
|
||||
from marshmallow import fields, post_dump, pre_load
|
||||
|
||||
import letta
|
||||
from letta.orm import Agent
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.user import User
|
||||
@@ -23,6 +24,12 @@ class SerializedAgentSchema(BaseSchema):
|
||||
|
||||
__pydantic_model__ = PydanticAgentState
|
||||
|
||||
FIELD_VERSION = "version"
|
||||
FIELD_MESSAGES = "messages"
|
||||
FIELD_MESSAGE_IDS = "message_ids"
|
||||
FIELD_IN_CONTEXT = "in_context"
|
||||
FIELD_ID = "id"
|
||||
|
||||
llm_config = LLMConfigField()
|
||||
embedding_config = EmbeddingConfigField()
|
||||
tool_rules = ToolRulesField()
|
||||
@@ -48,25 +55,47 @@ class SerializedAgentSchema(BaseSchema):
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs):
|
||||
"""
|
||||
- Removes `message_ids`
|
||||
- Adds versioning
|
||||
- Marks messages as in-context
|
||||
- Removes individual message `id` fields
|
||||
"""
|
||||
data = super().sanitize_ids(data, **kwargs)
|
||||
data[self.FIELD_VERSION] = letta.__version__
|
||||
|
||||
# Remap IDs of messages
|
||||
# Need to do this in post, so we can correctly map the in-context message IDs
|
||||
# TODO: Remap message_ids to reference objects, not just be a list
|
||||
id_remapping = dict()
|
||||
for message in data.get("messages"):
|
||||
message_id = message.get("id")
|
||||
if message_id not in id_remapping:
|
||||
id_remapping[message_id] = SerializedMessageSchema.__pydantic_model__.generate_id()
|
||||
message["id"] = id_remapping[message_id]
|
||||
else:
|
||||
raise ValueError(f"Duplicate message IDs in agent.messages: {message_id}")
|
||||
message_ids = set(data.pop(self.FIELD_MESSAGE_IDS, [])) # Store and remove message_ids
|
||||
|
||||
# Remap in context message ids
|
||||
data["message_ids"] = [id_remapping[message_id] for message_id in data.get("message_ids")]
|
||||
for message in data.get(self.FIELD_MESSAGES, []):
|
||||
message[self.FIELD_IN_CONTEXT] = message[self.FIELD_ID] in message_ids # Mark messages as in-context
|
||||
message.pop(self.FIELD_ID, None) # Remove the id field
|
||||
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def check_version(self, data, **kwargs):
|
||||
"""Check version and remove it from the schema"""
|
||||
version = data[self.FIELD_VERSION]
|
||||
if version != letta.__version__:
|
||||
print(f"Version mismatch: expected {letta.__version__}, got {version}")
|
||||
del data[self.FIELD_VERSION]
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def remap_in_context_messages(self, data, **kwargs):
|
||||
"""
|
||||
Restores `message_ids` by collecting message IDs where `in_context` is True,
|
||||
generates new IDs for all messages, and removes `in_context` from all messages.
|
||||
"""
|
||||
message_ids = []
|
||||
for msg in data.get(self.FIELD_MESSAGES, []):
|
||||
msg[self.FIELD_ID] = SerializedMessageSchema.generate_id() # Generate new ID
|
||||
if msg.pop(self.FIELD_IN_CONTEXT, False): # If it was in-context, track its new ID
|
||||
message_ids.append(msg[self.FIELD_ID])
|
||||
|
||||
data[self.FIELD_MESSAGE_IDS] = message_ids
|
||||
return data
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Agent
|
||||
# TODO: Serialize these as well...
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Dict, Optional
|
||||
|
||||
from marshmallow import post_dump, pre_load
|
||||
from marshmallow_sqlalchemy import SQLAlchemyAutoSchema
|
||||
from sqlalchemy.inspection import inspect
|
||||
|
||||
from letta.schemas.user import User
|
||||
|
||||
@@ -14,46 +13,35 @@ class BaseSchema(SQLAlchemyAutoSchema):
|
||||
"""
|
||||
|
||||
__pydantic_model__ = None
|
||||
sensitive_ids = {"_created_by_id", "_last_updated_by_id"}
|
||||
sensitive_relationships = {"organization"}
|
||||
id_scramble_placeholder = "xxx"
|
||||
|
||||
def __init__(self, *args, actor: Optional[User] = None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.actor = actor
|
||||
|
||||
def generate_id(self) -> Optional[str]:
|
||||
if self.__pydantic_model__:
|
||||
return self.__pydantic_model__.generate_id()
|
||||
@classmethod
|
||||
def generate_id(cls) -> Optional[str]:
|
||||
if cls.__pydantic_model__:
|
||||
return cls.__pydantic_model__.generate_id()
|
||||
|
||||
return None
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
if self.Meta.model:
|
||||
mapper = inspect(self.Meta.model)
|
||||
if "id" in mapper.columns:
|
||||
generated_id = self.generate_id()
|
||||
if generated_id:
|
||||
data["id"] = generated_id
|
||||
|
||||
for sensitive_id in BaseSchema.sensitive_ids.union(BaseSchema.sensitive_relationships):
|
||||
if sensitive_id in data:
|
||||
data[sensitive_id] = BaseSchema.id_scramble_placeholder
|
||||
# delete id
|
||||
del data["id"]
|
||||
del data["_created_by_id"]
|
||||
del data["_last_updated_by_id"]
|
||||
del data["organization"]
|
||||
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def regenerate_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
if self.Meta.model:
|
||||
mapper = inspect(self.Meta.model)
|
||||
for sensitive_id in BaseSchema.sensitive_ids:
|
||||
if sensitive_id in mapper.columns:
|
||||
data[sensitive_id] = self.actor.id
|
||||
|
||||
for relationship in BaseSchema.sensitive_relationships:
|
||||
if relationship in mapper.relationships:
|
||||
data[relationship] = self.actor.organization_id
|
||||
data["id"] = self.generate_id()
|
||||
data["_created_by_id"] = self.actor.id
|
||||
data["_last_updated_by_id"] = self.actor.id
|
||||
data["organization"] = self.actor.organization_id
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Dict
|
||||
|
||||
from marshmallow import post_dump
|
||||
from marshmallow import post_dump, pre_load
|
||||
|
||||
from letta.orm.message import Message
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
@@ -18,12 +18,25 @@ class SerializedMessageSchema(BaseSchema):
|
||||
tool_calls = ToolCallField()
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs):
|
||||
# We don't want to remap here
|
||||
# Because of the way that message_ids is just a JSON field on agents
|
||||
# We need to wait for the agent dumps, and then keep track of all the message IDs we remapped
|
||||
def sanitize_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
# keep id for remapping later on agent dump
|
||||
# agent dump will then get rid of message ids
|
||||
del data["_created_by_id"]
|
||||
del data["_last_updated_by_id"]
|
||||
del data["organization"]
|
||||
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def regenerate_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
if self.Meta.model:
|
||||
# Skip regenerating ID, as agent dump will do it
|
||||
data["_created_by_id"] = self.actor.id
|
||||
data["_last_updated_by_id"] = self.actor.id
|
||||
data["organization"] = self.actor.organization_id
|
||||
|
||||
return data
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Message
|
||||
exclude = BaseSchema.Meta.exclude + ("step", "job_message", "agent")
|
||||
exclude = BaseSchema.Meta.exclude + ("step", "job_message", "agent", "otid", "is_deleted")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from marshmallow import fields
|
||||
from typing import Dict
|
||||
|
||||
from marshmallow import fields, post_dump, pre_load
|
||||
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
@@ -13,6 +15,14 @@ class SerializedAgentTagSchema(BaseSchema):
|
||||
|
||||
tag = fields.String(required=True)
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs):
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def regenerate_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
return data
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = AgentsTags
|
||||
exclude = BaseSchema.Meta.exclude + ("agent",)
|
||||
|
||||
@@ -394,22 +394,6 @@ def test_deserialize_override_existing_tools(
|
||||
assert existing_tool.source_code == weather_tool.source_code, f"Tool {tool_name} should NOT be overridden"
|
||||
|
||||
|
||||
def test_in_context_message_id_remapping(local_client, server, serialize_test_agent, default_user, other_user):
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
# Check remapping on message_ids and messages is consistent
|
||||
assert sorted([m["id"] for m in result["messages"]]) == sorted(result["message_ids"])
|
||||
|
||||
# Deserialize the agent
|
||||
agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user)
|
||||
|
||||
# Make sure all the messages are able to be retrieved
|
||||
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user)
|
||||
assert len(in_context_messages) == len(result["message_ids"])
|
||||
assert sorted([m.id for m in in_context_messages]) == sorted(result["message_ids"])
|
||||
|
||||
|
||||
def test_agent_serialize_with_user_messages(local_client, server, serialize_test_agent, default_user, other_user):
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
append_copy_suffix = False
|
||||
@@ -473,6 +457,19 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server,
|
||||
assert copy_agent_response.completion_tokens > 0 and copy_agent_response.step_count > 0
|
||||
|
||||
|
||||
def test_in_context_message_id_remapping(local_client, server, serialize_test_agent, default_user, other_user):
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
# Deserialize the agent
|
||||
agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user)
|
||||
|
||||
# Make sure all the messages are able to be retrieved
|
||||
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user)
|
||||
assert len(in_context_messages) == len(result["message_ids"])
|
||||
assert sorted([m.id for m in in_context_messages]) == sorted(result["message_ids"])
|
||||
|
||||
|
||||
# FastAPI endpoint tests
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user