From 976f90139f1386337bf38b66e559d98ef584475a Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Thu, 13 Mar 2025 16:18:50 -0700 Subject: [PATCH] feat: properly scrub ids from serialized schemas and add version (#1258) Co-authored-by: Matt Zhou --- letta/serialize_schemas/agent.py | 57 ++++++++++++++++++++++-------- letta/serialize_schemas/base.py | 38 +++++++------------- letta/serialize_schemas/message.py | 25 +++++++++---- letta/serialize_schemas/tag.py | 12 ++++++- tests/test_agent_serialization.py | 29 +++++++-------- 5 files changed, 99 insertions(+), 62 deletions(-) diff --git a/letta/serialize_schemas/agent.py b/letta/serialize_schemas/agent.py index 7cef7d15..9a8b9e5f 100644 --- a/letta/serialize_schemas/agent.py +++ b/letta/serialize_schemas/agent.py @@ -1,7 +1,8 @@ from typing import Dict -from marshmallow import fields, post_dump +from marshmallow import fields, post_dump, pre_load +import letta from letta.orm import Agent from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.user import User @@ -23,6 +24,12 @@ class SerializedAgentSchema(BaseSchema): __pydantic_model__ = PydanticAgentState + FIELD_VERSION = "version" + FIELD_MESSAGES = "messages" + FIELD_MESSAGE_IDS = "message_ids" + FIELD_IN_CONTEXT = "in_context" + FIELD_ID = "id" + llm_config = LLMConfigField() embedding_config = EmbeddingConfigField() tool_rules = ToolRulesField() @@ -48,25 +55,47 @@ class SerializedAgentSchema(BaseSchema): @post_dump def sanitize_ids(self, data: Dict, **kwargs): + """ + - Removes `message_ids` + - Adds versioning + - Marks messages as in-context + - Removes individual message `id` fields + """ data = super().sanitize_ids(data, **kwargs) + data[self.FIELD_VERSION] = letta.__version__ - # Remap IDs of messages - # Need to do this in post, so we can correctly map the in-context message IDs - # TODO: Remap message_ids to reference objects, not just be a list - id_remapping = dict() - for message in data.get("messages"): - message_id = message.get("id") - if message_id not in id_remapping: - id_remapping[message_id] = SerializedMessageSchema.__pydantic_model__.generate_id() - message["id"] = id_remapping[message_id] - else: - raise ValueError(f"Duplicate message IDs in agent.messages: {message_id}") + message_ids = set(data.pop(self.FIELD_MESSAGE_IDS, [])) # Store and remove message_ids - # Remap in context message ids - data["message_ids"] = [id_remapping[message_id] for message_id in data.get("message_ids")] + for message in data.get(self.FIELD_MESSAGES, []): + message[self.FIELD_IN_CONTEXT] = message[self.FIELD_ID] in message_ids # Mark messages as in-context + message.pop(self.FIELD_ID, None) # Remove the id field return data + @pre_load + def check_version(self, data, **kwargs): + """Check version and remove it from the schema""" + version = data[self.FIELD_VERSION] + if version != letta.__version__: + print(f"Version mismatch: expected {letta.__version__}, got {version}") + del data[self.FIELD_VERSION] + return data + + @pre_load + def remap_in_context_messages(self, data, **kwargs): + """ + Restores `message_ids` by collecting message IDs where `in_context` is True, + generates new IDs for all messages, and removes `in_context` from all messages. + """ + message_ids = [] + for msg in data.get(self.FIELD_MESSAGES, []): + msg[self.FIELD_ID] = SerializedMessageSchema.generate_id() # Generate new ID + if msg.pop(self.FIELD_IN_CONTEXT, False): # If it was in-context, track its new ID + message_ids.append(msg[self.FIELD_ID]) + + data[self.FIELD_MESSAGE_IDS] = message_ids + return data + class Meta(BaseSchema.Meta): model = Agent # TODO: Serialize these as well... diff --git a/letta/serialize_schemas/base.py b/letta/serialize_schemas/base.py index de142467..50e53fd6 100644 --- a/letta/serialize_schemas/base.py +++ b/letta/serialize_schemas/base.py @@ -2,7 +2,6 @@ from typing import Dict, Optional from marshmallow import post_dump, pre_load from marshmallow_sqlalchemy import SQLAlchemyAutoSchema -from sqlalchemy.inspection import inspect from letta.schemas.user import User @@ -14,46 +13,35 @@ class BaseSchema(SQLAlchemyAutoSchema): """ __pydantic_model__ = None - sensitive_ids = {"_created_by_id", "_last_updated_by_id"} - sensitive_relationships = {"organization"} - id_scramble_placeholder = "xxx" def __init__(self, *args, actor: Optional[User] = None, **kwargs): super().__init__(*args, **kwargs) self.actor = actor - def generate_id(self) -> Optional[str]: - if self.__pydantic_model__: - return self.__pydantic_model__.generate_id() + @classmethod + def generate_id(cls) -> Optional[str]: + if cls.__pydantic_model__: + return cls.__pydantic_model__.generate_id() return None @post_dump def sanitize_ids(self, data: Dict, **kwargs) -> Dict: - if self.Meta.model: - mapper = inspect(self.Meta.model) - if "id" in mapper.columns: - generated_id = self.generate_id() - if generated_id: - data["id"] = generated_id - - for sensitive_id in BaseSchema.sensitive_ids.union(BaseSchema.sensitive_relationships): - if sensitive_id in data: - data[sensitive_id] = BaseSchema.id_scramble_placeholder + # delete id + del data["id"] + del data["_created_by_id"] + del data["_last_updated_by_id"] + del data["organization"] return data @pre_load def regenerate_ids(self, data: Dict, **kwargs) -> Dict: if self.Meta.model: - mapper = inspect(self.Meta.model) - for sensitive_id in BaseSchema.sensitive_ids: - if sensitive_id in mapper.columns: - data[sensitive_id] = self.actor.id - - for relationship in BaseSchema.sensitive_relationships: - if relationship in mapper.relationships: - data[relationship] = self.actor.organization_id + data["id"] = self.generate_id() + data["_created_by_id"] = self.actor.id + data["_last_updated_by_id"] = self.actor.id + data["organization"] = self.actor.organization_id return data diff --git a/letta/serialize_schemas/message.py b/letta/serialize_schemas/message.py index f1300d24..187b8f88 100644 --- a/letta/serialize_schemas/message.py +++ b/letta/serialize_schemas/message.py @@ -1,6 +1,6 @@ from typing import Dict -from marshmallow import post_dump +from marshmallow import post_dump, pre_load from letta.orm.message import Message from letta.schemas.message import Message as PydanticMessage @@ -18,12 +18,25 @@ class SerializedMessageSchema(BaseSchema): tool_calls = ToolCallField() @post_dump - def sanitize_ids(self, data: Dict, **kwargs): - # We don't want to remap here - # Because of the way that message_ids is just a JSON field on agents - # We need to wait for the agent dumps, and then keep track of all the message IDs we remapped + def sanitize_ids(self, data: Dict, **kwargs) -> Dict: + # keep id for remapping later on agent dump + # agent dump will then get rid of message ids + del data["_created_by_id"] + del data["_last_updated_by_id"] + del data["organization"] + + return data + + @pre_load + def regenerate_ids(self, data: Dict, **kwargs) -> Dict: + if self.Meta.model: + # Skip regenerating ID, as agent dump will do it + data["_created_by_id"] = self.actor.id + data["_last_updated_by_id"] = self.actor.id + data["organization"] = self.actor.organization_id + return data class Meta(BaseSchema.Meta): model = Message - exclude = BaseSchema.Meta.exclude + ("step", "job_message", "agent") + exclude = BaseSchema.Meta.exclude + ("step", "job_message", "agent", "otid", "is_deleted") diff --git a/letta/serialize_schemas/tag.py b/letta/serialize_schemas/tag.py index e83f0e3e..38c5e97c 100644 --- a/letta/serialize_schemas/tag.py +++ b/letta/serialize_schemas/tag.py @@ -1,4 +1,6 @@ -from marshmallow import fields +from typing import Dict + +from marshmallow import fields, post_dump, pre_load from letta.orm.agents_tags import AgentsTags from letta.serialize_schemas.base import BaseSchema @@ -13,6 +15,14 @@ class SerializedAgentTagSchema(BaseSchema): tag = fields.String(required=True) + @post_dump + def sanitize_ids(self, data: Dict, **kwargs): + return data + + @pre_load + def regenerate_ids(self, data: Dict, **kwargs) -> Dict: + return data + class Meta(BaseSchema.Meta): model = AgentsTags exclude = BaseSchema.Meta.exclude + ("agent",) diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index b7f6cc7d..9ed1a4bc 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -394,22 +394,6 @@ def test_deserialize_override_existing_tools( assert existing_tool.source_code == weather_tool.source_code, f"Tool {tool_name} should NOT be overridden" -def test_in_context_message_id_remapping(local_client, server, serialize_test_agent, default_user, other_user): - """Test deserializing JSON into an Agent instance.""" - result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) - - # Check remapping on message_ids and messages is consistent - assert sorted([m["id"] for m in result["messages"]]) == sorted(result["message_ids"]) - - # Deserialize the agent - agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user) - - # Make sure all the messages are able to be retrieved - in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user) - assert len(in_context_messages) == len(result["message_ids"]) - assert sorted([m.id for m in in_context_messages]) == sorted(result["message_ids"]) - - def test_agent_serialize_with_user_messages(local_client, server, serialize_test_agent, default_user, other_user): """Test deserializing JSON into an Agent instance.""" append_copy_suffix = False @@ -473,6 +457,19 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server, assert copy_agent_response.completion_tokens > 0 and copy_agent_response.step_count > 0 +def test_in_context_message_id_remapping(local_client, server, serialize_test_agent, default_user, other_user): + """Test deserializing JSON into an Agent instance.""" + result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) + + # Deserialize the agent + agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user) + + # Make sure all the messages are able to be retrieved + in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user) + assert len(in_context_messages) == len(result["message_ids"]) + assert sorted([m.id for m in in_context_messages]) == sorted(result["message_ids"]) + + # FastAPI endpoint tests