feat: encryption for mcp (#2937)

This commit is contained in:
jnjpng
2025-09-16 11:56:34 -07:00
committed by GitHub
parent c8d3616864
commit 3711b5279c
14 changed files with 2166 additions and 117 deletions

View File

@@ -53,7 +53,10 @@ jobs:
{"test_suite": "mcp_tests/", "use_experimental": true},
{"test_suite": "test_timezone_formatting.py"},
{"test_suite": "test_plugins.py"},
{"test_suite": "test_embeddings.py"}
{"test_suite": "test_embeddings.py"},
{"test_suite": "test_crypto_utils.py"},
{"test_suite": "test_mcp_encryption.py"},
{"test_suite": "test_secret.py"}
]
}
}

View File

@@ -53,7 +53,10 @@ jobs:
{"test_suite": "mcp_tests/", "use_experimental": true},
{"test_suite": "test_timezone_formatting.py"},
{"test_suite": "test_plugins.py"},
{"test_suite": "test_embeddings.py"}
{"test_suite": "test_embeddings.py"},
{"test_suite": "test_crypto_utils.py"},
{"test_suite": "test_mcp_encryption.py"},
{"test_suite": "test_secret.py"}
]
}
}

View File

@@ -0,0 +1,318 @@
"""add and migrate encrypted columns for mcp
Revision ID: d06594144ef3
Revises: 5d27a719b24d
Create Date: 2025-09-15 22:02:47.403970
"""
import json
import os
# Add the app directory to path to import our crypto utils
import sys
from pathlib import Path
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy import JSON, String, Text
from sqlalchemy.sql import column, table
from alembic import op
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from letta.helpers.crypto_utils import CryptoUtils
# revision identifiers, used by Alembic.
revision: str = "d06594144ef3"
down_revision: Union[str, None] = "5d27a719b24d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# First, add the new encrypted columns
op.add_column("mcp_oauth", sa.Column("access_token_enc", sa.Text(), nullable=True))
op.add_column("mcp_oauth", sa.Column("refresh_token_enc", sa.Text(), nullable=True))
op.add_column("mcp_oauth", sa.Column("client_secret_enc", sa.Text(), nullable=True))
op.add_column("mcp_server", sa.Column("token_enc", sa.Text(), nullable=True))
op.add_column("mcp_server", sa.Column("custom_headers_enc", sa.Text(), nullable=True))
# Check if encryption key is available
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY")
if not encryption_key:
print("WARNING: LETTA_ENCRYPTION_KEY not set. Skipping data encryption migration.")
print(" You can run a separate migration script later to encrypt existing data.")
return
# Get database connection
connection = op.get_bind()
# Batch processing configuration
BATCH_SIZE = 1000 # Process 1000 rows at a time
# Migrate mcp_oauth data
print("Migrating mcp_oauth encrypted fields...")
mcp_oauth = table(
"mcp_oauth",
column("id", String),
column("access_token", Text),
column("access_token_enc", Text),
column("refresh_token", Text),
column("refresh_token_enc", Text),
column("client_secret", Text),
column("client_secret_enc", Text),
)
# Count total rows to process
total_count_result = connection.execute(
sa.select(sa.func.count())
.select_from(mcp_oauth)
.where(
sa.and_(
sa.or_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.client_secret.isnot(None)),
# Only count rows that need encryption
sa.or_(
sa.and_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.access_token_enc.is_(None)),
sa.and_(mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.refresh_token_enc.is_(None)),
sa.and_(mcp_oauth.c.client_secret.isnot(None), mcp_oauth.c.client_secret_enc.is_(None)),
),
)
)
).scalar()
if total_count_result and total_count_result > 0:
print(f"Found {total_count_result} mcp_oauth records that need encryption")
encrypted_count = 0
skipped_count = 0
offset = 0
# Process in batches
while True:
# Select batch of rows
oauth_rows = connection.execute(
sa.select(
mcp_oauth.c.id,
mcp_oauth.c.access_token,
mcp_oauth.c.access_token_enc,
mcp_oauth.c.refresh_token,
mcp_oauth.c.refresh_token_enc,
mcp_oauth.c.client_secret,
mcp_oauth.c.client_secret_enc,
)
.where(
sa.and_(
sa.or_(
mcp_oauth.c.access_token.isnot(None),
mcp_oauth.c.refresh_token.isnot(None),
mcp_oauth.c.client_secret.isnot(None),
),
# Only select rows that need encryption
sa.or_(
sa.and_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.access_token_enc.is_(None)),
sa.and_(mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.refresh_token_enc.is_(None)),
sa.and_(mcp_oauth.c.client_secret.isnot(None), mcp_oauth.c.client_secret_enc.is_(None)),
),
)
)
.order_by(mcp_oauth.c.id) # Ensure consistent ordering
.limit(BATCH_SIZE)
.offset(offset)
).fetchall()
if not oauth_rows:
break # No more rows to process
# Prepare batch updates
batch_updates = []
for row in oauth_rows:
updates = {"id": row.id}
has_updates = False
# Encrypt access_token if present and not already encrypted
if row.access_token and not row.access_token_enc:
try:
updates["access_token_enc"] = CryptoUtils.encrypt(row.access_token, encryption_key)
has_updates = True
except Exception as e:
print(f"Warning: Failed to encrypt access_token for mcp_oauth id={row.id}: {e}")
elif row.access_token_enc:
skipped_count += 1
# Encrypt refresh_token if present and not already encrypted
if row.refresh_token and not row.refresh_token_enc:
try:
updates["refresh_token_enc"] = CryptoUtils.encrypt(row.refresh_token, encryption_key)
has_updates = True
except Exception as e:
print(f"Warning: Failed to encrypt refresh_token for mcp_oauth id={row.id}: {e}")
elif row.refresh_token_enc:
skipped_count += 1
# Encrypt client_secret if present and not already encrypted
if row.client_secret and not row.client_secret_enc:
try:
updates["client_secret_enc"] = CryptoUtils.encrypt(row.client_secret, encryption_key)
has_updates = True
except Exception as e:
print(f"Warning: Failed to encrypt client_secret for mcp_oauth id={row.id}: {e}")
elif row.client_secret_enc:
skipped_count += 1
if has_updates:
batch_updates.append(updates)
encrypted_count += 1
# Execute batch update if there are updates
if batch_updates:
# Use bulk update for better performance
for update_data in batch_updates:
row_id = update_data.pop("id")
if update_data: # Only update if there are fields to update
connection.execute(mcp_oauth.update().where(mcp_oauth.c.id == row_id).values(**update_data))
# Progress indicator for large datasets
if encrypted_count > 0 and encrypted_count % 10000 == 0:
print(f" Progress: Encrypted {encrypted_count} mcp_oauth records...")
offset += BATCH_SIZE
# For very large datasets, commit periodically to avoid long transactions
if encrypted_count > 0 and encrypted_count % 50000 == 0:
connection.commit()
print(f"mcp_oauth: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
else:
print("mcp_oauth: No records need encryption")
# Migrate mcp_server data
print("Migrating mcp_server encrypted fields...")
mcp_server = table(
"mcp_server",
column("id", String),
column("token", String),
column("token_enc", Text),
column("custom_headers", JSON),
column("custom_headers_enc", Text),
)
# Count total rows to process
total_count_result = connection.execute(
sa.select(sa.func.count())
.select_from(mcp_server)
.where(
sa.and_(
sa.or_(mcp_server.c.token.isnot(None), mcp_server.c.custom_headers.isnot(None)),
# Only count rows that need encryption
sa.or_(
sa.and_(mcp_server.c.token.isnot(None), mcp_server.c.token_enc.is_(None)),
sa.and_(mcp_server.c.custom_headers.isnot(None), mcp_server.c.custom_headers_enc.is_(None)),
),
)
)
).scalar()
if total_count_result and total_count_result > 0:
print(f"Found {total_count_result} mcp_server records that need encryption")
encrypted_count = 0
skipped_count = 0
offset = 0
# Process in batches
while True:
# Select batch of rows
server_rows = connection.execute(
sa.select(
mcp_server.c.id,
mcp_server.c.token,
mcp_server.c.token_enc,
mcp_server.c.custom_headers,
mcp_server.c.custom_headers_enc,
)
.where(
sa.and_(
sa.or_(mcp_server.c.token.isnot(None), mcp_server.c.custom_headers.isnot(None)),
# Only select rows that need encryption
sa.or_(
sa.and_(mcp_server.c.token.isnot(None), mcp_server.c.token_enc.is_(None)),
sa.and_(mcp_server.c.custom_headers.isnot(None), mcp_server.c.custom_headers_enc.is_(None)),
),
)
)
.order_by(mcp_server.c.id) # Ensure consistent ordering
.limit(BATCH_SIZE)
.offset(offset)
).fetchall()
if not server_rows:
break # No more rows to process
# Prepare batch updates
batch_updates = []
for row in server_rows:
updates = {"id": row.id}
has_updates = False
# Encrypt token if present and not already encrypted
if row.token and not row.token_enc:
try:
updates["token_enc"] = CryptoUtils.encrypt(row.token, encryption_key)
has_updates = True
except Exception as e:
print(f"Warning: Failed to encrypt token for mcp_server id={row.id}: {e}")
elif row.token_enc:
skipped_count += 1
# Encrypt custom_headers if present (JSON field) and not already encrypted
if row.custom_headers and not row.custom_headers_enc:
try:
# Convert JSON to string for encryption
headers_json = json.dumps(row.custom_headers)
updates["custom_headers_enc"] = CryptoUtils.encrypt(headers_json, encryption_key)
has_updates = True
except Exception as e:
print(f"Warning: Failed to encrypt custom_headers for mcp_server id={row.id}: {e}")
elif row.custom_headers_enc:
skipped_count += 1
if has_updates:
batch_updates.append(updates)
encrypted_count += 1
# Execute batch update if there are updates
if batch_updates:
# Use bulk update for better performance
for update_data in batch_updates:
row_id = update_data.pop("id")
if update_data: # Only update if there are fields to update
connection.execute(mcp_server.update().where(mcp_server.c.id == row_id).values(**update_data))
# Progress indicator for large datasets
if encrypted_count > 0 and encrypted_count % 10000 == 0:
print(f" Progress: Encrypted {encrypted_count} mcp_server records...")
offset += BATCH_SIZE
# For very large datasets, commit periodically to avoid long transactions
if encrypted_count > 0 and encrypted_count % 50000 == 0:
connection.commit()
print(f"mcp_server: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
else:
print("mcp_server: No records need encryption")
print("Migration complete. Plaintext columns are retained for rollback safety.")
# ### end Alembic commands ###
def downgrade() -> None:
op.drop_column("mcp_server", "custom_headers_enc")
op.drop_column("mcp_server", "token_enc")
op.drop_column("mcp_oauth", "client_secret_enc")
op.drop_column("mcp_oauth", "refresh_token_enc")
op.drop_column("mcp_oauth", "access_token_enc")
# ### end Alembic commands ###

View File

@@ -0,0 +1,134 @@
import base64
import os
from typing import Optional
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from letta.settings import settings
class CryptoUtils:
"""Utility class for AES-256-GCM encryption/decryption of sensitive data."""
# AES-256 requires 32 bytes key
KEY_SIZE = 32
# GCM standard IV size is 12 bytes (96 bits)
IV_SIZE = 12
# GCM tag size is 16 bytes (128 bits)
TAG_SIZE = 16
# Salt size for key derivation
SALT_SIZE = 16
@classmethod
def _derive_key(cls, master_key: str, salt: bytes) -> bytes:
"""Derive an AES key from the master key using PBKDF2."""
kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=cls.KEY_SIZE, salt=salt, iterations=100000, backend=default_backend())
return kdf.derive(master_key.encode())
@classmethod
def encrypt(cls, plaintext: str, master_key: Optional[str] = None) -> str:
"""
Encrypt a string using AES-256-GCM.
Args:
plaintext: The string to encrypt
master_key: Optional master key (defaults to settings.encryption_key)
Returns:
Base64 encoded string containing: salt + iv + ciphertext + tag
Raises:
ValueError: If no encryption key is configured
"""
if master_key is None:
master_key = settings.encryption_key
if not master_key:
raise ValueError("No encryption key configured. Set LETTA_ENCRYPTION_KEY environment variable.")
# Generate random salt and IV
salt = os.urandom(cls.SALT_SIZE)
iv = os.urandom(cls.IV_SIZE)
# Derive key from master key
key = cls._derive_key(master_key, salt)
# Create cipher
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=default_backend())
encryptor = cipher.encryptor()
# Encrypt the plaintext
ciphertext = encryptor.update(plaintext.encode()) + encryptor.finalize()
# Get the authentication tag
tag = encryptor.tag
# Combine salt + iv + ciphertext + tag
encrypted_data = salt + iv + ciphertext + tag
# Return as base64 encoded string
return base64.b64encode(encrypted_data).decode("utf-8")
@classmethod
def decrypt(cls, encrypted: str, master_key: Optional[str] = None) -> str:
"""
Decrypt a string that was encrypted using AES-256-GCM.
Args:
encrypted: Base64 encoded encrypted string
master_key: Optional master key (defaults to settings.encryption_key)
Returns:
The decrypted plaintext string
Raises:
ValueError: If no encryption key is configured or decryption fails
"""
if master_key is None:
master_key = settings.encryption_key
if not master_key:
raise ValueError("No encryption key configured. Set LETTA_ENCRYPTION_KEY environment variable.")
try:
# Decode from base64
encrypted_data = base64.b64decode(encrypted)
# Extract components
salt = encrypted_data[: cls.SALT_SIZE]
iv = encrypted_data[cls.SALT_SIZE : cls.SALT_SIZE + cls.IV_SIZE]
ciphertext = encrypted_data[cls.SALT_SIZE + cls.IV_SIZE : -cls.TAG_SIZE]
tag = encrypted_data[-cls.TAG_SIZE :]
# Derive key from master key
key = cls._derive_key(master_key, salt)
# Create cipher
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=default_backend())
decryptor = cipher.decryptor()
# Decrypt the ciphertext
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
return plaintext.decode("utf-8")
except Exception as e:
raise ValueError(f"Failed to decrypt data: {str(e)}")
@classmethod
def is_encrypted(cls, value: str) -> bool:
"""
Check if a string appears to be encrypted (base64 encoded with correct size).
This is a heuristic check and may have false positives.
"""
try:
decoded = base64.b64decode(value)
# Check if length is consistent with our encryption format
# Minimum size: salt(16) + iv(12) + tag(16) + at least 1 byte of ciphertext
return len(decoded) >= cls.SALT_SIZE + cls.IV_SIZE + cls.TAG_SIZE + 1
except Exception:
return False

