* wait I forgot to comit locally * cp the entire core directory and then rm the .git subdir
301 lines
12 KiB
Python
301 lines
12 KiB
Python
"""mcp encrypted data migration
|
|
|
|
Revision ID: eff256d296cb
|
|
Revises: 7f7933666957
|
|
Create Date: 2025-09-16 16:01:58.943318
|
|
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
|
|
# Add the app directory to path to import our crypto utils
|
|
from typing import Sequence, Union
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy import JSON, String, Text
|
|
from sqlalchemy.sql import column, table
|
|
|
|
from alembic import op
|
|
from letta.helpers.crypto_utils import CryptoUtils
|
|
|
|
# revision identifiers, used by Alembic.
|
|
revision: str = "eff256d296cb"
|
|
down_revision: Union[str, None] = "7f7933666957"
|
|
branch_labels: Union[str, Sequence[str], None] = None
|
|
depends_on: Union[str, Sequence[str], None] = None
|
|
|
|
|
|
def upgrade() -> None:
|
|
# Check if encryption key is available
|
|
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY")
|
|
if not encryption_key:
|
|
print("WARNING: LETTA_ENCRYPTION_KEY not set. Skipping data encryption migration.")
|
|
print("You can run a separate migration script later to encrypt existing data.")
|
|
return
|
|
|
|
# Get database connection
|
|
connection = op.get_bind()
|
|
|
|
# Batch processing configuration
|
|
BATCH_SIZE = 1000 # Process 1000 rows at a time
|
|
|
|
# Migrate mcp_oauth data
|
|
print("Migrating mcp_oauth encrypted fields...")
|
|
mcp_oauth = table(
|
|
"mcp_oauth",
|
|
column("id", String),
|
|
column("access_token", Text),
|
|
column("access_token_enc", Text),
|
|
column("refresh_token", Text),
|
|
column("refresh_token_enc", Text),
|
|
column("client_secret", Text),
|
|
column("client_secret_enc", Text),
|
|
)
|
|
|
|
# Count total rows to process
|
|
total_count_result = connection.execute(
|
|
sa.select(sa.func.count())
|
|
.select_from(mcp_oauth)
|
|
.where(
|
|
sa.and_(
|
|
sa.or_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.client_secret.isnot(None)),
|
|
# Only count rows that need encryption
|
|
sa.or_(
|
|
sa.and_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.access_token_enc.is_(None)),
|
|
sa.and_(mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.refresh_token_enc.is_(None)),
|
|
sa.and_(mcp_oauth.c.client_secret.isnot(None), mcp_oauth.c.client_secret_enc.is_(None)),
|
|
),
|
|
)
|
|
)
|
|
).scalar()
|
|
|
|
if total_count_result and total_count_result > 0:
|
|
print(f"Found {total_count_result} mcp_oauth records that need encryption")
|
|
|
|
encrypted_count = 0
|
|
skipped_count = 0
|
|
offset = 0
|
|
|
|
# Process in batches
|
|
while True:
|
|
# Select batch of rows
|
|
oauth_rows = connection.execute(
|
|
sa.select(
|
|
mcp_oauth.c.id,
|
|
mcp_oauth.c.access_token,
|
|
mcp_oauth.c.access_token_enc,
|
|
mcp_oauth.c.refresh_token,
|
|
mcp_oauth.c.refresh_token_enc,
|
|
mcp_oauth.c.client_secret,
|
|
mcp_oauth.c.client_secret_enc,
|
|
)
|
|
.where(
|
|
sa.and_(
|
|
sa.or_(
|
|
mcp_oauth.c.access_token.isnot(None),
|
|
mcp_oauth.c.refresh_token.isnot(None),
|
|
mcp_oauth.c.client_secret.isnot(None),
|
|
),
|
|
# Only select rows that need encryption
|
|
sa.or_(
|
|
sa.and_(mcp_oauth.c.access_token.isnot(None), mcp_oauth.c.access_token_enc.is_(None)),
|
|
sa.and_(mcp_oauth.c.refresh_token.isnot(None), mcp_oauth.c.refresh_token_enc.is_(None)),
|
|
sa.and_(mcp_oauth.c.client_secret.isnot(None), mcp_oauth.c.client_secret_enc.is_(None)),
|
|
),
|
|
)
|
|
)
|
|
.order_by(mcp_oauth.c.id) # Ensure consistent ordering
|
|
.limit(BATCH_SIZE)
|
|
.offset(offset)
|
|
).fetchall()
|
|
|
|
if not oauth_rows:
|
|
break # No more rows to process
|
|
|
|
# Prepare batch updates
|
|
batch_updates = []
|
|
|
|
for row in oauth_rows:
|
|
updates = {"id": row.id}
|
|
has_updates = False
|
|
|
|
# Encrypt access_token if present and not already encrypted
|
|
if row.access_token and not row.access_token_enc:
|
|
try:
|
|
updates["access_token_enc"] = CryptoUtils.encrypt(row.access_token, encryption_key)
|
|
has_updates = True
|
|
except Exception as e:
|
|
print(f"Warning: Failed to encrypt access_token for mcp_oauth id={row.id}: {e}")
|
|
elif row.access_token_enc:
|
|
skipped_count += 1
|
|
|
|
# Encrypt refresh_token if present and not already encrypted
|
|
if row.refresh_token and not row.refresh_token_enc:
|
|
try:
|
|
updates["refresh_token_enc"] = CryptoUtils.encrypt(row.refresh_token, encryption_key)
|
|
has_updates = True
|
|
except Exception as e:
|
|
print(f"Warning: Failed to encrypt refresh_token for mcp_oauth id={row.id}: {e}")
|
|
elif row.refresh_token_enc:
|
|
skipped_count += 1
|
|
|
|
# Encrypt client_secret if present and not already encrypted
|
|
if row.client_secret and not row.client_secret_enc:
|
|
try:
|
|
updates["client_secret_enc"] = CryptoUtils.encrypt(row.client_secret, encryption_key)
|
|
has_updates = True
|
|
except Exception as e:
|
|
print(f"Warning: Failed to encrypt client_secret for mcp_oauth id={row.id}: {e}")
|
|
elif row.client_secret_enc:
|
|
skipped_count += 1
|
|
|
|
if has_updates:
|
|
batch_updates.append(updates)
|
|
encrypted_count += 1
|
|
|
|
# Execute batch update if there are updates
|
|
if batch_updates:
|
|
# Use bulk update for better performance
|
|
for update_data in batch_updates:
|
|
row_id = update_data.pop("id")
|
|
if update_data: # Only update if there are fields to update
|
|
connection.execute(mcp_oauth.update().where(mcp_oauth.c.id == row_id).values(**update_data))
|
|
|
|
# Progress indicator for large datasets
|
|
if encrypted_count > 0 and encrypted_count % 10000 == 0:
|
|
print(f" Progress: Encrypted {encrypted_count} mcp_oauth records...")
|
|
|
|
offset += BATCH_SIZE
|
|
|
|
# For very large datasets, commit periodically to avoid long transactions
|
|
if encrypted_count > 0 and encrypted_count % 50000 == 0:
|
|
connection.commit()
|
|
|
|
print(f"mcp_oauth: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
|
|
else:
|
|
print("mcp_oauth: No records need encryption")
|
|
|
|
# Migrate mcp_server data
|
|
print("Migrating mcp_server encrypted fields...")
|
|
mcp_server = table(
|
|
"mcp_server",
|
|
column("id", String),
|
|
column("token", String),
|
|
column("token_enc", Text),
|
|
column("custom_headers", JSON),
|
|
column("custom_headers_enc", Text),
|
|
)
|
|
|
|
# Count total rows to process
|
|
total_count_result = connection.execute(
|
|
sa.select(sa.func.count())
|
|
.select_from(mcp_server)
|
|
.where(
|
|
sa.and_(
|
|
sa.or_(mcp_server.c.token.isnot(None), mcp_server.c.custom_headers.isnot(None)),
|
|
# Only count rows that need encryption
|
|
sa.or_(
|
|
sa.and_(mcp_server.c.token.isnot(None), mcp_server.c.token_enc.is_(None)),
|
|
sa.and_(mcp_server.c.custom_headers.isnot(None), mcp_server.c.custom_headers_enc.is_(None)),
|
|
),
|
|
)
|
|
)
|
|
).scalar()
|
|
|
|
if total_count_result and total_count_result > 0:
|
|
print(f"Found {total_count_result} mcp_server records that need encryption")
|
|
|
|
encrypted_count = 0
|
|
skipped_count = 0
|
|
offset = 0
|
|
|
|
# Process in batches
|
|
while True:
|
|
# Select batch of rows
|
|
server_rows = connection.execute(
|
|
sa.select(
|
|
mcp_server.c.id,
|
|
mcp_server.c.token,
|
|
mcp_server.c.token_enc,
|
|
mcp_server.c.custom_headers,
|
|
mcp_server.c.custom_headers_enc,
|
|
)
|
|
.where(
|
|
sa.and_(
|
|
sa.or_(mcp_server.c.token.isnot(None), mcp_server.c.custom_headers.isnot(None)),
|
|
# Only select rows that need encryption
|
|
sa.or_(
|
|
sa.and_(mcp_server.c.token.isnot(None), mcp_server.c.token_enc.is_(None)),
|
|
sa.and_(mcp_server.c.custom_headers.isnot(None), mcp_server.c.custom_headers_enc.is_(None)),
|
|
),
|
|
)
|
|
)
|
|
.order_by(mcp_server.c.id) # Ensure consistent ordering
|
|
.limit(BATCH_SIZE)
|
|
.offset(offset)
|
|
).fetchall()
|
|
|
|
if not server_rows:
|
|
break # No more rows to process
|
|
|
|
# Prepare batch updates
|
|
batch_updates = []
|
|
|
|
for row in server_rows:
|
|
updates = {"id": row.id}
|
|
has_updates = False
|
|
|
|
# Encrypt token if present and not already encrypted
|
|
if row.token and not row.token_enc:
|
|
try:
|
|
updates["token_enc"] = CryptoUtils.encrypt(row.token, encryption_key)
|
|
has_updates = True
|
|
except Exception as e:
|
|
print(f"Warning: Failed to encrypt token for mcp_server id={row.id}: {e}")
|
|
elif row.token_enc:
|
|
skipped_count += 1
|
|
|
|
# Encrypt custom_headers if present (JSON field) and not already encrypted
|
|
if row.custom_headers and not row.custom_headers_enc:
|
|
try:
|
|
# Convert JSON to string for encryption
|
|
headers_json = json.dumps(row.custom_headers)
|
|
updates["custom_headers_enc"] = CryptoUtils.encrypt(headers_json, encryption_key)
|
|
has_updates = True
|
|
except Exception as e:
|
|
print(f"Warning: Failed to encrypt custom_headers for mcp_server id={row.id}: {e}")
|
|
elif row.custom_headers_enc:
|
|
skipped_count += 1
|
|
|
|
if has_updates:
|
|
batch_updates.append(updates)
|
|
encrypted_count += 1
|
|
|
|
# Execute batch update if there are updates
|
|
if batch_updates:
|
|
# Use bulk update for better performance
|
|
for update_data in batch_updates:
|
|
row_id = update_data.pop("id")
|
|
if update_data: # Only update if there are fields to update
|
|
connection.execute(mcp_server.update().where(mcp_server.c.id == row_id).values(**update_data))
|
|
|
|
# Progress indicator for large datasets
|
|
if encrypted_count > 0 and encrypted_count % 10000 == 0:
|
|
print(f" Progress: Encrypted {encrypted_count} mcp_server records...")
|
|
|
|
offset += BATCH_SIZE
|
|
|
|
# For very large datasets, commit periodically to avoid long transactions
|
|
if encrypted_count > 0 and encrypted_count % 50000 == 0:
|
|
connection.commit()
|
|
|
|
print(f"mcp_server: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
|
|
else:
|
|
print("mcp_server: No records need encryption")
|
|
print("Migration complete. Plaintext columns are retained for rollback safety.")
|
|
|
|
|
|
def downgrade() -> None:
|
|
pass
|