feat: Add blocks and tools to agent serialization (#1187)
This commit is contained in:
@@ -54,7 +54,7 @@ DEVELOPMENT_LOGGING = {
|
||||
"propagate": True, # Let logs bubble up to root
|
||||
},
|
||||
"uvicorn": {
|
||||
"level": "DEBUG",
|
||||
"level": "INFO",
|
||||
"handlers": ["console"],
|
||||
"propagate": True,
|
||||
},
|
||||
|
||||
@@ -38,15 +38,15 @@ class LettaBase(BaseModel):
|
||||
description=cls._id_description(prefix),
|
||||
pattern=cls._id_regex_pattern(prefix),
|
||||
examples=[cls._id_example(prefix)],
|
||||
default_factory=cls._generate_id,
|
||||
default_factory=cls.generate_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _generate_id(cls, prefix: Optional[str] = None) -> str:
|
||||
def generate_id(cls, prefix: Optional[str] = None) -> str:
|
||||
prefix = prefix or cls.__id_prefix__
|
||||
return f"{prefix}-{uuid.uuid4()}"
|
||||
|
||||
# def _generate_id(self) -> str:
|
||||
# def generate_id(self) -> str:
|
||||
# return f"{self.__id_prefix__}-{uuid.uuid4()}"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -27,7 +27,7 @@ class Provider(ProviderBase):
|
||||
|
||||
def resolve_identifier(self):
|
||||
if not self.id:
|
||||
self.id = ProviderBase._generate_id(prefix=ProviderBase.__id_prefix__)
|
||||
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
return []
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
from marshmallow import fields
|
||||
from typing import Dict
|
||||
|
||||
from marshmallow import fields, post_dump
|
||||
|
||||
from letta.orm import Agent
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.user import User
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.block import SerializedBlockSchema
|
||||
from letta.serialize_schemas.custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField
|
||||
from letta.serialize_schemas.message import SerializedMessageSchema
|
||||
from letta.serialize_schemas.tool import SerializedToolSchema
|
||||
from letta.server.db import SessionLocal
|
||||
|
||||
|
||||
class SerializedAgentSchema(BaseSchema):
|
||||
@@ -12,25 +19,51 @@ class SerializedAgentSchema(BaseSchema):
|
||||
Excludes relational fields.
|
||||
"""
|
||||
|
||||
__pydantic_model__ = PydanticAgentState
|
||||
|
||||
llm_config = LLMConfigField()
|
||||
embedding_config = EmbeddingConfigField()
|
||||
tool_rules = ToolRulesField()
|
||||
|
||||
messages = fields.List(fields.Nested(SerializedMessageSchema))
|
||||
core_memory = fields.List(fields.Nested(SerializedBlockSchema))
|
||||
tools = fields.List(fields.Nested(SerializedToolSchema))
|
||||
|
||||
def __init__(self, *args, session=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if session:
|
||||
self.session = session
|
||||
def __init__(self, *args, session: SessionLocal, actor: User, **kwargs):
|
||||
super().__init__(*args, actor=actor, **kwargs)
|
||||
self.session = session
|
||||
|
||||
# propagate session to nested schemas
|
||||
for field_name, field_obj in self.fields.items():
|
||||
if isinstance(field_obj, fields.List) and hasattr(field_obj.inner, "schema"):
|
||||
field_obj.inner.schema.session = session
|
||||
elif hasattr(field_obj, "schema"):
|
||||
field_obj.schema.session = session
|
||||
# Propagate session and actor to nested schemas automatically
|
||||
for field in self.fields.values():
|
||||
if isinstance(field, fields.List) and isinstance(field.inner, fields.Nested):
|
||||
field.inner.schema.session = session
|
||||
field.inner.schema.actor = actor
|
||||
elif isinstance(field, fields.Nested):
|
||||
field.schema.session = session
|
||||
field.schema.actor = actor
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs):
|
||||
data = super().sanitize_ids(data, **kwargs)
|
||||
|
||||
# Remap IDs of messages
|
||||
# Need to do this in post, so we can correctly map the in-context message IDs
|
||||
# TODO: Remap message_ids to reference objects, not just be a list
|
||||
id_remapping = dict()
|
||||
for message in data.get("messages"):
|
||||
message_id = message.get("id")
|
||||
if message_id not in id_remapping:
|
||||
id_remapping[message_id] = SerializedMessageSchema.__pydantic_model__.generate_id()
|
||||
message["id"] = id_remapping[message_id]
|
||||
else:
|
||||
raise ValueError(f"Duplicate message IDs in agent.messages: {message_id}")
|
||||
|
||||
# Remap in context message ids
|
||||
data["message_ids"] = [id_remapping[message_id] for message_id in data.get("message_ids")]
|
||||
|
||||
return data
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Agent
|
||||
# TODO: Serialize these as well...
|
||||
exclude = ("tools", "sources", "core_memory", "tags", "source_passages", "agent_passages", "organization")
|
||||
exclude = BaseSchema.Meta.exclude + ("sources", "tags", "source_passages", "agent_passages")
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from marshmallow import post_dump, pre_load
|
||||
from marshmallow_sqlalchemy import SQLAlchemyAutoSchema
|
||||
from sqlalchemy.inspection import inspect
|
||||
|
||||
from letta.schemas.user import User
|
||||
|
||||
|
||||
class BaseSchema(SQLAlchemyAutoSchema):
|
||||
@@ -7,6 +13,41 @@ class BaseSchema(SQLAlchemyAutoSchema):
|
||||
This ensures all schemas share the same session.
|
||||
"""
|
||||
|
||||
__pydantic_model__ = None
|
||||
sensitive_ids = {"_created_by_id", "_last_updated_by_id"}
|
||||
sensitive_relationships = {"organization"}
|
||||
id_scramble_placeholder = "xxx"
|
||||
|
||||
def __init__(self, *args, actor: Optional[User] = None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.actor = actor
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs):
|
||||
data["id"] = self.__pydantic_model__.generate_id()
|
||||
|
||||
for sensitive_id in BaseSchema.sensitive_ids.union(BaseSchema.sensitive_relationships):
|
||||
if sensitive_id in data:
|
||||
data[sensitive_id] = BaseSchema.id_scramble_placeholder
|
||||
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def regenerate_ids(self, data: Dict, **kwargs):
|
||||
if self.Meta.model:
|
||||
mapper = inspect(self.Meta.model)
|
||||
for sensitive_id in BaseSchema.sensitive_ids:
|
||||
if sensitive_id in mapper.columns:
|
||||
data[sensitive_id] = self.actor.id
|
||||
|
||||
for relationship in BaseSchema.sensitive_relationships:
|
||||
if relationship in mapper.relationships:
|
||||
data[relationship] = self.actor.organization_id
|
||||
|
||||
return data
|
||||
|
||||
class Meta:
|
||||
model = None
|
||||
include_relationships = True
|
||||
load_instance = True
|
||||
exclude = ()
|
||||
|
||||
15
letta/serialize_schemas/block.py
Normal file
15
letta/serialize_schemas/block.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from letta.orm.block import Block
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
|
||||
|
||||
class SerializedBlockSchema(BaseSchema):
|
||||
"""
|
||||
Marshmallow schema for serializing/deserializing Block objects.
|
||||
"""
|
||||
|
||||
__pydantic_model__ = PydanticBlock
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Block
|
||||
exclude = BaseSchema.Meta.exclude + ("agents",)
|
||||
@@ -1,4 +1,9 @@
|
||||
from typing import Dict
|
||||
|
||||
from marshmallow import post_dump
|
||||
|
||||
from letta.orm.message import Message
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.custom_fields import ToolCallField
|
||||
|
||||
@@ -8,8 +13,17 @@ class SerializedMessageSchema(BaseSchema):
|
||||
Marshmallow schema for serializing/deserializing Message objects.
|
||||
"""
|
||||
|
||||
__pydantic_model__ = PydanticMessage
|
||||
|
||||
tool_calls = ToolCallField()
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs):
|
||||
# We don't want to remap here
|
||||
# Because of the way that message_ids is just a JSON field on agents
|
||||
# We need to wait for the agent dumps, and then keep track of all the message IDs we remapped
|
||||
return data
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Message
|
||||
exclude = ("step", "job_message")
|
||||
exclude = BaseSchema.Meta.exclude + ("step", "job_message", "agent")
|
||||
|
||||
15
letta/serialize_schemas/tool.py
Normal file
15
letta/serialize_schemas/tool.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from letta.orm import Tool
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
|
||||
|
||||
class SerializedToolSchema(BaseSchema):
|
||||
"""
|
||||
Marshmallow schema for serializing/deserializing Tool objects.
|
||||
"""
|
||||
|
||||
__pydantic_model__ = PydanticTool
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Tool
|
||||
exclude = BaseSchema.Meta.exclude
|
||||
@@ -35,6 +35,7 @@ from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.serialize_schemas import SerializedAgentSchema
|
||||
from letta.serialize_schemas.tool import SerializedToolSchema
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
_process_relationship,
|
||||
@@ -394,18 +395,28 @@ class AgentManager:
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
schema = SerializedAgentSchema(session=session)
|
||||
schema = SerializedAgentSchema(session=session, actor=actor)
|
||||
return schema.dump(agent)
|
||||
|
||||
@enforce_types
|
||||
def deserialize(self, serialized_agent: dict, actor: PydanticUser) -> PydanticAgentState:
|
||||
# TODO: Use actor to override fields
|
||||
def deserialize(self, serialized_agent: dict, actor: PydanticUser, mark_as_copy: bool = True) -> PydanticAgentState:
|
||||
tool_data_list = serialized_agent.pop("tools", [])
|
||||
|
||||
with self.session_maker() as session:
|
||||
schema = SerializedAgentSchema(session=session)
|
||||
schema = SerializedAgentSchema(session=session, actor=actor)
|
||||
agent = schema.load(serialized_agent, session=session)
|
||||
agent.organization_id = actor.organization_id
|
||||
agent = agent.create(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
if mark_as_copy:
|
||||
agent.name += "_copy"
|
||||
agent.create(session, actor=actor)
|
||||
pydantic_agent = agent.to_pydantic()
|
||||
|
||||
# Need to do this separately as there's some fancy upsert logic that SqlAlchemy cannot handle
|
||||
for tool_data in tool_data_list:
|
||||
pydantic_tool = SerializedToolSchema(actor=actor).load(tool_data, transient=True).to_pydantic()
|
||||
pydantic_tool = self.tool_manager.create_or_update_tool(pydantic_tool, actor=actor)
|
||||
pydantic_agent = self.attach_tool(agent_id=pydantic_agent.id, tool_id=pydantic_tool.id, actor=actor)
|
||||
|
||||
return pydantic_agent
|
||||
|
||||
# ======================================================================================================================
|
||||
# Per Agent Environment Variable Management
|
||||
|
||||
@@ -46,7 +46,7 @@ class MessageManager:
|
||||
|
||||
# Sort results directly based on message_ids
|
||||
result_dict = {msg.id: msg.to_pydantic() for msg in results}
|
||||
return [result_dict[msg_id] for msg_id in message_ids]
|
||||
return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids]))
|
||||
|
||||
@enforce_types
|
||||
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
|
||||
|
||||
@@ -1,15 +1,27 @@
|
||||
import difflib
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Mapping
|
||||
|
||||
import pytest
|
||||
from rich.console import Console
|
||||
from rich.syntax import Syntax
|
||||
|
||||
from letta import create_client
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm import Base
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.agent import AgentState, CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.user import User
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def _clear_tables():
|
||||
from letta.server.db import db_context
|
||||
@@ -58,80 +70,291 @@ def default_user(server: SyncServer, default_organization):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sarah_agent(server: SyncServer, default_user, default_organization):
|
||||
def other_organization(server: SyncServer):
|
||||
"""Fixture to create and return the default organization."""
|
||||
org = server.organization_manager.create_organization(pydantic_org=Organization(name="letta"))
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_user(server: SyncServer, other_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = server.user_manager.create_user(pydantic_user=User(organization_id=other_organization.id, name="sarah"))
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serialize_test_agent(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create and return a sample agent within the default organization."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
agent_name = f"serialize_test_agent_{timestamp}"
|
||||
|
||||
server.tool_manager.upsert_base_tools(actor=default_user)
|
||||
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="sarah_agent",
|
||||
memory_blocks=[],
|
||||
name=agent_name,
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
value="Name: Caren",
|
||||
label="human",
|
||||
),
|
||||
],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=True,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
yield agent_state
|
||||
|
||||
|
||||
def test_agent_serialization(server, sarah_agent, default_user):
|
||||
"""Test serializing an Agent instance to JSON."""
|
||||
result = server.agent_manager.serialize(agent_id=sarah_agent.id, actor=default_user)
|
||||
|
||||
# Assert that the result is a dictionary (JSON object)
|
||||
assert isinstance(result, dict), "Expected a dictionary result"
|
||||
|
||||
# Assert that the 'id' field is present and matches the agent's ID
|
||||
assert "id" in result, "Agent 'id' is missing in the serialized result"
|
||||
assert result["id"] == sarah_agent.id, f"Expected agent 'id' to be {sarah_agent.id}, but got {result['id']}"
|
||||
|
||||
# Assert that the 'llm_config' and 'embedding_config' fields exist
|
||||
assert "llm_config" in result, "'llm_config' is missing in the serialized result"
|
||||
assert "embedding_config" in result, "'embedding_config' is missing in the serialized result"
|
||||
|
||||
# Assert that 'messages' is a list
|
||||
assert isinstance(result.get("messages", []), list), "'messages' should be a list"
|
||||
|
||||
# Assert that the 'tool_exec_environment_variables' field is a list (empty or populated)
|
||||
assert isinstance(result.get("tool_exec_environment_variables", []), list), "'tool_exec_environment_variables' should be a list"
|
||||
|
||||
# Assert that the 'agent_type' is a valid string
|
||||
assert isinstance(result.get("agent_type"), str), "'agent_type' should be a string"
|
||||
|
||||
# Assert that the 'tool_rules' field is a list (even if empty)
|
||||
assert isinstance(result.get("tool_rules", []), list), "'tool_rules' should be a list"
|
||||
|
||||
# Check that all necessary fields are present in the 'messages' section, focusing on core elements
|
||||
if "messages" in result:
|
||||
for message in result["messages"]:
|
||||
assert "id" in message, "Message 'id' is missing"
|
||||
assert "text" in message, "Message 'text' is missing"
|
||||
assert "role" in message, "Message 'role' is missing"
|
||||
assert "created_at" in message, "Message 'created_at' is missing"
|
||||
assert "updated_at" in message, "Message 'updated_at' is missing"
|
||||
|
||||
# Optionally check that 'created_at' and 'updated_at' are in ISO 8601 format
|
||||
assert isinstance(result["created_at"], str), "Expected 'created_at' to be a string"
|
||||
assert isinstance(result["updated_at"], str), "Expected 'updated_at' to be a string"
|
||||
|
||||
# Optionally check for presence of any required metadata or ensure it is null if expected
|
||||
assert "metadata_" in result, "'metadata_' field is missing"
|
||||
assert result["metadata_"] is None, "'metadata_' should be null"
|
||||
|
||||
# Assert that the agent name is as expected (if defined)
|
||||
assert result.get("name") == sarah_agent.name, "Expected agent 'name' to not be None, but found something else"
|
||||
|
||||
print(json.dumps(result, indent=4))
|
||||
# Helper functions below
|
||||
|
||||
|
||||
def test_agent_deserialization_basic(local_client, server, sarah_agent, default_user):
|
||||
def dict_to_pretty_json(d: Dict[str, Any]) -> str:
|
||||
"""Convert a dictionary to a pretty JSON string with sorted keys, handling datetime objects."""
|
||||
return json.dumps(d, indent=2, sort_keys=True, default=_json_serializable)
|
||||
|
||||
|
||||
def _json_serializable(obj: Any) -> Any:
|
||||
"""Convert non-serializable objects (like datetime) to a JSON-friendly format."""
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat() # Convert datetime to ISO 8601 format
|
||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
||||
|
||||
|
||||
def print_dict_diff(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> None:
|
||||
"""Prints a detailed colorized diff between two dictionaries."""
|
||||
json1 = dict_to_pretty_json(dict1).splitlines()
|
||||
json2 = dict_to_pretty_json(dict2).splitlines()
|
||||
|
||||
diff = list(difflib.unified_diff(json1, json2, fromfile="Expected", tofile="Actual", lineterm=""))
|
||||
|
||||
if diff:
|
||||
console.print("\n🔍 [bold red]Dictionary Diff:[/bold red]")
|
||||
diff_text = "\n".join(diff)
|
||||
syntax = Syntax(diff_text, "diff", theme="monokai", line_numbers=False)
|
||||
console.print(syntax)
|
||||
else:
|
||||
console.print("\n✅ [bold green]No differences found in dictionaries.[/bold green]")
|
||||
|
||||
|
||||
def has_same_prefix(value1: Any, value2: Any) -> bool:
|
||||
"""Check if two string values have the same major prefix (before the second hyphen)."""
|
||||
if not isinstance(value1, str) or not isinstance(value2, str):
|
||||
return False
|
||||
|
||||
prefix1 = value1.split("-")[0]
|
||||
prefix2 = value2.split("-")[0]
|
||||
|
||||
return prefix1 == prefix2
|
||||
|
||||
|
||||
def compare_lists(list1: List[Any], list2: List[Any]) -> bool:
|
||||
"""Compare lists while handling unordered dictionaries inside."""
|
||||
if len(list1) != len(list2):
|
||||
return False
|
||||
|
||||
if all(isinstance(item, Mapping) for item in list1) and all(isinstance(item, Mapping) for item in list2):
|
||||
return all(any(_compare_agent_state_model_dump(i1, i2, log=False) for i2 in list2) for i1 in list1)
|
||||
|
||||
return sorted(list1) == sorted(list2)
|
||||
|
||||
|
||||
def strip_datetime_fields(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Remove datetime fields from a dictionary before comparison."""
|
||||
return {k: v for k, v in d.items() if not isinstance(v, datetime)}
|
||||
|
||||
|
||||
def _log_mismatch(key: str, expected: Any, actual: Any, log: bool) -> None:
|
||||
"""Log detailed information about a mismatch."""
|
||||
if log:
|
||||
print(f"\n🔴 Mismatch Found in Key: '{key}'")
|
||||
print(f"Expected: {expected}")
|
||||
print(f"Actual: {actual}")
|
||||
|
||||
if isinstance(expected, str) and isinstance(actual, str):
|
||||
print("\n🔍 String Diff:")
|
||||
diff = difflib.ndiff(expected.splitlines(), actual.splitlines())
|
||||
print("\n".join(diff))
|
||||
|
||||
|
||||
def _compare_agent_state_model_dump(d1: Dict[str, Any], d2: Dict[str, Any], log: bool = True) -> bool:
|
||||
"""
|
||||
Compare two dictionaries with special handling:
|
||||
- Keys in `ignore_prefix_fields` should match only by prefix.
|
||||
- 'message_ids' lists should match in length only.
|
||||
- Datetime fields are ignored.
|
||||
- Order-independent comparison for lists of dicts.
|
||||
"""
|
||||
ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id"}
|
||||
|
||||
# Remove datetime fields upfront
|
||||
d1 = strip_datetime_fields(d1)
|
||||
d2 = strip_datetime_fields(d2)
|
||||
|
||||
if d1.keys() != d2.keys():
|
||||
_log_mismatch("dict_keys", set(d1.keys()), set(d2.keys()))
|
||||
return False
|
||||
|
||||
for key, v1 in d1.items():
|
||||
v2 = d2[key]
|
||||
|
||||
if key in ignore_prefix_fields:
|
||||
if v1 and v2 and not has_same_prefix(v1, v2):
|
||||
_log_mismatch(key, v1, v2, log)
|
||||
return False
|
||||
elif key == "message_ids":
|
||||
if not isinstance(v1, list) or not isinstance(v2, list) or len(v1) != len(v2):
|
||||
_log_mismatch(key, v1, v2, log)
|
||||
return False
|
||||
elif isinstance(v1, Dict) and isinstance(v2, Dict):
|
||||
if not _compare_agent_state_model_dump(v1, v2):
|
||||
_log_mismatch(key, v1, v2, log)
|
||||
return False
|
||||
elif isinstance(v1, list) and isinstance(v2, list):
|
||||
if not compare_lists(v1, v2):
|
||||
_log_mismatch(key, v1, v2, log)
|
||||
return False
|
||||
elif v1 != v2:
|
||||
_log_mismatch(key, v1, v2, log)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def compare_agent_state(original: AgentState, copy: AgentState, mark_as_copy: bool) -> bool:
|
||||
"""Wrapper function that provides a default set of ignored prefix fields."""
|
||||
if not mark_as_copy:
|
||||
assert original.name == copy.name
|
||||
|
||||
return _compare_agent_state_model_dump(original.model_dump(exclude="name"), copy.model_dump(exclude="name"))
|
||||
|
||||
|
||||
# Sanity tests for our agent model_dump verifier helpers
|
||||
|
||||
|
||||
def test_sanity_identical_dicts():
|
||||
d1 = {"name": "Alice", "age": 30, "details": {"city": "New York"}}
|
||||
d2 = {"name": "Alice", "age": 30, "details": {"city": "New York"}}
|
||||
assert _compare_agent_state_model_dump(d1, d2)
|
||||
|
||||
|
||||
def test_sanity_different_dicts():
|
||||
d1 = {"name": "Alice", "age": 30}
|
||||
d2 = {"name": "Bob", "age": 30}
|
||||
assert not _compare_agent_state_model_dump(d1, d2)
|
||||
|
||||
|
||||
def test_sanity_ignored_id_fields():
|
||||
d1 = {"id": "user-abc123", "name": "Alice"}
|
||||
d2 = {"id": "user-xyz789", "name": "Alice"} # Different ID, same prefix
|
||||
assert _compare_agent_state_model_dump(d1, d2)
|
||||
|
||||
|
||||
def test_sanity_different_id_prefix_fails():
|
||||
d1 = {"id": "user-abc123"}
|
||||
d2 = {"id": "admin-xyz789"} # Different prefix
|
||||
assert not _compare_agent_state_model_dump(d1, d2)
|
||||
|
||||
|
||||
def test_sanity_nested_dicts():
|
||||
d1 = {"user": {"id": "user-123", "name": "Alice"}}
|
||||
d2 = {"user": {"id": "user-456", "name": "Alice"}} # ID changes, but prefix matches
|
||||
assert _compare_agent_state_model_dump(d1, d2)
|
||||
|
||||
|
||||
def test_sanity_list_handling():
|
||||
d1 = {"items": [1, 2, 3]}
|
||||
d2 = {"items": [1, 2, 3]}
|
||||
assert _compare_agent_state_model_dump(d1, d2)
|
||||
|
||||
|
||||
def test_sanity_list_mismatch():
|
||||
d1 = {"items": [1, 2, 3]}
|
||||
d2 = {"items": [1, 2, 4]}
|
||||
assert not _compare_agent_state_model_dump(d1, d2)
|
||||
|
||||
|
||||
def test_sanity_message_ids_length_check():
|
||||
d1 = {"message_ids": ["msg-123", "msg-456", "msg-789"]}
|
||||
d2 = {"message_ids": ["msg-abc", "msg-def", "msg-ghi"]} # Same length, different values
|
||||
assert _compare_agent_state_model_dump(d1, d2)
|
||||
|
||||
|
||||
def test_sanity_message_ids_different_length():
|
||||
d1 = {"message_ids": ["msg-123", "msg-456"]}
|
||||
d2 = {"message_ids": ["msg-123"]}
|
||||
assert not _compare_agent_state_model_dump(d1, d2)
|
||||
|
||||
|
||||
def test_sanity_datetime_fields():
|
||||
d1 = {"created_at": datetime(2025, 3, 4, 18, 25, 37, tzinfo=timezone.utc)}
|
||||
d2 = {"created_at": datetime(2025, 3, 4, 18, 25, 37, tzinfo=timezone.utc)}
|
||||
assert _compare_agent_state_model_dump(d1, d2)
|
||||
|
||||
|
||||
def test_sanity_datetime_mismatch():
|
||||
d1 = {"created_at": datetime(2025, 3, 4, 18, 25, 37, tzinfo=timezone.utc)}
|
||||
d2 = {"created_at": datetime(2025, 3, 4, 18, 25, 38, tzinfo=timezone.utc)} # One second difference
|
||||
assert _compare_agent_state_model_dump(d1, d2) # Should ignore
|
||||
|
||||
|
||||
# Agent serialize/deserialize tests
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mark_as_copy", [True, False])
|
||||
def test_mark_as_copy_simple(local_client, server, serialize_test_agent, default_user, other_user, mark_as_copy):
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
# Send a message first
|
||||
sarah_agent = server.agent_manager.get_agent_by_id(agent_id=sarah_agent.id, actor=default_user)
|
||||
result = server.agent_manager.serialize(agent_id=sarah_agent.id, actor=default_user)
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
# Delete the agent
|
||||
server.agent_manager.delete_agent(sarah_agent.id, actor=default_user)
|
||||
# Deserialize the agent
|
||||
agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user, mark_as_copy=mark_as_copy)
|
||||
|
||||
agent_state = server.agent_manager.deserialize(serialized_agent=result, actor=default_user)
|
||||
# Compare serialized representations to check for exact match
|
||||
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
|
||||
assert compare_agent_state(agent_copy, serialize_test_agent, mark_as_copy=mark_as_copy)
|
||||
|
||||
assert agent_state.name == sarah_agent.name
|
||||
assert len(agent_state.message_ids) == len(sarah_agent.message_ids)
|
||||
|
||||
def test_in_context_message_id_remapping(local_client, server, serialize_test_agent, default_user, other_user):
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
# Check remapping on message_ids and messages is consistent
|
||||
assert sorted([m["id"] for m in result["messages"]]) == sorted(result["message_ids"])
|
||||
|
||||
# Deserialize the agent
|
||||
agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user)
|
||||
|
||||
# Make sure all the messages are able to be retrieved
|
||||
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user)
|
||||
assert len(in_context_messages) == len(result["message_ids"])
|
||||
assert sorted([m.id for m in in_context_messages]) == sorted(result["message_ids"])
|
||||
|
||||
|
||||
def test_agent_serialize_with_user_messages(local_client, server, serialize_test_agent, default_user, other_user):
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
mark_as_copy = False
|
||||
server.send_messages(
|
||||
actor=default_user, agent_id=serialize_test_agent.id, messages=[MessageCreate(role=MessageRole.user, content="hello")]
|
||||
)
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
# Deserialize the agent
|
||||
agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user, mark_as_copy=mark_as_copy)
|
||||
|
||||
# Get most recent original agent instance
|
||||
serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
# Compare serialized representations to check for exact match
|
||||
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
|
||||
assert compare_agent_state(agent_copy, serialize_test_agent, mark_as_copy=mark_as_copy)
|
||||
|
||||
# Make sure both agents can receive messages after
|
||||
server.send_messages(
|
||||
actor=default_user, agent_id=serialize_test_agent.id, messages=[MessageCreate(role=MessageRole.user, content="and hello again")]
|
||||
)
|
||||
server.send_messages(
|
||||
actor=other_user, agent_id=agent_copy.id, messages=[MessageCreate(role=MessageRole.user, content="and hello again")]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user