View File

@@ -18,6 +18,7 @@ from letta.orm.job import Job
from letta.orm.job_messages import JobMessage
from letta.orm.llm_batch_items import LLMBatchItem
from letta.orm.llm_batch_job import LLMBatchJob
from letta.orm.mcp_oauth import MCPOAuth
from letta.orm.mcp_server import MCPServer
from letta.orm.message import Message
from letta.orm.organization import Organization

View File

@@ -38,7 +38,11 @@ class MCPOAuth(SqlalchemyBase, OrganizationMixin, UserMixin):
# Token data
access_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth access token")
access_token_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted OAuth access token")
refresh_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth refresh token")
refresh_token_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted OAuth refresh token")
token_type: Mapped[str] = mapped_column(String(50), default="Bearer", doc="Token type")
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, doc="Token expiry time")
scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth scope")
@@ -46,6 +50,8 @@ class MCPOAuth(SqlalchemyBase, OrganizationMixin, UserMixin):
# Client configuration
client_id: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client ID")
client_secret: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client secret")
client_secret_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted OAuth client secret")
redirect_uri: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth redirect URI")
# Session state

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Optional
from sqlalchemy import JSON, String, UniqueConstraint
from sqlalchemy import JSON, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from letta.functions.mcp_client.types import StdioServerConfig
@@ -39,9 +39,15 @@ class MCPServer(SqlalchemyBase, OrganizationMixin):
# access token / api key for MCP servers that require authentication
token: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The access token or api key for the MCP server")
# encrypted access token or api key for the MCP server
token_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted access token or api key for the MCP server")
# custom headers for authentication (key-value pairs)
custom_headers: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Custom authentication headers as key-value pairs")
# encrypted custom headers for authentication (key-value pairs)
custom_headers_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted custom authentication headers")
# stdio server
stdio_config: Mapped[Optional[StdioServerConfig]] = mapped_column(
MCPStdioServerConfigColumn, nullable=True, doc="The configuration for the stdio server"

View File

@@ -13,6 +13,7 @@ from letta.functions.mcp_client.types import (
)
from letta.orm.mcp_oauth import OAuthSessionStatus
from letta.schemas.letta_base import LettaBase
from letta.schemas.secret import Secret, SecretDict
class BaseMCPServer(LettaBase):
@@ -29,6 +30,9 @@ class MCPServer(BaseMCPServer):
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for authentication)")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
token_enc: Optional[str] = Field(None, description="Encrypted token")
custom_headers_enc: Optional[str] = Field(None, description="Encrypted custom headers")
# stdio config
stdio_config: Optional[StdioServerConfig] = Field(
None, description="The configuration for the server (MCP 'local' client will run this command)"
@@ -41,18 +45,93 @@ class MCPServer(BaseMCPServer):
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
def get_token_secret(self) -> Secret:
"""Get the token as a Secret object, preferring encrypted over plaintext."""
return Secret.from_db(self.token_enc, self.token)
def get_custom_headers_secret(self) -> SecretDict:
"""Get custom headers as a SecretDict object, preferring encrypted over plaintext."""
return SecretDict.from_db(self.custom_headers_enc, self.custom_headers)
def set_token_secret(self, secret: Secret) -> None:
"""Set token from a Secret object, updating both encrypted and plaintext fields."""
secret_dict = secret.to_dict()
self.token_enc = secret_dict["encrypted"]
# Only set plaintext during migration phase
if not secret._was_encrypted:
self.token = secret_dict["plaintext"]
else:
self.token = None
def set_custom_headers_secret(self, secret: SecretDict) -> None:
"""Set custom headers from a SecretDict object, updating both fields."""
secret_dict = secret.to_dict()
self.custom_headers_enc = secret_dict["encrypted"]
# Only set plaintext during migration phase
if not secret._was_encrypted:
self.custom_headers = secret_dict["plaintext"]
else:
self.custom_headers = None
def model_dump(self, to_orm: bool = False, **kwargs):
"""Override model_dump to handle encryption when saving to database."""
import os
# Check environment variable directly to handle test patching
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY")
if not encryption_key:
from letta.settings import settings
encryption_key = settings.encryption_key
data = super().model_dump(to_orm=to_orm, **kwargs)
if to_orm and encryption_key:
# Temporarily set the encryption key for Secret/SecretDict
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = encryption_key
try:
# Encrypt token if present
if self.token is not None:
token_secret = Secret.from_plaintext(self.token)
secret_dict = token_secret.to_dict()
data["token_enc"] = secret_dict["encrypted"]
# Keep plaintext for dual-write during migration
data["token"] = secret_dict["plaintext"]
# Encrypt custom headers if present
if self.custom_headers is not None:
headers_secret = SecretDict.from_plaintext(self.custom_headers)
secret_dict = headers_secret.to_dict()
data["custom_headers_enc"] = secret_dict["encrypted"]
# Keep plaintext for dual-write during migration
data["custom_headers"] = secret_dict["plaintext"]
finally:
settings.encryption_key = original_key
return data
def to_config(
self,
environment_variables: Optional[Dict[str, str]] = None,
resolve_variables: bool = True,
) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]:
# Get decrypted values for use in config
token_secret = self.get_token_secret()
token_plaintext = token_secret.get_plaintext()
headers_secret = self.get_custom_headers_secret()
headers_plaintext = headers_secret.get_plaintext()
if self.server_type == MCPServerType.SSE:
config = SSEServerConfig(
server_name=self.server_name,
server_url=self.server_url,
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None,
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None,
custom_headers=self.custom_headers,
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if token_plaintext and not headers_plaintext else None,
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {token_plaintext}" if token_plaintext and not headers_plaintext else None,
custom_headers=headers_plaintext,
)
if resolve_variables:
config.resolve_environment_variables(environment_variables)
@@ -70,9 +149,9 @@ class MCPServer(BaseMCPServer):
config = StreamableHTTPServerConfig(
server_name=self.server_name,
server_url=self.server_url,
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None,
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None,
custom_headers=self.custom_headers,
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if token_plaintext and not headers_plaintext else None,
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {token_plaintext}" if token_plaintext and not headers_plaintext else None,
custom_headers=headers_plaintext,
)
if resolve_variables:
config.resolve_environment_variables(environment_variables)
@@ -138,11 +217,18 @@ class MCPOAuthSession(BaseMCPOAuth):
expires_at: Optional[datetime] = Field(None, description="Token expiry time")
scope: Optional[str] = Field(None, description="OAuth scope")
# Encrypted token fields (for internal use)
access_token_enc: Optional[str] = Field(None, description="Encrypted OAuth access token")
refresh_token_enc: Optional[str] = Field(None, description="Encrypted OAuth refresh token")
# Client configuration
client_id: Optional[str] = Field(None, description="OAuth client ID")
client_secret: Optional[str] = Field(None, description="OAuth client secret")
redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI")
# Encrypted client secret (for internal use)
client_secret_enc: Optional[str] = Field(None, description="Encrypted OAuth client secret")
# Session state
status: OAuthSessionStatus = Field(default=OAuthSessionStatus.PENDING, description="Session status")
@@ -150,6 +236,93 @@ class MCPOAuthSession(BaseMCPOAuth):
created_at: datetime = Field(default_factory=datetime.now, description="Session creation time")
updated_at: datetime = Field(default_factory=datetime.now, description="Last update time")
def get_access_token_secret(self) -> Secret:
"""Get the access token as a Secret object, preferring encrypted over plaintext."""
return Secret.from_db(self.access_token_enc, self.access_token)
def get_refresh_token_secret(self) -> Secret:
"""Get the refresh token as a Secret object, preferring encrypted over plaintext."""
return Secret.from_db(self.refresh_token_enc, self.refresh_token)
def get_client_secret_secret(self) -> Secret:
"""Get the client secret as a Secret object, preferring encrypted over plaintext."""
return Secret.from_db(self.client_secret_enc, self.client_secret)
def set_access_token_secret(self, secret: Secret) -> None:
"""Set access token from a Secret object."""
secret_dict = secret.to_dict()
self.access_token_enc = secret_dict["encrypted"]
if not secret._was_encrypted:
self.access_token = secret_dict["plaintext"]
else:
self.access_token = None
def set_refresh_token_secret(self, secret: Secret) -> None:
"""Set refresh token from a Secret object."""
secret_dict = secret.to_dict()
self.refresh_token_enc = secret_dict["encrypted"]
if not secret._was_encrypted:
self.refresh_token = secret_dict["plaintext"]
else:
self.refresh_token = None
def set_client_secret_secret(self, secret: Secret) -> None:
"""Set client secret from a Secret object."""
secret_dict = secret.to_dict()
self.client_secret_enc = secret_dict["encrypted"]
if not secret._was_encrypted:
self.client_secret = secret_dict["plaintext"]
else:
self.client_secret = None
def model_dump(self, to_orm: bool = False, **kwargs):
"""Override model_dump to handle encryption when saving to database."""
import os
# Check environment variable directly to handle test patching
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY")
if not encryption_key:
from letta.settings import settings
encryption_key = settings.encryption_key
data = super().model_dump(to_orm=to_orm, **kwargs)
if to_orm and encryption_key:
# Temporarily set the encryption key for Secret
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = encryption_key
try:
# Encrypt access token if present
if self.access_token is not None:
token_secret = Secret.from_plaintext(self.access_token)
secret_dict = token_secret.to_dict()
data["access_token_enc"] = secret_dict["encrypted"]
# Keep plaintext for dual-write during migration
data["access_token"] = secret_dict["plaintext"]
# Encrypt refresh token if present
if self.refresh_token is not None:
token_secret = Secret.from_plaintext(self.refresh_token)
secret_dict = token_secret.to_dict()
data["refresh_token_enc"] = secret_dict["encrypted"]
# Keep plaintext for dual-write during migration
data["refresh_token"] = secret_dict["plaintext"]
# Encrypt client secret if present
if self.client_secret is not None:
secret = Secret.from_plaintext(self.client_secret)
secret_dict = secret.to_dict()
data["client_secret_enc"] = secret_dict["encrypted"]
# Keep plaintext for dual-write during migration
data["client_secret"] = secret_dict["plaintext"]
finally:
settings.encryption_key = original_key
return data
class MCPOAuthSessionCreate(BaseMCPOAuth):
"""Create a new OAuth session."""

