feat: support mcp server export/import with agent files

Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
jnjpng
2025-07-24 18:17:08 -07:00
committed by GitHub
parent 7781ca55bc
commit bc471c6055
8 changed files with 573 additions and 35 deletions

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
@@ -8,6 +8,7 @@ from letta.schemas.block import Block, CreateBlock
from letta.schemas.enums import MessageRole
from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase
from letta.schemas.group import GroupCreate
from letta.schemas.mcp import MCPServer
from letta.schemas.message import Message, MessageCreate
from letta.schemas.source import Source, SourceCreate
from letta.schemas.tool import Tool
@@ -264,9 +265,34 @@ class ToolSchema(Tool):
return cls(**tool.model_dump())
# class MCPServerSchema(RegisterMCPServer):
# """MCP Server with human-readable ID for agent file"""
# id: str = Field(..., description="Human-readable identifier for this MCP server in the file")
class MCPServerSchema(BaseModel):
"""MCP server schema for agent files with remapped ID."""
__id_prefix__ = "mcp_server"
id: str = Field(..., description="Human-readable MCP server ID")
server_type: str
server_name: str
server_url: Optional[str] = None
stdio_config: Optional[Dict[str, Any]] = None
metadata_: Optional[Dict[str, Any]] = None
@classmethod
def from_mcp_server(cls, mcp_server: MCPServer) -> "MCPServerSchema":
"""Convert MCPServer to MCPServerSchema (excluding auth fields)."""
return cls(
id=mcp_server.id, # remapped by serialization manager
server_type=mcp_server.server_type,
server_name=mcp_server.server_name,
server_url=mcp_server.server_url,
# exclude token, custom_headers, and the env field in stdio_config that may contain authentication credentials
stdio_config=cls.strip_env_from_stdio_config(mcp_server.stdio_config.model_dump()) if mcp_server.stdio_config else None,
metadata_=mcp_server.metadata_,
)
def strip_env_from_stdio_config(stdio_config: Dict[str, Any]) -> Dict[str, Any]:
"""Strip out the env field from the stdio config."""
return {k: v for k, v in stdio_config.items() if k != "env"}
class AgentFileSchema(BaseModel):
@@ -278,7 +304,7 @@ class AgentFileSchema(BaseModel):
files: List[FileSchema] = Field(..., description="List of files in this agent file")
sources: List[SourceSchema] = Field(..., description="List of sources in this agent file")
tools: List[ToolSchema] = Field(..., description="List of tools in this agent file")
# mcp_servers: List[MCPServerSchema] = Field(..., description="List of MCP servers in this agent file")
mcp_servers: List[MCPServerSchema] = Field(..., description="List of MCP servers in this agent file")
metadata: Dict[str, str] = Field(
default_factory=dict, description="Metadata for this agent file, including revision_id and other export information."
)

View File

@@ -475,7 +475,11 @@ async def add_mcp_tool(
)
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
return await server.tool_manager.create_mcp_tool_async(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=actor)
# For config-based servers, use the server name as ID since they don't have database IDs
mcp_server_id = mcp_server_name
return await server.tool_manager.create_mcp_tool_async(
tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor
)
else:
return await server.mcp_manager.add_tool_from_mcp_server(mcp_server_name=mcp_server_name, mcp_tool_name=mcp_tool_name, actor=actor)

View File

