From 5b2e7d3356bd1f4044c2c0d23cc26c2571d30603 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 4 Mar 2025 15:31:50 -0800 Subject: [PATCH] feat: Add blocks and tools to agent serialization (#1187) --- letta/log.py | 2 +- letta/schemas/letta_base.py | 6 +- letta/schemas/providers.py | 2 +- letta/serialize_schemas/agent.py | 57 ++++- letta/serialize_schemas/base.py | 41 ++++ letta/serialize_schemas/block.py | 15 ++ letta/serialize_schemas/message.py | 16 +- letta/serialize_schemas/tool.py | 15 ++ letta/services/agent_manager.py | 25 ++- letta/services/message_manager.py | 2 +- tests/test_agent_serialization.py | 345 ++++++++++++++++++++++++----- 11 files changed, 439 insertions(+), 87 deletions(-) create mode 100644 letta/serialize_schemas/block.py create mode 100644 letta/serialize_schemas/tool.py diff --git a/letta/log.py b/letta/log.py index 0d4ad8e1..8a2506ac 100644 --- a/letta/log.py +++ b/letta/log.py @@ -54,7 +54,7 @@ DEVELOPMENT_LOGGING = { "propagate": True, # Let logs bubble up to root }, "uvicorn": { - "level": "DEBUG", + "level": "INFO", "handlers": ["console"], "propagate": True, }, diff --git a/letta/schemas/letta_base.py b/letta/schemas/letta_base.py index d6850933..5d2a3da3 100644 --- a/letta/schemas/letta_base.py +++ b/letta/schemas/letta_base.py @@ -38,15 +38,15 @@ class LettaBase(BaseModel): description=cls._id_description(prefix), pattern=cls._id_regex_pattern(prefix), examples=[cls._id_example(prefix)], - default_factory=cls._generate_id, + default_factory=cls.generate_id, ) @classmethod - def _generate_id(cls, prefix: Optional[str] = None) -> str: + def generate_id(cls, prefix: Optional[str] = None) -> str: prefix = prefix or cls.__id_prefix__ return f"{prefix}-{uuid.uuid4()}" - # def _generate_id(self) -> str: + # def generate_id(self) -> str: # return f"{self.__id_prefix__}-{uuid.uuid4()}" @classmethod diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index b7917038..9084c729 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -27,7 +27,7 @@ class Provider(ProviderBase): def resolve_identifier(self): if not self.id: - self.id = ProviderBase._generate_id(prefix=ProviderBase.__id_prefix__) + self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__) def list_llm_models(self) -> List[LLMConfig]: return [] diff --git a/letta/serialize_schemas/agent.py b/letta/serialize_schemas/agent.py index 036adf44..dfb3ff0e 100644 --- a/letta/serialize_schemas/agent.py +++ b/letta/serialize_schemas/agent.py @@ -1,9 +1,16 @@ -from marshmallow import fields +from typing import Dict + +from marshmallow import fields, post_dump from letta.orm import Agent +from letta.schemas.agent import AgentState as PydanticAgentState +from letta.schemas.user import User from letta.serialize_schemas.base import BaseSchema +from letta.serialize_schemas.block import SerializedBlockSchema from letta.serialize_schemas.custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField from letta.serialize_schemas.message import SerializedMessageSchema +from letta.serialize_schemas.tool import SerializedToolSchema +from letta.server.db import SessionLocal class SerializedAgentSchema(BaseSchema): @@ -12,25 +19,51 @@ class SerializedAgentSchema(BaseSchema): Excludes relational fields. """ + __pydantic_model__ = PydanticAgentState + llm_config = LLMConfigField() embedding_config = EmbeddingConfigField() tool_rules = ToolRulesField() messages = fields.List(fields.Nested(SerializedMessageSchema)) + core_memory = fields.List(fields.Nested(SerializedBlockSchema)) + tools = fields.List(fields.Nested(SerializedToolSchema)) - def __init__(self, *args, session=None, **kwargs): - super().__init__(*args, **kwargs) - if session: - self.session = session + def __init__(self, *args, session: SessionLocal, actor: User, **kwargs): + super().__init__(*args, actor=actor, **kwargs) + self.session = session - # propagate session to nested schemas - for field_name, field_obj in self.fields.items(): - if isinstance(field_obj, fields.List) and hasattr(field_obj.inner, "schema"): - field_obj.inner.schema.session = session - elif hasattr(field_obj, "schema"): - field_obj.schema.session = session + # Propagate session and actor to nested schemas automatically + for field in self.fields.values(): + if isinstance(field, fields.List) and isinstance(field.inner, fields.Nested): + field.inner.schema.session = session + field.inner.schema.actor = actor + elif isinstance(field, fields.Nested): + field.schema.session = session + field.schema.actor = actor + + @post_dump + def sanitize_ids(self, data: Dict, **kwargs): + data = super().sanitize_ids(data, **kwargs) + + # Remap IDs of messages + # Need to do this in post, so we can correctly map the in-context message IDs + # TODO: Remap message_ids to reference objects, not just be a list + id_remapping = dict() + for message in data.get("messages"): + message_id = message.get("id") + if message_id not in id_remapping: + id_remapping[message_id] = SerializedMessageSchema.__pydantic_model__.generate_id() + message["id"] = id_remapping[message_id] + else: + raise ValueError(f"Duplicate message IDs in agent.messages: {message_id}") + + # Remap in context message ids + data["message_ids"] = [id_remapping[message_id] for message_id in data.get("message_ids")] + + return data class Meta(BaseSchema.Meta): model = Agent # TODO: Serialize these as well... - exclude = ("tools", "sources", "core_memory", "tags", "source_passages", "agent_passages", "organization") + exclude = BaseSchema.Meta.exclude + ("sources", "tags", "source_passages", "agent_passages") diff --git a/letta/serialize_schemas/base.py b/letta/serialize_schemas/base.py index b64e76e2..e4051408 100644 --- a/letta/serialize_schemas/base.py +++ b/letta/serialize_schemas/base.py @@ -1,4 +1,10 @@ +from typing import Dict, Optional + +from marshmallow import post_dump, pre_load from marshmallow_sqlalchemy import SQLAlchemyAutoSchema +from sqlalchemy.inspection import inspect + +from letta.schemas.user import User class BaseSchema(SQLAlchemyAutoSchema): @@ -7,6 +13,41 @@ class BaseSchema(SQLAlchemyAutoSchema): This ensures all schemas share the same session. """ + __pydantic_model__ = None + sensitive_ids = {"_created_by_id", "_last_updated_by_id"} + sensitive_relationships = {"organization"} + id_scramble_placeholder = "xxx" + + def __init__(self, *args, actor: Optional[User] = None, **kwargs): + super().__init__(*args, **kwargs) + self.actor = actor + + @post_dump + def sanitize_ids(self, data: Dict, **kwargs): + data["id"] = self.__pydantic_model__.generate_id() + + for sensitive_id in BaseSchema.sensitive_ids.union(BaseSchema.sensitive_relationships): + if sensitive_id in data: + data[sensitive_id] = BaseSchema.id_scramble_placeholder + + return data + + @pre_load + def regenerate_ids(self, data: Dict, **kwargs): + if self.Meta.model: + mapper = inspect(self.Meta.model) + for sensitive_id in BaseSchema.sensitive_ids: + if sensitive_id in mapper.columns: + data[sensitive_id] = self.actor.id + + for relationship in BaseSchema.sensitive_relationships: + if relationship in mapper.relationships: + data[relationship] = self.actor.organization_id + + return data + class Meta: + model = None include_relationships = True load_instance = True + exclude = () diff --git a/letta/serialize_schemas/block.py b/letta/serialize_schemas/block.py new file mode 100644 index 00000000..41139121 --- /dev/null +++ b/letta/serialize_schemas/block.py @@ -0,0 +1,15 @@ +from letta.orm.block import Block +from letta.schemas.block import Block as PydanticBlock +from letta.serialize_schemas.base import BaseSchema + + +class SerializedBlockSchema(BaseSchema): + """ + Marshmallow schema for serializing/deserializing Block objects. + """ + + __pydantic_model__ = PydanticBlock + + class Meta(BaseSchema.Meta): + model = Block + exclude = BaseSchema.Meta.exclude + ("agents",) diff --git a/letta/serialize_schemas/message.py b/letta/serialize_schemas/message.py index 58d055d6..f1300d24 100644 --- a/letta/serialize_schemas/message.py +++ b/letta/serialize_schemas/message.py @@ -1,4 +1,9 @@ +from typing import Dict + +from marshmallow import post_dump + from letta.orm.message import Message +from letta.schemas.message import Message as PydanticMessage from letta.serialize_schemas.base import BaseSchema from letta.serialize_schemas.custom_fields import ToolCallField @@ -8,8 +13,17 @@ class SerializedMessageSchema(BaseSchema): Marshmallow schema for serializing/deserializing Message objects. """ + __pydantic_model__ = PydanticMessage + tool_calls = ToolCallField() + @post_dump + def sanitize_ids(self, data: Dict, **kwargs): + # We don't want to remap here + # Because of the way that message_ids is just a JSON field on agents + # We need to wait for the agent dumps, and then keep track of all the message IDs we remapped + return data + class Meta(BaseSchema.Meta): model = Message - exclude = ("step", "job_message") + exclude = BaseSchema.Meta.exclude + ("step", "job_message", "agent") diff --git a/letta/serialize_schemas/tool.py b/letta/serialize_schemas/tool.py new file mode 100644 index 00000000..fe2debe8 --- /dev/null +++ b/letta/serialize_schemas/tool.py @@ -0,0 +1,15 @@ +from letta.orm import Tool +from letta.schemas.tool import Tool as PydanticTool +from letta.serialize_schemas.base import BaseSchema + + +class SerializedToolSchema(BaseSchema): + """ + Marshmallow schema for serializing/deserializing Tool objects. + """ + + __pydantic_model__ = PydanticTool + + class Meta(BaseSchema.Meta): + model = Tool + exclude = BaseSchema.Meta.exclude diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 204a36be..06e0ef22 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -35,6 +35,7 @@ from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule from letta.schemas.tool_rule import ToolRule as PydanticToolRule from letta.schemas.user import User as PydanticUser from letta.serialize_schemas import SerializedAgentSchema +from letta.serialize_schemas.tool import SerializedToolSchema from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import ( _process_relationship, @@ -394,18 +395,28 @@ class AgentManager: with self.session_maker() as session: # Retrieve the agent agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - schema = SerializedAgentSchema(session=session) + schema = SerializedAgentSchema(session=session, actor=actor) return schema.dump(agent) @enforce_types - def deserialize(self, serialized_agent: dict, actor: PydanticUser) -> PydanticAgentState: - # TODO: Use actor to override fields + def deserialize(self, serialized_agent: dict, actor: PydanticUser, mark_as_copy: bool = True) -> PydanticAgentState: + tool_data_list = serialized_agent.pop("tools", []) + with self.session_maker() as session: - schema = SerializedAgentSchema(session=session) + schema = SerializedAgentSchema(session=session, actor=actor) agent = schema.load(serialized_agent, session=session) - agent.organization_id = actor.organization_id - agent = agent.create(session, actor=actor) - return agent.to_pydantic() + if mark_as_copy: + agent.name += "_copy" + agent.create(session, actor=actor) + pydantic_agent = agent.to_pydantic() + + # Need to do this separately as there's some fancy upsert logic that SqlAlchemy cannot handle + for tool_data in tool_data_list: + pydantic_tool = SerializedToolSchema(actor=actor).load(tool_data, transient=True).to_pydantic() + pydantic_tool = self.tool_manager.create_or_update_tool(pydantic_tool, actor=actor) + pydantic_agent = self.attach_tool(agent_id=pydantic_agent.id, tool_id=pydantic_tool.id, actor=actor) + + return pydantic_agent # ====================================================================================================================== # Per Agent Environment Variable Management diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 26f0bee5..26f7c27b 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -46,7 +46,7 @@ class MessageManager: # Sort results directly based on message_ids result_dict = {msg.id: msg.to_pydantic() for msg in results} - return [result_dict[msg_id] for msg_id in message_ids] + return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids])) @enforce_types def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage: diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index ade9f2f4..2651c08c 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -1,15 +1,27 @@ +import difflib import json +from datetime import datetime, timezone +from typing import Any, Dict, List, Mapping import pytest +from rich.console import Console +from rich.syntax import Syntax from letta import create_client from letta.config import LettaConfig from letta.orm import Base -from letta.schemas.agent import CreateAgent +from letta.schemas.agent import AgentState, CreateAgent +from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import MessageRole from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import MessageCreate +from letta.schemas.organization import Organization +from letta.schemas.user import User from letta.server.server import SyncServer +console = Console() + def _clear_tables(): from letta.server.db import db_context @@ -58,80 +70,291 @@ def default_user(server: SyncServer, default_organization): @pytest.fixture -def sarah_agent(server: SyncServer, default_user, default_organization): +def other_organization(server: SyncServer): + """Fixture to create and return the default organization.""" + org = server.organization_manager.create_organization(pydantic_org=Organization(name="letta")) + yield org + + +@pytest.fixture +def other_user(server: SyncServer, other_organization): + """Fixture to create and return the default user within the default organization.""" + user = server.user_manager.create_user(pydantic_user=User(organization_id=other_organization.id, name="sarah")) + yield user + + +@pytest.fixture +def serialize_test_agent(server: SyncServer, default_user, default_organization): """Fixture to create and return a sample agent within the default organization.""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + agent_name = f"serialize_test_agent_{timestamp}" + + server.tool_manager.upsert_base_tools(actor=default_user) + agent_state = server.agent_manager.create_agent( agent_create=CreateAgent( - name="sarah_agent", - memory_blocks=[], + name=agent_name, + memory_blocks=[ + CreateBlock( + value="Name: Caren", + label="human", + ), + ], llm_config=LLMConfig.default_config("gpt-4o-mini"), embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=True, ), actor=default_user, ) yield agent_state -def test_agent_serialization(server, sarah_agent, default_user): - """Test serializing an Agent instance to JSON.""" - result = server.agent_manager.serialize(agent_id=sarah_agent.id, actor=default_user) - - # Assert that the result is a dictionary (JSON object) - assert isinstance(result, dict), "Expected a dictionary result" - - # Assert that the 'id' field is present and matches the agent's ID - assert "id" in result, "Agent 'id' is missing in the serialized result" - assert result["id"] == sarah_agent.id, f"Expected agent 'id' to be {sarah_agent.id}, but got {result['id']}" - - # Assert that the 'llm_config' and 'embedding_config' fields exist - assert "llm_config" in result, "'llm_config' is missing in the serialized result" - assert "embedding_config" in result, "'embedding_config' is missing in the serialized result" - - # Assert that 'messages' is a list - assert isinstance(result.get("messages", []), list), "'messages' should be a list" - - # Assert that the 'tool_exec_environment_variables' field is a list (empty or populated) - assert isinstance(result.get("tool_exec_environment_variables", []), list), "'tool_exec_environment_variables' should be a list" - - # Assert that the 'agent_type' is a valid string - assert isinstance(result.get("agent_type"), str), "'agent_type' should be a string" - - # Assert that the 'tool_rules' field is a list (even if empty) - assert isinstance(result.get("tool_rules", []), list), "'tool_rules' should be a list" - - # Check that all necessary fields are present in the 'messages' section, focusing on core elements - if "messages" in result: - for message in result["messages"]: - assert "id" in message, "Message 'id' is missing" - assert "text" in message, "Message 'text' is missing" - assert "role" in message, "Message 'role' is missing" - assert "created_at" in message, "Message 'created_at' is missing" - assert "updated_at" in message, "Message 'updated_at' is missing" - - # Optionally check that 'created_at' and 'updated_at' are in ISO 8601 format - assert isinstance(result["created_at"], str), "Expected 'created_at' to be a string" - assert isinstance(result["updated_at"], str), "Expected 'updated_at' to be a string" - - # Optionally check for presence of any required metadata or ensure it is null if expected - assert "metadata_" in result, "'metadata_' field is missing" - assert result["metadata_"] is None, "'metadata_' should be null" - - # Assert that the agent name is as expected (if defined) - assert result.get("name") == sarah_agent.name, "Expected agent 'name' to not be None, but found something else" - - print(json.dumps(result, indent=4)) +# Helper functions below -def test_agent_deserialization_basic(local_client, server, sarah_agent, default_user): +def dict_to_pretty_json(d: Dict[str, Any]) -> str: + """Convert a dictionary to a pretty JSON string with sorted keys, handling datetime objects.""" + return json.dumps(d, indent=2, sort_keys=True, default=_json_serializable) + + +def _json_serializable(obj: Any) -> Any: + """Convert non-serializable objects (like datetime) to a JSON-friendly format.""" + if isinstance(obj, datetime): + return obj.isoformat() # Convert datetime to ISO 8601 format + raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") + + +def print_dict_diff(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> None: + """Prints a detailed colorized diff between two dictionaries.""" + json1 = dict_to_pretty_json(dict1).splitlines() + json2 = dict_to_pretty_json(dict2).splitlines() + + diff = list(difflib.unified_diff(json1, json2, fromfile="Expected", tofile="Actual", lineterm="")) + + if diff: + console.print("\nšŸ” [bold red]Dictionary Diff:[/bold red]") + diff_text = "\n".join(diff) + syntax = Syntax(diff_text, "diff", theme="monokai", line_numbers=False) + console.print(syntax) + else: + console.print("\nāœ… [bold green]No differences found in dictionaries.[/bold green]") + + +def has_same_prefix(value1: Any, value2: Any) -> bool: + """Check if two string values have the same major prefix (before the second hyphen).""" + if not isinstance(value1, str) or not isinstance(value2, str): + return False + + prefix1 = value1.split("-")[0] + prefix2 = value2.split("-")[0] + + return prefix1 == prefix2 + + +def compare_lists(list1: List[Any], list2: List[Any]) -> bool: + """Compare lists while handling unordered dictionaries inside.""" + if len(list1) != len(list2): + return False + + if all(isinstance(item, Mapping) for item in list1) and all(isinstance(item, Mapping) for item in list2): + return all(any(_compare_agent_state_model_dump(i1, i2, log=False) for i2 in list2) for i1 in list1) + + return sorted(list1) == sorted(list2) + + +def strip_datetime_fields(d: Dict[str, Any]) -> Dict[str, Any]: + """Remove datetime fields from a dictionary before comparison.""" + return {k: v for k, v in d.items() if not isinstance(v, datetime)} + + +def _log_mismatch(key: str, expected: Any, actual: Any, log: bool) -> None: + """Log detailed information about a mismatch.""" + if log: + print(f"\nšŸ”“ Mismatch Found in Key: '{key}'") + print(f"Expected: {expected}") + print(f"Actual: {actual}") + + if isinstance(expected, str) and isinstance(actual, str): + print("\nšŸ” String Diff:") + diff = difflib.ndiff(expected.splitlines(), actual.splitlines()) + print("\n".join(diff)) + + +def _compare_agent_state_model_dump(d1: Dict[str, Any], d2: Dict[str, Any], log: bool = True) -> bool: + """ + Compare two dictionaries with special handling: + - Keys in `ignore_prefix_fields` should match only by prefix. + - 'message_ids' lists should match in length only. + - Datetime fields are ignored. + - Order-independent comparison for lists of dicts. + """ + ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id"} + + # Remove datetime fields upfront + d1 = strip_datetime_fields(d1) + d2 = strip_datetime_fields(d2) + + if d1.keys() != d2.keys(): + _log_mismatch("dict_keys", set(d1.keys()), set(d2.keys())) + return False + + for key, v1 in d1.items(): + v2 = d2[key] + + if key in ignore_prefix_fields: + if v1 and v2 and not has_same_prefix(v1, v2): + _log_mismatch(key, v1, v2, log) + return False + elif key == "message_ids": + if not isinstance(v1, list) or not isinstance(v2, list) or len(v1) != len(v2): + _log_mismatch(key, v1, v2, log) + return False + elif isinstance(v1, Dict) and isinstance(v2, Dict): + if not _compare_agent_state_model_dump(v1, v2): + _log_mismatch(key, v1, v2, log) + return False + elif isinstance(v1, list) and isinstance(v2, list): + if not compare_lists(v1, v2): + _log_mismatch(key, v1, v2, log) + return False + elif v1 != v2: + _log_mismatch(key, v1, v2, log) + return False + + return True + + +def compare_agent_state(original: AgentState, copy: AgentState, mark_as_copy: bool) -> bool: + """Wrapper function that provides a default set of ignored prefix fields.""" + if not mark_as_copy: + assert original.name == copy.name + + return _compare_agent_state_model_dump(original.model_dump(exclude="name"), copy.model_dump(exclude="name")) + + +# Sanity tests for our agent model_dump verifier helpers + + +def test_sanity_identical_dicts(): + d1 = {"name": "Alice", "age": 30, "details": {"city": "New York"}} + d2 = {"name": "Alice", "age": 30, "details": {"city": "New York"}} + assert _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_different_dicts(): + d1 = {"name": "Alice", "age": 30} + d2 = {"name": "Bob", "age": 30} + assert not _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_ignored_id_fields(): + d1 = {"id": "user-abc123", "name": "Alice"} + d2 = {"id": "user-xyz789", "name": "Alice"} # Different ID, same prefix + assert _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_different_id_prefix_fails(): + d1 = {"id": "user-abc123"} + d2 = {"id": "admin-xyz789"} # Different prefix + assert not _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_nested_dicts(): + d1 = {"user": {"id": "user-123", "name": "Alice"}} + d2 = {"user": {"id": "user-456", "name": "Alice"}} # ID changes, but prefix matches + assert _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_list_handling(): + d1 = {"items": [1, 2, 3]} + d2 = {"items": [1, 2, 3]} + assert _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_list_mismatch(): + d1 = {"items": [1, 2, 3]} + d2 = {"items": [1, 2, 4]} + assert not _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_message_ids_length_check(): + d1 = {"message_ids": ["msg-123", "msg-456", "msg-789"]} + d2 = {"message_ids": ["msg-abc", "msg-def", "msg-ghi"]} # Same length, different values + assert _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_message_ids_different_length(): + d1 = {"message_ids": ["msg-123", "msg-456"]} + d2 = {"message_ids": ["msg-123"]} + assert not _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_datetime_fields(): + d1 = {"created_at": datetime(2025, 3, 4, 18, 25, 37, tzinfo=timezone.utc)} + d2 = {"created_at": datetime(2025, 3, 4, 18, 25, 37, tzinfo=timezone.utc)} + assert _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_datetime_mismatch(): + d1 = {"created_at": datetime(2025, 3, 4, 18, 25, 37, tzinfo=timezone.utc)} + d2 = {"created_at": datetime(2025, 3, 4, 18, 25, 38, tzinfo=timezone.utc)} # One second difference + assert _compare_agent_state_model_dump(d1, d2) # Should ignore + + +# Agent serialize/deserialize tests + + +@pytest.mark.parametrize("mark_as_copy", [True, False]) +def test_mark_as_copy_simple(local_client, server, serialize_test_agent, default_user, other_user, mark_as_copy): """Test deserializing JSON into an Agent instance.""" - # Send a message first - sarah_agent = server.agent_manager.get_agent_by_id(agent_id=sarah_agent.id, actor=default_user) - result = server.agent_manager.serialize(agent_id=sarah_agent.id, actor=default_user) + result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) - # Delete the agent - server.agent_manager.delete_agent(sarah_agent.id, actor=default_user) + # Deserialize the agent + agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user, mark_as_copy=mark_as_copy) - agent_state = server.agent_manager.deserialize(serialized_agent=result, actor=default_user) + # Compare serialized representations to check for exact match + print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json())) + assert compare_agent_state(agent_copy, serialize_test_agent, mark_as_copy=mark_as_copy) - assert agent_state.name == sarah_agent.name - assert len(agent_state.message_ids) == len(sarah_agent.message_ids) + +def test_in_context_message_id_remapping(local_client, server, serialize_test_agent, default_user, other_user): + """Test deserializing JSON into an Agent instance.""" + result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) + + # Check remapping on message_ids and messages is consistent + assert sorted([m["id"] for m in result["messages"]]) == sorted(result["message_ids"]) + + # Deserialize the agent + agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user) + + # Make sure all the messages are able to be retrieved + in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user) + assert len(in_context_messages) == len(result["message_ids"]) + assert sorted([m.id for m in in_context_messages]) == sorted(result["message_ids"]) + + +def test_agent_serialize_with_user_messages(local_client, server, serialize_test_agent, default_user, other_user): + """Test deserializing JSON into an Agent instance.""" + mark_as_copy = False + server.send_messages( + actor=default_user, agent_id=serialize_test_agent.id, messages=[MessageCreate(role=MessageRole.user, content="hello")] + ) + result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) + + # Deserialize the agent + agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user, mark_as_copy=mark_as_copy) + + # Get most recent original agent instance + serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user) + + # Compare serialized representations to check for exact match + print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json())) + assert compare_agent_state(agent_copy, serialize_test_agent, mark_as_copy=mark_as_copy) + + # Make sure both agents can receive messages after + server.send_messages( + actor=default_user, agent_id=serialize_test_agent.id, messages=[MessageCreate(role=MessageRole.user, content="and hello again")] + ) + server.send_messages( + actor=other_user, agent_id=agent_copy.id, messages=[MessageCreate(role=MessageRole.user, content="and hello again")] + )