Files
letta-server/tests/test_agent_serialization_v2.py
2025-07-24 18:17:08 -07:00

1660 lines
67 KiB
Python

import asyncio
from typing import List, Optional
import pytest
from letta.config import LettaConfig
from letta.errors import AgentFileExportError, AgentFileImportError
from letta.helpers.pinecone_utils import should_use_pinecone
from letta.orm import Base
from letta.schemas.agent import CreateAgent
from letta.schemas.agent_file import (
AgentFileSchema,
AgentSchema,
BlockSchema,
FileSchema,
GroupSchema,
MessageSchema,
SourceSchema,
ToolSchema,
)
from letta.schemas.block import Block, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.schemas.organization import Organization
from letta.schemas.source import Source
from letta.schemas.user import User
from letta.server.server import SyncServer
from letta.services.agent_serialization_manager import AgentSerializationManager
from tests.utils import create_tool_from_func
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def event_loop():
"""Use a single asyncio loop for the entire test session."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
def _clear_tables():
from letta.server.db import db_context
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
session.execute(table.delete()) # Truncate table
session.commit()
@pytest.fixture(autouse=True)
def clear_tables():
_clear_tables()
@pytest.fixture
def server():
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=True)
server.tool_manager.upsert_base_tools(actor=server.default_user)
yield server
@pytest.fixture
def default_organization(server: SyncServer):
"""Fixture to create and return the default organization."""
org = server.organization_manager.create_default_organization()
yield org
@pytest.fixture
def default_user(server: SyncServer, default_organization):
"""Fixture to create and return the default user within the default organization."""
user = server.user_manager.create_default_user(org_id=default_organization.id)
yield user
@pytest.fixture
def other_organization(server: SyncServer):
"""Fixture to create and return another organization."""
org = server.organization_manager.create_organization(pydantic_org=Organization(name="test_org"))
yield org
@pytest.fixture
def other_user(server: SyncServer, other_organization):
"""Fixture to create and return another user within the other organization."""
user = server.user_manager.create_user(pydantic_user=User(organization_id=other_organization.id, name="test_user"))
yield user
@pytest.fixture
def weather_tool_func():
def get_weather(location: str) -> str:
"""Get the current weather for a given location.
Args:
location: The city and state, e.g. San Francisco, CA
Returns:
Weather description
"""
return f"The weather in {location} is sunny and 72 degrees."
return get_weather
@pytest.fixture
def print_tool_func():
def print_message(message: str) -> str:
"""Print a message to the console.
Args:
message: The message to print
Returns:
Confirmation message
"""
print(message)
return f"Printed: {message}"
return print_tool_func
@pytest.fixture
def weather_tool(server, weather_tool_func, default_user):
weather_tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=weather_tool_func), actor=default_user)
yield weather_tool
@pytest.fixture
def print_tool(server, print_tool_func, default_user):
print_tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=print_tool_func), actor=default_user)
yield print_tool
@pytest.fixture
def test_block(server: SyncServer, default_user):
"""Fixture to create and return a test block."""
block_data = Block(
label="test_block",
value="Test Block Content",
description="A test block for agent file tests",
limit=1000,
metadata={"type": "test", "category": "demo"},
)
block = server.block_manager.create_or_update_block(block_data, actor=default_user)
yield block
@pytest.fixture
def agent_serialization_manager(server, default_user):
"""Fixture to create AgentSerializationManager with all required services including file processing."""
manager = AgentSerializationManager(
agent_manager=server.agent_manager,
tool_manager=server.tool_manager,
source_manager=server.source_manager,
block_manager=server.block_manager,
group_manager=server.group_manager,
mcp_manager=server.mcp_manager,
file_manager=server.file_manager,
file_agent_manager=server.file_agent_manager,
message_manager=server.message_manager,
)
yield manager
@pytest.fixture
def test_agent(server: SyncServer, default_user, default_organization, test_block, weather_tool):
"""Fixture to create and return a test agent with messages."""
memory_blocks = [
CreateBlock(label="human", value="User is a test user"),
CreateBlock(label="persona", value="I am a helpful test assistant"),
]
create_agent_request = CreateAgent(
name="test_agent_v2",
system="You are a helpful assistant for testing agent file export/import.",
memory_blocks=memory_blocks,
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
block_ids=[test_block.id],
tool_ids=[weather_tool.id],
tags=["test", "v2", "export"],
description="Test agent for agent file v2 testing",
metadata={"test_key": "test_value", "version": "v2"},
initial_message_sequence=[
MessageCreate(role=MessageRole.system, content="You are a helpful assistant."),
MessageCreate(role=MessageRole.user, content="Hello!"),
MessageCreate(role=MessageRole.assistant, content="Hello! How can I help you today?"),
],
tool_exec_environment_variables={"TEST_VAR": "test_value"},
message_buffer_autoclear=False,
)
agent_state = server.agent_manager.create_agent(
agent_create=create_agent_request,
actor=default_user,
)
server.send_messages(
actor=default_user,
agent_id=agent_state.id,
input_messages=[MessageCreate(role=MessageRole.user, content="What's the weather like?")],
)
agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_state.id, actor=default_user)
yield agent_state
@pytest.fixture
async def test_source(server: SyncServer, default_user):
"""Fixture to create and return a test source."""
source_data = Source(
name="test_source",
description="Test source for file export tests",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
source = await server.source_manager.create_source(source_data, default_user)
yield source
@pytest.fixture
async def test_file(server: SyncServer, default_user, test_source):
"""Fixture to create and return a test file attached to test_source."""
from letta.schemas.file import FileMetadata
file_data = FileMetadata(
source_id=test_source.id,
file_name="test.txt",
original_file_name="test.txt",
file_type="text/plain",
file_size=46,
)
file_metadata = await server.file_manager.create_file(file_data, default_user, text="This is a test file for export testing.")
yield file_metadata
@pytest.fixture
async def agent_with_files(server: SyncServer, default_user, test_block, weather_tool, test_source, test_file):
"""Fixture to create and return an agent with attached files."""
memory_blocks = [
CreateBlock(label="human", value="User is a test user"),
CreateBlock(label="persona", value="I am a helpful test assistant"),
]
create_agent_request = CreateAgent(
name="test_agent_v2",
system="You are a helpful assistant for testing agent file export/import.",
memory_blocks=memory_blocks,
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
block_ids=[test_block.id],
tool_ids=[weather_tool.id],
tags=["test", "v2", "export"],
description="Test agent for agent file v2 testing",
metadata={"test_key": "test_value", "version": "v2"},
initial_message_sequence=[
MessageCreate(role=MessageRole.system, content="You are a helpful assistant."),
MessageCreate(role=MessageRole.user, content="Hello!"),
MessageCreate(role=MessageRole.assistant, content="Hello! How can I help you today?"),
],
tool_exec_environment_variables={"TEST_VAR": "test_value"},
message_buffer_autoclear=False,
source_ids=[test_source.id],
)
agent_state = await server.agent_manager.create_agent_async(
agent_create=create_agent_request,
actor=default_user,
)
await server.agent_manager.insert_files_into_context_window(
agent_state=agent_state, file_metadata_with_content=[test_file], actor=default_user
)
return (agent_state.id, test_source.id, test_file.id)
@pytest.fixture
async def test_mcp_server(server: SyncServer, default_user):
"""Fixture to create and return a test MCP server."""
from letta.schemas.mcp import MCPServer, MCPServerType
mcp_server_data = MCPServer(
server_name="test_mcp_server",
server_type=MCPServerType.SSE,
server_url="http://test-mcp-server.com",
token="test-token-12345", # This should be excluded during export
custom_headers={"X-API-Key": "secret-key"}, # This should be excluded during export
)
mcp_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server_data, default_user)
yield mcp_server
@pytest.fixture
async def mcp_tool(server: SyncServer, default_user, test_mcp_server):
"""Fixture to create and return an MCP tool."""
from letta.schemas.tool import MCPTool, ToolCreate
# Create a mock MCP tool
mcp_tool_data = MCPTool(
name="test_mcp_tool",
description="Test MCP tool for serialization",
inputSchema={"type": "object", "properties": {"input": {"type": "string"}}},
)
tool_create = ToolCreate.from_mcp(test_mcp_server.server_name, mcp_tool_data)
# Create tool with MCP metadata
mcp_tool = await server.tool_manager.create_mcp_tool_async(tool_create, test_mcp_server.server_name, test_mcp_server.id, default_user)
yield mcp_tool
@pytest.fixture
async def agent_with_mcp_tools(server: SyncServer, default_user, test_block, mcp_tool, test_mcp_server):
"""Fixture to create and return an agent with MCP tools."""
memory_blocks = [
CreateBlock(label="human", value="User is a test user"),
CreateBlock(label="persona", value="I am a helpful test assistant"),
]
create_agent_request = CreateAgent(
name="test_agent_mcp",
system="You are a helpful assistant with MCP tools.",
memory_blocks=memory_blocks,
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
block_ids=[test_block.id],
tool_ids=[mcp_tool.id],
tags=["test", "mcp", "export"],
description="Test agent with MCP tools for serialization testing",
)
agent_state = await server.agent_manager.create_agent_async(
agent_create=create_agent_request,
actor=default_user,
)
return agent_state
# ------------------------------
# Helper Functions
# ------------------------------
async def create_test_source(server: SyncServer, name: str, user: User):
"""Helper function to create a test source using server."""
source_data = Source(
name=name,
description=f"Test source {name}",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
return await server.source_manager.create_source(source_data, user)
async def create_test_file(server: SyncServer, filename: str, source_id: str, user: User, content: Optional[str] = None):
"""Helper function to create a test file using server."""
from letta.schemas.file import FileMetadata
content = content or f"Content of {filename}"
file_data = FileMetadata(
source_id=source_id,
file_name=filename,
original_file_name=filename,
file_type="text/plain",
file_size=len(content),
)
return await server.file_manager.create_file(file_data, user, text=content)
async def create_test_agent_with_files(server: SyncServer, name: str, user: User, file_relationships: List[tuple]):
"""Helper function to create agent with attached files using server.
Args:
server: SyncServer instance
name: Agent name
user: User creating the agent
file_relationships: List of (source_id, file_id) tuples
"""
memory_blocks = [
CreateBlock(label="human", value="User is a test user"),
CreateBlock(label="persona", value="I am a helpful test assistant"),
]
create_agent_request = CreateAgent(
name=name,
system="You are a helpful assistant for testing file export.",
memory_blocks=memory_blocks,
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tags=["test", "files"],
description="Test agent with files",
)
agent_state = await server.agent_manager.create_agent_async(
agent_create=create_agent_request,
actor=user,
)
for source_id, file_id in file_relationships:
file_metadata = await server.file_manager.get_file_by_id(file_id, user)
await server.agent_manager.insert_files_into_context_window(
agent_state=agent_state, file_metadata_with_content=[file_metadata], actor=user
)
return agent_state
def compare_agent_files(original: AgentFileSchema, imported: AgentFileSchema) -> bool:
"""Compare two AgentFileSchema objects for logical equivalence."""
errors = []
if len(original.agents) != len(imported.agents):
errors.append(f"Agent count mismatch: {len(original.agents)} vs {len(imported.agents)}")
if len(original.tools) != len(imported.tools):
errors.append(f"Tool count mismatch: {len(original.tools)} vs {len(imported.tools)}")
if len(original.blocks) != len(imported.blocks):
errors.append(f"Block count mismatch: {len(original.blocks)} vs {len(imported.blocks)}")
if len(original.groups) != len(imported.groups):
errors.append(f"Group count mismatch: {len(original.groups)} vs {len(imported.groups)}")
if len(original.files) != len(imported.files):
errors.append(f"File count mismatch: {len(original.files)} vs {len(imported.files)}")
if len(original.sources) != len(imported.sources):
errors.append(f"Source count mismatch: {len(original.sources)} vs {len(imported.sources)}")
for i, (orig_agent, imp_agent) in enumerate(zip(original.agents, imported.agents)):
agent_errors = _compare_agents(orig_agent, imp_agent, i)
errors.extend(agent_errors)
orig_tools_sorted = sorted(original.tools, key=lambda x: x.name)
imp_tools_sorted = sorted(imported.tools, key=lambda x: x.name)
for i, (orig_tool, imp_tool) in enumerate(zip(orig_tools_sorted, imp_tools_sorted)):
tool_errors = _compare_tools(orig_tool, imp_tool, i)
errors.extend(tool_errors)
orig_blocks_sorted = sorted(original.blocks, key=lambda x: x.label)
imp_blocks_sorted = sorted(imported.blocks, key=lambda x: x.label)
for i, (orig_block, imp_block) in enumerate(zip(orig_blocks_sorted, imp_blocks_sorted)):
block_errors = _compare_blocks(orig_block, imp_block, i)
errors.extend(block_errors)
for i, (orig_group, imp_group) in enumerate(zip(original.groups, imported.groups)):
group_errors = _compare_groups(orig_group, imp_group, i)
errors.extend(group_errors)
for i, (orig_file, imp_file) in enumerate(zip(original.files, imported.files)):
file_errors = _compare_files(orig_file, imp_file, i)
errors.extend(file_errors)
for i, (orig_source, imp_source) in enumerate(zip(original.sources, imported.sources)):
source_errors = _compare_sources(orig_source, imp_source, i)
errors.extend(source_errors)
if errors:
print("Agent file comparison errors:")
for error in errors:
print(f" - {error}")
return False
return True
def _compare_agents(orig: AgentSchema, imp: AgentSchema, index: int) -> List[str]:
"""Compare two AgentSchema objects for logical equivalence."""
errors = []
if orig.name != imp.name:
errors.append(f"Agent {index}: name mismatch: '{orig.name}' vs '{imp.name}'")
if orig.system != imp.system:
errors.append(f"Agent {index}: system mismatch")
if orig.description != imp.description:
errors.append(f"Agent {index}: description mismatch")
if orig.agent_type != imp.agent_type:
errors.append(f"Agent {index}: agent_type mismatch: '{orig.agent_type}' vs '{imp.agent_type}'")
if sorted(orig.tags or []) != sorted(imp.tags or []):
errors.append(f"Agent {index}: tags mismatch: {orig.tags} vs {imp.tags}")
if orig.metadata != imp.metadata:
errors.append(f"Agent {index}: metadata mismatch")
if orig.llm_config != imp.llm_config:
errors.append(f"Agent {index}: llm_config mismatch")
if orig.embedding_config != imp.embedding_config:
errors.append(f"Agent {index}: embedding_config mismatch")
# Tool rules
if orig.tool_rules != imp.tool_rules:
errors.append(f"Agent {index}: tool_rules mismatch")
# Environment variables
if orig.tool_exec_environment_variables != imp.tool_exec_environment_variables:
errors.append(f"Agent {index}: tool_exec_environment_variables mismatch")
# Messages
if len(orig.messages) != len(imp.messages):
errors.append(f"Agent {index}: message count mismatch: {len(orig.messages)} vs {len(imp.messages)}")
else:
for j, (orig_msg, imp_msg) in enumerate(zip(orig.messages, imp.messages)):
msg_errors = _compare_messages(orig_msg, imp_msg, index, j)
errors.extend(msg_errors)
# In-context messages
if len(orig.in_context_message_ids) != len(imp.in_context_message_ids):
errors.append(
f"Agent {index}: in-context message count mismatch: {len(orig.in_context_message_ids)} vs {len(imp.in_context_message_ids)}"
)
# Relationship IDs (lengths should match)
if len(orig.tool_ids or []) != len(imp.tool_ids or []):
errors.append(f"Agent {index}: tool_ids count mismatch: {len(orig.tool_ids or [])} vs {len(imp.tool_ids or [])}")
if len(orig.block_ids or []) != len(imp.block_ids or []):
errors.append(f"Agent {index}: block_ids count mismatch: {len(orig.block_ids or [])} vs {len(imp.block_ids or [])}")
if len(orig.source_ids or []) != len(imp.source_ids or []):
errors.append(f"Agent {index}: source_ids count mismatch: {len(orig.source_ids or [])} vs {len(imp.source_ids or [])}")
return errors
def _compare_messages(orig: MessageSchema, imp: MessageSchema, agent_index: int, msg_index: int) -> List[str]:
"""Compare two MessageSchema objects for logical equivalence."""
errors = []
if orig.role != imp.role:
errors.append(f"Agent {agent_index}, Message {msg_index}: role mismatch: '{orig.role}' vs '{imp.role}'")
if orig.content != imp.content:
errors.append(f"Agent {agent_index}, Message {msg_index}: content mismatch")
if orig.name != imp.name:
errors.append(f"Agent {agent_index}, Message {msg_index}: name mismatch: '{orig.name}' vs '{imp.name}'")
if orig.model != imp.model:
errors.append(f"Agent {agent_index}, Message {msg_index}: model mismatch: '{orig.model}' vs '{imp.model}'")
# Skip agent_id comparison - expected to be different between original and imported
return errors
def _compare_tools(orig: ToolSchema, imp: ToolSchema, index: int) -> List[str]:
"""Compare two ToolSchema objects for logical equivalence."""
errors = []
if orig.name != imp.name:
errors.append(f"Tool {index}: name mismatch: '{orig.name}' vs '{imp.name}'")
if orig.description != imp.description:
errors.append(f"Tool {index}: description mismatch")
if orig.source_code != imp.source_code:
errors.append(f"Tool {index}: source_code mismatch")
if orig.json_schema != imp.json_schema:
errors.append(f"Tool {index}: json_schema mismatch")
if sorted(orig.tags or []) != sorted(imp.tags or []):
errors.append(f"Tool {index}: tags mismatch: {orig.tags} vs {imp.tags}")
if orig.metadata_ != imp.metadata_:
errors.append(f"Tool {index}: metadata mismatch")
# Skip organization_id comparison - expected to be different between orgs
return errors
def _compare_blocks(orig: BlockSchema, imp: BlockSchema, index: int) -> List[str]:
"""Compare two BlockSchema objects for logical equivalence."""
errors = []
if orig.label != imp.label:
errors.append(f"Block {index}: label mismatch: '{orig.label}' vs '{imp.label}'")
if orig.value != imp.value:
errors.append(f"Block {index}: value mismatch")
if orig.limit != imp.limit:
errors.append(f"Block {index}: limit mismatch: {orig.limit} vs {imp.limit}")
if orig.description != imp.description:
errors.append(f"Block {index}: description mismatch")
if orig.metadata != imp.metadata:
errors.append(f"Block {index}: metadata mismatch")
if orig.template_name != imp.template_name:
errors.append(f"Block {index}: template_name mismatch: '{orig.template_name}' vs '{imp.template_name}'")
if orig.is_template != imp.is_template:
errors.append(f"Block {index}: is_template mismatch: {orig.is_template} vs {imp.is_template}")
return errors
def _compare_groups(orig: GroupSchema, imp: GroupSchema, index: int) -> List[str]:
"""Compare two GroupSchema objects for logical equivalence."""
errors = []
if orig.name != imp.name:
errors.append(f"Group {index}: name mismatch: '{orig.name}' vs '{imp.name}'")
if orig.description != imp.description:
errors.append(f"Group {index}: description mismatch")
if orig.metadata != imp.metadata:
errors.append(f"Group {index}: metadata mismatch")
return errors
def _compare_files(orig: FileSchema, imp: FileSchema, index: int) -> List[str]:
"""Compare two FileSchema objects for logical equivalence."""
errors = []
if orig.file_name != imp.file_name:
errors.append(f"File {index}: file_name mismatch: '{orig.file_name}' vs '{imp.file_name}'")
if orig.original_file_name != imp.original_file_name:
errors.append(f"File {index}: original_file_name mismatch: '{orig.original_file_name}' vs '{imp.original_file_name}'")
if orig.file_size != imp.file_size:
errors.append(f"File {index}: file_size mismatch: {orig.file_size} vs {imp.file_size}")
if orig.file_type != imp.file_type:
errors.append(f"File {index}: file_type mismatch: '{orig.file_type}' vs '{imp.file_type}'")
if orig.processing_status != imp.processing_status:
errors.append(f"File {index}: processing_status mismatch: '{orig.processing_status}' vs '{imp.processing_status}'")
if orig.metadata != imp.metadata:
errors.append(f"File {index}: metadata mismatch")
# Check source_id reference format (should be remapped)
if not imp.source_id.startswith("source-"):
errors.append(f"File {index}: source_id not properly remapped: {imp.source_id}")
return errors
def _compare_sources(orig: SourceSchema, imp: SourceSchema, index: int) -> List[str]:
"""Compare two SourceSchema objects for logical equivalence."""
errors = []
if orig.name != imp.name:
errors.append(f"Source {index}: name mismatch: '{orig.name}' vs '{imp.name}'")
if orig.description != imp.description:
errors.append(f"Source {index}: description mismatch")
if orig.instructions != imp.instructions:
errors.append(f"Source {index}: instructions mismatch")
if orig.metadata != imp.metadata:
errors.append(f"Source {index}: metadata mismatch")
if orig.embedding_config != imp.embedding_config:
errors.append(f"Source {index}: embedding_config mismatch")
return errors
def _validate_entity_id(entity_id: str, expected_prefix: str) -> bool:
"""Helper function to validate that an ID follows the expected format (prefix-N)."""
if not entity_id.startswith(f"{expected_prefix}-"):
print(f"Invalid {expected_prefix} ID format: {entity_id} should start with '{expected_prefix}-'")
return False
try:
suffix = entity_id[len(expected_prefix) + 1 :]
int(suffix)
return True
except ValueError:
print(f"Invalid {expected_prefix} ID format: {entity_id} should have integer suffix")
return False
def validate_id_format(schema: AgentFileSchema) -> bool:
"""Validate that all IDs follow the expected format (entity-N)."""
for agent in schema.agents:
if not _validate_entity_id(agent.id, "agent"):
return False
for message in agent.messages:
if not _validate_entity_id(message.id, "message"):
return False
for msg_id in agent.in_context_message_ids:
if not _validate_entity_id(msg_id, "message"):
return False
for tool in schema.tools:
if not _validate_entity_id(tool.id, "tool"):
return False
for block in schema.blocks:
if not _validate_entity_id(block.id, "block"):
return False
for file in schema.files:
if not _validate_entity_id(file.id, "file"):
return False
for source in schema.sources:
if not _validate_entity_id(source.id, "source"):
return False
return True
# ------------------------------
# Tests
# ------------------------------
class TestFileExport:
"""Test file export functionality with comprehensive validation"""
async def test_basic_file_export(self, default_user, agent_serialization_manager, agent_with_files):
"""Test basic file export functionality"""
agent_id, source_id, file_id = agent_with_files
exported = await agent_serialization_manager.export([agent_id], actor=default_user)
assert len(exported.agents) == 1
assert len(exported.sources) == 1
assert len(exported.files) == 1
agent = exported.agents[0]
assert len(agent.files_agents) == 1
assert _validate_entity_id(agent.id, "agent")
assert _validate_entity_id(exported.sources[0].id, "source")
assert _validate_entity_id(exported.files[0].id, "file")
file_agent = agent.files_agents[0]
assert file_agent.agent_id == agent.id
assert file_agent.file_id == exported.files[0].id
assert file_agent.source_id == exported.sources[0].id
async def test_multiple_files_per_source(self, server, default_user, agent_serialization_manager):
"""Test export with multiple files from the same source"""
source = await create_test_source(server, "multi-file-source", default_user)
file1 = await create_test_file(server, "file1.txt", source.id, default_user)
file2 = await create_test_file(server, "file2.txt", source.id, default_user)
agent = await create_test_agent_with_files(server, "multi-file-agent", default_user, [(source.id, file1.id), (source.id, file2.id)])
exported = await agent_serialization_manager.export([agent.id], actor=default_user)
assert len(exported.agents) == 1
assert len(exported.sources) == 1
assert len(exported.files) == 2
agent = exported.agents[0]
assert len(agent.files_agents) == 2
source_id = exported.sources[0].id
for file_schema in exported.files:
assert file_schema.source_id == source_id
file_ids = {f.id for f in exported.files}
for file_agent in agent.files_agents:
assert file_agent.file_id in file_ids
assert file_agent.source_id == source_id
async def test_multiple_sources_export(self, server, default_user, agent_serialization_manager):
"""Test export with files from multiple sources"""
source1 = await create_test_source(server, "source-1", default_user)
source2 = await create_test_source(server, "source-2", default_user)
file1 = await create_test_file(server, "file1.txt", source1.id, default_user)
file2 = await create_test_file(server, "file2.txt", source2.id, default_user)
agent = await create_test_agent_with_files(
server, "multi-source-agent", default_user, [(source1.id, file1.id), (source2.id, file2.id)]
)
exported = await agent_serialization_manager.export([agent.id], actor=default_user)
assert len(exported.agents) == 1
assert len(exported.sources) == 2
assert len(exported.files) == 2
source_ids = {s.id for s in exported.sources}
for file_schema in exported.files:
assert file_schema.source_id in source_ids
async def test_cross_agent_file_deduplication(self, server, default_user, agent_serialization_manager):
"""Test that files shared across agents are deduplicated in export"""
source = await create_test_source(server, "shared-source", default_user)
shared_file = await create_test_file(server, "shared.txt", source.id, default_user)
agent1 = await create_test_agent_with_files(server, "agent-1", default_user, [(source.id, shared_file.id)])
agent2 = await create_test_agent_with_files(server, "agent-2", default_user, [(source.id, shared_file.id)])
exported = await agent_serialization_manager.export([agent1.id, agent2.id], actor=default_user)
assert len(exported.agents) == 2
assert len(exported.sources) == 1
assert len(exported.files) == 1
file_id = exported.files[0].id
source_id = exported.sources[0].id
for agent in exported.agents:
assert len(agent.files_agents) == 1
file_agent = agent.files_agents[0]
assert file_agent.file_id == file_id
assert file_agent.source_id == source_id
async def test_file_agent_relationship_preservation(self, server, default_user, agent_serialization_manager):
"""Test that file-agent relationship details are preserved"""
source = await create_test_source(server, "test-source", default_user)
file = await create_test_file(server, "test.txt", source.id, default_user)
agent = await create_test_agent_with_files(server, "test-agent", default_user, [(source.id, file.id)])
exported = await agent_serialization_manager.export([agent.id], actor=default_user)
agent = exported.agents[0]
file_agent = agent.files_agents[0]
assert file_agent.file_name == file.file_name
assert file_agent.is_open is True
assert hasattr(file_agent, "last_accessed_at")
async def test_id_remapping_consistency(self, server, default_user, agent_serialization_manager):
"""Test that ID remapping is consistent across all references"""
source = await create_test_source(server, "consistency-source", default_user)
file = await create_test_file(server, "consistency.txt", source.id, default_user)
agent = await create_test_agent_with_files(server, "consistency-agent", default_user, [(source.id, file.id)])
exported = await agent_serialization_manager.export([agent.id], actor=default_user)
agent_schema = exported.agents[0]
source_schema = exported.sources[0]
file_schema = exported.files[0]
file_agent = agent_schema.files_agents[0]
assert file_schema.source_id == source_schema.id
assert file_agent.agent_id == agent_schema.id
assert file_agent.file_id == file_schema.id
assert file_agent.source_id == source_schema.id
async def test_empty_file_relationships(self, server, default_user, agent_serialization_manager):
"""Test export of agent with no file relationships"""
agent_create = CreateAgent(
name="no-files-agent",
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
agent = await server.agent_manager.create_agent_async(agent_create, actor=default_user)
exported = await agent_serialization_manager.export([agent.id], actor=default_user)
assert len(exported.agents) == 1
assert len(exported.sources) == 0
assert len(exported.files) == 0
agent_schema = exported.agents[0]
assert len(agent_schema.files_agents) == 0
async def test_file_content_inclusion_in_export(self, default_user, agent_serialization_manager, agent_with_files):
"""Test that file content is included in export"""
agent_id, source_id, file_id = agent_with_files
exported = await agent_serialization_manager.export([agent_id], actor=default_user)
file_schema = exported.files[0]
assert hasattr(file_schema, "content") or file_schema.content is not None
class TestAgentFileExport:
"""Tests for agent file export functionality."""
async def test_basic_export(self, agent_serialization_manager, test_agent, default_user):
"""Test basic agent export functionality."""
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
assert isinstance(agent_file, AgentFileSchema)
assert len(agent_file.agents) == 1
assert len(agent_file.tools) > 0 # Should include base tools + weather tool
assert len(agent_file.blocks) > 0 # Should include memory blocks + test block
assert agent_file.metadata.get("revision_id") is not None
assert agent_file.metadata.get("revision_id") != "unknown"
assert len(agent_file.metadata.get("revision_id")) > 0
assert validate_id_format(agent_file)
exported_agent = agent_file.agents[0]
assert exported_agent.name == test_agent.name
assert exported_agent.system == test_agent.system
assert len(exported_agent.messages) > 0
assert len(exported_agent.in_context_message_ids) > 0
async def test_export_multiple_agents(self, server, agent_serialization_manager, test_agent, default_user, weather_tool):
"""Test exporting multiple agents."""
create_agent_request = CreateAgent(
name="second_test_agent",
system="Second test agent",
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tool_ids=[weather_tool.id],
initial_message_sequence=[
MessageCreate(role=MessageRole.user, content="Second agent message"),
],
)
second_agent = server.agent_manager.create_agent(
agent_create=create_agent_request,
actor=default_user,
)
agent_file = await agent_serialization_manager.export([test_agent.id, second_agent.id], default_user)
assert len(agent_file.agents) == 2
assert validate_id_format(agent_file)
agent_ids = {agent.id for agent in agent_file.agents}
assert len(agent_ids) == 2
async def test_export_id_remapping(self, agent_serialization_manager, test_agent, default_user):
"""Test that IDs are properly remapped during export."""
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
exported_agent = agent_file.agents[0]
assert exported_agent.id == "agent-0"
assert exported_agent.id != test_agent.id
if exported_agent.tool_ids:
for tool_id in exported_agent.tool_ids:
assert tool_id.startswith("tool-")
if exported_agent.block_ids:
for block_id in exported_agent.block_ids:
assert block_id.startswith("block-")
message_ids = {msg.id for msg in exported_agent.messages}
for in_context_id in exported_agent.in_context_message_ids:
assert in_context_id in message_ids, f"In-context message ID {in_context_id} not found in messages"
async def test_message_agent_id_remapping(self, agent_serialization_manager, test_agent, default_user):
"""Test that message.agent_id is properly remapped during export."""
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
exported_agent = agent_file.agents[0]
for message in exported_agent.messages:
assert (
message.agent_id == exported_agent.id
), f"Message {message.id} has agent_id {message.agent_id}, expected {exported_agent.id}"
assert exported_agent.id == "agent-0"
assert exported_agent.id != test_agent.id
async def test_export_empty_agent_list(self, agent_serialization_manager, default_user):
"""Test exporting empty agent list."""
agent_file = await agent_serialization_manager.export([], default_user)
assert len(agent_file.agents) == 0
assert len(agent_file.tools) == 0
assert len(agent_file.blocks) == 0
async def test_export_nonexistent_agent(self, agent_serialization_manager, default_user):
"""Test exporting non-existent agent raises error."""
with pytest.raises(AgentFileExportError): # Should raise AgentFileExportError for non-existent agent
await agent_serialization_manager.export(["non-existent-id"], default_user)
async def test_revision_id_automatic_setting(self, agent_serialization_manager, test_agent, default_user):
"""Test that revision_id is automatically set to the latest alembic revision."""
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
from letta.utils import get_latest_alembic_revision
expected_revision = await get_latest_alembic_revision()
assert agent_file.metadata.get("revision_id") == expected_revision
assert agent_file.metadata.get("revision_id") != "unknown"
assert len(agent_file.metadata.get("revision_id")) == 12
assert all(c in "0123456789abcdef" for c in agent_file.metadata.get("revision_id"))
class TestAgentFileImport:
"""Tests for agent file import functionality."""
async def test_basic_import(self, agent_serialization_manager, test_agent, default_user, other_user):
"""Test basic agent import functionality."""
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
result = await agent_serialization_manager.import_file(agent_file, other_user)
assert result.success
assert result.imported_count > 0
assert len(result.id_mappings) > 0
for file_id, db_id in result.id_mappings.items():
if file_id.startswith("agent-"):
assert db_id != test_agent.id # New agent should have different ID
async def test_import_preserves_data(self, server, agent_serialization_manager, test_agent, default_user, other_user):
"""Test that import preserves all important data."""
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
result = await agent_serialization_manager.import_file(agent_file, other_user)
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
imported_agent = server.agent_manager.get_agent_by_id(imported_agent_id, other_user)
assert imported_agent.name == test_agent.name
assert imported_agent.system == test_agent.system
assert imported_agent.description == test_agent.description
assert imported_agent.metadata == test_agent.metadata
assert imported_agent.tags == test_agent.tags
assert len(imported_agent.tools) == len(test_agent.tools)
assert len(imported_agent.memory.blocks) == len(test_agent.memory.blocks)
original_messages = server.message_manager.list_messages_for_agent(test_agent.id, default_user)
imported_messages = server.message_manager.list_messages_for_agent(imported_agent_id, other_user)
assert len(imported_messages) == len(original_messages)
for orig_msg, imp_msg in zip(original_messages, imported_messages):
assert orig_msg.role == imp_msg.role
assert orig_msg.content == imp_msg.content
assert imp_msg.agent_id == imported_agent_id # Should be remapped to new agent
async def test_import_message_context_preservation(self, server, agent_serialization_manager, test_agent, default_user, other_user):
"""Test that in-context message references are preserved during import."""
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
result = await agent_serialization_manager.import_file(agent_file, other_user)
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
imported_agent = server.agent_manager.get_agent_by_id(imported_agent_id, other_user)
assert len(imported_agent.message_ids) == len(test_agent.message_ids)
imported_messages = server.message_manager.list_messages_for_agent(imported_agent_id, other_user)
imported_message_ids = {msg.id for msg in imported_messages}
for in_context_id in imported_agent.message_ids:
assert in_context_id in imported_message_ids
async def test_dry_run_import(self, agent_serialization_manager, test_agent, default_user, other_user):
"""Test dry run import validation."""
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
result = await agent_serialization_manager.import_file(agent_file, other_user, dry_run=True)
assert result.success
assert result.imported_count == 0 # No actual imports in dry run
assert len(result.id_mappings) == 0
assert "dry run" in result.message.lower()
async def test_import_validation_errors(self, agent_serialization_manager, other_user):
"""Test import validation catches errors."""
from letta.utils import get_latest_alembic_revision
current_revision = await get_latest_alembic_revision()
invalid_agent_file = AgentFileSchema(
metadata={"revision_id": current_revision},
agents=[
AgentSchema(id="agent-0", name="agent1"),
AgentSchema(id="agent-0", name="agent2"), # Duplicate ID
],
groups=[],
blocks=[],
files=[],
sources=[],
tools=[],
)
with pytest.raises(AgentFileImportError):
await agent_serialization_manager.import_file(invalid_agent_file, other_user)
class TestAgentFileImportWithProcessing:
"""Tests for agent file import with file processing (chunking/embedding)."""
async def test_import_with_file_processing(self, server, agent_serialization_manager, default_user, other_user):
"""Test that import processes files for chunking and embedding."""
source = await create_test_source(server, "processing-source", default_user)
file_content = "This is test content for processing. It should be chunked and embedded during import."
file_metadata = await create_test_file(server, "process.txt", source.id, default_user, content=file_content)
agent = await create_test_agent_with_files(server, "processing-agent", default_user, [(source.id, file_metadata.id)])
exported = await agent_serialization_manager.export([agent.id], default_user)
result = await agent_serialization_manager.import_file(exported, other_user)
assert result.success
assert result.imported_count > 0
imported_file_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id.startswith("file-"))
imported_file = await server.file_manager.get_file_by_id(imported_file_id, other_user)
# When using Pinecone, status stays at embedding until chunks are confirmed uploaded
if should_use_pinecone():
assert imported_file.processing_status.value == "embedding"
else:
assert imported_file.processing_status.value == "completed"
async def test_import_passage_creation(self, server, agent_serialization_manager, default_user, other_user):
"""Test that import creates passages for file content."""
source = await create_test_source(server, "passage-source", default_user)
file_content = "This content should create passages. Each sentence should be chunked separately."
file_metadata = await create_test_file(server, "passages.txt", source.id, default_user, content=file_content)
agent = await create_test_agent_with_files(server, "passage-agent", default_user, [(source.id, file_metadata.id)])
exported = await agent_serialization_manager.export([agent.id], default_user)
result = await agent_serialization_manager.import_file(exported, other_user)
imported_file_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id.startswith("file-"))
passages = await server.passage_manager.list_passages_by_file_id_async(imported_file_id, other_user)
if should_use_pinecone():
# With Pinecone, passages are stored in Pinecone, not locally
assert len(passages) == 0
else:
# Without Pinecone, passages are stored locally
assert len(passages) > 0
for passage in passages:
assert passage.embedding is not None
assert len(passage.embedding) > 0
async def test_import_file_status_updates(self, server, agent_serialization_manager, default_user, other_user):
"""Test that file processing status is updated correctly during import."""
source = await create_test_source(server, "status-source", default_user)
file_metadata = await create_test_file(server, "status.txt", source.id, default_user)
agent = await create_test_agent_with_files(server, "status-agent", default_user, [(source.id, file_metadata.id)])
exported = await agent_serialization_manager.export([agent.id], default_user)
result = await agent_serialization_manager.import_file(exported, other_user)
imported_file_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id.startswith("file-"))
imported_file = await server.file_manager.get_file_by_id(imported_file_id, other_user)
# When using Pinecone, status stays at embedding until chunks are confirmed uploaded
if should_use_pinecone():
assert imported_file.processing_status.value == "embedding"
assert imported_file.total_chunks == 1 # Pinecone tracks chunk counts
assert imported_file.chunks_embedded == 0
else:
assert imported_file.processing_status.value == "completed"
assert imported_file.total_chunks is None
assert imported_file.chunks_embedded is None
async def test_import_multiple_files_processing(self, server, agent_serialization_manager, default_user, other_user):
"""Test import processes multiple files efficiently."""
source = await create_test_source(server, "multi-source", default_user)
file1 = await create_test_file(server, "file1.txt", source.id, default_user)
file2 = await create_test_file(server, "file2.txt", source.id, default_user)
agent = await create_test_agent_with_files(server, "multi-agent", default_user, [(source.id, file1.id), (source.id, file2.id)])
exported = await agent_serialization_manager.export([agent.id], default_user)
result = await agent_serialization_manager.import_file(exported, other_user)
imported_file_ids = [db_id for file_id, db_id in result.id_mappings.items() if file_id.startswith("file-")]
assert len(imported_file_ids) == 2
for file_id in imported_file_ids:
imported_file = await server.file_manager.get_file_by_id(file_id, other_user)
# When using Pinecone, status stays at embedding until chunks are confirmed uploaded
if should_use_pinecone():
assert imported_file.processing_status.value == "embedding"
else:
assert imported_file.processing_status.value == "completed"
class TestAgentFileRoundTrip:
"""Tests for complete export -> import -> export cycles."""
async def test_roundtrip_consistency(self, server, agent_serialization_manager, test_agent, default_user, other_user):
"""Test that export -> import -> export produces consistent results."""
original_export = await agent_serialization_manager.export([test_agent.id], default_user)
result = await agent_serialization_manager.import_file(original_export, other_user)
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
second_export = await agent_serialization_manager.export([imported_agent_id], other_user)
assert compare_agent_files(original_export, second_export)
async def test_multiple_roundtrips(self, server, agent_serialization_manager, test_agent, default_user, other_user):
"""Test multiple rounds of export/import maintain consistency."""
current_agent_id = test_agent.id
current_user = default_user
for i in range(3):
agent_file = await agent_serialization_manager.export([current_agent_id], current_user)
target_user = other_user if current_user == default_user else default_user
result = await agent_serialization_manager.import_file(agent_file, target_user)
current_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
current_user = target_user
imported_agent = server.agent_manager.get_agent_by_id(current_agent_id, current_user)
assert imported_agent.name == test_agent.name
class TestAgentFileEdgeCases:
"""Tests for edge cases and error conditions."""
async def test_agent_with_no_messages(self, server, agent_serialization_manager, default_user, other_user):
"""Test exporting/importing agent with no messages."""
# Create agent with no initial messages
create_agent_request = CreateAgent(
name="no_messages_agent",
system="Agent with no messages",
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
initial_message_sequence=[],
)
agent_state = await server.agent_manager.create_agent_async(
agent_create=create_agent_request,
actor=default_user,
_init_with_no_messages=True, # Create with truly no messages
)
# Export
agent_file = await agent_serialization_manager.export([agent_state.id], default_user)
# Import
result = await agent_serialization_manager.import_file(agent_file, other_user)
# Verify
assert result.success
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
imported_agent = server.agent_manager.get_agent_by_id(imported_agent_id, other_user)
assert len(imported_agent.message_ids) == 0
async def test_large_agent_file(self, server, agent_serialization_manager, default_user, other_user, weather_tool):
"""Test handling of larger agent files with many messages."""
# Create agent
create_agent_request = CreateAgent(
name="large_agent",
system="Agent with many messages",
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tool_ids=[weather_tool.id],
)
agent_state = server.agent_manager.create_agent(
agent_create=create_agent_request,
actor=default_user,
)
# Add many messages
for i in range(10):
server.send_messages(
actor=default_user,
agent_id=agent_state.id,
input_messages=[MessageCreate(role=MessageRole.user, content=f"Message {i}")],
)
# Export
agent_file = await agent_serialization_manager.export([agent_state.id], default_user)
# Verify large file
exported_agent = agent_file.agents[0]
assert len(exported_agent.messages) >= 10
# Import
result = await agent_serialization_manager.import_file(agent_file, other_user)
# Verify all messages imported correctly
assert result.success
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
imported_messages = server.message_manager.list_messages_for_agent(imported_agent_id, other_user)
assert len(imported_messages) >= 10
class TestAgentFileValidation:
"""Tests for agent file validation and schema compliance."""
def test_agent_file_schema_validation(self, test_agent):
"""Test AgentFileSchema validation."""
# Use a dummy revision for this test since we can't await in sync test
current_revision = "495f3f474131" # Use a known valid revision format
# Valid schema
valid_schema = AgentFileSchema(
metadata={"revision_id": current_revision},
agents=[AgentSchema(id="agent-0", name="test")],
groups=[],
blocks=[],
files=[],
sources=[],
tools=[],
mcp_servers=[],
)
# Should not raise
assert valid_schema.agents[0].id == "agent-0"
assert valid_schema.metadata.get("revision_id") == current_revision
def test_message_schema_conversion(self, test_agent, server, default_user):
"""Test MessageSchema.from_message conversion."""
# Get a message from the test agent
messages = server.message_manager.list_messages_for_agent(test_agent.id, default_user)
if messages:
original_message = messages[0]
# Convert to MessageSchema
message_schema = MessageSchema.from_message(original_message)
# Verify conversion
assert message_schema.role == original_message.role
assert message_schema.content == original_message.content
assert message_schema.model == original_message.model
assert message_schema.agent_id == original_message.agent_id
def test_id_format_validation(self):
"""Test ID format validation helper."""
# Use a dummy revision for this test since we can't await in sync test
current_revision = "495f3f474131" # Use a known valid revision format
# Valid schema
valid_schema = AgentFileSchema(
metadata={"revision_id": current_revision},
agents=[AgentSchema(id="agent-0", name="test")],
groups=[],
blocks=[BlockSchema(id="block-0", label="test", value="test")],
files=[],
sources=[],
tools=[
ToolSchema(
id="tool-0",
name="test_tool",
source_code="test",
json_schema={"name": "test_tool", "parameters": {"type": "object", "properties": {}}},
)
],
mcp_servers=[],
)
assert validate_id_format(valid_schema)
# Invalid schema
invalid_schema = AgentFileSchema(
metadata={"revision_id": current_revision},
agents=[AgentSchema(id="invalid-id-format", name="test")],
groups=[],
blocks=[],
files=[],
sources=[],
tools=[],
mcp_servers=[],
)
assert not validate_id_format(invalid_schema)
class TestMCPServerSerialization:
"""Tests for MCP server export/import functionality."""
async def test_mcp_server_export(self, agent_serialization_manager, agent_with_mcp_tools, default_user):
"""Test that MCP servers are exported correctly."""
agent_file = await agent_serialization_manager.export([agent_with_mcp_tools.id], default_user)
# Verify MCP server is included
assert len(agent_file.mcp_servers) == 1
mcp_server = agent_file.mcp_servers[0]
# Verify server details
assert mcp_server.server_name == "test_mcp_server"
assert mcp_server.server_url == "http://test-mcp-server.com"
assert mcp_server.server_type == "sse"
# Verify auth fields are excluded
assert not hasattr(mcp_server, "token")
assert not hasattr(mcp_server, "custom_headers")
# Verify ID format
assert _validate_entity_id(mcp_server.id, "mcp_server")
async def test_mcp_server_auth_scrubbing(self, server, agent_serialization_manager, default_user):
"""Test that authentication information is scrubbed during export."""
from letta.schemas.mcp import MCPServer, MCPServerType
# Create MCP server with auth info
mcp_server_data_stdio = MCPServer(
server_name="auth_test_server",
server_type=MCPServerType.STDIO,
# token="super-secret-token",
# custom_headers={"Authorization": "Bearer secret-key", "X-Custom": "custom-value"},
stdio_config={
"server_name": "auth_test_server",
"command": "test-command",
"args": ["arg1", "arg2"],
"env": {"ENV_VAR": "value"},
},
)
mcp_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server_data_stdio, default_user)
mcp_server_data_http = MCPServer(
server_name="auth_test_server_http",
server_type=MCPServerType.STREAMABLE_HTTP,
server_url="http://auth_test_server_http.com",
token="super-secret-token",
custom_headers={"X-Custom": "custom-value"},
)
mcp_server_http = await server.mcp_manager.create_or_update_mcp_server(mcp_server_data_http, default_user)
# Create tool from MCP server
from letta.schemas.tool import MCPTool, ToolCreate
mcp_tool_data = MCPTool(
name="auth_test_tool_stdio",
description="Tool with auth",
inputSchema={"type": "object", "properties": {}},
)
tool_create_stdio = ToolCreate.from_mcp(mcp_server.server_name, mcp_tool_data)
mcp_tool_data_http = MCPTool(
name="auth_test_tool_http",
description="Tool with auth",
inputSchema={"type": "object", "properties": {}},
)
tool_create_http = ToolCreate.from_mcp(mcp_server_http.server_name, mcp_tool_data_http)
mcp_tool = await server.tool_manager.create_mcp_tool_async(tool_create_stdio, mcp_server.server_name, mcp_server.id, default_user)
mcp_tool_http = await server.tool_manager.create_mcp_tool_async(
tool_create_http, mcp_server_http.server_name, mcp_server_http.id, default_user
)
# Create agent with the tool
from letta.schemas.agent import CreateAgent
create_agent_request = CreateAgent(
name="auth_test_agent",
tool_ids=[mcp_tool.id, mcp_tool_http.id],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
agent = await server.agent_manager.create_agent_async(create_agent_request, default_user)
# Export
agent_file = await agent_serialization_manager.export([agent.id], default_user)
for server in agent_file.mcp_servers:
if server.server_name == "auth_test_server":
exported_server_stdio = server
elif server.server_name == "auth_test_server_http":
exported_server_http = server
# Verify env variables in stdio server are excluded (typically used for auth)
assert exported_server_stdio.id != mcp_server.id
assert exported_server_stdio.server_name == "auth_test_server"
assert exported_server_stdio.stdio_config == {
"server_name": "auth_test_server",
"type": "stdio",
"command": "test-command",
"args": ["arg1", "arg2"],
} # Non-auth config preserved
assert exported_server_stdio.server_type == "stdio"
# Verify token and custom headers are excluded from export for http server
assert exported_server_http.id != mcp_server_http.id
assert exported_server_http.server_name == "auth_test_server_http"
assert exported_server_http.server_type == "streamable_http"
assert exported_server_http.server_url == "http://auth_test_server_http.com"
assert not hasattr(exported_server_http, "token")
assert not hasattr(exported_server_http, "custom_headers")
async def test_mcp_tool_metadata_with_server_id(self, agent_serialization_manager, agent_with_mcp_tools, default_user):
"""Test that MCP tools have server_id in metadata."""
agent_file = await agent_serialization_manager.export([agent_with_mcp_tools.id], default_user)
# Find the MCP tool
mcp_tool = next((t for t in agent_file.tools if t.name == "test_mcp_tool"), None)
assert mcp_tool is not None
# Verify metadata contains server info
assert mcp_tool.metadata_ is not None
assert "mcp" in mcp_tool.metadata_
assert "server_name" in mcp_tool.metadata_["mcp"]
assert "server_id" in mcp_tool.metadata_["mcp"]
assert mcp_tool.metadata_["mcp"]["server_name"] == "test_mcp_server"
# Verify tag format
assert any(tag.startswith("mcp:") for tag in mcp_tool.tags)
async def test_mcp_server_import(self, agent_serialization_manager, agent_with_mcp_tools, default_user, other_user):
"""Test importing agents with MCP servers."""
# Export from default user
agent_file = await agent_serialization_manager.export([agent_with_mcp_tools.id], default_user)
# Import to other user
result = await agent_serialization_manager.import_file(agent_file, other_user)
assert result.success
# Verify MCP server was imported
mcp_server_id = next((db_id for file_id, db_id in result.id_mappings.items() if file_id.startswith("mcp_server-")), None)
assert mcp_server_id is not None
async def test_multiple_mcp_servers_export(self, server, agent_serialization_manager, default_user):
"""Test exporting multiple MCP servers from different agents."""
from letta.schemas.mcp import MCPServer, MCPServerType
# Create two MCP servers
mcp_server1 = await server.mcp_manager.create_or_update_mcp_server(
MCPServer(
server_name="mcp1",
server_type=MCPServerType.STREAMABLE_HTTP,
server_url="http://mcp1.com",
token="super-secret-token",
custom_headers={"X-Custom": "custom-value"},
),
default_user,
)
mcp_server2 = await server.mcp_manager.create_or_update_mcp_server(
MCPServer(
server_name="mcp2",
server_type=MCPServerType.STDIO,
stdio_config={
"server_name": "mcp2",
"command": "mcp2-cmd",
"args": ["arg1", "arg2"],
},
),
default_user,
)
# Create tools from each server
from letta.schemas.tool import MCPTool, ToolCreate
tool1 = await server.tool_manager.create_mcp_tool_async(
ToolCreate.from_mcp(
"mcp1",
MCPTool(name="tool1", description="Tool 1", inputSchema={"type": "object", "properties": {}}),
),
"mcp1",
mcp_server1.id,
default_user,
)
tool2 = await server.tool_manager.create_mcp_tool_async(
ToolCreate.from_mcp(
"mcp2",
MCPTool(name="tool2", description="Tool 2", inputSchema={"type": "object", "properties": {}}),
),
"mcp2",
mcp_server2.id,
default_user,
)
# Create agents with different MCP tools
from letta.schemas.agent import CreateAgent
agent1 = await server.agent_manager.create_agent_async(
CreateAgent(
name="agent1",
tool_ids=[tool1.id],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
default_user,
)
agent2 = await server.agent_manager.create_agent_async(
CreateAgent(
name="agent2",
tool_ids=[tool2.id],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
default_user,
)
# Export both agents
agent_file = await agent_serialization_manager.export([agent1.id, agent2.id], default_user)
# Verify both MCP servers are included
assert len(agent_file.mcp_servers) == 2
# Verify server types
streamable_http_server = next(s for s in agent_file.mcp_servers if s.server_name == "mcp1")
stdio_server = next(s for s in agent_file.mcp_servers if s.server_name == "mcp2")
assert streamable_http_server.server_name == "mcp1"
assert streamable_http_server.server_type == "streamable_http"
assert streamable_http_server.server_url == "http://mcp1.com"
assert stdio_server.server_name == "mcp2"
assert stdio_server.server_type == "stdio"
assert stdio_server.stdio_config == {
"server_name": "mcp2",
"type": "stdio",
"command": "mcp2-cmd",
"args": ["arg1", "arg2"],
}
async def test_mcp_server_deduplication(self, server, agent_serialization_manager, default_user, test_mcp_server, mcp_tool):
"""Test that shared MCP servers are deduplicated during export."""
# Create two agents using the same MCP tool
from letta.schemas.agent import CreateAgent
agent1 = await server.agent_manager.create_agent_async(
CreateAgent(
name="agent_dup1",
tool_ids=[mcp_tool.id],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
default_user,
)
agent2 = await server.agent_manager.create_agent_async(
CreateAgent(
name="agent_dup2",
tool_ids=[mcp_tool.id],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
default_user,
)
# Export both agents
agent_file = await agent_serialization_manager.export([agent1.id, agent2.id], default_user)
# Verify only one MCP server is exported
assert len(agent_file.mcp_servers) == 1
assert agent_file.mcp_servers[0].server_name == "test_mcp_server"
if __name__ == "__main__":
pytest.main([__file__])