refactor: migrate mcp_servers and mcp_oauth to encrypted-only columns (#6751)

* refactor: migrate mcp_servers and mcp_oauth to encrypted-only columns

Complete migration to encrypted-only storage for sensitive fields:

- Remove dual-write to plaintext columns (token, custom_headers,
  authorization_code, access_token, refresh_token, client_secret)
- Read only from _enc columns, not from plaintext fallback
- Remove helper methods (get_token_secret, set_token_secret, etc.)
- Remove Secret.from_db() and Secret.to_dict() methods
- Update tests to verify encrypted-only behavior

After this change, plaintext columns can be set to NULL manually
since they are no longer read from or written to.

* fix test

* rename

* update

* union

* fix test
This commit is contained in:
jnjpng
2025-12-15 17:59:53 -08:00
committed by Caren Thomas
parent 03a41f8e8d
commit 00ba2d09f3
8 changed files with 176 additions and 327 deletions

View File

@@ -1,3 +1,4 @@
import json
from typing import TYPE_CHECKING, Optional
from sqlalchemy import JSON, String, Text, UniqueConstraint
@@ -11,6 +12,7 @@ from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.enums import MCPServerType
from letta.schemas.mcp import MCPServer
from letta.schemas.secret import Secret
if TYPE_CHECKING:
from letta.orm.organization import Organization
@@ -60,6 +62,23 @@ class MCPServer(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="mcp_servers")
def to_pydantic(self):
"""Convert ORM model to Pydantic model, handling encrypted fields."""
# Parse custom_headers from JSON if stored as string
return self.__pydantic_model__(
id=self.id,
server_type=self.server_type,
server_name=self.server_name,
server_url=self.server_url,
token_enc=Secret.from_encrypted(self.token_enc) if self.token_enc else None,
custom_headers_enc=Secret.from_encrypted(self.custom_headers_enc) if self.custom_headers_enc else None,
stdio_config=self.stdio_config,
organization_id=self.organization_id,
created_by_id=self.created_by_id,
last_updated_by_id=self.last_updated_by_id,
metadata_=self.metadata_,
)
class MCPTools(SqlalchemyBase, OrganizationMixin):
"""Represents a mapping of MCP server ID to tool ID"""

View File

@@ -1,9 +1,12 @@
import json
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from pydantic import Field
logger = logging.getLogger(__name__)
from letta.functions.mcp_client.types import (
MCP_AUTH_HEADER_AUTHORIZATION,
MCP_AUTH_TOKEN_BEARER_PREFIX,
@@ -48,68 +51,50 @@ class MCPServer(BaseMCPServer):
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
def get_token_secret(self) -> Secret:
"""Get the token as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
if self.token_enc is not None:
return self.token_enc
# Fallback to plaintext with error logging via Secret.from_db()
return Secret.from_db(encrypted_value=None, plaintext_value=self.token)
def get_token_secret(self) -> Optional[Secret]:
"""Get the token as a Secret object."""
return self.token_enc
def get_custom_headers_secret(self) -> Secret:
"""Get custom headers as a Secret object (stores JSON string). Prefers encrypted, falls back to plaintext with error logging."""
if self.custom_headers_enc is not None:
return self.custom_headers_enc
# Fallback to plaintext with error logging via Secret.from_db()
# Convert dict to JSON string for Secret storage
plaintext_json = json.dumps(self.custom_headers) if self.custom_headers else None
return Secret.from_db(encrypted_value=None, plaintext_value=plaintext_json)
def get_custom_headers_secret(self) -> Optional[Secret]:
"""Get the custom headers as a Secret object (JSON string)."""
return self.custom_headers_enc
def get_custom_headers_dict(self) -> Optional[Dict[str, str]]:
"""Get custom headers as a plaintext dictionary."""
secret = self.get_custom_headers_secret()
json_str = secret.get_plaintext()
if json_str:
try:
return json.loads(json_str)
except (json.JSONDecodeError, TypeError):
return None
"""Get the custom headers as a dictionary."""
if self.custom_headers_enc:
json_str = self.custom_headers_enc.get_plaintext()
if json_str:
try:
return json.loads(json_str)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"Failed to parse custom_headers_enc for MCP server {self.id}: {e}")
return None
def set_token_secret(self, secret: Secret) -> None:
"""Set token from a Secret object, updating both encrypted and plaintext fields."""
"""Set token from a Secret object."""
self.token_enc = secret
secret_dict = secret.to_dict()
# Only set plaintext during migration phase
if not secret.was_encrypted:
self.token = secret_dict["plaintext"]
else:
self.token = None
def set_custom_headers_secret(self, secret: Secret) -> None:
"""Set custom headers from a Secret object (containing JSON string), updating both fields."""
"""Set custom headers from a Secret object (JSON string)."""
self.custom_headers_enc = secret
secret_dict = secret.to_dict()
# Parse JSON string to dict for plaintext field
json_str = secret_dict.get("plaintext")
if json_str and not secret.was_encrypted:
try:
self.custom_headers = json.loads(json_str)
except (json.JSONDecodeError, TypeError):
self.custom_headers = None
else:
self.custom_headers = None
def to_config(
self,
environment_variables: Optional[Dict[str, str]] = None,
resolve_variables: bool = True,
) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]:
# Get decrypted values for use in config
token_secret = self.get_token_secret()
token_plaintext = token_secret.get_plaintext()
# Get decrypted values directly from encrypted columns
token_plaintext = self.token_enc.get_plaintext() if self.token_enc else None
# Get custom headers as dict
headers_plaintext = self.get_custom_headers_dict()
# Get custom headers as dict from encrypted column
headers_plaintext = None
if self.custom_headers_enc:
json_str = self.custom_headers_enc.get_plaintext()
if json_str:
try:
headers_plaintext = json.loads(json_str)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"Failed to parse custom_headers_enc for MCP server {self.id}: {e}")
if self.server_type == MCPServerType.SSE:
config = SSEServerConfig(
@@ -228,66 +213,6 @@ class MCPOAuthSession(BaseMCPOAuth):
created_at: datetime = Field(default_factory=datetime.now, description="Session creation time")
updated_at: datetime = Field(default_factory=datetime.now, description="Last update time")
def get_access_token_secret(self) -> Secret:
"""Get the access token as a Secret object, preferring encrypted over plaintext."""
if self.access_token_enc is not None:
return self.access_token_enc
return Secret.from_db(None, self.access_token)
def get_refresh_token_secret(self) -> Secret:
"""Get the refresh token as a Secret object, preferring encrypted over plaintext."""
if self.refresh_token_enc is not None:
return self.refresh_token_enc
return Secret.from_db(None, self.refresh_token)
def get_client_secret_secret(self) -> Secret:
"""Get the client secret as a Secret object, preferring encrypted over plaintext."""
if self.client_secret_enc is not None:
return self.client_secret_enc
return Secret.from_db(None, self.client_secret)
def get_authorization_code_secret(self) -> Secret:
"""Get the authorization code as a Secret object, preferring encrypted over plaintext."""
if self.authorization_code_enc is not None:
return self.authorization_code_enc
return Secret.from_db(None, self.authorization_code)
def set_access_token_secret(self, secret: Secret) -> None:
"""Set access token from a Secret object."""
self.access_token_enc = secret
secret_dict = secret.to_dict()
if not secret.was_encrypted:
self.access_token = secret_dict["plaintext"]
else:
self.access_token = None
def set_refresh_token_secret(self, secret: Secret) -> None:
"""Set refresh token from a Secret object."""
self.refresh_token_enc = secret
secret_dict = secret.to_dict()
if not secret.was_encrypted:
self.refresh_token = secret_dict["plaintext"]
else:
self.refresh_token = None
def set_client_secret_secret(self, secret: Secret) -> None:
"""Set client secret from a Secret object."""
self.client_secret_enc = secret
secret_dict = secret.to_dict()
if not secret.was_encrypted:
self.client_secret = secret_dict["plaintext"]
else:
self.client_secret = None
def set_authorization_code_secret(self, secret: Secret) -> None:
"""Set authorization code from a Secret object."""
self.authorization_code_enc = secret
secret_dict = secret.to_dict()
if not secret.was_encrypted:
self.authorization_code = secret_dict["plaintext"]
else:
self.authorization_code = None
class MCPOAuthSessionCreate(BaseMCPOAuth):
"""Create a new OAuth session."""

View File

@@ -319,24 +319,30 @@ def convert_generic_to_union(server) -> MCPServerUnion:
env=server.stdio_config.env if server.stdio_config else None,
)
elif server.server_type == MCPServerType.SSE:
# Get decrypted values from encrypted columns
token = server.token_enc.get_plaintext() if server.token_enc else None
headers = server.get_custom_headers_dict()
return SSEMCPServer(
id=server.id,
server_name=server.server_name,
mcp_server_type=MCPServerType.SSE,
server_url=server.server_url,
auth_header="Authorization" if server.token else None,
auth_token=f"Bearer {server.token}" if server.token else None,
custom_headers=server.custom_headers,
auth_header="Authorization" if token else None,
auth_token=f"Bearer {token}" if token else None,
custom_headers=headers,
)
elif server.server_type == MCPServerType.STREAMABLE_HTTP:
# Get decrypted values from encrypted columns
token = server.token_enc.get_plaintext() if server.token_enc else None
headers = server.get_custom_headers_dict()
return StreamableHTTPMCPServer(
id=server.id,
server_name=server.server_name,
mcp_server_type=MCPServerType.STREAMABLE_HTTP,
server_url=server.server_url,
auth_header="Authorization" if server.token else None,
auth_token=f"Bearer {server.token}" if server.token else None,
custom_headers=server.custom_headers,
auth_header="Authorization" if token else None,
auth_token=f"Bearer {token}" if token else None,
custom_headers=headers,
)
else:
raise ValueError(f"Unknown server type: {server.server_type}")

View File

@@ -1,4 +1,3 @@
import json
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, PrivateAttr
@@ -17,22 +16,17 @@ class Secret(BaseModel):
This class ensures that sensitive data remains encrypted as much as possible
while passing through the codebase, only decrypting when absolutely necessary.
Migration status (Phase 1 - encrypted-first reads with plaintext fallback):
- Reads: Prefer _enc columns, fallback to plaintext columns with ERROR logging
- Writes: Still dual-write to both _enc and plaintext columns for backward compatibility
- Encryption: Optional - if LETTA_ENCRYPTION_KEY is not set, stores plaintext in _enc column
TODO (Phase 2): Remove plaintext fallback in from_db() after verifying no error logs
TODO (Phase 3): Remove dual-write logic in to_dict() and set_*_secret() methods
TODO (Phase 4): Remove from_db() plaintext_value parameter, was_encrypted flag, and plaintext columns
Usage:
- Create from plaintext: Secret.from_plaintext(value)
- Create from encrypted DB value: Secret.from_encrypted(encrypted_value)
- Get encrypted for storage: secret.get_encrypted()
- Get plaintext when needed: secret.get_plaintext()
"""
# Store the encrypted value as a regular field
encrypted_value: Optional[str] = None
# Cache the decrypted value to avoid repeated decryption (not serialized for security)
_plaintext_cache: Optional[str] = PrivateAttr(default=None)
# Flag to indicate if the value was originally encrypted
was_encrypted: bool = False
model_config = ConfigDict(frozen=True)
@@ -51,7 +45,7 @@ class Secret(BaseModel):
A Secret instance with the encrypted (or plaintext) value
"""
if value is None:
return cls.model_construct(encrypted_value=None, was_encrypted=False)
return cls.model_construct(encrypted_value=None)
# Guard against double encryption - check if value is already encrypted
if CryptoUtils.is_encrypted(value):
@@ -60,7 +54,7 @@ class Secret(BaseModel):
# Try to encrypt, but fall back to storing plaintext if no encryption key
try:
encrypted = CryptoUtils.encrypt(value)
return cls.model_construct(encrypted_value=encrypted, was_encrypted=False)
return cls.model_construct(encrypted_value=encrypted)
except ValueError as e:
# No encryption key available, store as plaintext in the _enc column
if "No encryption key configured" in str(e):
@@ -68,7 +62,7 @@ class Secret(BaseModel):
"No encryption key configured. Storing Secret value as plaintext in _enc column. "
"Set LETTA_ENCRYPTION_KEY environment variable to enable encryption."
)
instance = cls.model_construct(encrypted_value=value, was_encrypted=False)
instance = cls.model_construct(encrypted_value=value)
instance._plaintext_cache = value # Cache it since we know the plaintext
return instance
raise # Re-raise if it's a different error
@@ -76,47 +70,15 @@ class Secret(BaseModel):
@classmethod
def from_encrypted(cls, encrypted_value: Optional[str]) -> "Secret":
"""
Create a Secret from an already encrypted value.
Create a Secret from an already encrypted value (read from DB).
Args:
encrypted_value: The encrypted value
encrypted_value: The encrypted value from the _enc column
Returns:
A Secret instance
"""
return cls.model_construct(encrypted_value=encrypted_value, was_encrypted=True)
@classmethod
def from_db(cls, encrypted_value: Optional[str], plaintext_value: Optional[str] = None) -> "Secret":
"""
Create a Secret from database values. Prefers encrypted column, falls back to plaintext with error logging.
During Phase 1 of migration, this method:
1. Uses encrypted_value if available (preferred)
2. Falls back to plaintext_value with ERROR logging if encrypted is unavailable
3. Returns empty Secret if neither is available
The error logging helps identify any records that haven't been migrated to encrypted columns.
Args:
encrypted_value: The encrypted value from the database (_enc column)
plaintext_value: The plaintext value from the database (legacy column, fallback only)
Returns:
A Secret instance with the value from encrypted or plaintext column
"""
if encrypted_value is not None:
return cls.from_encrypted(encrypted_value)
# Fallback to plaintext with error logging - this helps identify unmigrated data
if plaintext_value is not None:
logger.error(
"MIGRATION_NEEDED: Reading from plaintext column instead of encrypted column. "
"This indicates data that hasn't been migrated to the _enc column yet. "
"Please run migrate data to _enc columns as plaintext columns will be deprecated.",
stack_info=True,
)
return cls.from_plaintext(plaintext_value)
return cls.from_plaintext(None)
return cls.model_construct(encrypted_value=encrypted_value)
def get_encrypted(self) -> Optional[str]:
"""
@@ -146,14 +108,8 @@ class Secret(BaseModel):
if self.encrypted_value is None:
return None
# Use cached value if available, but only if it looks like plaintext
# or we're confident we can decrypt it
# Use cached value if available
if self._plaintext_cache is not None:
# If this was explicitly created as plaintext, trust the cache
# This prevents false positives from is_encrypted() heuristic
if not self.was_encrypted:
return self._plaintext_cache
# For encrypted values, trust the cache (already decrypted previously)
return self._plaintext_cache
# Try to decrypt
@@ -265,14 +221,6 @@ class Secret(BaseModel):
"""Representation that doesn't expose the actual value."""
return self.__str__()
def to_dict(self) -> Dict[str, Any]:
"""
Convert to dictionary for database storage.
Returns both encrypted and plaintext values for dual-write during migration.
"""
return {"encrypted": self.get_encrypted(), "plaintext": self.get_plaintext() if not self.was_encrypted else None}
def __eq__(self, other: Any) -> bool:
"""
Compare two secrets by their plaintext values.

View File

@@ -37,14 +37,12 @@ class DatabaseTokenStorage(TokenStorage):
if not oauth_session:
return None
# Decrypt tokens using getter methods
access_token_secret = oauth_session.get_access_token_secret()
access_token = access_token_secret.get_plaintext()
# Read tokens directly from _enc columns
access_token = oauth_session.access_token_enc.get_plaintext() if oauth_session.access_token_enc else None
if not access_token:
return None
refresh_token_secret = oauth_session.get_refresh_token_secret()
refresh_token = refresh_token_secret.get_plaintext()
refresh_token = oauth_session.refresh_token_enc.get_plaintext() if oauth_session.refresh_token_enc else None
return OAuthToken(
access_token=access_token,
@@ -72,9 +70,8 @@ class DatabaseTokenStorage(TokenStorage):
if not oauth_session or not oauth_session.client_id:
return None
# Decrypt client secret using getter method
client_secret_secret = oauth_session.get_client_secret_secret()
client_secret = client_secret_secret.get_plaintext()
# Read client secret directly from _enc column
client_secret = oauth_session.client_secret_enc.get_plaintext() if oauth_session.client_secret_enc else None
return OAuthClientInformationFull(
client_id=oauth_session.client_id,
@@ -147,19 +144,15 @@ class MCPOAuthSession:
async def store_authorization_code(self, code: str, state: str) -> Optional[MCPOAuth]:
"""Store the authorization code from OAuth callback."""
# Use mcp_manager to ensure proper encryption
from letta.schemas.mcp import MCPOAuthSessionUpdate
from letta.schemas.secret import Secret
async with db_registry.async_session() as session:
try:
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
# Encrypt the authorization_code before storing
# Encrypt the authorization_code and store only in _enc column
if code is not None:
oauth_record.authorization_code_enc = Secret.from_plaintext(code).get_encrypted()
# Keep plaintext for dual-write during migration
oauth_record.authorization_code = code
oauth_record.status = OAuthSessionStatus.AUTHORIZED
oauth_record.state = state
@@ -234,10 +227,10 @@ async def create_oauth_provider(
logger.info(f"Waiting for authorization code for session {session_id}")
while time.time() - start_time < timeout:
oauth_session = await mcp_manager.get_oauth_session_by_id(session_id, actor)
if oauth_session and oauth_session.authorization_code:
# Decrypt the authorization code before returning
auth_code_secret = oauth_session.get_authorization_code_secret()
return auth_code_secret.get_plaintext(), oauth_session.state
if oauth_session and oauth_session.authorization_code_enc:
# Read authorization code directly from _enc column
auth_code = oauth_session.authorization_code_enc.get_plaintext()
return auth_code, oauth_session.state
elif oauth_session and oauth_session.status == OAuthSessionStatus.ERROR:
raise Exception("OAuth authorization failed")
await asyncio.sleep(1)

View File

@@ -419,16 +419,14 @@ class MCPManager:
server_type=server_config.type,
server_url=server_config.server_url,
)
# Encrypt sensitive fields
# Encrypt sensitive fields - write only to _enc columns
token = server_config.resolve_token()
if token:
token_secret = Secret.from_plaintext(token)
mcp_server.set_token_secret(token_secret)
mcp_server.token_enc = Secret.from_plaintext(token)
if server_config.custom_headers:
# Convert dict to JSON string, then encrypt as Secret
headers_json = json.dumps(server_config.custom_headers)
headers_secret = Secret.from_plaintext(headers_json)
mcp_server.set_custom_headers_secret(headers_secret)
mcp_server.custom_headers_enc = Secret.from_plaintext(headers_json)
elif isinstance(server_config, StreamableHTTPServerConfig):
mcp_server = MCPServer(
@@ -436,16 +434,14 @@ class MCPManager:
server_type=server_config.type,
server_url=server_config.server_url,
)
# Encrypt sensitive fields
# Encrypt sensitive fields - write only to _enc columns
token = server_config.resolve_token()
if token:
token_secret = Secret.from_plaintext(token)
mcp_server.set_token_secret(token_secret)
mcp_server.token_enc = Secret.from_plaintext(token)
if server_config.custom_headers:
# Convert dict to JSON string, then encrypt as Secret
headers_json = json.dumps(server_config.custom_headers)
headers_secret = Secret.from_plaintext(headers_json)
mcp_server.set_custom_headers_secret(headers_secret)
mcp_server.custom_headers_enc = Secret.from_plaintext(headers_json)
else:
raise ValueError(f"Unsupported server config type: {type(server_config)}")
@@ -539,57 +535,44 @@ class MCPManager:
# Update tool attributes with only the fields that were explicitly set
update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True)
# Handle encryption for token if provided
# Only re-encrypt if the value has actually changed
# Handle encryption for token if provided - write only to _enc column
if "token" in update_data and update_data["token"] is not None:
# Check if value changed
# Check if value changed by reading from _enc column only
existing_token = None
if mcp_server.token_enc:
existing_secret = Secret.from_encrypted(mcp_server.token_enc)
existing_token = existing_secret.get_plaintext()
elif mcp_server.token:
existing_token = mcp_server.token
# Only re-encrypt if different
if existing_token != update_data["token"]:
mcp_server.token_enc = Secret.from_plaintext(update_data["token"]).get_encrypted()
# Keep plaintext for dual-write during migration
mcp_server.token = update_data["token"]
# Remove from update_data since we set directly on mcp_server
update_data.pop("token", None)
update_data.pop("token_enc", None)
# Handle encryption for custom_headers if provided
# Only re-encrypt if the value has actually changed
# Handle encryption for custom_headers if provided - write only to _enc column
if "custom_headers" in update_data:
if update_data["custom_headers"] is not None:
# custom_headers is a Dict[str, str], serialize to JSON then encrypt
import json
json_str = json.dumps(update_data["custom_headers"])
# Check if value changed
# Check if value changed by reading from _enc column only
existing_headers_json = None
if mcp_server.custom_headers_enc:
existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc)
existing_headers_json = existing_secret.get_plaintext()
elif mcp_server.custom_headers:
existing_headers_json = json.dumps(mcp_server.custom_headers)
# Only re-encrypt if different
if existing_headers_json != json_str:
mcp_server.custom_headers_enc = Secret.from_plaintext(json_str).get_encrypted()
# Keep plaintext for dual-write during migration
mcp_server.custom_headers = update_data["custom_headers"]
# Remove from update_data since we set directly on mcp_server
update_data.pop("custom_headers", None)
update_data.pop("custom_headers_enc", None)
else:
# Ensure custom_headers None is stored as SQL NULL, not JSON null
# Ensure custom_headers_enc None is stored as SQL NULL
update_data.pop("custom_headers", None)
setattr(mcp_server, "custom_headers", null())
setattr(mcp_server, "custom_headers_enc", None)
for key, value in update_data.items():
@@ -810,8 +793,8 @@ class MCPManager:
# If no OAuth provider is provided, check if we have stored OAuth credentials
if oauth_provider is None and hasattr(server_config, "server_url"):
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
# Check if access token exists by attempting to decrypt it
if oauth_session and oauth_session.get_access_token_secret().get_plaintext():
# Check if access token exists by reading directly from _enc column
if oauth_session and oauth_session.access_token_enc and oauth_session.access_token_enc.get_plaintext():
# Create OAuth provider from stored credentials
from letta.services.mcp.oauth_utils import create_oauth_provider
@@ -838,29 +821,23 @@ class MCPManager:
# OAuth-related methods
def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
"""
Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields.
Note: Prefers encrypted columns (_enc fields), falls back to plaintext with error logging.
This helps identify unmigrated data during the migration period.
Convert OAuth ORM model to Pydantic model, reading directly from encrypted columns.
"""
# Get decrypted values - prefer encrypted, fallback to plaintext with error logging
access_token = Secret.from_db(
encrypted_value=oauth_session.access_token_enc, plaintext_value=oauth_session.access_token
).get_plaintext()
# Convert encrypted columns to Secret objects
authorization_code_enc = (
Secret.from_encrypted(oauth_session.authorization_code_enc) if oauth_session.authorization_code_enc else None
)
access_token_enc = Secret.from_encrypted(oauth_session.access_token_enc) if oauth_session.access_token_enc else None
refresh_token_enc = Secret.from_encrypted(oauth_session.refresh_token_enc) if oauth_session.refresh_token_enc else None
client_secret_enc = Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None
refresh_token = Secret.from_db(
encrypted_value=oauth_session.refresh_token_enc, plaintext_value=oauth_session.refresh_token
).get_plaintext()
# Get plaintext values from encrypted columns (primary source of truth)
authorization_code = authorization_code_enc.get_plaintext() if authorization_code_enc else None
access_token = access_token_enc.get_plaintext() if access_token_enc else None
refresh_token = refresh_token_enc.get_plaintext() if refresh_token_enc else None
client_secret = client_secret_enc.get_plaintext() if client_secret_enc else None
client_secret = Secret.from_db(
encrypted_value=oauth_session.client_secret_enc, plaintext_value=oauth_session.client_secret
).get_plaintext()
authorization_code = Secret.from_db(
encrypted_value=oauth_session.authorization_code_enc, plaintext_value=oauth_session.authorization_code
).get_plaintext()
# Create the Pydantic object with encrypted fields as Secret objects
# Create the Pydantic object with both encrypted and plaintext fields
pydantic_session = MCPOAuthSession(
id=oauth_session.id,
state=oauth_session.state,
@@ -870,25 +847,24 @@ class MCPManager:
user_id=oauth_session.user_id,
organization_id=oauth_session.organization_id,
authorization_url=oauth_session.authorization_url,
authorization_code=authorization_code,
access_token=access_token,
refresh_token=refresh_token,
token_type=oauth_session.token_type,
expires_at=oauth_session.expires_at,
scope=oauth_session.scope,
client_id=oauth_session.client_id,
client_secret=client_secret,
redirect_uri=oauth_session.redirect_uri,
status=oauth_session.status,
created_at=oauth_session.created_at,
updated_at=oauth_session.updated_at,
# Encrypted fields as Secret objects (converted from encrypted strings in DB)
authorization_code_enc=Secret.from_encrypted(oauth_session.authorization_code_enc)
if oauth_session.authorization_code_enc
else None,
access_token_enc=Secret.from_encrypted(oauth_session.access_token_enc) if oauth_session.access_token_enc else None,
refresh_token_enc=Secret.from_encrypted(oauth_session.refresh_token_enc) if oauth_session.refresh_token_enc else None,
client_secret_enc=Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None,
# Plaintext fields populated from encrypted columns
authorization_code=authorization_code,
access_token=access_token,
refresh_token=refresh_token,
client_secret=client_secret,
# Encrypted fields as Secret objects
authorization_code_enc=authorization_code_enc,
access_token_enc=access_token_enc,
refresh_token_enc=refresh_token_enc,
client_secret_enc=client_secret_enc,
)
return pydantic_session
@@ -957,56 +933,41 @@ class MCPManager:
if session_update.authorization_url is not None:
oauth_session.authorization_url = session_update.authorization_url
# Handle encryption for authorization_code
# Only re-encrypt if the value has actually changed
# Handle encryption for authorization_code - write only to _enc column
if session_update.authorization_code is not None:
# Check if value changed
# Check if value changed by reading from _enc column only
existing_code = None
if oauth_session.authorization_code_enc:
existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc)
existing_code = existing_secret.get_plaintext()
elif oauth_session.authorization_code:
existing_code = oauth_session.authorization_code
# Only re-encrypt if different
if existing_code != session_update.authorization_code:
oauth_session.authorization_code_enc = Secret.from_plaintext(session_update.authorization_code).get_encrypted()
# Keep plaintext for dual-write during migration
oauth_session.authorization_code = session_update.authorization_code
# Handle encryption for access_token
# Only re-encrypt if the value has actually changed
# Handle encryption for access_token - write only to _enc column
if session_update.access_token is not None:
# Check if value changed
# Check if value changed by reading from _enc column only
existing_token = None
if oauth_session.access_token_enc:
existing_secret = Secret.from_encrypted(oauth_session.access_token_enc)
existing_token = existing_secret.get_plaintext()
elif oauth_session.access_token:
existing_token = oauth_session.access_token
# Only re-encrypt if different
if existing_token != session_update.access_token:
oauth_session.access_token_enc = Secret.from_plaintext(session_update.access_token).get_encrypted()
# Keep plaintext for dual-write during migration
oauth_session.access_token = session_update.access_token
# Handle encryption for refresh_token
# Only re-encrypt if the value has actually changed
# Handle encryption for refresh_token - write only to _enc column
if session_update.refresh_token is not None:
# Check if value changed
# Check if value changed by reading from _enc column only
existing_refresh = None
if oauth_session.refresh_token_enc:
existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc)
existing_refresh = existing_secret.get_plaintext()
elif oauth_session.refresh_token:
existing_refresh = oauth_session.refresh_token
# Only re-encrypt if different
if existing_refresh != session_update.refresh_token:
oauth_session.refresh_token_enc = Secret.from_plaintext(session_update.refresh_token).get_encrypted()
# Keep plaintext for dual-write during migration
oauth_session.refresh_token = session_update.refresh_token
if session_update.token_type is not None:
oauth_session.token_type = session_update.token_type
@@ -1017,22 +978,17 @@ class MCPManager:
if session_update.client_id is not None:
oauth_session.client_id = session_update.client_id
# Handle encryption for client_secret
# Only re-encrypt if the value has actually changed
# Handle encryption for client_secret - write only to _enc column
if session_update.client_secret is not None:
# Check if value changed
# Check if value changed by reading from _enc column only
existing_secret_val = None
if oauth_session.client_secret_enc:
existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc)
existing_secret_val = existing_secret.get_plaintext()
elif oauth_session.client_secret:
existing_secret_val = oauth_session.client_secret
# Only re-encrypt if different
if existing_secret_val != session_update.client_secret:
oauth_session.client_secret_enc = Secret.from_plaintext(session_update.client_secret).get_encrypted()
# Keep plaintext for dual-write during migration
oauth_session.client_secret = session_update.client_secret
if session_update.redirect_uri is not None:
oauth_session.redirect_uri = session_update.redirect_uri

View File

@@ -943,12 +943,13 @@ async def test_mcp_server_token_encryption_on_create(server, default_user, encry
assert created_server is not None
assert created_server.server_name == "test-encrypted-server"
# Verify plaintext token is accessible (dual-write during migration)
assert created_server.token == "sk-test-secret-token-12345"
# Verify plaintext token field is NOT set (no dual-write)
assert created_server.token is None
# Verify token_enc is a Secret object
# Verify token_enc is a Secret object and decrypts correctly
assert created_server.token_enc is not None
assert isinstance(created_server.token_enc, Secret)
assert created_server.token_enc.get_plaintext() == "sk-test-secret-token-12345"
# Read directly from database to verify encryption
async with db_registry.async_session() as session:
@@ -958,9 +959,6 @@ async def test_mcp_server_token_encryption_on_create(server, default_user, encry
actor=default_user,
)
# Verify plaintext column has the value (dual-write)
assert server_orm.token == "sk-test-secret-token-12345"
# Verify encrypted column is populated and different from plaintext
assert server_orm.token_enc is not None
assert server_orm.token_enc != "sk-test-secret-token-12345"
@@ -994,8 +992,12 @@ async def test_mcp_server_token_decryption_on_read(server, default_user, encrypt
# Read the server back
retrieved_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user)
# Verify the token is decrypted correctly
assert retrieved_server.token == "sk-test-decrypt-token-67890"
# Verify plaintext token field is NOT set (no dual-write)
assert retrieved_server.token is None
# Verify the token is decrypted correctly via token_enc
assert retrieved_server.token_enc is not None
assert retrieved_server.token_enc.get_plaintext() == "sk-test-decrypt-token-67890"
# Verify we can get the decrypted token through the secret getter
token_secret = retrieved_server.get_token_secret()
@@ -1028,8 +1030,11 @@ async def test_mcp_server_custom_headers_encryption(server, default_user, encryp
created_server = await server.mcp_manager.create_mcp_server(mcp_server, actor=default_user)
try:
# Verify custom_headers are accessible
assert created_server.custom_headers == custom_headers
# Verify plaintext custom_headers field is NOT set (no dual-write)
assert created_server.custom_headers is None
# Verify custom_headers are accessible via encrypted field
assert created_server.get_custom_headers_dict() == custom_headers
# Verify custom_headers_enc is a Secret object (stores JSON string)
assert created_server.custom_headers_enc is not None

View File

@@ -84,10 +84,8 @@ class TestMCPServerEncryption:
decrypted_token = CryptoUtils.decrypt(db_server.token_enc)
assert decrypted_token == token
# Legacy plaintext column should be None (or empty for dual-write)
# During migration phase, might store both
if db_server.token:
assert db_server.token == token # Dual-write phase
# Plaintext column should NOT be written to (encrypted-only)
assert db_server.token is None
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
@@ -176,9 +174,9 @@ class TestMCPServerEncryption:
assert test_server is not None
assert test_server.server_name == server_name
# Token should be decrypted when accessed via the secret method
token_secret = test_server.get_token_secret()
assert token_secret.get_plaintext() == plaintext_token
# Token should be decrypted when accessed via the _enc column
assert test_server.token_enc is not None
assert test_server.token_enc.get_plaintext() == plaintext_token
# Clean up
async with db_registry.async_session() as session:
@@ -220,15 +218,15 @@ class TestMCPServerEncryption:
# Should work without encryption key - stores plaintext in _enc column
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
# Check database - should store plaintext in _enc column
# Check database - should store plaintext in _enc column (no encryption key)
async with db_registry.async_session() as session:
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
db_server = result.scalar_one()
# Token should be stored as plaintext in _enc column (not encrypted)
assert db_server.token_enc == token # Plaintext stored directly
# Legacy plaintext column should also be populated (dual-write)
assert db_server.token == token
# Plaintext column should NOT be written to (encrypted-only)
assert db_server.token is None
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
@@ -346,10 +344,13 @@ class TestMCPOAuthEncryption:
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
assert test_session is not None
# Tokens should be decrypted
assert test_session.access_token == access_token
assert test_session.refresh_token == refresh_token
assert test_session.client_secret == client_secret
# Tokens should be decrypted from _enc columns
assert test_session.access_token_enc is not None
assert test_session.access_token_enc.get_plaintext() == access_token
assert test_session.refresh_token_enc is not None
assert test_session.refresh_token_enc.get_plaintext() == refresh_token
assert test_session.client_secret_enc is not None
assert test_session.client_secret_enc.get_plaintext() == client_secret
# Clean up not needed - test database is reset
@@ -396,9 +397,11 @@ class TestMCPOAuthEncryption:
updated_session = await server.mcp_manager.update_oauth_session(created_session.id, new_update, actor=default_user)
# Verify update worked
assert updated_session.access_token == new_access_token
assert updated_session.refresh_token == new_refresh_token
# Verify update worked - read from _enc columns
assert updated_session.access_token_enc is not None
assert updated_session.access_token_enc.get_plaintext() == new_access_token
assert updated_session.refresh_token_enc is not None
assert updated_session.refresh_token_enc.get_plaintext() == new_refresh_token
# Check database encryption
async with db_registry.async_session() as session:
@@ -459,8 +462,9 @@ class TestMCPOAuthEncryption:
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
assert test_session is not None
# Should use encrypted value only (plaintext is ignored)
assert test_session.access_token == new_encrypted_token
# Should read from encrypted column only (plaintext is ignored)
assert test_session.access_token_enc is not None
assert test_session.access_token_enc.get_plaintext() == new_encrypted_token
# Clean up not needed - test database is reset
@@ -469,15 +473,13 @@ class TestMCPOAuthEncryption:
settings.encryption_key = original_key
@pytest.mark.asyncio
async def test_plaintext_only_record_fallback_with_error_logging(self, server, default_user, caplog):
"""Test that records with only plaintext values fall back to plaintext with error logging.
async def test_plaintext_only_record_returns_none(self, server, default_user):
"""Test that records with only plaintext values return None for encrypted fields.
Note: In Phase 1 of migration, if a record only has plaintext value
(no encrypted value), the system falls back to plaintext but logs an error
to help identify unmigrated data.
With encrypted-only migration complete, if a record only has plaintext value
(no encrypted value), the system returns None for that field since we only
read from _enc columns now.
"""
import logging
# Set encryption key directly on settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
@@ -494,7 +496,7 @@ class TestMCPOAuthEncryption:
server_url="https://test.com/mcp",
server_name="Plaintext Only Test",
# Only plaintext value, no encrypted
access_token=plaintext_token, # Legacy plaintext - should fallback with error log
access_token=plaintext_token, # Legacy plaintext - should be ignored
access_token_enc=None, # No encrypted value
client_id="test-client",
user_id=default_user.id,
@@ -505,17 +507,12 @@ class TestMCPOAuthEncryption:
session.add(db_oauth)
await session.commit()
# Retrieve through manager - should log error about plaintext fallback
with caplog.at_level(logging.ERROR):
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
# Retrieve through manager
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
assert test_session is not None
# Should fall back to plaintext value
assert test_session.access_token == plaintext_token
# Should have logged an error about reading from plaintext column
assert "MIGRATION_NEEDED" in caplog.text
# Should return None since we only read from _enc columns now
assert test_session.access_token_enc is None
# Clean up not needed - test database is reset