feat: Add basic test serialization tests (#4076)

This commit is contained in:
Matthew Zhou
2025-08-21 11:08:06 -07:00
committed by GitHub
parent ea26680674
commit dcb689cbfa
5 changed files with 975 additions and 3 deletions

View File

@@ -1,5 +1,7 @@
import io
import json
import os
import textwrap
import threading
import time
import uuid
@@ -60,6 +62,117 @@ def agent(client: LettaSDKClient):
client.agents.delete(agent_id=agent_state.id)
@pytest.fixture(scope="function")
def fibonacci_tool(client: LettaSDKClient):
"""Fixture providing Fibonacci calculation tool."""
def calculate_fibonacci(n: int) -> int:
"""Calculate the nth Fibonacci number.
Args:
n: The position in the Fibonacci sequence to calculate.
Returns:
The nth Fibonacci number.
"""
if n <= 0:
return 0
elif n == 1:
return 1
else:
a, b = 0, 1
for _ in range(2, n + 1):
a, b = b, a + b
return b
tool = client.tools.upsert_from_function(func=calculate_fibonacci, tags=["math", "utility"])
yield tool
client.tools.delete(tool.id)
@pytest.fixture(scope="function")
def preferences_tool(client: LettaSDKClient):
"""Fixture providing user preferences tool."""
def get_user_preferences(category: str) -> str:
"""Get user preferences for a specific category.
Args:
category: The preference category to retrieve (notification, theme, language).
Returns:
The user's preference for the specified category, or "not specified" if unknown.
"""
preferences = {"notification": "email only", "theme": "dark mode", "language": "english"}
return preferences.get(category, "not specified")
tool = client.tools.upsert_from_function(func=get_user_preferences, tags=["user", "preferences"])
yield tool
client.tools.delete(tool.id)
@pytest.fixture(scope="function")
def data_analysis_tool(client: LettaSDKClient):
"""Fixture providing data analysis tool."""
def analyze_data(data_type: str, values: List[float]) -> str:
"""Analyze data and provide insights.
Args:
data_type: Type of data to analyze.
values: Numerical values to analyze.
Returns:
Analysis results including average, max, and min values.
"""
if not values:
return "No data provided"
avg = sum(values) / len(values)
max_val = max(values)
min_val = min(values)
return f"Analysis of {data_type}: avg={avg:.2f}, max={max_val}, min={min_val}"
tool = client.tools.upsert_from_function(func=analyze_data, tags=["analysis", "data"])
yield tool
client.tools.delete(tool.id)
@pytest.fixture(scope="function")
def persona_block(client: LettaSDKClient):
"""Fixture providing persona memory block."""
block = client.blocks.create(
label="persona",
value="You are Alex, a data analyst and mathematician who helps users with calculations and insights. You have extensive experience in statistical analysis and prefer to provide clear, accurate results.",
limit=8000,
)
yield block
client.blocks.delete(block.id)
@pytest.fixture(scope="function")
def human_block(client: LettaSDKClient):
"""Fixture providing human memory block."""
block = client.blocks.create(
label="human",
value="username: sarah_researcher\noccupation: data scientist\ninterests: machine learning, statistics, fibonacci sequences\npreferred_communication: detailed explanations with examples",
limit=4000,
)
yield block
client.blocks.delete(block.id)
@pytest.fixture(scope="function")
def context_block(client: LettaSDKClient):
"""Fixture providing project context memory block."""
block = client.blocks.create(
label="project_context",
value="Current project: Building predictive models for financial markets. Sarah is working on sequence analysis and pattern recognition. Recently interested in mathematical sequences like Fibonacci for trend analysis.",
limit=6000,
)
yield block
client.blocks.delete(block.id)
def test_shared_blocks(client: LettaSDKClient):
# create a block
block = client.blocks.create(
@@ -1465,7 +1578,6 @@ def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient):
def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
"""Test that passing both new JSON schema AND source code still renames the tool based on source code"""
import textwrap
# Create initial tool
def initial_tool(x: int) -> int:
@@ -1543,3 +1655,217 @@ def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
finally:
# Clean up
client.tools.delete(tool_id=tool.id)
def test_import_agent_file_from_disk(
client: LettaSDKClient, fibonacci_tool, preferences_tool, data_analysis_tool, persona_block, human_block, context_block
):
"""Test exporting an agent to file and importing it back from disk."""
# Create a comprehensive agent (similar to test_agent_serialization_v2)
name = f"test_export_import_{str(uuid.uuid4())}"
temp_agent = client.agents.create(
name=name,
memory_blocks=[persona_block, human_block, context_block],
model="openai/gpt-4.1-mini",
embedding="openai/text-embedding-3-small",
tool_ids=[fibonacci_tool.id, preferences_tool.id, data_analysis_tool.id],
include_base_tools=True,
tags=["test", "export", "import"],
system="You are a helpful assistant specializing in data analysis and mathematical computations.",
)
# Add archival memory
archival_passages = ["Test archival passage for export/import testing.", "Another passage with data about testing procedures."]
for passage_text in archival_passages:
client.agents.passages.create(agent_id=temp_agent.id, text=passage_text)
# Send a test message
client.agents.messages.create(
agent_id=temp_agent.id,
messages=[
MessageCreate(
role="user",
content="Test message for export",
),
],
)
# Export the agent
serialized_v2 = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False)
# Save to file
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_basic_agent_with_blocks_tools_messages_v2.af")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as f:
json.dump(serialized_v2, f, indent=2)
# Now import from the file
with open(file_path, "rb") as f:
import_result = client.agents.import_file(
file=f, append_copy_suffix=True, override_existing_tools=True # Use suffix to avoid name conflict
)
# Basic verification
assert import_result is not None, "Import result should not be None"
assert len(import_result.agent_ids) > 0, "Should have imported at least one agent"
# Get the imported agent
imported_agent_id = import_result.agent_ids[0]
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
# Basic checks
assert imported_agent is not None, "Should be able to retrieve imported agent"
assert imported_agent.name is not None, "Imported agent should have a name"
assert imported_agent.memory is not None, "Agent should have memory"
assert len(imported_agent.tools) > 0, "Agent should have tools"
assert imported_agent.system is not None, "Agent should have a system prompt"
def test_agent_serialization_v2(
client: LettaSDKClient, fibonacci_tool, preferences_tool, data_analysis_tool, persona_block, human_block, context_block
):
"""Test agent serialization with comprehensive setup including custom tools, blocks, messages, and archival memory."""
name = f"comprehensive_test_agent_{str(uuid.uuid4())}"
temp_agent = client.agents.create(
name=name,
memory_blocks=[persona_block, human_block, context_block],
model="openai/gpt-4.1-mini",
embedding="openai/text-embedding-3-small",
tool_ids=[fibonacci_tool.id, preferences_tool.id, data_analysis_tool.id],
include_base_tools=True,
tags=["test", "comprehensive", "serialization"],
system="You are a helpful assistant specializing in data analysis and mathematical computations.",
)
# Add archival memory
archival_passages = [
"Project background: Sarah is working on a financial prediction model that uses Fibonacci retracements for technical analysis.",
"Research notes: Golden ratio (1.618) derived from Fibonacci sequence is often used in financial markets for support/resistance levels.",
]
for passage_text in archival_passages:
client.agents.passages.create(agent_id=temp_agent.id, text=passage_text)
# Send some messages
client.agents.messages.create(
agent_id=temp_agent.id,
messages=[
MessageCreate(
role="user",
content="Test message",
),
],
)
# Serialize using v2
serialized_v2 = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False)
# Convert dict to JSON bytes for import
json_str = json.dumps(serialized_v2)
file_obj = io.BytesIO(json_str.encode("utf-8"))
# Import again
import_result = client.agents.import_file(file=file_obj, append_copy_suffix=False, override_existing_tools=True)
# Verify import was successful
assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent"
imported_agent_id = import_result.agent_ids[0]
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
# ========== BASIC AGENT PROPERTIES ==========
# Name should be the same (if append_copy_suffix=False) or have suffix
assert imported_agent.name == name, f"Agent name mismatch: {imported_agent.name} != {name}"
# LLM and embedding configs should be preserved
assert (
imported_agent.llm_config.model == temp_agent.llm_config.model
), f"LLM model mismatch: {imported_agent.llm_config.model} != {temp_agent.llm_config.model}"
assert imported_agent.embedding_config.embedding_model == temp_agent.embedding_config.embedding_model, "Embedding model mismatch"
# System prompt should be preserved
assert imported_agent.system == temp_agent.system, "System prompt was not preserved"
# Tags should be preserved
assert set(imported_agent.tags) == set(temp_agent.tags), f"Tags mismatch: {imported_agent.tags} != {temp_agent.tags}"
# Agent type should be preserved
assert (
imported_agent.agent_type == temp_agent.agent_type
), f"Agent type mismatch: {imported_agent.agent_type} != {temp_agent.agent_type}"
# ========== MEMORY BLOCKS ==========
# Compare memory blocks directly from AgentState objects
original_blocks = temp_agent.memory.blocks
imported_blocks = imported_agent.memory.blocks
# Should have same number of blocks
assert len(imported_blocks) == len(original_blocks), f"Block count mismatch: {len(imported_blocks)} != {len(original_blocks)}"
# Verify each block by label
original_blocks_by_label = {block.label: block for block in original_blocks}
imported_blocks_by_label = {block.label: block for block in imported_blocks}
# Check persona block
assert "persona" in imported_blocks_by_label, "Persona block missing in imported agent"
assert "Alex" in imported_blocks_by_label["persona"].value, "Persona block content not preserved"
assert imported_blocks_by_label["persona"].limit == original_blocks_by_label["persona"].limit, "Persona block limit mismatch"
# Check human block
assert "human" in imported_blocks_by_label, "Human block missing in imported agent"
assert "sarah_researcher" in imported_blocks_by_label["human"].value, "Human block content not preserved"
assert imported_blocks_by_label["human"].limit == original_blocks_by_label["human"].limit, "Human block limit mismatch"
# Check context block
assert "project_context" in imported_blocks_by_label, "Context block missing in imported agent"
assert "financial markets" in imported_blocks_by_label["project_context"].value, "Context block content not preserved"
assert (
imported_blocks_by_label["project_context"].limit == original_blocks_by_label["project_context"].limit
), "Context block limit mismatch"
# ========== TOOLS ==========
# Compare tools directly from AgentState objects
original_tools = temp_agent.tools
imported_tools = imported_agent.tools
# Should have same number of tools
assert len(imported_tools) == len(original_tools), f"Tool count mismatch: {len(imported_tools)} != {len(original_tools)}"
original_tool_names = {tool.name for tool in original_tools}
imported_tool_names = {tool.name for tool in imported_tools}
# Check custom tools are present
assert "calculate_fibonacci" in imported_tool_names, "Fibonacci tool missing in imported agent"
assert "get_user_preferences" in imported_tool_names, "Preferences tool missing in imported agent"
assert "analyze_data" in imported_tool_names, "Data analysis tool missing in imported agent"
# Check for base tools (since we set include_base_tools=True when creating the agent)
# Base tools should also be present (at least some core ones)
base_tool_names = {"send_message", "conversation_search"}
missing_base_tools = base_tool_names - imported_tool_names
assert len(missing_base_tools) == 0, f"Missing base tools: {missing_base_tools}"
# Verify tool names match exactly
assert original_tool_names == imported_tool_names, f"Tool names don't match: {original_tool_names} != {imported_tool_names}"
# ========== MESSAGE HISTORY ==========
# Get messages for both agents
original_messages = client.agents.messages.list(agent_id=temp_agent.id, limit=100)
imported_messages = client.agents.messages.list(agent_id=imported_agent_id, limit=100)
# Should have same number of messages
assert len(imported_messages) >= 1, "Imported agent should have messages"
# Filter for user messages (excluding system-generated login messages)
original_user_msgs = [msg for msg in original_messages if msg.message_type == "user_message" and "Test message" in msg.content]
imported_user_msgs = [msg for msg in imported_messages if msg.message_type == "user_message" and "Test message" in msg.content]
# Should have the same number of test messages
assert len(imported_user_msgs) == len(
original_user_msgs
), f"User message count mismatch: {len(imported_user_msgs)} != {len(original_user_msgs)}"
# Verify test message content is preserved
if len(original_user_msgs) > 0 and len(imported_user_msgs) > 0:
assert imported_user_msgs[0].content == original_user_msgs[0].content, "User message content not preserved"
assert "Test message" in imported_user_msgs[0].content, "Test message content not found"