350 lines
13 KiB
Python
350 lines
13 KiB
Python
"""backfill encrypted columns for providers, mcp, sandbox
|
|
|
|
Revision ID: 8149a781ac1b
|
|
Revises: 066857381578
|
|
Create Date: 2025-10-13 13:35:55.929562
|
|
|
|
"""
|
|
|
|
import os
|
|
from typing import Sequence, Union
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy import String, Text
|
|
from sqlalchemy.sql import column, table
|
|
|
|
from alembic import op
|
|
from letta.helpers.crypto_utils import CryptoUtils
|
|
|
|
# revision identifiers, used by Alembic.
|
|
revision: str = "8149a781ac1b"
|
|
down_revision: Union[str, None] = "066857381578"
|
|
branch_labels: Union[str, Sequence[str], None] = None
|
|
depends_on: Union[str, Sequence[str], None] = None
|
|
|
|
|
|
def upgrade() -> None:
|
|
# Check if encryption key is available
|
|
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY")
|
|
if not encryption_key:
|
|
print("WARNING: LETTA_ENCRYPTION_KEY not set. Skipping data encryption migration.")
|
|
print("You can run a separate migration script later to encrypt existing data.")
|
|
return
|
|
|
|
# Get database connection
|
|
connection = op.get_bind()
|
|
|
|
# Batch processing configuration
|
|
BATCH_SIZE = 1000 # Process 1000 rows at a time
|
|
|
|
# Migrate providers data
|
|
print("Migrating providers encrypted fields...")
|
|
providers = table(
|
|
"providers",
|
|
column("id", String),
|
|
column("api_key", String),
|
|
column("api_key_enc", Text),
|
|
column("access_key", String),
|
|
column("access_key_enc", Text),
|
|
)
|
|
|
|
# Count total rows to process
|
|
total_count_result = connection.execute(
|
|
sa.select(sa.func.count())
|
|
.select_from(providers)
|
|
.where(
|
|
sa.and_(
|
|
sa.or_(providers.c.api_key.isnot(None), providers.c.access_key.isnot(None)),
|
|
# Only count rows that need encryption
|
|
sa.or_(
|
|
sa.and_(providers.c.api_key.isnot(None), providers.c.api_key_enc.is_(None)),
|
|
sa.and_(providers.c.access_key.isnot(None), providers.c.access_key_enc.is_(None)),
|
|
),
|
|
)
|
|
)
|
|
).scalar()
|
|
|
|
if total_count_result and total_count_result > 0:
|
|
print(f"Found {total_count_result} providers records that need encryption")
|
|
|
|
encrypted_count = 0
|
|
skipped_count = 0
|
|
offset = 0
|
|
|
|
# Process in batches
|
|
while True:
|
|
# Select batch of rows
|
|
provider_rows = connection.execute(
|
|
sa.select(
|
|
providers.c.id,
|
|
providers.c.api_key,
|
|
providers.c.api_key_enc,
|
|
providers.c.access_key,
|
|
providers.c.access_key_enc,
|
|
)
|
|
.where(
|
|
sa.and_(
|
|
sa.or_(providers.c.api_key.isnot(None), providers.c.access_key.isnot(None)),
|
|
# Only select rows that need encryption
|
|
sa.or_(
|
|
sa.and_(providers.c.api_key.isnot(None), providers.c.api_key_enc.is_(None)),
|
|
sa.and_(providers.c.access_key.isnot(None), providers.c.access_key_enc.is_(None)),
|
|
),
|
|
)
|
|
)
|
|
.order_by(providers.c.id) # Ensure consistent ordering
|
|
.limit(BATCH_SIZE)
|
|
.offset(offset)
|
|
).fetchall()
|
|
|
|
if not provider_rows:
|
|
break # No more rows to process
|
|
|
|
# Prepare batch updates
|
|
batch_updates = []
|
|
|
|
for row in provider_rows:
|
|
updates = {"id": row.id}
|
|
has_updates = False
|
|
|
|
# Encrypt api_key if present and not already encrypted
|
|
if row.api_key and not row.api_key_enc:
|
|
try:
|
|
updates["api_key_enc"] = CryptoUtils.encrypt(row.api_key, encryption_key)
|
|
has_updates = True
|
|
except Exception as e:
|
|
print(f"Warning: Failed to encrypt api_key for providers id={row.id}: {e}")
|
|
elif row.api_key_enc:
|
|
skipped_count += 1
|
|
|
|
# Encrypt access_key if present and not already encrypted
|
|
if row.access_key and not row.access_key_enc:
|
|
try:
|
|
updates["access_key_enc"] = CryptoUtils.encrypt(row.access_key, encryption_key)
|
|
has_updates = True
|
|
except Exception as e:
|
|
print(f"Warning: Failed to encrypt access_key for providers id={row.id}: {e}")
|
|
elif row.access_key_enc:
|
|
skipped_count += 1
|
|
|
|
if has_updates:
|
|
batch_updates.append(updates)
|
|
encrypted_count += 1
|
|
|
|
# Execute batch update if there are updates
|
|
if batch_updates:
|
|
# Use bulk update for better performance
|
|
for update_data in batch_updates:
|
|
row_id = update_data.pop("id")
|
|
if update_data: # Only update if there are fields to update
|
|
connection.execute(providers.update().where(providers.c.id == row_id).values(**update_data))
|
|
|
|
# Progress indicator for large datasets
|
|
if encrypted_count > 0 and encrypted_count % 10000 == 0:
|
|
print(f" Progress: Encrypted {encrypted_count} providers records...")
|
|
|
|
offset += BATCH_SIZE
|
|
|
|
# For very large datasets, commit periodically to avoid long transactions
|
|
if encrypted_count > 0 and encrypted_count % 50000 == 0:
|
|
connection.commit()
|
|
|
|
print(f"providers: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
|
|
else:
|
|
print("providers: No records need encryption")
|
|
|
|
# Migrate sandbox_environment_variables data
|
|
print("Migrating sandbox_environment_variables encrypted fields...")
|
|
sandbox_environment_variables = table(
|
|
"sandbox_environment_variables",
|
|
column("id", String),
|
|
column("value", String),
|
|
column("value_enc", Text),
|
|
)
|
|
|
|
# Count total rows to process
|
|
total_count_result = connection.execute(
|
|
sa.select(sa.func.count())
|
|
.select_from(sandbox_environment_variables)
|
|
.where(
|
|
sa.and_(
|
|
sandbox_environment_variables.c.value.isnot(None),
|
|
sandbox_environment_variables.c.value_enc.is_(None),
|
|
)
|
|
)
|
|
).scalar()
|
|
|
|
if total_count_result and total_count_result > 0:
|
|
print(f"Found {total_count_result} sandbox_environment_variables records that need encryption")
|
|
|
|
encrypted_count = 0
|
|
skipped_count = 0
|
|
offset = 0
|
|
|
|
# Process in batches
|
|
while True:
|
|
# Select batch of rows
|
|
env_var_rows = connection.execute(
|
|
sa.select(
|
|
sandbox_environment_variables.c.id,
|
|
sandbox_environment_variables.c.value,
|
|
sandbox_environment_variables.c.value_enc,
|
|
)
|
|
.where(
|
|
sa.and_(
|
|
sandbox_environment_variables.c.value.isnot(None),
|
|
sandbox_environment_variables.c.value_enc.is_(None),
|
|
)
|
|
)
|
|
.order_by(sandbox_environment_variables.c.id) # Ensure consistent ordering
|
|
.limit(BATCH_SIZE)
|
|
.offset(offset)
|
|
).fetchall()
|
|
|
|
if not env_var_rows:
|
|
break # No more rows to process
|
|
|
|
# Prepare batch updates
|
|
batch_updates = []
|
|
|
|
for row in env_var_rows:
|
|
updates = {"id": row.id}
|
|
has_updates = False
|
|
|
|
# Encrypt value if present and not already encrypted
|
|
if row.value and not row.value_enc:
|
|
try:
|
|
updates["value_enc"] = CryptoUtils.encrypt(row.value, encryption_key)
|
|
has_updates = True
|
|
except Exception as e:
|
|
print(f"Warning: Failed to encrypt value for sandbox_environment_variables id={row.id}: {e}")
|
|
elif row.value_enc:
|
|
skipped_count += 1
|
|
|
|
if has_updates:
|
|
batch_updates.append(updates)
|
|
encrypted_count += 1
|
|
|
|
# Execute batch update if there are updates
|
|
if batch_updates:
|
|
# Use bulk update for better performance
|
|
for update_data in batch_updates:
|
|
row_id = update_data.pop("id")
|
|
if update_data: # Only update if there are fields to update
|
|
connection.execute(
|
|
sandbox_environment_variables.update().where(sandbox_environment_variables.c.id == row_id).values(**update_data)
|
|
)
|
|
|
|
# Progress indicator for large datasets
|
|
if encrypted_count > 0 and encrypted_count % 10000 == 0:
|
|
print(f" Progress: Encrypted {encrypted_count} sandbox_environment_variables records...")
|
|
|
|
offset += BATCH_SIZE
|
|
|
|
# For very large datasets, commit periodically to avoid long transactions
|
|
if encrypted_count > 0 and encrypted_count % 50000 == 0:
|
|
connection.commit()
|
|
|
|
print(f"sandbox_environment_variables: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
|
|
else:
|
|
print("sandbox_environment_variables: No records need encryption")
|
|
|
|
# Migrate mcp_oauth data (only authorization_code field)
|
|
print("Migrating mcp_oauth encrypted fields...")
|
|
mcp_oauth = table(
|
|
"mcp_oauth",
|
|
column("id", String),
|
|
column("authorization_code", Text),
|
|
column("authorization_code_enc", Text),
|
|
)
|
|
|
|
# Count total rows to process
|
|
total_count_result = connection.execute(
|
|
sa.select(sa.func.count())
|
|
.select_from(mcp_oauth)
|
|
.where(
|
|
sa.and_(
|
|
mcp_oauth.c.authorization_code.isnot(None),
|
|
mcp_oauth.c.authorization_code_enc.is_(None),
|
|
)
|
|
)
|
|
).scalar()
|
|
|
|
if total_count_result and total_count_result > 0:
|
|
print(f"Found {total_count_result} mcp_oauth records that need encryption")
|
|
|
|
encrypted_count = 0
|
|
skipped_count = 0
|
|
offset = 0
|
|
|
|
# Process in batches
|
|
while True:
|
|
# Select batch of rows
|
|
oauth_rows = connection.execute(
|
|
sa.select(
|
|
mcp_oauth.c.id,
|
|
mcp_oauth.c.authorization_code,
|
|
mcp_oauth.c.authorization_code_enc,
|
|
)
|
|
.where(
|
|
sa.and_(
|
|
mcp_oauth.c.authorization_code.isnot(None),
|
|
mcp_oauth.c.authorization_code_enc.is_(None),
|
|
)
|
|
)
|
|
.order_by(mcp_oauth.c.id) # Ensure consistent ordering
|
|
.limit(BATCH_SIZE)
|
|
.offset(offset)
|
|
).fetchall()
|
|
|
|
if not oauth_rows:
|
|
break # No more rows to process
|
|
|
|
# Prepare batch updates
|
|
batch_updates = []
|
|
|
|
for row in oauth_rows:
|
|
updates = {"id": row.id}
|
|
has_updates = False
|
|
|
|
# Encrypt authorization_code if present and not already encrypted
|
|
if row.authorization_code and not row.authorization_code_enc:
|
|
try:
|
|
updates["authorization_code_enc"] = CryptoUtils.encrypt(row.authorization_code, encryption_key)
|
|
has_updates = True
|
|
except Exception as e:
|
|
print(f"Warning: Failed to encrypt authorization_code for mcp_oauth id={row.id}: {e}")
|
|
elif row.authorization_code_enc:
|
|
skipped_count += 1
|
|
|
|
if has_updates:
|
|
batch_updates.append(updates)
|
|
encrypted_count += 1
|
|
|
|
# Execute batch update if there are updates
|
|
if batch_updates:
|
|
# Use bulk update for better performance
|
|
for update_data in batch_updates:
|
|
row_id = update_data.pop("id")
|
|
if update_data: # Only update if there are fields to update
|
|
connection.execute(mcp_oauth.update().where(mcp_oauth.c.id == row_id).values(**update_data))
|
|
|
|
# Progress indicator for large datasets
|
|
if encrypted_count > 0 and encrypted_count % 10000 == 0:
|
|
print(f" Progress: Encrypted {encrypted_count} mcp_oauth records...")
|
|
|
|
offset += BATCH_SIZE
|
|
|
|
# For very large datasets, commit periodically to avoid long transactions
|
|
if encrypted_count > 0 and encrypted_count % 50000 == 0:
|
|
connection.commit()
|
|
|
|
print(f"mcp_oauth: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
|
|
else:
|
|
print("mcp_oauth: No records need encryption")
|
|
print("Migration complete. Plaintext columns are retained for rollback safety.")
|
|
|
|
|
|
def downgrade() -> None:
|
|
pass
|