241
letta/schemas/secret.py Normal file
View File

@@ -0,0 +1,241 @@
import json
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, PrivateAttr
from letta.helpers.crypto_utils import CryptoUtils
class Secret(BaseModel):
"""
A wrapper class for encrypted credentials that keeps values encrypted in memory.
This class ensures that sensitive data remains encrypted as much as possible
while passing through the codebase, only decrypting when absolutely necessary.
TODO: Once we deprecate plaintext columns in the database:
- Remove the dual-write logic in to_dict()
- Remove the from_db() method's plaintext_value parameter
- Remove the _was_encrypted flag (no longer needed for migration)
- Simplify get_plaintext() to only handle encrypted values
"""
# Store the encrypted value
_encrypted_value: Optional[str] = PrivateAttr(default=None)
# Cache the decrypted value to avoid repeated decryption
_plaintext_cache: Optional[str] = PrivateAttr(default=None)
# Flag to indicate if the value was originally encrypted
_was_encrypted: bool = PrivateAttr(default=False)
model_config = ConfigDict(frozen=True)
@classmethod
def from_plaintext(cls, value: Optional[str]) -> "Secret":
"""
Create a Secret from a plaintext value, encrypting it immediately.
Args:
value: The plaintext value to encrypt
Returns:
A Secret instance with the encrypted value
"""
if value is None:
instance = cls()
instance._encrypted_value = None
instance._was_encrypted = False
return instance
encrypted = CryptoUtils.encrypt(value)
instance = cls()
instance._encrypted_value = encrypted
instance._was_encrypted = False
return instance
@classmethod
def from_encrypted(cls, encrypted_value: Optional[str]) -> "Secret":
"""
Create a Secret from an already encrypted value.
Args:
encrypted_value: The encrypted value
Returns:
A Secret instance
"""
instance = cls()
instance._encrypted_value = encrypted_value
instance._was_encrypted = True
return instance
@classmethod
def from_db(cls, encrypted_value: Optional[str], plaintext_value: Optional[str]) -> "Secret":
"""
Create a Secret from database values during migration phase.
Prefers encrypted value if available, falls back to plaintext.
Args:
encrypted_value: The encrypted value from the database
plaintext_value: The plaintext value from the database
Returns:
A Secret instance
"""
if encrypted_value is not None:
return cls.from_encrypted(encrypted_value)
elif plaintext_value is not None:
return cls.from_plaintext(plaintext_value)
else:
return cls.from_plaintext(None)
def get_encrypted(self) -> Optional[str]:
"""
Get the encrypted value.
Returns:
The encrypted value, or None if the secret is empty
"""
return self._encrypted_value
def get_plaintext(self) -> Optional[str]:
"""
Get the decrypted plaintext value.
This should only be called when the plaintext is actually needed,
such as when making an external API call.
Returns:
The decrypted plaintext value
"""
if self._encrypted_value is None:
return None
# Use cached value if available
if self._plaintext_cache is not None:
return self._plaintext_cache
# Decrypt and cache
try:
plaintext = CryptoUtils.decrypt(self._encrypted_value)
# Note: We can't actually cache due to frozen=True, but in practice
# we'll create new instances rather than mutating
return plaintext
except Exception:
# If decryption fails and this wasn't originally encrypted,
# it might be that the value is actually plaintext (during migration)
if not self._was_encrypted:
return None
raise
def is_empty(self) -> bool:
"""Check if the secret is empty/None."""
return self._encrypted_value is None
def __str__(self) -> str:
"""String representation that doesn't expose the actual value."""
if self.is_empty():
return "<Secret: empty>"
return "<Secret: ****>"
def __repr__(self) -> str:
"""Representation that doesn't expose the actual value."""
return self.__str__()
def to_dict(self) -> Dict[str, Any]:
"""
Convert to dictionary for database storage.
Returns both encrypted and plaintext values for dual-write during migration.
"""
return {"encrypted": self.get_encrypted(), "plaintext": self.get_plaintext() if not self._was_encrypted else None}
def __eq__(self, other: Any) -> bool:
"""
Compare two secrets by their plaintext values.
Note: This decrypts both values, so use sparingly.
"""
if not isinstance(other, Secret):
return False
return self.get_plaintext() == other.get_plaintext()
class SecretDict(BaseModel):
"""
A wrapper for dictionaries containing sensitive key-value pairs.
Used for custom headers and other key-value configurations.
TODO: Once we deprecate plaintext columns in the database:
- Remove the dual-write logic in to_dict()
- Remove the from_db() method's plaintext_value parameter
- Remove the _was_encrypted flag (no longer needed for migration)
- Simplify get_plaintext() to only handle encrypted JSON values
"""
_encrypted_value: Optional[str] = PrivateAttr(default=None)
_plaintext_cache: Optional[Dict[str, str]] = PrivateAttr(default=None)
_was_encrypted: bool = PrivateAttr(default=False)
model_config = ConfigDict(frozen=True)
@classmethod
def from_plaintext(cls, value: Optional[Dict[str, str]]) -> "SecretDict":
"""Create a SecretDict from a plaintext dictionary."""
if value is None:
instance = cls()
instance._encrypted_value = None
instance._was_encrypted = False
return instance
# Serialize to JSON then encrypt
json_str = json.dumps(value)
encrypted = CryptoUtils.encrypt(json_str)
instance = cls()
instance._encrypted_value = encrypted
instance._was_encrypted = False
return instance
@classmethod
def from_encrypted(cls, encrypted_value: Optional[str]) -> "SecretDict":
"""Create a SecretDict from an encrypted value."""
instance = cls()
instance._encrypted_value = encrypted_value
instance._was_encrypted = True
return instance
@classmethod
def from_db(cls, encrypted_value: Optional[str], plaintext_value: Optional[Dict[str, str]]) -> "SecretDict":
"""Create a SecretDict from database values during migration phase."""
if encrypted_value is not None:
return cls.from_encrypted(encrypted_value)
elif plaintext_value is not None:
return cls.from_plaintext(plaintext_value)
else:
return cls.from_plaintext(None)
def get_encrypted(self) -> Optional[str]:
"""Get the encrypted value."""
return self._encrypted_value
def get_plaintext(self) -> Optional[Dict[str, str]]:
"""Get the decrypted dictionary."""
if self._encrypted_value is None:
return None
try:
decrypted_json = CryptoUtils.decrypt(self._encrypted_value)
return json.loads(decrypted_json)
except Exception:
if not self._was_encrypted:
return None
raise
def is_empty(self) -> bool:
"""Check if the secret dict is empty/None."""
return self._encrypted_value is None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for database storage."""
return {"encrypted": self.get_encrypted(), "plaintext": self.get_plaintext() if not self._was_encrypted else None}

View File

@@ -728,35 +728,16 @@ async def add_mcp_server_to_config(
return await server.add_mcp_server_to_config(server_config=request, allow_upsert=True)
else:
# log to DB
from letta.schemas.mcp import MCPServer
if isinstance(request, StdioServerConfig):
mapped_request = MCPServer(server_name=request.server_name, server_type=request.type, stdio_config=request)
# don't allow stdio servers
if tool_settings.mcp_disable_stdio: # protected server
raise HTTPException(
status_code=400,
detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server",
)
elif isinstance(request, SSEServerConfig):
mapped_request = MCPServer(
server_name=request.server_name,
server_type=request.type,
server_url=request.server_url,
token=request.resolve_token(),
custom_headers=request.custom_headers,
)
elif isinstance(request, StreamableHTTPServerConfig):
mapped_request = MCPServer(
server_name=request.server_name,
server_type=request.type,
server_url=request.server_url,
token=request.resolve_token(),
custom_headers=request.custom_headers,
# Check if stdio servers are disabled
if isinstance(request, StdioServerConfig) and tool_settings.mcp_disable_stdio:
raise HTTPException(
status_code=400,
detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server",
)
# Create MCP server and optimistically sync tools
await server.mcp_manager.create_mcp_server_with_tools(mapped_request, actor=actor)
# The mcp_manager will handle encryption of sensitive fields
await server.mcp_manager.create_mcp_server_from_config_with_tools(request, actor=actor)
# TODO: don't do this in the future (just return MCPServer)
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)

View File

@@ -35,6 +35,7 @@ from letta.schemas.mcp import (
UpdateStdioMCPServer,
UpdateStreamableHTTPMCPServer,
)
from letta.schemas.secret import Secret, SecretDict
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
@@ -354,6 +355,69 @@ class MCPManager:
logger.error(f"Failed to create MCP server: {e}")
raise
@enforce_types
async def create_mcp_server_from_config(
self, server_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], actor: PydanticUser
) -> MCPServer:
"""
Create an MCP server from a config object, handling encryption of sensitive fields.
This method converts the server config to an MCPServer model and encrypts
sensitive fields like tokens and custom headers.
"""
# Create base MCPServer object
if isinstance(server_config, StdioServerConfig):
mcp_server = MCPServer(server_name=server_config.server_name, server_type=server_config.type, stdio_config=server_config)
elif isinstance(server_config, SSEServerConfig):
mcp_server = MCPServer(
server_name=server_config.server_name,
server_type=server_config.type,
server_url=server_config.server_url,
)
# Encrypt sensitive fields
token = server_config.resolve_token()
if token:
token_secret = Secret.from_plaintext(token)
mcp_server.set_token_secret(token_secret)
if server_config.custom_headers:
headers_secret = SecretDict.from_plaintext(server_config.custom_headers)
mcp_server.set_custom_headers_secret(headers_secret)
elif isinstance(server_config, StreamableHTTPServerConfig):
mcp_server = MCPServer(
server_name=server_config.server_name,
server_type=server_config.type,
server_url=server_config.server_url,
)
# Encrypt sensitive fields
token = server_config.resolve_token()
if token:
token_secret = Secret.from_plaintext(token)
mcp_server.set_token_secret(token_secret)
if server_config.custom_headers:
headers_secret = SecretDict.from_plaintext(server_config.custom_headers)
mcp_server.set_custom_headers_secret(headers_secret)
else:
raise ValueError(f"Unsupported server config type: {type(server_config)}")
return mcp_server
@enforce_types
async def create_mcp_server_from_config_with_tools(
self, server_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], actor: PydanticUser
) -> MCPServer:
"""
Create an MCP server from a config object and optimistically sync its tools.
This method handles encryption of sensitive fields and then creates the server
with automatic tool synchronization.
"""
# Convert config to MCPServer with encryption
mcp_server = await self.create_mcp_server_from_config(server_config, actor)
# Create the server with tools
return await self.create_mcp_server_with_tools(mcp_server, actor)
@enforce_types
async def create_mcp_server_with_tools(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
"""
@@ -420,10 +484,33 @@ class MCPManager:
# Update tool attributes with only the fields that were explicitly set
update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True)
# Ensure custom_headers None is stored as SQL NULL, not JSON null
if update_data.get("custom_headers") is None:
update_data.pop("custom_headers", None)
setattr(mcp_server, "custom_headers", null())
# Handle encryption for token if provided
if "token" in update_data and update_data["token"] is not None:
token_secret = Secret.from_plaintext(update_data["token"])
secret_dict = token_secret.to_dict()
update_data["token_enc"] = secret_dict["encrypted"]
# During migration phase, also update plaintext
if not token_secret._was_encrypted:
update_data["token"] = secret_dict["plaintext"]
else:
update_data["token"] = None
# Handle encryption for custom_headers if provided
if "custom_headers" in update_data:
if update_data["custom_headers"] is not None:
headers_secret = SecretDict.from_plaintext(update_data["custom_headers"])
secret_dict = headers_secret.to_dict()
update_data["custom_headers_enc"] = secret_dict["encrypted"]
# During migration phase, also update plaintext
if not headers_secret._was_encrypted:
update_data["custom_headers"] = secret_dict["plaintext"]
else:
update_data["custom_headers"] = None
else:
# Ensure custom_headers None is stored as SQL NULL, not JSON null
update_data.pop("custom_headers", None)
setattr(mcp_server, "custom_headers", null())
setattr(mcp_server, "custom_headers_enc", None)
for key, value in update_data.items():
setattr(mcp_server, key, value)
@@ -664,6 +751,86 @@ class MCPManager:
raise ValueError(f"Unsupported server config type: {type(server_config)}")
# OAuth-related methods
def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
"""
Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields.
"""
# Check for encryption key from env or settings
import os
from letta.settings import settings
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key
# Get decrypted values using the dual-read approach
access_token = None
if oauth_session.access_token_enc or oauth_session.access_token:
if encryption_key:
# Temporarily set the key for Secret
original_key = settings.encryption_key
settings.encryption_key = encryption_key
try:
secret = Secret.from_db(oauth_session.access_token_enc, oauth_session.access_token)
access_token = secret.get_plaintext()
finally:
settings.encryption_key = original_key
else:
# No encryption key, use plaintext if available
access_token = oauth_session.access_token
refresh_token = None
if oauth_session.refresh_token_enc or oauth_session.refresh_token:
if encryption_key:
# Temporarily set the key for Secret
original_key = settings.encryption_key
settings.encryption_key = encryption_key
try:
secret = Secret.from_db(oauth_session.refresh_token_enc, oauth_session.refresh_token)
refresh_token = secret.get_plaintext()
finally:
settings.encryption_key = original_key
else:
# No encryption key, use plaintext if available
refresh_token = oauth_session.refresh_token
client_secret = None
if oauth_session.client_secret_enc or oauth_session.client_secret:
if encryption_key:
# Temporarily set the key for Secret
original_key = settings.encryption_key
settings.encryption_key = encryption_key
try:
secret = Secret.from_db(oauth_session.client_secret_enc, oauth_session.client_secret)
client_secret = secret.get_plaintext()
finally:
settings.encryption_key = original_key
else:
# No encryption key, use plaintext if available
client_secret = oauth_session.client_secret
return MCPOAuthSession(
id=oauth_session.id,
state=oauth_session.state,
server_id=oauth_session.server_id,
server_url=oauth_session.server_url,
server_name=oauth_session.server_name,
user_id=oauth_session.user_id,
organization_id=oauth_session.organization_id,
authorization_url=oauth_session.authorization_url,
authorization_code=oauth_session.authorization_code,
access_token=access_token,
refresh_token=refresh_token,
token_type=oauth_session.token_type,
expires_at=oauth_session.expires_at,
scope=oauth_session.scope,
client_id=oauth_session.client_id,
client_secret=client_secret,
redirect_uri=oauth_session.redirect_uri,
status=oauth_session.status,
created_at=oauth_session.created_at,
updated_at=oauth_session.updated_at,
)
@enforce_types
async def create_oauth_session(self, session_create: MCPOAuthSessionCreate, actor: PydanticUser) -> MCPOAuthSession:
"""Create a new OAuth session for MCP server authentication."""
@@ -682,18 +849,8 @@ class MCPManager:
)
oauth_session = await oauth_session.create_async(session, actor=actor)
# Convert to Pydantic model
return MCPOAuthSession(
id=oauth_session.id,
state=oauth_session.state,
server_url=oauth_session.server_url,
server_name=oauth_session.server_name,
user_id=oauth_session.user_id,
organization_id=oauth_session.organization_id,
status=oauth_session.status,
created_at=oauth_session.created_at,
updated_at=oauth_session.updated_at,
)
# Convert to Pydantic model - note: new sessions won't have tokens yet
return self._oauth_orm_to_pydantic(oauth_session)
@enforce_types
async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
@@ -701,27 +858,7 @@ class MCPManager:
async with db_registry.async_session() as session:
try:
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
return MCPOAuthSession(
id=oauth_session.id,
state=oauth_session.state,
server_url=oauth_session.server_url,
server_name=oauth_session.server_name,
user_id=oauth_session.user_id,
organization_id=oauth_session.organization_id,
authorization_url=oauth_session.authorization_url,
authorization_code=oauth_session.authorization_code,
access_token=oauth_session.access_token,
refresh_token=oauth_session.refresh_token,
token_type=oauth_session.token_type,
expires_at=oauth_session.expires_at,
scope=oauth_session.scope,
client_id=oauth_session.client_id,
client_secret=oauth_session.client_secret,
redirect_uri=oauth_session.redirect_uri,
status=oauth_session.status,
created_at=oauth_session.created_at,
updated_at=oauth_session.updated_at,
)
return self._oauth_orm_to_pydantic(oauth_session)
except NoResultFound:
return None
@@ -747,27 +884,7 @@ class MCPManager:
if not oauth_session:
return None
return MCPOAuthSession(
id=oauth_session.id,
state=oauth_session.state,
server_url=oauth_session.server_url,
server_name=oauth_session.server_name,
user_id=oauth_session.user_id,
organization_id=oauth_session.organization_id,
authorization_url=oauth_session.authorization_url,
authorization_code=oauth_session.authorization_code,
access_token=oauth_session.access_token,
refresh_token=oauth_session.refresh_token,
token_type=oauth_session.token_type,
expires_at=oauth_session.expires_at,
scope=oauth_session.scope,
client_id=oauth_session.client_id,
client_secret=oauth_session.client_secret,
redirect_uri=oauth_session.redirect_uri,
status=oauth_session.status,
created_at=oauth_session.created_at,
updated_at=oauth_session.updated_at,
)
return self._oauth_orm_to_pydantic(oauth_session)
@enforce_types
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
@@ -780,10 +897,59 @@ class MCPManager:
oauth_session.authorization_url = session_update.authorization_url
if session_update.authorization_code is not None:
oauth_session.authorization_code = session_update.authorization_code
# Handle encryption for access_token
if session_update.access_token is not None:
oauth_session.access_token = session_update.access_token
# Check for encryption key from env or settings
import os
from letta.settings import settings
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key
if encryption_key:
# Temporarily set the key for Secret
original_key = settings.encryption_key
settings.encryption_key = encryption_key
try:
token_secret = Secret.from_plaintext(session_update.access_token)
secret_dict = token_secret.to_dict()
oauth_session.access_token_enc = secret_dict["encrypted"]
# During migration phase, also update plaintext
oauth_session.access_token = secret_dict["plaintext"] if not token_secret._was_encrypted else None
finally:
settings.encryption_key = original_key
else:
# No encryption, store plaintext
oauth_session.access_token = session_update.access_token
oauth_session.access_token_enc = None
# Handle encryption for refresh_token
if session_update.refresh_token is not None:
oauth_session.refresh_token = session_update.refresh_token
# Check for encryption key from env or settings
import os
from letta.settings import settings
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key
if encryption_key:
# Temporarily set the key for Secret
original_key = settings.encryption_key
settings.encryption_key = encryption_key
try:
token_secret = Secret.from_plaintext(session_update.refresh_token)
secret_dict = token_secret.to_dict()
oauth_session.refresh_token_enc = secret_dict["encrypted"]
# During migration phase, also update plaintext
oauth_session.refresh_token = secret_dict["plaintext"] if not token_secret._was_encrypted else None
finally:
settings.encryption_key = original_key
else:
# No encryption, store plaintext
oauth_session.refresh_token = session_update.refresh_token
oauth_session.refresh_token_enc = None
if session_update.token_type is not None:
oauth_session.token_type = session_update.token_type
if session_update.expires_at is not None:
@@ -792,8 +958,33 @@ class MCPManager:
oauth_session.scope = session_update.scope
if session_update.client_id is not None:
oauth_session.client_id = session_update.client_id
# Handle encryption for client_secret
if session_update.client_secret is not None:
oauth_session.client_secret = session_update.client_secret
# Check for encryption key from env or settings
import os
from letta.settings import settings
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key
if encryption_key:
# Temporarily set the key for Secret
original_key = settings.encryption_key
settings.encryption_key = encryption_key
try:
secret_secret = Secret.from_plaintext(session_update.client_secret)
secret_dict = secret_secret.to_dict()
oauth_session.client_secret_enc = secret_dict["encrypted"]
# During migration phase, also update plaintext
oauth_session.client_secret = secret_dict["plaintext"] if not secret_secret._was_encrypted else None
finally:
settings.encryption_key = original_key
else:
# No encryption, store plaintext
oauth_session.client_secret = session_update.client_secret
oauth_session.client_secret_enc = None
if session_update.redirect_uri is not None:
oauth_session.redirect_uri = session_update.redirect_uri
if session_update.status is not None:
@@ -804,27 +995,7 @@ class MCPManager:
oauth_session = await oauth_session.update_async(db_session=session, actor=actor)
return MCPOAuthSession(
id=oauth_session.id,
state=oauth_session.state,
server_url=oauth_session.server_url,
server_name=oauth_session.server_name,
user_id=oauth_session.user_id,
organization_id=oauth_session.organization_id,
authorization_url=oauth_session.authorization_url,
authorization_code=oauth_session.authorization_code,
access_token=oauth_session.access_token,
refresh_token=oauth_session.refresh_token,
token_type=oauth_session.token_type,
expires_at=oauth_session.expires_at,
scope=oauth_session.scope,
client_id=oauth_session.client_id,
client_secret=oauth_session.client_secret,
redirect_uri=oauth_session.redirect_uri,
status=oauth_session.status,
created_at=oauth_session.created_at,
updated_at=oauth_session.updated_at,
)
return self._oauth_orm_to_pydantic(oauth_session)
@enforce_types
async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None:

