Files
letta-server/tests/test_mcp_encryption.py
cthomas 9a95a8f976 fix: duplicate session commit in step logging (#7512)
* fix: duplicate session commit in step logging

* update all callsites
2026-01-12 10:57:19 -08:00

527 lines
23 KiB
Python

"""
Integration tests for MCP server and OAuth session encryption.
Tests the end-to-end encryption functionality in the MCP manager.
"""
import json
import os
from datetime import datetime, timezone
from unittest.mock import AsyncMock, Mock, patch
from uuid import uuid4
import pytest
from sqlalchemy import select
from letta.config import LettaConfig
from letta.helpers.crypto_utils import CryptoUtils
from letta.orm import MCPOAuth, MCPServer as ORMMCPServer
from letta.schemas.mcp import (
MCPOAuthSessionCreate,
MCPOAuthSessionUpdate,
MCPServer as PydanticMCPServer,
MCPServerType,
SSEServerConfig,
StdioServerConfig,
)
from letta.schemas.secret import Secret
from letta.server.db import db_registry
from letta.server.server import SyncServer
from letta.services.mcp_manager import MCPManager
from letta.settings import settings
@pytest.fixture(scope="module")
def server():
"""Fixture to create and return a SyncServer instance with MCP manager."""
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=False)
return server
class TestMCPServerEncryption:
"""Test MCP server encryption functionality."""
MOCK_ENCRYPTION_KEY = "test-mcp-encryption-key-123456"
@pytest.mark.asyncio
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server_with_token_encryption(self, mock_get_client, server, default_user):
"""Test that MCP server tokens are encrypted when stored."""
# Set encryption key directly on settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
try:
# Mock the MCP client
mock_client = AsyncMock()
mock_client.list_tools.return_value = []
mock_get_client.return_value = mock_client
# Create MCP server with token
server_name = f"test_encrypted_server_{uuid4().hex[:8]}"
token = "super-secret-api-token-12345"
server_url = "https://api.example.com/mcp"
mcp_server = PydanticMCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url, token=token)
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
# Verify server was created
assert created_server.server_name == server_name
assert created_server.server_type == MCPServerType.SSE
# Check database directly to verify encryption
async with db_registry.async_session() as session:
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
db_server = result.scalar_one()
# Token should be encrypted in database
assert db_server.token_enc is not None
assert db_server.token_enc != token # Should not be plaintext
# Decrypt to verify correctness
decrypted_token = CryptoUtils.decrypt(db_server.token_enc)
assert decrypted_token == token
# Plaintext column should NOT be written to (encrypted-only)
assert db_server.token is None
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
finally:
# Restore original encryption key
settings.encryption_key = original_key
@pytest.mark.asyncio
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server_with_custom_headers_encryption(self, mock_get_client, server, default_user):
"""Test that MCP server custom headers are encrypted when stored."""
# Set encryption key directly on settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
try:
# Mock the MCP client
mock_client = AsyncMock()
mock_client.list_tools.return_value = []
mock_get_client.return_value = mock_client
server_name = f"test_headers_server_{uuid4().hex[:8]}"
custom_headers = {"Authorization": "Bearer secret-token-xyz", "X-API-Key": "api-key-123456", "X-Custom-Header": "custom-value"}
server_url = "https://api.example.com/mcp"
mcp_server = PydanticMCPServer(
server_name=server_name, server_type=MCPServerType.STREAMABLE_HTTP, server_url=server_url, custom_headers=custom_headers
)
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
# Check database directly
async with db_registry.async_session() as session:
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
db_server = result.scalar_one()
# Custom headers should be encrypted as JSON
assert db_server.custom_headers_enc is not None
# Decrypt and parse JSON
decrypted_json = CryptoUtils.decrypt(db_server.custom_headers_enc)
decrypted_headers = json.loads(decrypted_json)
assert decrypted_headers == custom_headers
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
finally:
# Restore original encryption key
settings.encryption_key = original_key
@pytest.mark.asyncio
async def test_retrieve_mcp_server_decrypts_values(self, server, default_user):
"""Test that retrieving MCP server decrypts encrypted values."""
# Set encryption key directly on settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
try:
# Manually insert encrypted server into database
server_id = f"mcp_server-{uuid4().hex[:8]}"
server_name = f"test_decrypt_server_{uuid4().hex[:8]}"
plaintext_token = "decryption-test-token"
encrypted_token = CryptoUtils.encrypt(plaintext_token)
async with db_registry.async_session() as session:
db_server = ORMMCPServer(
id=server_id,
server_name=server_name,
server_type=MCPServerType.SSE.value,
server_url="https://test.com",
token_enc=encrypted_token,
token=None, # No plaintext
created_by_id=default_user.id,
last_updated_by_id=default_user.id,
organization_id=default_user.organization_id,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(db_server)
# context manager now handles commits
# await session.commit()
# Retrieve server directly by ID to avoid issues with other servers in DB
test_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user)
assert test_server is not None
assert test_server.server_name == server_name
# Token should be decrypted when accessed via the _enc column
assert test_server.token_enc is not None
assert test_server.token_enc.get_plaintext() == plaintext_token
# Clean up
async with db_registry.async_session() as session:
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == server_id))
db_server = result.scalar_one()
await session.delete(db_server)
# context manager now handles commits
# await session.commit()
finally:
# Restore original encryption key
settings.encryption_key = original_key
@pytest.mark.asyncio
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server_without_encryption_key_stores_plaintext(self, mock_get_client, server, default_user):
"""Test that MCP servers work without encryption key by storing plaintext in _enc column.
Note: In Phase 1 of migration, if no encryption key is configured, the value
is stored as plaintext directly in the _enc column. This allows users without
encryption keys to continue working while migrating off the old plaintext columns.
"""
# Save and clear encryption key
original_key = settings.encryption_key
settings.encryption_key = None
try:
# Mock the MCP client
mock_client = AsyncMock()
mock_client.list_tools.return_value = []
mock_get_client.return_value = mock_client
server_name = f"test_no_encrypt_server_{uuid4().hex[:8]}"
token = "plaintext-token-no-encryption"
mcp_server = PydanticMCPServer(
server_name=server_name, server_type=MCPServerType.SSE, server_url="https://api.example.com", token=token
)
# Should work without encryption key - stores plaintext in _enc column
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
# Check database - should store plaintext in _enc column (no encryption key)
async with db_registry.async_session() as session:
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
db_server = result.scalar_one()
# Token should be stored as plaintext in _enc column (not encrypted)
assert db_server.token_enc == token # Plaintext stored directly
# Plaintext column should NOT be written to (encrypted-only)
assert db_server.token is None
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
finally:
# Restore original encryption key
settings.encryption_key = original_key
class TestMCPOAuthEncryption:
"""Test MCP OAuth session encryption functionality."""
MOCK_ENCRYPTION_KEY = "test-oauth-encryption-key-123456"
@pytest.mark.asyncio
async def test_create_oauth_session_with_encryption(self, server, default_user):
"""Test that OAuth tokens are encrypted when stored."""
# Set encryption key directly on settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
try:
server_url = "https://github.com/mcp"
server_name = "GitHub MCP"
# Step 1: Create OAuth session (without tokens initially)
oauth_session_create = MCPOAuthSessionCreate(
server_url=server_url,
server_name=server_name,
organization_id=default_user.organization_id,
user_id=default_user.id,
)
created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user)
assert created_session.server_url == server_url
assert created_session.server_name == server_name
# Step 2: Update session with tokens (simulating OAuth callback)
update_data = MCPOAuthSessionUpdate(
access_token="github-access-token-abc123",
refresh_token="github-refresh-token-xyz789",
client_id="client-id-123",
client_secret="client-secret-super-secret",
expires_at=datetime.now(timezone.utc),
)
await server.mcp_manager.update_oauth_session(created_session.id, update_data, actor=default_user)
# Check database directly for encryption
async with db_registry.async_session() as session:
result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id))
db_oauth = result.scalar_one()
# All sensitive fields should be encrypted
assert db_oauth.access_token_enc is not None
assert db_oauth.access_token_enc != update_data.access_token
assert db_oauth.refresh_token_enc is not None
assert db_oauth.client_secret_enc is not None
# Verify decryption
decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc)
assert decrypted_access == update_data.access_token
decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc)
assert decrypted_refresh == update_data.refresh_token
decrypted_secret = CryptoUtils.decrypt(db_oauth.client_secret_enc)
assert decrypted_secret == update_data.client_secret
finally:
# Restore original encryption key
settings.encryption_key = original_key
# Clean up not needed - test database is reset
@pytest.mark.asyncio
async def test_retrieve_oauth_session_decrypts_tokens(self, server, default_user):
"""Test that retrieving OAuth session decrypts tokens."""
# Set encryption key directly on settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
try:
# Manually insert encrypted OAuth session
session_id = f"mcp-oauth-{str(uuid4())[:8]}"
access_token = "test-access-token"
refresh_token = "test-refresh-token"
client_secret = "test-client-secret"
encrypted_access = CryptoUtils.encrypt(access_token)
encrypted_refresh = CryptoUtils.encrypt(refresh_token)
encrypted_secret = CryptoUtils.encrypt(client_secret)
async with db_registry.async_session() as session:
db_oauth = MCPOAuth(
id=session_id,
state=f"test-state-{uuid4().hex[:8]}",
server_url="https://test.com/mcp",
server_name="Test Provider",
access_token_enc=encrypted_access,
refresh_token_enc=encrypted_refresh,
client_id="test-client",
client_secret_enc=encrypted_secret,
user_id=default_user.id,
organization_id=default_user.organization_id,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(db_oauth)
# context manager now handles commits
# await session.commit()
# Retrieve through manager by ID
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
assert test_session is not None
# Tokens should be decrypted from _enc columns
assert test_session.access_token_enc is not None
assert test_session.access_token_enc.get_plaintext() == access_token
assert test_session.refresh_token_enc is not None
assert test_session.refresh_token_enc.get_plaintext() == refresh_token
assert test_session.client_secret_enc is not None
assert test_session.client_secret_enc.get_plaintext() == client_secret
# Clean up not needed - test database is reset
finally:
# Restore original encryption key
settings.encryption_key = original_key
@pytest.mark.asyncio
async def test_update_oauth_session_maintains_encryption(self, server, default_user):
"""Test that updating OAuth session maintains encryption."""
# Set encryption key directly on settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
try:
# Create initial session (without tokens)
oauth_session_create = MCPOAuthSessionCreate(
server_url="https://test.com/mcp",
server_name="Test Update Provider",
organization_id=default_user.organization_id,
user_id=default_user.id,
)
created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user)
# Add initial tokens
initial_update = MCPOAuthSessionUpdate(
access_token="initial-token",
refresh_token="initial-refresh",
client_id="client-123",
client_secret="initial-secret",
)
await server.mcp_manager.update_oauth_session(created_session.id, initial_update, actor=default_user)
# Update with new tokens
new_access_token = "updated-access-token"
new_refresh_token = "updated-refresh-token"
new_update = MCPOAuthSessionUpdate(
access_token=new_access_token,
refresh_token=new_refresh_token,
)
updated_session = await server.mcp_manager.update_oauth_session(created_session.id, new_update, actor=default_user)
# Verify update worked - read from _enc columns
assert updated_session.access_token_enc is not None
assert updated_session.access_token_enc.get_plaintext() == new_access_token
assert updated_session.refresh_token_enc is not None
assert updated_session.refresh_token_enc.get_plaintext() == new_refresh_token
# Check database encryption
async with db_registry.async_session() as session:
result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id))
db_oauth = result.scalar_one()
# New tokens should be encrypted
decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc)
assert decrypted_access == new_access_token
decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc)
assert decrypted_refresh == new_refresh_token
# Clean up not needed - test database is reset
finally:
# Restore original encryption key
settings.encryption_key = original_key
@pytest.mark.asyncio
async def test_encrypted_only_reads(self, server, default_user):
"""Test that system only reads from encrypted columns, ignoring plaintext.
Note: In Phase 1 of migration, reads are encrypted-only. Plaintext columns
are ignored even if they contain values. This test verifies that the
encrypted value is used and plaintext is never used as fallback.
"""
# Set encryption key directly on settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
try:
# Insert a record with both encrypted and plaintext values
session_id = f"mcp-oauth-{str(uuid4())[:8]}"
plaintext_token = "legacy-plaintext-token"
new_encrypted_token = "new-encrypted-token"
encrypted_new = CryptoUtils.encrypt(new_encrypted_token)
async with db_registry.async_session() as session:
db_oauth = MCPOAuth(
id=session_id,
state=f"dual-read-state-{uuid4().hex[:8]}",
server_url="https://test.com/mcp",
server_name="Encrypted Only Read Test",
# Both encrypted and plaintext values
access_token=plaintext_token, # Legacy plaintext - should be ignored
access_token_enc=encrypted_new, # Encrypted value - should be used
client_id="test-client",
user_id=default_user.id,
organization_id=default_user.organization_id,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(db_oauth)
# context manager now handles commits
# await session.commit()
# Retrieve through manager
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
assert test_session is not None
# Should read from encrypted column only (plaintext is ignored)
assert test_session.access_token_enc is not None
assert test_session.access_token_enc.get_plaintext() == new_encrypted_token
# Clean up not needed - test database is reset
finally:
# Restore original encryption key
settings.encryption_key = original_key
@pytest.mark.asyncio
async def test_plaintext_only_record_returns_none(self, server, default_user):
"""Test that records with only plaintext values return None for encrypted fields.
With encrypted-only migration complete, if a record only has plaintext value
(no encrypted value), the system returns None for that field since we only
read from _enc columns now.
"""
# Set encryption key directly on settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
try:
# Insert a record with only plaintext value (no encrypted)
session_id = f"mcp-oauth-{str(uuid4())[:8]}"
plaintext_token = "legacy-plaintext-token"
async with db_registry.async_session() as session:
db_oauth = MCPOAuth(
id=session_id,
state=f"plaintext-only-state-{uuid4().hex[:8]}",
server_url="https://test.com/mcp",
server_name="Plaintext Only Test",
# Only plaintext value, no encrypted
access_token=plaintext_token, # Legacy plaintext - should be ignored
access_token_enc=None, # No encrypted value
client_id="test-client",
user_id=default_user.id,
organization_id=default_user.organization_id,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(db_oauth)
# context manager now handles commits
# await session.commit()
# Retrieve through manager
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
assert test_session is not None
# Should return None since we only read from _enc columns now
assert test_session.access_token_enc is None
# Clean up not needed - test database is reset
finally:
# Restore original encryption key
settings.encryption_key = original_key