* base * use secret field * fix * auth code * stage publish * decouple backfill * revert uncomment * providers and agent vars * mcp * mcp * stage and publish * fix oauth * double encrypt * sandbox --------- Co-authored-by: Letta Bot <noreply@letta.com>
454 lines
19 KiB
Python
454 lines
19 KiB
Python
"""
|
|
Integration tests for MCP server and OAuth session encryption.
|
|
Tests the end-to-end encryption functionality in the MCP manager.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
|
|
from letta.config import LettaConfig
|
|
from letta.helpers.crypto_utils import CryptoUtils
|
|
from letta.orm import MCPOAuth, MCPServer as ORMMCPServer
|
|
from letta.schemas.mcp import (
|
|
MCPOAuthSessionCreate,
|
|
MCPOAuthSessionUpdate,
|
|
MCPServer as PydanticMCPServer,
|
|
MCPServerType,
|
|
SSEServerConfig,
|
|
StdioServerConfig,
|
|
)
|
|
from letta.schemas.secret import Secret
|
|
from letta.server.db import db_registry
|
|
from letta.server.server import SyncServer
|
|
from letta.services.mcp_manager import MCPManager
|
|
from letta.settings import settings
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def server():
|
|
"""Fixture to create and return a SyncServer instance with MCP manager."""
|
|
config = LettaConfig.load()
|
|
config.save()
|
|
server = SyncServer(init_with_default_org_and_user=False)
|
|
return server
|
|
|
|
|
|
class TestMCPServerEncryption:
|
|
"""Test MCP server encryption functionality."""
|
|
|
|
MOCK_ENCRYPTION_KEY = "test-mcp-encryption-key-123456"
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
|
async def test_create_mcp_server_with_token_encryption(self, mock_get_client, server, default_user):
|
|
"""Test that MCP server tokens are encrypted when stored."""
|
|
# Set encryption key directly on settings
|
|
original_key = settings.encryption_key
|
|
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
|
|
|
|
try:
|
|
# Mock the MCP client
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools.return_value = []
|
|
mock_get_client.return_value = mock_client
|
|
|
|
# Create MCP server with token
|
|
server_name = f"test_encrypted_server_{uuid4().hex[:8]}"
|
|
token = "super-secret-api-token-12345"
|
|
server_url = "https://api.example.com/mcp"
|
|
|
|
mcp_server = PydanticMCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url, token=token)
|
|
|
|
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
|
|
|
|
# Verify server was created
|
|
assert created_server.server_name == server_name
|
|
assert created_server.server_type == MCPServerType.SSE
|
|
|
|
# Check database directly to verify encryption
|
|
async with db_registry.async_session() as session:
|
|
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
|
|
db_server = result.scalar_one()
|
|
|
|
# Token should be encrypted in database
|
|
assert db_server.token_enc is not None
|
|
assert db_server.token_enc != token # Should not be plaintext
|
|
|
|
# Decrypt to verify correctness
|
|
decrypted_token = CryptoUtils.decrypt(db_server.token_enc)
|
|
assert decrypted_token == token
|
|
|
|
# Legacy plaintext column should be None (or empty for dual-write)
|
|
# During migration phase, might store both
|
|
if db_server.token:
|
|
assert db_server.token == token # Dual-write phase
|
|
|
|
# Clean up
|
|
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
|
|
|
finally:
|
|
# Restore original encryption key
|
|
settings.encryption_key = original_key
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
|
async def test_create_mcp_server_with_custom_headers_encryption(self, mock_get_client, server, default_user):
|
|
"""Test that MCP server custom headers are encrypted when stored."""
|
|
# Set encryption key directly on settings
|
|
original_key = settings.encryption_key
|
|
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
|
|
|
|
try:
|
|
# Mock the MCP client
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools.return_value = []
|
|
mock_get_client.return_value = mock_client
|
|
|
|
server_name = f"test_headers_server_{uuid4().hex[:8]}"
|
|
custom_headers = {"Authorization": "Bearer secret-token-xyz", "X-API-Key": "api-key-123456", "X-Custom-Header": "custom-value"}
|
|
server_url = "https://api.example.com/mcp"
|
|
|
|
mcp_server = PydanticMCPServer(
|
|
server_name=server_name, server_type=MCPServerType.STREAMABLE_HTTP, server_url=server_url, custom_headers=custom_headers
|
|
)
|
|
|
|
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
|
|
|
|
# Check database directly
|
|
async with db_registry.async_session() as session:
|
|
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
|
|
db_server = result.scalar_one()
|
|
|
|
# Custom headers should be encrypted as JSON
|
|
assert db_server.custom_headers_enc is not None
|
|
|
|
# Decrypt and parse JSON
|
|
decrypted_json = CryptoUtils.decrypt(db_server.custom_headers_enc)
|
|
decrypted_headers = json.loads(decrypted_json)
|
|
assert decrypted_headers == custom_headers
|
|
|
|
# Clean up
|
|
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
|
|
|
finally:
|
|
# Restore original encryption key
|
|
settings.encryption_key = original_key
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_mcp_server_decrypts_values(self, server, default_user):
|
|
"""Test that retrieving MCP server decrypts encrypted values."""
|
|
# Set encryption key directly on settings
|
|
original_key = settings.encryption_key
|
|
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
|
|
|
|
try:
|
|
# Manually insert encrypted server into database
|
|
server_id = f"mcp_server-{uuid4().hex[:8]}"
|
|
server_name = f"test_decrypt_server_{uuid4().hex[:8]}"
|
|
plaintext_token = "decryption-test-token"
|
|
encrypted_token = CryptoUtils.encrypt(plaintext_token)
|
|
|
|
async with db_registry.async_session() as session:
|
|
db_server = ORMMCPServer(
|
|
id=server_id,
|
|
server_name=server_name,
|
|
server_type=MCPServerType.SSE.value,
|
|
server_url="https://test.com",
|
|
token_enc=encrypted_token,
|
|
token=None, # No plaintext
|
|
created_by_id=default_user.id,
|
|
last_updated_by_id=default_user.id,
|
|
organization_id=default_user.organization_id,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(db_server)
|
|
await session.commit()
|
|
|
|
# Retrieve server directly by ID to avoid issues with other servers in DB
|
|
test_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user)
|
|
|
|
assert test_server is not None
|
|
assert test_server.server_name == server_name
|
|
# Token should be decrypted when accessed via the secret method
|
|
token_secret = test_server.get_token_secret()
|
|
assert token_secret.get_plaintext() == plaintext_token
|
|
|
|
# Clean up
|
|
async with db_registry.async_session() as session:
|
|
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == server_id))
|
|
db_server = result.scalar_one()
|
|
await session.delete(db_server)
|
|
await session.commit()
|
|
|
|
finally:
|
|
# Restore original encryption key
|
|
settings.encryption_key = original_key
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(os.environ, {}, clear=True) # No encryption key
|
|
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
|
async def test_create_mcp_server_without_encryption_key(self, mock_get_client, server, default_user):
|
|
"""Test that MCP servers work without encryption key (backward compatibility)."""
|
|
# Remove encryption key
|
|
os.environ.pop("LETTA_ENCRYPTION_KEY", None)
|
|
|
|
# Mock the MCP client
|
|
mock_client = AsyncMock()
|
|
mock_client.list_tools.return_value = []
|
|
mock_get_client.return_value = mock_client
|
|
|
|
server_name = f"test_no_encrypt_server_{uuid4().hex[:8]}"
|
|
token = "plaintext-token-no-encryption"
|
|
|
|
mcp_server = PydanticMCPServer(
|
|
server_name=server_name, server_type=MCPServerType.SSE, server_url="https://api.example.com", token=token
|
|
)
|
|
|
|
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
|
|
|
|
# Check database - should store as plaintext
|
|
async with db_registry.async_session() as session:
|
|
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
|
|
db_server = result.scalar_one()
|
|
|
|
# Should store in plaintext column
|
|
assert db_server.token == token
|
|
assert db_server.token_enc is None # No encryption
|
|
|
|
# Clean up
|
|
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
|
|
|
|
|
class TestMCPOAuthEncryption:
|
|
"""Test MCP OAuth session encryption functionality."""
|
|
|
|
MOCK_ENCRYPTION_KEY = "test-oauth-encryption-key-123456"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_oauth_session_with_encryption(self, server, default_user):
|
|
"""Test that OAuth tokens are encrypted when stored."""
|
|
# Set encryption key directly on settings
|
|
original_key = settings.encryption_key
|
|
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
|
|
|
|
try:
|
|
server_url = "https://github.com/mcp"
|
|
server_name = "GitHub MCP"
|
|
|
|
# Step 1: Create OAuth session (without tokens initially)
|
|
oauth_session_create = MCPOAuthSessionCreate(
|
|
server_url=server_url,
|
|
server_name=server_name,
|
|
organization_id=default_user.organization_id,
|
|
user_id=default_user.id,
|
|
)
|
|
|
|
created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user)
|
|
|
|
assert created_session.server_url == server_url
|
|
assert created_session.server_name == server_name
|
|
|
|
# Step 2: Update session with tokens (simulating OAuth callback)
|
|
update_data = MCPOAuthSessionUpdate(
|
|
access_token="github-access-token-abc123",
|
|
refresh_token="github-refresh-token-xyz789",
|
|
client_id="client-id-123",
|
|
client_secret="client-secret-super-secret",
|
|
expires_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
await server.mcp_manager.update_oauth_session(created_session.id, update_data, actor=default_user)
|
|
|
|
# Check database directly for encryption
|
|
async with db_registry.async_session() as session:
|
|
result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id))
|
|
db_oauth = result.scalar_one()
|
|
|
|
# All sensitive fields should be encrypted
|
|
assert db_oauth.access_token_enc is not None
|
|
assert db_oauth.access_token_enc != update_data.access_token
|
|
|
|
assert db_oauth.refresh_token_enc is not None
|
|
|
|
assert db_oauth.client_secret_enc is not None
|
|
|
|
# Verify decryption
|
|
decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc)
|
|
assert decrypted_access == update_data.access_token
|
|
|
|
decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc)
|
|
assert decrypted_refresh == update_data.refresh_token
|
|
|
|
decrypted_secret = CryptoUtils.decrypt(db_oauth.client_secret_enc)
|
|
assert decrypted_secret == update_data.client_secret
|
|
|
|
finally:
|
|
# Restore original encryption key
|
|
settings.encryption_key = original_key
|
|
|
|
# Clean up not needed - test database is reset
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_oauth_session_decrypts_tokens(self, server, default_user):
|
|
"""Test that retrieving OAuth session decrypts tokens."""
|
|
# Set encryption key directly on settings
|
|
original_key = settings.encryption_key
|
|
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
|
|
|
|
try:
|
|
# Manually insert encrypted OAuth session
|
|
session_id = f"mcp-oauth-{str(uuid4())[:8]}"
|
|
access_token = "test-access-token"
|
|
refresh_token = "test-refresh-token"
|
|
client_secret = "test-client-secret"
|
|
|
|
encrypted_access = CryptoUtils.encrypt(access_token)
|
|
encrypted_refresh = CryptoUtils.encrypt(refresh_token)
|
|
encrypted_secret = CryptoUtils.encrypt(client_secret)
|
|
|
|
async with db_registry.async_session() as session:
|
|
db_oauth = MCPOAuth(
|
|
id=session_id,
|
|
state=f"test-state-{uuid4().hex[:8]}",
|
|
server_url="https://test.com/mcp",
|
|
server_name="Test Provider",
|
|
access_token_enc=encrypted_access,
|
|
refresh_token_enc=encrypted_refresh,
|
|
client_id="test-client",
|
|
client_secret_enc=encrypted_secret,
|
|
user_id=default_user.id,
|
|
organization_id=default_user.organization_id,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(db_oauth)
|
|
await session.commit()
|
|
|
|
# Retrieve through manager by ID
|
|
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
|
|
assert test_session is not None
|
|
|
|
# Tokens should be decrypted
|
|
assert test_session.access_token == access_token
|
|
assert test_session.refresh_token == refresh_token
|
|
assert test_session.client_secret == client_secret
|
|
|
|
# Clean up not needed - test database is reset
|
|
|
|
finally:
|
|
# Restore original encryption key
|
|
settings.encryption_key = original_key
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_oauth_session_maintains_encryption(self, server, default_user):
|
|
"""Test that updating OAuth session maintains encryption."""
|
|
# Set encryption key directly on settings
|
|
original_key = settings.encryption_key
|
|
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
|
|
|
|
try:
|
|
# Create initial session (without tokens)
|
|
oauth_session_create = MCPOAuthSessionCreate(
|
|
server_url="https://test.com/mcp",
|
|
server_name="Test Update Provider",
|
|
organization_id=default_user.organization_id,
|
|
user_id=default_user.id,
|
|
)
|
|
|
|
created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user)
|
|
|
|
# Add initial tokens
|
|
initial_update = MCPOAuthSessionUpdate(
|
|
access_token="initial-token",
|
|
refresh_token="initial-refresh",
|
|
client_id="client-123",
|
|
client_secret="initial-secret",
|
|
)
|
|
|
|
await server.mcp_manager.update_oauth_session(created_session.id, initial_update, actor=default_user)
|
|
|
|
# Update with new tokens
|
|
new_access_token = "updated-access-token"
|
|
new_refresh_token = "updated-refresh-token"
|
|
|
|
new_update = MCPOAuthSessionUpdate(
|
|
access_token=new_access_token,
|
|
refresh_token=new_refresh_token,
|
|
)
|
|
|
|
updated_session = await server.mcp_manager.update_oauth_session(created_session.id, new_update, actor=default_user)
|
|
|
|
# Verify update worked
|
|
assert updated_session.access_token == new_access_token
|
|
assert updated_session.refresh_token == new_refresh_token
|
|
|
|
# Check database encryption
|
|
async with db_registry.async_session() as session:
|
|
result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id))
|
|
db_oauth = result.scalar_one()
|
|
|
|
# New tokens should be encrypted
|
|
decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc)
|
|
assert decrypted_access == new_access_token
|
|
|
|
decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc)
|
|
assert decrypted_refresh == new_refresh_token
|
|
|
|
# Clean up not needed - test database is reset
|
|
|
|
finally:
|
|
# Restore original encryption key
|
|
settings.encryption_key = original_key
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dual_read_backward_compatibility(self, server, default_user):
|
|
"""Test that system can read both encrypted and plaintext values (migration support)."""
|
|
# Set encryption key directly on settings
|
|
original_key = settings.encryption_key
|
|
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
|
|
|
|
try:
|
|
# Insert a record with both encrypted and plaintext values
|
|
session_id = f"mcp-oauth-{str(uuid4())[:8]}"
|
|
plaintext_token = "legacy-plaintext-token"
|
|
new_encrypted_token = "new-encrypted-token"
|
|
encrypted_new = CryptoUtils.encrypt(new_encrypted_token)
|
|
|
|
async with db_registry.async_session() as session:
|
|
db_oauth = MCPOAuth(
|
|
id=session_id,
|
|
state=f"dual-read-state-{uuid4().hex[:8]}",
|
|
server_url="https://test.com/mcp",
|
|
server_name="Dual Read Test",
|
|
# Both encrypted and plaintext values
|
|
access_token=plaintext_token, # Legacy plaintext
|
|
access_token_enc=encrypted_new, # New encrypted
|
|
client_id="test-client",
|
|
user_id=default_user.id,
|
|
organization_id=default_user.organization_id,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(db_oauth)
|
|
await session.commit()
|
|
|
|
# Retrieve through manager
|
|
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
|
|
assert test_session is not None
|
|
|
|
# Should prefer encrypted value over plaintext
|
|
assert test_session.access_token == new_encrypted_token
|
|
|
|
# Clean up not needed - test database is reset
|
|
|
|
finally:
|
|
# Restore original encryption key
|
|
settings.encryption_key = original_key
|