232
tests/test_crypto_utils.py Normal file
View File

@@ -0,0 +1,232 @@
import base64
import json
import os
from unittest.mock import patch
import pytest
from letta.helpers.crypto_utils import CryptoUtils
class TestCryptoUtils:
"""Test suite for CryptoUtils encryption/decryption functionality."""
# Mock master keys for testing
MOCK_KEY_1 = "test-master-key-1234567890abcdef"
MOCK_KEY_2 = "another-test-key-fedcba0987654321"
def test_encrypt_decrypt_roundtrip(self):
"""Test that encryption followed by decryption returns the original value."""
test_cases = [
"simple text",
"text with special chars: !@#$%^&*()",
"unicode text: 你好世界 🌍",
"very long text " * 1000,
'{"json": "data", "nested": {"key": "value"}}',
"", # Empty string
]
for plaintext in test_cases:
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
assert encrypted != plaintext, f"Encryption failed for: {plaintext[:50]}"
# Encrypted value is base64 encoded
assert len(encrypted) > 0, "Encrypted value should not be empty"
decrypted = CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1)
assert decrypted == plaintext, f"Roundtrip failed for: {plaintext[:50]}"
def test_encrypt_with_different_keys(self):
"""Test that different keys produce different ciphertexts."""
plaintext = "sensitive data"
encrypted1 = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
encrypted2 = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_2)
# Different keys should produce different ciphertexts
assert encrypted1 != encrypted2
# Each should decrypt correctly with its own key
assert CryptoUtils.decrypt(encrypted1, self.MOCK_KEY_1) == plaintext
assert CryptoUtils.decrypt(encrypted2, self.MOCK_KEY_2) == plaintext
def test_decrypt_with_wrong_key_fails(self):
"""Test that decryption with wrong key raises an error."""
plaintext = "secret message"
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
with pytest.raises(Exception): # Could be ValueError or cryptography exception
CryptoUtils.decrypt(encrypted, self.MOCK_KEY_2)
def test_encrypt_none_value(self):
"""Test handling of None values."""
# Encrypt None should raise TypeError (None has no encode method)
with pytest.raises((TypeError, AttributeError)):
CryptoUtils.encrypt(None, self.MOCK_KEY_1)
def test_decrypt_none_value(self):
"""Test that decrypting None raises an error."""
with pytest.raises(ValueError):
CryptoUtils.decrypt(None, self.MOCK_KEY_1)
def test_decrypt_empty_string(self):
"""Test that decrypting empty string raises an error."""
with pytest.raises(Exception): # base64 decode error
CryptoUtils.decrypt("", self.MOCK_KEY_1)
def test_decrypt_plaintext_value(self):
"""Test that decrypting non-encrypted value raises an error."""
plaintext = "not encrypted"
with pytest.raises(Exception): # Will fail base64 decode or decryption
CryptoUtils.decrypt(plaintext, self.MOCK_KEY_1)
def test_encrypted_format_structure(self):
"""Test the structure of encrypted values."""
plaintext = "test data"
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
# Should be base64 encoded
encrypted_data = encrypted
# Should be valid base64
try:
decoded = base64.b64decode(encrypted_data)
assert len(decoded) > 0
except Exception as e:
pytest.fail(f"Invalid base64 encoding: {e}")
# Decoded data should contain salt, IV, tag, and ciphertext
# Total should be at least SALT_SIZE + IV_SIZE + TAG_SIZE bytes
min_size = CryptoUtils.SALT_SIZE + CryptoUtils.IV_SIZE + CryptoUtils.TAG_SIZE
assert len(decoded) >= min_size
def test_deterministic_with_same_salt(self):
"""Test that encryption is deterministic when using the same salt (for testing)."""
plaintext = "deterministic test"
# Note: In production, each encryption generates a random salt
# This test verifies the encryption mechanism itself
encrypted1 = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
encrypted2 = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
# Due to random salt, these should be different
assert encrypted1 != encrypted2
# But both should decrypt to the same value
assert CryptoUtils.decrypt(encrypted1, self.MOCK_KEY_1) == plaintext
assert CryptoUtils.decrypt(encrypted2, self.MOCK_KEY_1) == plaintext
def test_encrypt_uses_env_key_when_none_provided(self):
"""Test that encryption uses environment key when no key is provided."""
from letta.settings import settings
# Mock the settings to have an encryption key
original_key = settings.encryption_key
settings.encryption_key = "env-test-key-123"
try:
plaintext = "test with env key"
# Should use key from settings
encrypted = CryptoUtils.encrypt(plaintext)
assert len(encrypted) > 0
# Should decrypt with same key
decrypted = CryptoUtils.decrypt(encrypted)
assert decrypted == plaintext
finally:
# Restore original key
settings.encryption_key = original_key
def test_encrypt_without_key_raises_error(self):
"""Test that encryption without any key raises an error."""
from letta.settings import settings
# Mock settings to have no encryption key
original_key = settings.encryption_key
settings.encryption_key = None
try:
with pytest.raises(ValueError, match="No encryption key configured"):
CryptoUtils.encrypt("test data")
finally:
# Restore original key
settings.encryption_key = original_key
def test_large_data_encryption(self):
"""Test encryption of large data."""
# Create 10MB of data
large_data = "x" * (10 * 1024 * 1024)
encrypted = CryptoUtils.encrypt(large_data, self.MOCK_KEY_1)
assert len(encrypted) > 0
assert encrypted != large_data
decrypted = CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1)
assert decrypted == large_data
def test_json_data_encryption(self):
"""Test encryption of JSON data."""
json_data = {
"user": "test_user",
"token": "secret_token_123",
"nested": {"api_key": "sk-1234567890", "headers": {"Authorization": "Bearer token"}},
}
json_str = json.dumps(json_data)
encrypted = CryptoUtils.encrypt(json_str, self.MOCK_KEY_1)
decrypted_str = CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1)
decrypted_data = json.loads(decrypted_str)
assert decrypted_data == json_data
def test_invalid_encrypted_format(self):
"""Test handling of invalid encrypted data format."""
invalid_cases = [
"invalid-base64!@#", # Invalid base64
"dGVzdA==", # Valid base64 but too short for encrypted data
]
for invalid in invalid_cases:
with pytest.raises(Exception): # Could be various exceptions
CryptoUtils.decrypt(invalid, self.MOCK_KEY_1)
def test_key_derivation_consistency(self):
"""Test that key derivation is consistent."""
plaintext = "test key derivation"
# Multiple encryptions with same key should work
encrypted_values = []
for _ in range(5):
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
encrypted_values.append(encrypted)
# All should decrypt correctly
for encrypted in encrypted_values:
assert CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1) == plaintext
def test_special_characters_in_key(self):
"""Test encryption with keys containing special characters."""
special_key = "key-with-special-chars!@#$%^&*()_+"
plaintext = "test data"
encrypted = CryptoUtils.encrypt(plaintext, special_key)
decrypted = CryptoUtils.decrypt(encrypted, special_key)
assert decrypted == plaintext
def test_whitespace_handling(self):
"""Test encryption of strings with various whitespace."""
test_cases = [
" leading spaces",
"trailing spaces ",
" both sides ",
"multiple\n\nlines",
"\ttabs\there\t",
"mixed \t\n whitespace \r\n",
]
for plaintext in test_cases:
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
decrypted = CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1)
assert decrypted == plaintext, f"Whitespace handling failed for: {repr(plaintext)}"

