Files
letta-server/tests/test_agent_serialization_v2.py
2025-07-16 17:11:09 -07:00

967 lines
37 KiB
Python

import asyncio
from typing import List
import pytest
from letta.config import LettaConfig
from letta.errors import AgentFileExportError, AgentFileImportError
from letta.orm import Base
from letta.schemas.agent import CreateAgent
from letta.schemas.agent_file import (
AgentFileSchema,
AgentSchema,
BlockSchema,
FileSchema,
GroupSchema,
MessageSchema,
SourceSchema,
ToolSchema,
)
from letta.schemas.block import Block, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.schemas.organization import Organization
from letta.schemas.user import User
from letta.server.server import SyncServer
from letta.services.agent_file_manager import AgentFileManager
from tests.utils import create_tool_from_func
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def event_loop():
"""Use a single asyncio loop for the entire test session."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
def _clear_tables():
from letta.server.db import db_context
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
session.execute(table.delete()) # Truncate table
session.commit()
@pytest.fixture(autouse=True)
def clear_tables():
_clear_tables()
@pytest.fixture
def server():
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=True)
server.tool_manager.upsert_base_tools(actor=server.default_user)
yield server
@pytest.fixture
def default_organization(server: SyncServer):
"""Fixture to create and return the default organization."""
org = server.organization_manager.create_default_organization()
yield org
@pytest.fixture
def default_user(server: SyncServer, default_organization):
"""Fixture to create and return the default user within the default organization."""
user = server.user_manager.create_default_user(org_id=default_organization.id)
yield user
@pytest.fixture
def other_organization(server: SyncServer):
"""Fixture to create and return another organization."""
org = server.organization_manager.create_organization(pydantic_org=Organization(name="test_org"))
yield org
@pytest.fixture
def other_user(server: SyncServer, other_organization):
"""Fixture to create and return another user within the other organization."""
user = server.user_manager.create_user(pydantic_user=User(organization_id=other_organization.id, name="test_user"))
yield user
@pytest.fixture
def weather_tool_func():
def get_weather(location: str) -> str:
"""Get the current weather for a given location.
Args:
location: The city and state, e.g. San Francisco, CA
Returns:
Weather description
"""
return f"The weather in {location} is sunny and 72 degrees."
return get_weather
@pytest.fixture
def print_tool_func():
def print_message(message: str) -> str:
"""Print a message to the console.
Args:
message: The message to print
Returns:
Confirmation message
"""
print(message)
return f"Printed: {message}"
return print_tool_func
@pytest.fixture
def weather_tool(server, weather_tool_func, default_user):
weather_tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=weather_tool_func), actor=default_user)
yield weather_tool
@pytest.fixture
def print_tool(server, print_tool_func, default_user):
print_tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=print_tool_func), actor=default_user)
yield print_tool
@pytest.fixture
def test_block(server: SyncServer, default_user):
"""Fixture to create and return a test block."""
block_data = Block(
label="test_block",
value="Test Block Content",
description="A test block for agent file tests",
limit=1000,
metadata={"type": "test", "category": "demo"},
)
block = server.block_manager.create_or_update_block(block_data, actor=default_user)
yield block
@pytest.fixture
def agent_file_manager(server, default_user):
"""Fixture to create AgentFileManager with all required services."""
manager = AgentFileManager(
agent_manager=server.agent_manager,
tool_manager=server.tool_manager,
source_manager=server.source_manager,
block_manager=server.block_manager,
group_manager=server.group_manager,
mcp_manager=server.mcp_manager,
file_manager=server.file_manager,
file_agent_manager=server.file_agent_manager,
message_manager=server.message_manager,
)
yield manager
@pytest.fixture
def test_agent(server: SyncServer, default_user, default_organization, test_block, weather_tool):
"""Fixture to create and return a test agent with messages."""
memory_blocks = [
CreateBlock(label="human", value="User is a test user"),
CreateBlock(label="persona", value="I am a helpful test assistant"),
]
create_agent_request = CreateAgent(
name="test_agent_v2",
system="You are a helpful assistant for testing agent file export/import.",
memory_blocks=memory_blocks,
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
block_ids=[test_block.id],
tool_ids=[weather_tool.id],
tags=["test", "v2", "export"],
description="Test agent for agent file v2 testing",
metadata={"test_key": "test_value", "version": "v2"},
initial_message_sequence=[
MessageCreate(role=MessageRole.system, content="You are a helpful assistant."),
MessageCreate(role=MessageRole.user, content="Hello!"),
MessageCreate(role=MessageRole.assistant, content="Hello! How can I help you today?"),
],
tool_exec_environment_variables={"TEST_VAR": "test_value"},
message_buffer_autoclear=False,
)
agent_state = server.agent_manager.create_agent(
agent_create=create_agent_request,
actor=default_user,
)
# Add more messages to create a richer conversation history
server.send_messages(
actor=default_user,
agent_id=agent_state.id,
input_messages=[MessageCreate(role=MessageRole.user, content="What's the weather like?")],
)
# Get updated agent state with messages
agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_state.id, actor=default_user)
yield agent_state
# ------------------------------
# Helper Functions
# ------------------------------
def compare_agent_files(original: AgentFileSchema, imported: AgentFileSchema) -> bool:
"""Compare two AgentFileSchema objects for logical equivalence."""
errors = []
if len(original.agents) != len(imported.agents):
errors.append(f"Agent count mismatch: {len(original.agents)} vs {len(imported.agents)}")
if len(original.tools) != len(imported.tools):
errors.append(f"Tool count mismatch: {len(original.tools)} vs {len(imported.tools)}")
if len(original.blocks) != len(imported.blocks):
errors.append(f"Block count mismatch: {len(original.blocks)} vs {len(imported.blocks)}")
if len(original.groups) != len(imported.groups):
errors.append(f"Group count mismatch: {len(original.groups)} vs {len(imported.groups)}")
if len(original.files) != len(imported.files):
errors.append(f"File count mismatch: {len(original.files)} vs {len(imported.files)}")
if len(original.sources) != len(imported.sources):
errors.append(f"Source count mismatch: {len(original.sources)} vs {len(imported.sources)}")
for i, (orig_agent, imp_agent) in enumerate(zip(original.agents, imported.agents)):
agent_errors = _compare_agents(orig_agent, imp_agent, i)
errors.extend(agent_errors)
orig_tools_sorted = sorted(original.tools, key=lambda x: x.name)
imp_tools_sorted = sorted(imported.tools, key=lambda x: x.name)
for i, (orig_tool, imp_tool) in enumerate(zip(orig_tools_sorted, imp_tools_sorted)):
tool_errors = _compare_tools(orig_tool, imp_tool, i)
errors.extend(tool_errors)
orig_blocks_sorted = sorted(original.blocks, key=lambda x: x.label)
imp_blocks_sorted = sorted(imported.blocks, key=lambda x: x.label)
for i, (orig_block, imp_block) in enumerate(zip(orig_blocks_sorted, imp_blocks_sorted)):
block_errors = _compare_blocks(orig_block, imp_block, i)
errors.extend(block_errors)
for i, (orig_group, imp_group) in enumerate(zip(original.groups, imported.groups)):
group_errors = _compare_groups(orig_group, imp_group, i)
errors.extend(group_errors)
for i, (orig_file, imp_file) in enumerate(zip(original.files, imported.files)):
file_errors = _compare_files(orig_file, imp_file, i)
errors.extend(file_errors)
for i, (orig_source, imp_source) in enumerate(zip(original.sources, imported.sources)):
source_errors = _compare_sources(orig_source, imp_source, i)
errors.extend(source_errors)
if errors:
print("Agent file comparison errors:")
for error in errors:
print(f" - {error}")
return False
return True
def _compare_agents(orig: AgentSchema, imp: AgentSchema, index: int) -> List[str]:
"""Compare two AgentSchema objects for logical equivalence."""
errors = []
if orig.name != imp.name:
errors.append(f"Agent {index}: name mismatch: '{orig.name}' vs '{imp.name}'")
if orig.system != imp.system:
errors.append(f"Agent {index}: system mismatch")
if orig.description != imp.description:
errors.append(f"Agent {index}: description mismatch")
if orig.agent_type != imp.agent_type:
errors.append(f"Agent {index}: agent_type mismatch: '{orig.agent_type}' vs '{imp.agent_type}'")
if orig.tags != imp.tags:
errors.append(f"Agent {index}: tags mismatch: {orig.tags} vs {imp.tags}")
if orig.metadata != imp.metadata:
errors.append(f"Agent {index}: metadata mismatch")
if orig.llm_config != imp.llm_config:
errors.append(f"Agent {index}: llm_config mismatch")
if orig.embedding_config != imp.embedding_config:
errors.append(f"Agent {index}: embedding_config mismatch")
# Tool rules
if orig.tool_rules != imp.tool_rules:
errors.append(f"Agent {index}: tool_rules mismatch")
# Environment variables
if orig.tool_exec_environment_variables != imp.tool_exec_environment_variables:
errors.append(f"Agent {index}: tool_exec_environment_variables mismatch")
# Messages
if len(orig.messages) != len(imp.messages):
errors.append(f"Agent {index}: message count mismatch: {len(orig.messages)} vs {len(imp.messages)}")
else:
for j, (orig_msg, imp_msg) in enumerate(zip(orig.messages, imp.messages)):
msg_errors = _compare_messages(orig_msg, imp_msg, index, j)
errors.extend(msg_errors)
# In-context messages
if len(orig.in_context_message_ids) != len(imp.in_context_message_ids):
errors.append(
f"Agent {index}: in-context message count mismatch: {len(orig.in_context_message_ids)} vs {len(imp.in_context_message_ids)}"
)
# Relationship IDs (lengths should match)
if len(orig.tool_ids or []) != len(imp.tool_ids or []):
errors.append(f"Agent {index}: tool_ids count mismatch: {len(orig.tool_ids or [])} vs {len(imp.tool_ids or [])}")
if len(orig.block_ids or []) != len(imp.block_ids or []):
errors.append(f"Agent {index}: block_ids count mismatch: {len(orig.block_ids or [])} vs {len(imp.block_ids or [])}")
if len(orig.source_ids or []) != len(imp.source_ids or []):
errors.append(f"Agent {index}: source_ids count mismatch: {len(orig.source_ids or [])} vs {len(imp.source_ids or [])}")
return errors
def _compare_messages(orig: MessageSchema, imp: MessageSchema, agent_index: int, msg_index: int) -> List[str]:
"""Compare two MessageSchema objects for logical equivalence."""
errors = []
if orig.role != imp.role:
errors.append(f"Agent {agent_index}, Message {msg_index}: role mismatch: '{orig.role}' vs '{imp.role}'")
if orig.content != imp.content:
errors.append(f"Agent {agent_index}, Message {msg_index}: content mismatch")
if orig.name != imp.name:
errors.append(f"Agent {agent_index}, Message {msg_index}: name mismatch: '{orig.name}' vs '{imp.name}'")
if orig.model != imp.model:
errors.append(f"Agent {agent_index}, Message {msg_index}: model mismatch: '{orig.model}' vs '{imp.model}'")
# Skip agent_id comparison - expected to be different between original and imported
return errors
def _compare_tools(orig: ToolSchema, imp: ToolSchema, index: int) -> List[str]:
"""Compare two ToolSchema objects for logical equivalence."""
errors = []
if orig.name != imp.name:
errors.append(f"Tool {index}: name mismatch: '{orig.name}' vs '{imp.name}'")
if orig.description != imp.description:
errors.append(f"Tool {index}: description mismatch")
if orig.source_code != imp.source_code:
errors.append(f"Tool {index}: source_code mismatch")
if orig.json_schema != imp.json_schema:
errors.append(f"Tool {index}: json_schema mismatch")
if orig.tags != imp.tags:
errors.append(f"Tool {index}: tags mismatch: {orig.tags} vs {imp.tags}")
if orig.metadata_ != imp.metadata_:
errors.append(f"Tool {index}: metadata mismatch")
# Skip organization_id comparison - expected to be different between orgs
return errors
def _compare_blocks(orig: BlockSchema, imp: BlockSchema, index: int) -> List[str]:
"""Compare two BlockSchema objects for logical equivalence."""
errors = []
if orig.label != imp.label:
errors.append(f"Block {index}: label mismatch: '{orig.label}' vs '{imp.label}'")
if orig.value != imp.value:
errors.append(f"Block {index}: value mismatch")
if orig.limit != imp.limit:
errors.append(f"Block {index}: limit mismatch: {orig.limit} vs {imp.limit}")
if orig.description != imp.description:
errors.append(f"Block {index}: description mismatch")
if orig.metadata != imp.metadata:
errors.append(f"Block {index}: metadata mismatch")
if orig.template_name != imp.template_name:
errors.append(f"Block {index}: template_name mismatch: '{orig.template_name}' vs '{imp.template_name}'")
if orig.is_template != imp.is_template:
errors.append(f"Block {index}: is_template mismatch: {orig.is_template} vs {imp.is_template}")
return errors
def _compare_groups(orig: GroupSchema, imp: GroupSchema, index: int) -> List[str]:
"""Compare two GroupSchema objects for logical equivalence."""
errors = []
if orig.name != imp.name:
errors.append(f"Group {index}: name mismatch: '{orig.name}' vs '{imp.name}'")
if orig.description != imp.description:
errors.append(f"Group {index}: description mismatch")
if orig.metadata != imp.metadata:
errors.append(f"Group {index}: metadata mismatch")
return errors
def _compare_files(orig: FileSchema, imp: FileSchema, index: int) -> List[str]:
"""Compare two FileSchema objects for logical equivalence."""
errors = []
if orig.file_name != imp.file_name:
errors.append(f"File {index}: file_name mismatch: '{orig.file_name}' vs '{imp.file_name}'")
if orig.file_size != imp.file_size:
errors.append(f"File {index}: file_size mismatch: {orig.file_size} vs {imp.file_size}")
if orig.file_type != imp.file_type:
errors.append(f"File {index}: file_type mismatch: '{orig.file_type}' vs '{imp.file_type}'")
if orig.description != imp.description:
errors.append(f"File {index}: description mismatch")
if orig.metadata != imp.metadata:
errors.append(f"File {index}: metadata mismatch")
return errors
def _compare_sources(orig: SourceSchema, imp: SourceSchema, index: int) -> List[str]:
"""Compare two SourceSchema objects for logical equivalence."""
errors = []
if orig.name != imp.name:
errors.append(f"Source {index}: name mismatch: '{orig.name}' vs '{imp.name}'")
if orig.description != imp.description:
errors.append(f"Source {index}: description mismatch")
if orig.metadata != imp.metadata:
errors.append(f"Source {index}: metadata mismatch")
return errors
def validate_id_format(schema: AgentFileSchema) -> bool:
"""Validate that all IDs follow the expected format (entity-N)."""
# Check agent IDs
for agent in schema.agents:
if not agent.id.startswith("agent-"):
print(f"Invalid agent ID format: {agent.id}")
return False
# Check message IDs within agents
for message in agent.messages:
if not message.id.startswith("message-"):
print(f"Invalid message ID format: {message.id}")
return False
# Check in-context message ID references
for msg_id in agent.in_context_message_ids:
if not msg_id.startswith("message-"):
print(f"Invalid in-context message ID format: {msg_id}")
return False
# Check tool IDs
for tool in schema.tools:
if not tool.id.startswith("tool-"):
print(f"Invalid tool ID format: {tool.id}")
return False
# Check block IDs
for block in schema.blocks:
if not block.id.startswith("block-"):
print(f"Invalid block ID format: {block.id}")
return False
return True
# ------------------------------
# Tests
# ------------------------------
class TestAgentFileExport:
"""Tests for agent file export functionality."""
async def test_basic_export(self, agent_file_manager, test_agent, default_user):
"""Test basic agent export functionality."""
# Export the agent
agent_file = await agent_file_manager.export([test_agent.id], default_user)
# Validate the structure
assert isinstance(agent_file, AgentFileSchema)
assert len(agent_file.agents) == 1
assert len(agent_file.tools) > 0 # Should include base tools + weather tool
assert len(agent_file.blocks) > 0 # Should include memory blocks + test block
# Validate revision_id is automatically set in metadata
assert agent_file.metadata.get("revision_id") is not None
assert agent_file.metadata.get("revision_id") != "unknown"
assert len(agent_file.metadata.get("revision_id")) > 0
# Validate ID formats
assert validate_id_format(agent_file)
# Check agent data
exported_agent = agent_file.agents[0]
assert exported_agent.name == test_agent.name
assert exported_agent.system == test_agent.system
assert len(exported_agent.messages) > 0
assert len(exported_agent.in_context_message_ids) > 0
async def test_export_multiple_agents(self, server, agent_file_manager, test_agent, default_user, weather_tool):
"""Test exporting multiple agents."""
# Create a second agent
create_agent_request = CreateAgent(
name="second_test_agent",
system="Second test agent",
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tool_ids=[weather_tool.id],
initial_message_sequence=[
MessageCreate(role=MessageRole.user, content="Second agent message"),
],
)
second_agent = server.agent_manager.create_agent(
agent_create=create_agent_request,
actor=default_user,
)
# Export both agents
agent_file = await agent_file_manager.export([test_agent.id, second_agent.id], default_user)
# Validate
assert len(agent_file.agents) == 2
assert validate_id_format(agent_file)
# Check that agents have different IDs but tools are deduplicated
agent_ids = {agent.id for agent in agent_file.agents}
assert len(agent_ids) == 2
async def test_export_id_remapping(self, agent_file_manager, test_agent, default_user):
"""Test that IDs are properly remapped during export."""
# Export the agent
agent_file = await agent_file_manager.export([test_agent.id], default_user)
exported_agent = agent_file.agents[0]
# Verify agent ID is remapped
assert exported_agent.id == "agent-0"
assert exported_agent.id != test_agent.id
# Verify tool/block IDs are remapped
if exported_agent.tool_ids:
for tool_id in exported_agent.tool_ids:
assert tool_id.startswith("tool-")
if exported_agent.block_ids:
for block_id in exported_agent.block_ids:
assert block_id.startswith("block-")
# Verify message IDs are remapped and in-context references are consistent
message_ids = {msg.id for msg in exported_agent.messages}
for in_context_id in exported_agent.in_context_message_ids:
assert in_context_id in message_ids, f"In-context message ID {in_context_id} not found in messages"
async def test_message_agent_id_remapping(self, agent_file_manager, test_agent, default_user):
"""Test that message.agent_id is properly remapped during export."""
# Export the agent
agent_file = await agent_file_manager.export([test_agent.id], default_user)
exported_agent = agent_file.agents[0]
# Verify all messages have the remapped agent_id matching the exported agent
for message in exported_agent.messages:
assert (
message.agent_id == exported_agent.id
), f"Message {message.id} has agent_id {message.agent_id}, expected {exported_agent.id}"
# Verify agent_id is the remapped file ID format
assert exported_agent.id == "agent-0"
assert exported_agent.id != test_agent.id
async def test_export_empty_agent_list(self, agent_file_manager, default_user):
"""Test exporting empty agent list."""
agent_file = await agent_file_manager.export([], default_user)
assert len(agent_file.agents) == 0
assert len(agent_file.tools) == 0
assert len(agent_file.blocks) == 0
async def test_export_nonexistent_agent(self, agent_file_manager, default_user):
"""Test exporting non-existent agent raises error."""
with pytest.raises(AgentFileExportError): # Should raise AgentFileExportError for non-existent agent
await agent_file_manager.export(["non-existent-id"], default_user)
async def test_revision_id_automatic_setting(self, agent_file_manager, test_agent, default_user):
"""Test that revision_id is automatically set to the latest alembic revision."""
# Export the agent
agent_file = await agent_file_manager.export([test_agent.id], default_user)
# Get the expected revision ID from the function
from letta.utils import get_latest_alembic_revision
expected_revision = await get_latest_alembic_revision()
# Validate that the revision_id matches the latest alembic revision
assert agent_file.metadata.get("revision_id") == expected_revision
# Validate that it's not the fallback "unknown" value
assert agent_file.metadata.get("revision_id") != "unknown"
# Validate that it looks like a valid revision ID (12 hex characters)
assert len(agent_file.metadata.get("revision_id")) == 12
assert all(c in "0123456789abcdef" for c in agent_file.metadata.get("revision_id"))
class TestAgentFileImport:
"""Tests for agent file import functionality."""
async def test_basic_import(self, agent_file_manager, test_agent, default_user, other_user):
"""Test basic agent import functionality."""
# Export the agent
agent_file = await agent_file_manager.export([test_agent.id], default_user)
# Import the agent
result = await agent_file_manager.import_file(agent_file, other_user)
# Validate import result
assert result.success
assert result.imported_count > 0
assert len(result.id_mappings) > 0
# Verify new entities were created (not existing ones reused)
for file_id, db_id in result.id_mappings.items():
if file_id.startswith("agent-"):
assert db_id != test_agent.id # New agent should have different ID
async def test_import_preserves_data(self, server, agent_file_manager, test_agent, default_user, other_user):
"""Test that import preserves all important data."""
# Export the agent
agent_file = await agent_file_manager.export([test_agent.id], default_user)
# Import the agent
result = await agent_file_manager.import_file(agent_file, other_user)
# Get the imported agent
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
imported_agent = server.agent_manager.get_agent_by_id(imported_agent_id, other_user)
# Compare key fields
assert imported_agent.name == test_agent.name
assert imported_agent.system == test_agent.system
assert imported_agent.description == test_agent.description
assert imported_agent.metadata == test_agent.metadata
assert imported_agent.tags == test_agent.tags
# Check that tools and blocks were imported
assert len(imported_agent.tools) == len(test_agent.tools)
assert len(imported_agent.memory.blocks) == len(test_agent.memory.blocks)
# Check that messages were imported
original_messages = server.message_manager.list_messages_for_agent(test_agent.id, default_user)
imported_messages = server.message_manager.list_messages_for_agent(imported_agent_id, other_user)
assert len(imported_messages) == len(original_messages)
# Verify message content is preserved
for orig_msg, imp_msg in zip(original_messages, imported_messages):
assert orig_msg.role == imp_msg.role
assert orig_msg.content == imp_msg.content
assert imp_msg.agent_id == imported_agent_id # Should be remapped to new agent
async def test_import_message_context_preservation(self, server, agent_file_manager, test_agent, default_user, other_user):
"""Test that in-context message references are preserved during import."""
# Export the agent
agent_file = await agent_file_manager.export([test_agent.id], default_user)
# Import the agent
result = await agent_file_manager.import_file(agent_file, other_user)
# Get the imported agent
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
imported_agent = server.agent_manager.get_agent_by_id(imported_agent_id, other_user)
# Check that in-context message count is preserved
assert len(imported_agent.message_ids) == len(test_agent.message_ids)
# Verify all in-context messages exist
imported_messages = server.message_manager.list_messages_for_agent(imported_agent_id, other_user)
imported_message_ids = {msg.id for msg in imported_messages}
for in_context_id in imported_agent.message_ids:
assert in_context_id in imported_message_ids
async def test_dry_run_import(self, agent_file_manager, test_agent, default_user, other_user):
"""Test dry run import validation."""
# Export the agent
agent_file = await agent_file_manager.export([test_agent.id], default_user)
# Dry run import
result = await agent_file_manager.import_file(agent_file, other_user, dry_run=True)
# Validate dry run result
assert result.success
assert result.imported_count == 0 # No actual imports in dry run
assert len(result.id_mappings) == 0
assert "dry run" in result.message.lower()
async def test_import_validation_errors(self, agent_file_manager, other_user):
"""Test import validation catches errors."""
# Get current revision for test
from letta.utils import get_latest_alembic_revision
current_revision = await get_latest_alembic_revision()
# Create invalid agent file with duplicate IDs
invalid_agent_file = AgentFileSchema(
metadata={"revision_id": current_revision},
agents=[
AgentSchema(id="agent-0", name="agent1"),
AgentSchema(id="agent-0", name="agent2"), # Duplicate ID
],
groups=[],
blocks=[],
files=[],
sources=[],
tools=[],
)
# Import should fail validation
with pytest.raises(AgentFileImportError):
await agent_file_manager.import_file(invalid_agent_file, other_user)
class TestAgentFileRoundTrip:
"""Tests for complete export -> import -> export cycles."""
async def test_roundtrip_consistency(self, server, agent_file_manager, test_agent, default_user, other_user):
"""Test that export -> import -> export produces consistent results."""
# First export
original_export = await agent_file_manager.export([test_agent.id], default_user)
# Import
result = await agent_file_manager.import_file(original_export, other_user)
# Get imported agent ID
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
# Second export
second_export = await agent_file_manager.export([imported_agent_id], other_user)
# Compare exports (should be logically equivalent)
assert compare_agent_files(original_export, second_export)
async def test_multiple_roundtrips(self, server, agent_file_manager, test_agent, default_user, other_user):
"""Test multiple rounds of export/import maintain consistency."""
current_agent_id = test_agent.id
current_user = default_user
for i in range(3):
# Export
agent_file = await agent_file_manager.export([current_agent_id], current_user)
# Import to other user
target_user = other_user if current_user == default_user else default_user
result = await agent_file_manager.import_file(agent_file, target_user)
# Update for next iteration
current_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
current_user = target_user
# Verify the agent still works
imported_agent = server.agent_manager.get_agent_by_id(current_agent_id, current_user)
assert imported_agent.name == test_agent.name
class TestAgentFileEdgeCases:
"""Tests for edge cases and error conditions."""
async def test_agent_with_no_messages(self, server, agent_file_manager, default_user, other_user):
"""Test exporting/importing agent with no messages."""
# Create agent with no initial messages
create_agent_request = CreateAgent(
name="no_messages_agent",
system="Agent with no messages",
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
initial_message_sequence=[],
)
agent_state = await server.agent_manager.create_agent_async(
agent_create=create_agent_request,
actor=default_user,
_init_with_no_messages=True, # Create with truly no messages
)
# Export
agent_file = await agent_file_manager.export([agent_state.id], default_user)
# Import
result = await agent_file_manager.import_file(agent_file, other_user)
# Verify
assert result.success
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
imported_agent = server.agent_manager.get_agent_by_id(imported_agent_id, other_user)
assert len(imported_agent.message_ids) == 0
async def test_large_agent_file(self, server, agent_file_manager, default_user, other_user, weather_tool):
"""Test handling of larger agent files with many messages."""
# Create agent
create_agent_request = CreateAgent(
name="large_agent",
system="Agent with many messages",
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tool_ids=[weather_tool.id],
)
agent_state = server.agent_manager.create_agent(
agent_create=create_agent_request,
actor=default_user,
)
# Add many messages
for i in range(10):
server.send_messages(
actor=default_user,
agent_id=agent_state.id,
input_messages=[MessageCreate(role=MessageRole.user, content=f"Message {i}")],
)
# Export
agent_file = await agent_file_manager.export([agent_state.id], default_user)
# Verify large file
exported_agent = agent_file.agents[0]
assert len(exported_agent.messages) >= 10
# Import
result = await agent_file_manager.import_file(agent_file, other_user)
# Verify all messages imported correctly
assert result.success
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
imported_messages = server.message_manager.list_messages_for_agent(imported_agent_id, other_user)
assert len(imported_messages) >= 10
class TestAgentFileValidation:
"""Tests for agent file validation and schema compliance."""
def test_agent_file_schema_validation(self, test_agent):
"""Test AgentFileSchema validation."""
# Use a dummy revision for this test since we can't await in sync test
current_revision = "495f3f474131" # Use a known valid revision format
# Valid schema
valid_schema = AgentFileSchema(
metadata={"revision_id": current_revision},
agents=[AgentSchema(id="agent-0", name="test")],
groups=[],
blocks=[],
files=[],
sources=[],
tools=[],
# mcp_servers=[],
)
# Should not raise
assert valid_schema.agents[0].id == "agent-0"
assert valid_schema.metadata.get("revision_id") == current_revision
def test_message_schema_conversion(self, test_agent, server, default_user):
"""Test MessageSchema.from_message conversion."""
# Get a message from the test agent
messages = server.message_manager.list_messages_for_agent(test_agent.id, default_user)
if messages:
original_message = messages[0]
# Convert to MessageSchema
message_schema = MessageSchema.from_message(original_message)
# Verify conversion
assert message_schema.role == original_message.role
assert message_schema.content == original_message.content
assert message_schema.model == original_message.model
assert message_schema.agent_id == original_message.agent_id
def test_id_format_validation(self):
"""Test ID format validation helper."""
# Use a dummy revision for this test since we can't await in sync test
current_revision = "495f3f474131" # Use a known valid revision format
# Valid schema
valid_schema = AgentFileSchema(
metadata={"revision_id": current_revision},
agents=[AgentSchema(id="agent-0", name="test")],
groups=[],
blocks=[BlockSchema(id="block-0", label="test", value="test")],
files=[],
sources=[],
tools=[
ToolSchema(
id="tool-0",
name="test_tool",
source_code="test",
json_schema={"name": "test_tool", "parameters": {"type": "object", "properties": {}}},
)
],
# mcp_servers=[],
)
assert validate_id_format(valid_schema)
# Invalid schema
invalid_schema = AgentFileSchema(
metadata={"revision_id": current_revision},
agents=[AgentSchema(id="invalid-id-format", name="test")],
groups=[],
blocks=[],
files=[],
sources=[],
tools=[],
# mcp_servers=[],
)
assert not validate_id_format(invalid_schema)
if __name__ == "__main__":
pytest.main([__file__])