feat: Add basic test serialization tests (#4076)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import textwrap
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
@@ -60,6 +62,117 @@ def agent(client: LettaSDKClient):
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def fibonacci_tool(client: LettaSDKClient):
|
||||
"""Fixture providing Fibonacci calculation tool."""
|
||||
|
||||
def calculate_fibonacci(n: int) -> int:
|
||||
"""Calculate the nth Fibonacci number.
|
||||
|
||||
Args:
|
||||
n: The position in the Fibonacci sequence to calculate.
|
||||
|
||||
Returns:
|
||||
The nth Fibonacci number.
|
||||
"""
|
||||
if n <= 0:
|
||||
return 0
|
||||
elif n == 1:
|
||||
return 1
|
||||
else:
|
||||
a, b = 0, 1
|
||||
for _ in range(2, n + 1):
|
||||
a, b = b, a + b
|
||||
return b
|
||||
|
||||
tool = client.tools.upsert_from_function(func=calculate_fibonacci, tags=["math", "utility"])
|
||||
yield tool
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def preferences_tool(client: LettaSDKClient):
|
||||
"""Fixture providing user preferences tool."""
|
||||
|
||||
def get_user_preferences(category: str) -> str:
|
||||
"""Get user preferences for a specific category.
|
||||
|
||||
Args:
|
||||
category: The preference category to retrieve (notification, theme, language).
|
||||
|
||||
Returns:
|
||||
The user's preference for the specified category, or "not specified" if unknown.
|
||||
"""
|
||||
preferences = {"notification": "email only", "theme": "dark mode", "language": "english"}
|
||||
return preferences.get(category, "not specified")
|
||||
|
||||
tool = client.tools.upsert_from_function(func=get_user_preferences, tags=["user", "preferences"])
|
||||
yield tool
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def data_analysis_tool(client: LettaSDKClient):
|
||||
"""Fixture providing data analysis tool."""
|
||||
|
||||
def analyze_data(data_type: str, values: List[float]) -> str:
|
||||
"""Analyze data and provide insights.
|
||||
|
||||
Args:
|
||||
data_type: Type of data to analyze.
|
||||
values: Numerical values to analyze.
|
||||
|
||||
Returns:
|
||||
Analysis results including average, max, and min values.
|
||||
"""
|
||||
if not values:
|
||||
return "No data provided"
|
||||
avg = sum(values) / len(values)
|
||||
max_val = max(values)
|
||||
min_val = min(values)
|
||||
return f"Analysis of {data_type}: avg={avg:.2f}, max={max_val}, min={min_val}"
|
||||
|
||||
tool = client.tools.upsert_from_function(func=analyze_data, tags=["analysis", "data"])
|
||||
yield tool
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def persona_block(client: LettaSDKClient):
|
||||
"""Fixture providing persona memory block."""
|
||||
block = client.blocks.create(
|
||||
label="persona",
|
||||
value="You are Alex, a data analyst and mathematician who helps users with calculations and insights. You have extensive experience in statistical analysis and prefer to provide clear, accurate results.",
|
||||
limit=8000,
|
||||
)
|
||||
yield block
|
||||
client.blocks.delete(block.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def human_block(client: LettaSDKClient):
|
||||
"""Fixture providing human memory block."""
|
||||
block = client.blocks.create(
|
||||
label="human",
|
||||
value="username: sarah_researcher\noccupation: data scientist\ninterests: machine learning, statistics, fibonacci sequences\npreferred_communication: detailed explanations with examples",
|
||||
limit=4000,
|
||||
)
|
||||
yield block
|
||||
client.blocks.delete(block.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def context_block(client: LettaSDKClient):
|
||||
"""Fixture providing project context memory block."""
|
||||
block = client.blocks.create(
|
||||
label="project_context",
|
||||
value="Current project: Building predictive models for financial markets. Sarah is working on sequence analysis and pattern recognition. Recently interested in mathematical sequences like Fibonacci for trend analysis.",
|
||||
limit=6000,
|
||||
)
|
||||
yield block
|
||||
client.blocks.delete(block.id)
|
||||
|
||||
|
||||
def test_shared_blocks(client: LettaSDKClient):
|
||||
# create a block
|
||||
block = client.blocks.create(
|
||||
@@ -1465,7 +1578,6 @@ def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient):
|
||||
|
||||
def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
|
||||
"""Test that passing both new JSON schema AND source code still renames the tool based on source code"""
|
||||
import textwrap
|
||||
|
||||
# Create initial tool
|
||||
def initial_tool(x: int) -> int:
|
||||
@@ -1543,3 +1655,217 @@ def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
|
||||
finally:
|
||||
# Clean up
|
||||
client.tools.delete(tool_id=tool.id)
|
||||
|
||||
|
||||
def test_import_agent_file_from_disk(
|
||||
client: LettaSDKClient, fibonacci_tool, preferences_tool, data_analysis_tool, persona_block, human_block, context_block
|
||||
):
|
||||
"""Test exporting an agent to file and importing it back from disk."""
|
||||
# Create a comprehensive agent (similar to test_agent_serialization_v2)
|
||||
name = f"test_export_import_{str(uuid.uuid4())}"
|
||||
temp_agent = client.agents.create(
|
||||
name=name,
|
||||
memory_blocks=[persona_block, human_block, context_block],
|
||||
model="openai/gpt-4.1-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tool_ids=[fibonacci_tool.id, preferences_tool.id, data_analysis_tool.id],
|
||||
include_base_tools=True,
|
||||
tags=["test", "export", "import"],
|
||||
system="You are a helpful assistant specializing in data analysis and mathematical computations.",
|
||||
)
|
||||
|
||||
# Add archival memory
|
||||
archival_passages = ["Test archival passage for export/import testing.", "Another passage with data about testing procedures."]
|
||||
|
||||
for passage_text in archival_passages:
|
||||
client.agents.passages.create(agent_id=temp_agent.id, text=passage_text)
|
||||
|
||||
# Send a test message
|
||||
client.agents.messages.create(
|
||||
agent_id=temp_agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="Test message for export",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Export the agent
|
||||
serialized_v2 = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False)
|
||||
|
||||
# Save to file
|
||||
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_basic_agent_with_blocks_tools_messages_v2.af")
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(serialized_v2, f, indent=2)
|
||||
|
||||
# Now import from the file
|
||||
with open(file_path, "rb") as f:
|
||||
import_result = client.agents.import_file(
|
||||
file=f, append_copy_suffix=True, override_existing_tools=True # Use suffix to avoid name conflict
|
||||
)
|
||||
|
||||
# Basic verification
|
||||
assert import_result is not None, "Import result should not be None"
|
||||
assert len(import_result.agent_ids) > 0, "Should have imported at least one agent"
|
||||
|
||||
# Get the imported agent
|
||||
imported_agent_id = import_result.agent_ids[0]
|
||||
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
|
||||
|
||||
# Basic checks
|
||||
assert imported_agent is not None, "Should be able to retrieve imported agent"
|
||||
assert imported_agent.name is not None, "Imported agent should have a name"
|
||||
assert imported_agent.memory is not None, "Agent should have memory"
|
||||
assert len(imported_agent.tools) > 0, "Agent should have tools"
|
||||
assert imported_agent.system is not None, "Agent should have a system prompt"
|
||||
|
||||
|
||||
def test_agent_serialization_v2(
|
||||
client: LettaSDKClient, fibonacci_tool, preferences_tool, data_analysis_tool, persona_block, human_block, context_block
|
||||
):
|
||||
"""Test agent serialization with comprehensive setup including custom tools, blocks, messages, and archival memory."""
|
||||
name = f"comprehensive_test_agent_{str(uuid.uuid4())}"
|
||||
temp_agent = client.agents.create(
|
||||
name=name,
|
||||
memory_blocks=[persona_block, human_block, context_block],
|
||||
model="openai/gpt-4.1-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tool_ids=[fibonacci_tool.id, preferences_tool.id, data_analysis_tool.id],
|
||||
include_base_tools=True,
|
||||
tags=["test", "comprehensive", "serialization"],
|
||||
system="You are a helpful assistant specializing in data analysis and mathematical computations.",
|
||||
)
|
||||
|
||||
# Add archival memory
|
||||
archival_passages = [
|
||||
"Project background: Sarah is working on a financial prediction model that uses Fibonacci retracements for technical analysis.",
|
||||
"Research notes: Golden ratio (1.618) derived from Fibonacci sequence is often used in financial markets for support/resistance levels.",
|
||||
]
|
||||
|
||||
for passage_text in archival_passages:
|
||||
client.agents.passages.create(agent_id=temp_agent.id, text=passage_text)
|
||||
|
||||
# Send some messages
|
||||
client.agents.messages.create(
|
||||
agent_id=temp_agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="Test message",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Serialize using v2
|
||||
serialized_v2 = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False)
|
||||
# Convert dict to JSON bytes for import
|
||||
json_str = json.dumps(serialized_v2)
|
||||
file_obj = io.BytesIO(json_str.encode("utf-8"))
|
||||
|
||||
# Import again
|
||||
import_result = client.agents.import_file(file=file_obj, append_copy_suffix=False, override_existing_tools=True)
|
||||
|
||||
# Verify import was successful
|
||||
assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent"
|
||||
imported_agent_id = import_result.agent_ids[0]
|
||||
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
|
||||
|
||||
# ========== BASIC AGENT PROPERTIES ==========
|
||||
# Name should be the same (if append_copy_suffix=False) or have suffix
|
||||
assert imported_agent.name == name, f"Agent name mismatch: {imported_agent.name} != {name}"
|
||||
|
||||
# LLM and embedding configs should be preserved
|
||||
assert (
|
||||
imported_agent.llm_config.model == temp_agent.llm_config.model
|
||||
), f"LLM model mismatch: {imported_agent.llm_config.model} != {temp_agent.llm_config.model}"
|
||||
assert imported_agent.embedding_config.embedding_model == temp_agent.embedding_config.embedding_model, "Embedding model mismatch"
|
||||
|
||||
# System prompt should be preserved
|
||||
assert imported_agent.system == temp_agent.system, "System prompt was not preserved"
|
||||
|
||||
# Tags should be preserved
|
||||
assert set(imported_agent.tags) == set(temp_agent.tags), f"Tags mismatch: {imported_agent.tags} != {temp_agent.tags}"
|
||||
|
||||
# Agent type should be preserved
|
||||
assert (
|
||||
imported_agent.agent_type == temp_agent.agent_type
|
||||
), f"Agent type mismatch: {imported_agent.agent_type} != {temp_agent.agent_type}"
|
||||
|
||||
# ========== MEMORY BLOCKS ==========
|
||||
# Compare memory blocks directly from AgentState objects
|
||||
original_blocks = temp_agent.memory.blocks
|
||||
imported_blocks = imported_agent.memory.blocks
|
||||
|
||||
# Should have same number of blocks
|
||||
assert len(imported_blocks) == len(original_blocks), f"Block count mismatch: {len(imported_blocks)} != {len(original_blocks)}"
|
||||
|
||||
# Verify each block by label
|
||||
original_blocks_by_label = {block.label: block for block in original_blocks}
|
||||
imported_blocks_by_label = {block.label: block for block in imported_blocks}
|
||||
|
||||
# Check persona block
|
||||
assert "persona" in imported_blocks_by_label, "Persona block missing in imported agent"
|
||||
assert "Alex" in imported_blocks_by_label["persona"].value, "Persona block content not preserved"
|
||||
assert imported_blocks_by_label["persona"].limit == original_blocks_by_label["persona"].limit, "Persona block limit mismatch"
|
||||
|
||||
# Check human block
|
||||
assert "human" in imported_blocks_by_label, "Human block missing in imported agent"
|
||||
assert "sarah_researcher" in imported_blocks_by_label["human"].value, "Human block content not preserved"
|
||||
assert imported_blocks_by_label["human"].limit == original_blocks_by_label["human"].limit, "Human block limit mismatch"
|
||||
|
||||
# Check context block
|
||||
assert "project_context" in imported_blocks_by_label, "Context block missing in imported agent"
|
||||
assert "financial markets" in imported_blocks_by_label["project_context"].value, "Context block content not preserved"
|
||||
assert (
|
||||
imported_blocks_by_label["project_context"].limit == original_blocks_by_label["project_context"].limit
|
||||
), "Context block limit mismatch"
|
||||
|
||||
# ========== TOOLS ==========
|
||||
# Compare tools directly from AgentState objects
|
||||
original_tools = temp_agent.tools
|
||||
imported_tools = imported_agent.tools
|
||||
|
||||
# Should have same number of tools
|
||||
assert len(imported_tools) == len(original_tools), f"Tool count mismatch: {len(imported_tools)} != {len(original_tools)}"
|
||||
|
||||
original_tool_names = {tool.name for tool in original_tools}
|
||||
imported_tool_names = {tool.name for tool in imported_tools}
|
||||
|
||||
# Check custom tools are present
|
||||
assert "calculate_fibonacci" in imported_tool_names, "Fibonacci tool missing in imported agent"
|
||||
assert "get_user_preferences" in imported_tool_names, "Preferences tool missing in imported agent"
|
||||
assert "analyze_data" in imported_tool_names, "Data analysis tool missing in imported agent"
|
||||
|
||||
# Check for base tools (since we set include_base_tools=True when creating the agent)
|
||||
# Base tools should also be present (at least some core ones)
|
||||
base_tool_names = {"send_message", "conversation_search"}
|
||||
missing_base_tools = base_tool_names - imported_tool_names
|
||||
assert len(missing_base_tools) == 0, f"Missing base tools: {missing_base_tools}"
|
||||
|
||||
# Verify tool names match exactly
|
||||
assert original_tool_names == imported_tool_names, f"Tool names don't match: {original_tool_names} != {imported_tool_names}"
|
||||
|
||||
# ========== MESSAGE HISTORY ==========
|
||||
# Get messages for both agents
|
||||
original_messages = client.agents.messages.list(agent_id=temp_agent.id, limit=100)
|
||||
imported_messages = client.agents.messages.list(agent_id=imported_agent_id, limit=100)
|
||||
|
||||
# Should have same number of messages
|
||||
assert len(imported_messages) >= 1, "Imported agent should have messages"
|
||||
|
||||
# Filter for user messages (excluding system-generated login messages)
|
||||
original_user_msgs = [msg for msg in original_messages if msg.message_type == "user_message" and "Test message" in msg.content]
|
||||
imported_user_msgs = [msg for msg in imported_messages if msg.message_type == "user_message" and "Test message" in msg.content]
|
||||
|
||||
# Should have the same number of test messages
|
||||
assert len(imported_user_msgs) == len(
|
||||
original_user_msgs
|
||||
), f"User message count mismatch: {len(imported_user_msgs)} != {len(original_user_msgs)}"
|
||||
|
||||
# Verify test message content is preserved
|
||||
if len(original_user_msgs) > 0 and len(imported_user_msgs) > 0:
|
||||
assert imported_user_msgs[0].content == original_user_msgs[0].content, "User message content not preserved"
|
||||
assert "Test message" in imported_user_msgs[0].content, "Test message content not found"
|
||||
|
||||
Reference in New Issue
Block a user