View File

@@ -0,0 +1,407 @@
"""
Integration tests for MCP server and OAuth session encryption.
Tests the end-to-end encryption functionality in the MCP manager.
"""
import json
import os
from datetime import datetime, timezone
from unittest.mock import AsyncMock, Mock, patch
from uuid import uuid4
import pytest
from sqlalchemy import select
from letta.config import LettaConfig
from letta.helpers.crypto_utils import CryptoUtils
from letta.orm import MCPOAuth, MCPServer as ORMMCPServer
from letta.schemas.mcp import (
MCPOAuthSessionCreate,
MCPServer as PydanticMCPServer,
MCPServerType,
SSEServerConfig,
StdioServerConfig,
)
from letta.schemas.secret import Secret, SecretDict
from letta.server.db import db_registry
from letta.server.server import SyncServer
from letta.services.mcp_manager import MCPManager
@pytest.fixture(scope="module")
def server():
"""Fixture to create and return a SyncServer instance with MCP manager."""
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=False)
return server
class TestMCPServerEncryption:
"""Test MCP server encryption functionality."""
MOCK_ENCRYPTION_KEY = "test-mcp-encryption-key-123456"
@pytest.mark.asyncio
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server_with_token_encryption(self, mock_get_client, server, default_user):
"""Test that MCP server tokens are encrypted when stored."""
# Mock the MCP client
mock_client = AsyncMock()
mock_client.list_tools.return_value = []
mock_get_client.return_value = mock_client
# Create MCP server with token
server_name = f"test_encrypted_server_{uuid4().hex[:8]}"
token = "super-secret-api-token-12345"
server_url = "https://api.example.com/mcp"
mcp_server = PydanticMCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url, token=token)
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
# Verify server was created
assert created_server.server_name == server_name
assert created_server.server_type == MCPServerType.SSE
# Check database directly to verify encryption
async with db_registry.async_session() as session:
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
db_server = result.scalar_one()
# Token should be encrypted in database
assert db_server.token_enc is not None
assert db_server.token_enc != token # Should not be plaintext
# Decrypt to verify correctness
decrypted_token = CryptoUtils.decrypt(db_server.token_enc, self.MOCK_ENCRYPTION_KEY)
assert decrypted_token == token
# Legacy plaintext column should be None (or empty for dual-write)
# During migration phase, might store both
if db_server.token:
assert db_server.token == token # Dual-write phase
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
@pytest.mark.asyncio
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server_with_custom_headers_encryption(self, mock_get_client, server, default_user):
"""Test that MCP server custom headers are encrypted when stored."""
# Mock the MCP client
mock_client = AsyncMock()
mock_client.list_tools.return_value = []
mock_get_client.return_value = mock_client
server_name = f"test_headers_server_{uuid4().hex[:8]}"
custom_headers = {"Authorization": "Bearer secret-token-xyz", "X-API-Key": "api-key-123456", "X-Custom-Header": "custom-value"}
server_url = "https://api.example.com/mcp"
mcp_server = PydanticMCPServer(
server_name=server_name, server_type=MCPServerType.STREAMABLE_HTTP, server_url=server_url, custom_headers=custom_headers
)
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
# Check database directly
async with db_registry.async_session() as session:
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
db_server = result.scalar_one()
# Custom headers should be encrypted as JSON
assert db_server.custom_headers_enc is not None
# Decrypt and parse JSON
decrypted_json = CryptoUtils.decrypt(db_server.custom_headers_enc, self.MOCK_ENCRYPTION_KEY)
decrypted_headers = json.loads(decrypted_json)
assert decrypted_headers == custom_headers
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
@pytest.mark.asyncio
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
async def test_retrieve_mcp_server_decrypts_values(self, server, default_user):
"""Test that retrieving MCP server decrypts encrypted values."""
# Manually insert encrypted server into database
server_id = f"mcp_server-{uuid4().hex[:8]}"
server_name = f"test_decrypt_server_{uuid4().hex[:8]}"
plaintext_token = "decryption-test-token"
encrypted_token = CryptoUtils.encrypt(plaintext_token, self.MOCK_ENCRYPTION_KEY)
async with db_registry.async_session() as session:
db_server = ORMMCPServer(
id=server_id,
server_name=server_name,
server_type=MCPServerType.SSE.value,
server_url="https://test.com",
token_enc=encrypted_token,
token=None, # No plaintext
created_by_id=default_user.id,
last_updated_by_id=default_user.id,
organization_id=default_user.organization_id,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(db_server)
await session.commit()
# Retrieve server directly by ID to avoid issues with other servers in DB
test_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user)
assert test_server is not None
assert test_server.server_name == server_name
# Token should be decrypted when accessed via the secret method
# Ensure encryption key is available for decryption
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
try:
token_secret = test_server.get_token_secret()
assert token_secret.get_plaintext() == plaintext_token
finally:
settings.encryption_key = original_key
# Clean up
async with db_registry.async_session() as session:
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == server_id))
db_server = result.scalar_one()
await session.delete(db_server)
await session.commit()
@pytest.mark.asyncio
@patch.dict(os.environ, {}, clear=True) # No encryption key
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server_without_encryption_key(self, mock_get_client, server, default_user):
"""Test that MCP servers work without encryption key (backward compatibility)."""
# Remove encryption key
os.environ.pop("LETTA_ENCRYPTION_KEY", None)
# Mock the MCP client
mock_client = AsyncMock()
mock_client.list_tools.return_value = []
mock_get_client.return_value = mock_client
server_name = f"test_no_encrypt_server_{uuid4().hex[:8]}"
token = "plaintext-token-no-encryption"
mcp_server = PydanticMCPServer(
server_name=server_name, server_type=MCPServerType.SSE, server_url="https://api.example.com", token=token
)
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
# Check database - should store as plaintext
async with db_registry.async_session() as session:
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
db_server = result.scalar_one()
# Should store in plaintext column
assert db_server.token == token
assert db_server.token_enc is None # No encryption
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
class TestMCPOAuthEncryption:
"""Test MCP OAuth session encryption functionality."""
MOCK_ENCRYPTION_KEY = "test-oauth-encryption-key-123456"
@pytest.mark.asyncio
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
async def test_create_oauth_session_with_encryption(self, server, default_user):
"""Test that OAuth tokens are encrypted when stored."""
server_url = "https://github.com/mcp"
server_name = "GitHub MCP"
# Step 1: Create OAuth session (without tokens initially)
oauth_session_create = MCPOAuthSessionCreate(
server_url=server_url,
server_name=server_name,
organization_id=default_user.organization_id,
user_id=default_user.id,
)
created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user)
assert created_session.server_url == server_url
assert created_session.server_name == server_name
# Step 2: Update session with tokens (simulating OAuth callback)
from letta.schemas.mcp import MCPOAuthSessionUpdate
update_data = MCPOAuthSessionUpdate(
access_token="github-access-token-abc123",
refresh_token="github-refresh-token-xyz789",
client_id="client-id-123",
client_secret="client-secret-super-secret",
expires_at=datetime.now(timezone.utc),
)
await server.mcp_manager.update_oauth_session(created_session.id, update_data, actor=default_user)
# Check database directly for encryption
async with db_registry.async_session() as session:
result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id))
db_oauth = result.scalar_one()
# All sensitive fields should be encrypted
assert db_oauth.access_token_enc is not None
assert db_oauth.access_token_enc != update_data.access_token
assert db_oauth.refresh_token_enc is not None
assert db_oauth.client_secret_enc is not None
# Verify decryption
decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc, self.MOCK_ENCRYPTION_KEY)
assert decrypted_access == update_data.access_token
decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc, self.MOCK_ENCRYPTION_KEY)
assert decrypted_refresh == update_data.refresh_token
decrypted_secret = CryptoUtils.decrypt(db_oauth.client_secret_enc, self.MOCK_ENCRYPTION_KEY)
assert decrypted_secret == update_data.client_secret
# Clean up not needed - test database is reset
@pytest.mark.asyncio
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
async def test_retrieve_oauth_session_decrypts_tokens(self, server, default_user):
"""Test that retrieving OAuth session decrypts tokens."""
# Manually insert encrypted OAuth session
session_id = f"mcp-oauth-{str(uuid4())[:8]}"
access_token = "test-access-token"
refresh_token = "test-refresh-token"
client_secret = "test-client-secret"
encrypted_access = CryptoUtils.encrypt(access_token, self.MOCK_ENCRYPTION_KEY)
encrypted_refresh = CryptoUtils.encrypt(refresh_token, self.MOCK_ENCRYPTION_KEY)
encrypted_secret = CryptoUtils.encrypt(client_secret, self.MOCK_ENCRYPTION_KEY)
async with db_registry.async_session() as session:
db_oauth = MCPOAuth(
id=session_id,
state=f"test-state-{uuid4().hex[:8]}",
server_url="https://test.com/mcp",
server_name="Test Provider",
access_token_enc=encrypted_access,
refresh_token_enc=encrypted_refresh,
client_id="test-client",
client_secret_enc=encrypted_secret,
user_id=default_user.id,
organization_id=default_user.organization_id,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(db_oauth)
await session.commit()
# Retrieve through manager by ID
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
assert test_session is not None
# Tokens should be decrypted
assert test_session.access_token == access_token
assert test_session.refresh_token == refresh_token
assert test_session.client_secret == client_secret
# Clean up not needed - test database is reset
@pytest.mark.asyncio
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
async def test_update_oauth_session_maintains_encryption(self, server, default_user):
"""Test that updating OAuth session maintains encryption."""
# Create initial session (without tokens)
from letta.schemas.mcp import MCPOAuthSessionUpdate
oauth_session_create = MCPOAuthSessionCreate(
server_url="https://test.com/mcp",
server_name="Test Update Provider",
organization_id=default_user.organization_id,
user_id=default_user.id,
)
created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user)
# Add initial tokens
initial_update = MCPOAuthSessionUpdate(
access_token="initial-token",
refresh_token="initial-refresh",
client_id="client-123",
client_secret="initial-secret",
)
await server.mcp_manager.update_oauth_session(created_session.id, initial_update, actor=default_user)
# Update with new tokens
new_access_token = "updated-access-token"
new_refresh_token = "updated-refresh-token"
new_update = MCPOAuthSessionUpdate(
access_token=new_access_token,
refresh_token=new_refresh_token,
)
updated_session = await server.mcp_manager.update_oauth_session(created_session.id, new_update, actor=default_user)
# Verify update worked
assert updated_session.access_token == new_access_token
assert updated_session.refresh_token == new_refresh_token
# Check database encryption
async with db_registry.async_session() as session:
result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id))
db_oauth = result.scalar_one()
# New tokens should be encrypted
decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc, self.MOCK_ENCRYPTION_KEY)
assert decrypted_access == new_access_token
decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc, self.MOCK_ENCRYPTION_KEY)
assert decrypted_refresh == new_refresh_token
# Clean up not needed - test database is reset
@pytest.mark.asyncio
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
async def test_dual_read_backward_compatibility(self, server, default_user):
"""Test that system can read both encrypted and plaintext values (migration support)."""
# Insert a record with both encrypted and plaintext values
session_id = f"mcp-oauth-{str(uuid4())[:8]}"
plaintext_token = "legacy-plaintext-token"
new_encrypted_token = "new-encrypted-token"
encrypted_new = CryptoUtils.encrypt(new_encrypted_token, self.MOCK_ENCRYPTION_KEY)
async with db_registry.async_session() as session:
db_oauth = MCPOAuth(
id=session_id,
state=f"dual-read-state-{uuid4().hex[:8]}",
server_url="https://test.com/mcp",
server_name="Dual Read Test",
# Both encrypted and plaintext values
access_token=plaintext_token, # Legacy plaintext
access_token_enc=encrypted_new, # New encrypted
client_id="test-client",
user_id=default_user.id,
organization_id=default_user.organization_id,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(db_oauth)
await session.commit()
# Retrieve through manager
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
assert test_session is not None
# Should prefer encrypted value over plaintext
assert test_session.access_token == new_encrypted_token
# Clean up not needed - test database is reset

373
tests/test_secret.py Normal file
View File

@@ -0,0 +1,373 @@
import json
import pytest
from letta.helpers.crypto_utils import CryptoUtils
from letta.schemas.secret import Secret, SecretDict
class TestSecret:
"""Test suite for Secret wrapper class."""
MOCK_KEY = "test-secret-key-1234567890"
def test_from_plaintext_with_key(self):
"""Test creating a Secret from plaintext value with encryption key."""
from letta.settings import settings
# Set encryption key
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext = "my-secret-value"
secret = Secret.from_plaintext(plaintext)
# Should store encrypted value
assert secret._encrypted_value is not None
assert secret._encrypted_value != plaintext
assert secret._was_encrypted is False
# Should decrypt to original value
assert secret.get_plaintext() == plaintext
finally:
settings.encryption_key = original_key
def test_from_plaintext_without_key(self):
"""Test creating a Secret from plaintext without encryption key."""
from letta.settings import settings
# Clear encryption key
original_key = settings.encryption_key
settings.encryption_key = None
try:
plaintext = "my-plaintext-value"
# Should raise error when trying to encrypt without key
with pytest.raises(ValueError):
Secret.from_plaintext(plaintext)
finally:
settings.encryption_key = original_key
def test_from_plaintext_with_none(self):
"""Test creating a Secret from None value."""
secret = Secret.from_plaintext(None)
assert secret._encrypted_value is None
assert secret._was_encrypted is False
assert secret.get_plaintext() is None
assert secret.is_empty() is True
def test_from_encrypted(self):
"""Test creating a Secret from already encrypted value."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext = "database-secret"
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY)
secret = Secret.from_encrypted(encrypted)
assert secret._encrypted_value == encrypted
assert secret._was_encrypted is True
assert secret.get_plaintext() == plaintext
finally:
settings.encryption_key = original_key
def test_from_db_with_encrypted_value(self):
"""Test creating a Secret from database with encrypted value."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext = "database-secret"
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY)
secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=None)
assert secret._encrypted_value == encrypted
assert secret._was_encrypted is True
assert secret.get_plaintext() == plaintext
finally:
settings.encryption_key = original_key
def test_from_db_with_plaintext_value(self):
"""Test creating a Secret from database with plaintext value (backward compatibility)."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext = "legacy-plaintext"
# When only plaintext is provided, should encrypt it
secret = Secret.from_db(encrypted_value=None, plaintext_value=plaintext)
# Should encrypt the plaintext
assert secret._encrypted_value is not None
assert secret._was_encrypted is False
assert secret.get_plaintext() == plaintext
finally:
settings.encryption_key = original_key
def test_from_db_dual_read(self):
"""Test dual read functionality - prefer encrypted over plaintext."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext = "correct-value"
old_plaintext = "old-legacy-value"
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY)
# When both values exist, should prefer encrypted
secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=old_plaintext)
assert secret.get_plaintext() == plaintext # Should use encrypted value, not plaintext
finally:
settings.encryption_key = original_key
def test_get_encrypted(self):
"""Test getting the encrypted value for database storage."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext = "test-encryption"
secret = Secret.from_plaintext(plaintext)
encrypted_value = secret.get_encrypted()
assert encrypted_value is not None
# Should decrypt back to original
decrypted = CryptoUtils.decrypt(encrypted_value, self.MOCK_KEY)
assert decrypted == plaintext
finally:
settings.encryption_key = original_key
def test_is_empty(self):
"""Test checking if secret is empty."""
# Empty secret
empty_secret = Secret.from_plaintext(None)
assert empty_secret.is_empty() is True
# Non-empty secret
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
non_empty_secret = Secret.from_plaintext("value")
assert non_empty_secret.is_empty() is False
finally:
settings.encryption_key = original_key
def test_string_representation(self):
"""Test that string representation doesn't expose secret."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
secret = Secret.from_plaintext("sensitive-data")
# String representation should not contain the actual value
str_repr = str(secret)
assert "sensitive-data" not in str_repr
assert "****" in str_repr
# Empty secret
empty_secret = Secret.from_plaintext(None)
assert "empty" in str(empty_secret)
finally:
settings.encryption_key = original_key
def test_equality(self):
"""Test comparing two secrets."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext = "same-value"
secret1 = Secret.from_plaintext(plaintext)
secret2 = Secret.from_plaintext(plaintext)
# Should be equal based on plaintext value
assert secret1 == secret2
# Different values should not be equal
secret3 = Secret.from_plaintext("different-value")
assert secret1 != secret3
finally:
settings.encryption_key = original_key
class TestSecretDict:
"""Test suite for SecretDict wrapper class."""
MOCK_KEY = "test-secretdict-key-1234567890"
def test_from_plaintext_dict(self):
"""Test creating a SecretDict from plaintext dictionary."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext_dict = {"api_key": "sk-1234567890", "api_secret": "secret-value", "nested": {"token": "bearer-token"}}
secret_dict = SecretDict.from_plaintext(plaintext_dict)
# Should store encrypted JSON
assert secret_dict._encrypted_value is not None
# Should decrypt to original dict
assert secret_dict.get_plaintext() == plaintext_dict
finally:
settings.encryption_key = original_key
def test_from_plaintext_none(self):
"""Test creating a SecretDict from None value."""
secret_dict = SecretDict.from_plaintext(None)
assert secret_dict._encrypted_value is None
assert secret_dict.get_plaintext() is None
def test_from_encrypted_with_json(self):
"""Test creating a SecretDict from encrypted JSON."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext_dict = {"header1": "value1", "Authorization": "Bearer token123"}
json_str = json.dumps(plaintext_dict)
encrypted = CryptoUtils.encrypt(json_str, self.MOCK_KEY)
secret_dict = SecretDict.from_encrypted(encrypted)
assert secret_dict._encrypted_value == encrypted
assert secret_dict.get_plaintext() == plaintext_dict
finally:
settings.encryption_key = original_key
def test_from_db_with_encrypted(self):
"""Test creating SecretDict from database with encrypted value."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext_dict = {"key": "value"}
json_str = json.dumps(plaintext_dict)
encrypted = CryptoUtils.encrypt(json_str, self.MOCK_KEY)
secret_dict = SecretDict.from_db(encrypted_value=encrypted, plaintext_value=None)
assert secret_dict.get_plaintext() == plaintext_dict
finally:
settings.encryption_key = original_key
def test_from_db_with_plaintext_json(self):
"""Test creating SecretDict from database with plaintext JSON (backward compatibility)."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext_dict = {"legacy": "headers"}
# from_db expects a Dict, not a JSON string
secret_dict = SecretDict.from_db(encrypted_value=None, plaintext_value=plaintext_dict)
assert secret_dict.get_plaintext() == plaintext_dict
# Should have encrypted it
assert secret_dict._encrypted_value is not None
finally:
settings.encryption_key = original_key
def test_complex_nested_structure(self):
"""Test SecretDict with complex nested structures."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
complex_dict = {
"level1": {"level2": {"level3": ["item1", "item2"], "secret": "nested-secret"}, "array": [1, 2, {"nested": "value"}]},
"simple": "value",
"number": 42,
"boolean": True,
"null": None,
}
secret_dict = SecretDict.from_plaintext(complex_dict)
decrypted = secret_dict.get_plaintext()
assert decrypted == complex_dict
finally:
settings.encryption_key = original_key
def test_empty_dict(self):
"""Test SecretDict with empty dictionary."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
empty_dict = {}
secret_dict = SecretDict.from_plaintext(empty_dict)
assert secret_dict.get_plaintext() == empty_dict
# Encrypted value should still be created
encrypted = secret_dict.get_encrypted()
assert encrypted is not None
finally:
settings.encryption_key = original_key
def test_dual_read_prefer_encrypted(self):
"""Test that SecretDict prefers encrypted value over plaintext when both exist."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
new_dict = {"current": "value"}
old_dict = {"legacy": "value"}
encrypted = CryptoUtils.encrypt(json.dumps(new_dict), self.MOCK_KEY)
plaintext = json.dumps(old_dict)
secret_dict = SecretDict.from_db(encrypted_value=encrypted, plaintext_value=plaintext)
# Should use encrypted value, not plaintext
assert secret_dict.get_plaintext() == new_dict
finally:
settings.encryption_key = original_key