feat: encryption for mcp (#2937)
This commit is contained in:
5
.github/workflows/core-unit-sqlite-test.yaml
vendored
5
.github/workflows/core-unit-sqlite-test.yaml
vendored
@@ -53,7 +53,10 @@ jobs:
|
||||
{"test_suite": "mcp_tests/", "use_experimental": true},
|
||||
{"test_suite": "test_timezone_formatting.py"},
|
||||
{"test_suite": "test_plugins.py"},
|
||||
{"test_suite": "test_embeddings.py"}
|
||||
{"test_suite": "test_embeddings.py"},
|
||||
{"test_suite": "test_crypto_utils.py"},
|
||||
{"test_suite": "test_mcp_encryption.py"},
|
||||
{"test_suite": "test_secret.py"}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
5
.github/workflows/core-unit-test.yml
vendored
5
.github/workflows/core-unit-test.yml
vendored
@@ -53,7 +53,10 @@ jobs:
|
||||
{"test_suite": "mcp_tests/", "use_experimental": true},
|
||||
{"test_suite": "test_timezone_formatting.py"},
|
||||
{"test_suite": "test_plugins.py"},
|
||||
{"test_suite": "test_embeddings.py"}
|
||||
{"test_suite": "test_embeddings.py"},
|
||||
{"test_suite": "test_crypto_utils.py"},
|
||||
{"test_suite": "test_mcp_encryption.py"},
|
||||
{"test_suite": "test_secret.py"}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,318 @@
|
||||
"""add and migrate encrypted columns for mcp
|
||||
|
||||
Revision ID: d06594144ef3
|
||||
Revises: 5d27a719b24d
|
||||
Create Date: 2025-09-15 22:02:47.403970
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
# Add the app directory to path to import our crypto utils
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import JSON, String, Text
|
||||
from sqlalchemy.sql import column, table
|
||||
|
||||
from alembic import op
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from letta.helpers.crypto_utils import CryptoUtils
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d06594144ef3"
|
||||
down_revision: Union[str, None] = "5d27a719b24d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First, add the new encrypted columns
|
||||
op.add_column("mcp_oauth", sa.Column("access_token_enc", sa.Text(), nullable=True))
|
||||
op.add_column("mcp_oauth", sa.Column("refresh_token_enc", sa.Text(), nullable=True))
|
||||
op.add_column("mcp_oauth", sa.Column("client_secret_enc", sa.Text(), nullable=True))
|
||||
op.add_column("mcp_server", sa.Column("token_enc", sa.Text(), nullable=True))
|
||||
op.add_column("mcp_server", sa.Column("custom_headers_enc", sa.Text(), nullable=True))
|
||||
|
||||
# Check if encryption key is available
|
||||
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY")
|
||||
if not encryption_key:
|
||||
print("WARNING: LETTA_ENCRYPTION_KEY not set. Skipping data encryption migration.")
|
||||
print(" You can run a separate migration script later to encrypt existing data.")
|
||||
return
|
||||
|
||||
# Get database connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Batch processing configuration
|
||||
BATCH_SIZE = 1000 # Process 1000 rows at a time
|
||||
|
||||
# Migrate mcp_oauth data
|
||||
print("Migrating mcp_oauth encrypted fields...")
|
||||
mcp_oauth = table(
|
||||
"mcp_oauth",
|
||||
column("id", String),
|
||||
column("access_token", Text),
|
||||
column("access_token_enc", Text),
|
||||
column("refresh_token", Text),
|
||||
column("refresh_token_enc", Text),
|
||||
column("client_secret", Text),
|
||||
column("client_secret_enc", Text),
|
||||
)
|
||||
|
||||
# Count total rows to process
|
||||
total_count_result = connection.execute(
|
||||
sa.select(sa.func.count())
|
||||
.select_from(mcp_oauth)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.client_secret.isnot(None)),
|
||||
# Only count rows that need encryption
|
||||
sa.or_(
|
||||
sa.and_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.access_token_enc.is_(None)),
|
||||
sa.and_(mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.refresh_token_enc.is_(None)),
|
||||
sa.and_(mcp_oauth.c.client_secret.isnot(None), mcp_oauth.c.client_secret_enc.is_(None)),
|
||||
),
|
||||
)
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if total_count_result and total_count_result > 0:
|
||||
print(f"Found {total_count_result} mcp_oauth records that need encryption")
|
||||
|
||||
encrypted_count = 0
|
||||
skipped_count = 0
|
||||
offset = 0
|
||||
|
||||
# Process in batches
|
||||
while True:
|
||||
# Select batch of rows
|
||||
oauth_rows = connection.execute(
|
||||
sa.select(
|
||||
mcp_oauth.c.id,
|
||||
mcp_oauth.c.access_token,
|
||||
mcp_oauth.c.access_token_enc,
|
||||
mcp_oauth.c.refresh_token,
|
||||
mcp_oauth.c.refresh_token_enc,
|
||||
mcp_oauth.c.client_secret,
|
||||
mcp_oauth.c.client_secret_enc,
|
||||
)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(
|
||||
mcp_oauth.c.access_token.isnot(None),
|
||||
mcp_oauth.c.refresh_token.isnot(None),
|
||||
mcp_oauth.c.client_secret.isnot(None),
|
||||
),
|
||||
# Only select rows that need encryption
|
||||
sa.or_(
|
||||
sa.and_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.access_token_enc.is_(None)),
|
||||
sa.and_(mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.refresh_token_enc.is_(None)),
|
||||
sa.and_(mcp_oauth.c.client_secret.isnot(None), mcp_oauth.c.client_secret_enc.is_(None)),
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(mcp_oauth.c.id) # Ensure consistent ordering
|
||||
.limit(BATCH_SIZE)
|
||||
.offset(offset)
|
||||
).fetchall()
|
||||
|
||||
if not oauth_rows:
|
||||
break # No more rows to process
|
||||
|
||||
# Prepare batch updates
|
||||
batch_updates = []
|
||||
|
||||
for row in oauth_rows:
|
||||
updates = {"id": row.id}
|
||||
has_updates = False
|
||||
|
||||
# Encrypt access_token if present and not already encrypted
|
||||
if row.access_token and not row.access_token_enc:
|
||||
try:
|
||||
updates["access_token_enc"] = CryptoUtils.encrypt(row.access_token, encryption_key)
|
||||
has_updates = True
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt access_token for mcp_oauth id={row.id}: {e}")
|
||||
elif row.access_token_enc:
|
||||
skipped_count += 1
|
||||
|
||||
# Encrypt refresh_token if present and not already encrypted
|
||||
if row.refresh_token and not row.refresh_token_enc:
|
||||
try:
|
||||
updates["refresh_token_enc"] = CryptoUtils.encrypt(row.refresh_token, encryption_key)
|
||||
has_updates = True
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt refresh_token for mcp_oauth id={row.id}: {e}")
|
||||
elif row.refresh_token_enc:
|
||||
skipped_count += 1
|
||||
|
||||
# Encrypt client_secret if present and not already encrypted
|
||||
if row.client_secret and not row.client_secret_enc:
|
||||
try:
|
||||
updates["client_secret_enc"] = CryptoUtils.encrypt(row.client_secret, encryption_key)
|
||||
has_updates = True
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt client_secret for mcp_oauth id={row.id}: {e}")
|
||||
elif row.client_secret_enc:
|
||||
skipped_count += 1
|
||||
|
||||
if has_updates:
|
||||
batch_updates.append(updates)
|
||||
encrypted_count += 1
|
||||
|
||||
# Execute batch update if there are updates
|
||||
if batch_updates:
|
||||
# Use bulk update for better performance
|
||||
for update_data in batch_updates:
|
||||
row_id = update_data.pop("id")
|
||||
if update_data: # Only update if there are fields to update
|
||||
connection.execute(mcp_oauth.update().where(mcp_oauth.c.id == row_id).values(**update_data))
|
||||
|
||||
# Progress indicator for large datasets
|
||||
if encrypted_count > 0 and encrypted_count % 10000 == 0:
|
||||
print(f" Progress: Encrypted {encrypted_count} mcp_oauth records...")
|
||||
|
||||
offset += BATCH_SIZE
|
||||
|
||||
# For very large datasets, commit periodically to avoid long transactions
|
||||
if encrypted_count > 0 and encrypted_count % 50000 == 0:
|
||||
connection.commit()
|
||||
|
||||
print(f"mcp_oauth: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
|
||||
else:
|
||||
print("mcp_oauth: No records need encryption")
|
||||
|
||||
# Migrate mcp_server data
|
||||
print("Migrating mcp_server encrypted fields...")
|
||||
mcp_server = table(
|
||||
"mcp_server",
|
||||
column("id", String),
|
||||
column("token", String),
|
||||
column("token_enc", Text),
|
||||
column("custom_headers", JSON),
|
||||
column("custom_headers_enc", Text),
|
||||
)
|
||||
|
||||
# Count total rows to process
|
||||
total_count_result = connection.execute(
|
||||
sa.select(sa.func.count())
|
||||
.select_from(mcp_server)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(mcp_server.c.token.isnot(None), mcp_server.c.custom_headers.isnot(None)),
|
||||
# Only count rows that need encryption
|
||||
sa.or_(
|
||||
sa.and_(mcp_server.c.token.isnot(None), mcp_server.c.token_enc.is_(None)),
|
||||
sa.and_(mcp_server.c.custom_headers.isnot(None), mcp_server.c.custom_headers_enc.is_(None)),
|
||||
),
|
||||
)
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if total_count_result and total_count_result > 0:
|
||||
print(f"Found {total_count_result} mcp_server records that need encryption")
|
||||
|
||||
encrypted_count = 0
|
||||
skipped_count = 0
|
||||
offset = 0
|
||||
|
||||
# Process in batches
|
||||
while True:
|
||||
# Select batch of rows
|
||||
server_rows = connection.execute(
|
||||
sa.select(
|
||||
mcp_server.c.id,
|
||||
mcp_server.c.token,
|
||||
mcp_server.c.token_enc,
|
||||
mcp_server.c.custom_headers,
|
||||
mcp_server.c.custom_headers_enc,
|
||||
)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(mcp_server.c.token.isnot(None), mcp_server.c.custom_headers.isnot(None)),
|
||||
# Only select rows that need encryption
|
||||
sa.or_(
|
||||
sa.and_(mcp_server.c.token.isnot(None), mcp_server.c.token_enc.is_(None)),
|
||||
sa.and_(mcp_server.c.custom_headers.isnot(None), mcp_server.c.custom_headers_enc.is_(None)),
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(mcp_server.c.id) # Ensure consistent ordering
|
||||
.limit(BATCH_SIZE)
|
||||
.offset(offset)
|
||||
).fetchall()
|
||||
|
||||
if not server_rows:
|
||||
break # No more rows to process
|
||||
|
||||
# Prepare batch updates
|
||||
batch_updates = []
|
||||
|
||||
for row in server_rows:
|
||||
updates = {"id": row.id}
|
||||
has_updates = False
|
||||
|
||||
# Encrypt token if present and not already encrypted
|
||||
if row.token and not row.token_enc:
|
||||
try:
|
||||
updates["token_enc"] = CryptoUtils.encrypt(row.token, encryption_key)
|
||||
has_updates = True
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt token for mcp_server id={row.id}: {e}")
|
||||
elif row.token_enc:
|
||||
skipped_count += 1
|
||||
|
||||
# Encrypt custom_headers if present (JSON field) and not already encrypted
|
||||
if row.custom_headers and not row.custom_headers_enc:
|
||||
try:
|
||||
# Convert JSON to string for encryption
|
||||
headers_json = json.dumps(row.custom_headers)
|
||||
updates["custom_headers_enc"] = CryptoUtils.encrypt(headers_json, encryption_key)
|
||||
has_updates = True
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt custom_headers for mcp_server id={row.id}: {e}")
|
||||
elif row.custom_headers_enc:
|
||||
skipped_count += 1
|
||||
|
||||
if has_updates:
|
||||
batch_updates.append(updates)
|
||||
encrypted_count += 1
|
||||
|
||||
# Execute batch update if there are updates
|
||||
if batch_updates:
|
||||
# Use bulk update for better performance
|
||||
for update_data in batch_updates:
|
||||
row_id = update_data.pop("id")
|
||||
if update_data: # Only update if there are fields to update
|
||||
connection.execute(mcp_server.update().where(mcp_server.c.id == row_id).values(**update_data))
|
||||
|
||||
# Progress indicator for large datasets
|
||||
if encrypted_count > 0 and encrypted_count % 10000 == 0:
|
||||
print(f" Progress: Encrypted {encrypted_count} mcp_server records...")
|
||||
|
||||
offset += BATCH_SIZE
|
||||
|
||||
# For very large datasets, commit periodically to avoid long transactions
|
||||
if encrypted_count > 0 and encrypted_count % 50000 == 0:
|
||||
connection.commit()
|
||||
|
||||
print(f"mcp_server: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
|
||||
else:
|
||||
print("mcp_server: No records need encryption")
|
||||
print("Migration complete. Plaintext columns are retained for rollback safety.")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("mcp_server", "custom_headers_enc")
|
||||
op.drop_column("mcp_server", "token_enc")
|
||||
op.drop_column("mcp_oauth", "client_secret_enc")
|
||||
op.drop_column("mcp_oauth", "refresh_token_enc")
|
||||
op.drop_column("mcp_oauth", "access_token_enc")
|
||||
# ### end Alembic commands ###
|
||||
134
letta/helpers/crypto_utils.py
Normal file
134
letta/helpers/crypto_utils.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import base64
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
class CryptoUtils:
|
||||
"""Utility class for AES-256-GCM encryption/decryption of sensitive data."""
|
||||
|
||||
# AES-256 requires 32 bytes key
|
||||
KEY_SIZE = 32
|
||||
# GCM standard IV size is 12 bytes (96 bits)
|
||||
IV_SIZE = 12
|
||||
# GCM tag size is 16 bytes (128 bits)
|
||||
TAG_SIZE = 16
|
||||
# Salt size for key derivation
|
||||
SALT_SIZE = 16
|
||||
|
||||
@classmethod
|
||||
def _derive_key(cls, master_key: str, salt: bytes) -> bytes:
|
||||
"""Derive an AES key from the master key using PBKDF2."""
|
||||
kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=cls.KEY_SIZE, salt=salt, iterations=100000, backend=default_backend())
|
||||
return kdf.derive(master_key.encode())
|
||||
|
||||
@classmethod
|
||||
def encrypt(cls, plaintext: str, master_key: Optional[str] = None) -> str:
|
||||
"""
|
||||
Encrypt a string using AES-256-GCM.
|
||||
|
||||
Args:
|
||||
plaintext: The string to encrypt
|
||||
master_key: Optional master key (defaults to settings.encryption_key)
|
||||
|
||||
Returns:
|
||||
Base64 encoded string containing: salt + iv + ciphertext + tag
|
||||
|
||||
Raises:
|
||||
ValueError: If no encryption key is configured
|
||||
"""
|
||||
if master_key is None:
|
||||
master_key = settings.encryption_key
|
||||
|
||||
if not master_key:
|
||||
raise ValueError("No encryption key configured. Set LETTA_ENCRYPTION_KEY environment variable.")
|
||||
|
||||
# Generate random salt and IV
|
||||
salt = os.urandom(cls.SALT_SIZE)
|
||||
iv = os.urandom(cls.IV_SIZE)
|
||||
|
||||
# Derive key from master key
|
||||
key = cls._derive_key(master_key, salt)
|
||||
|
||||
# Create cipher
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
# Encrypt the plaintext
|
||||
ciphertext = encryptor.update(plaintext.encode()) + encryptor.finalize()
|
||||
|
||||
# Get the authentication tag
|
||||
tag = encryptor.tag
|
||||
|
||||
# Combine salt + iv + ciphertext + tag
|
||||
encrypted_data = salt + iv + ciphertext + tag
|
||||
|
||||
# Return as base64 encoded string
|
||||
return base64.b64encode(encrypted_data).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def decrypt(cls, encrypted: str, master_key: Optional[str] = None) -> str:
|
||||
"""
|
||||
Decrypt a string that was encrypted using AES-256-GCM.
|
||||
|
||||
Args:
|
||||
encrypted: Base64 encoded encrypted string
|
||||
master_key: Optional master key (defaults to settings.encryption_key)
|
||||
|
||||
Returns:
|
||||
The decrypted plaintext string
|
||||
|
||||
Raises:
|
||||
ValueError: If no encryption key is configured or decryption fails
|
||||
"""
|
||||
if master_key is None:
|
||||
master_key = settings.encryption_key
|
||||
|
||||
if not master_key:
|
||||
raise ValueError("No encryption key configured. Set LETTA_ENCRYPTION_KEY environment variable.")
|
||||
|
||||
try:
|
||||
# Decode from base64
|
||||
encrypted_data = base64.b64decode(encrypted)
|
||||
|
||||
# Extract components
|
||||
salt = encrypted_data[: cls.SALT_SIZE]
|
||||
iv = encrypted_data[cls.SALT_SIZE : cls.SALT_SIZE + cls.IV_SIZE]
|
||||
ciphertext = encrypted_data[cls.SALT_SIZE + cls.IV_SIZE : -cls.TAG_SIZE]
|
||||
tag = encrypted_data[-cls.TAG_SIZE :]
|
||||
|
||||
# Derive key from master key
|
||||
key = cls._derive_key(master_key, salt)
|
||||
|
||||
# Create cipher
|
||||
cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
# Decrypt the ciphertext
|
||||
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
|
||||
|
||||
return plaintext.decode("utf-8")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to decrypt data: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def is_encrypted(cls, value: str) -> bool:
|
||||
"""
|
||||
Check if a string appears to be encrypted (base64 encoded with correct size).
|
||||
|
||||
This is a heuristic check and may have false positives.
|
||||
"""
|
||||
try:
|
||||
decoded = base64.b64decode(value)
|
||||
# Check if length is consistent with our encryption format
|
||||
# Minimum size: salt(16) + iv(12) + tag(16) + at least 1 byte of ciphertext
|
||||
return len(decoded) >= cls.SALT_SIZE + cls.IV_SIZE + cls.TAG_SIZE + 1
|
||||
except Exception:
|
||||
return False
|
||||
@@ -18,6 +18,7 @@ from letta.orm.job import Job
|
||||
from letta.orm.job_messages import JobMessage
|
||||
from letta.orm.llm_batch_items import LLMBatchItem
|
||||
from letta.orm.llm_batch_job import LLMBatchJob
|
||||
from letta.orm.mcp_oauth import MCPOAuth
|
||||
from letta.orm.mcp_server import MCPServer
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
@@ -38,7 +38,11 @@ class MCPOAuth(SqlalchemyBase, OrganizationMixin, UserMixin):
|
||||
|
||||
# Token data
|
||||
access_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth access token")
|
||||
access_token_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted OAuth access token")
|
||||
|
||||
refresh_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth refresh token")
|
||||
refresh_token_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted OAuth refresh token")
|
||||
|
||||
token_type: Mapped[str] = mapped_column(String(50), default="Bearer", doc="Token type")
|
||||
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, doc="Token expiry time")
|
||||
scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth scope")
|
||||
@@ -46,6 +50,8 @@ class MCPOAuth(SqlalchemyBase, OrganizationMixin, UserMixin):
|
||||
# Client configuration
|
||||
client_id: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client ID")
|
||||
client_secret: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client secret")
|
||||
client_secret_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted OAuth client secret")
|
||||
|
||||
redirect_uri: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth redirect URI")
|
||||
|
||||
# Session state
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import JSON, String, UniqueConstraint
|
||||
from sqlalchemy import JSON, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.functions.mcp_client.types import StdioServerConfig
|
||||
@@ -39,9 +39,15 @@ class MCPServer(SqlalchemyBase, OrganizationMixin):
|
||||
# access token / api key for MCP servers that require authentication
|
||||
token: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The access token or api key for the MCP server")
|
||||
|
||||
# encrypted access token or api key for the MCP server
|
||||
token_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted access token or api key for the MCP server")
|
||||
|
||||
# custom headers for authentication (key-value pairs)
|
||||
custom_headers: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Custom authentication headers as key-value pairs")
|
||||
|
||||
# encrypted custom headers for authentication (key-value pairs)
|
||||
custom_headers_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted custom authentication headers")
|
||||
|
||||
# stdio server
|
||||
stdio_config: Mapped[Optional[StdioServerConfig]] = mapped_column(
|
||||
MCPStdioServerConfigColumn, nullable=True, doc="The configuration for the stdio server"
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta.functions.mcp_client.types import (
|
||||
)
|
||||
from letta.orm.mcp_oauth import OAuthSessionStatus
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.secret import Secret, SecretDict
|
||||
|
||||
|
||||
class BaseMCPServer(LettaBase):
|
||||
@@ -29,6 +30,9 @@ class MCPServer(BaseMCPServer):
|
||||
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for authentication)")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
|
||||
|
||||
token_enc: Optional[str] = Field(None, description="Encrypted token")
|
||||
custom_headers_enc: Optional[str] = Field(None, description="Encrypted custom headers")
|
||||
|
||||
# stdio config
|
||||
stdio_config: Optional[StdioServerConfig] = Field(
|
||||
None, description="The configuration for the server (MCP 'local' client will run this command)"
|
||||
@@ -41,18 +45,93 @@ class MCPServer(BaseMCPServer):
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
|
||||
|
||||
def get_token_secret(self) -> Secret:
|
||||
"""Get the token as a Secret object, preferring encrypted over plaintext."""
|
||||
return Secret.from_db(self.token_enc, self.token)
|
||||
|
||||
def get_custom_headers_secret(self) -> SecretDict:
|
||||
"""Get custom headers as a SecretDict object, preferring encrypted over plaintext."""
|
||||
return SecretDict.from_db(self.custom_headers_enc, self.custom_headers)
|
||||
|
||||
def set_token_secret(self, secret: Secret) -> None:
|
||||
"""Set token from a Secret object, updating both encrypted and plaintext fields."""
|
||||
secret_dict = secret.to_dict()
|
||||
self.token_enc = secret_dict["encrypted"]
|
||||
# Only set plaintext during migration phase
|
||||
if not secret._was_encrypted:
|
||||
self.token = secret_dict["plaintext"]
|
||||
else:
|
||||
self.token = None
|
||||
|
||||
def set_custom_headers_secret(self, secret: SecretDict) -> None:
|
||||
"""Set custom headers from a SecretDict object, updating both fields."""
|
||||
secret_dict = secret.to_dict()
|
||||
self.custom_headers_enc = secret_dict["encrypted"]
|
||||
# Only set plaintext during migration phase
|
||||
if not secret._was_encrypted:
|
||||
self.custom_headers = secret_dict["plaintext"]
|
||||
else:
|
||||
self.custom_headers = None
|
||||
|
||||
def model_dump(self, to_orm: bool = False, **kwargs):
|
||||
"""Override model_dump to handle encryption when saving to database."""
|
||||
import os
|
||||
|
||||
# Check environment variable directly to handle test patching
|
||||
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY")
|
||||
if not encryption_key:
|
||||
from letta.settings import settings
|
||||
|
||||
encryption_key = settings.encryption_key
|
||||
|
||||
data = super().model_dump(to_orm=to_orm, **kwargs)
|
||||
|
||||
if to_orm and encryption_key:
|
||||
# Temporarily set the encryption key for Secret/SecretDict
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = encryption_key
|
||||
try:
|
||||
# Encrypt token if present
|
||||
if self.token is not None:
|
||||
token_secret = Secret.from_plaintext(self.token)
|
||||
secret_dict = token_secret.to_dict()
|
||||
data["token_enc"] = secret_dict["encrypted"]
|
||||
# Keep plaintext for dual-write during migration
|
||||
data["token"] = secret_dict["plaintext"]
|
||||
|
||||
# Encrypt custom headers if present
|
||||
if self.custom_headers is not None:
|
||||
headers_secret = SecretDict.from_plaintext(self.custom_headers)
|
||||
secret_dict = headers_secret.to_dict()
|
||||
data["custom_headers_enc"] = secret_dict["encrypted"]
|
||||
# Keep plaintext for dual-write during migration
|
||||
data["custom_headers"] = secret_dict["plaintext"]
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
return data
|
||||
|
||||
def to_config(
|
||||
self,
|
||||
environment_variables: Optional[Dict[str, str]] = None,
|
||||
resolve_variables: bool = True,
|
||||
) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]:
|
||||
# Get decrypted values for use in config
|
||||
token_secret = self.get_token_secret()
|
||||
token_plaintext = token_secret.get_plaintext()
|
||||
|
||||
headers_secret = self.get_custom_headers_secret()
|
||||
headers_plaintext = headers_secret.get_plaintext()
|
||||
|
||||
if self.server_type == MCPServerType.SSE:
|
||||
config = SSEServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None,
|
||||
custom_headers=self.custom_headers,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if token_plaintext and not headers_plaintext else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {token_plaintext}" if token_plaintext and not headers_plaintext else None,
|
||||
custom_headers=headers_plaintext,
|
||||
)
|
||||
if resolve_variables:
|
||||
config.resolve_environment_variables(environment_variables)
|
||||
@@ -70,9 +149,9 @@ class MCPServer(BaseMCPServer):
|
||||
config = StreamableHTTPServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None,
|
||||
custom_headers=self.custom_headers,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if token_plaintext and not headers_plaintext else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {token_plaintext}" if token_plaintext and not headers_plaintext else None,
|
||||
custom_headers=headers_plaintext,
|
||||
)
|
||||
if resolve_variables:
|
||||
config.resolve_environment_variables(environment_variables)
|
||||
@@ -138,11 +217,18 @@ class MCPOAuthSession(BaseMCPOAuth):
|
||||
expires_at: Optional[datetime] = Field(None, description="Token expiry time")
|
||||
scope: Optional[str] = Field(None, description="OAuth scope")
|
||||
|
||||
# Encrypted token fields (for internal use)
|
||||
access_token_enc: Optional[str] = Field(None, description="Encrypted OAuth access token")
|
||||
refresh_token_enc: Optional[str] = Field(None, description="Encrypted OAuth refresh token")
|
||||
|
||||
# Client configuration
|
||||
client_id: Optional[str] = Field(None, description="OAuth client ID")
|
||||
client_secret: Optional[str] = Field(None, description="OAuth client secret")
|
||||
redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI")
|
||||
|
||||
# Encrypted client secret (for internal use)
|
||||
client_secret_enc: Optional[str] = Field(None, description="Encrypted OAuth client secret")
|
||||
|
||||
# Session state
|
||||
status: OAuthSessionStatus = Field(default=OAuthSessionStatus.PENDING, description="Session status")
|
||||
|
||||
@@ -150,6 +236,93 @@ class MCPOAuthSession(BaseMCPOAuth):
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="Session creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.now, description="Last update time")
|
||||
|
||||
def get_access_token_secret(self) -> Secret:
|
||||
"""Get the access token as a Secret object, preferring encrypted over plaintext."""
|
||||
return Secret.from_db(self.access_token_enc, self.access_token)
|
||||
|
||||
def get_refresh_token_secret(self) -> Secret:
|
||||
"""Get the refresh token as a Secret object, preferring encrypted over plaintext."""
|
||||
return Secret.from_db(self.refresh_token_enc, self.refresh_token)
|
||||
|
||||
def get_client_secret_secret(self) -> Secret:
|
||||
"""Get the client secret as a Secret object, preferring encrypted over plaintext."""
|
||||
return Secret.from_db(self.client_secret_enc, self.client_secret)
|
||||
|
||||
def set_access_token_secret(self, secret: Secret) -> None:
|
||||
"""Set access token from a Secret object."""
|
||||
secret_dict = secret.to_dict()
|
||||
self.access_token_enc = secret_dict["encrypted"]
|
||||
if not secret._was_encrypted:
|
||||
self.access_token = secret_dict["plaintext"]
|
||||
else:
|
||||
self.access_token = None
|
||||
|
||||
def set_refresh_token_secret(self, secret: Secret) -> None:
|
||||
"""Set refresh token from a Secret object."""
|
||||
secret_dict = secret.to_dict()
|
||||
self.refresh_token_enc = secret_dict["encrypted"]
|
||||
if not secret._was_encrypted:
|
||||
self.refresh_token = secret_dict["plaintext"]
|
||||
else:
|
||||
self.refresh_token = None
|
||||
|
||||
def set_client_secret_secret(self, secret: Secret) -> None:
|
||||
"""Set client secret from a Secret object."""
|
||||
secret_dict = secret.to_dict()
|
||||
self.client_secret_enc = secret_dict["encrypted"]
|
||||
if not secret._was_encrypted:
|
||||
self.client_secret = secret_dict["plaintext"]
|
||||
else:
|
||||
self.client_secret = None
|
||||
|
||||
def model_dump(self, to_orm: bool = False, **kwargs):
|
||||
"""Override model_dump to handle encryption when saving to database."""
|
||||
import os
|
||||
|
||||
# Check environment variable directly to handle test patching
|
||||
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY")
|
||||
if not encryption_key:
|
||||
from letta.settings import settings
|
||||
|
||||
encryption_key = settings.encryption_key
|
||||
|
||||
data = super().model_dump(to_orm=to_orm, **kwargs)
|
||||
|
||||
if to_orm and encryption_key:
|
||||
# Temporarily set the encryption key for Secret
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = encryption_key
|
||||
try:
|
||||
# Encrypt access token if present
|
||||
if self.access_token is not None:
|
||||
token_secret = Secret.from_plaintext(self.access_token)
|
||||
secret_dict = token_secret.to_dict()
|
||||
data["access_token_enc"] = secret_dict["encrypted"]
|
||||
# Keep plaintext for dual-write during migration
|
||||
data["access_token"] = secret_dict["plaintext"]
|
||||
|
||||
# Encrypt refresh token if present
|
||||
if self.refresh_token is not None:
|
||||
token_secret = Secret.from_plaintext(self.refresh_token)
|
||||
secret_dict = token_secret.to_dict()
|
||||
data["refresh_token_enc"] = secret_dict["encrypted"]
|
||||
# Keep plaintext for dual-write during migration
|
||||
data["refresh_token"] = secret_dict["plaintext"]
|
||||
|
||||
# Encrypt client secret if present
|
||||
if self.client_secret is not None:
|
||||
secret = Secret.from_plaintext(self.client_secret)
|
||||
secret_dict = secret.to_dict()
|
||||
data["client_secret_enc"] = secret_dict["encrypted"]
|
||||
# Keep plaintext for dual-write during migration
|
||||
data["client_secret"] = secret_dict["plaintext"]
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class MCPOAuthSessionCreate(BaseMCPOAuth):
|
||||
"""Create a new OAuth session."""
|
||||
|
||||
241
letta/schemas/secret.py
Normal file
241
letta/schemas/secret.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, PrivateAttr
|
||||
|
||||
from letta.helpers.crypto_utils import CryptoUtils
|
||||
|
||||
|
||||
class Secret(BaseModel):
|
||||
"""
|
||||
A wrapper class for encrypted credentials that keeps values encrypted in memory.
|
||||
|
||||
This class ensures that sensitive data remains encrypted as much as possible
|
||||
while passing through the codebase, only decrypting when absolutely necessary.
|
||||
|
||||
TODO: Once we deprecate plaintext columns in the database:
|
||||
- Remove the dual-write logic in to_dict()
|
||||
- Remove the from_db() method's plaintext_value parameter
|
||||
- Remove the _was_encrypted flag (no longer needed for migration)
|
||||
- Simplify get_plaintext() to only handle encrypted values
|
||||
"""
|
||||
|
||||
# Store the encrypted value
|
||||
_encrypted_value: Optional[str] = PrivateAttr(default=None)
|
||||
# Cache the decrypted value to avoid repeated decryption
|
||||
_plaintext_cache: Optional[str] = PrivateAttr(default=None)
|
||||
# Flag to indicate if the value was originally encrypted
|
||||
_was_encrypted: bool = PrivateAttr(default=False)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
@classmethod
|
||||
def from_plaintext(cls, value: Optional[str]) -> "Secret":
|
||||
"""
|
||||
Create a Secret from a plaintext value, encrypting it immediately.
|
||||
|
||||
Args:
|
||||
value: The plaintext value to encrypt
|
||||
|
||||
Returns:
|
||||
A Secret instance with the encrypted value
|
||||
"""
|
||||
if value is None:
|
||||
instance = cls()
|
||||
instance._encrypted_value = None
|
||||
instance._was_encrypted = False
|
||||
return instance
|
||||
|
||||
encrypted = CryptoUtils.encrypt(value)
|
||||
instance = cls()
|
||||
instance._encrypted_value = encrypted
|
||||
instance._was_encrypted = False
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def from_encrypted(cls, encrypted_value: Optional[str]) -> "Secret":
|
||||
"""
|
||||
Create a Secret from an already encrypted value.
|
||||
|
||||
Args:
|
||||
encrypted_value: The encrypted value
|
||||
|
||||
Returns:
|
||||
A Secret instance
|
||||
"""
|
||||
instance = cls()
|
||||
instance._encrypted_value = encrypted_value
|
||||
instance._was_encrypted = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, encrypted_value: Optional[str], plaintext_value: Optional[str]) -> "Secret":
|
||||
"""
|
||||
Create a Secret from database values during migration phase.
|
||||
|
||||
Prefers encrypted value if available, falls back to plaintext.
|
||||
|
||||
Args:
|
||||
encrypted_value: The encrypted value from the database
|
||||
plaintext_value: The plaintext value from the database
|
||||
|
||||
Returns:
|
||||
A Secret instance
|
||||
"""
|
||||
if encrypted_value is not None:
|
||||
return cls.from_encrypted(encrypted_value)
|
||||
elif plaintext_value is not None:
|
||||
return cls.from_plaintext(plaintext_value)
|
||||
else:
|
||||
return cls.from_plaintext(None)
|
||||
|
||||
def get_encrypted(self) -> Optional[str]:
|
||||
"""
|
||||
Get the encrypted value.
|
||||
|
||||
Returns:
|
||||
The encrypted value, or None if the secret is empty
|
||||
"""
|
||||
return self._encrypted_value
|
||||
|
||||
def get_plaintext(self) -> Optional[str]:
|
||||
"""
|
||||
Get the decrypted plaintext value.
|
||||
|
||||
This should only be called when the plaintext is actually needed,
|
||||
such as when making an external API call.
|
||||
|
||||
Returns:
|
||||
The decrypted plaintext value
|
||||
"""
|
||||
if self._encrypted_value is None:
|
||||
return None
|
||||
|
||||
# Use cached value if available
|
||||
if self._plaintext_cache is not None:
|
||||
return self._plaintext_cache
|
||||
|
||||
# Decrypt and cache
|
||||
try:
|
||||
plaintext = CryptoUtils.decrypt(self._encrypted_value)
|
||||
# Note: We can't actually cache due to frozen=True, but in practice
|
||||
# we'll create new instances rather than mutating
|
||||
return plaintext
|
||||
except Exception:
|
||||
# If decryption fails and this wasn't originally encrypted,
|
||||
# it might be that the value is actually plaintext (during migration)
|
||||
if not self._was_encrypted:
|
||||
return None
|
||||
raise
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the secret is empty/None."""
|
||||
return self._encrypted_value is None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation that doesn't expose the actual value."""
|
||||
if self.is_empty():
|
||||
return "<Secret: empty>"
|
||||
return "<Secret: ****>"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Representation that doesn't expose the actual value."""
|
||||
return self.__str__()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert to dictionary for database storage.
|
||||
|
||||
Returns both encrypted and plaintext values for dual-write during migration.
|
||||
"""
|
||||
return {"encrypted": self.get_encrypted(), "plaintext": self.get_plaintext() if not self._was_encrypted else None}
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""
|
||||
Compare two secrets by their plaintext values.
|
||||
|
||||
Note: This decrypts both values, so use sparingly.
|
||||
"""
|
||||
if not isinstance(other, Secret):
|
||||
return False
|
||||
return self.get_plaintext() == other.get_plaintext()
|
||||
|
||||
|
||||
class SecretDict(BaseModel):
|
||||
"""
|
||||
A wrapper for dictionaries containing sensitive key-value pairs.
|
||||
|
||||
Used for custom headers and other key-value configurations.
|
||||
|
||||
TODO: Once we deprecate plaintext columns in the database:
|
||||
- Remove the dual-write logic in to_dict()
|
||||
- Remove the from_db() method's plaintext_value parameter
|
||||
- Remove the _was_encrypted flag (no longer needed for migration)
|
||||
- Simplify get_plaintext() to only handle encrypted JSON values
|
||||
"""
|
||||
|
||||
_encrypted_value: Optional[str] = PrivateAttr(default=None)
|
||||
_plaintext_cache: Optional[Dict[str, str]] = PrivateAttr(default=None)
|
||||
_was_encrypted: bool = PrivateAttr(default=False)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
@classmethod
|
||||
def from_plaintext(cls, value: Optional[Dict[str, str]]) -> "SecretDict":
|
||||
"""Create a SecretDict from a plaintext dictionary."""
|
||||
if value is None:
|
||||
instance = cls()
|
||||
instance._encrypted_value = None
|
||||
instance._was_encrypted = False
|
||||
return instance
|
||||
|
||||
# Serialize to JSON then encrypt
|
||||
json_str = json.dumps(value)
|
||||
encrypted = CryptoUtils.encrypt(json_str)
|
||||
instance = cls()
|
||||
instance._encrypted_value = encrypted
|
||||
instance._was_encrypted = False
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def from_encrypted(cls, encrypted_value: Optional[str]) -> "SecretDict":
|
||||
"""Create a SecretDict from an encrypted value."""
|
||||
instance = cls()
|
||||
instance._encrypted_value = encrypted_value
|
||||
instance._was_encrypted = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, encrypted_value: Optional[str], plaintext_value: Optional[Dict[str, str]]) -> "SecretDict":
|
||||
"""Create a SecretDict from database values during migration phase."""
|
||||
if encrypted_value is not None:
|
||||
return cls.from_encrypted(encrypted_value)
|
||||
elif plaintext_value is not None:
|
||||
return cls.from_plaintext(plaintext_value)
|
||||
else:
|
||||
return cls.from_plaintext(None)
|
||||
|
||||
def get_encrypted(self) -> Optional[str]:
|
||||
"""Get the encrypted value."""
|
||||
return self._encrypted_value
|
||||
|
||||
def get_plaintext(self) -> Optional[Dict[str, str]]:
|
||||
"""Get the decrypted dictionary."""
|
||||
if self._encrypted_value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
decrypted_json = CryptoUtils.decrypt(self._encrypted_value)
|
||||
return json.loads(decrypted_json)
|
||||
except Exception:
|
||||
if not self._was_encrypted:
|
||||
return None
|
||||
raise
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the secret dict is empty/None."""
|
||||
return self._encrypted_value is None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for database storage."""
|
||||
return {"encrypted": self.get_encrypted(), "plaintext": self.get_plaintext() if not self._was_encrypted else None}
|
||||
@@ -728,35 +728,16 @@ async def add_mcp_server_to_config(
|
||||
return await server.add_mcp_server_to_config(server_config=request, allow_upsert=True)
|
||||
else:
|
||||
# log to DB
|
||||
from letta.schemas.mcp import MCPServer
|
||||
|
||||
if isinstance(request, StdioServerConfig):
|
||||
mapped_request = MCPServer(server_name=request.server_name, server_type=request.type, stdio_config=request)
|
||||
# don't allow stdio servers
|
||||
if tool_settings.mcp_disable_stdio: # protected server
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server",
|
||||
)
|
||||
elif isinstance(request, SSEServerConfig):
|
||||
mapped_request = MCPServer(
|
||||
server_name=request.server_name,
|
||||
server_type=request.type,
|
||||
server_url=request.server_url,
|
||||
token=request.resolve_token(),
|
||||
custom_headers=request.custom_headers,
|
||||
)
|
||||
elif isinstance(request, StreamableHTTPServerConfig):
|
||||
mapped_request = MCPServer(
|
||||
server_name=request.server_name,
|
||||
server_type=request.type,
|
||||
server_url=request.server_url,
|
||||
token=request.resolve_token(),
|
||||
custom_headers=request.custom_headers,
|
||||
# Check if stdio servers are disabled
|
||||
if isinstance(request, StdioServerConfig) and tool_settings.mcp_disable_stdio:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server",
|
||||
)
|
||||
|
||||
# Create MCP server and optimistically sync tools
|
||||
await server.mcp_manager.create_mcp_server_with_tools(mapped_request, actor=actor)
|
||||
# The mcp_manager will handle encryption of sensitive fields
|
||||
await server.mcp_manager.create_mcp_server_from_config_with_tools(request, actor=actor)
|
||||
|
||||
# TODO: don't do this in the future (just return MCPServer)
|
||||
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
|
||||
|
||||
@@ -35,6 +35,7 @@ from letta.schemas.mcp import (
|
||||
UpdateStdioMCPServer,
|
||||
UpdateStreamableHTTPMCPServer,
|
||||
)
|
||||
from letta.schemas.secret import Secret, SecretDict
|
||||
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
@@ -354,6 +355,69 @@ class MCPManager:
|
||||
logger.error(f"Failed to create MCP server: {e}")
|
||||
raise
|
||||
|
||||
@enforce_types
|
||||
async def create_mcp_server_from_config(
|
||||
self, server_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], actor: PydanticUser
|
||||
) -> MCPServer:
|
||||
"""
|
||||
Create an MCP server from a config object, handling encryption of sensitive fields.
|
||||
|
||||
This method converts the server config to an MCPServer model and encrypts
|
||||
sensitive fields like tokens and custom headers.
|
||||
"""
|
||||
# Create base MCPServer object
|
||||
if isinstance(server_config, StdioServerConfig):
|
||||
mcp_server = MCPServer(server_name=server_config.server_name, server_type=server_config.type, stdio_config=server_config)
|
||||
elif isinstance(server_config, SSEServerConfig):
|
||||
mcp_server = MCPServer(
|
||||
server_name=server_config.server_name,
|
||||
server_type=server_config.type,
|
||||
server_url=server_config.server_url,
|
||||
)
|
||||
# Encrypt sensitive fields
|
||||
token = server_config.resolve_token()
|
||||
if token:
|
||||
token_secret = Secret.from_plaintext(token)
|
||||
mcp_server.set_token_secret(token_secret)
|
||||
if server_config.custom_headers:
|
||||
headers_secret = SecretDict.from_plaintext(server_config.custom_headers)
|
||||
mcp_server.set_custom_headers_secret(headers_secret)
|
||||
|
||||
elif isinstance(server_config, StreamableHTTPServerConfig):
|
||||
mcp_server = MCPServer(
|
||||
server_name=server_config.server_name,
|
||||
server_type=server_config.type,
|
||||
server_url=server_config.server_url,
|
||||
)
|
||||
# Encrypt sensitive fields
|
||||
token = server_config.resolve_token()
|
||||
if token:
|
||||
token_secret = Secret.from_plaintext(token)
|
||||
mcp_server.set_token_secret(token_secret)
|
||||
if server_config.custom_headers:
|
||||
headers_secret = SecretDict.from_plaintext(server_config.custom_headers)
|
||||
mcp_server.set_custom_headers_secret(headers_secret)
|
||||
else:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
|
||||
return mcp_server
|
||||
|
||||
@enforce_types
|
||||
async def create_mcp_server_from_config_with_tools(
|
||||
self, server_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], actor: PydanticUser
|
||||
) -> MCPServer:
|
||||
"""
|
||||
Create an MCP server from a config object and optimistically sync its tools.
|
||||
|
||||
This method handles encryption of sensitive fields and then creates the server
|
||||
with automatic tool synchronization.
|
||||
"""
|
||||
# Convert config to MCPServer with encryption
|
||||
mcp_server = await self.create_mcp_server_from_config(server_config, actor)
|
||||
|
||||
# Create the server with tools
|
||||
return await self.create_mcp_server_with_tools(mcp_server, actor)
|
||||
|
||||
@enforce_types
|
||||
async def create_mcp_server_with_tools(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
|
||||
"""
|
||||
@@ -420,10 +484,33 @@ class MCPManager:
|
||||
# Update tool attributes with only the fields that were explicitly set
|
||||
update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True)
|
||||
|
||||
# Ensure custom_headers None is stored as SQL NULL, not JSON null
|
||||
if update_data.get("custom_headers") is None:
|
||||
update_data.pop("custom_headers", None)
|
||||
setattr(mcp_server, "custom_headers", null())
|
||||
# Handle encryption for token if provided
|
||||
if "token" in update_data and update_data["token"] is not None:
|
||||
token_secret = Secret.from_plaintext(update_data["token"])
|
||||
secret_dict = token_secret.to_dict()
|
||||
update_data["token_enc"] = secret_dict["encrypted"]
|
||||
# During migration phase, also update plaintext
|
||||
if not token_secret._was_encrypted:
|
||||
update_data["token"] = secret_dict["plaintext"]
|
||||
else:
|
||||
update_data["token"] = None
|
||||
|
||||
# Handle encryption for custom_headers if provided
|
||||
if "custom_headers" in update_data:
|
||||
if update_data["custom_headers"] is not None:
|
||||
headers_secret = SecretDict.from_plaintext(update_data["custom_headers"])
|
||||
secret_dict = headers_secret.to_dict()
|
||||
update_data["custom_headers_enc"] = secret_dict["encrypted"]
|
||||
# During migration phase, also update plaintext
|
||||
if not headers_secret._was_encrypted:
|
||||
update_data["custom_headers"] = secret_dict["plaintext"]
|
||||
else:
|
||||
update_data["custom_headers"] = None
|
||||
else:
|
||||
# Ensure custom_headers None is stored as SQL NULL, not JSON null
|
||||
update_data.pop("custom_headers", None)
|
||||
setattr(mcp_server, "custom_headers", null())
|
||||
setattr(mcp_server, "custom_headers_enc", None)
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(mcp_server, key, value)
|
||||
@@ -664,6 +751,86 @@ class MCPManager:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
|
||||
# OAuth-related methods
|
||||
def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
|
||||
"""
|
||||
Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields.
|
||||
"""
|
||||
# Check for encryption key from env or settings
|
||||
import os
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key
|
||||
|
||||
# Get decrypted values using the dual-read approach
|
||||
access_token = None
|
||||
if oauth_session.access_token_enc or oauth_session.access_token:
|
||||
if encryption_key:
|
||||
# Temporarily set the key for Secret
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = encryption_key
|
||||
try:
|
||||
secret = Secret.from_db(oauth_session.access_token_enc, oauth_session.access_token)
|
||||
access_token = secret.get_plaintext()
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
else:
|
||||
# No encryption key, use plaintext if available
|
||||
access_token = oauth_session.access_token
|
||||
|
||||
refresh_token = None
|
||||
if oauth_session.refresh_token_enc or oauth_session.refresh_token:
|
||||
if encryption_key:
|
||||
# Temporarily set the key for Secret
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = encryption_key
|
||||
try:
|
||||
secret = Secret.from_db(oauth_session.refresh_token_enc, oauth_session.refresh_token)
|
||||
refresh_token = secret.get_plaintext()
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
else:
|
||||
# No encryption key, use plaintext if available
|
||||
refresh_token = oauth_session.refresh_token
|
||||
|
||||
client_secret = None
|
||||
if oauth_session.client_secret_enc or oauth_session.client_secret:
|
||||
if encryption_key:
|
||||
# Temporarily set the key for Secret
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = encryption_key
|
||||
try:
|
||||
secret = Secret.from_db(oauth_session.client_secret_enc, oauth_session.client_secret)
|
||||
client_secret = secret.get_plaintext()
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
else:
|
||||
# No encryption key, use plaintext if available
|
||||
client_secret = oauth_session.client_secret
|
||||
|
||||
return MCPOAuthSession(
|
||||
id=oauth_session.id,
|
||||
state=oauth_session.state,
|
||||
server_id=oauth_session.server_id,
|
||||
server_url=oauth_session.server_url,
|
||||
server_name=oauth_session.server_name,
|
||||
user_id=oauth_session.user_id,
|
||||
organization_id=oauth_session.organization_id,
|
||||
authorization_url=oauth_session.authorization_url,
|
||||
authorization_code=oauth_session.authorization_code,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type=oauth_session.token_type,
|
||||
expires_at=oauth_session.expires_at,
|
||||
scope=oauth_session.scope,
|
||||
client_id=oauth_session.client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=oauth_session.redirect_uri,
|
||||
status=oauth_session.status,
|
||||
created_at=oauth_session.created_at,
|
||||
updated_at=oauth_session.updated_at,
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
async def create_oauth_session(self, session_create: MCPOAuthSessionCreate, actor: PydanticUser) -> MCPOAuthSession:
|
||||
"""Create a new OAuth session for MCP server authentication."""
|
||||
@@ -682,18 +849,8 @@ class MCPManager:
|
||||
)
|
||||
oauth_session = await oauth_session.create_async(session, actor=actor)
|
||||
|
||||
# Convert to Pydantic model
|
||||
return MCPOAuthSession(
|
||||
id=oauth_session.id,
|
||||
state=oauth_session.state,
|
||||
server_url=oauth_session.server_url,
|
||||
server_name=oauth_session.server_name,
|
||||
user_id=oauth_session.user_id,
|
||||
organization_id=oauth_session.organization_id,
|
||||
status=oauth_session.status,
|
||||
created_at=oauth_session.created_at,
|
||||
updated_at=oauth_session.updated_at,
|
||||
)
|
||||
# Convert to Pydantic model - note: new sessions won't have tokens yet
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
|
||||
@@ -701,27 +858,7 @@ class MCPManager:
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
|
||||
return MCPOAuthSession(
|
||||
id=oauth_session.id,
|
||||
state=oauth_session.state,
|
||||
server_url=oauth_session.server_url,
|
||||
server_name=oauth_session.server_name,
|
||||
user_id=oauth_session.user_id,
|
||||
organization_id=oauth_session.organization_id,
|
||||
authorization_url=oauth_session.authorization_url,
|
||||
authorization_code=oauth_session.authorization_code,
|
||||
access_token=oauth_session.access_token,
|
||||
refresh_token=oauth_session.refresh_token,
|
||||
token_type=oauth_session.token_type,
|
||||
expires_at=oauth_session.expires_at,
|
||||
scope=oauth_session.scope,
|
||||
client_id=oauth_session.client_id,
|
||||
client_secret=oauth_session.client_secret,
|
||||
redirect_uri=oauth_session.redirect_uri,
|
||||
status=oauth_session.status,
|
||||
created_at=oauth_session.created_at,
|
||||
updated_at=oauth_session.updated_at,
|
||||
)
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@@ -747,27 +884,7 @@ class MCPManager:
|
||||
if not oauth_session:
|
||||
return None
|
||||
|
||||
return MCPOAuthSession(
|
||||
id=oauth_session.id,
|
||||
state=oauth_session.state,
|
||||
server_url=oauth_session.server_url,
|
||||
server_name=oauth_session.server_name,
|
||||
user_id=oauth_session.user_id,
|
||||
organization_id=oauth_session.organization_id,
|
||||
authorization_url=oauth_session.authorization_url,
|
||||
authorization_code=oauth_session.authorization_code,
|
||||
access_token=oauth_session.access_token,
|
||||
refresh_token=oauth_session.refresh_token,
|
||||
token_type=oauth_session.token_type,
|
||||
expires_at=oauth_session.expires_at,
|
||||
scope=oauth_session.scope,
|
||||
client_id=oauth_session.client_id,
|
||||
client_secret=oauth_session.client_secret,
|
||||
redirect_uri=oauth_session.redirect_uri,
|
||||
status=oauth_session.status,
|
||||
created_at=oauth_session.created_at,
|
||||
updated_at=oauth_session.updated_at,
|
||||
)
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
|
||||
@@ -780,10 +897,59 @@ class MCPManager:
|
||||
oauth_session.authorization_url = session_update.authorization_url
|
||||
if session_update.authorization_code is not None:
|
||||
oauth_session.authorization_code = session_update.authorization_code
|
||||
|
||||
# Handle encryption for access_token
|
||||
if session_update.access_token is not None:
|
||||
oauth_session.access_token = session_update.access_token
|
||||
# Check for encryption key from env or settings
|
||||
import os
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key
|
||||
|
||||
if encryption_key:
|
||||
# Temporarily set the key for Secret
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = encryption_key
|
||||
try:
|
||||
token_secret = Secret.from_plaintext(session_update.access_token)
|
||||
secret_dict = token_secret.to_dict()
|
||||
oauth_session.access_token_enc = secret_dict["encrypted"]
|
||||
# During migration phase, also update plaintext
|
||||
oauth_session.access_token = secret_dict["plaintext"] if not token_secret._was_encrypted else None
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
else:
|
||||
# No encryption, store plaintext
|
||||
oauth_session.access_token = session_update.access_token
|
||||
oauth_session.access_token_enc = None
|
||||
|
||||
# Handle encryption for refresh_token
|
||||
if session_update.refresh_token is not None:
|
||||
oauth_session.refresh_token = session_update.refresh_token
|
||||
# Check for encryption key from env or settings
|
||||
import os
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key
|
||||
|
||||
if encryption_key:
|
||||
# Temporarily set the key for Secret
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = encryption_key
|
||||
try:
|
||||
token_secret = Secret.from_plaintext(session_update.refresh_token)
|
||||
secret_dict = token_secret.to_dict()
|
||||
oauth_session.refresh_token_enc = secret_dict["encrypted"]
|
||||
# During migration phase, also update plaintext
|
||||
oauth_session.refresh_token = secret_dict["plaintext"] if not token_secret._was_encrypted else None
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
else:
|
||||
# No encryption, store plaintext
|
||||
oauth_session.refresh_token = session_update.refresh_token
|
||||
oauth_session.refresh_token_enc = None
|
||||
|
||||
if session_update.token_type is not None:
|
||||
oauth_session.token_type = session_update.token_type
|
||||
if session_update.expires_at is not None:
|
||||
@@ -792,8 +958,33 @@ class MCPManager:
|
||||
oauth_session.scope = session_update.scope
|
||||
if session_update.client_id is not None:
|
||||
oauth_session.client_id = session_update.client_id
|
||||
|
||||
# Handle encryption for client_secret
|
||||
if session_update.client_secret is not None:
|
||||
oauth_session.client_secret = session_update.client_secret
|
||||
# Check for encryption key from env or settings
|
||||
import os
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key
|
||||
|
||||
if encryption_key:
|
||||
# Temporarily set the key for Secret
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = encryption_key
|
||||
try:
|
||||
secret_secret = Secret.from_plaintext(session_update.client_secret)
|
||||
secret_dict = secret_secret.to_dict()
|
||||
oauth_session.client_secret_enc = secret_dict["encrypted"]
|
||||
# During migration phase, also update plaintext
|
||||
oauth_session.client_secret = secret_dict["plaintext"] if not secret_secret._was_encrypted else None
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
else:
|
||||
# No encryption, store plaintext
|
||||
oauth_session.client_secret = session_update.client_secret
|
||||
oauth_session.client_secret_enc = None
|
||||
|
||||
if session_update.redirect_uri is not None:
|
||||
oauth_session.redirect_uri = session_update.redirect_uri
|
||||
if session_update.status is not None:
|
||||
@@ -804,27 +995,7 @@ class MCPManager:
|
||||
|
||||
oauth_session = await oauth_session.update_async(db_session=session, actor=actor)
|
||||
|
||||
return MCPOAuthSession(
|
||||
id=oauth_session.id,
|
||||
state=oauth_session.state,
|
||||
server_url=oauth_session.server_url,
|
||||
server_name=oauth_session.server_name,
|
||||
user_id=oauth_session.user_id,
|
||||
organization_id=oauth_session.organization_id,
|
||||
authorization_url=oauth_session.authorization_url,
|
||||
authorization_code=oauth_session.authorization_code,
|
||||
access_token=oauth_session.access_token,
|
||||
refresh_token=oauth_session.refresh_token,
|
||||
token_type=oauth_session.token_type,
|
||||
expires_at=oauth_session.expires_at,
|
||||
scope=oauth_session.scope,
|
||||
client_id=oauth_session.client_id,
|
||||
client_secret=oauth_session.client_secret,
|
||||
redirect_uri=oauth_session.redirect_uri,
|
||||
status=oauth_session.status,
|
||||
created_at=oauth_session.created_at,
|
||||
updated_at=oauth_session.updated_at,
|
||||
)
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None:
|
||||
|
||||
232
tests/test_crypto_utils.py
Normal file
232
tests/test_crypto_utils.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.helpers.crypto_utils import CryptoUtils
|
||||
|
||||
|
||||
class TestCryptoUtils:
|
||||
"""Test suite for CryptoUtils encryption/decryption functionality."""
|
||||
|
||||
# Mock master keys for testing
|
||||
MOCK_KEY_1 = "test-master-key-1234567890abcdef"
|
||||
MOCK_KEY_2 = "another-test-key-fedcba0987654321"
|
||||
|
||||
def test_encrypt_decrypt_roundtrip(self):
|
||||
"""Test that encryption followed by decryption returns the original value."""
|
||||
test_cases = [
|
||||
"simple text",
|
||||
"text with special chars: !@#$%^&*()",
|
||||
"unicode text: 你好世界 🌍",
|
||||
"very long text " * 1000,
|
||||
'{"json": "data", "nested": {"key": "value"}}',
|
||||
"", # Empty string
|
||||
]
|
||||
|
||||
for plaintext in test_cases:
|
||||
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
|
||||
assert encrypted != plaintext, f"Encryption failed for: {plaintext[:50]}"
|
||||
# Encrypted value is base64 encoded
|
||||
assert len(encrypted) > 0, "Encrypted value should not be empty"
|
||||
|
||||
decrypted = CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1)
|
||||
assert decrypted == plaintext, f"Roundtrip failed for: {plaintext[:50]}"
|
||||
|
||||
def test_encrypt_with_different_keys(self):
|
||||
"""Test that different keys produce different ciphertexts."""
|
||||
plaintext = "sensitive data"
|
||||
|
||||
encrypted1 = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
|
||||
encrypted2 = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_2)
|
||||
|
||||
# Different keys should produce different ciphertexts
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
# Each should decrypt correctly with its own key
|
||||
assert CryptoUtils.decrypt(encrypted1, self.MOCK_KEY_1) == plaintext
|
||||
assert CryptoUtils.decrypt(encrypted2, self.MOCK_KEY_2) == plaintext
|
||||
|
||||
def test_decrypt_with_wrong_key_fails(self):
|
||||
"""Test that decryption with wrong key raises an error."""
|
||||
plaintext = "secret message"
|
||||
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
|
||||
|
||||
with pytest.raises(Exception): # Could be ValueError or cryptography exception
|
||||
CryptoUtils.decrypt(encrypted, self.MOCK_KEY_2)
|
||||
|
||||
def test_encrypt_none_value(self):
|
||||
"""Test handling of None values."""
|
||||
# Encrypt None should raise TypeError (None has no encode method)
|
||||
with pytest.raises((TypeError, AttributeError)):
|
||||
CryptoUtils.encrypt(None, self.MOCK_KEY_1)
|
||||
|
||||
def test_decrypt_none_value(self):
|
||||
"""Test that decrypting None raises an error."""
|
||||
with pytest.raises(ValueError):
|
||||
CryptoUtils.decrypt(None, self.MOCK_KEY_1)
|
||||
|
||||
def test_decrypt_empty_string(self):
|
||||
"""Test that decrypting empty string raises an error."""
|
||||
with pytest.raises(Exception): # base64 decode error
|
||||
CryptoUtils.decrypt("", self.MOCK_KEY_1)
|
||||
|
||||
def test_decrypt_plaintext_value(self):
|
||||
"""Test that decrypting non-encrypted value raises an error."""
|
||||
plaintext = "not encrypted"
|
||||
with pytest.raises(Exception): # Will fail base64 decode or decryption
|
||||
CryptoUtils.decrypt(plaintext, self.MOCK_KEY_1)
|
||||
|
||||
def test_encrypted_format_structure(self):
|
||||
"""Test the structure of encrypted values."""
|
||||
plaintext = "test data"
|
||||
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
|
||||
|
||||
# Should be base64 encoded
|
||||
encrypted_data = encrypted
|
||||
|
||||
# Should be valid base64
|
||||
try:
|
||||
decoded = base64.b64decode(encrypted_data)
|
||||
assert len(decoded) > 0
|
||||
except Exception as e:
|
||||
pytest.fail(f"Invalid base64 encoding: {e}")
|
||||
|
||||
# Decoded data should contain salt, IV, tag, and ciphertext
|
||||
# Total should be at least SALT_SIZE + IV_SIZE + TAG_SIZE bytes
|
||||
min_size = CryptoUtils.SALT_SIZE + CryptoUtils.IV_SIZE + CryptoUtils.TAG_SIZE
|
||||
assert len(decoded) >= min_size
|
||||
|
||||
def test_deterministic_with_same_salt(self):
|
||||
"""Test that encryption is deterministic when using the same salt (for testing)."""
|
||||
plaintext = "deterministic test"
|
||||
|
||||
# Note: In production, each encryption generates a random salt
|
||||
# This test verifies the encryption mechanism itself
|
||||
encrypted1 = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
|
||||
encrypted2 = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
|
||||
|
||||
# Due to random salt, these should be different
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
# But both should decrypt to the same value
|
||||
assert CryptoUtils.decrypt(encrypted1, self.MOCK_KEY_1) == plaintext
|
||||
assert CryptoUtils.decrypt(encrypted2, self.MOCK_KEY_1) == plaintext
|
||||
|
||||
def test_encrypt_uses_env_key_when_none_provided(self):
|
||||
"""Test that encryption uses environment key when no key is provided."""
|
||||
from letta.settings import settings
|
||||
|
||||
# Mock the settings to have an encryption key
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = "env-test-key-123"
|
||||
|
||||
try:
|
||||
plaintext = "test with env key"
|
||||
|
||||
# Should use key from settings
|
||||
encrypted = CryptoUtils.encrypt(plaintext)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
# Should decrypt with same key
|
||||
decrypted = CryptoUtils.decrypt(encrypted)
|
||||
assert decrypted == plaintext
|
||||
finally:
|
||||
# Restore original key
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_encrypt_without_key_raises_error(self):
|
||||
"""Test that encryption without any key raises an error."""
|
||||
from letta.settings import settings
|
||||
|
||||
# Mock settings to have no encryption key
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = None
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError, match="No encryption key configured"):
|
||||
CryptoUtils.encrypt("test data")
|
||||
finally:
|
||||
# Restore original key
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_large_data_encryption(self):
|
||||
"""Test encryption of large data."""
|
||||
# Create 10MB of data
|
||||
large_data = "x" * (10 * 1024 * 1024)
|
||||
|
||||
encrypted = CryptoUtils.encrypt(large_data, self.MOCK_KEY_1)
|
||||
assert len(encrypted) > 0
|
||||
assert encrypted != large_data
|
||||
|
||||
decrypted = CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1)
|
||||
assert decrypted == large_data
|
||||
|
||||
def test_json_data_encryption(self):
|
||||
"""Test encryption of JSON data."""
|
||||
json_data = {
|
||||
"user": "test_user",
|
||||
"token": "secret_token_123",
|
||||
"nested": {"api_key": "sk-1234567890", "headers": {"Authorization": "Bearer token"}},
|
||||
}
|
||||
|
||||
json_str = json.dumps(json_data)
|
||||
encrypted = CryptoUtils.encrypt(json_str, self.MOCK_KEY_1)
|
||||
|
||||
decrypted_str = CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1)
|
||||
decrypted_data = json.loads(decrypted_str)
|
||||
|
||||
assert decrypted_data == json_data
|
||||
|
||||
def test_invalid_encrypted_format(self):
|
||||
"""Test handling of invalid encrypted data format."""
|
||||
invalid_cases = [
|
||||
"invalid-base64!@#", # Invalid base64
|
||||
"dGVzdA==", # Valid base64 but too short for encrypted data
|
||||
]
|
||||
|
||||
for invalid in invalid_cases:
|
||||
with pytest.raises(Exception): # Could be various exceptions
|
||||
CryptoUtils.decrypt(invalid, self.MOCK_KEY_1)
|
||||
|
||||
def test_key_derivation_consistency(self):
|
||||
"""Test that key derivation is consistent."""
|
||||
plaintext = "test key derivation"
|
||||
|
||||
# Multiple encryptions with same key should work
|
||||
encrypted_values = []
|
||||
for _ in range(5):
|
||||
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
|
||||
encrypted_values.append(encrypted)
|
||||
|
||||
# All should decrypt correctly
|
||||
for encrypted in encrypted_values:
|
||||
assert CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1) == plaintext
|
||||
|
||||
def test_special_characters_in_key(self):
|
||||
"""Test encryption with keys containing special characters."""
|
||||
special_key = "key-with-special-chars!@#$%^&*()_+"
|
||||
plaintext = "test data"
|
||||
|
||||
encrypted = CryptoUtils.encrypt(plaintext, special_key)
|
||||
decrypted = CryptoUtils.decrypt(encrypted, special_key)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test encryption of strings with various whitespace."""
|
||||
test_cases = [
|
||||
" leading spaces",
|
||||
"trailing spaces ",
|
||||
" both sides ",
|
||||
"multiple\n\nlines",
|
||||
"\ttabs\there\t",
|
||||
"mixed \t\n whitespace \r\n",
|
||||
]
|
||||
|
||||
for plaintext in test_cases:
|
||||
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY_1)
|
||||
decrypted = CryptoUtils.decrypt(encrypted, self.MOCK_KEY_1)
|
||||
assert decrypted == plaintext, f"Whitespace handling failed for: {repr(plaintext)}"
|
||||
407
tests/test_mcp_encryption.py
Normal file
407
tests/test_mcp_encryption.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""
|
||||
Integration tests for MCP server and OAuth session encryption.
|
||||
Tests the end-to-end encryption functionality in the MCP manager.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.helpers.crypto_utils import CryptoUtils
|
||||
from letta.orm import MCPOAuth, MCPServer as ORMMCPServer
|
||||
from letta.schemas.mcp import (
|
||||
MCPOAuthSessionCreate,
|
||||
MCPServer as PydanticMCPServer,
|
||||
MCPServerType,
|
||||
SSEServerConfig,
|
||||
StdioServerConfig,
|
||||
)
|
||||
from letta.schemas.secret import Secret, SecretDict
|
||||
from letta.server.db import db_registry
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.mcp_manager import MCPManager
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
"""Fixture to create and return a SyncServer instance with MCP manager."""
|
||||
config = LettaConfig.load()
|
||||
config.save()
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
return server
|
||||
|
||||
|
||||
class TestMCPServerEncryption:
|
||||
"""Test MCP server encryption functionality."""
|
||||
|
||||
MOCK_ENCRYPTION_KEY = "test-mcp-encryption-key-123456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
|
||||
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
||||
async def test_create_mcp_server_with_token_encryption(self, mock_get_client, server, default_user):
|
||||
"""Test that MCP server tokens are encrypted when stored."""
|
||||
# Mock the MCP client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools.return_value = []
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Create MCP server with token
|
||||
server_name = f"test_encrypted_server_{uuid4().hex[:8]}"
|
||||
token = "super-secret-api-token-12345"
|
||||
server_url = "https://api.example.com/mcp"
|
||||
|
||||
mcp_server = PydanticMCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url, token=token)
|
||||
|
||||
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
|
||||
|
||||
# Verify server was created
|
||||
assert created_server.server_name == server_name
|
||||
assert created_server.server_type == MCPServerType.SSE
|
||||
|
||||
# Check database directly to verify encryption
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
|
||||
db_server = result.scalar_one()
|
||||
|
||||
# Token should be encrypted in database
|
||||
assert db_server.token_enc is not None
|
||||
assert db_server.token_enc != token # Should not be plaintext
|
||||
|
||||
# Decrypt to verify correctness
|
||||
decrypted_token = CryptoUtils.decrypt(db_server.token_enc, self.MOCK_ENCRYPTION_KEY)
|
||||
assert decrypted_token == token
|
||||
|
||||
# Legacy plaintext column should be None (or empty for dual-write)
|
||||
# During migration phase, might store both
|
||||
if db_server.token:
|
||||
assert db_server.token == token # Dual-write phase
|
||||
|
||||
# Clean up
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
|
||||
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
||||
async def test_create_mcp_server_with_custom_headers_encryption(self, mock_get_client, server, default_user):
|
||||
"""Test that MCP server custom headers are encrypted when stored."""
|
||||
# Mock the MCP client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools.return_value = []
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
server_name = f"test_headers_server_{uuid4().hex[:8]}"
|
||||
custom_headers = {"Authorization": "Bearer secret-token-xyz", "X-API-Key": "api-key-123456", "X-Custom-Header": "custom-value"}
|
||||
server_url = "https://api.example.com/mcp"
|
||||
|
||||
mcp_server = PydanticMCPServer(
|
||||
server_name=server_name, server_type=MCPServerType.STREAMABLE_HTTP, server_url=server_url, custom_headers=custom_headers
|
||||
)
|
||||
|
||||
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
|
||||
|
||||
# Check database directly
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
|
||||
db_server = result.scalar_one()
|
||||
|
||||
# Custom headers should be encrypted as JSON
|
||||
assert db_server.custom_headers_enc is not None
|
||||
|
||||
# Decrypt and parse JSON
|
||||
decrypted_json = CryptoUtils.decrypt(db_server.custom_headers_enc, self.MOCK_ENCRYPTION_KEY)
|
||||
decrypted_headers = json.loads(decrypted_json)
|
||||
assert decrypted_headers == custom_headers
|
||||
|
||||
# Clean up
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
|
||||
async def test_retrieve_mcp_server_decrypts_values(self, server, default_user):
|
||||
"""Test that retrieving MCP server decrypts encrypted values."""
|
||||
# Manually insert encrypted server into database
|
||||
server_id = f"mcp_server-{uuid4().hex[:8]}"
|
||||
server_name = f"test_decrypt_server_{uuid4().hex[:8]}"
|
||||
plaintext_token = "decryption-test-token"
|
||||
encrypted_token = CryptoUtils.encrypt(plaintext_token, self.MOCK_ENCRYPTION_KEY)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
db_server = ORMMCPServer(
|
||||
id=server_id,
|
||||
server_name=server_name,
|
||||
server_type=MCPServerType.SSE.value,
|
||||
server_url="https://test.com",
|
||||
token_enc=encrypted_token,
|
||||
token=None, # No plaintext
|
||||
created_by_id=default_user.id,
|
||||
last_updated_by_id=default_user.id,
|
||||
organization_id=default_user.organization_id,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(db_server)
|
||||
await session.commit()
|
||||
|
||||
# Retrieve server directly by ID to avoid issues with other servers in DB
|
||||
test_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user)
|
||||
|
||||
assert test_server is not None
|
||||
assert test_server.server_name == server_name
|
||||
# Token should be decrypted when accessed via the secret method
|
||||
# Ensure encryption key is available for decryption
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_ENCRYPTION_KEY
|
||||
try:
|
||||
token_secret = test_server.get_token_secret()
|
||||
assert token_secret.get_plaintext() == plaintext_token
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
# Clean up
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == server_id))
|
||||
db_server = result.scalar_one()
|
||||
await session.delete(db_server)
|
||||
await session.commit()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(os.environ, {}, clear=True) # No encryption key
|
||||
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
||||
async def test_create_mcp_server_without_encryption_key(self, mock_get_client, server, default_user):
|
||||
"""Test that MCP servers work without encryption key (backward compatibility)."""
|
||||
# Remove encryption key
|
||||
os.environ.pop("LETTA_ENCRYPTION_KEY", None)
|
||||
|
||||
# Mock the MCP client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools.return_value = []
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
server_name = f"test_no_encrypt_server_{uuid4().hex[:8]}"
|
||||
token = "plaintext-token-no-encryption"
|
||||
|
||||
mcp_server = PydanticMCPServer(
|
||||
server_name=server_name, server_type=MCPServerType.SSE, server_url="https://api.example.com", token=token
|
||||
)
|
||||
|
||||
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
|
||||
|
||||
# Check database - should store as plaintext
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id))
|
||||
db_server = result.scalar_one()
|
||||
|
||||
# Should store in plaintext column
|
||||
assert db_server.token == token
|
||||
assert db_server.token_enc is None # No encryption
|
||||
|
||||
# Clean up
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
|
||||
|
||||
class TestMCPOAuthEncryption:
|
||||
"""Test MCP OAuth session encryption functionality."""
|
||||
|
||||
MOCK_ENCRYPTION_KEY = "test-oauth-encryption-key-123456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
|
||||
async def test_create_oauth_session_with_encryption(self, server, default_user):
|
||||
"""Test that OAuth tokens are encrypted when stored."""
|
||||
server_url = "https://github.com/mcp"
|
||||
server_name = "GitHub MCP"
|
||||
|
||||
# Step 1: Create OAuth session (without tokens initially)
|
||||
oauth_session_create = MCPOAuthSessionCreate(
|
||||
server_url=server_url,
|
||||
server_name=server_name,
|
||||
organization_id=default_user.organization_id,
|
||||
user_id=default_user.id,
|
||||
)
|
||||
|
||||
created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user)
|
||||
|
||||
assert created_session.server_url == server_url
|
||||
assert created_session.server_name == server_name
|
||||
|
||||
# Step 2: Update session with tokens (simulating OAuth callback)
|
||||
from letta.schemas.mcp import MCPOAuthSessionUpdate
|
||||
|
||||
update_data = MCPOAuthSessionUpdate(
|
||||
access_token="github-access-token-abc123",
|
||||
refresh_token="github-refresh-token-xyz789",
|
||||
client_id="client-id-123",
|
||||
client_secret="client-secret-super-secret",
|
||||
expires_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
await server.mcp_manager.update_oauth_session(created_session.id, update_data, actor=default_user)
|
||||
|
||||
# Check database directly for encryption
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id))
|
||||
db_oauth = result.scalar_one()
|
||||
|
||||
# All sensitive fields should be encrypted
|
||||
assert db_oauth.access_token_enc is not None
|
||||
assert db_oauth.access_token_enc != update_data.access_token
|
||||
|
||||
assert db_oauth.refresh_token_enc is not None
|
||||
|
||||
assert db_oauth.client_secret_enc is not None
|
||||
|
||||
# Verify decryption
|
||||
decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc, self.MOCK_ENCRYPTION_KEY)
|
||||
assert decrypted_access == update_data.access_token
|
||||
|
||||
decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc, self.MOCK_ENCRYPTION_KEY)
|
||||
assert decrypted_refresh == update_data.refresh_token
|
||||
|
||||
decrypted_secret = CryptoUtils.decrypt(db_oauth.client_secret_enc, self.MOCK_ENCRYPTION_KEY)
|
||||
assert decrypted_secret == update_data.client_secret
|
||||
|
||||
# Clean up not needed - test database is reset
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
|
||||
async def test_retrieve_oauth_session_decrypts_tokens(self, server, default_user):
|
||||
"""Test that retrieving OAuth session decrypts tokens."""
|
||||
# Manually insert encrypted OAuth session
|
||||
session_id = f"mcp-oauth-{str(uuid4())[:8]}"
|
||||
access_token = "test-access-token"
|
||||
refresh_token = "test-refresh-token"
|
||||
client_secret = "test-client-secret"
|
||||
|
||||
encrypted_access = CryptoUtils.encrypt(access_token, self.MOCK_ENCRYPTION_KEY)
|
||||
encrypted_refresh = CryptoUtils.encrypt(refresh_token, self.MOCK_ENCRYPTION_KEY)
|
||||
encrypted_secret = CryptoUtils.encrypt(client_secret, self.MOCK_ENCRYPTION_KEY)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
db_oauth = MCPOAuth(
|
||||
id=session_id,
|
||||
state=f"test-state-{uuid4().hex[:8]}",
|
||||
server_url="https://test.com/mcp",
|
||||
server_name="Test Provider",
|
||||
access_token_enc=encrypted_access,
|
||||
refresh_token_enc=encrypted_refresh,
|
||||
client_id="test-client",
|
||||
client_secret_enc=encrypted_secret,
|
||||
user_id=default_user.id,
|
||||
organization_id=default_user.organization_id,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(db_oauth)
|
||||
await session.commit()
|
||||
|
||||
# Retrieve through manager by ID
|
||||
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
|
||||
assert test_session is not None
|
||||
|
||||
# Tokens should be decrypted
|
||||
assert test_session.access_token == access_token
|
||||
assert test_session.refresh_token == refresh_token
|
||||
assert test_session.client_secret == client_secret
|
||||
|
||||
# Clean up not needed - test database is reset
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
|
||||
async def test_update_oauth_session_maintains_encryption(self, server, default_user):
|
||||
"""Test that updating OAuth session maintains encryption."""
|
||||
# Create initial session (without tokens)
|
||||
from letta.schemas.mcp import MCPOAuthSessionUpdate
|
||||
|
||||
oauth_session_create = MCPOAuthSessionCreate(
|
||||
server_url="https://test.com/mcp",
|
||||
server_name="Test Update Provider",
|
||||
organization_id=default_user.organization_id,
|
||||
user_id=default_user.id,
|
||||
)
|
||||
|
||||
created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user)
|
||||
|
||||
# Add initial tokens
|
||||
initial_update = MCPOAuthSessionUpdate(
|
||||
access_token="initial-token",
|
||||
refresh_token="initial-refresh",
|
||||
client_id="client-123",
|
||||
client_secret="initial-secret",
|
||||
)
|
||||
|
||||
await server.mcp_manager.update_oauth_session(created_session.id, initial_update, actor=default_user)
|
||||
|
||||
# Update with new tokens
|
||||
new_access_token = "updated-access-token"
|
||||
new_refresh_token = "updated-refresh-token"
|
||||
|
||||
new_update = MCPOAuthSessionUpdate(
|
||||
access_token=new_access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
)
|
||||
|
||||
updated_session = await server.mcp_manager.update_oauth_session(created_session.id, new_update, actor=default_user)
|
||||
|
||||
# Verify update worked
|
||||
assert updated_session.access_token == new_access_token
|
||||
assert updated_session.refresh_token == new_refresh_token
|
||||
|
||||
# Check database encryption
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id))
|
||||
db_oauth = result.scalar_one()
|
||||
|
||||
# New tokens should be encrypted
|
||||
decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc, self.MOCK_ENCRYPTION_KEY)
|
||||
assert decrypted_access == new_access_token
|
||||
|
||||
decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc, self.MOCK_ENCRYPTION_KEY)
|
||||
assert decrypted_refresh == new_refresh_token
|
||||
|
||||
# Clean up not needed - test database is reset
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY})
|
||||
async def test_dual_read_backward_compatibility(self, server, default_user):
|
||||
"""Test that system can read both encrypted and plaintext values (migration support)."""
|
||||
# Insert a record with both encrypted and plaintext values
|
||||
session_id = f"mcp-oauth-{str(uuid4())[:8]}"
|
||||
plaintext_token = "legacy-plaintext-token"
|
||||
new_encrypted_token = "new-encrypted-token"
|
||||
encrypted_new = CryptoUtils.encrypt(new_encrypted_token, self.MOCK_ENCRYPTION_KEY)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
db_oauth = MCPOAuth(
|
||||
id=session_id,
|
||||
state=f"dual-read-state-{uuid4().hex[:8]}",
|
||||
server_url="https://test.com/mcp",
|
||||
server_name="Dual Read Test",
|
||||
# Both encrypted and plaintext values
|
||||
access_token=plaintext_token, # Legacy plaintext
|
||||
access_token_enc=encrypted_new, # New encrypted
|
||||
client_id="test-client",
|
||||
user_id=default_user.id,
|
||||
organization_id=default_user.organization_id,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(db_oauth)
|
||||
await session.commit()
|
||||
|
||||
# Retrieve through manager
|
||||
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
|
||||
assert test_session is not None
|
||||
|
||||
# Should prefer encrypted value over plaintext
|
||||
assert test_session.access_token == new_encrypted_token
|
||||
|
||||
# Clean up not needed - test database is reset
|
||||
373
tests/test_secret.py
Normal file
373
tests/test_secret.py
Normal file
@@ -0,0 +1,373 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.helpers.crypto_utils import CryptoUtils
|
||||
from letta.schemas.secret import Secret, SecretDict
|
||||
|
||||
|
||||
class TestSecret:
|
||||
"""Test suite for Secret wrapper class."""
|
||||
|
||||
MOCK_KEY = "test-secret-key-1234567890"
|
||||
|
||||
def test_from_plaintext_with_key(self):
|
||||
"""Test creating a Secret from plaintext value with encryption key."""
|
||||
from letta.settings import settings
|
||||
|
||||
# Set encryption key
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "my-secret-value"
|
||||
|
||||
secret = Secret.from_plaintext(plaintext)
|
||||
|
||||
# Should store encrypted value
|
||||
assert secret._encrypted_value is not None
|
||||
assert secret._encrypted_value != plaintext
|
||||
assert secret._was_encrypted is False
|
||||
|
||||
# Should decrypt to original value
|
||||
assert secret.get_plaintext() == plaintext
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_plaintext_without_key(self):
|
||||
"""Test creating a Secret from plaintext without encryption key."""
|
||||
from letta.settings import settings
|
||||
|
||||
# Clear encryption key
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = None
|
||||
|
||||
try:
|
||||
plaintext = "my-plaintext-value"
|
||||
|
||||
# Should raise error when trying to encrypt without key
|
||||
with pytest.raises(ValueError):
|
||||
Secret.from_plaintext(plaintext)
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_plaintext_with_none(self):
|
||||
"""Test creating a Secret from None value."""
|
||||
secret = Secret.from_plaintext(None)
|
||||
|
||||
assert secret._encrypted_value is None
|
||||
assert secret._was_encrypted is False
|
||||
assert secret.get_plaintext() is None
|
||||
assert secret.is_empty() is True
|
||||
|
||||
def test_from_encrypted(self):
|
||||
"""Test creating a Secret from already encrypted value."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "database-secret"
|
||||
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY)
|
||||
|
||||
secret = Secret.from_encrypted(encrypted)
|
||||
|
||||
assert secret._encrypted_value == encrypted
|
||||
assert secret._was_encrypted is True
|
||||
assert secret.get_plaintext() == plaintext
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_db_with_encrypted_value(self):
|
||||
"""Test creating a Secret from database with encrypted value."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "database-secret"
|
||||
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY)
|
||||
|
||||
secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=None)
|
||||
|
||||
assert secret._encrypted_value == encrypted
|
||||
assert secret._was_encrypted is True
|
||||
assert secret.get_plaintext() == plaintext
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_db_with_plaintext_value(self):
|
||||
"""Test creating a Secret from database with plaintext value (backward compatibility)."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "legacy-plaintext"
|
||||
|
||||
# When only plaintext is provided, should encrypt it
|
||||
secret = Secret.from_db(encrypted_value=None, plaintext_value=plaintext)
|
||||
|
||||
# Should encrypt the plaintext
|
||||
assert secret._encrypted_value is not None
|
||||
assert secret._was_encrypted is False
|
||||
assert secret.get_plaintext() == plaintext
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_db_dual_read(self):
|
||||
"""Test dual read functionality - prefer encrypted over plaintext."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "correct-value"
|
||||
old_plaintext = "old-legacy-value"
|
||||
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY)
|
||||
|
||||
# When both values exist, should prefer encrypted
|
||||
secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=old_plaintext)
|
||||
|
||||
assert secret.get_plaintext() == plaintext # Should use encrypted value, not plaintext
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_get_encrypted(self):
|
||||
"""Test getting the encrypted value for database storage."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "test-encryption"
|
||||
|
||||
secret = Secret.from_plaintext(plaintext)
|
||||
encrypted_value = secret.get_encrypted()
|
||||
|
||||
assert encrypted_value is not None
|
||||
|
||||
# Should decrypt back to original
|
||||
decrypted = CryptoUtils.decrypt(encrypted_value, self.MOCK_KEY)
|
||||
assert decrypted == plaintext
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_is_empty(self):
|
||||
"""Test checking if secret is empty."""
|
||||
# Empty secret
|
||||
empty_secret = Secret.from_plaintext(None)
|
||||
assert empty_secret.is_empty() is True
|
||||
|
||||
# Non-empty secret
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
non_empty_secret = Secret.from_plaintext("value")
|
||||
assert non_empty_secret.is_empty() is False
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_string_representation(self):
|
||||
"""Test that string representation doesn't expose secret."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
secret = Secret.from_plaintext("sensitive-data")
|
||||
|
||||
# String representation should not contain the actual value
|
||||
str_repr = str(secret)
|
||||
assert "sensitive-data" not in str_repr
|
||||
assert "****" in str_repr
|
||||
|
||||
# Empty secret
|
||||
empty_secret = Secret.from_plaintext(None)
|
||||
assert "empty" in str(empty_secret)
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_equality(self):
|
||||
"""Test comparing two secrets."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "same-value"
|
||||
|
||||
secret1 = Secret.from_plaintext(plaintext)
|
||||
secret2 = Secret.from_plaintext(plaintext)
|
||||
|
||||
# Should be equal based on plaintext value
|
||||
assert secret1 == secret2
|
||||
|
||||
# Different values should not be equal
|
||||
secret3 = Secret.from_plaintext("different-value")
|
||||
assert secret1 != secret3
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
|
||||
class TestSecretDict:
|
||||
"""Test suite for SecretDict wrapper class."""
|
||||
|
||||
MOCK_KEY = "test-secretdict-key-1234567890"
|
||||
|
||||
def test_from_plaintext_dict(self):
|
||||
"""Test creating a SecretDict from plaintext dictionary."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext_dict = {"api_key": "sk-1234567890", "api_secret": "secret-value", "nested": {"token": "bearer-token"}}
|
||||
|
||||
secret_dict = SecretDict.from_plaintext(plaintext_dict)
|
||||
|
||||
# Should store encrypted JSON
|
||||
assert secret_dict._encrypted_value is not None
|
||||
|
||||
# Should decrypt to original dict
|
||||
assert secret_dict.get_plaintext() == plaintext_dict
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_plaintext_none(self):
|
||||
"""Test creating a SecretDict from None value."""
|
||||
secret_dict = SecretDict.from_plaintext(None)
|
||||
|
||||
assert secret_dict._encrypted_value is None
|
||||
assert secret_dict.get_plaintext() is None
|
||||
|
||||
def test_from_encrypted_with_json(self):
|
||||
"""Test creating a SecretDict from encrypted JSON."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext_dict = {"header1": "value1", "Authorization": "Bearer token123"}
|
||||
|
||||
json_str = json.dumps(plaintext_dict)
|
||||
encrypted = CryptoUtils.encrypt(json_str, self.MOCK_KEY)
|
||||
|
||||
secret_dict = SecretDict.from_encrypted(encrypted)
|
||||
|
||||
assert secret_dict._encrypted_value == encrypted
|
||||
assert secret_dict.get_plaintext() == plaintext_dict
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_db_with_encrypted(self):
|
||||
"""Test creating SecretDict from database with encrypted value."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext_dict = {"key": "value"}
|
||||
json_str = json.dumps(plaintext_dict)
|
||||
encrypted = CryptoUtils.encrypt(json_str, self.MOCK_KEY)
|
||||
|
||||
secret_dict = SecretDict.from_db(encrypted_value=encrypted, plaintext_value=None)
|
||||
|
||||
assert secret_dict.get_plaintext() == plaintext_dict
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_db_with_plaintext_json(self):
|
||||
"""Test creating SecretDict from database with plaintext JSON (backward compatibility)."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext_dict = {"legacy": "headers"}
|
||||
|
||||
# from_db expects a Dict, not a JSON string
|
||||
secret_dict = SecretDict.from_db(encrypted_value=None, plaintext_value=plaintext_dict)
|
||||
|
||||
assert secret_dict.get_plaintext() == plaintext_dict
|
||||
# Should have encrypted it
|
||||
assert secret_dict._encrypted_value is not None
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_complex_nested_structure(self):
|
||||
"""Test SecretDict with complex nested structures."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
complex_dict = {
|
||||
"level1": {"level2": {"level3": ["item1", "item2"], "secret": "nested-secret"}, "array": [1, 2, {"nested": "value"}]},
|
||||
"simple": "value",
|
||||
"number": 42,
|
||||
"boolean": True,
|
||||
"null": None,
|
||||
}
|
||||
|
||||
secret_dict = SecretDict.from_plaintext(complex_dict)
|
||||
decrypted = secret_dict.get_plaintext()
|
||||
|
||||
assert decrypted == complex_dict
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_empty_dict(self):
|
||||
"""Test SecretDict with empty dictionary."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
empty_dict = {}
|
||||
|
||||
secret_dict = SecretDict.from_plaintext(empty_dict)
|
||||
assert secret_dict.get_plaintext() == empty_dict
|
||||
|
||||
# Encrypted value should still be created
|
||||
encrypted = secret_dict.get_encrypted()
|
||||
assert encrypted is not None
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_dual_read_prefer_encrypted(self):
|
||||
"""Test that SecretDict prefers encrypted value over plaintext when both exist."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
new_dict = {"current": "value"}
|
||||
old_dict = {"legacy": "value"}
|
||||
|
||||
encrypted = CryptoUtils.encrypt(json.dumps(new_dict), self.MOCK_KEY)
|
||||
plaintext = json.dumps(old_dict)
|
||||
|
||||
secret_dict = SecretDict.from_db(encrypted_value=encrypted, plaintext_value=plaintext)
|
||||
|
||||
# Should use encrypted value, not plaintext
|
||||
assert secret_dict.get_plaintext() == new_dict
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
Reference in New Issue
Block a user