fix: asyncify encrypt on write (#8339)

* base

* update

* import
This commit is contained in:
jnjpng
2026-01-06 12:26:30 -08:00
committed by Caren Thomas
parent 350436c0cc
commit ccfd3d1432
7 changed files with 241 additions and 61 deletions

View File

@@ -5,6 +5,7 @@ from pydantic_core import core_schema
from letta.helpers.crypto_utils import CryptoUtils
from letta.log import get_logger
from letta.utils import bounded_gather
logger = get_logger(__name__)
@@ -67,6 +68,72 @@ class Secret(BaseModel):
return instance
raise # Re-raise if it's a different error
@classmethod
async def from_plaintext_async(cls, value: Optional[str]) -> "Secret":
"""
Create a Secret from a plaintext value, encrypting it asynchronously.
This async version runs encryption in a thread pool to avoid blocking
the event loop during the CPU-intensive PBKDF2 key derivation (100-500ms).
Use this method in all async contexts (FastAPI endpoints, async services, etc.)
to avoid blocking the event loop.
Args:
value: The plaintext value to encrypt
Returns:
A Secret instance with the encrypted (or plaintext) value
"""
if value is None:
return cls.model_construct(encrypted_value=None)
# Guard against double encryption - check if value is already encrypted
if CryptoUtils.is_encrypted(value):
logger.warning("Creating Secret from already-encrypted value. This can be dangerous.")
# Try to encrypt asynchronously, but fall back to storing plaintext if no encryption key
try:
encrypted = await CryptoUtils.encrypt_async(value)
return cls.model_construct(encrypted_value=encrypted)
except ValueError as e:
# No encryption key available, store as plaintext in the _enc column
if "No encryption key configured" in str(e):
logger.warning(
"No encryption key configured. Storing Secret value as plaintext in _enc column. "
"Set LETTA_ENCRYPTION_KEY environment variable to enable encryption."
)
instance = cls.model_construct(encrypted_value=value)
instance._plaintext_cache = value # Cache it since we know the plaintext
return instance
raise # Re-raise if it's a different error
@classmethod
async def from_plaintexts_async(cls, values: dict[str, str], max_concurrency: int = 10) -> dict[str, "Secret"]:
"""
Create multiple Secrets from plaintexts concurrently with bounded concurrency.
Uses bounded_gather() to encrypt values in parallel while limiting
concurrent operations to prevent overwhelming the event loop.
Args:
values: Dict of key -> plaintext value
max_concurrency: Maximum number of concurrent encryption operations (default: 10)
Returns:
Dict of key -> Secret
"""
if not values:
return {}
keys = list(values.keys())
async def encrypt_one(key: str) -> "Secret":
return await cls.from_plaintext_async(values[key])
secrets = await bounded_gather([encrypt_one(k) for k in keys], max_concurrency=max_concurrency)
return dict(zip(keys, secrets))
@classmethod
def from_encrypted(cls, encrypted_value: Optional[str]) -> "Secret":
"""

View File

@@ -556,19 +556,18 @@ class AgentManager:
agent_secrets = agent_create.secrets or agent_create.tool_exec_environment_variables
if agent_secrets:
# Encrypt environment variable values
env_rows = []
for key, val in agent_secrets.items():
# Encrypt value (Secret.from_plaintext handles missing encryption key internally)
value_secret = Secret.from_plaintext(val)
row = {
# Encrypt environment variable values concurrently (async to avoid blocking event loop)
secrets_dict = await Secret.from_plaintexts_async(agent_secrets)
env_rows = [
{
"agent_id": aid,
"key": key,
"value": "", # Empty string for NOT NULL constraint (deprecated, use value_enc)
"value_enc": value_secret.get_encrypted(),
"value_enc": secret.get_encrypted(),
"organization_id": actor.organization_id,
}
env_rows.append(row)
for key, secret in secrets_dict.items()
]
result = await session.execute(insert(AgentEnvironmentVariable).values(env_rows).returning(AgentEnvironmentVariable.id))
env_rows = [{**row, "id": env_var_id} for row, env_var_id in zip(env_rows, result.scalars().all())]
@@ -832,25 +831,35 @@ class AgentManager:
# TODO: do we need to delete each time or can we just upsert?
await session.execute(delete(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid))
# Encrypt environment variable values
# Only re-encrypt if the value has actually changed
# Decrypt existing values to check for changes (async to avoid blocking)
existing_values: dict[str, str | None] = {}
for k, existing_env in existing_env_vars.items():
if existing_env.value_enc:
existing_secret = Secret.from_encrypted(existing_env.value_enc)
existing_values[k] = await existing_secret.get_plaintext_async()
else:
existing_values[k] = None
# Identify values that need encryption (new or changed)
to_encrypt = {
k: v
for k, v in agent_secrets.items()
if k not in existing_env_vars or existing_values.get(k) != v or not existing_env_vars[k].value_enc
}
# Batch encrypt new/changed values concurrently (async to avoid blocking event loop)
new_secrets = await Secret.from_plaintexts_async(to_encrypt) if to_encrypt else {}
# Build rows, reusing existing encrypted values where unchanged
env_rows = []
for k, v in agent_secrets.items():
# Check if value changed to avoid unnecessary re-encryption
existing_env = existing_env_vars.get(k)
existing_value = None
if existing_env and existing_env.value_enc:
existing_secret = Secret.from_encrypted(existing_env.value_enc)
existing_value = await existing_secret.get_plaintext_async()
# Encrypt value (reuse existing encrypted value if unchanged)
if existing_value == v and existing_env and existing_env.value_enc:
# Value unchanged, reuse existing encrypted value
value_enc = existing_env.value_enc
if k in new_secrets:
# New or changed value - use newly encrypted value
value_enc = new_secrets[k].get_encrypted()
else:
# Value changed or new, encrypt
value_secret = Secret.from_plaintext(v)
value_enc = value_secret.get_encrypted()
# Value unchanged - reuse existing encrypted value
value_enc = existing_env_vars[k].value_enc
row = {
"agent_id": aid,

View File

@@ -150,9 +150,10 @@ class MCPOAuthSession:
try:
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
# Encrypt the authorization_code and store only in _enc column
# Encrypt the authorization_code and store only in _enc column (async to avoid blocking event loop)
if code is not None:
oauth_record.authorization_code_enc = Secret.from_plaintext(code).get_encrypted()
code_secret = await Secret.from_plaintext_async(code)
oauth_record.authorization_code_enc = code_secret.get_encrypted()
oauth_record.status = OAuthSessionStatus.AUTHORIZED
oauth_record.state = state

View File

@@ -448,15 +448,15 @@ class MCPServerManager:
# Set the organization id at the ORM layer
pydantic_mcp_server.organization_id = actor.organization_id
# Explicitly populate encrypted fields
# Explicitly populate encrypted fields (async to avoid blocking event loop)
if pydantic_mcp_server.token is not None:
pydantic_mcp_server.token_enc = Secret.from_plaintext(pydantic_mcp_server.token)
pydantic_mcp_server.token_enc = await Secret.from_plaintext_async(pydantic_mcp_server.token)
if pydantic_mcp_server.custom_headers is not None:
# custom_headers is a Dict[str, str], serialize to JSON then encrypt
import json
json_str = json.dumps(pydantic_mcp_server.custom_headers)
pydantic_mcp_server.custom_headers_enc = Secret.from_plaintext(json_str)
pydantic_mcp_server.custom_headers_enc = await Secret.from_plaintext_async(json_str)
mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True)
@@ -517,15 +517,15 @@ class MCPServerManager:
server_type=server_config.type,
server_url=server_config.server_url,
)
# Encrypt sensitive fields
# Encrypt sensitive fields (async to avoid blocking event loop)
token = server_config.resolve_token()
if token:
token_secret = Secret.from_plaintext(token)
token_secret = await Secret.from_plaintext_async(token)
mcp_server.set_token_secret(token_secret)
if server_config.custom_headers:
# Convert dict to JSON string, then encrypt as Secret
headers_json = json.dumps(server_config.custom_headers)
headers_secret = Secret.from_plaintext(headers_json)
headers_secret = await Secret.from_plaintext_async(headers_json)
mcp_server.set_custom_headers_secret(headers_secret)
elif isinstance(server_config, StreamableHTTPServerConfig):
@@ -534,15 +534,15 @@ class MCPServerManager:
server_type=server_config.type,
server_url=server_config.server_url,
)
# Encrypt sensitive fields
# Encrypt sensitive fields (async to avoid blocking event loop)
token = server_config.resolve_token()
if token:
token_secret = Secret.from_plaintext(token)
token_secret = await Secret.from_plaintext_async(token)
mcp_server.set_token_secret(token_secret)
if server_config.custom_headers:
# Convert dict to JSON string, then encrypt as Secret
headers_json = json.dumps(server_config.custom_headers)
headers_secret = Secret.from_plaintext(headers_json)
headers_secret = await Secret.from_plaintext_async(headers_json)
mcp_server.set_custom_headers_secret(headers_secret)
else:
raise ValueError(f"Unsupported server config type: {type(server_config)}")
@@ -698,9 +698,10 @@ class MCPServerManager:
elif mcp_server.token:
existing_token = mcp_server.token
# Only re-encrypt if different
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_token != update_data["token"]:
mcp_server.token_enc = Secret.from_plaintext(update_data["token"]).get_encrypted()
token_secret = await Secret.from_plaintext_async(update_data["token"])
mcp_server.token_enc = token_secret.get_encrypted()
# Keep plaintext for dual-write during migration
mcp_server.token = update_data["token"]
@@ -725,9 +726,10 @@ class MCPServerManager:
elif mcp_server.custom_headers:
existing_headers_json = json.dumps(mcp_server.custom_headers)
# Only re-encrypt if different
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_headers_json != json_str:
mcp_server.custom_headers_enc = Secret.from_plaintext(json_str).get_encrypted()
headers_secret = await Secret.from_plaintext_async(json_str)
mcp_server.custom_headers_enc = headers_secret.get_encrypted()
# Keep plaintext for dual-write during migration
mcp_server.custom_headers = update_data["custom_headers"]
@@ -1117,9 +1119,10 @@ class MCPServerManager:
elif oauth_session.authorization_code:
existing_code = oauth_session.authorization_code
# Only re-encrypt if different
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_code != session_update.authorization_code:
oauth_session.authorization_code_enc = Secret.from_plaintext(session_update.authorization_code).get_encrypted()
code_secret = await Secret.from_plaintext_async(session_update.authorization_code)
oauth_session.authorization_code_enc = code_secret.get_encrypted()
# Keep plaintext for dual-write during migration
oauth_session.authorization_code = session_update.authorization_code
@@ -1134,9 +1137,10 @@ class MCPServerManager:
elif oauth_session.access_token:
existing_token = oauth_session.access_token
# Only re-encrypt if different
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_token != session_update.access_token:
oauth_session.access_token_enc = Secret.from_plaintext(session_update.access_token).get_encrypted()
token_secret = await Secret.from_plaintext_async(session_update.access_token)
oauth_session.access_token_enc = token_secret.get_encrypted()
# Keep plaintext for dual-write during migration
oauth_session.access_token = session_update.access_token
@@ -1151,9 +1155,10 @@ class MCPServerManager:
elif oauth_session.refresh_token:
existing_refresh = oauth_session.refresh_token
# Only re-encrypt if different
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_refresh != session_update.refresh_token:
oauth_session.refresh_token_enc = Secret.from_plaintext(session_update.refresh_token).get_encrypted()
refresh_secret = await Secret.from_plaintext_async(session_update.refresh_token)
oauth_session.refresh_token_enc = refresh_secret.get_encrypted()
# Keep plaintext for dual-write during migration
oauth_session.refresh_token = session_update.refresh_token
@@ -1177,9 +1182,10 @@ class MCPServerManager:
elif oauth_session.client_secret:
existing_secret_val = oauth_session.client_secret
# Only re-encrypt if different
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_secret_val != session_update.client_secret:
oauth_session.client_secret_enc = Secret.from_plaintext(session_update.client_secret).get_encrypted()
client_secret_encrypted = await Secret.from_plaintext_async(session_update.client_secret)
oauth_session.client_secret_enc = client_secret_encrypted.get_encrypted()
# Keep plaintext for dual-write during migration
oauth_session.client_secret = session_update.client_secret

View File

@@ -89,11 +89,13 @@ class ProviderManager:
deleted_provider.region = request.region
deleted_provider.api_version = request.api_version
# Update encrypted fields
# Update encrypted fields (async to avoid blocking event loop)
if request.api_key is not None:
deleted_provider.api_key_enc = Secret.from_plaintext(request.api_key).get_encrypted()
api_key_secret = await Secret.from_plaintext_async(request.api_key)
deleted_provider.api_key_enc = api_key_secret.get_encrypted()
if request.access_key is not None:
deleted_provider.access_key_enc = Secret.from_plaintext(request.access_key).get_encrypted()
access_key_secret = await Secret.from_plaintext_async(request.access_key)
deleted_provider.access_key_enc = access_key_secret.get_encrypted()
await deleted_provider.update_async(session, actor=actor)
provider_pydantic = deleted_provider.to_pydantic()
@@ -125,11 +127,11 @@ class ProviderManager:
# Lazily create the provider id prior to persistence
provider.resolve_identifier()
# Explicitly populate encrypted fields from plaintext
# Explicitly populate encrypted fields from plaintext (async to avoid blocking event loop)
if request.api_key is not None:
provider.api_key_enc = Secret.from_plaintext(request.api_key)
provider.api_key_enc = await Secret.from_plaintext_async(request.api_key)
if request.access_key is not None:
provider.access_key_enc = Secret.from_plaintext(request.access_key)
provider.access_key_enc = await Secret.from_plaintext_async(request.access_key)
new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
await new_provider.create_async(session, actor=actor)
@@ -164,9 +166,10 @@ class ProviderManager:
existing_secret = Secret.from_encrypted(existing_provider.api_key_enc)
existing_api_key = await existing_secret.get_plaintext_async()
# Only re-encrypt if different
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_api_key != update_data["api_key"]:
existing_provider.api_key_enc = Secret.from_plaintext(update_data["api_key"]).get_encrypted()
api_key_secret = await Secret.from_plaintext_async(update_data["api_key"])
existing_provider.api_key_enc = api_key_secret.get_encrypted()
# Remove from update_data since we set directly on existing_provider
update_data.pop("api_key", None)
@@ -181,9 +184,10 @@ class ProviderManager:
existing_secret = Secret.from_encrypted(existing_provider.access_key_enc)
existing_access_key = await existing_secret.get_plaintext_async()
# Only re-encrypt if different
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_access_key != update_data["access_key"]:
existing_provider.access_key_enc = Secret.from_plaintext(update_data["access_key"]).get_encrypted()
access_key_secret = await Secret.from_plaintext_async(update_data["access_key"])
existing_provider.access_key_enc = access_key_secret.get_encrypted()
# Remove from update_data since we set directly on existing_provider
update_data.pop("access_key", None)

View File

@@ -210,11 +210,11 @@ class SandboxConfigManager:
return db_env_var
else:
async with db_registry.async_session() as session:
# Encrypt the value before storing (only to value_enc, not plaintext)
# Encrypt the value before storing (async to avoid blocking event loop)
from letta.schemas.secret import Secret
if env_var.value:
env_var.value_enc = Secret.from_plaintext(env_var.value)
env_var.value_enc = await Secret.from_plaintext_async(env_var.value)
env_var.value = "" # Don't store plaintext, use empty string for NOT NULL constraint
env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True))
@@ -242,9 +242,10 @@ class SandboxConfigManager:
existing_secret = Secret.from_encrypted(env_var.value_enc)
existing_value = await existing_secret.get_plaintext_async()
# Only re-encrypt if different
# Only re-encrypt if different (async to avoid blocking event loop)
if existing_value != update_data["value"]:
env_var.value_enc = Secret.from_plaintext(update_data["value"]).get_encrypted()
value_secret = await Secret.from_plaintext_async(update_data["value"])
env_var.value_enc = value_secret.get_encrypted()
# Don't store plaintext anymore
# Remove from update_data since we set directly on env_var

View File

@@ -216,3 +216,95 @@ class TestSecret:
assert mock_decrypt.call_count == 1
finally:
settings.encryption_key = original_key
@pytest.mark.asyncio
async def test_from_plaintext_async_with_key(self):
"""Test creating a Secret from plaintext value asynchronously with encryption key."""
from letta.settings import settings
# Set encryption key
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext = "my-async-secret-value"
secret = await Secret.from_plaintext_async(plaintext)
# Should store encrypted value
assert secret.encrypted_value is not None
assert secret.encrypted_value != plaintext
# Should decrypt to original value
result = await secret.get_plaintext_async()
assert result == plaintext
finally:
settings.encryption_key = original_key
@pytest.mark.asyncio
async def test_from_plaintext_async_without_key_stores_plaintext(self):
"""Test creating a Secret asynchronously without encryption key stores as plaintext."""
from letta.settings import settings
# Clear encryption key
original_key = settings.encryption_key
settings.encryption_key = None
try:
plaintext = "my-async-plaintext-value"
# Should store as plaintext in _enc column when no encryption key
secret = await Secret.from_plaintext_async(plaintext)
# Should store the plaintext value directly in encrypted_value
assert secret.encrypted_value == plaintext
result = await secret.get_plaintext_async()
assert result == plaintext
finally:
settings.encryption_key = original_key
@pytest.mark.asyncio
async def test_from_plaintext_async_with_none(self):
"""Test creating a Secret asynchronously from None value."""
secret = await Secret.from_plaintext_async(None)
assert secret.encrypted_value is None
result = await secret.get_plaintext_async()
assert result is None
assert secret.is_empty() is True
@pytest.mark.asyncio
async def test_from_plaintexts_async(self):
"""Test batch encrypting multiple secrets concurrently."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
values = {
"key1": "value1",
"key2": "value2",
"key3": "value3",
}
secrets = await Secret.from_plaintexts_async(values)
# Should return dict with same keys
assert set(secrets.keys()) == {"key1", "key2", "key3"}
# Each secret should decrypt to original value
for key, secret in secrets.items():
assert isinstance(secret, Secret)
assert secret.encrypted_value is not None
assert secret.encrypted_value != values[key]
result = await secret.get_plaintext_async()
assert result == values[key]
finally:
settings.encryption_key = original_key
@pytest.mark.asyncio
async def test_from_plaintexts_async_empty_dict(self):
"""Test batch encrypting with empty dict."""
secrets = await Secret.from_plaintexts_async({})
assert secrets == {}