feat: backfill providers, sandbox_environment_variables, mcp_oauth *_enc columns [LET-5458] (#5382)
* base * remove print --------- Co-authored-by: Letta Bot <noreply@letta.com>
This commit is contained in:
349
alembic/versions/8149a781ac1b_backfill_encrypted_columns_for_.py
Normal file
349
alembic/versions/8149a781ac1b_backfill_encrypted_columns_for_.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""backfill encrypted columns for providers, mcp, sandbox
|
||||
|
||||
Revision ID: 8149a781ac1b
|
||||
Revises: 066857381578
|
||||
Create Date: 2025-10-13 13:35:55.929562
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.sql import column, table
|
||||
|
||||
from alembic import op
|
||||
from letta.helpers.crypto_utils import CryptoUtils
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "8149a781ac1b"
|
||||
down_revision: Union[str, None] = "066857381578"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Check if encryption key is available
|
||||
encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY")
|
||||
if not encryption_key:
|
||||
print("WARNING: LETTA_ENCRYPTION_KEY not set. Skipping data encryption migration.")
|
||||
print("You can run a separate migration script later to encrypt existing data.")
|
||||
return
|
||||
|
||||
# Get database connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Batch processing configuration
|
||||
BATCH_SIZE = 1000 # Process 1000 rows at a time
|
||||
|
||||
# Migrate providers data
|
||||
print("Migrating providers encrypted fields...")
|
||||
providers = table(
|
||||
"providers",
|
||||
column("id", String),
|
||||
column("api_key", String),
|
||||
column("api_key_enc", Text),
|
||||
column("access_key", String),
|
||||
column("access_key_enc", Text),
|
||||
)
|
||||
|
||||
# Count total rows to process
|
||||
total_count_result = connection.execute(
|
||||
sa.select(sa.func.count())
|
||||
.select_from(providers)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(providers.c.api_key.isnot(None), providers.c.access_key.isnot(None)),
|
||||
# Only count rows that need encryption
|
||||
sa.or_(
|
||||
sa.and_(providers.c.api_key.isnot(None), providers.c.api_key_enc.is_(None)),
|
||||
sa.and_(providers.c.access_key.isnot(None), providers.c.access_key_enc.is_(None)),
|
||||
),
|
||||
)
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if total_count_result and total_count_result > 0:
|
||||
print(f"Found {total_count_result} providers records that need encryption")
|
||||
|
||||
encrypted_count = 0
|
||||
skipped_count = 0
|
||||
offset = 0
|
||||
|
||||
# Process in batches
|
||||
while True:
|
||||
# Select batch of rows
|
||||
provider_rows = connection.execute(
|
||||
sa.select(
|
||||
providers.c.id,
|
||||
providers.c.api_key,
|
||||
providers.c.api_key_enc,
|
||||
providers.c.access_key,
|
||||
providers.c.access_key_enc,
|
||||
)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(providers.c.api_key.isnot(None), providers.c.access_key.isnot(None)),
|
||||
# Only select rows that need encryption
|
||||
sa.or_(
|
||||
sa.and_(providers.c.api_key.isnot(None), providers.c.api_key_enc.is_(None)),
|
||||
sa.and_(providers.c.access_key.isnot(None), providers.c.access_key_enc.is_(None)),
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(providers.c.id) # Ensure consistent ordering
|
||||
.limit(BATCH_SIZE)
|
||||
.offset(offset)
|
||||
).fetchall()
|
||||
|
||||
if not provider_rows:
|
||||
break # No more rows to process
|
||||
|
||||
# Prepare batch updates
|
||||
batch_updates = []
|
||||
|
||||
for row in provider_rows:
|
||||
updates = {"id": row.id}
|
||||
has_updates = False
|
||||
|
||||
# Encrypt api_key if present and not already encrypted
|
||||
if row.api_key and not row.api_key_enc:
|
||||
try:
|
||||
updates["api_key_enc"] = CryptoUtils.encrypt(row.api_key, encryption_key)
|
||||
has_updates = True
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt api_key for providers id={row.id}: {e}")
|
||||
elif row.api_key_enc:
|
||||
skipped_count += 1
|
||||
|
||||
# Encrypt access_key if present and not already encrypted
|
||||
if row.access_key and not row.access_key_enc:
|
||||
try:
|
||||
updates["access_key_enc"] = CryptoUtils.encrypt(row.access_key, encryption_key)
|
||||
has_updates = True
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt access_key for providers id={row.id}: {e}")
|
||||
elif row.access_key_enc:
|
||||
skipped_count += 1
|
||||
|
||||
if has_updates:
|
||||
batch_updates.append(updates)
|
||||
encrypted_count += 1
|
||||
|
||||
# Execute batch update if there are updates
|
||||
if batch_updates:
|
||||
# Use bulk update for better performance
|
||||
for update_data in batch_updates:
|
||||
row_id = update_data.pop("id")
|
||||
if update_data: # Only update if there are fields to update
|
||||
connection.execute(providers.update().where(providers.c.id == row_id).values(**update_data))
|
||||
|
||||
# Progress indicator for large datasets
|
||||
if encrypted_count > 0 and encrypted_count % 10000 == 0:
|
||||
print(f" Progress: Encrypted {encrypted_count} providers records...")
|
||||
|
||||
offset += BATCH_SIZE
|
||||
|
||||
# For very large datasets, commit periodically to avoid long transactions
|
||||
if encrypted_count > 0 and encrypted_count % 50000 == 0:
|
||||
connection.commit()
|
||||
|
||||
print(f"providers: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
|
||||
else:
|
||||
print("providers: No records need encryption")
|
||||
|
||||
# Migrate sandbox_environment_variables data
|
||||
print("Migrating sandbox_environment_variables encrypted fields...")
|
||||
sandbox_environment_variables = table(
|
||||
"sandbox_environment_variables",
|
||||
column("id", String),
|
||||
column("value", String),
|
||||
column("value_enc", Text),
|
||||
)
|
||||
|
||||
# Count total rows to process
|
||||
total_count_result = connection.execute(
|
||||
sa.select(sa.func.count())
|
||||
.select_from(sandbox_environment_variables)
|
||||
.where(
|
||||
sa.and_(
|
||||
sandbox_environment_variables.c.value.isnot(None),
|
||||
sandbox_environment_variables.c.value_enc.is_(None),
|
||||
)
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if total_count_result and total_count_result > 0:
|
||||
print(f"Found {total_count_result} sandbox_environment_variables records that need encryption")
|
||||
|
||||
encrypted_count = 0
|
||||
skipped_count = 0
|
||||
offset = 0
|
||||
|
||||
# Process in batches
|
||||
while True:
|
||||
# Select batch of rows
|
||||
env_var_rows = connection.execute(
|
||||
sa.select(
|
||||
sandbox_environment_variables.c.id,
|
||||
sandbox_environment_variables.c.value,
|
||||
sandbox_environment_variables.c.value_enc,
|
||||
)
|
||||
.where(
|
||||
sa.and_(
|
||||
sandbox_environment_variables.c.value.isnot(None),
|
||||
sandbox_environment_variables.c.value_enc.is_(None),
|
||||
)
|
||||
)
|
||||
.order_by(sandbox_environment_variables.c.id) # Ensure consistent ordering
|
||||
.limit(BATCH_SIZE)
|
||||
.offset(offset)
|
||||
).fetchall()
|
||||
|
||||
if not env_var_rows:
|
||||
break # No more rows to process
|
||||
|
||||
# Prepare batch updates
|
||||
batch_updates = []
|
||||
|
||||
for row in env_var_rows:
|
||||
updates = {"id": row.id}
|
||||
has_updates = False
|
||||
|
||||
# Encrypt value if present and not already encrypted
|
||||
if row.value and not row.value_enc:
|
||||
try:
|
||||
updates["value_enc"] = CryptoUtils.encrypt(row.value, encryption_key)
|
||||
has_updates = True
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt value for sandbox_environment_variables id={row.id}: {e}")
|
||||
elif row.value_enc:
|
||||
skipped_count += 1
|
||||
|
||||
if has_updates:
|
||||
batch_updates.append(updates)
|
||||
encrypted_count += 1
|
||||
|
||||
# Execute batch update if there are updates
|
||||
if batch_updates:
|
||||
# Use bulk update for better performance
|
||||
for update_data in batch_updates:
|
||||
row_id = update_data.pop("id")
|
||||
if update_data: # Only update if there are fields to update
|
||||
connection.execute(
|
||||
sandbox_environment_variables.update().where(sandbox_environment_variables.c.id == row_id).values(**update_data)
|
||||
)
|
||||
|
||||
# Progress indicator for large datasets
|
||||
if encrypted_count > 0 and encrypted_count % 10000 == 0:
|
||||
print(f" Progress: Encrypted {encrypted_count} sandbox_environment_variables records...")
|
||||
|
||||
offset += BATCH_SIZE
|
||||
|
||||
# For very large datasets, commit periodically to avoid long transactions
|
||||
if encrypted_count > 0 and encrypted_count % 50000 == 0:
|
||||
connection.commit()
|
||||
|
||||
print(f"sandbox_environment_variables: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
|
||||
else:
|
||||
print("sandbox_environment_variables: No records need encryption")
|
||||
|
||||
# Migrate mcp_oauth data (only authorization_code field)
|
||||
print("Migrating mcp_oauth encrypted fields...")
|
||||
mcp_oauth = table(
|
||||
"mcp_oauth",
|
||||
column("id", String),
|
||||
column("authorization_code", Text),
|
||||
column("authorization_code_enc", Text),
|
||||
)
|
||||
|
||||
# Count total rows to process
|
||||
total_count_result = connection.execute(
|
||||
sa.select(sa.func.count())
|
||||
.select_from(mcp_oauth)
|
||||
.where(
|
||||
sa.and_(
|
||||
mcp_oauth.c.authorization_code.isnot(None),
|
||||
mcp_oauth.c.authorization_code_enc.is_(None),
|
||||
)
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if total_count_result and total_count_result > 0:
|
||||
print(f"Found {total_count_result} mcp_oauth records that need encryption")
|
||||
|
||||
encrypted_count = 0
|
||||
skipped_count = 0
|
||||
offset = 0
|
||||
|
||||
# Process in batches
|
||||
while True:
|
||||
# Select batch of rows
|
||||
oauth_rows = connection.execute(
|
||||
sa.select(
|
||||
mcp_oauth.c.id,
|
||||
mcp_oauth.c.authorization_code,
|
||||
mcp_oauth.c.authorization_code_enc,
|
||||
)
|
||||
.where(
|
||||
sa.and_(
|
||||
mcp_oauth.c.authorization_code.isnot(None),
|
||||
mcp_oauth.c.authorization_code_enc.is_(None),
|
||||
)
|
||||
)
|
||||
.order_by(mcp_oauth.c.id) # Ensure consistent ordering
|
||||
.limit(BATCH_SIZE)
|
||||
.offset(offset)
|
||||
).fetchall()
|
||||
|
||||
if not oauth_rows:
|
||||
break # No more rows to process
|
||||
|
||||
# Prepare batch updates
|
||||
batch_updates = []
|
||||
|
||||
for row in oauth_rows:
|
||||
updates = {"id": row.id}
|
||||
has_updates = False
|
||||
|
||||
# Encrypt authorization_code if present and not already encrypted
|
||||
if row.authorization_code and not row.authorization_code_enc:
|
||||
try:
|
||||
updates["authorization_code_enc"] = CryptoUtils.encrypt(row.authorization_code, encryption_key)
|
||||
has_updates = True
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt authorization_code for mcp_oauth id={row.id}: {e}")
|
||||
elif row.authorization_code_enc:
|
||||
skipped_count += 1
|
||||
|
||||
if has_updates:
|
||||
batch_updates.append(updates)
|
||||
encrypted_count += 1
|
||||
|
||||
# Execute batch update if there are updates
|
||||
if batch_updates:
|
||||
# Use bulk update for better performance
|
||||
for update_data in batch_updates:
|
||||
row_id = update_data.pop("id")
|
||||
if update_data: # Only update if there are fields to update
|
||||
connection.execute(mcp_oauth.update().where(mcp_oauth.c.id == row_id).values(**update_data))
|
||||
|
||||
# Progress indicator for large datasets
|
||||
if encrypted_count > 0 and encrypted_count % 10000 == 0:
|
||||
print(f" Progress: Encrypted {encrypted_count} mcp_oauth records...")
|
||||
|
||||
offset += BATCH_SIZE
|
||||
|
||||
# For very large datasets, commit periodically to avoid long transactions
|
||||
if encrypted_count > 0 and encrypted_count % 50000 == 0:
|
||||
connection.commit()
|
||||
|
||||
print(f"mcp_oauth: Encrypted {encrypted_count} records, skipped {skipped_count} already encrypted fields")
|
||||
else:
|
||||
print("mcp_oauth: No records need encryption")
|
||||
print("Migration complete. Plaintext columns are retained for rollback safety.")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
Reference in New Issue
Block a user