Files
letta-server/tests/managers/test_mcp_manager.py
Kian Jones 25d54dd896 chore: enable F821, F401, W293 (#9503)
* auto fixes

* auto fix pt2 and transitive deps and undefined var checking locals()

* manual fixes (ignored or letta-code fixed)

* fix circular import
2026-02-24 10:55:08 -08:00

1056 lines
43 KiB
Python

import uuid
from unittest.mock import AsyncMock, patch
import pytest
# Import shared fixtures and constants from conftest
from letta.constants import (
MCP_TOOL_TAG_NAME_PREFIX,
)
from letta.functions.mcp_client.types import MCPTool
from letta.schemas.enums import (
ToolType,
)
from letta.server.db import db_registry
from letta.settings import settings
# ======================================================================================================================
# MCPManager Tests
# ======================================================================================================================
@pytest.mark.asyncio
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server(mock_get_client, server, default_user):
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig
from letta.settings import tool_settings
if tool_settings.mcp_read_from_config:
return
# create mock client with required methods
mock_client = AsyncMock()
mock_client.connect_to_server = AsyncMock()
mock_client.list_tools = AsyncMock(
return_value=[
MCPTool(
name="get_simple_price",
inputSchema={
"type": "object",
"properties": {
"ids": {"type": "string"},
"vs_currencies": {"type": "string"},
"include_market_cap": {"type": "boolean"},
"include_24hr_vol": {"type": "boolean"},
"include_24hr_change": {"type": "boolean"},
},
"required": ["ids", "vs_currencies"],
"additionalProperties": False,
},
)
]
)
mock_client.execute_tool = AsyncMock(
return_value=(
'{"bitcoin": {"usd": 50000, "usd_market_cap": 900000000000, "usd_24h_vol": 30000000000, "usd_24h_change": 2.5}}',
True,
)
)
mock_get_client.return_value = mock_client
# Test with a valid StdioServerConfig
server_config = StdioServerConfig(
server_name="test_server", type=MCPServerType.STDIO, command="echo 'test'", args=["arg1", "arg2"], env={"ENV1": "value1"}
)
mcp_server = MCPServer(server_name="test_server", server_type=MCPServerType.STDIO, stdio_config=server_config)
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
print(created_server)
assert created_server.server_name == server_config.server_name
assert created_server.server_type == server_config.type
# Test with a valid SSEServerConfig
mcp_server_name = "coingecko"
server_url = "https://mcp.api.coingecko.com/sse"
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
mcp_sse_server = MCPServer(server_name=mcp_server_name, server_type=MCPServerType.SSE, server_url=server_url)
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_sse_server, actor=default_user)
print(created_server)
assert created_server.server_name == mcp_server_name
assert created_server.server_type == MCPServerType.SSE
# list mcp servers
servers = await server.mcp_manager.list_mcp_servers(actor=default_user)
print(servers)
assert len(servers) > 0, "No MCP servers found"
# list tools from sse server
tools = await server.mcp_manager.list_mcp_server_tools(created_server.server_name, actor=default_user)
print(tools)
# call a tool from the sse server
tool_name = "get_simple_price"
tool_args = {
"ids": "bitcoin",
"vs_currencies": "usd",
"include_market_cap": True,
"include_24hr_vol": True,
"include_24hr_change": True,
}
result = await server.mcp_manager.execute_mcp_server_tool(
created_server.server_name, tool_name=tool_name, tool_args=tool_args, actor=default_user, environment_variables={}
)
print(result)
# add a tool
tool = await server.mcp_manager.add_tool_from_mcp_server(created_server.server_name, tool_name, actor=default_user)
print(tool)
assert tool.name == tool_name
assert f"mcp:{created_server.server_name}" in tool.tags, f"Expected tag {f'mcp:{created_server.server_name}'}, got {tool.tags}"
print("TAGS", tool.tags)
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server_with_tools(mock_get_client, server, default_user):
"""Test that creating an MCP server automatically syncs and persists its tools."""
from letta.functions.mcp_client.types import MCPToolHealth
from letta.schemas.mcp import MCPServer, MCPServerType
from letta.settings import tool_settings
if tool_settings.mcp_read_from_config:
return
# Create mock tools with different health statuses
mock_tools = [
MCPTool(
name="valid_tool_1",
description="A valid tool",
inputSchema={
"type": "object",
"properties": {
"param1": {"type": "string"},
},
"required": ["param1"],
},
health=MCPToolHealth(status="VALID", reasons=[]),
),
MCPTool(
name="valid_tool_2",
description="Another valid tool",
inputSchema={
"type": "object",
"properties": {
"param2": {"type": "number"},
},
},
health=MCPToolHealth(status="VALID", reasons=[]),
),
MCPTool(
name="invalid_tool",
description="An invalid tool that should be skipped",
inputSchema={
"type": "invalid_type", # Invalid schema
},
health=MCPToolHealth(status="INVALID", reasons=["Invalid schema type"]),
),
MCPTool(
name="warning_tool",
description="A tool with warnings but should still be persisted",
inputSchema={
"type": "object",
"properties": {},
},
health=MCPToolHealth(status="WARNING", reasons=["No properties defined"]),
),
]
# Create mock client
mock_client = AsyncMock()
mock_client.connect_to_server = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tools)
mock_client.cleanup = AsyncMock()
mock_get_client.return_value = mock_client
# Create MCP server config
server_name = f"test_server_{uuid.uuid4().hex[:8]}"
server_url = "https://test-with-tools.example.com/sse"
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
# Create server with tools using the new method
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
# Verify server was created
assert created_server.server_name == server_name
assert created_server.server_type == MCPServerType.SSE
assert created_server.server_url == server_url
# Verify tools were persisted (all except the invalid one)
# Get all tools and filter by checking metadata
all_tools = await server.tool_manager.list_tools_async(
actor=default_user, names=["valid_tool_1", "valid_tool_2", "warning_tool", "invalid_tool"]
)
# Filter tools that belong to our MCP server
persisted_tools = [
tool
for tool in all_tools
if tool.metadata_
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
]
# Should have 3 tools (2 valid + 1 warning, but not the invalid one)
assert len(persisted_tools) == 3, f"Expected 3 tools, got {len(persisted_tools)}"
# Check tool names
tool_names = {tool.name for tool in persisted_tools}
assert "valid_tool_1" in tool_names
assert "valid_tool_2" in tool_names
assert "warning_tool" in tool_names
assert "invalid_tool" not in tool_names # Invalid tool should be filtered out
# Verify each tool has correct metadata
for tool in persisted_tools:
assert tool.metadata_ is not None
assert MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
assert tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == server_name
assert tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_id"] == created_server.id
assert tool.tool_type == ToolType.EXTERNAL_MCP
# Clean up - delete the server
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
# Verify tools were also deleted (cascade) by trying to get them again
remaining_tools = await server.tool_manager.list_tools_async(actor=default_user, names=["valid_tool_1", "valid_tool_2", "warning_tool"])
# Filter to see if any still belong to our deleted server
remaining_mcp_tools = [
tool
for tool in remaining_tools
if tool.metadata_
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
]
assert len(remaining_mcp_tools) == 0, "Tools should be deleted when server is deleted"
@pytest.mark.asyncio
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_complex_schema_normalization(mock_get_client, server, default_user):
"""Test that complex MCP schemas with nested objects are normalized and accepted."""
from letta.functions.mcp_client.types import MCPTool, MCPToolHealth
from letta.schemas.mcp import MCPServer, MCPServerType
from letta.settings import tool_settings
if tool_settings.mcp_read_from_config:
return
# Create mock tools with complex schemas that would normally be INVALID
# These schemas have: nested $defs, $ref references, missing additionalProperties
mock_tools = [
# 1. Nested object with $ref (like create_person)
MCPTool(
name="create_person",
description="Create a person with nested address",
inputSchema={
"$defs": {
"Address": {
"type": "object",
"properties": {
"street": {"type": "string"},
"city": {"type": "string"},
"zip_code": {"type": "string"},
},
"required": ["street", "city", "zip_code"],
},
"Person": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"address": {"$ref": "#/$defs/Address"},
},
"required": ["name", "age"],
},
},
"type": "object",
"properties": {"person": {"$ref": "#/$defs/Person"}},
"required": ["person"],
},
health=MCPToolHealth(
status="INVALID",
reasons=["root: 'additionalProperties' not explicitly set", "root.properties.person: Missing 'type'"],
),
),
# 2. List of objects (like manage_tasks)
MCPTool(
name="manage_tasks",
description="Manage multiple tasks",
inputSchema={
"$defs": {
"TaskItem": {
"type": "object",
"properties": {
"title": {"type": "string"},
"priority": {"type": "integer", "default": 1},
"completed": {"type": "boolean", "default": False},
"tags": {"type": "array", "items": {"type": "string"}},
},
"required": ["title"],
}
},
"type": "object",
"properties": {
"tasks": {
"type": "array",
"items": {"$ref": "#/$defs/TaskItem"},
}
},
"required": ["tasks"],
},
health=MCPToolHealth(
status="INVALID",
reasons=["root: 'additionalProperties' not explicitly set", "root.properties.tasks.items: Missing 'type'"],
),
),
# 3. Complex filter object with optional fields
MCPTool(
name="search_with_filters",
description="Search with complex filters",
inputSchema={
"$defs": {
"SearchFilter": {
"type": "object",
"properties": {
"keywords": {"type": "array", "items": {"type": "string"}},
"min_score": {"type": "number"},
"categories": {"type": "array", "items": {"type": "string"}},
},
"required": ["keywords"],
}
},
"type": "object",
"properties": {
"query": {"type": "string"},
"filters": {"$ref": "#/$defs/SearchFilter"},
},
"required": ["query", "filters"],
},
health=MCPToolHealth(
status="INVALID",
reasons=["root: 'additionalProperties' not explicitly set", "root.properties.filters: Missing 'type'"],
),
),
]
# Create mock client
mock_client = AsyncMock()
mock_client.connect_to_server = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tools)
mock_client.cleanup = AsyncMock()
mock_get_client.return_value = mock_client
# Create MCP server
server_name = f"test_complex_schema_{uuid.uuid4().hex[:8]}"
server_url = "https://test-complex.example.com/sse"
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
try:
# Create server (this will auto-sync tools)
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
assert created_server.server_name == server_name
# Now attempt to add each tool - they should be normalized from INVALID to acceptable
# The normalization happens in add_tool_from_mcp_server
# Test 1: create_person should normalize successfully
person_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "create_person", actor=default_user)
assert person_tool is not None
assert person_tool.name == "create_person"
# Verify the schema has additionalProperties set
assert person_tool.json_schema["parameters"]["additionalProperties"] == False
# Verify nested $defs have additionalProperties
if "$defs" in person_tool.json_schema["parameters"]:
for def_name, def_schema in person_tool.json_schema["parameters"]["$defs"].items():
if def_schema.get("type") == "object":
assert "additionalProperties" in def_schema, f"$defs.{def_name} missing additionalProperties after normalization"
# Test 2: manage_tasks should normalize successfully
tasks_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "manage_tasks", actor=default_user)
assert tasks_tool is not None
assert tasks_tool.name == "manage_tasks"
# Verify array items have explicit type
tasks_prop = tasks_tool.json_schema["parameters"]["properties"]["tasks"]
assert "items" in tasks_prop
assert "type" in tasks_prop["items"], "Array items should have explicit type after normalization"
# Test 3: search_with_filters should normalize successfully
search_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "search_with_filters", actor=default_user)
assert search_tool is not None
assert search_tool.name == "search_with_filters"
# Verify all tools were persisted
all_tools = await server.tool_manager.list_tools_async(
actor=default_user, names=["create_person", "manage_tasks", "search_with_filters"]
)
# Filter to tools from our MCP server
mcp_tools = [
tool
for tool in all_tools
if tool.metadata_
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
]
# All 3 complex schema tools should have been normalized and persisted
assert len(mcp_tools) == 3, f"Expected 3 normalized tools, got {len(mcp_tools)}"
# Verify they all have the correct MCP metadata
for tool in mcp_tools:
assert tool.tool_type == ToolType.EXTERNAL_MCP
assert f"mcp:{server_name}" in tool.tags
finally:
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server_with_tools_connection_failure(mock_get_client, server, default_user):
"""Test that MCP server creation succeeds even when tool sync fails (optimistic approach)."""
from letta.schemas.mcp import MCPServer, MCPServerType
from letta.settings import tool_settings
if tool_settings.mcp_read_from_config:
return
# Create mock client that fails to connect
mock_client = AsyncMock()
mock_client.connect_to_server = AsyncMock(side_effect=Exception("Connection failed"))
mock_client.cleanup = AsyncMock()
mock_get_client.return_value = mock_client
# Create MCP server config
server_name = f"test_server_fail_{uuid.uuid4().hex[:8]}"
server_url = "https://test-fail.example.com/sse"
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
# Create server with tools - should succeed despite connection failure
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
# Verify server was created successfully
assert created_server.server_name == server_name
assert created_server.server_type == MCPServerType.SSE
assert created_server.server_url == server_url
# Verify no tools were persisted (due to connection failure)
# Try to get tools by the names we would have expected
all_tools = await server.tool_manager.list_tools_async(
actor=default_user,
names=["tool1", "tool2", "tool3"], # Generic names since we don't know what tools would have been listed
)
# Filter to see if any belong to our server (there shouldn't be any)
persisted_tools = [
tool
for tool in all_tools
if tool.metadata_
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
]
assert len(persisted_tools) == 0, "No tools should be persisted when connection fails"
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
async def test_get_mcp_servers_by_ids(server, default_user):
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig
from letta.settings import tool_settings
if tool_settings.mcp_read_from_config:
return
# Create multiple MCP servers for testing
servers_data = [
{
"name": "test_server_1",
"config": StdioServerConfig(
server_name="test_server_1", type=MCPServerType.STDIO, command="echo 'test1'", args=["arg1"], env={"ENV1": "value1"}
),
"type": MCPServerType.STDIO,
},
{
"name": "test_server_2",
"config": SSEServerConfig(server_name="test_server_2", server_url="https://test2.example.com/sse"),
"type": MCPServerType.SSE,
},
{
"name": "test_server_3",
"config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/mcp"),
"type": MCPServerType.STREAMABLE_HTTP,
},
]
created_servers = []
for server_data in servers_data:
if server_data["type"] == MCPServerType.STDIO:
mcp_server = MCPServer(server_name=server_data["name"], server_type=server_data["type"], stdio_config=server_data["config"])
else:
mcp_server = MCPServer(
server_name=server_data["name"], server_type=server_data["type"], server_url=server_data["config"].server_url
)
created = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
created_servers.append(created)
# Test fetching multiple servers by IDs
server_ids = [s.id for s in created_servers]
fetched_servers = await server.mcp_manager.get_mcp_servers_by_ids(server_ids, actor=default_user)
assert len(fetched_servers) == len(created_servers)
fetched_ids = {s.id for s in fetched_servers}
expected_ids = {s.id for s in created_servers}
assert fetched_ids == expected_ids
# Test fetching subset of servers
subset_ids = server_ids[:2]
subset_servers = await server.mcp_manager.get_mcp_servers_by_ids(subset_ids, actor=default_user)
assert len(subset_servers) == 2
assert all(s.id in subset_ids for s in subset_servers)
# Test fetching with empty list
empty_result = await server.mcp_manager.get_mcp_servers_by_ids([], actor=default_user)
assert empty_result == []
# Test fetching with non-existent ID mixed with valid IDs
mixed_ids = [server_ids[0], "non-existent-id", server_ids[1]]
mixed_result = await server.mcp_manager.get_mcp_servers_by_ids(mixed_ids, actor=default_user)
# Should only return the existing servers
assert len(mixed_result) == 2
assert all(s.id in server_ids for s in mixed_result)
# Test that servers from different organizations are not returned
# This would require creating another user/org, but for now we'll just verify
# that the function respects the actor's organization
all_servers = await server.mcp_manager.list_mcp_servers(actor=default_user)
all_server_ids = [s.id for s in all_servers]
bulk_fetched = await server.mcp_manager.get_mcp_servers_by_ids(all_server_ids, actor=default_user)
# All fetched servers should belong to the same organization
assert all(s.organization_id == default_user.organization_id for s in bulk_fetched)
# Additional MCPManager OAuth session tests
@pytest.mark.asyncio
async def test_mcp_server_deletion_cascades_oauth_sessions(server, default_organization, default_user):
"""Deleting an MCP server deletes associated OAuth sessions (same user + URL)."""
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPServer as PydanticMCPServer, MCPServerType
test_server_url = "https://test.example.com/mcp"
# Create orphaned OAuth sessions (no server id) for same user and URL
created_session_ids: list[str] = []
for i in range(3):
session = await server.mcp_manager.create_oauth_session(
MCPOAuthSessionCreate(
server_url=test_server_url,
server_name=f"test_mcp_server_{i}",
user_id=default_user.id,
organization_id=default_organization.id,
),
actor=default_user,
)
created_session_ids.append(session.id)
# Create the MCP server with the same URL
created_server = await server.mcp_manager.create_mcp_server(
PydanticMCPServer(
server_name=f"test_mcp_server_{str(uuid.uuid4().hex[:8])}", # ensure unique name
server_type=MCPServerType.SSE,
server_url=test_server_url,
organization_id=default_organization.id,
),
actor=default_user,
)
# Now delete the server via manager
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
# Verify all sessions are gone
for sid in created_session_ids:
session = await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user)
assert session is None, f"OAuth session {sid} should be deleted"
@pytest.mark.asyncio
async def test_oauth_sessions_with_different_url_persist(server, default_organization, default_user):
"""Sessions with different URL should not be deleted when deleting the server for another URL."""
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPServer as PydanticMCPServer, MCPServerType
server_url = "https://test.example.com/mcp"
other_url = "https://other.example.com/mcp"
# Create a session for other_url (should persist)
other_session = await server.mcp_manager.create_oauth_session(
MCPOAuthSessionCreate(
server_url=other_url,
server_name="standalone_oauth",
user_id=default_user.id,
organization_id=default_organization.id,
),
actor=default_user,
)
# Create the MCP server at server_url
created_server = await server.mcp_manager.create_mcp_server(
PydanticMCPServer(
server_name=f"test_mcp_server_{str(uuid.uuid4().hex[:8])}",
server_type=MCPServerType.SSE,
server_url=server_url,
organization_id=default_organization.id,
),
actor=default_user,
)
# Delete the server at server_url
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
# Verify the session at other_url still exists
persisted = await server.mcp_manager.get_oauth_session_by_id(other_session.id, actor=default_user)
assert persisted is not None, "OAuth session with different URL should persist"
@pytest.mark.asyncio
async def test_mcp_server_creation_links_orphaned_sessions(server, default_organization, default_user):
"""Creating a server should link any existing orphaned sessions (same user + URL)."""
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPServer as PydanticMCPServer, MCPServerType
server_url = "https://test-atomic-create.example.com/mcp"
# Pre-create orphaned sessions (no server_id) for same user + URL
orphaned_ids: list[str] = []
for i in range(3):
session = await server.mcp_manager.create_oauth_session(
MCPOAuthSessionCreate(
server_url=server_url,
server_name=f"atomic_session_{i}",
user_id=default_user.id,
organization_id=default_organization.id,
),
actor=default_user,
)
orphaned_ids.append(session.id)
# Create server
created_server = await server.mcp_manager.create_mcp_server(
PydanticMCPServer(
server_name=f"test_atomic_server_{str(uuid.uuid4().hex[:8])}",
server_type=MCPServerType.SSE,
server_url=server_url,
organization_id=default_organization.id,
),
actor=default_user,
)
# Sessions should still be retrievable via manager API
for sid in orphaned_ids:
s = await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user)
assert s is not None
# Indirect verification: deleting the server removes sessions for that URL+user
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
for sid in orphaned_ids:
assert await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user) is None
@pytest.mark.asyncio
async def test_mcp_server_delete_removes_all_sessions_for_url_and_user(server, default_organization, default_user):
"""Deleting a server removes both linked and orphaned sessions for same user+URL."""
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPServer as PydanticMCPServer, MCPServerType
server_url = "https://test-atomic-cleanup.example.com/mcp"
# Create orphaned session
orphaned = await server.mcp_manager.create_oauth_session(
MCPOAuthSessionCreate(
server_url=server_url,
server_name="orphaned",
user_id=default_user.id,
organization_id=default_organization.id,
),
actor=default_user,
)
# Create server
created_server = await server.mcp_manager.create_mcp_server(
PydanticMCPServer(
server_name=f"cleanup_server_{str(uuid.uuid4().hex[:8])}",
server_type=MCPServerType.SSE,
server_url=server_url,
organization_id=default_organization.id,
),
actor=default_user,
)
# Delete server
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
# Both orphaned and any linked sessions for that URL+user should be gone
assert await server.mcp_manager.get_oauth_session_by_id(orphaned.id, actor=default_user) is None
@pytest.mark.asyncio
async def test_mcp_server_resync_tools(server, default_user, default_organization):
"""Test that resyncing MCP server tools correctly handles added, deleted, and updated tools."""
from unittest.mock import AsyncMock, patch
from letta.functions.mcp_client.types import MCPTool, MCPToolHealth
from letta.schemas.mcp import MCPServer as PydanticMCPServer, MCPServerType
from letta.schemas.tool import ToolCreate
# Create MCP server
mcp_server = await server.mcp_manager.create_mcp_server(
PydanticMCPServer(
server_name=f"test_resync_{uuid.uuid4().hex[:8]}",
server_type=MCPServerType.SSE,
server_url="https://test-resync.example.com/mcp",
organization_id=default_organization.id,
),
actor=default_user,
)
mcp_server_id = mcp_server.id
try:
# Create initial persisted tools (simulating previously added tools)
# Use sync method like in the existing mcp_tool fixture
tool1_create = ToolCreate.from_mcp(
mcp_server_name=mcp_server.server_name,
mcp_tool=MCPTool(
name="tool1",
description="Tool 1",
inputSchema={"type": "object", "properties": {"param1": {"type": "string"}}},
),
)
tool1 = await server.tool_manager.create_or_update_mcp_tool_async(
tool_create=tool1_create,
mcp_server_name=mcp_server.server_name,
mcp_server_id=mcp_server_id,
actor=default_user,
)
tool2_create = ToolCreate.from_mcp(
mcp_server_name=mcp_server.server_name,
mcp_tool=MCPTool(
name="tool2",
description="Tool 2 to be deleted",
inputSchema={"type": "object", "properties": {"param2": {"type": "number"}}},
),
)
tool2 = await server.tool_manager.create_or_update_mcp_tool_async(
tool_create=tool2_create,
mcp_server_name=mcp_server.server_name,
mcp_server_id=mcp_server_id,
actor=default_user,
)
# Mock the list_mcp_server_tools to return updated tools from server
# tool1 is updated, tool2 is deleted, tool3 is added
updated_tools = [
MCPTool(
name="tool1",
description="Tool 1 Updated",
inputSchema={"type": "object", "properties": {"param1": {"type": "string"}, "param1b": {"type": "boolean"}}},
health=MCPToolHealth(status="VALID", reasons=[]),
),
MCPTool(
name="tool3",
description="Tool 3 New",
inputSchema={"type": "object", "properties": {"param3": {"type": "array"}}},
health=MCPToolHealth(status="VALID", reasons=[]),
),
]
with patch.object(server.mcp_manager, "list_mcp_server_tools", new_callable=AsyncMock) as mock_list_tools:
mock_list_tools.return_value = updated_tools
# Run resync
result = await server.mcp_manager.resync_mcp_server_tools(
mcp_server_name=mcp_server.server_name,
actor=default_user,
)
# Verify the resync result
assert len(result.deleted) == 1
assert "tool2" in result.deleted
assert len(result.updated) == 1
assert "tool1" in result.updated
assert len(result.added) == 1
assert "tool3" in result.added
# Verify tool2 was actually deleted
try:
deleted_tool = await server.tool_manager.get_tool_by_id_async(tool_id=tool2.id, actor=default_user)
assert False, "Tool2 should have been deleted"
except Exception:
pass # Expected - tool should be deleted
# Verify tool1 was updated with new schema
updated_tool1 = await server.tool_manager.get_tool_by_id_async(tool_id=tool1.id, actor=default_user)
assert "param1b" in updated_tool1.json_schema["parameters"]["properties"]
# Verify tool3 was added
tools = await server.tool_manager.list_tools_async(actor=default_user, names=["tool3"])
assert len(tools) == 1
assert tools[0].name == "tool3"
finally:
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(mcp_server_id, actor=default_user)
# ======================================================================================================================
# MCPManager Tests - Encryption
# ======================================================================================================================
@pytest.fixture
def encryption_key():
"""Fixture to ensure encryption key is set for tests."""
original_key = settings.encryption_key
# Set a test encryption key if not already set
if not settings.encryption_key:
settings.encryption_key = "test-encryption-key-32-bytes!!"
yield settings.encryption_key
# Restore original
settings.encryption_key = original_key
@pytest.mark.asyncio
async def test_mcp_server_token_encryption_on_create(server, default_user, encryption_key):
"""Test that creating an MCP server encrypts the token in the database."""
from letta.functions.mcp_client.types import MCPServerType
from letta.orm.mcp_server import MCPServer as MCPServerModel
from letta.schemas.mcp import MCPServer
from letta.schemas.secret import Secret
# Create MCP server with token
mcp_server = MCPServer(
server_name="test-encrypted-server",
server_type=MCPServerType.STREAMABLE_HTTP,
server_url="https://api.example.com/mcp",
token="sk-test-secret-token-12345",
)
created_server = await server.mcp_manager.create_mcp_server(mcp_server, actor=default_user)
try:
# Verify server was created
assert created_server is not None
assert created_server.server_name == "test-encrypted-server"
# Verify plaintext token field is NOT set (no dual-write)
assert created_server.token is None
# Verify token_enc is a Secret object and decrypts correctly
assert created_server.token_enc is not None
assert isinstance(created_server.token_enc, Secret)
assert created_server.token_enc.get_plaintext() == "sk-test-secret-token-12345"
# Read directly from database to verify encryption
async with db_registry.async_session() as session:
server_orm = await MCPServerModel.read_async(
db_session=session,
identifier=created_server.id,
actor=default_user,
)
# Verify encrypted column is populated and different from plaintext
assert server_orm.token_enc is not None
assert server_orm.token_enc != "sk-test-secret-token-12345"
# Encrypted value should be longer
assert len(server_orm.token_enc) > len("sk-test-secret-token-12345")
finally:
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
@pytest.mark.asyncio
async def test_mcp_server_token_decryption_on_read(server, default_user, encryption_key):
"""Test that reading an MCP server decrypts the token correctly."""
from letta.functions.mcp_client.types import MCPServerType
from letta.schemas.mcp import MCPServer
from letta.schemas.secret import Secret
# Create MCP server
mcp_server = MCPServer(
server_name="test-decrypt-server",
server_type=MCPServerType.STREAMABLE_HTTP,
server_url="https://api.example.com/mcp",
token="sk-test-decrypt-token-67890",
)
created_server = await server.mcp_manager.create_mcp_server(mcp_server, actor=default_user)
server_id = created_server.id
try:
# Read the server back
retrieved_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user)
# Verify plaintext token field is NOT set (no dual-write)
assert retrieved_server.token is None
# Verify the token is decrypted correctly via token_enc
assert retrieved_server.token_enc is not None
assert retrieved_server.token_enc.get_plaintext() == "sk-test-decrypt-token-67890"
# Verify we can get the decrypted token through the secret getter
token_secret = retrieved_server.get_token_secret()
assert isinstance(token_secret, Secret)
decrypted_token = token_secret.get_plaintext()
assert decrypted_token == "sk-test-decrypt-token-67890"
finally:
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(server_id, actor=default_user)
@pytest.mark.asyncio
async def test_mcp_server_custom_headers_encryption(server, default_user, encryption_key):
"""Test that custom headers are encrypted as JSON strings."""
from letta.functions.mcp_client.types import MCPServerType
from letta.orm.mcp_server import MCPServer as MCPServerModel
from letta.schemas.mcp import MCPServer
from letta.schemas.secret import Secret
# Create MCP server with custom headers
custom_headers = {"Authorization": "Bearer token123", "X-API-Key": "secret-key-456"}
mcp_server = MCPServer(
server_name="test-headers-server",
server_type=MCPServerType.STREAMABLE_HTTP,
server_url="https://api.example.com/mcp",
custom_headers=custom_headers,
)
created_server = await server.mcp_manager.create_mcp_server(mcp_server, actor=default_user)
try:
# Verify plaintext custom_headers field is NOT set (no dual-write)
assert created_server.custom_headers is None
# Verify custom_headers are accessible via encrypted field
assert created_server.get_custom_headers_dict() == custom_headers
# Verify custom_headers_enc is a Secret object (stores JSON string)
assert created_server.custom_headers_enc is not None
assert isinstance(created_server.custom_headers_enc, Secret)
# Verify the getter method returns a Secret (JSON string)
headers_secret = created_server.get_custom_headers_secret()
assert isinstance(headers_secret, Secret)
# Verify the Secret contains JSON string
json_str = headers_secret.get_plaintext()
assert json_str is not None
import json
assert json.loads(json_str) == custom_headers
# Verify the convenience method returns dict directly
headers_dict = created_server.get_custom_headers_dict()
assert headers_dict == custom_headers
# Read from DB to verify encryption
async with db_registry.async_session() as session:
server_orm = await MCPServerModel.read_async(
db_session=session,
identifier=created_server.id,
actor=default_user,
)
# Verify encrypted column contains encrypted JSON string
assert server_orm.custom_headers_enc is not None
# Decrypt and verify it's valid JSON matching original headers
decrypted_json = Secret.from_encrypted(server_orm.custom_headers_enc).get_plaintext()
import json
decrypted_headers = json.loads(decrypted_json)
assert decrypted_headers == custom_headers
finally:
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
@pytest.mark.asyncio
async def test_oauth_session_tokens_encryption(server, default_user, encryption_key):
"""Test that OAuth session tokens are encrypted in the database."""
from letta.orm.mcp_oauth import MCPOAuth as MCPOAuthModel
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPOAuthSessionUpdate
from letta.schemas.secret import Secret
# Create OAuth session
session_create = MCPOAuthSessionCreate(
server_url="https://oauth.example.com",
server_name="test-oauth-server",
organization_id=default_user.organization_id,
user_id=default_user.id,
)
created_session = await server.mcp_manager.create_oauth_session(session_create, actor=default_user)
session_id = created_session.id
try:
# Update with OAuth tokens
session_update = MCPOAuthSessionUpdate(
access_token="access-token-abc123",
refresh_token="refresh-token-xyz789",
client_secret="client-secret-def456",
authorization_code="auth-code-ghi012",
)
updated_session = await server.mcp_manager.update_oauth_session(session_id, session_update, actor=default_user)
# Verify tokens are accessible
assert updated_session.access_token == "access-token-abc123"
assert updated_session.refresh_token == "refresh-token-xyz789"
assert updated_session.client_secret == "client-secret-def456"
assert updated_session.authorization_code == "auth-code-ghi012"
# Verify encrypted fields are Secret objects
assert isinstance(updated_session.access_token_enc, Secret)
assert isinstance(updated_session.refresh_token_enc, Secret)
assert isinstance(updated_session.client_secret_enc, Secret)
assert isinstance(updated_session.authorization_code_enc, Secret)
# Read from DB to verify all tokens are encrypted
async with db_registry.async_session() as session:
oauth_orm = await MCPOAuthModel.read_async(
db_session=session,
identifier=session_id,
actor=default_user,
)
# Verify all encrypted columns are populated and encrypted
assert oauth_orm.access_token_enc is not None
assert oauth_orm.refresh_token_enc is not None
assert oauth_orm.client_secret_enc is not None
assert oauth_orm.authorization_code_enc is not None
# Decrypt and verify
assert Secret.from_encrypted(oauth_orm.access_token_enc).get_plaintext() == "access-token-abc123"
assert Secret.from_encrypted(oauth_orm.refresh_token_enc).get_plaintext() == "refresh-token-xyz789"
assert Secret.from_encrypted(oauth_orm.client_secret_enc).get_plaintext() == "client-secret-def456"
assert Secret.from_encrypted(oauth_orm.authorization_code_enc).get_plaintext() == "auth-code-ghi012"
finally:
# Clean up
await server.mcp_manager.delete_oauth_session(session_id, actor=default_user)