diff --git a/letta/orm/mcp_server.py b/letta/orm/mcp_server.py index a829dd53..a62ff1d6 100644 --- a/letta/orm/mcp_server.py +++ b/letta/orm/mcp_server.py @@ -1,3 +1,4 @@ +import json from typing import TYPE_CHECKING, Optional from sqlalchemy import JSON, String, Text, UniqueConstraint @@ -11,6 +12,7 @@ from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.enums import MCPServerType from letta.schemas.mcp import MCPServer +from letta.schemas.secret import Secret if TYPE_CHECKING: from letta.orm.organization import Organization @@ -60,6 +62,23 @@ class MCPServer(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="mcp_servers") + def to_pydantic(self): + """Convert ORM model to Pydantic model, handling encrypted fields.""" + # Parse custom_headers from JSON if stored as string + return self.__pydantic_model__( + id=self.id, + server_type=self.server_type, + server_name=self.server_name, + server_url=self.server_url, + token_enc=Secret.from_encrypted(self.token_enc) if self.token_enc else None, + custom_headers_enc=Secret.from_encrypted(self.custom_headers_enc) if self.custom_headers_enc else None, + stdio_config=self.stdio_config, + organization_id=self.organization_id, + created_by_id=self.created_by_id, + last_updated_by_id=self.last_updated_by_id, + metadata_=self.metadata_, + ) + class MCPTools(SqlalchemyBase, OrganizationMixin): """Represents a mapping of MCP server ID to tool ID""" diff --git a/letta/schemas/mcp.py b/letta/schemas/mcp.py index 85ab7e23..ef8a10aa 100644 --- a/letta/schemas/mcp.py +++ b/letta/schemas/mcp.py @@ -1,9 +1,12 @@ import json +import logging from datetime import datetime from typing import Any, Dict, List, Optional, Union from pydantic import Field +logger = logging.getLogger(__name__) + from letta.functions.mcp_client.types import ( MCP_AUTH_HEADER_AUTHORIZATION, MCP_AUTH_TOKEN_BEARER_PREFIX, @@ -48,68 +51,50 @@ class MCPServer(BaseMCPServer): last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.") - def get_token_secret(self) -> Secret: - """Get the token as a Secret object. Prefers encrypted, falls back to plaintext with error logging.""" - if self.token_enc is not None: - return self.token_enc - # Fallback to plaintext with error logging via Secret.from_db() - return Secret.from_db(encrypted_value=None, plaintext_value=self.token) + def get_token_secret(self) -> Optional[Secret]: + """Get the token as a Secret object.""" + return self.token_enc - def get_custom_headers_secret(self) -> Secret: - """Get custom headers as a Secret object (stores JSON string). Prefers encrypted, falls back to plaintext with error logging.""" - if self.custom_headers_enc is not None: - return self.custom_headers_enc - # Fallback to plaintext with error logging via Secret.from_db() - # Convert dict to JSON string for Secret storage - plaintext_json = json.dumps(self.custom_headers) if self.custom_headers else None - return Secret.from_db(encrypted_value=None, plaintext_value=plaintext_json) + def get_custom_headers_secret(self) -> Optional[Secret]: + """Get the custom headers as a Secret object (JSON string).""" + return self.custom_headers_enc def get_custom_headers_dict(self) -> Optional[Dict[str, str]]: - """Get custom headers as a plaintext dictionary.""" - secret = self.get_custom_headers_secret() - json_str = secret.get_plaintext() - if json_str: - try: - return json.loads(json_str) - except (json.JSONDecodeError, TypeError): - return None + """Get the custom headers as a dictionary.""" + if self.custom_headers_enc: + json_str = self.custom_headers_enc.get_plaintext() + if json_str: + try: + return json.loads(json_str) + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"Failed to parse custom_headers_enc for MCP server {self.id}: {e}") return None def set_token_secret(self, secret: Secret) -> None: - """Set token from a Secret object, updating both encrypted and plaintext fields.""" + """Set token from a Secret object.""" self.token_enc = secret - secret_dict = secret.to_dict() - # Only set plaintext during migration phase - if not secret.was_encrypted: - self.token = secret_dict["plaintext"] - else: - self.token = None def set_custom_headers_secret(self, secret: Secret) -> None: - """Set custom headers from a Secret object (containing JSON string), updating both fields.""" + """Set custom headers from a Secret object (JSON string).""" self.custom_headers_enc = secret - secret_dict = secret.to_dict() - # Parse JSON string to dict for plaintext field - json_str = secret_dict.get("plaintext") - if json_str and not secret.was_encrypted: - try: - self.custom_headers = json.loads(json_str) - except (json.JSONDecodeError, TypeError): - self.custom_headers = None - else: - self.custom_headers = None def to_config( self, environment_variables: Optional[Dict[str, str]] = None, resolve_variables: bool = True, ) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]: - # Get decrypted values for use in config - token_secret = self.get_token_secret() - token_plaintext = token_secret.get_plaintext() + # Get decrypted values directly from encrypted columns + token_plaintext = self.token_enc.get_plaintext() if self.token_enc else None - # Get custom headers as dict - headers_plaintext = self.get_custom_headers_dict() + # Get custom headers as dict from encrypted column + headers_plaintext = None + if self.custom_headers_enc: + json_str = self.custom_headers_enc.get_plaintext() + if json_str: + try: + headers_plaintext = json.loads(json_str) + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"Failed to parse custom_headers_enc for MCP server {self.id}: {e}") if self.server_type == MCPServerType.SSE: config = SSEServerConfig( @@ -228,66 +213,6 @@ class MCPOAuthSession(BaseMCPOAuth): created_at: datetime = Field(default_factory=datetime.now, description="Session creation time") updated_at: datetime = Field(default_factory=datetime.now, description="Last update time") - def get_access_token_secret(self) -> Secret: - """Get the access token as a Secret object, preferring encrypted over plaintext.""" - if self.access_token_enc is not None: - return self.access_token_enc - return Secret.from_db(None, self.access_token) - - def get_refresh_token_secret(self) -> Secret: - """Get the refresh token as a Secret object, preferring encrypted over plaintext.""" - if self.refresh_token_enc is not None: - return self.refresh_token_enc - return Secret.from_db(None, self.refresh_token) - - def get_client_secret_secret(self) -> Secret: - """Get the client secret as a Secret object, preferring encrypted over plaintext.""" - if self.client_secret_enc is not None: - return self.client_secret_enc - return Secret.from_db(None, self.client_secret) - - def get_authorization_code_secret(self) -> Secret: - """Get the authorization code as a Secret object, preferring encrypted over plaintext.""" - if self.authorization_code_enc is not None: - return self.authorization_code_enc - return Secret.from_db(None, self.authorization_code) - - def set_access_token_secret(self, secret: Secret) -> None: - """Set access token from a Secret object.""" - self.access_token_enc = secret - secret_dict = secret.to_dict() - if not secret.was_encrypted: - self.access_token = secret_dict["plaintext"] - else: - self.access_token = None - - def set_refresh_token_secret(self, secret: Secret) -> None: - """Set refresh token from a Secret object.""" - self.refresh_token_enc = secret - secret_dict = secret.to_dict() - if not secret.was_encrypted: - self.refresh_token = secret_dict["plaintext"] - else: - self.refresh_token = None - - def set_client_secret_secret(self, secret: Secret) -> None: - """Set client secret from a Secret object.""" - self.client_secret_enc = secret - secret_dict = secret.to_dict() - if not secret.was_encrypted: - self.client_secret = secret_dict["plaintext"] - else: - self.client_secret = None - - def set_authorization_code_secret(self, secret: Secret) -> None: - """Set authorization code from a Secret object.""" - self.authorization_code_enc = secret - secret_dict = secret.to_dict() - if not secret.was_encrypted: - self.authorization_code = secret_dict["plaintext"] - else: - self.authorization_code = None - class MCPOAuthSessionCreate(BaseMCPOAuth): """Create a new OAuth session.""" diff --git a/letta/schemas/mcp_server.py b/letta/schemas/mcp_server.py index 459292bc..c8112e3c 100644 --- a/letta/schemas/mcp_server.py +++ b/letta/schemas/mcp_server.py @@ -319,24 +319,30 @@ def convert_generic_to_union(server) -> MCPServerUnion: env=server.stdio_config.env if server.stdio_config else None, ) elif server.server_type == MCPServerType.SSE: + # Get decrypted values from encrypted columns + token = server.token_enc.get_plaintext() if server.token_enc else None + headers = server.get_custom_headers_dict() return SSEMCPServer( id=server.id, server_name=server.server_name, mcp_server_type=MCPServerType.SSE, server_url=server.server_url, - auth_header="Authorization" if server.token else None, - auth_token=f"Bearer {server.token}" if server.token else None, - custom_headers=server.custom_headers, + auth_header="Authorization" if token else None, + auth_token=f"Bearer {token}" if token else None, + custom_headers=headers, ) elif server.server_type == MCPServerType.STREAMABLE_HTTP: + # Get decrypted values from encrypted columns + token = server.token_enc.get_plaintext() if server.token_enc else None + headers = server.get_custom_headers_dict() return StreamableHTTPMCPServer( id=server.id, server_name=server.server_name, mcp_server_type=MCPServerType.STREAMABLE_HTTP, server_url=server.server_url, - auth_header="Authorization" if server.token else None, - auth_token=f"Bearer {server.token}" if server.token else None, - custom_headers=server.custom_headers, + auth_header="Authorization" if token else None, + auth_token=f"Bearer {token}" if token else None, + custom_headers=headers, ) else: raise ValueError(f"Unknown server type: {server.server_type}") diff --git a/letta/schemas/secret.py b/letta/schemas/secret.py index e28f7641..64a1df61 100644 --- a/letta/schemas/secret.py +++ b/letta/schemas/secret.py @@ -1,4 +1,3 @@ -import json from typing import Any, Dict, Optional from pydantic import BaseModel, ConfigDict, PrivateAttr @@ -17,22 +16,17 @@ class Secret(BaseModel): This class ensures that sensitive data remains encrypted as much as possible while passing through the codebase, only decrypting when absolutely necessary. - Migration status (Phase 1 - encrypted-first reads with plaintext fallback): - - Reads: Prefer _enc columns, fallback to plaintext columns with ERROR logging - - Writes: Still dual-write to both _enc and plaintext columns for backward compatibility - - Encryption: Optional - if LETTA_ENCRYPTION_KEY is not set, stores plaintext in _enc column - - TODO (Phase 2): Remove plaintext fallback in from_db() after verifying no error logs - TODO (Phase 3): Remove dual-write logic in to_dict() and set_*_secret() methods - TODO (Phase 4): Remove from_db() plaintext_value parameter, was_encrypted flag, and plaintext columns + Usage: + - Create from plaintext: Secret.from_plaintext(value) + - Create from encrypted DB value: Secret.from_encrypted(encrypted_value) + - Get encrypted for storage: secret.get_encrypted() + - Get plaintext when needed: secret.get_plaintext() """ # Store the encrypted value as a regular field encrypted_value: Optional[str] = None # Cache the decrypted value to avoid repeated decryption (not serialized for security) _plaintext_cache: Optional[str] = PrivateAttr(default=None) - # Flag to indicate if the value was originally encrypted - was_encrypted: bool = False model_config = ConfigDict(frozen=True) @@ -51,7 +45,7 @@ class Secret(BaseModel): A Secret instance with the encrypted (or plaintext) value """ if value is None: - return cls.model_construct(encrypted_value=None, was_encrypted=False) + return cls.model_construct(encrypted_value=None) # Guard against double encryption - check if value is already encrypted if CryptoUtils.is_encrypted(value): @@ -60,7 +54,7 @@ class Secret(BaseModel): # Try to encrypt, but fall back to storing plaintext if no encryption key try: encrypted = CryptoUtils.encrypt(value) - return cls.model_construct(encrypted_value=encrypted, was_encrypted=False) + return cls.model_construct(encrypted_value=encrypted) except ValueError as e: # No encryption key available, store as plaintext in the _enc column if "No encryption key configured" in str(e): @@ -68,7 +62,7 @@ class Secret(BaseModel): "No encryption key configured. Storing Secret value as plaintext in _enc column. " "Set LETTA_ENCRYPTION_KEY environment variable to enable encryption." ) - instance = cls.model_construct(encrypted_value=value, was_encrypted=False) + instance = cls.model_construct(encrypted_value=value) instance._plaintext_cache = value # Cache it since we know the plaintext return instance raise # Re-raise if it's a different error @@ -76,47 +70,15 @@ class Secret(BaseModel): @classmethod def from_encrypted(cls, encrypted_value: Optional[str]) -> "Secret": """ - Create a Secret from an already encrypted value. + Create a Secret from an already encrypted value (read from DB). Args: - encrypted_value: The encrypted value + encrypted_value: The encrypted value from the _enc column Returns: A Secret instance """ - return cls.model_construct(encrypted_value=encrypted_value, was_encrypted=True) - - @classmethod - def from_db(cls, encrypted_value: Optional[str], plaintext_value: Optional[str] = None) -> "Secret": - """ - Create a Secret from database values. Prefers encrypted column, falls back to plaintext with error logging. - - During Phase 1 of migration, this method: - 1. Uses encrypted_value if available (preferred) - 2. Falls back to plaintext_value with ERROR logging if encrypted is unavailable - 3. Returns empty Secret if neither is available - - The error logging helps identify any records that haven't been migrated to encrypted columns. - - Args: - encrypted_value: The encrypted value from the database (_enc column) - plaintext_value: The plaintext value from the database (legacy column, fallback only) - - Returns: - A Secret instance with the value from encrypted or plaintext column - """ - if encrypted_value is not None: - return cls.from_encrypted(encrypted_value) - # Fallback to plaintext with error logging - this helps identify unmigrated data - if plaintext_value is not None: - logger.error( - "MIGRATION_NEEDED: Reading from plaintext column instead of encrypted column. " - "This indicates data that hasn't been migrated to the _enc column yet. " - "Please run migrate data to _enc columns as plaintext columns will be deprecated.", - stack_info=True, - ) - return cls.from_plaintext(plaintext_value) - return cls.from_plaintext(None) + return cls.model_construct(encrypted_value=encrypted_value) def get_encrypted(self) -> Optional[str]: """ @@ -146,14 +108,8 @@ class Secret(BaseModel): if self.encrypted_value is None: return None - # Use cached value if available, but only if it looks like plaintext - # or we're confident we can decrypt it + # Use cached value if available if self._plaintext_cache is not None: - # If this was explicitly created as plaintext, trust the cache - # This prevents false positives from is_encrypted() heuristic - if not self.was_encrypted: - return self._plaintext_cache - # For encrypted values, trust the cache (already decrypted previously) return self._plaintext_cache # Try to decrypt @@ -265,14 +221,6 @@ class Secret(BaseModel): """Representation that doesn't expose the actual value.""" return self.__str__() - def to_dict(self) -> Dict[str, Any]: - """ - Convert to dictionary for database storage. - - Returns both encrypted and plaintext values for dual-write during migration. - """ - return {"encrypted": self.get_encrypted(), "plaintext": self.get_plaintext() if not self.was_encrypted else None} - def __eq__(self, other: Any) -> bool: """ Compare two secrets by their plaintext values. diff --git a/letta/services/mcp/oauth_utils.py b/letta/services/mcp/oauth_utils.py index 5ff6085b..e565c22b 100644 --- a/letta/services/mcp/oauth_utils.py +++ b/letta/services/mcp/oauth_utils.py @@ -37,14 +37,12 @@ class DatabaseTokenStorage(TokenStorage): if not oauth_session: return None - # Decrypt tokens using getter methods - access_token_secret = oauth_session.get_access_token_secret() - access_token = access_token_secret.get_plaintext() + # Read tokens directly from _enc columns + access_token = oauth_session.access_token_enc.get_plaintext() if oauth_session.access_token_enc else None if not access_token: return None - refresh_token_secret = oauth_session.get_refresh_token_secret() - refresh_token = refresh_token_secret.get_plaintext() + refresh_token = oauth_session.refresh_token_enc.get_plaintext() if oauth_session.refresh_token_enc else None return OAuthToken( access_token=access_token, @@ -72,9 +70,8 @@ class DatabaseTokenStorage(TokenStorage): if not oauth_session or not oauth_session.client_id: return None - # Decrypt client secret using getter method - client_secret_secret = oauth_session.get_client_secret_secret() - client_secret = client_secret_secret.get_plaintext() + # Read client secret directly from _enc column + client_secret = oauth_session.client_secret_enc.get_plaintext() if oauth_session.client_secret_enc else None return OAuthClientInformationFull( client_id=oauth_session.client_id, @@ -147,19 +144,15 @@ class MCPOAuthSession: async def store_authorization_code(self, code: str, state: str) -> Optional[MCPOAuth]: """Store the authorization code from OAuth callback.""" - # Use mcp_manager to ensure proper encryption - from letta.schemas.mcp import MCPOAuthSessionUpdate from letta.schemas.secret import Secret async with db_registry.async_session() as session: try: oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) - # Encrypt the authorization_code before storing + # Encrypt the authorization_code and store only in _enc column if code is not None: oauth_record.authorization_code_enc = Secret.from_plaintext(code).get_encrypted() - # Keep plaintext for dual-write during migration - oauth_record.authorization_code = code oauth_record.status = OAuthSessionStatus.AUTHORIZED oauth_record.state = state @@ -234,10 +227,10 @@ async def create_oauth_provider( logger.info(f"Waiting for authorization code for session {session_id}") while time.time() - start_time < timeout: oauth_session = await mcp_manager.get_oauth_session_by_id(session_id, actor) - if oauth_session and oauth_session.authorization_code: - # Decrypt the authorization code before returning - auth_code_secret = oauth_session.get_authorization_code_secret() - return auth_code_secret.get_plaintext(), oauth_session.state + if oauth_session and oauth_session.authorization_code_enc: + # Read authorization code directly from _enc column + auth_code = oauth_session.authorization_code_enc.get_plaintext() + return auth_code, oauth_session.state elif oauth_session and oauth_session.status == OAuthSessionStatus.ERROR: raise Exception("OAuth authorization failed") await asyncio.sleep(1) diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 37f2d33d..a184ccf6 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -419,16 +419,14 @@ class MCPManager: server_type=server_config.type, server_url=server_config.server_url, ) - # Encrypt sensitive fields + # Encrypt sensitive fields - write only to _enc columns token = server_config.resolve_token() if token: - token_secret = Secret.from_plaintext(token) - mcp_server.set_token_secret(token_secret) + mcp_server.token_enc = Secret.from_plaintext(token) if server_config.custom_headers: # Convert dict to JSON string, then encrypt as Secret headers_json = json.dumps(server_config.custom_headers) - headers_secret = Secret.from_plaintext(headers_json) - mcp_server.set_custom_headers_secret(headers_secret) + mcp_server.custom_headers_enc = Secret.from_plaintext(headers_json) elif isinstance(server_config, StreamableHTTPServerConfig): mcp_server = MCPServer( @@ -436,16 +434,14 @@ class MCPManager: server_type=server_config.type, server_url=server_config.server_url, ) - # Encrypt sensitive fields + # Encrypt sensitive fields - write only to _enc columns token = server_config.resolve_token() if token: - token_secret = Secret.from_plaintext(token) - mcp_server.set_token_secret(token_secret) + mcp_server.token_enc = Secret.from_plaintext(token) if server_config.custom_headers: # Convert dict to JSON string, then encrypt as Secret headers_json = json.dumps(server_config.custom_headers) - headers_secret = Secret.from_plaintext(headers_json) - mcp_server.set_custom_headers_secret(headers_secret) + mcp_server.custom_headers_enc = Secret.from_plaintext(headers_json) else: raise ValueError(f"Unsupported server config type: {type(server_config)}") @@ -539,57 +535,44 @@ class MCPManager: # Update tool attributes with only the fields that were explicitly set update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True) - # Handle encryption for token if provided - # Only re-encrypt if the value has actually changed + # Handle encryption for token if provided - write only to _enc column if "token" in update_data and update_data["token"] is not None: - # Check if value changed + # Check if value changed by reading from _enc column only existing_token = None if mcp_server.token_enc: existing_secret = Secret.from_encrypted(mcp_server.token_enc) existing_token = existing_secret.get_plaintext() - elif mcp_server.token: - existing_token = mcp_server.token # Only re-encrypt if different if existing_token != update_data["token"]: mcp_server.token_enc = Secret.from_plaintext(update_data["token"]).get_encrypted() - # Keep plaintext for dual-write during migration - mcp_server.token = update_data["token"] # Remove from update_data since we set directly on mcp_server update_data.pop("token", None) update_data.pop("token_enc", None) - # Handle encryption for custom_headers if provided - # Only re-encrypt if the value has actually changed + # Handle encryption for custom_headers if provided - write only to _enc column if "custom_headers" in update_data: if update_data["custom_headers"] is not None: # custom_headers is a Dict[str, str], serialize to JSON then encrypt - import json - json_str = json.dumps(update_data["custom_headers"]) - # Check if value changed + # Check if value changed by reading from _enc column only existing_headers_json = None if mcp_server.custom_headers_enc: existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc) existing_headers_json = existing_secret.get_plaintext() - elif mcp_server.custom_headers: - existing_headers_json = json.dumps(mcp_server.custom_headers) # Only re-encrypt if different if existing_headers_json != json_str: mcp_server.custom_headers_enc = Secret.from_plaintext(json_str).get_encrypted() - # Keep plaintext for dual-write during migration - mcp_server.custom_headers = update_data["custom_headers"] # Remove from update_data since we set directly on mcp_server update_data.pop("custom_headers", None) update_data.pop("custom_headers_enc", None) else: - # Ensure custom_headers None is stored as SQL NULL, not JSON null + # Ensure custom_headers_enc None is stored as SQL NULL update_data.pop("custom_headers", None) - setattr(mcp_server, "custom_headers", null()) setattr(mcp_server, "custom_headers_enc", None) for key, value in update_data.items(): @@ -810,8 +793,8 @@ class MCPManager: # If no OAuth provider is provided, check if we have stored OAuth credentials if oauth_provider is None and hasattr(server_config, "server_url"): oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor) - # Check if access token exists by attempting to decrypt it - if oauth_session and oauth_session.get_access_token_secret().get_plaintext(): + # Check if access token exists by reading directly from _enc column + if oauth_session and oauth_session.access_token_enc and oauth_session.access_token_enc.get_plaintext(): # Create OAuth provider from stored credentials from letta.services.mcp.oauth_utils import create_oauth_provider @@ -838,29 +821,23 @@ class MCPManager: # OAuth-related methods def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession: """ - Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields. - - Note: Prefers encrypted columns (_enc fields), falls back to plaintext with error logging. - This helps identify unmigrated data during the migration period. + Convert OAuth ORM model to Pydantic model, reading directly from encrypted columns. """ - # Get decrypted values - prefer encrypted, fallback to plaintext with error logging - access_token = Secret.from_db( - encrypted_value=oauth_session.access_token_enc, plaintext_value=oauth_session.access_token - ).get_plaintext() + # Convert encrypted columns to Secret objects + authorization_code_enc = ( + Secret.from_encrypted(oauth_session.authorization_code_enc) if oauth_session.authorization_code_enc else None + ) + access_token_enc = Secret.from_encrypted(oauth_session.access_token_enc) if oauth_session.access_token_enc else None + refresh_token_enc = Secret.from_encrypted(oauth_session.refresh_token_enc) if oauth_session.refresh_token_enc else None + client_secret_enc = Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None - refresh_token = Secret.from_db( - encrypted_value=oauth_session.refresh_token_enc, plaintext_value=oauth_session.refresh_token - ).get_plaintext() + # Get plaintext values from encrypted columns (primary source of truth) + authorization_code = authorization_code_enc.get_plaintext() if authorization_code_enc else None + access_token = access_token_enc.get_plaintext() if access_token_enc else None + refresh_token = refresh_token_enc.get_plaintext() if refresh_token_enc else None + client_secret = client_secret_enc.get_plaintext() if client_secret_enc else None - client_secret = Secret.from_db( - encrypted_value=oauth_session.client_secret_enc, plaintext_value=oauth_session.client_secret - ).get_plaintext() - - authorization_code = Secret.from_db( - encrypted_value=oauth_session.authorization_code_enc, plaintext_value=oauth_session.authorization_code - ).get_plaintext() - - # Create the Pydantic object with encrypted fields as Secret objects + # Create the Pydantic object with both encrypted and plaintext fields pydantic_session = MCPOAuthSession( id=oauth_session.id, state=oauth_session.state, @@ -870,25 +847,24 @@ class MCPManager: user_id=oauth_session.user_id, organization_id=oauth_session.organization_id, authorization_url=oauth_session.authorization_url, - authorization_code=authorization_code, - access_token=access_token, - refresh_token=refresh_token, token_type=oauth_session.token_type, expires_at=oauth_session.expires_at, scope=oauth_session.scope, client_id=oauth_session.client_id, - client_secret=client_secret, redirect_uri=oauth_session.redirect_uri, status=oauth_session.status, created_at=oauth_session.created_at, updated_at=oauth_session.updated_at, - # Encrypted fields as Secret objects (converted from encrypted strings in DB) - authorization_code_enc=Secret.from_encrypted(oauth_session.authorization_code_enc) - if oauth_session.authorization_code_enc - else None, - access_token_enc=Secret.from_encrypted(oauth_session.access_token_enc) if oauth_session.access_token_enc else None, - refresh_token_enc=Secret.from_encrypted(oauth_session.refresh_token_enc) if oauth_session.refresh_token_enc else None, - client_secret_enc=Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None, + # Plaintext fields populated from encrypted columns + authorization_code=authorization_code, + access_token=access_token, + refresh_token=refresh_token, + client_secret=client_secret, + # Encrypted fields as Secret objects + authorization_code_enc=authorization_code_enc, + access_token_enc=access_token_enc, + refresh_token_enc=refresh_token_enc, + client_secret_enc=client_secret_enc, ) return pydantic_session @@ -957,56 +933,41 @@ class MCPManager: if session_update.authorization_url is not None: oauth_session.authorization_url = session_update.authorization_url - # Handle encryption for authorization_code - # Only re-encrypt if the value has actually changed + # Handle encryption for authorization_code - write only to _enc column if session_update.authorization_code is not None: - # Check if value changed + # Check if value changed by reading from _enc column only existing_code = None if oauth_session.authorization_code_enc: existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc) existing_code = existing_secret.get_plaintext() - elif oauth_session.authorization_code: - existing_code = oauth_session.authorization_code # Only re-encrypt if different if existing_code != session_update.authorization_code: oauth_session.authorization_code_enc = Secret.from_plaintext(session_update.authorization_code).get_encrypted() - # Keep plaintext for dual-write during migration - oauth_session.authorization_code = session_update.authorization_code - # Handle encryption for access_token - # Only re-encrypt if the value has actually changed + # Handle encryption for access_token - write only to _enc column if session_update.access_token is not None: - # Check if value changed + # Check if value changed by reading from _enc column only existing_token = None if oauth_session.access_token_enc: existing_secret = Secret.from_encrypted(oauth_session.access_token_enc) existing_token = existing_secret.get_plaintext() - elif oauth_session.access_token: - existing_token = oauth_session.access_token # Only re-encrypt if different if existing_token != session_update.access_token: oauth_session.access_token_enc = Secret.from_plaintext(session_update.access_token).get_encrypted() - # Keep plaintext for dual-write during migration - oauth_session.access_token = session_update.access_token - # Handle encryption for refresh_token - # Only re-encrypt if the value has actually changed + # Handle encryption for refresh_token - write only to _enc column if session_update.refresh_token is not None: - # Check if value changed + # Check if value changed by reading from _enc column only existing_refresh = None if oauth_session.refresh_token_enc: existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc) existing_refresh = existing_secret.get_plaintext() - elif oauth_session.refresh_token: - existing_refresh = oauth_session.refresh_token # Only re-encrypt if different if existing_refresh != session_update.refresh_token: oauth_session.refresh_token_enc = Secret.from_plaintext(session_update.refresh_token).get_encrypted() - # Keep plaintext for dual-write during migration - oauth_session.refresh_token = session_update.refresh_token if session_update.token_type is not None: oauth_session.token_type = session_update.token_type @@ -1017,22 +978,17 @@ class MCPManager: if session_update.client_id is not None: oauth_session.client_id = session_update.client_id - # Handle encryption for client_secret - # Only re-encrypt if the value has actually changed + # Handle encryption for client_secret - write only to _enc column if session_update.client_secret is not None: - # Check if value changed + # Check if value changed by reading from _enc column only existing_secret_val = None if oauth_session.client_secret_enc: existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc) existing_secret_val = existing_secret.get_plaintext() - elif oauth_session.client_secret: - existing_secret_val = oauth_session.client_secret # Only re-encrypt if different if existing_secret_val != session_update.client_secret: oauth_session.client_secret_enc = Secret.from_plaintext(session_update.client_secret).get_encrypted() - # Keep plaintext for dual-write during migration - oauth_session.client_secret = session_update.client_secret if session_update.redirect_uri is not None: oauth_session.redirect_uri = session_update.redirect_uri diff --git a/tests/managers/test_mcp_manager.py b/tests/managers/test_mcp_manager.py index e4bb9efc..ba36820d 100644 --- a/tests/managers/test_mcp_manager.py +++ b/tests/managers/test_mcp_manager.py @@ -943,12 +943,13 @@ async def test_mcp_server_token_encryption_on_create(server, default_user, encry assert created_server is not None assert created_server.server_name == "test-encrypted-server" - # Verify plaintext token is accessible (dual-write during migration) - assert created_server.token == "sk-test-secret-token-12345" + # Verify plaintext token field is NOT set (no dual-write) + assert created_server.token is None - # Verify token_enc is a Secret object + # Verify token_enc is a Secret object and decrypts correctly assert created_server.token_enc is not None assert isinstance(created_server.token_enc, Secret) + assert created_server.token_enc.get_plaintext() == "sk-test-secret-token-12345" # Read directly from database to verify encryption async with db_registry.async_session() as session: @@ -958,9 +959,6 @@ async def test_mcp_server_token_encryption_on_create(server, default_user, encry actor=default_user, ) - # Verify plaintext column has the value (dual-write) - assert server_orm.token == "sk-test-secret-token-12345" - # Verify encrypted column is populated and different from plaintext assert server_orm.token_enc is not None assert server_orm.token_enc != "sk-test-secret-token-12345" @@ -994,8 +992,12 @@ async def test_mcp_server_token_decryption_on_read(server, default_user, encrypt # Read the server back retrieved_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user) - # Verify the token is decrypted correctly - assert retrieved_server.token == "sk-test-decrypt-token-67890" + # Verify plaintext token field is NOT set (no dual-write) + assert retrieved_server.token is None + + # Verify the token is decrypted correctly via token_enc + assert retrieved_server.token_enc is not None + assert retrieved_server.token_enc.get_plaintext() == "sk-test-decrypt-token-67890" # Verify we can get the decrypted token through the secret getter token_secret = retrieved_server.get_token_secret() @@ -1028,8 +1030,11 @@ async def test_mcp_server_custom_headers_encryption(server, default_user, encryp created_server = await server.mcp_manager.create_mcp_server(mcp_server, actor=default_user) try: - # Verify custom_headers are accessible - assert created_server.custom_headers == custom_headers + # Verify plaintext custom_headers field is NOT set (no dual-write) + assert created_server.custom_headers is None + + # Verify custom_headers are accessible via encrypted field + assert created_server.get_custom_headers_dict() == custom_headers # Verify custom_headers_enc is a Secret object (stores JSON string) assert created_server.custom_headers_enc is not None diff --git a/tests/test_mcp_encryption.py b/tests/test_mcp_encryption.py index 91dd77eb..1e5c046f 100644 --- a/tests/test_mcp_encryption.py +++ b/tests/test_mcp_encryption.py @@ -84,10 +84,8 @@ class TestMCPServerEncryption: decrypted_token = CryptoUtils.decrypt(db_server.token_enc) assert decrypted_token == token - # Legacy plaintext column should be None (or empty for dual-write) - # During migration phase, might store both - if db_server.token: - assert db_server.token == token # Dual-write phase + # Plaintext column should NOT be written to (encrypted-only) + assert db_server.token is None # Clean up await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) @@ -176,9 +174,9 @@ class TestMCPServerEncryption: assert test_server is not None assert test_server.server_name == server_name - # Token should be decrypted when accessed via the secret method - token_secret = test_server.get_token_secret() - assert token_secret.get_plaintext() == plaintext_token + # Token should be decrypted when accessed via the _enc column + assert test_server.token_enc is not None + assert test_server.token_enc.get_plaintext() == plaintext_token # Clean up async with db_registry.async_session() as session: @@ -220,15 +218,15 @@ class TestMCPServerEncryption: # Should work without encryption key - stores plaintext in _enc column created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user) - # Check database - should store plaintext in _enc column + # Check database - should store plaintext in _enc column (no encryption key) async with db_registry.async_session() as session: result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id)) db_server = result.scalar_one() # Token should be stored as plaintext in _enc column (not encrypted) assert db_server.token_enc == token # Plaintext stored directly - # Legacy plaintext column should also be populated (dual-write) - assert db_server.token == token + # Plaintext column should NOT be written to (encrypted-only) + assert db_server.token is None # Clean up await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) @@ -346,10 +344,13 @@ class TestMCPOAuthEncryption: test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user) assert test_session is not None - # Tokens should be decrypted - assert test_session.access_token == access_token - assert test_session.refresh_token == refresh_token - assert test_session.client_secret == client_secret + # Tokens should be decrypted from _enc columns + assert test_session.access_token_enc is not None + assert test_session.access_token_enc.get_plaintext() == access_token + assert test_session.refresh_token_enc is not None + assert test_session.refresh_token_enc.get_plaintext() == refresh_token + assert test_session.client_secret_enc is not None + assert test_session.client_secret_enc.get_plaintext() == client_secret # Clean up not needed - test database is reset @@ -396,9 +397,11 @@ class TestMCPOAuthEncryption: updated_session = await server.mcp_manager.update_oauth_session(created_session.id, new_update, actor=default_user) - # Verify update worked - assert updated_session.access_token == new_access_token - assert updated_session.refresh_token == new_refresh_token + # Verify update worked - read from _enc columns + assert updated_session.access_token_enc is not None + assert updated_session.access_token_enc.get_plaintext() == new_access_token + assert updated_session.refresh_token_enc is not None + assert updated_session.refresh_token_enc.get_plaintext() == new_refresh_token # Check database encryption async with db_registry.async_session() as session: @@ -459,8 +462,9 @@ class TestMCPOAuthEncryption: test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user) assert test_session is not None - # Should use encrypted value only (plaintext is ignored) - assert test_session.access_token == new_encrypted_token + # Should read from encrypted column only (plaintext is ignored) + assert test_session.access_token_enc is not None + assert test_session.access_token_enc.get_plaintext() == new_encrypted_token # Clean up not needed - test database is reset @@ -469,15 +473,13 @@ class TestMCPOAuthEncryption: settings.encryption_key = original_key @pytest.mark.asyncio - async def test_plaintext_only_record_fallback_with_error_logging(self, server, default_user, caplog): - """Test that records with only plaintext values fall back to plaintext with error logging. + async def test_plaintext_only_record_returns_none(self, server, default_user): + """Test that records with only plaintext values return None for encrypted fields. - Note: In Phase 1 of migration, if a record only has plaintext value - (no encrypted value), the system falls back to plaintext but logs an error - to help identify unmigrated data. + With encrypted-only migration complete, if a record only has plaintext value + (no encrypted value), the system returns None for that field since we only + read from _enc columns now. """ - import logging - # Set encryption key directly on settings original_key = settings.encryption_key settings.encryption_key = self.MOCK_ENCRYPTION_KEY @@ -494,7 +496,7 @@ class TestMCPOAuthEncryption: server_url="https://test.com/mcp", server_name="Plaintext Only Test", # Only plaintext value, no encrypted - access_token=plaintext_token, # Legacy plaintext - should fallback with error log + access_token=plaintext_token, # Legacy plaintext - should be ignored access_token_enc=None, # No encrypted value client_id="test-client", user_id=default_user.id, @@ -505,17 +507,12 @@ class TestMCPOAuthEncryption: session.add(db_oauth) await session.commit() - # Retrieve through manager - should log error about plaintext fallback - with caplog.at_level(logging.ERROR): - test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user) - + # Retrieve through manager + test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user) assert test_session is not None - # Should fall back to plaintext value - assert test_session.access_token == plaintext_token - - # Should have logged an error about reading from plaintext column - assert "MIGRATION_NEEDED" in caplog.text + # Should return None since we only read from _enc columns now + assert test_session.access_token_enc is None # Clean up not needed - test database is reset