Files
letta-server/alembic/versions/eff256d296cb_mcp_encrypted_data_migration.py
Kian Jones b8e9a80d93 merge this (#4759)
* wait I forgot to comit locally

* cp the entire core directory and then rm the .git subdir
2025-09-17 15:47:40 -07:00

301 lines
12 KiB
Python

"""mcp encrypted data migration
Revision ID: eff256d296cb
Revises: 7f7933666957
Create Date: 2025-09-16 16:01:58.943318
"""
import json
import os
# Add the app directory to path to import our crypto utils
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy import JSON, String, Text
from sqlalchemy.sql import column, table
from alembic import op
from letta.helpers.crypto_utils import CryptoUtils
# revision identifiers, used by Alembic.
revision: str = "eff256d296cb"
down_revision: Union[str, None] = "7f7933666957"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Check if encryption key is available
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY")
if not encryption_key:
print("WARNING: LETTA_ENCRYPTION_KEY not set. Skipping data encryption migration.")
print("You can run a separate migration script later to encrypt existing data.")
return
# Get database connection
connection = op.get_bind()
# Batch processing configuration
BATCH_SIZE = 1000 # Process 1000 rows at a time
# Migrate mcp_oauth data
print("Migrating mcp_oauth encrypted fields...")
mcp_oauth = table(
"mcp_oauth",
column("id", String),
column("access_token", Text),
column("access_token_enc", Text),
column("refresh_token", Text),
column("refresh_token_enc", Text),
column("client_secret", Text),
column("client_secret_enc", Text),
)
# Count total rows to process
total_count_result = connection.execute(
sa.select(sa.func.count())
.select_from(mcp_oauth)
.where(
sa.and_(
sa.or_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.client_secret.isnot(None)),
# Only count rows that need encryption
sa.or_(
sa.and_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.access_token_enc.is_(None)),
sa.and_(mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.refresh_token_enc.is_(None)),
sa.and_(mcp_oauth.c.client_secret.isnot(None), mcp_oauth.c.client_secret_enc.is_(None)),
),
)
)
).scalar()
if total_count_result and total_count_result > 0:
print(f"Found {total_count_result} mcp_oauth records that need encryption")
encrypted_count = 0
skipped_count = 0
offset = 0
# Process in batches
while True:
# Select batch of rows
oauth_rows = connection.execute(
sa.select(
mcp_oauth.c.id,
mcp_oauth.c.access_token,
mcp_oauth.c.access_token_enc,
mcp_oauth.c.refresh_token,
mcp_oauth.c.refresh_token_enc,
mcp_oauth.c.client_secret,
mcp_oauth.c.client_secret_enc,
)
.where(
sa.and_(
sa.or_(
mcp_oauth.c.access_token.isnot(None),
mcp_oauth.c.refresh_token.isnot(None),
mcp_oauth.c.client_secret.isnot(None),
),
# Only select rows that need encryption
sa.or_(
sa.and_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.access_token_enc.is_(None)),
sa.and_(mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.refresh_token_enc.is_(None)),
sa.and_(mcp_oauth.c.client_secret.isnot(None), mcp_oauth.c.client_secret_enc.is_(None)),
),
)
)
.order_by(mcp_oauth.c.id) # Ensure consistent ordering
.limit(BATCH_SIZE)
.offset(offset)
).fetchall()
if not oauth_rows:
break # No more rows to process
# Prepare batch updates
batch_updates = []
for row in oauth_rows:
updates = {"id": row.id}
has_updates = False
# Encrypt access_token if present and not already encrypted
if row.access_token and not row.access_token_enc:
try:
updates["access_token_enc"] = CryptoUtils.encrypt(row.access_token, encryption_key)
has_updates = True
except Exception as e:
print(f"Warning: Failed to encrypt access_token for mcp_oauth id={row.id}: {e}")
elif row.access_token_enc:
skipped_count += 1
# Encrypt refresh_token if present and not already encrypted
if row.refresh_token and not row.refresh_token_enc:
try:
updates["refresh_token_enc"] = CryptoUtils.encrypt(row.refresh_token, encryption_key)
has_updates = True
except Exception as e:
print(f"Warning: Failed to encrypt refresh_token for mcp_oauth id={row.id}: {e}")
elif row.refresh_token_enc:
skipped_count += 1
# Encrypt client_secret if present and not already encrypted
if row.client_secret and not row.client_secret_enc:
try:
updates["client_secret_enc"] = CryptoUtils.encrypt(row.client_secret, encryption_key)
has_updates = True
except Exception as e:
print(f"Warning: Failed to encrypt client_secret for mcp_oauth id={row.id}: {e}")
elif row.client_secret_enc:
skipped_count += 1
if has_updates:
batch_updates.append(updates)
encrypted_count += 1
# Execute batch update if there are updates
if batch_updates:
# Use bulk update for better performance
for update_data in batch_updates:
row_id = update_data.pop("id")
if update_data: # Only update if there are fields to update
connection.execute(mcp_oauth.update().where(mcp_oauth.c.id == row_id).values(**update_data))
# Progress indicator for large datasets
if encrypted_count > 0 and encrypted_count % 10000 == 0:
print(f" Progress: Encrypted {encrypted_count} mcp_oauth records...")
offset += BATCH_SIZE
# For very large datasets, commit periodically to avoid long transactions
if encrypted_count > 0 and encrypted_count % 50000 == 0:
connection.commit()
print(f"mcp_oauth: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
else:
print("mcp_oauth: No records need encryption")
# Migrate mcp_server data
print("Migrating mcp_server encrypted fields...")
mcp_server = table(
"mcp_server",
column("id", String),
column("token", String),
column("token_enc", Text),
column("custom_headers", JSON),
column("custom_headers_enc", Text),
)
# Count total rows to process
total_count_result = connection.execute(
sa.select(sa.func.count())
.select_from(mcp_server)
.where(
sa.and_(
sa.or_(mcp_server.c.token.isnot(None), mcp_server.c.custom_headers.isnot(None)),
# Only count rows that need encryption
sa.or_(
sa.and_(mcp_server.c.token.isnot(None), mcp_server.c.token_enc.is_(None)),
sa.and_(mcp_server.c.custom_headers.isnot(None), mcp_server.c.custom_headers_enc.is_(None)),
),
)
)
).scalar()
if total_count_result and total_count_result > 0:
print(f"Found {total_count_result} mcp_server records that need encryption")
encrypted_count = 0
skipped_count = 0
offset = 0
# Process in batches
while True:
# Select batch of rows
server_rows = connection.execute(
sa.select(
mcp_server.c.id,
mcp_server.c.token,
mcp_server.c.token_enc,
mcp_server.c.custom_headers,
mcp_server.c.custom_headers_enc,
)
.where(
sa.and_(
sa.or_(mcp_server.c.token.isnot(None), mcp_server.c.custom_headers.isnot(None)),
# Only select rows that need encryption
sa.or_(
sa.and_(mcp_server.c.token.isnot(None), mcp_server.c.token_enc.is_(None)),
sa.and_(mcp_server.c.custom_headers.isnot(None), mcp_server.c.custom_headers_enc.is_(None)),
),
)
)
.order_by(mcp_server.c.id) # Ensure consistent ordering
.limit(BATCH_SIZE)
.offset(offset)
).fetchall()
if not server_rows:
break # No more rows to process
# Prepare batch updates
batch_updates = []
for row in server_rows:
updates = {"id": row.id}
has_updates = False
# Encrypt token if present and not already encrypted
if row.token and not row.token_enc:
try:
updates["token_enc"] = CryptoUtils.encrypt(row.token, encryption_key)
has_updates = True
except Exception as e:
print(f"Warning: Failed to encrypt token for mcp_server id={row.id}: {e}")
elif row.token_enc:
skipped_count += 1
# Encrypt custom_headers if present (JSON field) and not already encrypted
if row.custom_headers and not row.custom_headers_enc:
try:
# Convert JSON to string for encryption
headers_json = json.dumps(row.custom_headers)
updates["custom_headers_enc"] = CryptoUtils.encrypt(headers_json, encryption_key)
has_updates = True
except Exception as e:
print(f"Warning: Failed to encrypt custom_headers for mcp_server id={row.id}: {e}")
elif row.custom_headers_enc:
skipped_count += 1
if has_updates:
batch_updates.append(updates)
encrypted_count += 1
# Execute batch update if there are updates
if batch_updates:
# Use bulk update for better performance
for update_data in batch_updates:
row_id = update_data.pop("id")
if update_data: # Only update if there are fields to update
connection.execute(mcp_server.update().where(mcp_server.c.id == row_id).values(**update_data))
# Progress indicator for large datasets
if encrypted_count > 0 and encrypted_count % 10000 == 0:
print(f" Progress: Encrypted {encrypted_count} mcp_server records...")
offset += BATCH_SIZE
# For very large datasets, commit periodically to avoid long transactions
if encrypted_count > 0 and encrypted_count % 50000 == 0:
connection.commit()
print(f"mcp_server: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
else:
print("mcp_server: No records need encryption")
print("Migration complete. Plaintext columns are retained for rollback safety.")
def downgrade() -> None:
pass