feat: support mcp server export/import with agent files
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -8,6 +8,7 @@ from letta.schemas.block import Block, CreateBlock
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase
|
||||
from letta.schemas.group import GroupCreate
|
||||
from letta.schemas.mcp import MCPServer
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.source import Source, SourceCreate
|
||||
from letta.schemas.tool import Tool
|
||||
@@ -264,9 +265,34 @@ class ToolSchema(Tool):
|
||||
return cls(**tool.model_dump())
|
||||
|
||||
|
||||
# class MCPServerSchema(RegisterMCPServer):
|
||||
# """MCP Server with human-readable ID for agent file"""
|
||||
# id: str = Field(..., description="Human-readable identifier for this MCP server in the file")
|
||||
class MCPServerSchema(BaseModel):
|
||||
"""MCP server schema for agent files with remapped ID."""
|
||||
|
||||
__id_prefix__ = "mcp_server"
|
||||
|
||||
id: str = Field(..., description="Human-readable MCP server ID")
|
||||
server_type: str
|
||||
server_name: str
|
||||
server_url: Optional[str] = None
|
||||
stdio_config: Optional[Dict[str, Any]] = None
|
||||
metadata_: Optional[Dict[str, Any]] = None
|
||||
|
||||
@classmethod
|
||||
def from_mcp_server(cls, mcp_server: MCPServer) -> "MCPServerSchema":
|
||||
"""Convert MCPServer to MCPServerSchema (excluding auth fields)."""
|
||||
return cls(
|
||||
id=mcp_server.id, # remapped by serialization manager
|
||||
server_type=mcp_server.server_type,
|
||||
server_name=mcp_server.server_name,
|
||||
server_url=mcp_server.server_url,
|
||||
# exclude token, custom_headers, and the env field in stdio_config that may contain authentication credentials
|
||||
stdio_config=cls.strip_env_from_stdio_config(mcp_server.stdio_config.model_dump()) if mcp_server.stdio_config else None,
|
||||
metadata_=mcp_server.metadata_,
|
||||
)
|
||||
|
||||
def strip_env_from_stdio_config(stdio_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Strip out the env field from the stdio config."""
|
||||
return {k: v for k, v in stdio_config.items() if k != "env"}
|
||||
|
||||
|
||||
class AgentFileSchema(BaseModel):
|
||||
@@ -278,7 +304,7 @@ class AgentFileSchema(BaseModel):
|
||||
files: List[FileSchema] = Field(..., description="List of files in this agent file")
|
||||
sources: List[SourceSchema] = Field(..., description="List of sources in this agent file")
|
||||
tools: List[ToolSchema] = Field(..., description="List of tools in this agent file")
|
||||
# mcp_servers: List[MCPServerSchema] = Field(..., description="List of MCP servers in this agent file")
|
||||
mcp_servers: List[MCPServerSchema] = Field(..., description="List of MCP servers in this agent file")
|
||||
metadata: Dict[str, str] = Field(
|
||||
default_factory=dict, description="Metadata for this agent file, including revision_id and other export information."
|
||||
)
|
||||
|
||||
@@ -475,7 +475,11 @@ async def add_mcp_tool(
|
||||
)
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
return await server.tool_manager.create_mcp_tool_async(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=actor)
|
||||
# For config-based servers, use the server name as ID since they don't have database IDs
|
||||
mcp_server_id = mcp_server_name
|
||||
return await server.tool_manager.create_mcp_tool_async(
|
||||
tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor
|
||||
)
|
||||
|
||||
else:
|
||||
return await server.mcp_manager.add_tool_from_mcp_server(mcp_server_name=mcp_server_name, mcp_tool_name=mcp_tool_name, actor=actor)
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta.schemas.agent_file import (
|
||||
FileSchema,
|
||||
GroupSchema,
|
||||
ImportResult,
|
||||
MCPServerSchema,
|
||||
MessageSchema,
|
||||
SourceSchema,
|
||||
ToolSchema,
|
||||
@@ -20,6 +21,7 @@ from letta.schemas.agent_file import (
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.enums import FileProcessingStatus
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.mcp import MCPServer
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
@@ -92,7 +94,7 @@ class AgentSerializationManager:
|
||||
ToolSchema.__id_prefix__: 0,
|
||||
MessageSchema.__id_prefix__: 0,
|
||||
FileAgentSchema.__id_prefix__: 0,
|
||||
# MCPServerSchema.__id_prefix__: 0,
|
||||
MCPServerSchema.__id_prefix__: 0,
|
||||
}
|
||||
|
||||
def _reset_state(self):
|
||||
@@ -258,6 +260,58 @@ class AgentSerializationManager:
|
||||
file_schema.source_id = self._map_db_to_file_id(file_metadata.source_id, SourceSchema.__id_prefix__, allow_new=False)
|
||||
return file_schema
|
||||
|
||||
async def _extract_unique_mcp_servers(self, tools: List, actor: User) -> List:
|
||||
"""Extract unique MCP servers from tools based on metadata, using server_id if available, otherwise falling back to server_name."""
|
||||
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
|
||||
|
||||
mcp_server_ids = set()
|
||||
mcp_server_names = set()
|
||||
for tool in tools:
|
||||
# Check if tool has MCP metadata
|
||||
if tool.metadata_ and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_:
|
||||
mcp_metadata = tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]
|
||||
# TODO: @jnjpng clean this up once we fully migrate to server_id being the main identifier
|
||||
if "server_id" in mcp_metadata:
|
||||
mcp_server_ids.add(mcp_metadata["server_id"])
|
||||
elif "server_name" in mcp_metadata:
|
||||
mcp_server_names.add(mcp_metadata["server_name"])
|
||||
|
||||
# Fetch MCP servers by ID
|
||||
mcp_servers = []
|
||||
fetched_server_ids = set()
|
||||
if mcp_server_ids:
|
||||
for server_id in mcp_server_ids:
|
||||
try:
|
||||
mcp_server = await self.mcp_manager.get_mcp_server_by_id_async(server_id, actor)
|
||||
if mcp_server:
|
||||
mcp_servers.append(mcp_server)
|
||||
fetched_server_ids.add(server_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch MCP server {server_id}: {e}")
|
||||
|
||||
# Fetch MCP servers by name if not already fetched by ID
|
||||
if mcp_server_names:
|
||||
for server_name in mcp_server_names:
|
||||
try:
|
||||
mcp_server = await self.mcp_manager.get_mcp_server(server_name, actor)
|
||||
if mcp_server and mcp_server.id not in fetched_server_ids:
|
||||
mcp_servers.append(mcp_server)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch MCP server by name {server_name}: {e}")
|
||||
|
||||
return mcp_servers
|
||||
|
||||
def _convert_mcp_server_to_schema(self, mcp_server: MCPServer) -> MCPServerSchema:
|
||||
"""Convert MCPServer to MCPServerSchema with ID remapping and auth scrubbing"""
|
||||
try:
|
||||
mcp_file_id = self._map_db_to_file_id(mcp_server.id, MCPServerSchema.__id_prefix__, allow_new=False)
|
||||
mcp_schema = MCPServerSchema.from_mcp_server(mcp_server)
|
||||
mcp_schema.id = mcp_file_id
|
||||
return mcp_schema
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert MCP server {mcp_server.id}: {e}")
|
||||
raise
|
||||
|
||||
async def export(self, agent_ids: List[str], actor: User) -> AgentFileSchema:
|
||||
"""
|
||||
Export agents and their related entities to AgentFileSchema format.
|
||||
@@ -289,6 +343,13 @@ class AgentSerializationManager:
|
||||
tool_set = self._extract_unique_tools(agent_states)
|
||||
block_set = self._extract_unique_blocks(agent_states)
|
||||
|
||||
# Extract MCP servers from tools BEFORE conversion (must be done before ID mapping)
|
||||
mcp_server_set = await self._extract_unique_mcp_servers(tool_set, actor)
|
||||
|
||||
# Map MCP server IDs before converting schemas
|
||||
for mcp_server in mcp_server_set:
|
||||
self._map_db_to_file_id(mcp_server.id, MCPServerSchema.__id_prefix__)
|
||||
|
||||
# Extract sources and files from agent states BEFORE conversion (with caching)
|
||||
source_set, file_set = await self._extract_unique_sources_and_files_from_agents(agent_states, actor, files_agents_cache)
|
||||
|
||||
@@ -301,6 +362,7 @@ class AgentSerializationManager:
|
||||
block_schemas = [self._convert_block_to_schema(block) for block in block_set]
|
||||
source_schemas = [self._convert_source_to_schema(source) for source in source_set]
|
||||
file_schemas = [self._convert_file_to_schema(file_metadata) for file_metadata in file_set]
|
||||
mcp_server_schemas = [self._convert_mcp_server_to_schema(mcp_server) for mcp_server in mcp_server_set]
|
||||
|
||||
logger.info(f"Exporting {len(agent_ids)} agents to agent file format")
|
||||
|
||||
@@ -312,7 +374,7 @@ class AgentSerializationManager:
|
||||
files=file_schemas,
|
||||
sources=source_schemas,
|
||||
tools=tool_schemas,
|
||||
# mcp_servers=[], # TODO: Extract and convert MCP servers
|
||||
mcp_servers=mcp_server_schemas,
|
||||
metadata={"revision_id": await get_latest_alembic_revision()},
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -359,7 +421,20 @@ class AgentSerializationManager:
|
||||
# in-memory cache for file metadata to avoid repeated db calls
|
||||
file_metadata_cache = {} # Maps database file ID to FileMetadata
|
||||
|
||||
# 1. Create tools first (no dependencies) - using bulk upsert for efficiency
|
||||
# 1. Create MCP servers first (tools depend on them)
|
||||
if schema.mcp_servers:
|
||||
for mcp_server_schema in schema.mcp_servers:
|
||||
server_data = mcp_server_schema.model_dump(exclude={"id"})
|
||||
filtered_server_data = self._filter_dict_for_model(server_data, MCPServer)
|
||||
create_schema = MCPServer(**filtered_server_data)
|
||||
|
||||
# Note: We don't have auth info from export, so the user will need to re-configure auth.
|
||||
# TODO: @jnjpng store metadata about obfuscated metadata to surface to the user
|
||||
created_mcp_server = await self.mcp_manager.create_or_update_mcp_server(create_schema, actor)
|
||||
file_to_db_ids[mcp_server_schema.id] = created_mcp_server.id
|
||||
imported_count += 1
|
||||
|
||||
# 2. Create tools (may depend on MCP servers) - using bulk upsert for efficiency
|
||||
if schema.tools:
|
||||
# convert tool schemas to pydantic tools
|
||||
pydantic_tools = []
|
||||
@@ -559,6 +634,7 @@ class AgentSerializationManager:
|
||||
(schema.files, FileSchema.__id_prefix__),
|
||||
(schema.sources, SourceSchema.__id_prefix__),
|
||||
(schema.tools, ToolSchema.__id_prefix__),
|
||||
(schema.mcp_servers, MCPServerSchema.__id_prefix__),
|
||||
]
|
||||
|
||||
for entities, expected_prefix in entity_checks:
|
||||
@@ -601,6 +677,7 @@ class AgentSerializationManager:
|
||||
("files", schema.files),
|
||||
("sources", schema.sources),
|
||||
("tools", schema.tools),
|
||||
("mcp_servers", schema.mcp_servers),
|
||||
]
|
||||
|
||||
for entity_type, entities in entity_collections:
|
||||
@@ -705,3 +782,11 @@ class AgentSerializationManager:
|
||||
raise AgentFileImportError(f"Schema validation failed: {'; '.join(errors)}")
|
||||
|
||||
logger.info("Schema validation passed")
|
||||
|
||||
def _filter_dict_for_model(self, data: dict, model_cls):
|
||||
"""Filter a dictionary to only include keys that are in the model fields"""
|
||||
try:
|
||||
allowed = model_cls.model_fields.keys() # Pydantic v2
|
||||
except AttributeError:
|
||||
allowed = model_cls.__fields__.keys() # Pydantic v1
|
||||
return {k: v for k, v in data.items() if k in allowed}
|
||||
|
||||
@@ -98,12 +98,19 @@ class MCPManager:
|
||||
@enforce_types
|
||||
async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
"""Add a tool from an MCP server to the Letta tool registry."""
|
||||
# get the MCP server ID, we should migrate to use the server_id instead of the name
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
if not mcp_server_id:
|
||||
raise ValueError(f"MCP server '{mcp_server_name}' not found")
|
||||
|
||||
mcp_tools = await self.list_mcp_server_tools(mcp_server_name, actor=actor)
|
||||
|
||||
for mcp_tool in mcp_tools:
|
||||
if mcp_tool.name == mcp_tool_name:
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
return await self.tool_manager.create_mcp_tool_async(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=actor)
|
||||
return await self.tool_manager.create_mcp_tool_async(
|
||||
tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor
|
||||
)
|
||||
|
||||
# failed to add - handle error?
|
||||
return None
|
||||
@@ -194,14 +201,7 @@ class MCPManager:
|
||||
"""Update an MCP server by its name."""
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
|
||||
if not mcp_server_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "MCPServerNotFoundError",
|
||||
"message": f"MCP server {mcp_server_name} not found",
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
raise ValueError(f"MCP server {mcp_server_name} not found")
|
||||
return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor)
|
||||
|
||||
@enforce_types
|
||||
@@ -223,6 +223,18 @@ class MCPManager:
|
||||
# Convert the SQLAlchemy Tool object to PydanticTool
|
||||
return mcp_server.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def get_mcp_servers_by_ids(self, mcp_server_ids: List[str], actor: PydanticUser) -> List[MCPServer]:
|
||||
"""Fetch multiple MCP servers by their IDs in a single query."""
|
||||
if not mcp_server_ids:
|
||||
return []
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
mcp_servers = await MCPServerModel.list_async(
|
||||
db_session=session, organization_id=actor.organization_id, id=mcp_server_ids # This will use the IN operator
|
||||
)
|
||||
return [mcp_server.to_pydantic() for mcp_server in mcp_servers]
|
||||
|
||||
@enforce_types
|
||||
async def get_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
"""Get a tool by name."""
|
||||
@@ -230,14 +242,7 @@ class MCPManager:
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
|
||||
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
||||
if not mcp_server:
|
||||
raise HTTPException(
|
||||
status_code=404, # Not Found
|
||||
detail={
|
||||
"code": "MCPServerNotFoundError",
|
||||
"message": f"MCP server {mcp_server_name} not found",
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
raise ValueError(f"MCP server {mcp_server_name} not found")
|
||||
return mcp_server.to_pydantic()
|
||||
|
||||
# @enforce_types
|
||||
|
||||
@@ -106,8 +106,10 @@ class ToolManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_or_update_mcp_tool(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}}
|
||||
def create_or_update_mcp_tool(
|
||||
self, tool_create: ToolCreate, mcp_server_name: str, mcp_server_id: str, actor: PydanticUser
|
||||
) -> PydanticTool:
|
||||
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name, "server_id": mcp_server_id}}
|
||||
return self.create_or_update_tool(
|
||||
PydanticTool(
|
||||
tool_type=ToolType.EXTERNAL_MCP, name=tool_create.json_schema["name"], metadata_=metadata, **tool_create.model_dump()
|
||||
@@ -116,8 +118,10 @@ class ToolManager:
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
async def create_mcp_tool_async(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}}
|
||||
async def create_mcp_tool_async(
|
||||
self, tool_create: ToolCreate, mcp_server_name: str, mcp_server_id: str, actor: PydanticUser
|
||||
) -> PydanticTool:
|
||||
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name, "server_id": mcp_server_id}}
|
||||
return await self.create_or_update_tool_async(
|
||||
PydanticTool(
|
||||
tool_type=ToolType.EXTERNAL_MCP, name=tool_create.json_schema["name"], metadata_=metadata, **tool_create.model_dump()
|
||||
|
||||
@@ -1 +1 @@
|
||||
{}
|
||||
{}
|
||||
@@ -283,6 +283,68 @@ async def agent_with_files(server: SyncServer, default_user, test_block, weather
|
||||
return (agent_state.id, test_source.id, test_file.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_mcp_server(server: SyncServer, default_user):
|
||||
"""Fixture to create and return a test MCP server."""
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType
|
||||
|
||||
mcp_server_data = MCPServer(
|
||||
server_name="test_mcp_server",
|
||||
server_type=MCPServerType.SSE,
|
||||
server_url="http://test-mcp-server.com",
|
||||
token="test-token-12345", # This should be excluded during export
|
||||
custom_headers={"X-API-Key": "secret-key"}, # This should be excluded during export
|
||||
)
|
||||
mcp_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server_data, default_user)
|
||||
yield mcp_server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_tool(server: SyncServer, default_user, test_mcp_server):
|
||||
"""Fixture to create and return an MCP tool."""
|
||||
from letta.schemas.tool import MCPTool, ToolCreate
|
||||
|
||||
# Create a mock MCP tool
|
||||
mcp_tool_data = MCPTool(
|
||||
name="test_mcp_tool",
|
||||
description="Test MCP tool for serialization",
|
||||
inputSchema={"type": "object", "properties": {"input": {"type": "string"}}},
|
||||
)
|
||||
tool_create = ToolCreate.from_mcp(test_mcp_server.server_name, mcp_tool_data)
|
||||
|
||||
# Create tool with MCP metadata
|
||||
mcp_tool = await server.tool_manager.create_mcp_tool_async(tool_create, test_mcp_server.server_name, test_mcp_server.id, default_user)
|
||||
yield mcp_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def agent_with_mcp_tools(server: SyncServer, default_user, test_block, mcp_tool, test_mcp_server):
|
||||
"""Fixture to create and return an agent with MCP tools."""
|
||||
memory_blocks = [
|
||||
CreateBlock(label="human", value="User is a test user"),
|
||||
CreateBlock(label="persona", value="I am a helpful test assistant"),
|
||||
]
|
||||
|
||||
create_agent_request = CreateAgent(
|
||||
name="test_agent_mcp",
|
||||
system="You are a helpful assistant with MCP tools.",
|
||||
memory_blocks=memory_blocks,
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
block_ids=[test_block.id],
|
||||
tool_ids=[mcp_tool.id],
|
||||
tags=["test", "mcp", "export"],
|
||||
description="Test agent with MCP tools for serialization testing",
|
||||
)
|
||||
|
||||
agent_state = await server.agent_manager.create_agent_async(
|
||||
agent_create=create_agent_request,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
return agent_state
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Helper Functions
|
||||
# ------------------------------
|
||||
@@ -1258,7 +1320,7 @@ class TestAgentFileValidation:
|
||||
files=[],
|
||||
sources=[],
|
||||
tools=[],
|
||||
# mcp_servers=[],
|
||||
mcp_servers=[],
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
@@ -1302,7 +1364,7 @@ class TestAgentFileValidation:
|
||||
json_schema={"name": "test_tool", "parameters": {"type": "object", "properties": {}}},
|
||||
)
|
||||
],
|
||||
# mcp_servers=[],
|
||||
mcp_servers=[],
|
||||
)
|
||||
|
||||
assert validate_id_format(valid_schema)
|
||||
@@ -1316,11 +1378,282 @@ class TestAgentFileValidation:
|
||||
files=[],
|
||||
sources=[],
|
||||
tools=[],
|
||||
# mcp_servers=[],
|
||||
mcp_servers=[],
|
||||
)
|
||||
|
||||
assert not validate_id_format(invalid_schema)
|
||||
|
||||
|
||||
class TestMCPServerSerialization:
|
||||
"""Tests for MCP server export/import functionality."""
|
||||
|
||||
async def test_mcp_server_export(self, agent_serialization_manager, agent_with_mcp_tools, default_user):
|
||||
"""Test that MCP servers are exported correctly."""
|
||||
agent_file = await agent_serialization_manager.export([agent_with_mcp_tools.id], default_user)
|
||||
|
||||
# Verify MCP server is included
|
||||
assert len(agent_file.mcp_servers) == 1
|
||||
mcp_server = agent_file.mcp_servers[0]
|
||||
|
||||
# Verify server details
|
||||
assert mcp_server.server_name == "test_mcp_server"
|
||||
assert mcp_server.server_url == "http://test-mcp-server.com"
|
||||
assert mcp_server.server_type == "sse"
|
||||
|
||||
# Verify auth fields are excluded
|
||||
assert not hasattr(mcp_server, "token")
|
||||
assert not hasattr(mcp_server, "custom_headers")
|
||||
|
||||
# Verify ID format
|
||||
assert _validate_entity_id(mcp_server.id, "mcp_server")
|
||||
|
||||
async def test_mcp_server_auth_scrubbing(self, server, agent_serialization_manager, default_user):
|
||||
"""Test that authentication information is scrubbed during export."""
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType
|
||||
|
||||
# Create MCP server with auth info
|
||||
mcp_server_data_stdio = MCPServer(
|
||||
server_name="auth_test_server",
|
||||
server_type=MCPServerType.STDIO,
|
||||
# token="super-secret-token",
|
||||
# custom_headers={"Authorization": "Bearer secret-key", "X-Custom": "custom-value"},
|
||||
stdio_config={
|
||||
"server_name": "auth_test_server",
|
||||
"command": "test-command",
|
||||
"args": ["arg1", "arg2"],
|
||||
"env": {"ENV_VAR": "value"},
|
||||
},
|
||||
)
|
||||
mcp_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server_data_stdio, default_user)
|
||||
|
||||
mcp_server_data_http = MCPServer(
|
||||
server_name="auth_test_server_http",
|
||||
server_type=MCPServerType.STREAMABLE_HTTP,
|
||||
server_url="http://auth_test_server_http.com",
|
||||
token="super-secret-token",
|
||||
custom_headers={"X-Custom": "custom-value"},
|
||||
)
|
||||
mcp_server_http = await server.mcp_manager.create_or_update_mcp_server(mcp_server_data_http, default_user)
|
||||
# Create tool from MCP server
|
||||
from letta.schemas.tool import MCPTool, ToolCreate
|
||||
|
||||
mcp_tool_data = MCPTool(
|
||||
name="auth_test_tool_stdio",
|
||||
description="Tool with auth",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
tool_create_stdio = ToolCreate.from_mcp(mcp_server.server_name, mcp_tool_data)
|
||||
|
||||
mcp_tool_data_http = MCPTool(
|
||||
name="auth_test_tool_http",
|
||||
description="Tool with auth",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
|
||||
tool_create_http = ToolCreate.from_mcp(mcp_server_http.server_name, mcp_tool_data_http)
|
||||
|
||||
mcp_tool = await server.tool_manager.create_mcp_tool_async(tool_create_stdio, mcp_server.server_name, mcp_server.id, default_user)
|
||||
mcp_tool_http = await server.tool_manager.create_mcp_tool_async(
|
||||
tool_create_http, mcp_server_http.server_name, mcp_server_http.id, default_user
|
||||
)
|
||||
|
||||
# Create agent with the tool
|
||||
from letta.schemas.agent import CreateAgent
|
||||
|
||||
create_agent_request = CreateAgent(
|
||||
name="auth_test_agent",
|
||||
tool_ids=[mcp_tool.id, mcp_tool_http.id],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
)
|
||||
agent = await server.agent_manager.create_agent_async(create_agent_request, default_user)
|
||||
|
||||
# Export
|
||||
agent_file = await agent_serialization_manager.export([agent.id], default_user)
|
||||
|
||||
for server in agent_file.mcp_servers:
|
||||
if server.server_name == "auth_test_server":
|
||||
exported_server_stdio = server
|
||||
elif server.server_name == "auth_test_server_http":
|
||||
exported_server_http = server
|
||||
|
||||
# Verify env variables in stdio server are excluded (typically used for auth)
|
||||
assert exported_server_stdio.id != mcp_server.id
|
||||
assert exported_server_stdio.server_name == "auth_test_server"
|
||||
assert exported_server_stdio.stdio_config == {
|
||||
"server_name": "auth_test_server",
|
||||
"type": "stdio",
|
||||
"command": "test-command",
|
||||
"args": ["arg1", "arg2"],
|
||||
} # Non-auth config preserved
|
||||
assert exported_server_stdio.server_type == "stdio"
|
||||
|
||||
# Verify token and custom headers are excluded from export for http server
|
||||
assert exported_server_http.id != mcp_server_http.id
|
||||
assert exported_server_http.server_name == "auth_test_server_http"
|
||||
assert exported_server_http.server_type == "streamable_http"
|
||||
assert exported_server_http.server_url == "http://auth_test_server_http.com"
|
||||
assert not hasattr(exported_server_http, "token")
|
||||
assert not hasattr(exported_server_http, "custom_headers")
|
||||
|
||||
async def test_mcp_tool_metadata_with_server_id(self, agent_serialization_manager, agent_with_mcp_tools, default_user):
|
||||
"""Test that MCP tools have server_id in metadata."""
|
||||
agent_file = await agent_serialization_manager.export([agent_with_mcp_tools.id], default_user)
|
||||
|
||||
# Find the MCP tool
|
||||
mcp_tool = next((t for t in agent_file.tools if t.name == "test_mcp_tool"), None)
|
||||
assert mcp_tool is not None
|
||||
|
||||
# Verify metadata contains server info
|
||||
assert mcp_tool.metadata_ is not None
|
||||
assert "mcp" in mcp_tool.metadata_
|
||||
assert "server_name" in mcp_tool.metadata_["mcp"]
|
||||
assert "server_id" in mcp_tool.metadata_["mcp"]
|
||||
assert mcp_tool.metadata_["mcp"]["server_name"] == "test_mcp_server"
|
||||
|
||||
# Verify tag format
|
||||
assert any(tag.startswith("mcp:") for tag in mcp_tool.tags)
|
||||
|
||||
async def test_mcp_server_import(self, agent_serialization_manager, agent_with_mcp_tools, default_user, other_user):
|
||||
"""Test importing agents with MCP servers."""
|
||||
# Export from default user
|
||||
agent_file = await agent_serialization_manager.export([agent_with_mcp_tools.id], default_user)
|
||||
|
||||
# Import to other user
|
||||
result = await agent_serialization_manager.import_file(agent_file, other_user)
|
||||
|
||||
assert result.success
|
||||
|
||||
# Verify MCP server was imported
|
||||
mcp_server_id = next((db_id for file_id, db_id in result.id_mappings.items() if file_id.startswith("mcp_server-")), None)
|
||||
assert mcp_server_id is not None
|
||||
|
||||
async def test_multiple_mcp_servers_export(self, server, agent_serialization_manager, default_user):
|
||||
"""Test exporting multiple MCP servers from different agents."""
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType
|
||||
|
||||
# Create two MCP servers
|
||||
mcp_server1 = await server.mcp_manager.create_or_update_mcp_server(
|
||||
MCPServer(
|
||||
server_name="mcp1",
|
||||
server_type=MCPServerType.STREAMABLE_HTTP,
|
||||
server_url="http://mcp1.com",
|
||||
token="super-secret-token",
|
||||
custom_headers={"X-Custom": "custom-value"},
|
||||
),
|
||||
default_user,
|
||||
)
|
||||
mcp_server2 = await server.mcp_manager.create_or_update_mcp_server(
|
||||
MCPServer(
|
||||
server_name="mcp2",
|
||||
server_type=MCPServerType.STDIO,
|
||||
stdio_config={
|
||||
"server_name": "mcp2",
|
||||
"command": "mcp2-cmd",
|
||||
"args": ["arg1", "arg2"],
|
||||
},
|
||||
),
|
||||
default_user,
|
||||
)
|
||||
|
||||
# Create tools from each server
|
||||
from letta.schemas.tool import MCPTool, ToolCreate
|
||||
|
||||
tool1 = await server.tool_manager.create_mcp_tool_async(
|
||||
ToolCreate.from_mcp(
|
||||
"mcp1",
|
||||
MCPTool(name="tool1", description="Tool 1", inputSchema={"type": "object", "properties": {}}),
|
||||
),
|
||||
"mcp1",
|
||||
mcp_server1.id,
|
||||
default_user,
|
||||
)
|
||||
tool2 = await server.tool_manager.create_mcp_tool_async(
|
||||
ToolCreate.from_mcp(
|
||||
"mcp2",
|
||||
MCPTool(name="tool2", description="Tool 2", inputSchema={"type": "object", "properties": {}}),
|
||||
),
|
||||
"mcp2",
|
||||
mcp_server2.id,
|
||||
default_user,
|
||||
)
|
||||
|
||||
# Create agents with different MCP tools
|
||||
from letta.schemas.agent import CreateAgent
|
||||
|
||||
agent1 = await server.agent_manager.create_agent_async(
|
||||
CreateAgent(
|
||||
name="agent1",
|
||||
tool_ids=[tool1.id],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
default_user,
|
||||
)
|
||||
agent2 = await server.agent_manager.create_agent_async(
|
||||
CreateAgent(
|
||||
name="agent2",
|
||||
tool_ids=[tool2.id],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
default_user,
|
||||
)
|
||||
|
||||
# Export both agents
|
||||
agent_file = await agent_serialization_manager.export([agent1.id, agent2.id], default_user)
|
||||
|
||||
# Verify both MCP servers are included
|
||||
assert len(agent_file.mcp_servers) == 2
|
||||
|
||||
# Verify server types
|
||||
streamable_http_server = next(s for s in agent_file.mcp_servers if s.server_name == "mcp1")
|
||||
stdio_server = next(s for s in agent_file.mcp_servers if s.server_name == "mcp2")
|
||||
|
||||
assert streamable_http_server.server_name == "mcp1"
|
||||
assert streamable_http_server.server_type == "streamable_http"
|
||||
assert streamable_http_server.server_url == "http://mcp1.com"
|
||||
|
||||
assert stdio_server.server_name == "mcp2"
|
||||
assert stdio_server.server_type == "stdio"
|
||||
assert stdio_server.stdio_config == {
|
||||
"server_name": "mcp2",
|
||||
"type": "stdio",
|
||||
"command": "mcp2-cmd",
|
||||
"args": ["arg1", "arg2"],
|
||||
}
|
||||
|
||||
async def test_mcp_server_deduplication(self, server, agent_serialization_manager, default_user, test_mcp_server, mcp_tool):
|
||||
"""Test that shared MCP servers are deduplicated during export."""
|
||||
# Create two agents using the same MCP tool
|
||||
from letta.schemas.agent import CreateAgent
|
||||
|
||||
agent1 = await server.agent_manager.create_agent_async(
|
||||
CreateAgent(
|
||||
name="agent_dup1",
|
||||
tool_ids=[mcp_tool.id],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
default_user,
|
||||
)
|
||||
agent2 = await server.agent_manager.create_agent_async(
|
||||
CreateAgent(
|
||||
name="agent_dup2",
|
||||
tool_ids=[mcp_tool.id],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
default_user,
|
||||
)
|
||||
|
||||
# Export both agents
|
||||
agent_file = await agent_serialization_manager.export([agent1.id, agent2.id], default_user)
|
||||
|
||||
# Verify only one MCP server is exported
|
||||
assert len(agent_file.mcp_servers) == 1
|
||||
assert agent_file.mcp_servers[0].server_name == "test_mcp_server"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -269,8 +269,11 @@ def mcp_tool(server, default_user):
|
||||
},
|
||||
)
|
||||
mcp_server_name = "test"
|
||||
mcp_server_id = "test-server-id" # Mock server ID for testing
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
tool = server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user)
|
||||
tool = server.tool_manager.create_or_update_mcp_tool(
|
||||
tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=default_user
|
||||
)
|
||||
yield tool
|
||||
|
||||
|
||||
@@ -3474,6 +3477,7 @@ def test_create_mcp_tool(server: SyncServer, mcp_tool, default_user, default_org
|
||||
assert mcp_tool.created_by_id == default_user.id
|
||||
assert mcp_tool.tool_type == ToolType.EXTERNAL_MCP
|
||||
assert mcp_tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == "test"
|
||||
assert mcp_tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_id"] == "test-server-id"
|
||||
|
||||
|
||||
# Test should work with both SQLite and PostgreSQL
|
||||
@@ -8552,6 +8556,83 @@ async def test_create_mcp_server(server, default_user, event_loop):
|
||||
print("TAGS", tool.tags)
|
||||
|
||||
|
||||
async def test_get_mcp_servers_by_ids(server, default_user, event_loop):
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig
|
||||
from letta.settings import tool_settings
|
||||
|
||||
if tool_settings.mcp_read_from_config:
|
||||
return
|
||||
|
||||
# Create multiple MCP servers for testing
|
||||
servers_data = [
|
||||
{
|
||||
"name": "test_server_1",
|
||||
"config": StdioServerConfig(
|
||||
server_name="test_server_1", type=MCPServerType.STDIO, command="echo 'test1'", args=["arg1"], env={"ENV1": "value1"}
|
||||
),
|
||||
"type": MCPServerType.STDIO,
|
||||
},
|
||||
{
|
||||
"name": "test_server_2",
|
||||
"config": SSEServerConfig(server_name="test_server_2", server_url="https://test2.example.com/sse"),
|
||||
"type": MCPServerType.SSE,
|
||||
},
|
||||
{
|
||||
"name": "test_server_3",
|
||||
"config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/sse"),
|
||||
"type": MCPServerType.SSE,
|
||||
},
|
||||
]
|
||||
|
||||
created_servers = []
|
||||
for server_data in servers_data:
|
||||
if server_data["type"] == MCPServerType.STDIO:
|
||||
mcp_server = MCPServer(server_name=server_data["name"], server_type=server_data["type"], stdio_config=server_data["config"])
|
||||
else:
|
||||
mcp_server = MCPServer(
|
||||
server_name=server_data["name"], server_type=server_data["type"], server_url=server_data["config"].server_url
|
||||
)
|
||||
|
||||
created = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
|
||||
created_servers.append(created)
|
||||
|
||||
# Test fetching multiple servers by IDs
|
||||
server_ids = [s.id for s in created_servers]
|
||||
fetched_servers = await server.mcp_manager.get_mcp_servers_by_ids(server_ids, actor=default_user)
|
||||
|
||||
assert len(fetched_servers) == len(created_servers)
|
||||
fetched_ids = {s.id for s in fetched_servers}
|
||||
expected_ids = {s.id for s in created_servers}
|
||||
assert fetched_ids == expected_ids
|
||||
|
||||
# Test fetching subset of servers
|
||||
subset_ids = server_ids[:2]
|
||||
subset_servers = await server.mcp_manager.get_mcp_servers_by_ids(subset_ids, actor=default_user)
|
||||
assert len(subset_servers) == 2
|
||||
assert all(s.id in subset_ids for s in subset_servers)
|
||||
|
||||
# Test fetching with empty list
|
||||
empty_result = await server.mcp_manager.get_mcp_servers_by_ids([], actor=default_user)
|
||||
assert empty_result == []
|
||||
|
||||
# Test fetching with non-existent ID mixed with valid IDs
|
||||
mixed_ids = [server_ids[0], "non-existent-id", server_ids[1]]
|
||||
mixed_result = await server.mcp_manager.get_mcp_servers_by_ids(mixed_ids, actor=default_user)
|
||||
# Should only return the existing servers
|
||||
assert len(mixed_result) == 2
|
||||
assert all(s.id in server_ids for s in mixed_result)
|
||||
|
||||
# Test that servers from different organizations are not returned
|
||||
# This would require creating another user/org, but for now we'll just verify
|
||||
# that the function respects the actor's organization
|
||||
all_servers = await server.mcp_manager.list_mcp_servers(actor=default_user)
|
||||
all_server_ids = [s.id for s in all_servers]
|
||||
bulk_fetched = await server.mcp_manager.get_mcp_servers_by_ids(all_server_ids, actor=default_user)
|
||||
|
||||
# All fetched servers should belong to the same organization
|
||||
assert all(s.organization_id == default_user.organization_id for s in bulk_fetched)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# FileAgent Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user