From 3711b5279c098f85a181c69dba38989a0e4241cf Mon Sep 17 00:00:00 2001 From: jnjpng Date: Tue, 16 Sep 2025 11:56:34 -0700 Subject: [PATCH] feat: encryption for mcp (#2937) --- .github/workflows/core-unit-sqlite-test.yaml | 5 +- .github/workflows/core-unit-test.yml | 5 +- ..._add_and_migrate_encrypted_columns_for_.py | 318 ++++++++++++++ letta/helpers/crypto_utils.py | 134 ++++++ letta/orm/__init__.py | 1 + letta/orm/mcp_oauth.py | 6 + letta/orm/mcp_server.py | 8 +- letta/schemas/mcp.py | 185 +++++++- letta/schemas/secret.py | 241 +++++++++++ letta/server/rest_api/routers/v1/tools.py | 33 +- letta/services/mcp_manager.py | 335 ++++++++++---- tests/test_crypto_utils.py | 232 ++++++++++ tests/test_mcp_encryption.py | 407 ++++++++++++++++++ tests/test_secret.py | 373 ++++++++++++++++ 14 files changed, 2166 insertions(+), 117 deletions(-) create mode 100644 alembic/versions/d06594144ef3_add_and_migrate_encrypted_columns_for_.py create mode 100644 letta/helpers/crypto_utils.py create mode 100644 letta/schemas/secret.py create mode 100644 tests/test_crypto_utils.py create mode 100644 tests/test_mcp_encryption.py create mode 100644 tests/test_secret.py diff --git a/.github/workflows/core-unit-sqlite-test.yaml b/.github/workflows/core-unit-sqlite-test.yaml index 76236dea..34f5090c 100644 --- a/.github/workflows/core-unit-sqlite-test.yaml +++ b/.github/workflows/core-unit-sqlite-test.yaml @@ -53,7 +53,10 @@ jobs: {"test_suite": "mcp_tests/", "use_experimental": true}, {"test_suite": "test_timezone_formatting.py"}, {"test_suite": "test_plugins.py"}, - {"test_suite": "test_embeddings.py"} + {"test_suite": "test_embeddings.py"}, + {"test_suite": "test_crypto_utils.py"}, + {"test_suite": "test_mcp_encryption.py"}, + {"test_suite": "test_secret.py"} ] } } diff --git a/.github/workflows/core-unit-test.yml b/.github/workflows/core-unit-test.yml index 28096b60..e1f0c5de 100644 --- a/.github/workflows/core-unit-test.yml +++ b/.github/workflows/core-unit-test.yml @@ -53,7 +53,10 @@ jobs: {"test_suite": "mcp_tests/", "use_experimental": true}, {"test_suite": "test_timezone_formatting.py"}, {"test_suite": "test_plugins.py"}, - {"test_suite": "test_embeddings.py"} + {"test_suite": "test_embeddings.py"}, + {"test_suite": "test_crypto_utils.py"}, + {"test_suite": "test_mcp_encryption.py"}, + {"test_suite": "test_secret.py"} ] } } diff --git a/alembic/versions/d06594144ef3_add_and_migrate_encrypted_columns_for_.py b/alembic/versions/d06594144ef3_add_and_migrate_encrypted_columns_for_.py new file mode 100644 index 00000000..d927532e --- /dev/null +++ b/alembic/versions/d06594144ef3_add_and_migrate_encrypted_columns_for_.py @@ -0,0 +1,318 @@ +"""add and migrate encrypted columns for mcp + +Revision ID: d06594144ef3 +Revises: 5d27a719b24d +Create Date: 2025-09-15 22:02:47.403970 + +""" + +import json +import os + +# Add the app directory to path to import our crypto utils +import sys +from pathlib import Path +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy import JSON, String, Text +from sqlalchemy.sql import column, table + +from alembic import op + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from letta.helpers.crypto_utils import CryptoUtils + +# revision identifiers, used by Alembic. +revision: str = "d06594144ef3" +down_revision: Union[str, None] = "5d27a719b24d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # First, add the new encrypted columns + op.add_column("mcp_oauth", sa.Column("access_token_enc", sa.Text(), nullable=True)) + op.add_column("mcp_oauth", sa.Column("refresh_token_enc", sa.Text(), nullable=True)) + op.add_column("mcp_oauth", sa.Column("client_secret_enc", sa.Text(), nullable=True)) + op.add_column("mcp_server", sa.Column("token_enc", sa.Text(), nullable=True)) + op.add_column("mcp_server", sa.Column("custom_headers_enc", sa.Text(), nullable=True)) + + # Check if encryption key is available + encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") + if not encryption_key: + print("WARNING: LETTA_ENCRYPTION_KEY not set. Skipping data encryption migration.") + print(" You can run a separate migration script later to encrypt existing data.") + return + + # Get database connection + connection = op.get_bind() + + # Batch processing configuration + BATCH_SIZE = 1000 # Process 1000 rows at a time + + # Migrate mcp_oauth data + print("Migrating mcp_oauth encrypted fields...") + mcp_oauth = table( + "mcp_oauth", + column("id", String), + column("access_token", Text), + column("access_token_enc", Text), + column("refresh_token", Text), + column("refresh_token_enc", Text), + column("client_secret", Text), + column("client_secret_enc", Text), + ) + + # Count total rows to process + total_count_result = connection.execute( + sa.select(sa.func.count()) + .select_from(mcp_oauth) + .where( + sa.and_( + sa.or_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.client_secret.isnot(None)), + # Only count rows that need encryption + sa.or_( + sa.and_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.access_token_enc.is_(None)), + sa.and_(mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.refresh_token_enc.is_(None)), + sa.and_(mcp_oauth.c.client_secret.isnot(None), mcp_oauth.c.client_secret_enc.is_(None)), + ), + ) + ) + ).scalar() + + if total_count_result and total_count_result > 0: + print(f"Found {total_count_result} mcp_oauth records that need encryption") + + encrypted_count = 0 + skipped_count = 0 + offset = 0 + + # Process in batches + while True: + # Select batch of rows + oauth_rows = connection.execute( + sa.select( + mcp_oauth.c.id, + mcp_oauth.c.access_token, + mcp_oauth.c.access_token_enc, + mcp_oauth.c.refresh_token, + mcp_oauth.c.refresh_token_enc, + mcp_oauth.c.client_secret, + mcp_oauth.c.client_secret_enc, + ) + .where( + sa.and_( + sa.or_( + mcp_oauth.c.access_token.isnot(None), + mcp_oauth.c.refresh_token.isnot(None), + mcp_oauth.c.client_secret.isnot(None), + ), + # Only select rows that need encryption + sa.or_( + sa.and_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.access_token_enc.is_(None)), + sa.and_(mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.refresh_token_enc.is_(None)), + sa.and_(mcp_oauth.c.client_secret.isnot(None), mcp_oauth.c.client_secret_enc.is_(None)), + ), + ) + ) + .order_by(mcp_oauth.c.id) # Ensure consistent ordering + .limit(BATCH_SIZE) + .offset(offset) + ).fetchall() + + if not oauth_rows: + break # No more rows to process + + # Prepare batch updates + batch_updates = [] + + for row in oauth_rows: + updates = {"id": row.id} + has_updates = False + + # Encrypt access_token if present and not already encrypted + if row.access_token and not row.access_token_enc: + try: + updates["access_token_enc"] = CryptoUtils.encrypt(row.access_token, encryption_key) + has_updates = True + except Exception as e: + print(f"Warning: Failed to encrypt access_token for mcp_oauth id={row.id}: {e}") + elif row.access_token_enc: + skipped_count += 1 + + # Encrypt refresh_token if present and not already encrypted + if row.refresh_token and not row.refresh_token_enc: + try: + updates["refresh_token_enc"] = CryptoUtils.encrypt(row.refresh_token, encryption_key) + has_updates = True + except Exception as e: + print(f"Warning: Failed to encrypt refresh_token for mcp_oauth id={row.id}: {e}") + elif row.refresh_token_enc: + skipped_count += 1 + + # Encrypt client_secret if present and not already encrypted + if row.client_secret and not row.client_secret_enc: + try: + updates["client_secret_enc"] = CryptoUtils.encrypt(row.client_secret, encryption_key) + has_updates = True + except Exception as e: + print(f"Warning: Failed to encrypt client_secret for mcp_oauth id={row.id}: {e}") + elif row.client_secret_enc: + skipped_count += 1 + + if has_updates: + batch_updates.append(updates) + encrypted_count += 1 + + # Execute batch update if there are updates + if batch_updates: + # Use bulk update for better performance + for update_data in batch_updates: + row_id = update_data.pop("id") + if update_data: # Only update if there are fields to update + connection.execute(mcp_oauth.update().where(mcp_oauth.c.id == row_id).values(**update_data)) + + # Progress indicator for large datasets + if encrypted_count > 0 and encrypted_count % 10000 == 0: + print(f" Progress: Encrypted {encrypted_count} mcp_oauth records...") + + offset += BATCH_SIZE + + # For very large datasets, commit periodically to avoid long transactions + if encrypted_count > 0 and encrypted_count % 50000 == 0: + connection.commit() + + print(f"mcp_oauth: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields") + else: + print("mcp_oauth: No records need encryption") + + # Migrate mcp_server data + print("Migrating mcp_server encrypted fields...") + mcp_server = table( + "mcp_server", + column("id", String), + column("token", String), + column("token_enc", Text), + column("custom_headers", JSON), + column("custom_headers_enc", Text), + ) + + # Count total rows to process + total_count_result = connection.execute( + sa.select(sa.func.count()) + .select_from(mcp_server) + .where( + sa.and_( + sa.or_(mcp_server.c.token.isnot(None), mcp_server.c.custom_headers.isnot(None)), + # Only count rows that need encryption + sa.or_( + sa.and_(mcp_server.c.token.isnot(None), mcp_server.c.token_enc.is_(None)), + sa.and_(mcp_server.c.custom_headers.isnot(None), mcp_server.c.custom_headers_enc.is_(None)), + ), + ) + ) + ).scalar() + + if total_count_result and total_count_result > 0: + print(f"Found {total_count_result} mcp_server records that need encryption") + + encrypted_count = 0 + skipped_count = 0 + offset = 0 + + # Process in batches + while True: + # Select batch of rows + server_rows = connection.execute( + sa.select( + mcp_server.c.id, + mcp_server.c.token, + mcp_server.c.token_enc, + mcp_server.c.custom_headers, + mcp_server.c.custom_headers_enc, + ) + .where( + sa.and_( + sa.or_(mcp_server.c.token.isnot(None), mcp_server.c.custom_headers.isnot(None)), + # Only select rows that need encryption + sa.or_( + sa.and_(mcp_server.c.token.isnot(None), mcp_server.c.token_enc.is_(None)), + sa.and_(mcp_server.c.custom_headers.isnot(None), mcp_server.c.custom_headers_enc.is_(None)), + ), + ) + ) + .order_by(mcp_server.c.id) # Ensure consistent ordering + .limit(BATCH_SIZE) + .offset(offset) + ).fetchall() + + if not server_rows: + break # No more rows to process + + # Prepare batch updates + batch_updates = [] + + for row in server_rows: + updates = {"id": row.id} + has_updates = False + + # Encrypt token if present and not already encrypted + if row.token and not row.token_enc: + try: + updates["token_enc"] = CryptoUtils.encrypt(row.token, encryption_key) + has_updates = True + except Exception as e: + print(f"Warning: Failed to encrypt token for mcp_server id={row.id}: {e}") + elif row.token_enc: + skipped_count += 1 + + # Encrypt custom_headers if present (JSON field) and not already encrypted + if row.custom_headers and not row.custom_headers_enc: + try: + # Convert JSON to string for encryption + headers_json = json.dumps(row.custom_headers) + updates["custom_headers_enc"] = CryptoUtils.encrypt(headers_json, encryption_key) + has_updates = True + except Exception as e: + print(f"Warning: Failed to encrypt custom_headers for mcp_server id={row.id}: {e}") + elif row.custom_headers_enc: + skipped_count += 1 + + if has_updates: + batch_updates.append(updates) + encrypted_count += 1 + + # Execute batch update if there are updates + if batch_updates: + # Use bulk update for better performance + for update_data in batch_updates: + row_id = update_data.pop("id") + if update_data: # Only update if there are fields to update + connection.execute(mcp_server.update().where(mcp_server.c.id == row_id).values(**update_data)) + + # Progress indicator for large datasets + if encrypted_count > 0 and encrypted_count % 10000 == 0: + print(f" Progress: Encrypted {encrypted_count} mcp_server records...") + + offset += BATCH_SIZE + + # For very large datasets, commit periodically to avoid long transactions + if encrypted_count > 0 and encrypted_count % 50000 == 0: + connection.commit() + + print(f"mcp_server: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields") + else: + print("mcp_server: No records need encryption") + print("Migration complete. Plaintext columns are retained for rollback safety.") + # ### end Alembic commands ### + + +def downgrade() -> None: + op.drop_column("mcp_server", "custom_headers_enc") + op.drop_column("mcp_server", "token_enc") + op.drop_column("mcp_oauth", "client_secret_enc") + op.drop_column("mcp_oauth", "refresh_token_enc") + op.drop_column("mcp_oauth", "access_token_enc") + # ### end Alembic commands ### diff --git a/letta/helpers/crypto_utils.py b/letta/helpers/crypto_utils.py new file mode 100644 index 00000000..292e1a2f --- /dev/null +++ b/letta/helpers/crypto_utils.py @@ -0,0 +1,134 @@ +import base64 +import os +from typing import Optional + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + +from letta.settings import settings + + +class CryptoUtils: + """Utility class for AES-256-GCM encryption/decryption of sensitive data.""" + + # AES-256 requires 32 bytes key + KEY_SIZE = 32 + # GCM standard IV size is 12 bytes (96 bits) + IV_SIZE = 12 + # GCM tag size is 16 bytes (128 bits) + TAG_SIZE = 16 + # Salt size for key derivation + SALT_SIZE = 16 + + @classmethod + def _derive_key(cls, master_key: str, salt: bytes) -> bytes: + """Derive an AES key from the master key using PBKDF2.""" + kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=cls.KEY_SIZE, salt=salt, iterations=100000, backend=default_backend()) + return kdf.derive(master_key.encode()) + + @classmethod + def encrypt(cls, plaintext: str, master_key: Optional[str] = None) -> str: + """ + Encrypt a string using AES-256-GCM. + + Args: + plaintext: The string to encrypt + master_key: Optional master key (defaults to settings.encryption_key) + + Returns: + Base64 encoded string containing: salt + iv + ciphertext + tag + + Raises: + ValueError: If no encryption key is configured + """ + if master_key is None: + master_key = settings.encryption_key + + if not master_key: + raise ValueError("No encryption key configured. Set LETTA_ENCRYPTION_KEY environment variable.") + + # Generate random salt and IV + salt = os.urandom(cls.SALT_SIZE) + iv = os.urandom(cls.IV_SIZE) + + # Derive key from master key + key = cls._derive_key(master_key, salt) + + # Create cipher + cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=default_backend()) + encryptor = cipher.encryptor() + + # Encrypt the plaintext + ciphertext = encryptor.update(plaintext.encode()) + encryptor.finalize() + + # Get the authentication tag + tag = encryptor.tag + + # Combine salt + iv + ciphertext + tag + encrypted_data = salt + iv + ciphertext + tag + + # Return as base64 encoded string + return base64.b64encode(encrypted_data).decode("utf-8") + + @classmethod + def decrypt(cls, encrypted: str, master_key: Optional[str] = None) -> str: + """ + Decrypt a string that was encrypted using AES-256-GCM. + + Args: + encrypted: Base64 encoded encrypted string + master_key: Optional master key (defaults to settings.encryption_key) + + Returns: + The decrypted plaintext string + + Raises: + ValueError: If no encryption key is configured or decryption fails + """ + if master_key is None: + master_key = settings.encryption_key + + if not master_key: + raise ValueError("No encryption key configured. Set LETTA_ENCRYPTION_KEY environment variable.") + + try: + # Decode from base64 + encrypted_data = base64.b64decode(encrypted) + + # Extract components + salt = encrypted_data[: cls.SALT_SIZE] + iv = encrypted_data[cls.SALT_SIZE : cls.SALT_SIZE + cls.IV_SIZE] + ciphertext = encrypted_data[cls.SALT_SIZE + cls.IV_SIZE : -cls.TAG_SIZE] + tag = encrypted_data[-cls.TAG_SIZE :] + + # Derive key from master key + key = cls._derive_key(master_key, salt) + + # Create cipher + cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=default_backend()) + decryptor = cipher.decryptor() + + # Decrypt the ciphertext + plaintext = decryptor.update(ciphertext) + decryptor.finalize() + + return plaintext.decode("utf-8") + + except Exception as e: + raise ValueError(f"Failed to decrypt data: {str(e)}") + + @classmethod + def is_encrypted(cls, value: str) -> bool: + """ + Check if a string appears to be encrypted (base64 encoded with correct size). + + This is a heuristic check and may have false positives. + """ + try: + decoded = base64.b64decode(value) + # Check if length is consistent with our encryption format + # Minimum size: salt(16) + iv(12) + tag(16) + at least 1 byte of ciphertext + return len(decoded) >= cls.SALT_SIZE + cls.IV_SIZE + cls.TAG_SIZE + 1 + except Exception: + return False diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index f2d1bd15..f8057e07 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -18,6 +18,7 @@ from letta.orm.job import Job from letta.orm.job_messages import JobMessage from letta.orm.llm_batch_items import LLMBatchItem from letta.orm.llm_batch_job import LLMBatchJob +from letta.orm.mcp_oauth import MCPOAuth from letta.orm.mcp_server import MCPServer from letta.orm.message import Message from letta.orm.organization import Organization diff --git a/letta/orm/mcp_oauth.py b/letta/orm/mcp_oauth.py index e34f685a..163514b1 100644 --- a/letta/orm/mcp_oauth.py +++ b/letta/orm/mcp_oauth.py @@ -38,7 +38,11 @@ class MCPOAuth(SqlalchemyBase, OrganizationMixin, UserMixin): # Token data access_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth access token") + access_token_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted OAuth access token") + refresh_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth refresh token") + refresh_token_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted OAuth refresh token") + token_type: Mapped[str] = mapped_column(String(50), default="Bearer", doc="Token type") expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, doc="Token expiry time") scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth scope") @@ -46,6 +50,8 @@ class MCPOAuth(SqlalchemyBase, OrganizationMixin, UserMixin): # Client configuration client_id: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client ID") client_secret: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client secret") + client_secret_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted OAuth client secret") + redirect_uri: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth redirect URI") # Session state diff --git a/letta/orm/mcp_server.py b/letta/orm/mcp_server.py index 55a2a672..49cffb84 100644 --- a/letta/orm/mcp_server.py +++ b/letta/orm/mcp_server.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Optional -from sqlalchemy import JSON, String, UniqueConstraint +from sqlalchemy import JSON, String, Text, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column from letta.functions.mcp_client.types import StdioServerConfig @@ -39,9 +39,15 @@ class MCPServer(SqlalchemyBase, OrganizationMixin): # access token / api key for MCP servers that require authentication token: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The access token or api key for the MCP server") + # encrypted access token or api key for the MCP server + token_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted access token or api key for the MCP server") + # custom headers for authentication (key-value pairs) custom_headers: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Custom authentication headers as key-value pairs") + # encrypted custom headers for authentication (key-value pairs) + custom_headers_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted custom authentication headers") + # stdio server stdio_config: Mapped[Optional[StdioServerConfig]] = mapped_column( MCPStdioServerConfigColumn, nullable=True, doc="The configuration for the stdio server" diff --git a/letta/schemas/mcp.py b/letta/schemas/mcp.py index 5412bc33..6f273816 100644 --- a/letta/schemas/mcp.py +++ b/letta/schemas/mcp.py @@ -13,6 +13,7 @@ from letta.functions.mcp_client.types import ( ) from letta.orm.mcp_oauth import OAuthSessionStatus from letta.schemas.letta_base import LettaBase +from letta.schemas.secret import Secret, SecretDict class BaseMCPServer(LettaBase): @@ -29,6 +30,9 @@ class MCPServer(BaseMCPServer): token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for authentication)") custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs") + token_enc: Optional[str] = Field(None, description="Encrypted token") + custom_headers_enc: Optional[str] = Field(None, description="Encrypted custom headers") + # stdio config stdio_config: Optional[StdioServerConfig] = Field( None, description="The configuration for the server (MCP 'local' client will run this command)" @@ -41,18 +45,93 @@ class MCPServer(BaseMCPServer): last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.") + def get_token_secret(self) -> Secret: + """Get the token as a Secret object, preferring encrypted over plaintext.""" + return Secret.from_db(self.token_enc, self.token) + + def get_custom_headers_secret(self) -> SecretDict: + """Get custom headers as a SecretDict object, preferring encrypted over plaintext.""" + return SecretDict.from_db(self.custom_headers_enc, self.custom_headers) + + def set_token_secret(self, secret: Secret) -> None: + """Set token from a Secret object, updating both encrypted and plaintext fields.""" + secret_dict = secret.to_dict() + self.token_enc = secret_dict["encrypted"] + # Only set plaintext during migration phase + if not secret._was_encrypted: + self.token = secret_dict["plaintext"] + else: + self.token = None + + def set_custom_headers_secret(self, secret: SecretDict) -> None: + """Set custom headers from a SecretDict object, updating both fields.""" + secret_dict = secret.to_dict() + self.custom_headers_enc = secret_dict["encrypted"] + # Only set plaintext during migration phase + if not secret._was_encrypted: + self.custom_headers = secret_dict["plaintext"] + else: + self.custom_headers = None + + def model_dump(self, to_orm: bool = False, **kwargs): + """Override model_dump to handle encryption when saving to database.""" + import os + + # Check environment variable directly to handle test patching + encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") + if not encryption_key: + from letta.settings import settings + + encryption_key = settings.encryption_key + + data = super().model_dump(to_orm=to_orm, **kwargs) + + if to_orm and encryption_key: + # Temporarily set the encryption key for Secret/SecretDict + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = encryption_key + try: + # Encrypt token if present + if self.token is not None: + token_secret = Secret.from_plaintext(self.token) + secret_dict = token_secret.to_dict() + data["token_enc"] = secret_dict["encrypted"] + # Keep plaintext for dual-write during migration + data["token"] = secret_dict["plaintext"] + + # Encrypt custom headers if present + if self.custom_headers is not None: + headers_secret = SecretDict.from_plaintext(self.custom_headers) + secret_dict = headers_secret.to_dict() + data["custom_headers_enc"] = secret_dict["encrypted"] + # Keep plaintext for dual-write during migration + data["custom_headers"] = secret_dict["plaintext"] + finally: + settings.encryption_key = original_key + + return data + def to_config( self, environment_variables: Optional[Dict[str, str]] = None, resolve_variables: bool = True, ) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]: + # Get decrypted values for use in config + token_secret = self.get_token_secret() + token_plaintext = token_secret.get_plaintext() + + headers_secret = self.get_custom_headers_secret() + headers_plaintext = headers_secret.get_plaintext() + if self.server_type == MCPServerType.SSE: config = SSEServerConfig( server_name=self.server_name, server_url=self.server_url, - auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None, - auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None, - custom_headers=self.custom_headers, + auth_header=MCP_AUTH_HEADER_AUTHORIZATION if token_plaintext and not headers_plaintext else None, + auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {token_plaintext}" if token_plaintext and not headers_plaintext else None, + custom_headers=headers_plaintext, ) if resolve_variables: config.resolve_environment_variables(environment_variables) @@ -70,9 +149,9 @@ class MCPServer(BaseMCPServer): config = StreamableHTTPServerConfig( server_name=self.server_name, server_url=self.server_url, - auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None, - auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None, - custom_headers=self.custom_headers, + auth_header=MCP_AUTH_HEADER_AUTHORIZATION if token_plaintext and not headers_plaintext else None, + auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {token_plaintext}" if token_plaintext and not headers_plaintext else None, + custom_headers=headers_plaintext, ) if resolve_variables: config.resolve_environment_variables(environment_variables) @@ -138,11 +217,18 @@ class MCPOAuthSession(BaseMCPOAuth): expires_at: Optional[datetime] = Field(None, description="Token expiry time") scope: Optional[str] = Field(None, description="OAuth scope") + # Encrypted token fields (for internal use) + access_token_enc: Optional[str] = Field(None, description="Encrypted OAuth access token") + refresh_token_enc: Optional[str] = Field(None, description="Encrypted OAuth refresh token") + # Client configuration client_id: Optional[str] = Field(None, description="OAuth client ID") client_secret: Optional[str] = Field(None, description="OAuth client secret") redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI") + # Encrypted client secret (for internal use) + client_secret_enc: Optional[str] = Field(None, description="Encrypted OAuth client secret") + # Session state status: OAuthSessionStatus = Field(default=OAuthSessionStatus.PENDING, description="Session status") @@ -150,6 +236,93 @@ class MCPOAuthSession(BaseMCPOAuth): created_at: datetime = Field(default_factory=datetime.now, description="Session creation time") updated_at: datetime = Field(default_factory=datetime.now, description="Last update time") + def get_access_token_secret(self) -> Secret: + """Get the access token as a Secret object, preferring encrypted over plaintext.""" + return Secret.from_db(self.access_token_enc, self.access_token) + + def get_refresh_token_secret(self) -> Secret: + """Get the refresh token as a Secret object, preferring encrypted over plaintext.""" + return Secret.from_db(self.refresh_token_enc, self.refresh_token) + + def get_client_secret_secret(self) -> Secret: + """Get the client secret as a Secret object, preferring encrypted over plaintext.""" + return Secret.from_db(self.client_secret_enc, self.client_secret) + + def set_access_token_secret(self, secret: Secret) -> None: + """Set access token from a Secret object.""" + secret_dict = secret.to_dict() + self.access_token_enc = secret_dict["encrypted"] + if not secret._was_encrypted: + self.access_token = secret_dict["plaintext"] + else: + self.access_token = None + + def set_refresh_token_secret(self, secret: Secret) -> None: + """Set refresh token from a Secret object.""" + secret_dict = secret.to_dict() + self.refresh_token_enc = secret_dict["encrypted"] + if not secret._was_encrypted: + self.refresh_token = secret_dict["plaintext"] + else: + self.refresh_token = None + + def set_client_secret_secret(self, secret: Secret) -> None: + """Set client secret from a Secret object.""" + secret_dict = secret.to_dict() + self.client_secret_enc = secret_dict["encrypted"] + if not secret._was_encrypted: + self.client_secret = secret_dict["plaintext"] + else: + self.client_secret = None + + def model_dump(self, to_orm: bool = False, **kwargs): + """Override model_dump to handle encryption when saving to database.""" + import os + + # Check environment variable directly to handle test patching + encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") + if not encryption_key: + from letta.settings import settings + + encryption_key = settings.encryption_key + + data = super().model_dump(to_orm=to_orm, **kwargs) + + if to_orm and encryption_key: + # Temporarily set the encryption key for Secret + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = encryption_key + try: + # Encrypt access token if present + if self.access_token is not None: + token_secret = Secret.from_plaintext(self.access_token) + secret_dict = token_secret.to_dict() + data["access_token_enc"] = secret_dict["encrypted"] + # Keep plaintext for dual-write during migration + data["access_token"] = secret_dict["plaintext"] + + # Encrypt refresh token if present + if self.refresh_token is not None: + token_secret = Secret.from_plaintext(self.refresh_token) + secret_dict = token_secret.to_dict() + data["refresh_token_enc"] = secret_dict["encrypted"] + # Keep plaintext for dual-write during migration + data["refresh_token"] = secret_dict["plaintext"] + + # Encrypt client secret if present + if self.client_secret is not None: + secret = Secret.from_plaintext(self.client_secret) + secret_dict = secret.to_dict() + data["client_secret_enc"] = secret_dict["encrypted"] + # Keep plaintext for dual-write during migration + data["client_secret"] = secret_dict["plaintext"] + finally: + settings.encryption_key = original_key + + return data + class MCPOAuthSessionCreate(BaseMCPOAuth): """Create a new OAuth session.""" diff --git a/letta/schemas/secret.py b/letta/schemas/secret.py new file mode 100644 index 00000000..fe2c7cf0 --- /dev/null +++ b/letta/schemas/secret.py @@ -0,0 +1,241 @@ +import json +from typing import Any, Dict, Optional + +from pydantic import BaseModel, ConfigDict, PrivateAttr + +from letta.helpers.crypto_utils import CryptoUtils + + +class Secret(BaseModel): + """ + A wrapper class for encrypted credentials that keeps values encrypted in memory. + + This class ensures that sensitive data remains encrypted as much as possible + while passing through the codebase, only decrypting when absolutely necessary. + + TODO: Once we deprecate plaintext columns in the database: + - Remove the dual-write logic in to_dict() + - Remove the from_db() method's plaintext_value parameter + - Remove the _was_encrypted flag (no longer needed for migration) + - Simplify get_plaintext() to only handle encrypted values + """ + + # Store the encrypted value + _encrypted_value: Optional[str] = PrivateAttr(default=None) + # Cache the decrypted value to avoid repeated decryption + _plaintext_cache: Optional[str] = PrivateAttr(default=None) + # Flag to indicate if the value was originally encrypted + _was_encrypted: bool = PrivateAttr(default=False) + + model_config = ConfigDict(frozen=True) + + @classmethod + def from_plaintext(cls, value: Optional[str]) -> "Secret": + """ + Create a Secret from a plaintext value, encrypting it immediately. + + Args: + value: The plaintext value to encrypt + + Returns: + A Secret instance with the encrypted value + """ + if value is None: + instance = cls() + instance._encrypted_value = None + instance._was_encrypted = False + return instance + + encrypted = CryptoUtils.encrypt(value) + instance = cls() + instance._encrypted_value = encrypted + instance._was_encrypted = False + return instance + + @classmethod + def from_encrypted(cls, encrypted_value: Optional[str]) -> "Secret": + """ + Create a Secret from an already encrypted value. + + Args: + encrypted_value: The encrypted value + + Returns: + A Secret instance + """ + instance = cls() + instance._encrypted_value = encrypted_value + instance._was_encrypted = True + return instance + + @classmethod + def from_db(cls, encrypted_value: Optional[str], plaintext_value: Optional[str]) -> "Secret": + """ + Create a Secret from database values during migration phase. + + Prefers encrypted value if available, falls back to plaintext. + + Args: + encrypted_value: The encrypted value from the database + plaintext_value: The plaintext value from the database + + Returns: + A Secret instance + """ + if encrypted_value is not None: + return cls.from_encrypted(encrypted_value) + elif plaintext_value is not None: + return cls.from_plaintext(plaintext_value) + else: + return cls.from_plaintext(None) + + def get_encrypted(self) -> Optional[str]: + """ + Get the encrypted value. + + Returns: + The encrypted value, or None if the secret is empty + """ + return self._encrypted_value + + def get_plaintext(self) -> Optional[str]: + """ + Get the decrypted plaintext value. + + This should only be called when the plaintext is actually needed, + such as when making an external API call. + + Returns: + The decrypted plaintext value + """ + if self._encrypted_value is None: + return None + + # Use cached value if available + if self._plaintext_cache is not None: + return self._plaintext_cache + + # Decrypt and cache + try: + plaintext = CryptoUtils.decrypt(self._encrypted_value) + # Note: We can't actually cache due to frozen=True, but in practice + # we'll create new instances rather than mutating + return plaintext + except Exception: + # If decryption fails and this wasn't originally encrypted, + # it might be that the value is actually plaintext (during migration) + if not self._was_encrypted: + return None + raise + + def is_empty(self) -> bool: + """Check if the secret is empty/None.""" + return self._encrypted_value is None + + def __str__(self) -> str: + """String representation that doesn't expose the actual value.""" + if self.is_empty(): + return "" + return "" + + def __repr__(self) -> str: + """Representation that doesn't expose the actual value.""" + return self.__str__() + + def to_dict(self) -> Dict[str, Any]: + """ + Convert to dictionary for database storage. + + Returns both encrypted and plaintext values for dual-write during migration. + """ + return {"encrypted": self.get_encrypted(), "plaintext": self.get_plaintext() if not self._was_encrypted else None} + + def __eq__(self, other: Any) -> bool: + """ + Compare two secrets by their plaintext values. + + Note: This decrypts both values, so use sparingly. + """ + if not isinstance(other, Secret): + return False + return self.get_plaintext() == other.get_plaintext() + + +class SecretDict(BaseModel): + """ + A wrapper for dictionaries containing sensitive key-value pairs. + + Used for custom headers and other key-value configurations. + + TODO: Once we deprecate plaintext columns in the database: + - Remove the dual-write logic in to_dict() + - Remove the from_db() method's plaintext_value parameter + - Remove the _was_encrypted flag (no longer needed for migration) + - Simplify get_plaintext() to only handle encrypted JSON values + """ + + _encrypted_value: Optional[str] = PrivateAttr(default=None) + _plaintext_cache: Optional[Dict[str, str]] = PrivateAttr(default=None) + _was_encrypted: bool = PrivateAttr(default=False) + + model_config = ConfigDict(frozen=True) + + @classmethod + def from_plaintext(cls, value: Optional[Dict[str, str]]) -> "SecretDict": + """Create a SecretDict from a plaintext dictionary.""" + if value is None: + instance = cls() + instance._encrypted_value = None + instance._was_encrypted = False + return instance + + # Serialize to JSON then encrypt + json_str = json.dumps(value) + encrypted = CryptoUtils.encrypt(json_str) + instance = cls() + instance._encrypted_value = encrypted + instance._was_encrypted = False + return instance + + @classmethod + def from_encrypted(cls, encrypted_value: Optional[str]) -> "SecretDict": + """Create a SecretDict from an encrypted value.""" + instance = cls() + instance._encrypted_value = encrypted_value + instance._was_encrypted = True + return instance + + @classmethod + def from_db(cls, encrypted_value: Optional[str], plaintext_value: Optional[Dict[str, str]]) -> "SecretDict": + """Create a SecretDict from database values during migration phase.""" + if encrypted_value is not None: + return cls.from_encrypted(encrypted_value) + elif plaintext_value is not None: + return cls.from_plaintext(plaintext_value) + else: + return cls.from_plaintext(None) + + def get_encrypted(self) -> Optional[str]: + """Get the encrypted value.""" + return self._encrypted_value + + def get_plaintext(self) -> Optional[Dict[str, str]]: + """Get the decrypted dictionary.""" + if self._encrypted_value is None: + return None + + try: + decrypted_json = CryptoUtils.decrypt(self._encrypted_value) + return json.loads(decrypted_json) + except Exception: + if not self._was_encrypted: + return None + raise + + def is_empty(self) -> bool: + """Check if the secret dict is empty/None.""" + return self._encrypted_value is None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for database storage.""" + return {"encrypted": self.get_encrypted(), "plaintext": self.get_plaintext() if not self._was_encrypted else None} diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 0db5df9c..5c7200d9 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -728,35 +728,16 @@ async def add_mcp_server_to_config( return await server.add_mcp_server_to_config(server_config=request, allow_upsert=True) else: # log to DB - from letta.schemas.mcp import MCPServer - - if isinstance(request, StdioServerConfig): - mapped_request = MCPServer(server_name=request.server_name, server_type=request.type, stdio_config=request) - # don't allow stdio servers - if tool_settings.mcp_disable_stdio: # protected server - raise HTTPException( - status_code=400, - detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server", - ) - elif isinstance(request, SSEServerConfig): - mapped_request = MCPServer( - server_name=request.server_name, - server_type=request.type, - server_url=request.server_url, - token=request.resolve_token(), - custom_headers=request.custom_headers, - ) - elif isinstance(request, StreamableHTTPServerConfig): - mapped_request = MCPServer( - server_name=request.server_name, - server_type=request.type, - server_url=request.server_url, - token=request.resolve_token(), - custom_headers=request.custom_headers, + # Check if stdio servers are disabled + if isinstance(request, StdioServerConfig) and tool_settings.mcp_disable_stdio: + raise HTTPException( + status_code=400, + detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server", ) # Create MCP server and optimistically sync tools - await server.mcp_manager.create_mcp_server_with_tools(mapped_request, actor=actor) + # The mcp_manager will handle encryption of sensitive fields + await server.mcp_manager.create_mcp_server_from_config_with_tools(request, actor=actor) # TODO: don't do this in the future (just return MCPServer) all_servers = await server.mcp_manager.list_mcp_servers(actor=actor) diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index e3d2e671..08bf38a4 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -35,6 +35,7 @@ from letta.schemas.mcp import ( UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer, ) +from letta.schemas.secret import Secret, SecretDict from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry @@ -354,6 +355,69 @@ class MCPManager: logger.error(f"Failed to create MCP server: {e}") raise + @enforce_types + async def create_mcp_server_from_config( + self, server_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], actor: PydanticUser + ) -> MCPServer: + """ + Create an MCP server from a config object, handling encryption of sensitive fields. + + This method converts the server config to an MCPServer model and encrypts + sensitive fields like tokens and custom headers. + """ + # Create base MCPServer object + if isinstance(server_config, StdioServerConfig): + mcp_server = MCPServer(server_name=server_config.server_name, server_type=server_config.type, stdio_config=server_config) + elif isinstance(server_config, SSEServerConfig): + mcp_server = MCPServer( + server_name=server_config.server_name, + server_type=server_config.type, + server_url=server_config.server_url, + ) + # Encrypt sensitive fields + token = server_config.resolve_token() + if token: + token_secret = Secret.from_plaintext(token) + mcp_server.set_token_secret(token_secret) + if server_config.custom_headers: + headers_secret = SecretDict.from_plaintext(server_config.custom_headers) + mcp_server.set_custom_headers_secret(headers_secret) + + elif isinstance(server_config, StreamableHTTPServerConfig): + mcp_server = MCPServer( + server_name=server_config.server_name, + server_type=server_config.type, + server_url=server_config.server_url, + ) + # Encrypt sensitive fields + token = server_config.resolve_token() + if token: + token_secret = Secret.from_plaintext(token) + mcp_server.set_token_secret(token_secret) + if server_config.custom_headers: + headers_secret = SecretDict.from_plaintext(server_config.custom_headers) + mcp_server.set_custom_headers_secret(headers_secret) + else: + raise ValueError(f"Unsupported server config type: {type(server_config)}") + + return mcp_server + + @enforce_types + async def create_mcp_server_from_config_with_tools( + self, server_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], actor: PydanticUser + ) -> MCPServer: + """ + Create an MCP server from a config object and optimistically sync its tools. + + This method handles encryption of sensitive fields and then creates the server + with automatic tool synchronization. + """ + # Convert config to MCPServer with encryption + mcp_server = await self.create_mcp_server_from_config(server_config, actor) + + # Create the server with tools + return await self.create_mcp_server_with_tools(mcp_server, actor) + @enforce_types async def create_mcp_server_with_tools(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer: """ @@ -420,10 +484,33 @@ class MCPManager: # Update tool attributes with only the fields that were explicitly set update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True) - # Ensure custom_headers None is stored as SQL NULL, not JSON null - if update_data.get("custom_headers") is None: - update_data.pop("custom_headers", None) - setattr(mcp_server, "custom_headers", null()) + # Handle encryption for token if provided + if "token" in update_data and update_data["token"] is not None: + token_secret = Secret.from_plaintext(update_data["token"]) + secret_dict = token_secret.to_dict() + update_data["token_enc"] = secret_dict["encrypted"] + # During migration phase, also update plaintext + if not token_secret._was_encrypted: + update_data["token"] = secret_dict["plaintext"] + else: + update_data["token"] = None + + # Handle encryption for custom_headers if provided + if "custom_headers" in update_data: + if update_data["custom_headers"] is not None: + headers_secret = SecretDict.from_plaintext(update_data["custom_headers"]) + secret_dict = headers_secret.to_dict() + update_data["custom_headers_enc"] = secret_dict["encrypted"] + # During migration phase, also update plaintext + if not headers_secret._was_encrypted: + update_data["custom_headers"] = secret_dict["plaintext"] + else: + update_data["custom_headers"] = None + else: + # Ensure custom_headers None is stored as SQL NULL, not JSON null + update_data.pop("custom_headers", None) + setattr(mcp_server, "custom_headers", null()) + setattr(mcp_server, "custom_headers_enc", None) for key, value in update_data.items(): setattr(mcp_server, key, value) @@ -664,6 +751,86 @@ class MCPManager: raise ValueError(f"Unsupported server config type: {type(server_config)}") # OAuth-related methods + def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession: + """ + Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields. + """ + # Check for encryption key from env or settings + import os + + from letta.settings import settings + + encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key + + # Get decrypted values using the dual-read approach + access_token = None + if oauth_session.access_token_enc or oauth_session.access_token: + if encryption_key: + # Temporarily set the key for Secret + original_key = settings.encryption_key + settings.encryption_key = encryption_key + try: + secret = Secret.from_db(oauth_session.access_token_enc, oauth_session.access_token) + access_token = secret.get_plaintext() + finally: + settings.encryption_key = original_key + else: + # No encryption key, use plaintext if available + access_token = oauth_session.access_token + + refresh_token = None + if oauth_session.refresh_token_enc or oauth_session.refresh_token: + if encryption_key: + # Temporarily set the key for Secret + original_key = settings.encryption_key + settings.encryption_key = encryption_key + try: + secret = Secret.from_db(oauth_session.refresh_token_enc, oauth_session.refresh_token) + refresh_token = secret.get_plaintext() + finally: + settings.encryption_key = original_key + else: + # No encryption key, use plaintext if available + refresh_token = oauth_session.refresh_token + + client_secret = None + if oauth_session.client_secret_enc or oauth_session.client_secret: + if encryption_key: + # Temporarily set the key for Secret + original_key = settings.encryption_key + settings.encryption_key = encryption_key + try: + secret = Secret.from_db(oauth_session.client_secret_enc, oauth_session.client_secret) + client_secret = secret.get_plaintext() + finally: + settings.encryption_key = original_key + else: + # No encryption key, use plaintext if available + client_secret = oauth_session.client_secret + + return MCPOAuthSession( + id=oauth_session.id, + state=oauth_session.state, + server_id=oauth_session.server_id, + server_url=oauth_session.server_url, + server_name=oauth_session.server_name, + user_id=oauth_session.user_id, + organization_id=oauth_session.organization_id, + authorization_url=oauth_session.authorization_url, + authorization_code=oauth_session.authorization_code, + access_token=access_token, + refresh_token=refresh_token, + token_type=oauth_session.token_type, + expires_at=oauth_session.expires_at, + scope=oauth_session.scope, + client_id=oauth_session.client_id, + client_secret=client_secret, + redirect_uri=oauth_session.redirect_uri, + status=oauth_session.status, + created_at=oauth_session.created_at, + updated_at=oauth_session.updated_at, + ) + @enforce_types async def create_oauth_session(self, session_create: MCPOAuthSessionCreate, actor: PydanticUser) -> MCPOAuthSession: """Create a new OAuth session for MCP server authentication.""" @@ -682,18 +849,8 @@ class MCPManager: ) oauth_session = await oauth_session.create_async(session, actor=actor) - # Convert to Pydantic model - return MCPOAuthSession( - id=oauth_session.id, - state=oauth_session.state, - server_url=oauth_session.server_url, - server_name=oauth_session.server_name, - user_id=oauth_session.user_id, - organization_id=oauth_session.organization_id, - status=oauth_session.status, - created_at=oauth_session.created_at, - updated_at=oauth_session.updated_at, - ) + # Convert to Pydantic model - note: new sessions won't have tokens yet + return self._oauth_orm_to_pydantic(oauth_session) @enforce_types async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]: @@ -701,27 +858,7 @@ class MCPManager: async with db_registry.async_session() as session: try: oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor) - return MCPOAuthSession( - id=oauth_session.id, - state=oauth_session.state, - server_url=oauth_session.server_url, - server_name=oauth_session.server_name, - user_id=oauth_session.user_id, - organization_id=oauth_session.organization_id, - authorization_url=oauth_session.authorization_url, - authorization_code=oauth_session.authorization_code, - access_token=oauth_session.access_token, - refresh_token=oauth_session.refresh_token, - token_type=oauth_session.token_type, - expires_at=oauth_session.expires_at, - scope=oauth_session.scope, - client_id=oauth_session.client_id, - client_secret=oauth_session.client_secret, - redirect_uri=oauth_session.redirect_uri, - status=oauth_session.status, - created_at=oauth_session.created_at, - updated_at=oauth_session.updated_at, - ) + return self._oauth_orm_to_pydantic(oauth_session) except NoResultFound: return None @@ -747,27 +884,7 @@ class MCPManager: if not oauth_session: return None - return MCPOAuthSession( - id=oauth_session.id, - state=oauth_session.state, - server_url=oauth_session.server_url, - server_name=oauth_session.server_name, - user_id=oauth_session.user_id, - organization_id=oauth_session.organization_id, - authorization_url=oauth_session.authorization_url, - authorization_code=oauth_session.authorization_code, - access_token=oauth_session.access_token, - refresh_token=oauth_session.refresh_token, - token_type=oauth_session.token_type, - expires_at=oauth_session.expires_at, - scope=oauth_session.scope, - client_id=oauth_session.client_id, - client_secret=oauth_session.client_secret, - redirect_uri=oauth_session.redirect_uri, - status=oauth_session.status, - created_at=oauth_session.created_at, - updated_at=oauth_session.updated_at, - ) + return self._oauth_orm_to_pydantic(oauth_session) @enforce_types async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession: @@ -780,10 +897,59 @@ class MCPManager: oauth_session.authorization_url = session_update.authorization_url if session_update.authorization_code is not None: oauth_session.authorization_code = session_update.authorization_code + + # Handle encryption for access_token if session_update.access_token is not None: - oauth_session.access_token = session_update.access_token + # Check for encryption key from env or settings + import os + + from letta.settings import settings + + encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key + + if encryption_key: + # Temporarily set the key for Secret + original_key = settings.encryption_key + settings.encryption_key = encryption_key + try: + token_secret = Secret.from_plaintext(session_update.access_token) + secret_dict = token_secret.to_dict() + oauth_session.access_token_enc = secret_dict["encrypted"] + # During migration phase, also update plaintext + oauth_session.access_token = secret_dict["plaintext"] if not token_secret._was_encrypted else None + finally: + settings.encryption_key = original_key + else: + # No encryption, store plaintext + oauth_session.access_token = session_update.access_token + oauth_session.access_token_enc = None + + # Handle encryption for refresh_token if session_update.refresh_token is not None: - oauth_session.refresh_token = session_update.refresh_token + # Check for encryption key from env or settings + import os + + from letta.settings import settings + + encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key + + if encryption_key: + # Temporarily set the key for Secret + original_key = settings.encryption_key + settings.encryption_key = encryption_key + try: + token_secret = Secret.from_plaintext(session_update.refresh_token) + secret_dict = token_secret.to_dict() + oauth_session.refresh_token_enc = secret_dict["encrypted"] + # During migration phase, also update plaintext + oauth_session.refresh_token = secret_dict["plaintext"] if not token_secret._was_encrypted else None + finally: + settings.encryption_key = original_key + else: + # No encryption, store plaintext + oauth_session.refresh_token = session_update.refresh_token + oauth_session.refresh_token_enc = None + if session_update.token_type is not None: oauth_session.token_type = session_update.token_type if session_update.expires_at is not None: @@ -792,8 +958,33 @@ class MCPManager: oauth_session.scope = session_update.scope if session_update.client_id is not None: oauth_session.client_id = session_update.client_id + + # Handle encryption for client_secret if session_update.client_secret is not None: - oauth_session.client_secret = session_update.client_secret + # Check for encryption key from env or settings + import os + + from letta.settings import settings + + encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key + + if encryption_key: + # Temporarily set the key for Secret + original_key = settings.encryption_key + settings.encryption_key = encryption_key + try: + secret_secret = Secret.from_plaintext(session_update.client_secret) + secret_dict = secret_secret.to_dict() + oauth_session.client_secret_enc = secret_dict["encrypted"] + # During migration phase, also update plaintext + oauth_session.client_secret = secret_dict["plaintext"] if not secret_secret._was_encrypted else None + finally: + settings.encryption_key = original_key + else: + # No encryption, store plaintext + oauth_session.client_secret = session_update.client_secret + oauth_session.client_secret_enc = None + if session_update.redirect_uri is not None: oauth_session.redirect_uri = session_update.redirect_uri if session_update.status is not None: @@ -804,27 +995,7 @@ class MCPManager: oauth_session = await oauth_session.update_async(db_session=session, actor=actor) - return MCPOAuthSession( - id=oauth_session.id, - state=oauth_session.state, - server_url=oauth_session.server_url, - server_name=oauth_session.server_name, - user_id=oauth_session.user_id, - organization_id=oauth_session.organization_id, - authorization_url=oauth_session.authorization_url, - authorization_code=oauth_session.authorization_code, - access_token=oauth_session.access_token, - refresh_token=oauth_session.refresh_token, - token_type=oauth_session.token_type, - expires_at=oauth_session.expires_at, - scope=oauth_session.scope, - client_id=oauth_session.client_id, - client_secret=oauth_session.client_secret, - redirect_uri=oauth_session.redirect_uri, - status=oauth_session.status, - created_at=oauth_session.created_at, - updated_at=oauth_session.updated_at, - ) + return self._oauth_orm_to_pydantic(oauth_session) @enforce_types async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None: diff --git a/tests/test_crypto_utils.py b/tests/test_crypto_utils.py new file mode 100644 index 00000000..0922ce84 --- /dev/null +++ b/tests/test_crypto_utils.py @@ -0,0 +1,232 @@ +import base64 +import json +import os +from unittest.mock import patch + +import pytest + +from letta.helpers.crypto_utils import CryptoUtils + + +class TestCryptoUtils: + """Test suite for CryptoUtils encryption/decryption functionality.""" + + # Mock master keys for testing + MOCK_KEY_1 = "test-master-key-1234567890abcdef" + MOCK_KEY_2 = "another-test-key-fedcba0987654321" + + def test_encrypt_decrypt_roundtrip(self): + """Test that encryption followed by decryption returns the original value.""" + test_cases = [ + "simple text", + "text with special chars: !@#$%^&*()", + "unicode text: 你好世界 🌍", + "very long text " * 1000, + '{"json": "data", "nested": {"key": "value"}}', + "", # Empty string + ] + + for plaintext in test_cases: + encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1) + assert encrypted != plaintext, f"Encryption failed for: {plaintext[:50]}" + # Encrypted value is base64 encoded + assert len(encrypted) > 0, "Encrypted value should not be empty" + + decrypted = CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1) + assert decrypted == plaintext, f"Roundtrip failed for: {plaintext[:50]}" + + def test_encrypt_with_different_keys(self): + """Test that different keys produce different ciphertexts.""" + plaintext = "sensitive data" + + encrypted1 = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1) + encrypted2 = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_2) + + # Different keys should produce different ciphertexts + assert encrypted1 != encrypted2 + + # Each should decrypt correctly with its own key + assert CryptoUtils.decrypt(encrypted1, self.MOCK_KEY_1) == plaintext + assert CryptoUtils.decrypt(encrypted2, self.MOCK_KEY_2) == plaintext + + def test_decrypt_with_wrong_key_fails(self): + """Test that decryption with wrong key raises an error.""" + plaintext = "secret message" + encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1) + + with pytest.raises(Exception): # Could be ValueError or cryptography exception + CryptoUtils.decrypt(encrypted, self.MOCK_KEY_2) + + def test_encrypt_none_value(self): + """Test handling of None values.""" + # Encrypt None should raise TypeError (None has no encode method) + with pytest.raises((TypeError, AttributeError)): + CryptoUtils.encrypt(None, self.MOCK_KEY_1) + + def test_decrypt_none_value(self): + """Test that decrypting None raises an error.""" + with pytest.raises(ValueError): + CryptoUtils.decrypt(None, self.MOCK_KEY_1) + + def test_decrypt_empty_string(self): + """Test that decrypting empty string raises an error.""" + with pytest.raises(Exception): # base64 decode error + CryptoUtils.decrypt("", self.MOCK_KEY_1) + + def test_decrypt_plaintext_value(self): + """Test that decrypting non-encrypted value raises an error.""" + plaintext = "not encrypted" + with pytest.raises(Exception): # Will fail base64 decode or decryption + CryptoUtils.decrypt(plaintext, self.MOCK_KEY_1) + + def test_encrypted_format_structure(self): + """Test the structure of encrypted values.""" + plaintext = "test data" + encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1) + + # Should be base64 encoded + encrypted_data = encrypted + + # Should be valid base64 + try: + decoded = base64.b64decode(encrypted_data) + assert len(decoded) > 0 + except Exception as e: + pytest.fail(f"Invalid base64 encoding: {e}") + + # Decoded data should contain salt, IV, tag, and ciphertext + # Total should be at least SALT_SIZE + IV_SIZE + TAG_SIZE bytes + min_size = CryptoUtils.SALT_SIZE + CryptoUtils.IV_SIZE + CryptoUtils.TAG_SIZE + assert len(decoded) >= min_size + + def test_deterministic_with_same_salt(self): + """Test that encryption is deterministic when using the same salt (for testing).""" + plaintext = "deterministic test" + + # Note: In production, each encryption generates a random salt + # This test verifies the encryption mechanism itself + encrypted1 = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1) + encrypted2 = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1) + + # Due to random salt, these should be different + assert encrypted1 != encrypted2 + + # But both should decrypt to the same value + assert CryptoUtils.decrypt(encrypted1, self.MOCK_KEY_1) == plaintext + assert CryptoUtils.decrypt(encrypted2, self.MOCK_KEY_1) == plaintext + + def test_encrypt_uses_env_key_when_none_provided(self): + """Test that encryption uses environment key when no key is provided.""" + from letta.settings import settings + + # Mock the settings to have an encryption key + original_key = settings.encryption_key + settings.encryption_key = "env-test-key-123" + + try: + plaintext = "test with env key" + + # Should use key from settings + encrypted = CryptoUtils.encrypt(plaintext) + assert len(encrypted) > 0 + + # Should decrypt with same key + decrypted = CryptoUtils.decrypt(encrypted) + assert decrypted == plaintext + finally: + # Restore original key + settings.encryption_key = original_key + + def test_encrypt_without_key_raises_error(self): + """Test that encryption without any key raises an error.""" + from letta.settings import settings + + # Mock settings to have no encryption key + original_key = settings.encryption_key + settings.encryption_key = None + + try: + with pytest.raises(ValueError, match="No encryption key configured"): + CryptoUtils.encrypt("test data") + finally: + # Restore original key + settings.encryption_key = original_key + + def test_large_data_encryption(self): + """Test encryption of large data.""" + # Create 10MB of data + large_data = "x" * (10 * 1024 * 1024) + + encrypted = CryptoUtils.encrypt(large_data, self.MOCK_KEY_1) + assert len(encrypted) > 0 + assert encrypted != large_data + + decrypted = CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1) + assert decrypted == large_data + + def test_json_data_encryption(self): + """Test encryption of JSON data.""" + json_data = { + "user": "test_user", + "token": "secret_token_123", + "nested": {"api_key": "sk-1234567890", "headers": {"Authorization": "Bearer token"}}, + } + + json_str = json.dumps(json_data) + encrypted = CryptoUtils.encrypt(json_str, self.MOCK_KEY_1) + + decrypted_str = CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1) + decrypted_data = json.loads(decrypted_str) + + assert decrypted_data == json_data + + def test_invalid_encrypted_format(self): + """Test handling of invalid encrypted data format.""" + invalid_cases = [ + "invalid-base64!@#", # Invalid base64 + "dGVzdA==", # Valid base64 but too short for encrypted data + ] + + for invalid in invalid_cases: + with pytest.raises(Exception): # Could be various exceptions + CryptoUtils.decrypt(invalid, self.MOCK_KEY_1) + + def test_key_derivation_consistency(self): + """Test that key derivation is consistent.""" + plaintext = "test key derivation" + + # Multiple encryptions with same key should work + encrypted_values = [] + for _ in range(5): + encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1) + encrypted_values.append(encrypted) + + # All should decrypt correctly + for encrypted in encrypted_values: + assert CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1) == plaintext + + def test_special_characters_in_key(self): + """Test encryption with keys containing special characters.""" + special_key = "key-with-special-chars!@#$%^&*()_+" + plaintext = "test data" + + encrypted = CryptoUtils.encrypt(plaintext, special_key) + decrypted = CryptoUtils.decrypt(encrypted, special_key) + + assert decrypted == plaintext + + def test_whitespace_handling(self): + """Test encryption of strings with various whitespace.""" + test_cases = [ + " leading spaces", + "trailing spaces ", + " both sides ", + "multiple\n\nlines", + "\ttabs\there\t", + "mixed \t\n whitespace \r\n", + ] + + for plaintext in test_cases: + encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1) + decrypted = CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1) + assert decrypted == plaintext, f"Whitespace handling failed for: {repr(plaintext)}" diff --git a/tests/test_mcp_encryption.py b/tests/test_mcp_encryption.py new file mode 100644 index 00000000..bcd2e827 --- /dev/null +++ b/tests/test_mcp_encryption.py @@ -0,0 +1,407 @@ +""" +Integration tests for MCP server and OAuth session encryption. +Tests the end-to-end encryption functionality in the MCP manager. +""" + +import json +import os +from datetime import datetime, timezone +from unittest.mock import AsyncMock, Mock, patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select + +from letta.config import LettaConfig +from letta.helpers.crypto_utils import CryptoUtils +from letta.orm import MCPOAuth, MCPServer as ORMMCPServer +from letta.schemas.mcp import ( + MCPOAuthSessionCreate, + MCPServer as PydanticMCPServer, + MCPServerType, + SSEServerConfig, + StdioServerConfig, +) +from letta.schemas.secret import Secret, SecretDict +from letta.server.db import db_registry +from letta.server.server import SyncServer +from letta.services.mcp_manager import MCPManager + + +@pytest.fixture(scope="module") +def server(): + """Fixture to create and return a SyncServer instance with MCP manager.""" + config = LettaConfig.load() + config.save() + server = SyncServer(init_with_default_org_and_user=False) + return server + + +class TestMCPServerEncryption: + """Test MCP server encryption functionality.""" + + MOCK_ENCRYPTION_KEY = "test-mcp-encryption-key-123456" + + @pytest.mark.asyncio + @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) + @patch("letta.services.mcp_manager.MCPManager.get_mcp_client") + async def test_create_mcp_server_with_token_encryption(self, mock_get_client, server, default_user): + """Test that MCP server tokens are encrypted when stored.""" + # Mock the MCP client + mock_client = AsyncMock() + mock_client.list_tools.return_value = [] + mock_get_client.return_value = mock_client + + # Create MCP server with token + server_name = f"test_encrypted_server_{uuid4().hex[:8]}" + token = "super-secret-api-token-12345" + server_url = "https://api.example.com/mcp" + + mcp_server = PydanticMCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url, token=token) + + created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user) + + # Verify server was created + assert created_server.server_name == server_name + assert created_server.server_type == MCPServerType.SSE + + # Check database directly to verify encryption + async with db_registry.async_session() as session: + result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id)) + db_server = result.scalar_one() + + # Token should be encrypted in database + assert db_server.token_enc is not None + assert db_server.token_enc != token # Should not be plaintext + + # Decrypt to verify correctness + decrypted_token = CryptoUtils.decrypt(db_server.token_enc, self.MOCK_ENCRYPTION_KEY) + assert decrypted_token == token + + # Legacy plaintext column should be None (or empty for dual-write) + # During migration phase, might store both + if db_server.token: + assert db_server.token == token # Dual-write phase + + # Clean up + await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + + @pytest.mark.asyncio + @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) + @patch("letta.services.mcp_manager.MCPManager.get_mcp_client") + async def test_create_mcp_server_with_custom_headers_encryption(self, mock_get_client, server, default_user): + """Test that MCP server custom headers are encrypted when stored.""" + # Mock the MCP client + mock_client = AsyncMock() + mock_client.list_tools.return_value = [] + mock_get_client.return_value = mock_client + + server_name = f"test_headers_server_{uuid4().hex[:8]}" + custom_headers = {"Authorization": "Bearer secret-token-xyz", "X-API-Key": "api-key-123456", "X-Custom-Header": "custom-value"} + server_url = "https://api.example.com/mcp" + + mcp_server = PydanticMCPServer( + server_name=server_name, server_type=MCPServerType.STREAMABLE_HTTP, server_url=server_url, custom_headers=custom_headers + ) + + created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user) + + # Check database directly + async with db_registry.async_session() as session: + result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id)) + db_server = result.scalar_one() + + # Custom headers should be encrypted as JSON + assert db_server.custom_headers_enc is not None + + # Decrypt and parse JSON + decrypted_json = CryptoUtils.decrypt(db_server.custom_headers_enc, self.MOCK_ENCRYPTION_KEY) + decrypted_headers = json.loads(decrypted_json) + assert decrypted_headers == custom_headers + + # Clean up + await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + + @pytest.mark.asyncio + @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) + async def test_retrieve_mcp_server_decrypts_values(self, server, default_user): + """Test that retrieving MCP server decrypts encrypted values.""" + # Manually insert encrypted server into database + server_id = f"mcp_server-{uuid4().hex[:8]}" + server_name = f"test_decrypt_server_{uuid4().hex[:8]}" + plaintext_token = "decryption-test-token" + encrypted_token = CryptoUtils.encrypt(plaintext_token, self.MOCK_ENCRYPTION_KEY) + + async with db_registry.async_session() as session: + db_server = ORMMCPServer( + id=server_id, + server_name=server_name, + server_type=MCPServerType.SSE.value, + server_url="https://test.com", + token_enc=encrypted_token, + token=None, # No plaintext + created_by_id=default_user.id, + last_updated_by_id=default_user.id, + organization_id=default_user.organization_id, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(db_server) + await session.commit() + + # Retrieve server directly by ID to avoid issues with other servers in DB + test_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user) + + assert test_server is not None + assert test_server.server_name == server_name + # Token should be decrypted when accessed via the secret method + # Ensure encryption key is available for decryption + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_ENCRYPTION_KEY + try: + token_secret = test_server.get_token_secret() + assert token_secret.get_plaintext() == plaintext_token + finally: + settings.encryption_key = original_key + + # Clean up + async with db_registry.async_session() as session: + result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == server_id)) + db_server = result.scalar_one() + await session.delete(db_server) + await session.commit() + + @pytest.mark.asyncio + @patch.dict(os.environ, {}, clear=True) # No encryption key + @patch("letta.services.mcp_manager.MCPManager.get_mcp_client") + async def test_create_mcp_server_without_encryption_key(self, mock_get_client, server, default_user): + """Test that MCP servers work without encryption key (backward compatibility).""" + # Remove encryption key + os.environ.pop("LETTA_ENCRYPTION_KEY", None) + + # Mock the MCP client + mock_client = AsyncMock() + mock_client.list_tools.return_value = [] + mock_get_client.return_value = mock_client + + server_name = f"test_no_encrypt_server_{uuid4().hex[:8]}" + token = "plaintext-token-no-encryption" + + mcp_server = PydanticMCPServer( + server_name=server_name, server_type=MCPServerType.SSE, server_url="https://api.example.com", token=token + ) + + created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user) + + # Check database - should store as plaintext + async with db_registry.async_session() as session: + result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id)) + db_server = result.scalar_one() + + # Should store in plaintext column + assert db_server.token == token + assert db_server.token_enc is None # No encryption + + # Clean up + await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + + +class TestMCPOAuthEncryption: + """Test MCP OAuth session encryption functionality.""" + + MOCK_ENCRYPTION_KEY = "test-oauth-encryption-key-123456" + + @pytest.mark.asyncio + @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) + async def test_create_oauth_session_with_encryption(self, server, default_user): + """Test that OAuth tokens are encrypted when stored.""" + server_url = "https://github.com/mcp" + server_name = "GitHub MCP" + + # Step 1: Create OAuth session (without tokens initially) + oauth_session_create = MCPOAuthSessionCreate( + server_url=server_url, + server_name=server_name, + organization_id=default_user.organization_id, + user_id=default_user.id, + ) + + created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user) + + assert created_session.server_url == server_url + assert created_session.server_name == server_name + + # Step 2: Update session with tokens (simulating OAuth callback) + from letta.schemas.mcp import MCPOAuthSessionUpdate + + update_data = MCPOAuthSessionUpdate( + access_token="github-access-token-abc123", + refresh_token="github-refresh-token-xyz789", + client_id="client-id-123", + client_secret="client-secret-super-secret", + expires_at=datetime.now(timezone.utc), + ) + + await server.mcp_manager.update_oauth_session(created_session.id, update_data, actor=default_user) + + # Check database directly for encryption + async with db_registry.async_session() as session: + result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id)) + db_oauth = result.scalar_one() + + # All sensitive fields should be encrypted + assert db_oauth.access_token_enc is not None + assert db_oauth.access_token_enc != update_data.access_token + + assert db_oauth.refresh_token_enc is not None + + assert db_oauth.client_secret_enc is not None + + # Verify decryption + decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc, self.MOCK_ENCRYPTION_KEY) + assert decrypted_access == update_data.access_token + + decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc, self.MOCK_ENCRYPTION_KEY) + assert decrypted_refresh == update_data.refresh_token + + decrypted_secret = CryptoUtils.decrypt(db_oauth.client_secret_enc, self.MOCK_ENCRYPTION_KEY) + assert decrypted_secret == update_data.client_secret + + # Clean up not needed - test database is reset + + @pytest.mark.asyncio + @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) + async def test_retrieve_oauth_session_decrypts_tokens(self, server, default_user): + """Test that retrieving OAuth session decrypts tokens.""" + # Manually insert encrypted OAuth session + session_id = f"mcp-oauth-{str(uuid4())[:8]}" + access_token = "test-access-token" + refresh_token = "test-refresh-token" + client_secret = "test-client-secret" + + encrypted_access = CryptoUtils.encrypt(access_token, self.MOCK_ENCRYPTION_KEY) + encrypted_refresh = CryptoUtils.encrypt(refresh_token, self.MOCK_ENCRYPTION_KEY) + encrypted_secret = CryptoUtils.encrypt(client_secret, self.MOCK_ENCRYPTION_KEY) + + async with db_registry.async_session() as session: + db_oauth = MCPOAuth( + id=session_id, + state=f"test-state-{uuid4().hex[:8]}", + server_url="https://test.com/mcp", + server_name="Test Provider", + access_token_enc=encrypted_access, + refresh_token_enc=encrypted_refresh, + client_id="test-client", + client_secret_enc=encrypted_secret, + user_id=default_user.id, + organization_id=default_user.organization_id, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(db_oauth) + await session.commit() + + # Retrieve through manager by ID + test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user) + assert test_session is not None + + # Tokens should be decrypted + assert test_session.access_token == access_token + assert test_session.refresh_token == refresh_token + assert test_session.client_secret == client_secret + + # Clean up not needed - test database is reset + + @pytest.mark.asyncio + @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) + async def test_update_oauth_session_maintains_encryption(self, server, default_user): + """Test that updating OAuth session maintains encryption.""" + # Create initial session (without tokens) + from letta.schemas.mcp import MCPOAuthSessionUpdate + + oauth_session_create = MCPOAuthSessionCreate( + server_url="https://test.com/mcp", + server_name="Test Update Provider", + organization_id=default_user.organization_id, + user_id=default_user.id, + ) + + created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user) + + # Add initial tokens + initial_update = MCPOAuthSessionUpdate( + access_token="initial-token", + refresh_token="initial-refresh", + client_id="client-123", + client_secret="initial-secret", + ) + + await server.mcp_manager.update_oauth_session(created_session.id, initial_update, actor=default_user) + + # Update with new tokens + new_access_token = "updated-access-token" + new_refresh_token = "updated-refresh-token" + + new_update = MCPOAuthSessionUpdate( + access_token=new_access_token, + refresh_token=new_refresh_token, + ) + + updated_session = await server.mcp_manager.update_oauth_session(created_session.id, new_update, actor=default_user) + + # Verify update worked + assert updated_session.access_token == new_access_token + assert updated_session.refresh_token == new_refresh_token + + # Check database encryption + async with db_registry.async_session() as session: + result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id)) + db_oauth = result.scalar_one() + + # New tokens should be encrypted + decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc, self.MOCK_ENCRYPTION_KEY) + assert decrypted_access == new_access_token + + decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc, self.MOCK_ENCRYPTION_KEY) + assert decrypted_refresh == new_refresh_token + + # Clean up not needed - test database is reset + + @pytest.mark.asyncio + @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) + async def test_dual_read_backward_compatibility(self, server, default_user): + """Test that system can read both encrypted and plaintext values (migration support).""" + # Insert a record with both encrypted and plaintext values + session_id = f"mcp-oauth-{str(uuid4())[:8]}" + plaintext_token = "legacy-plaintext-token" + new_encrypted_token = "new-encrypted-token" + encrypted_new = CryptoUtils.encrypt(new_encrypted_token, self.MOCK_ENCRYPTION_KEY) + + async with db_registry.async_session() as session: + db_oauth = MCPOAuth( + id=session_id, + state=f"dual-read-state-{uuid4().hex[:8]}", + server_url="https://test.com/mcp", + server_name="Dual Read Test", + # Both encrypted and plaintext values + access_token=plaintext_token, # Legacy plaintext + access_token_enc=encrypted_new, # New encrypted + client_id="test-client", + user_id=default_user.id, + organization_id=default_user.organization_id, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(db_oauth) + await session.commit() + + # Retrieve through manager + test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user) + assert test_session is not None + + # Should prefer encrypted value over plaintext + assert test_session.access_token == new_encrypted_token + + # Clean up not needed - test database is reset diff --git a/tests/test_secret.py b/tests/test_secret.py new file mode 100644 index 00000000..a273344f --- /dev/null +++ b/tests/test_secret.py @@ -0,0 +1,373 @@ +import json + +import pytest + +from letta.helpers.crypto_utils import CryptoUtils +from letta.schemas.secret import Secret, SecretDict + + +class TestSecret: + """Test suite for Secret wrapper class.""" + + MOCK_KEY = "test-secret-key-1234567890" + + def test_from_plaintext_with_key(self): + """Test creating a Secret from plaintext value with encryption key.""" + from letta.settings import settings + + # Set encryption key + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext = "my-secret-value" + + secret = Secret.from_plaintext(plaintext) + + # Should store encrypted value + assert secret._encrypted_value is not None + assert secret._encrypted_value != plaintext + assert secret._was_encrypted is False + + # Should decrypt to original value + assert secret.get_plaintext() == plaintext + finally: + settings.encryption_key = original_key + + def test_from_plaintext_without_key(self): + """Test creating a Secret from plaintext without encryption key.""" + from letta.settings import settings + + # Clear encryption key + original_key = settings.encryption_key + settings.encryption_key = None + + try: + plaintext = "my-plaintext-value" + + # Should raise error when trying to encrypt without key + with pytest.raises(ValueError): + Secret.from_plaintext(plaintext) + finally: + settings.encryption_key = original_key + + def test_from_plaintext_with_none(self): + """Test creating a Secret from None value.""" + secret = Secret.from_plaintext(None) + + assert secret._encrypted_value is None + assert secret._was_encrypted is False + assert secret.get_plaintext() is None + assert secret.is_empty() is True + + def test_from_encrypted(self): + """Test creating a Secret from already encrypted value.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext = "database-secret" + encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY) + + secret = Secret.from_encrypted(encrypted) + + assert secret._encrypted_value == encrypted + assert secret._was_encrypted is True + assert secret.get_plaintext() == plaintext + finally: + settings.encryption_key = original_key + + def test_from_db_with_encrypted_value(self): + """Test creating a Secret from database with encrypted value.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext = "database-secret" + encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY) + + secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=None) + + assert secret._encrypted_value == encrypted + assert secret._was_encrypted is True + assert secret.get_plaintext() == plaintext + finally: + settings.encryption_key = original_key + + def test_from_db_with_plaintext_value(self): + """Test creating a Secret from database with plaintext value (backward compatibility).""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext = "legacy-plaintext" + + # When only plaintext is provided, should encrypt it + secret = Secret.from_db(encrypted_value=None, plaintext_value=plaintext) + + # Should encrypt the plaintext + assert secret._encrypted_value is not None + assert secret._was_encrypted is False + assert secret.get_plaintext() == plaintext + finally: + settings.encryption_key = original_key + + def test_from_db_dual_read(self): + """Test dual read functionality - prefer encrypted over plaintext.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext = "correct-value" + old_plaintext = "old-legacy-value" + encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY) + + # When both values exist, should prefer encrypted + secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=old_plaintext) + + assert secret.get_plaintext() == plaintext # Should use encrypted value, not plaintext + finally: + settings.encryption_key = original_key + + def test_get_encrypted(self): + """Test getting the encrypted value for database storage.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext = "test-encryption" + + secret = Secret.from_plaintext(plaintext) + encrypted_value = secret.get_encrypted() + + assert encrypted_value is not None + + # Should decrypt back to original + decrypted = CryptoUtils.decrypt(encrypted_value, self.MOCK_KEY) + assert decrypted == plaintext + finally: + settings.encryption_key = original_key + + def test_is_empty(self): + """Test checking if secret is empty.""" + # Empty secret + empty_secret = Secret.from_plaintext(None) + assert empty_secret.is_empty() is True + + # Non-empty secret + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + non_empty_secret = Secret.from_plaintext("value") + assert non_empty_secret.is_empty() is False + finally: + settings.encryption_key = original_key + + def test_string_representation(self): + """Test that string representation doesn't expose secret.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + secret = Secret.from_plaintext("sensitive-data") + + # String representation should not contain the actual value + str_repr = str(secret) + assert "sensitive-data" not in str_repr + assert "****" in str_repr + + # Empty secret + empty_secret = Secret.from_plaintext(None) + assert "empty" in str(empty_secret) + finally: + settings.encryption_key = original_key + + def test_equality(self): + """Test comparing two secrets.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext = "same-value" + + secret1 = Secret.from_plaintext(plaintext) + secret2 = Secret.from_plaintext(plaintext) + + # Should be equal based on plaintext value + assert secret1 == secret2 + + # Different values should not be equal + secret3 = Secret.from_plaintext("different-value") + assert secret1 != secret3 + finally: + settings.encryption_key = original_key + + +class TestSecretDict: + """Test suite for SecretDict wrapper class.""" + + MOCK_KEY = "test-secretdict-key-1234567890" + + def test_from_plaintext_dict(self): + """Test creating a SecretDict from plaintext dictionary.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext_dict = {"api_key": "sk-1234567890", "api_secret": "secret-value", "nested": {"token": "bearer-token"}} + + secret_dict = SecretDict.from_plaintext(plaintext_dict) + + # Should store encrypted JSON + assert secret_dict._encrypted_value is not None + + # Should decrypt to original dict + assert secret_dict.get_plaintext() == plaintext_dict + finally: + settings.encryption_key = original_key + + def test_from_plaintext_none(self): + """Test creating a SecretDict from None value.""" + secret_dict = SecretDict.from_plaintext(None) + + assert secret_dict._encrypted_value is None + assert secret_dict.get_plaintext() is None + + def test_from_encrypted_with_json(self): + """Test creating a SecretDict from encrypted JSON.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext_dict = {"header1": "value1", "Authorization": "Bearer token123"} + + json_str = json.dumps(plaintext_dict) + encrypted = CryptoUtils.encrypt(json_str, self.MOCK_KEY) + + secret_dict = SecretDict.from_encrypted(encrypted) + + assert secret_dict._encrypted_value == encrypted + assert secret_dict.get_plaintext() == plaintext_dict + finally: + settings.encryption_key = original_key + + def test_from_db_with_encrypted(self): + """Test creating SecretDict from database with encrypted value.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext_dict = {"key": "value"} + json_str = json.dumps(plaintext_dict) + encrypted = CryptoUtils.encrypt(json_str, self.MOCK_KEY) + + secret_dict = SecretDict.from_db(encrypted_value=encrypted, plaintext_value=None) + + assert secret_dict.get_plaintext() == plaintext_dict + finally: + settings.encryption_key = original_key + + def test_from_db_with_plaintext_json(self): + """Test creating SecretDict from database with plaintext JSON (backward compatibility).""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext_dict = {"legacy": "headers"} + + # from_db expects a Dict, not a JSON string + secret_dict = SecretDict.from_db(encrypted_value=None, plaintext_value=plaintext_dict) + + assert secret_dict.get_plaintext() == plaintext_dict + # Should have encrypted it + assert secret_dict._encrypted_value is not None + finally: + settings.encryption_key = original_key + + def test_complex_nested_structure(self): + """Test SecretDict with complex nested structures.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + complex_dict = { + "level1": {"level2": {"level3": ["item1", "item2"], "secret": "nested-secret"}, "array": [1, 2, {"nested": "value"}]}, + "simple": "value", + "number": 42, + "boolean": True, + "null": None, + } + + secret_dict = SecretDict.from_plaintext(complex_dict) + decrypted = secret_dict.get_plaintext() + + assert decrypted == complex_dict + finally: + settings.encryption_key = original_key + + def test_empty_dict(self): + """Test SecretDict with empty dictionary.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + empty_dict = {} + + secret_dict = SecretDict.from_plaintext(empty_dict) + assert secret_dict.get_plaintext() == empty_dict + + # Encrypted value should still be created + encrypted = secret_dict.get_encrypted() + assert encrypted is not None + finally: + settings.encryption_key = original_key + + def test_dual_read_prefer_encrypted(self): + """Test that SecretDict prefers encrypted value over plaintext when both exist.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + new_dict = {"current": "value"} + old_dict = {"legacy": "value"} + + encrypted = CryptoUtils.encrypt(json.dumps(new_dict), self.MOCK_KEY) + plaintext = json.dumps(old_dict) + + secret_dict = SecretDict.from_db(encrypted_value=encrypted, plaintext_value=plaintext) + + # Should use encrypted value, not plaintext + assert secret_dict.get_plaintext() == new_dict + finally: + settings.encryption_key = original_key