@@ -13,6 +13,7 @@ from letta.schemas.agent_file import (
FileSchema,
GroupSchema,
ImportResult,
MCPServerSchema,
MessageSchema,
SourceSchema,
ToolSchema,
@@ -20,6 +21,7 @@ from letta.schemas.agent_file import (
from letta.schemas.block import Block
from letta.schemas.enums import FileProcessingStatus
from letta.schemas.file import FileMetadata
from letta.schemas.mcp import MCPServer
from letta.schemas.message import Message
from letta.schemas.source import Source
from letta.schemas.tool import Tool
@@ -92,7 +94,7 @@ class AgentSerializationManager:
ToolSchema.__id_prefix__: 0,
MessageSchema.__id_prefix__: 0,
FileAgentSchema.__id_prefix__: 0,
# MCPServerSchema.__id_prefix__: 0,
MCPServerSchema.__id_prefix__: 0,
}
def _reset_state(self):
@@ -258,6 +260,58 @@ class AgentSerializationManager:
file_schema.source_id = self._map_db_to_file_id(file_metadata.source_id, SourceSchema.__id_prefix__, allow_new=False)
return file_schema
async def _extract_unique_mcp_servers(self, tools: List, actor: User) -> List:
"""Extract unique MCP servers from tools based on metadata, using server_id if available, otherwise falling back to server_name."""
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
mcp_server_ids = set()
mcp_server_names = set()
for tool in tools:
# Check if tool has MCP metadata
if tool.metadata_ and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_:
mcp_metadata = tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]
# TODO: @jnjpng clean this up once we fully migrate to server_id being the main identifier
if "server_id" in mcp_metadata:
mcp_server_ids.add(mcp_metadata["server_id"])
elif "server_name" in mcp_metadata:
mcp_server_names.add(mcp_metadata["server_name"])
# Fetch MCP servers by ID
mcp_servers = []
fetched_server_ids = set()
if mcp_server_ids:
for server_id in mcp_server_ids:
try:
mcp_server = await self.mcp_manager.get_mcp_server_by_id_async(server_id, actor)
if mcp_server:
mcp_servers.append(mcp_server)
fetched_server_ids.add(server_id)
except Exception as e:
logger.warning(f"Failed to fetch MCP server {server_id}: {e}")
# Fetch MCP servers by name if not already fetched by ID
if mcp_server_names:
for server_name in mcp_server_names:
try:
mcp_server = await self.mcp_manager.get_mcp_server(server_name, actor)
if mcp_server and mcp_server.id not in fetched_server_ids:
mcp_servers.append(mcp_server)
except Exception as e:
logger.warning(f"Failed to fetch MCP server by name {server_name}: {e}")
return mcp_servers
def _convert_mcp_server_to_schema(self, mcp_server: MCPServer) -> MCPServerSchema:
"""Convert MCPServer to MCPServerSchema with ID remapping and auth scrubbing"""
try:
mcp_file_id = self._map_db_to_file_id(mcp_server.id, MCPServerSchema.__id_prefix__, allow_new=False)
mcp_schema = MCPServerSchema.from_mcp_server(mcp_server)
mcp_schema.id = mcp_file_id
return mcp_schema
except Exception as e:
logger.error(f"Failed to convert MCP server {mcp_server.id}: {e}")
raise
async def export(self, agent_ids: List[str], actor: User) -> AgentFileSchema:
"""
Export agents and their related entities to AgentFileSchema format.
@@ -289,6 +343,13 @@ class AgentSerializationManager:
tool_set = self._extract_unique_tools(agent_states)
block_set = self._extract_unique_blocks(agent_states)
# Extract MCP servers from tools BEFORE conversion (must be done before ID mapping)
mcp_server_set = await self._extract_unique_mcp_servers(tool_set, actor)
# Map MCP server IDs before converting schemas
for mcp_server in mcp_server_set:
self._map_db_to_file_id(mcp_server.id, MCPServerSchema.__id_prefix__)
# Extract sources and files from agent states BEFORE conversion (with caching)
source_set, file_set = await self._extract_unique_sources_and_files_from_agents(agent_states, actor, files_agents_cache)
@@ -301,6 +362,7 @@ class AgentSerializationManager:
block_schemas = [self._convert_block_to_schema(block) for block in block_set]
source_schemas = [self._convert_source_to_schema(source) for source in source_set]
file_schemas = [self._convert_file_to_schema(file_metadata) for file_metadata in file_set]
mcp_server_schemas = [self._convert_mcp_server_to_schema(mcp_server) for mcp_server in mcp_server_set]
logger.info(f"Exporting {len(agent_ids)} agents to agent file format")
@@ -312,7 +374,7 @@ class AgentSerializationManager:
files=file_schemas,
sources=source_schemas,
tools=tool_schemas,
# mcp_servers=[], # TODO: Extract and convert MCP servers
mcp_servers=mcp_server_schemas,
metadata={"revision_id": await get_latest_alembic_revision()},
created_at=datetime.now(timezone.utc),
)
@@ -359,7 +421,20 @@ class AgentSerializationManager:
# in-memory cache for file metadata to avoid repeated db calls
file_metadata_cache = {} # Maps database file ID to FileMetadata
# 1. Create tools first (no dependencies) - using bulk upsert for efficiency
# 1. Create MCP servers first (tools depend on them)
if schema.mcp_servers:
for mcp_server_schema in schema.mcp_servers:
server_data = mcp_server_schema.model_dump(exclude={"id"})
filtered_server_data = self._filter_dict_for_model(server_data, MCPServer)
create_schema = MCPServer(**filtered_server_data)
# Note: We don't have auth info from export, so the user will need to re-configure auth.
# TODO: @jnjpng store metadata about obfuscated metadata to surface to the user
created_mcp_server = await self.mcp_manager.create_or_update_mcp_server(create_schema, actor)
file_to_db_ids[mcp_server_schema.id] = created_mcp_server.id
imported_count += 1
# 2. Create tools (may depend on MCP servers) - using bulk upsert for efficiency
if schema.tools:
# convert tool schemas to pydantic tools
pydantic_tools = []
@@ -559,6 +634,7 @@ class AgentSerializationManager:
(schema.files, FileSchema.__id_prefix__),
(schema.sources, SourceSchema.__id_prefix__),
(schema.tools, ToolSchema.__id_prefix__),
(schema.mcp_servers, MCPServerSchema.__id_prefix__),
]
for entities, expected_prefix in entity_checks:
@@ -601,6 +677,7 @@ class AgentSerializationManager:
("files", schema.files),
("sources", schema.sources),
("tools", schema.tools),
("mcp_servers", schema.mcp_servers),
]
for entity_type, entities in entity_collections:
@@ -705,3 +782,11 @@ class AgentSerializationManager:
raise AgentFileImportError(f"Schema validation failed: {'; '.join(errors)}")
logger.info("Schema validation passed")
def _filter_dict_for_model(self, data: dict, model_cls):
"""Filter a dictionary to only include keys that are in the model fields"""
try:
allowed = model_cls.model_fields.keys() # Pydantic v2
except AttributeError:
allowed = model_cls.__fields__.keys() # Pydantic v1
return {k: v for k, v in data.items() if k in allowed}

View File

@@ -98,12 +98,19 @@ class MCPManager:
@enforce_types
async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool:
"""Add a tool from an MCP server to the Letta tool registry."""
# get the MCP server ID, we should migrate to use the server_id instead of the name
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
if not mcp_server_id:
raise ValueError(f"MCP server '{mcp_server_name}' not found")
mcp_tools = await self.list_mcp_server_tools(mcp_server_name, actor=actor)
for mcp_tool in mcp_tools:
if mcp_tool.name == mcp_tool_name:
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
return await self.tool_manager.create_mcp_tool_async(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=actor)
return await self.tool_manager.create_mcp_tool_async(
tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor
)
# failed to add - handle error?
return None
@@ -194,14 +201,7 @@ class MCPManager:
"""Update an MCP server by its name."""
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
if not mcp_server_id:
raise HTTPException(
status_code=404,
detail={
"code": "MCPServerNotFoundError",
"message": f"MCP server {mcp_server_name} not found",
"mcp_server_name": mcp_server_name,
},
)
raise ValueError(f"MCP server {mcp_server_name} not found")
return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor)
@enforce_types
@@ -223,6 +223,18 @@ class MCPManager:
# Convert the SQLAlchemy Tool object to PydanticTool
return mcp_server.to_pydantic()
@enforce_types
async def get_mcp_servers_by_ids(self, mcp_server_ids: List[str], actor: PydanticUser) -> List[MCPServer]:
"""Fetch multiple MCP servers by their IDs in a single query."""
if not mcp_server_ids:
return []
async with db_registry.async_session() as session:
mcp_servers = await MCPServerModel.list_async(
db_session=session, organization_id=actor.organization_id, id=mcp_server_ids # This will use the IN operator
)
return [mcp_server.to_pydantic() for mcp_server in mcp_servers]
@enforce_types
async def get_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
"""Get a tool by name."""
@@ -230,14 +242,7 @@ class MCPManager:
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
if not mcp_server:
raise HTTPException(
status_code=404, # Not Found
detail={
"code": "MCPServerNotFoundError",
"message": f"MCP server {mcp_server_name} not found",
"mcp_server_name": mcp_server_name,
},
)
raise ValueError(f"MCP server {mcp_server_name} not found")
return mcp_server.to_pydantic()
# @enforce_types

View File

@@ -106,8 +106,10 @@ class ToolManager:
@enforce_types
@trace_method
def create_or_update_mcp_tool(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}}
def create_or_update_mcp_tool(
self, tool_create: ToolCreate, mcp_server_name: str, mcp_server_id: str, actor: PydanticUser
) -> PydanticTool:
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name, "server_id": mcp_server_id}}
return self.create_or_update_tool(
PydanticTool(
tool_type=ToolType.EXTERNAL_MCP, name=tool_create.json_schema["name"], metadata_=metadata, **tool_create.model_dump()
@@ -116,8 +118,10 @@ class ToolManager:
)
@enforce_types
async def create_mcp_tool_async(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}}
async def create_mcp_tool_async(
self, tool_create: ToolCreate, mcp_server_name: str, mcp_server_id: str, actor: PydanticUser
) -> PydanticTool:
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name, "server_id": mcp_server_id}}
return await self.create_or_update_tool_async(
PydanticTool(
tool_type=ToolType.EXTERNAL_MCP, name=tool_create.json_schema["name"], metadata_=metadata, **tool_create.model_dump()

View File

@@ -1 +1 @@
{}
{}

View File

@@ -283,6 +283,68 @@ async def agent_with_files(server: SyncServer, default_user, test_block, weather
return (agent_state.id, test_source.id, test_file.id)
@pytest.fixture
async def test_mcp_server(server: SyncServer, default_user):
"""Fixture to create and return a test MCP server."""
from letta.schemas.mcp import MCPServer, MCPServerType
mcp_server_data = MCPServer(
server_name="test_mcp_server",
server_type=MCPServerType.SSE,
server_url="http://test-mcp-server.com",
token="test-token-12345", # This should be excluded during export
custom_headers={"X-API-Key": "secret-key"}, # This should be excluded during export
)
mcp_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server_data, default_user)
yield mcp_server
@pytest.fixture
async def mcp_tool(server: SyncServer, default_user, test_mcp_server):
"""Fixture to create and return an MCP tool."""
from letta.schemas.tool import MCPTool, ToolCreate
# Create a mock MCP tool
mcp_tool_data = MCPTool(
name="test_mcp_tool",
description="Test MCP tool for serialization",
inputSchema={"type": "object", "properties": {"input": {"type": "string"}}},
)
tool_create = ToolCreate.from_mcp(test_mcp_server.server_name, mcp_tool_data)
# Create tool with MCP metadata
mcp_tool = await server.tool_manager.create_mcp_tool_async(tool_create, test_mcp_server.server_name, test_mcp_server.id, default_user)
yield mcp_tool
@pytest.fixture
async def agent_with_mcp_tools(server: SyncServer, default_user, test_block, mcp_tool, test_mcp_server):
"""Fixture to create and return an agent with MCP tools."""
memory_blocks = [
CreateBlock(label="human", value="User is a test user"),
CreateBlock(label="persona", value="I am a helpful test assistant"),
]
create_agent_request = CreateAgent(
name="test_agent_mcp",
system="You are a helpful assistant with MCP tools.",
memory_blocks=memory_blocks,
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
block_ids=[test_block.id],
tool_ids=[mcp_tool.id],
tags=["test", "mcp", "export"],
description="Test agent with MCP tools for serialization testing",
)
agent_state = await server.agent_manager.create_agent_async(
agent_create=create_agent_request,
actor=default_user,
)
return agent_state
# ------------------------------
# Helper Functions
# ------------------------------
@@ -1258,7 +1320,7 @@ class TestAgentFileValidation:
files=[],
sources=[],
tools=[],
# mcp_servers=[],
mcp_servers=[],
)
# Should not raise
@@ -1302,7 +1364,7 @@ class TestAgentFileValidation:
json_schema={"name": "test_tool", "parameters": {"type": "object", "properties": {}}},
)
],
# mcp_servers=[],
mcp_servers=[],
)
assert validate_id_format(valid_schema)
@@ -1316,11 +1378,282 @@ class TestAgentFileValidation:
files=[],
sources=[],
tools=[],
# mcp_servers=[],
mcp_servers=[],
)
assert not validate_id_format(invalid_schema)
class TestMCPServerSerialization:
"""Tests for MCP server export/import functionality."""
async def test_mcp_server_export(self, agent_serialization_manager, agent_with_mcp_tools, default_user):
"""Test that MCP servers are exported correctly."""
agent_file = await agent_serialization_manager.export([agent_with_mcp_tools.id], default_user)
# Verify MCP server is included
assert len(agent_file.mcp_servers) == 1
mcp_server = agent_file.mcp_servers[0]
# Verify server details
assert mcp_server.server_name == "test_mcp_server"
assert mcp_server.server_url == "http://test-mcp-server.com"
assert mcp_server.server_type == "sse"
# Verify auth fields are excluded
assert not hasattr(mcp_server, "token")
assert not hasattr(mcp_server, "custom_headers")
# Verify ID format
assert _validate_entity_id(mcp_server.id, "mcp_server")
async def test_mcp_server_auth_scrubbing(self, server, agent_serialization_manager, default_user):
"""Test that authentication information is scrubbed during export."""
from letta.schemas.mcp import MCPServer, MCPServerType
# Create MCP server with auth info
mcp_server_data_stdio = MCPServer(
server_name="auth_test_server",
server_type=MCPServerType.STDIO,
# token="super-secret-token",
# custom_headers={"Authorization": "Bearer secret-key", "X-Custom": "custom-value"},
stdio_config={
"server_name": "auth_test_server",
"command": "test-command",
"args": ["arg1", "arg2"],
"env": {"ENV_VAR": "value"},
},
)
mcp_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server_data_stdio, default_user)
mcp_server_data_http = MCPServer(
server_name="auth_test_server_http",
server_type=MCPServerType.STREAMABLE_HTTP,
server_url="http://auth_test_server_http.com",
token="super-secret-token",
custom_headers={"X-Custom": "custom-value"},
)
mcp_server_http = await server.mcp_manager.create_or_update_mcp_server(mcp_server_data_http, default_user)
# Create tool from MCP server
from letta.schemas.tool import MCPTool, ToolCreate
mcp_tool_data = MCPTool(
name="auth_test_tool_stdio",
description="Tool with auth",
inputSchema={"type": "object", "properties": {}},
)
tool_create_stdio = ToolCreate.from_mcp(mcp_server.server_name, mcp_tool_data)
mcp_tool_data_http = MCPTool(
name="auth_test_tool_http",
description="Tool with auth",
inputSchema={"type": "object", "properties": {}},
)
tool_create_http = ToolCreate.from_mcp(mcp_server_http.server_name, mcp_tool_data_http)
mcp_tool = await server.tool_manager.create_mcp_tool_async(tool_create_stdio, mcp_server.server_name, mcp_server.id, default_user)
mcp_tool_http = await server.tool_manager.create_mcp_tool_async(
tool_create_http, mcp_server_http.server_name, mcp_server_http.id, default_user
)
# Create agent with the tool
from letta.schemas.agent import CreateAgent
create_agent_request = CreateAgent(
name="auth_test_agent",
tool_ids=[mcp_tool.id, mcp_tool_http.id],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
agent = await server.agent_manager.create_agent_async(create_agent_request, default_user)
# Export
agent_file = await agent_serialization_manager.export([agent.id], default_user)
for server in agent_file.mcp_servers:
if server.server_name == "auth_test_server":
exported_server_stdio = server
elif server.server_name == "auth_test_server_http":
exported_server_http = server
# Verify env variables in stdio server are excluded (typically used for auth)
assert exported_server_stdio.id != mcp_server.id
assert exported_server_stdio.server_name == "auth_test_server"
assert exported_server_stdio.stdio_config == {
"server_name": "auth_test_server",
"type": "stdio",
"command": "test-command",
"args": ["arg1", "arg2"],
} # Non-auth config preserved
assert exported_server_stdio.server_type == "stdio"
# Verify token and custom headers are excluded from export for http server
assert exported_server_http.id != mcp_server_http.id
assert exported_server_http.server_name == "auth_test_server_http"
assert exported_server_http.server_type == "streamable_http"
assert exported_server_http.server_url == "http://auth_test_server_http.com"
assert not hasattr(exported_server_http, "token")
assert not hasattr(exported_server_http, "custom_headers")
async def test_mcp_tool_metadata_with_server_id(self, agent_serialization_manager, agent_with_mcp_tools, default_user):
"""Test that MCP tools have server_id in metadata."""
agent_file = await agent_serialization_manager.export([agent_with_mcp_tools.id], default_user)
# Find the MCP tool
mcp_tool = next((t for t in agent_file.tools if t.name == "test_mcp_tool"), None)
assert mcp_tool is not None
# Verify metadata contains server info
assert mcp_tool.metadata_ is not None
assert "mcp" in mcp_tool.metadata_
assert "server_name" in mcp_tool.metadata_["mcp"]
assert "server_id" in mcp_tool.metadata_["mcp"]
assert mcp_tool.metadata_["mcp"]["server_name"] == "test_mcp_server"
# Verify tag format
assert any(tag.startswith("mcp:") for tag in mcp_tool.tags)
async def test_mcp_server_import(self, agent_serialization_manager, agent_with_mcp_tools, default_user, other_user):
"""Test importing agents with MCP servers."""
# Export from default user
agent_file = await agent_serialization_manager.export([agent_with_mcp_tools.id], default_user)
# Import to other user
result = await agent_serialization_manager.import_file(agent_file, other_user)
assert result.success
# Verify MCP server was imported
mcp_server_id = next((db_id for file_id, db_id in result.id_mappings.items() if file_id.startswith("mcp_server-")), None)
assert mcp_server_id is not None
async def test_multiple_mcp_servers_export(self, server, agent_serialization_manager, default_user):
"""Test exporting multiple MCP servers from different agents."""
from letta.schemas.mcp import MCPServer, MCPServerType
# Create two MCP servers
mcp_server1 = await server.mcp_manager.create_or_update_mcp_server(
MCPServer(
server_name="mcp1",
server_type=MCPServerType.STREAMABLE_HTTP,
server_url="http://mcp1.com",
token="super-secret-token",
custom_headers={"X-Custom": "custom-value"},
),
default_user,
)
mcp_server2 = await server.mcp_manager.create_or_update_mcp_server(
MCPServer(
server_name="mcp2",
server_type=MCPServerType.STDIO,
stdio_config={
"server_name": "mcp2",
"command": "mcp2-cmd",
"args": ["arg1", "arg2"],
},
),
default_user,
)
# Create tools from each server
from letta.schemas.tool import MCPTool, ToolCreate
tool1 = await server.tool_manager.create_mcp_tool_async(
ToolCreate.from_mcp(
"mcp1",
MCPTool(name="tool1", description="Tool 1", inputSchema={"type": "object", "properties": {}}),
),
"mcp1",
mcp_server1.id,
default_user,
)
tool2 = await server.tool_manager.create_mcp_tool_async(
ToolCreate.from_mcp(
"mcp2",
MCPTool(name="tool2", description="Tool 2", inputSchema={"type": "object", "properties": {}}),
),
"mcp2",
mcp_server2.id,
default_user,
)
# Create agents with different MCP tools
from letta.schemas.agent import CreateAgent
agent1 = await server.agent_manager.create_agent_async(
CreateAgent(
name="agent1",
tool_ids=[tool1.id],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
default_user,
)
agent2 = await server.agent_manager.create_agent_async(
CreateAgent(
name="agent2",
tool_ids=[tool2.id],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
default_user,
)
# Export both agents
agent_file = await agent_serialization_manager.export([agent1.id, agent2.id], default_user)
# Verify both MCP servers are included
assert len(agent_file.mcp_servers) == 2
# Verify server types
streamable_http_server = next(s for s in agent_file.mcp_servers if s.server_name == "mcp1")
stdio_server = next(s for s in agent_file.mcp_servers if s.server_name == "mcp2")
assert streamable_http_server.server_name == "mcp1"
assert streamable_http_server.server_type == "streamable_http"
assert streamable_http_server.server_url == "http://mcp1.com"
assert stdio_server.server_name == "mcp2"
assert stdio_server.server_type == "stdio"
assert stdio_server.stdio_config == {
"server_name": "mcp2",
"type": "stdio",
"command": "mcp2-cmd",
"args": ["arg1", "arg2"],
}
async def test_mcp_server_deduplication(self, server, agent_serialization_manager, default_user, test_mcp_server, mcp_tool):
"""Test that shared MCP servers are deduplicated during export."""
# Create two agents using the same MCP tool
from letta.schemas.agent import CreateAgent
agent1 = await server.agent_manager.create_agent_async(
CreateAgent(
name="agent_dup1",
tool_ids=[mcp_tool.id],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
default_user,
)
agent2 = await server.agent_manager.create_agent_async(
CreateAgent(
name="agent_dup2",
tool_ids=[mcp_tool.id],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
default_user,
)
# Export both agents
agent_file = await agent_serialization_manager.export([agent1.id, agent2.id], default_user)
# Verify only one MCP server is exported
assert len(agent_file.mcp_servers) == 1
assert agent_file.mcp_servers[0].server_name == "test_mcp_server"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -269,8 +269,11 @@ def mcp_tool(server, default_user):
},
)
mcp_server_name = "test"
mcp_server_id = "test-server-id" # Mock server ID for testing
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
tool = server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user)
tool = server.tool_manager.create_or_update_mcp_tool(
tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=default_user
)
yield tool
@@ -3474,6 +3477,7 @@ def test_create_mcp_tool(server: SyncServer, mcp_tool, default_user, default_org
assert mcp_tool.created_by_id == default_user.id
assert mcp_tool.tool_type == ToolType.EXTERNAL_MCP
assert mcp_tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == "test"
assert mcp_tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_id"] == "test-server-id"
# Test should work with both SQLite and PostgreSQL
@@ -8552,6 +8556,83 @@ async def test_create_mcp_server(server, default_user, event_loop):
print("TAGS", tool.tags)
async def test_get_mcp_servers_by_ids(server, default_user, event_loop):
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig
from letta.settings import tool_settings
if tool_settings.mcp_read_from_config:
return
# Create multiple MCP servers for testing
servers_data = [
{
"name": "test_server_1",
"config": StdioServerConfig(
server_name="test_server_1", type=MCPServerType.STDIO, command="echo 'test1'", args=["arg1"], env={"ENV1": "value1"}
),
"type": MCPServerType.STDIO,
},
{
"name": "test_server_2",
"config": SSEServerConfig(server_name="test_server_2", server_url="https://test2.example.com/sse"),
"type": MCPServerType.SSE,
},
{
"name": "test_server_3",
"config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/sse"),
"type": MCPServerType.SSE,
},
]
created_servers = []
for server_data in servers_data:
if server_data["type"] == MCPServerType.STDIO:
mcp_server = MCPServer(server_name=server_data["name"], server_type=server_data["type"], stdio_config=server_data["config"])
else:
mcp_server = MCPServer(
server_name=server_data["name"], server_type=server_data["type"], server_url=server_data["config"].server_url
)
created = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
created_servers.append(created)
# Test fetching multiple servers by IDs
server_ids = [s.id for s in created_servers]
fetched_servers = await server.mcp_manager.get_mcp_servers_by_ids(server_ids, actor=default_user)
assert len(fetched_servers) == len(created_servers)
fetched_ids = {s.id for s in fetched_servers}
expected_ids = {s.id for s in created_servers}
assert fetched_ids == expected_ids
# Test fetching subset of servers
subset_ids = server_ids[:2]
subset_servers = await server.mcp_manager.get_mcp_servers_by_ids(subset_ids, actor=default_user)
assert len(subset_servers) == 2
assert all(s.id in subset_ids for s in subset_servers)
# Test fetching with empty list
empty_result = await server.mcp_manager.get_mcp_servers_by_ids([], actor=default_user)
assert empty_result == []
# Test fetching with non-existent ID mixed with valid IDs
mixed_ids = [server_ids[0], "non-existent-id", server_ids[1]]
mixed_result = await server.mcp_manager.get_mcp_servers_by_ids(mixed_ids, actor=default_user)
# Should only return the existing servers
assert len(mixed_result) == 2
assert all(s.id in server_ids for s in mixed_result)
# Test that servers from different organizations are not returned
# This would require creating another user/org, but for now we'll just verify
# that the function respects the actor's organization
all_servers = await server.mcp_manager.list_mcp_servers(actor=default_user)
all_server_ids = [s.id for s in all_servers]
bulk_fetched = await server.mcp_manager.get_mcp_servers_by_ids(all_server_ids, actor=default_user)
# All fetched servers should belong to the same organization
assert all(s.organization_id == default_user.organization_id for s in bulk_fetched)
# ======================================================================================================================
# FileAgent Tests
# ======================================================================================================================