diff --git a/letta/schemas/agent_file.py b/letta/schemas/agent_file.py index 26b45cb0..c0e4fa47 100644 --- a/letta/schemas/agent_file.py +++ b/letta/schemas/agent_file.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field @@ -8,6 +8,7 @@ from letta.schemas.block import Block, CreateBlock from letta.schemas.enums import MessageRole from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase from letta.schemas.group import GroupCreate +from letta.schemas.mcp import MCPServer from letta.schemas.message import Message, MessageCreate from letta.schemas.source import Source, SourceCreate from letta.schemas.tool import Tool @@ -264,9 +265,34 @@ class ToolSchema(Tool): return cls(**tool.model_dump()) -# class MCPServerSchema(RegisterMCPServer): -# """MCP Server with human-readable ID for agent file""" -# id: str = Field(..., description="Human-readable identifier for this MCP server in the file") +class MCPServerSchema(BaseModel): + """MCP server schema for agent files with remapped ID.""" + + __id_prefix__ = "mcp_server" + + id: str = Field(..., description="Human-readable MCP server ID") + server_type: str + server_name: str + server_url: Optional[str] = None + stdio_config: Optional[Dict[str, Any]] = None + metadata_: Optional[Dict[str, Any]] = None + + @classmethod + def from_mcp_server(cls, mcp_server: MCPServer) -> "MCPServerSchema": + """Convert MCPServer to MCPServerSchema (excluding auth fields).""" + return cls( + id=mcp_server.id, # remapped by serialization manager + server_type=mcp_server.server_type, + server_name=mcp_server.server_name, + server_url=mcp_server.server_url, + # exclude token, custom_headers, and the env field in stdio_config that may contain authentication credentials + stdio_config=cls.strip_env_from_stdio_config(mcp_server.stdio_config.model_dump()) if mcp_server.stdio_config else None, + metadata_=mcp_server.metadata_, + ) + + def strip_env_from_stdio_config(stdio_config: Dict[str, Any]) -> Dict[str, Any]: + """Strip out the env field from the stdio config.""" + return {k: v for k, v in stdio_config.items() if k != "env"} class AgentFileSchema(BaseModel): @@ -278,7 +304,7 @@ class AgentFileSchema(BaseModel): files: List[FileSchema] = Field(..., description="List of files in this agent file") sources: List[SourceSchema] = Field(..., description="List of sources in this agent file") tools: List[ToolSchema] = Field(..., description="List of tools in this agent file") - # mcp_servers: List[MCPServerSchema] = Field(..., description="List of MCP servers in this agent file") + mcp_servers: List[MCPServerSchema] = Field(..., description="List of MCP servers in this agent file") metadata: Dict[str, str] = Field( default_factory=dict, description="Metadata for this agent file, including revision_id and other export information." ) diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 7e840a8c..ad171613 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -475,7 +475,11 @@ async def add_mcp_tool( ) tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool) - return await server.tool_manager.create_mcp_tool_async(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=actor) + # For config-based servers, use the server name as ID since they don't have database IDs + mcp_server_id = mcp_server_name + return await server.tool_manager.create_mcp_tool_async( + tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor + ) else: return await server.mcp_manager.add_tool_from_mcp_server(mcp_server_name=mcp_server_name, mcp_tool_name=mcp_tool_name, actor=actor) diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index 632fd77b..166d1608 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -13,6 +13,7 @@ from letta.schemas.agent_file import ( FileSchema, GroupSchema, ImportResult, + MCPServerSchema, MessageSchema, SourceSchema, ToolSchema, @@ -20,6 +21,7 @@ from letta.schemas.agent_file import ( from letta.schemas.block import Block from letta.schemas.enums import FileProcessingStatus from letta.schemas.file import FileMetadata +from letta.schemas.mcp import MCPServer from letta.schemas.message import Message from letta.schemas.source import Source from letta.schemas.tool import Tool @@ -92,7 +94,7 @@ class AgentSerializationManager: ToolSchema.__id_prefix__: 0, MessageSchema.__id_prefix__: 0, FileAgentSchema.__id_prefix__: 0, - # MCPServerSchema.__id_prefix__: 0, + MCPServerSchema.__id_prefix__: 0, } def _reset_state(self): @@ -258,6 +260,58 @@ class AgentSerializationManager: file_schema.source_id = self._map_db_to_file_id(file_metadata.source_id, SourceSchema.__id_prefix__, allow_new=False) return file_schema + async def _extract_unique_mcp_servers(self, tools: List, actor: User) -> List: + """Extract unique MCP servers from tools based on metadata, using server_id if available, otherwise falling back to server_name.""" + from letta.constants import MCP_TOOL_TAG_NAME_PREFIX + + mcp_server_ids = set() + mcp_server_names = set() + for tool in tools: + # Check if tool has MCP metadata + if tool.metadata_ and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_: + mcp_metadata = tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX] + # TODO: @jnjpng clean this up once we fully migrate to server_id being the main identifier + if "server_id" in mcp_metadata: + mcp_server_ids.add(mcp_metadata["server_id"]) + elif "server_name" in mcp_metadata: + mcp_server_names.add(mcp_metadata["server_name"]) + + # Fetch MCP servers by ID + mcp_servers = [] + fetched_server_ids = set() + if mcp_server_ids: + for server_id in mcp_server_ids: + try: + mcp_server = await self.mcp_manager.get_mcp_server_by_id_async(server_id, actor) + if mcp_server: + mcp_servers.append(mcp_server) + fetched_server_ids.add(server_id) + except Exception as e: + logger.warning(f"Failed to fetch MCP server {server_id}: {e}") + + # Fetch MCP servers by name if not already fetched by ID + if mcp_server_names: + for server_name in mcp_server_names: + try: + mcp_server = await self.mcp_manager.get_mcp_server(server_name, actor) + if mcp_server and mcp_server.id not in fetched_server_ids: + mcp_servers.append(mcp_server) + except Exception as e: + logger.warning(f"Failed to fetch MCP server by name {server_name}: {e}") + + return mcp_servers + + def _convert_mcp_server_to_schema(self, mcp_server: MCPServer) -> MCPServerSchema: + """Convert MCPServer to MCPServerSchema with ID remapping and auth scrubbing""" + try: + mcp_file_id = self._map_db_to_file_id(mcp_server.id, MCPServerSchema.__id_prefix__, allow_new=False) + mcp_schema = MCPServerSchema.from_mcp_server(mcp_server) + mcp_schema.id = mcp_file_id + return mcp_schema + except Exception as e: + logger.error(f"Failed to convert MCP server {mcp_server.id}: {e}") + raise + async def export(self, agent_ids: List[str], actor: User) -> AgentFileSchema: """ Export agents and their related entities to AgentFileSchema format. @@ -289,6 +343,13 @@ class AgentSerializationManager: tool_set = self._extract_unique_tools(agent_states) block_set = self._extract_unique_blocks(agent_states) + # Extract MCP servers from tools BEFORE conversion (must be done before ID mapping) + mcp_server_set = await self._extract_unique_mcp_servers(tool_set, actor) + + # Map MCP server IDs before converting schemas + for mcp_server in mcp_server_set: + self._map_db_to_file_id(mcp_server.id, MCPServerSchema.__id_prefix__) + # Extract sources and files from agent states BEFORE conversion (with caching) source_set, file_set = await self._extract_unique_sources_and_files_from_agents(agent_states, actor, files_agents_cache) @@ -301,6 +362,7 @@ class AgentSerializationManager: block_schemas = [self._convert_block_to_schema(block) for block in block_set] source_schemas = [self._convert_source_to_schema(source) for source in source_set] file_schemas = [self._convert_file_to_schema(file_metadata) for file_metadata in file_set] + mcp_server_schemas = [self._convert_mcp_server_to_schema(mcp_server) for mcp_server in mcp_server_set] logger.info(f"Exporting {len(agent_ids)} agents to agent file format") @@ -312,7 +374,7 @@ class AgentSerializationManager: files=file_schemas, sources=source_schemas, tools=tool_schemas, - # mcp_servers=[], # TODO: Extract and convert MCP servers + mcp_servers=mcp_server_schemas, metadata={"revision_id": await get_latest_alembic_revision()}, created_at=datetime.now(timezone.utc), ) @@ -359,7 +421,20 @@ class AgentSerializationManager: # in-memory cache for file metadata to avoid repeated db calls file_metadata_cache = {} # Maps database file ID to FileMetadata - # 1. Create tools first (no dependencies) - using bulk upsert for efficiency + # 1. Create MCP servers first (tools depend on them) + if schema.mcp_servers: + for mcp_server_schema in schema.mcp_servers: + server_data = mcp_server_schema.model_dump(exclude={"id"}) + filtered_server_data = self._filter_dict_for_model(server_data, MCPServer) + create_schema = MCPServer(**filtered_server_data) + + # Note: We don't have auth info from export, so the user will need to re-configure auth. + # TODO: @jnjpng store metadata about obfuscated metadata to surface to the user + created_mcp_server = await self.mcp_manager.create_or_update_mcp_server(create_schema, actor) + file_to_db_ids[mcp_server_schema.id] = created_mcp_server.id + imported_count += 1 + + # 2. Create tools (may depend on MCP servers) - using bulk upsert for efficiency if schema.tools: # convert tool schemas to pydantic tools pydantic_tools = [] @@ -559,6 +634,7 @@ class AgentSerializationManager: (schema.files, FileSchema.__id_prefix__), (schema.sources, SourceSchema.__id_prefix__), (schema.tools, ToolSchema.__id_prefix__), + (schema.mcp_servers, MCPServerSchema.__id_prefix__), ] for entities, expected_prefix in entity_checks: @@ -601,6 +677,7 @@ class AgentSerializationManager: ("files", schema.files), ("sources", schema.sources), ("tools", schema.tools), + ("mcp_servers", schema.mcp_servers), ] for entity_type, entities in entity_collections: @@ -705,3 +782,11 @@ class AgentSerializationManager: raise AgentFileImportError(f"Schema validation failed: {'; '.join(errors)}") logger.info("Schema validation passed") + + def _filter_dict_for_model(self, data: dict, model_cls): + """Filter a dictionary to only include keys that are in the model fields""" + try: + allowed = model_cls.model_fields.keys() # Pydantic v2 + except AttributeError: + allowed = model_cls.__fields__.keys() # Pydantic v1 + return {k: v for k, v in data.items() if k in allowed} diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index b542e22b..77cd2a9a 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -98,12 +98,19 @@ class MCPManager: @enforce_types async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool: """Add a tool from an MCP server to the Letta tool registry.""" + # get the MCP server ID, we should migrate to use the server_id instead of the name + mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor) + if not mcp_server_id: + raise ValueError(f"MCP server '{mcp_server_name}' not found") + mcp_tools = await self.list_mcp_server_tools(mcp_server_name, actor=actor) for mcp_tool in mcp_tools: if mcp_tool.name == mcp_tool_name: tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool) - return await self.tool_manager.create_mcp_tool_async(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=actor) + return await self.tool_manager.create_mcp_tool_async( + tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor + ) # failed to add - handle error? return None @@ -194,14 +201,7 @@ class MCPManager: """Update an MCP server by its name.""" mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor) if not mcp_server_id: - raise HTTPException( - status_code=404, - detail={ - "code": "MCPServerNotFoundError", - "message": f"MCP server {mcp_server_name} not found", - "mcp_server_name": mcp_server_name, - }, - ) + raise ValueError(f"MCP server {mcp_server_name} not found") return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor) @enforce_types @@ -223,6 +223,18 @@ class MCPManager: # Convert the SQLAlchemy Tool object to PydanticTool return mcp_server.to_pydantic() + @enforce_types + async def get_mcp_servers_by_ids(self, mcp_server_ids: List[str], actor: PydanticUser) -> List[MCPServer]: + """Fetch multiple MCP servers by their IDs in a single query.""" + if not mcp_server_ids: + return [] + + async with db_registry.async_session() as session: + mcp_servers = await MCPServerModel.list_async( + db_session=session, organization_id=actor.organization_id, id=mcp_server_ids # This will use the IN operator + ) + return [mcp_server.to_pydantic() for mcp_server in mcp_servers] + @enforce_types async def get_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> PydanticTool: """Get a tool by name.""" @@ -230,14 +242,7 @@ class MCPManager: mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor) mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) if not mcp_server: - raise HTTPException( - status_code=404, # Not Found - detail={ - "code": "MCPServerNotFoundError", - "message": f"MCP server {mcp_server_name} not found", - "mcp_server_name": mcp_server_name, - }, - ) + raise ValueError(f"MCP server {mcp_server_name} not found") return mcp_server.to_pydantic() # @enforce_types diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 1885fa24..29d2723a 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -106,8 +106,10 @@ class ToolManager: @enforce_types @trace_method - def create_or_update_mcp_tool(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool: - metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}} + def create_or_update_mcp_tool( + self, tool_create: ToolCreate, mcp_server_name: str, mcp_server_id: str, actor: PydanticUser + ) -> PydanticTool: + metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name, "server_id": mcp_server_id}} return self.create_or_update_tool( PydanticTool( tool_type=ToolType.EXTERNAL_MCP, name=tool_create.json_schema["name"], metadata_=metadata, **tool_create.model_dump() @@ -116,8 +118,10 @@ class ToolManager: ) @enforce_types - async def create_mcp_tool_async(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool: - metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}} + async def create_mcp_tool_async( + self, tool_create: ToolCreate, mcp_server_name: str, mcp_server_id: str, actor: PydanticUser + ) -> PydanticTool: + metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name, "server_id": mcp_server_id}} return await self.create_or_update_tool_async( PydanticTool( tool_type=ToolType.EXTERNAL_MCP, name=tool_create.json_schema["name"], metadata_=metadata, **tool_create.model_dump() diff --git a/tests/mcp/mcp_config.json b/tests/mcp/mcp_config.json index 0967ef42..9e26dfee 100644 --- a/tests/mcp/mcp_config.json +++ b/tests/mcp/mcp_config.json @@ -1 +1 @@ -{} +{} \ No newline at end of file diff --git a/tests/test_agent_serialization_v2.py b/tests/test_agent_serialization_v2.py index 5d0b8570..10f5480c 100644 --- a/tests/test_agent_serialization_v2.py +++ b/tests/test_agent_serialization_v2.py @@ -283,6 +283,68 @@ async def agent_with_files(server: SyncServer, default_user, test_block, weather return (agent_state.id, test_source.id, test_file.id) +@pytest.fixture +async def test_mcp_server(server: SyncServer, default_user): + """Fixture to create and return a test MCP server.""" + from letta.schemas.mcp import MCPServer, MCPServerType + + mcp_server_data = MCPServer( + server_name="test_mcp_server", + server_type=MCPServerType.SSE, + server_url="http://test-mcp-server.com", + token="test-token-12345", # This should be excluded during export + custom_headers={"X-API-Key": "secret-key"}, # This should be excluded during export + ) + mcp_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server_data, default_user) + yield mcp_server + + +@pytest.fixture +async def mcp_tool(server: SyncServer, default_user, test_mcp_server): + """Fixture to create and return an MCP tool.""" + from letta.schemas.tool import MCPTool, ToolCreate + + # Create a mock MCP tool + mcp_tool_data = MCPTool( + name="test_mcp_tool", + description="Test MCP tool for serialization", + inputSchema={"type": "object", "properties": {"input": {"type": "string"}}}, + ) + tool_create = ToolCreate.from_mcp(test_mcp_server.server_name, mcp_tool_data) + + # Create tool with MCP metadata + mcp_tool = await server.tool_manager.create_mcp_tool_async(tool_create, test_mcp_server.server_name, test_mcp_server.id, default_user) + yield mcp_tool + + +@pytest.fixture +async def agent_with_mcp_tools(server: SyncServer, default_user, test_block, mcp_tool, test_mcp_server): + """Fixture to create and return an agent with MCP tools.""" + memory_blocks = [ + CreateBlock(label="human", value="User is a test user"), + CreateBlock(label="persona", value="I am a helpful test assistant"), + ] + + create_agent_request = CreateAgent( + name="test_agent_mcp", + system="You are a helpful assistant with MCP tools.", + memory_blocks=memory_blocks, + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + block_ids=[test_block.id], + tool_ids=[mcp_tool.id], + tags=["test", "mcp", "export"], + description="Test agent with MCP tools for serialization testing", + ) + + agent_state = await server.agent_manager.create_agent_async( + agent_create=create_agent_request, + actor=default_user, + ) + + return agent_state + + # ------------------------------ # Helper Functions # ------------------------------ @@ -1258,7 +1320,7 @@ class TestAgentFileValidation: files=[], sources=[], tools=[], - # mcp_servers=[], + mcp_servers=[], ) # Should not raise @@ -1302,7 +1364,7 @@ class TestAgentFileValidation: json_schema={"name": "test_tool", "parameters": {"type": "object", "properties": {}}}, ) ], - # mcp_servers=[], + mcp_servers=[], ) assert validate_id_format(valid_schema) @@ -1316,11 +1378,282 @@ class TestAgentFileValidation: files=[], sources=[], tools=[], - # mcp_servers=[], + mcp_servers=[], ) assert not validate_id_format(invalid_schema) +class TestMCPServerSerialization: + """Tests for MCP server export/import functionality.""" + + async def test_mcp_server_export(self, agent_serialization_manager, agent_with_mcp_tools, default_user): + """Test that MCP servers are exported correctly.""" + agent_file = await agent_serialization_manager.export([agent_with_mcp_tools.id], default_user) + + # Verify MCP server is included + assert len(agent_file.mcp_servers) == 1 + mcp_server = agent_file.mcp_servers[0] + + # Verify server details + assert mcp_server.server_name == "test_mcp_server" + assert mcp_server.server_url == "http://test-mcp-server.com" + assert mcp_server.server_type == "sse" + + # Verify auth fields are excluded + assert not hasattr(mcp_server, "token") + assert not hasattr(mcp_server, "custom_headers") + + # Verify ID format + assert _validate_entity_id(mcp_server.id, "mcp_server") + + async def test_mcp_server_auth_scrubbing(self, server, agent_serialization_manager, default_user): + """Test that authentication information is scrubbed during export.""" + from letta.schemas.mcp import MCPServer, MCPServerType + + # Create MCP server with auth info + mcp_server_data_stdio = MCPServer( + server_name="auth_test_server", + server_type=MCPServerType.STDIO, + # token="super-secret-token", + # custom_headers={"Authorization": "Bearer secret-key", "X-Custom": "custom-value"}, + stdio_config={ + "server_name": "auth_test_server", + "command": "test-command", + "args": ["arg1", "arg2"], + "env": {"ENV_VAR": "value"}, + }, + ) + mcp_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server_data_stdio, default_user) + + mcp_server_data_http = MCPServer( + server_name="auth_test_server_http", + server_type=MCPServerType.STREAMABLE_HTTP, + server_url="http://auth_test_server_http.com", + token="super-secret-token", + custom_headers={"X-Custom": "custom-value"}, + ) + mcp_server_http = await server.mcp_manager.create_or_update_mcp_server(mcp_server_data_http, default_user) + # Create tool from MCP server + from letta.schemas.tool import MCPTool, ToolCreate + + mcp_tool_data = MCPTool( + name="auth_test_tool_stdio", + description="Tool with auth", + inputSchema={"type": "object", "properties": {}}, + ) + tool_create_stdio = ToolCreate.from_mcp(mcp_server.server_name, mcp_tool_data) + + mcp_tool_data_http = MCPTool( + name="auth_test_tool_http", + description="Tool with auth", + inputSchema={"type": "object", "properties": {}}, + ) + + tool_create_http = ToolCreate.from_mcp(mcp_server_http.server_name, mcp_tool_data_http) + + mcp_tool = await server.tool_manager.create_mcp_tool_async(tool_create_stdio, mcp_server.server_name, mcp_server.id, default_user) + mcp_tool_http = await server.tool_manager.create_mcp_tool_async( + tool_create_http, mcp_server_http.server_name, mcp_server_http.id, default_user + ) + + # Create agent with the tool + from letta.schemas.agent import CreateAgent + + create_agent_request = CreateAgent( + name="auth_test_agent", + tool_ids=[mcp_tool.id, mcp_tool_http.id], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ) + agent = await server.agent_manager.create_agent_async(create_agent_request, default_user) + + # Export + agent_file = await agent_serialization_manager.export([agent.id], default_user) + + for server in agent_file.mcp_servers: + if server.server_name == "auth_test_server": + exported_server_stdio = server + elif server.server_name == "auth_test_server_http": + exported_server_http = server + + # Verify env variables in stdio server are excluded (typically used for auth) + assert exported_server_stdio.id != mcp_server.id + assert exported_server_stdio.server_name == "auth_test_server" + assert exported_server_stdio.stdio_config == { + "server_name": "auth_test_server", + "type": "stdio", + "command": "test-command", + "args": ["arg1", "arg2"], + } # Non-auth config preserved + assert exported_server_stdio.server_type == "stdio" + + # Verify token and custom headers are excluded from export for http server + assert exported_server_http.id != mcp_server_http.id + assert exported_server_http.server_name == "auth_test_server_http" + assert exported_server_http.server_type == "streamable_http" + assert exported_server_http.server_url == "http://auth_test_server_http.com" + assert not hasattr(exported_server_http, "token") + assert not hasattr(exported_server_http, "custom_headers") + + async def test_mcp_tool_metadata_with_server_id(self, agent_serialization_manager, agent_with_mcp_tools, default_user): + """Test that MCP tools have server_id in metadata.""" + agent_file = await agent_serialization_manager.export([agent_with_mcp_tools.id], default_user) + + # Find the MCP tool + mcp_tool = next((t for t in agent_file.tools if t.name == "test_mcp_tool"), None) + assert mcp_tool is not None + + # Verify metadata contains server info + assert mcp_tool.metadata_ is not None + assert "mcp" in mcp_tool.metadata_ + assert "server_name" in mcp_tool.metadata_["mcp"] + assert "server_id" in mcp_tool.metadata_["mcp"] + assert mcp_tool.metadata_["mcp"]["server_name"] == "test_mcp_server" + + # Verify tag format + assert any(tag.startswith("mcp:") for tag in mcp_tool.tags) + + async def test_mcp_server_import(self, agent_serialization_manager, agent_with_mcp_tools, default_user, other_user): + """Test importing agents with MCP servers.""" + # Export from default user + agent_file = await agent_serialization_manager.export([agent_with_mcp_tools.id], default_user) + + # Import to other user + result = await agent_serialization_manager.import_file(agent_file, other_user) + + assert result.success + + # Verify MCP server was imported + mcp_server_id = next((db_id for file_id, db_id in result.id_mappings.items() if file_id.startswith("mcp_server-")), None) + assert mcp_server_id is not None + + async def test_multiple_mcp_servers_export(self, server, agent_serialization_manager, default_user): + """Test exporting multiple MCP servers from different agents.""" + from letta.schemas.mcp import MCPServer, MCPServerType + + # Create two MCP servers + mcp_server1 = await server.mcp_manager.create_or_update_mcp_server( + MCPServer( + server_name="mcp1", + server_type=MCPServerType.STREAMABLE_HTTP, + server_url="http://mcp1.com", + token="super-secret-token", + custom_headers={"X-Custom": "custom-value"}, + ), + default_user, + ) + mcp_server2 = await server.mcp_manager.create_or_update_mcp_server( + MCPServer( + server_name="mcp2", + server_type=MCPServerType.STDIO, + stdio_config={ + "server_name": "mcp2", + "command": "mcp2-cmd", + "args": ["arg1", "arg2"], + }, + ), + default_user, + ) + + # Create tools from each server + from letta.schemas.tool import MCPTool, ToolCreate + + tool1 = await server.tool_manager.create_mcp_tool_async( + ToolCreate.from_mcp( + "mcp1", + MCPTool(name="tool1", description="Tool 1", inputSchema={"type": "object", "properties": {}}), + ), + "mcp1", + mcp_server1.id, + default_user, + ) + tool2 = await server.tool_manager.create_mcp_tool_async( + ToolCreate.from_mcp( + "mcp2", + MCPTool(name="tool2", description="Tool 2", inputSchema={"type": "object", "properties": {}}), + ), + "mcp2", + mcp_server2.id, + default_user, + ) + + # Create agents with different MCP tools + from letta.schemas.agent import CreateAgent + + agent1 = await server.agent_manager.create_agent_async( + CreateAgent( + name="agent1", + tool_ids=[tool1.id], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + default_user, + ) + agent2 = await server.agent_manager.create_agent_async( + CreateAgent( + name="agent2", + tool_ids=[tool2.id], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + default_user, + ) + + # Export both agents + agent_file = await agent_serialization_manager.export([agent1.id, agent2.id], default_user) + + # Verify both MCP servers are included + assert len(agent_file.mcp_servers) == 2 + + # Verify server types + streamable_http_server = next(s for s in agent_file.mcp_servers if s.server_name == "mcp1") + stdio_server = next(s for s in agent_file.mcp_servers if s.server_name == "mcp2") + + assert streamable_http_server.server_name == "mcp1" + assert streamable_http_server.server_type == "streamable_http" + assert streamable_http_server.server_url == "http://mcp1.com" + + assert stdio_server.server_name == "mcp2" + assert stdio_server.server_type == "stdio" + assert stdio_server.stdio_config == { + "server_name": "mcp2", + "type": "stdio", + "command": "mcp2-cmd", + "args": ["arg1", "arg2"], + } + + async def test_mcp_server_deduplication(self, server, agent_serialization_manager, default_user, test_mcp_server, mcp_tool): + """Test that shared MCP servers are deduplicated during export.""" + # Create two agents using the same MCP tool + from letta.schemas.agent import CreateAgent + + agent1 = await server.agent_manager.create_agent_async( + CreateAgent( + name="agent_dup1", + tool_ids=[mcp_tool.id], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + default_user, + ) + agent2 = await server.agent_manager.create_agent_async( + CreateAgent( + name="agent_dup2", + tool_ids=[mcp_tool.id], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + default_user, + ) + + # Export both agents + agent_file = await agent_serialization_manager.export([agent1.id, agent2.id], default_user) + + # Verify only one MCP server is exported + assert len(agent_file.mcp_servers) == 1 + assert agent_file.mcp_servers[0].server_name == "test_mcp_server" + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_managers.py b/tests/test_managers.py index c75b6167..395f3ebc 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -269,8 +269,11 @@ def mcp_tool(server, default_user): }, ) mcp_server_name = "test" + mcp_server_id = "test-server-id" # Mock server ID for testing tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool) - tool = server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user) + tool = server.tool_manager.create_or_update_mcp_tool( + tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=default_user + ) yield tool @@ -3474,6 +3477,7 @@ def test_create_mcp_tool(server: SyncServer, mcp_tool, default_user, default_org assert mcp_tool.created_by_id == default_user.id assert mcp_tool.tool_type == ToolType.EXTERNAL_MCP assert mcp_tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == "test" + assert mcp_tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_id"] == "test-server-id" # Test should work with both SQLite and PostgreSQL @@ -8552,6 +8556,83 @@ async def test_create_mcp_server(server, default_user, event_loop): print("TAGS", tool.tags) +async def test_get_mcp_servers_by_ids(server, default_user, event_loop): + from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig + from letta.settings import tool_settings + + if tool_settings.mcp_read_from_config: + return + + # Create multiple MCP servers for testing + servers_data = [ + { + "name": "test_server_1", + "config": StdioServerConfig( + server_name="test_server_1", type=MCPServerType.STDIO, command="echo 'test1'", args=["arg1"], env={"ENV1": "value1"} + ), + "type": MCPServerType.STDIO, + }, + { + "name": "test_server_2", + "config": SSEServerConfig(server_name="test_server_2", server_url="https://test2.example.com/sse"), + "type": MCPServerType.SSE, + }, + { + "name": "test_server_3", + "config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/sse"), + "type": MCPServerType.SSE, + }, + ] + + created_servers = [] + for server_data in servers_data: + if server_data["type"] == MCPServerType.STDIO: + mcp_server = MCPServer(server_name=server_data["name"], server_type=server_data["type"], stdio_config=server_data["config"]) + else: + mcp_server = MCPServer( + server_name=server_data["name"], server_type=server_data["type"], server_url=server_data["config"].server_url + ) + + created = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user) + created_servers.append(created) + + # Test fetching multiple servers by IDs + server_ids = [s.id for s in created_servers] + fetched_servers = await server.mcp_manager.get_mcp_servers_by_ids(server_ids, actor=default_user) + + assert len(fetched_servers) == len(created_servers) + fetched_ids = {s.id for s in fetched_servers} + expected_ids = {s.id for s in created_servers} + assert fetched_ids == expected_ids + + # Test fetching subset of servers + subset_ids = server_ids[:2] + subset_servers = await server.mcp_manager.get_mcp_servers_by_ids(subset_ids, actor=default_user) + assert len(subset_servers) == 2 + assert all(s.id in subset_ids for s in subset_servers) + + # Test fetching with empty list + empty_result = await server.mcp_manager.get_mcp_servers_by_ids([], actor=default_user) + assert empty_result == [] + + # Test fetching with non-existent ID mixed with valid IDs + mixed_ids = [server_ids[0], "non-existent-id", server_ids[1]] + mixed_result = await server.mcp_manager.get_mcp_servers_by_ids(mixed_ids, actor=default_user) + # Should only return the existing servers + assert len(mixed_result) == 2 + assert all(s.id in server_ids for s in mixed_result) + + # Test that servers from different organizations are not returned + # This would require creating another user/org, but for now we'll just verify + # that the function respects the actor's organization + all_servers = await server.mcp_manager.list_mcp_servers(actor=default_user) + all_server_ids = [s.id for s in all_servers] + bulk_fetched = await server.mcp_manager.get_mcp_servers_by_ids(all_server_ids, actor=default_user) + + # All fetched servers should belong to the same organization + assert all(s.organization_id == default_user.organization_id for s in bulk_fetched) + + # ====================================================================================================================== # FileAgent Tests # ======================================================================================================================