From ccfd3d1432af2d5e749deed2fe1c87bbcfe70f51 Mon Sep 17 00:00:00 2001 From: jnjpng Date: Tue, 6 Jan 2026 12:26:30 -0800 Subject: [PATCH] fix: asyncify encrypt on write (#8339) * base * update * import --- letta/schemas/secret.py | 67 +++++++++++++++++ letta/services/agent_manager.py | 57 ++++++++------- letta/services/mcp/oauth_utils.py | 5 +- letta/services/mcp_server_manager.py | 48 +++++++------ letta/services/provider_manager.py | 24 ++++--- letta/services/sandbox_config_manager.py | 9 +-- tests/test_secret.py | 92 ++++++++++++++++++++++++ 7 files changed, 241 insertions(+), 61 deletions(-) diff --git a/letta/schemas/secret.py b/letta/schemas/secret.py index 8868691d..b2714c20 100644 --- a/letta/schemas/secret.py +++ b/letta/schemas/secret.py @@ -5,6 +5,7 @@ from pydantic_core import core_schema from letta.helpers.crypto_utils import CryptoUtils from letta.log import get_logger +from letta.utils import bounded_gather logger = get_logger(__name__) @@ -67,6 +68,72 @@ class Secret(BaseModel): return instance raise # Re-raise if it's a different error + @classmethod + async def from_plaintext_async(cls, value: Optional[str]) -> "Secret": + """ + Create a Secret from a plaintext value, encrypting it asynchronously. + + This async version runs encryption in a thread pool to avoid blocking + the event loop during the CPU-intensive PBKDF2 key derivation (100-500ms). + + Use this method in all async contexts (FastAPI endpoints, async services, etc.) + to avoid blocking the event loop. + + Args: + value: The plaintext value to encrypt + + Returns: + A Secret instance with the encrypted (or plaintext) value + """ + if value is None: + return cls.model_construct(encrypted_value=None) + + # Guard against double encryption - check if value is already encrypted + if CryptoUtils.is_encrypted(value): + logger.warning("Creating Secret from already-encrypted value. This can be dangerous.") + + # Try to encrypt asynchronously, but fall back to storing plaintext if no encryption key + try: + encrypted = await CryptoUtils.encrypt_async(value) + return cls.model_construct(encrypted_value=encrypted) + except ValueError as e: + # No encryption key available, store as plaintext in the _enc column + if "No encryption key configured" in str(e): + logger.warning( + "No encryption key configured. Storing Secret value as plaintext in _enc column. " + "Set LETTA_ENCRYPTION_KEY environment variable to enable encryption." + ) + instance = cls.model_construct(encrypted_value=value) + instance._plaintext_cache = value # Cache it since we know the plaintext + return instance + raise # Re-raise if it's a different error + + @classmethod + async def from_plaintexts_async(cls, values: dict[str, str], max_concurrency: int = 10) -> dict[str, "Secret"]: + """ + Create multiple Secrets from plaintexts concurrently with bounded concurrency. + + Uses bounded_gather() to encrypt values in parallel while limiting + concurrent operations to prevent overwhelming the event loop. + + Args: + values: Dict of key -> plaintext value + max_concurrency: Maximum number of concurrent encryption operations (default: 10) + + Returns: + Dict of key -> Secret + """ + if not values: + return {} + + keys = list(values.keys()) + + async def encrypt_one(key: str) -> "Secret": + return await cls.from_plaintext_async(values[key]) + + secrets = await bounded_gather([encrypt_one(k) for k in keys], max_concurrency=max_concurrency) + return dict(zip(keys, secrets)) + @classmethod def from_encrypted(cls, encrypted_value: Optional[str]) -> "Secret": """ diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 5ddc3b89..d32541fd 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -556,19 +556,18 @@ class AgentManager: agent_secrets = agent_create.secrets or agent_create.tool_exec_environment_variables if agent_secrets: - # Encrypt environment variable values - env_rows = [] - for key, val in agent_secrets.items(): - # Encrypt value (Secret.from_plaintext handles missing encryption key internally) - value_secret = Secret.from_plaintext(val) - row = { + # Encrypt environment variable values concurrently (async to avoid blocking event loop) + secrets_dict = await Secret.from_plaintexts_async(agent_secrets) + env_rows = [ + { "agent_id": aid, "key": key, "value": "", # Empty string for NOT NULL constraint (deprecated, use value_enc) - "value_enc": value_secret.get_encrypted(), + "value_enc": secret.get_encrypted(), "organization_id": actor.organization_id, } - env_rows.append(row) + for key, secret in secrets_dict.items() + ] result = await session.execute(insert(AgentEnvironmentVariable).values(env_rows).returning(AgentEnvironmentVariable.id)) env_rows = [{**row, "id": env_var_id} for row, env_var_id in zip(env_rows, result.scalars().all())] @@ -832,25 +831,35 @@ class AgentManager: # TODO: do we need to delete each time or can we just upsert? await session.execute(delete(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid)) - # Encrypt environment variable values - # Only re-encrypt if the value has actually changed + + # Decrypt existing values to check for changes (async to avoid blocking) + existing_values: dict[str, str | None] = {} + for k, existing_env in existing_env_vars.items(): + if existing_env.value_enc: + existing_secret = Secret.from_encrypted(existing_env.value_enc) + existing_values[k] = await existing_secret.get_plaintext_async() + else: + existing_values[k] = None + + # Identify values that need encryption (new or changed) + to_encrypt = { + k: v + for k, v in agent_secrets.items() + if k not in existing_env_vars or existing_values.get(k) != v or not existing_env_vars[k].value_enc + } + + # Batch encrypt new/changed values concurrently (async to avoid blocking event loop) + new_secrets = await Secret.from_plaintexts_async(to_encrypt) if to_encrypt else {} + + # Build rows, reusing existing encrypted values where unchanged env_rows = [] for k, v in agent_secrets.items(): - # Check if value changed to avoid unnecessary re-encryption - existing_env = existing_env_vars.get(k) - existing_value = None - if existing_env and existing_env.value_enc: - existing_secret = Secret.from_encrypted(existing_env.value_enc) - existing_value = await existing_secret.get_plaintext_async() - - # Encrypt value (reuse existing encrypted value if unchanged) - if existing_value == v and existing_env and existing_env.value_enc: - # Value unchanged, reuse existing encrypted value - value_enc = existing_env.value_enc + if k in new_secrets: + # New or changed value - use newly encrypted value + value_enc = new_secrets[k].get_encrypted() else: - # Value changed or new, encrypt - value_secret = Secret.from_plaintext(v) - value_enc = value_secret.get_encrypted() + # Value unchanged - reuse existing encrypted value + value_enc = existing_env_vars[k].value_enc row = { "agent_id": aid, diff --git a/letta/services/mcp/oauth_utils.py b/letta/services/mcp/oauth_utils.py index e4cbb992..b2c90d5c 100644 --- a/letta/services/mcp/oauth_utils.py +++ b/letta/services/mcp/oauth_utils.py @@ -150,9 +150,10 @@ class MCPOAuthSession: try: oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) - # Encrypt the authorization_code and store only in _enc column + # Encrypt the authorization_code and store only in _enc column (async to avoid blocking event loop) if code is not None: - oauth_record.authorization_code_enc = Secret.from_plaintext(code).get_encrypted() + code_secret = await Secret.from_plaintext_async(code) + oauth_record.authorization_code_enc = code_secret.get_encrypted() oauth_record.status = OAuthSessionStatus.AUTHORIZED oauth_record.state = state diff --git a/letta/services/mcp_server_manager.py b/letta/services/mcp_server_manager.py index a81c1ec0..f05caf80 100644 --- a/letta/services/mcp_server_manager.py +++ b/letta/services/mcp_server_manager.py @@ -448,15 +448,15 @@ class MCPServerManager: # Set the organization id at the ORM layer pydantic_mcp_server.organization_id = actor.organization_id - # Explicitly populate encrypted fields + # Explicitly populate encrypted fields (async to avoid blocking event loop) if pydantic_mcp_server.token is not None: - pydantic_mcp_server.token_enc = Secret.from_plaintext(pydantic_mcp_server.token) + pydantic_mcp_server.token_enc = await Secret.from_plaintext_async(pydantic_mcp_server.token) if pydantic_mcp_server.custom_headers is not None: # custom_headers is a Dict[str, str], serialize to JSON then encrypt import json json_str = json.dumps(pydantic_mcp_server.custom_headers) - pydantic_mcp_server.custom_headers_enc = Secret.from_plaintext(json_str) + pydantic_mcp_server.custom_headers_enc = await Secret.from_plaintext_async(json_str) mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True) @@ -517,15 +517,15 @@ class MCPServerManager: server_type=server_config.type, server_url=server_config.server_url, ) - # Encrypt sensitive fields + # Encrypt sensitive fields (async to avoid blocking event loop) token = server_config.resolve_token() if token: - token_secret = Secret.from_plaintext(token) + token_secret = await Secret.from_plaintext_async(token) mcp_server.set_token_secret(token_secret) if server_config.custom_headers: # Convert dict to JSON string, then encrypt as Secret headers_json = json.dumps(server_config.custom_headers) - headers_secret = Secret.from_plaintext(headers_json) + headers_secret = await Secret.from_plaintext_async(headers_json) mcp_server.set_custom_headers_secret(headers_secret) elif isinstance(server_config, StreamableHTTPServerConfig): @@ -534,15 +534,15 @@ class MCPServerManager: server_type=server_config.type, server_url=server_config.server_url, ) - # Encrypt sensitive fields + # Encrypt sensitive fields (async to avoid blocking event loop) token = server_config.resolve_token() if token: - token_secret = Secret.from_plaintext(token) + token_secret = await Secret.from_plaintext_async(token) mcp_server.set_token_secret(token_secret) if server_config.custom_headers: # Convert dict to JSON string, then encrypt as Secret headers_json = json.dumps(server_config.custom_headers) - headers_secret = Secret.from_plaintext(headers_json) + headers_secret = await Secret.from_plaintext_async(headers_json) mcp_server.set_custom_headers_secret(headers_secret) else: raise ValueError(f"Unsupported server config type: {type(server_config)}") @@ -698,9 +698,10 @@ class MCPServerManager: elif mcp_server.token: existing_token = mcp_server.token - # Only re-encrypt if different + # Only re-encrypt if different (async to avoid blocking event loop) if existing_token != update_data["token"]: - mcp_server.token_enc = Secret.from_plaintext(update_data["token"]).get_encrypted() + token_secret = await Secret.from_plaintext_async(update_data["token"]) + mcp_server.token_enc = token_secret.get_encrypted() # Keep plaintext for dual-write during migration mcp_server.token = update_data["token"] @@ -725,9 +726,10 @@ class MCPServerManager: elif mcp_server.custom_headers: existing_headers_json = json.dumps(mcp_server.custom_headers) - # Only re-encrypt if different + # Only re-encrypt if different (async to avoid blocking event loop) if existing_headers_json != json_str: - mcp_server.custom_headers_enc = Secret.from_plaintext(json_str).get_encrypted() + headers_secret = await Secret.from_plaintext_async(json_str) + mcp_server.custom_headers_enc = headers_secret.get_encrypted() # Keep plaintext for dual-write during migration mcp_server.custom_headers = update_data["custom_headers"] @@ -1117,9 +1119,10 @@ class MCPServerManager: elif oauth_session.authorization_code: existing_code = oauth_session.authorization_code - # Only re-encrypt if different + # Only re-encrypt if different (async to avoid blocking event loop) if existing_code != session_update.authorization_code: - oauth_session.authorization_code_enc = Secret.from_plaintext(session_update.authorization_code).get_encrypted() + code_secret = await Secret.from_plaintext_async(session_update.authorization_code) + oauth_session.authorization_code_enc = code_secret.get_encrypted() # Keep plaintext for dual-write during migration oauth_session.authorization_code = session_update.authorization_code @@ -1134,9 +1137,10 @@ class MCPServerManager: elif oauth_session.access_token: existing_token = oauth_session.access_token - # Only re-encrypt if different + # Only re-encrypt if different (async to avoid blocking event loop) if existing_token != session_update.access_token: - oauth_session.access_token_enc = Secret.from_plaintext(session_update.access_token).get_encrypted() + token_secret = await Secret.from_plaintext_async(session_update.access_token) + oauth_session.access_token_enc = token_secret.get_encrypted() # Keep plaintext for dual-write during migration oauth_session.access_token = session_update.access_token @@ -1151,9 +1155,10 @@ class MCPServerManager: elif oauth_session.refresh_token: existing_refresh = oauth_session.refresh_token - # Only re-encrypt if different + # Only re-encrypt if different (async to avoid blocking event loop) if existing_refresh != session_update.refresh_token: - oauth_session.refresh_token_enc = Secret.from_plaintext(session_update.refresh_token).get_encrypted() + refresh_secret = await Secret.from_plaintext_async(session_update.refresh_token) + oauth_session.refresh_token_enc = refresh_secret.get_encrypted() # Keep plaintext for dual-write during migration oauth_session.refresh_token = session_update.refresh_token @@ -1177,9 +1182,10 @@ class MCPServerManager: elif oauth_session.client_secret: existing_secret_val = oauth_session.client_secret - # Only re-encrypt if different + # Only re-encrypt if different (async to avoid blocking event loop) if existing_secret_val != session_update.client_secret: - oauth_session.client_secret_enc = Secret.from_plaintext(session_update.client_secret).get_encrypted() + client_secret_encrypted = await Secret.from_plaintext_async(session_update.client_secret) + oauth_session.client_secret_enc = client_secret_encrypted.get_encrypted() # Keep plaintext for dual-write during migration oauth_session.client_secret = session_update.client_secret diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 14515ac5..87fee46c 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -89,11 +89,13 @@ class ProviderManager: deleted_provider.region = request.region deleted_provider.api_version = request.api_version - # Update encrypted fields + # Update encrypted fields (async to avoid blocking event loop) if request.api_key is not None: - deleted_provider.api_key_enc = Secret.from_plaintext(request.api_key).get_encrypted() + api_key_secret = await Secret.from_plaintext_async(request.api_key) + deleted_provider.api_key_enc = api_key_secret.get_encrypted() if request.access_key is not None: - deleted_provider.access_key_enc = Secret.from_plaintext(request.access_key).get_encrypted() + access_key_secret = await Secret.from_plaintext_async(request.access_key) + deleted_provider.access_key_enc = access_key_secret.get_encrypted() await deleted_provider.update_async(session, actor=actor) provider_pydantic = deleted_provider.to_pydantic() @@ -125,11 +127,11 @@ class ProviderManager: # Lazily create the provider id prior to persistence provider.resolve_identifier() - # Explicitly populate encrypted fields from plaintext + # Explicitly populate encrypted fields from plaintext (async to avoid blocking event loop) if request.api_key is not None: - provider.api_key_enc = Secret.from_plaintext(request.api_key) + provider.api_key_enc = await Secret.from_plaintext_async(request.api_key) if request.access_key is not None: - provider.access_key_enc = Secret.from_plaintext(request.access_key) + provider.access_key_enc = await Secret.from_plaintext_async(request.access_key) new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True)) await new_provider.create_async(session, actor=actor) @@ -164,9 +166,10 @@ class ProviderManager: existing_secret = Secret.from_encrypted(existing_provider.api_key_enc) existing_api_key = await existing_secret.get_plaintext_async() - # Only re-encrypt if different + # Only re-encrypt if different (async to avoid blocking event loop) if existing_api_key != update_data["api_key"]: - existing_provider.api_key_enc = Secret.from_plaintext(update_data["api_key"]).get_encrypted() + api_key_secret = await Secret.from_plaintext_async(update_data["api_key"]) + existing_provider.api_key_enc = api_key_secret.get_encrypted() # Remove from update_data since we set directly on existing_provider update_data.pop("api_key", None) @@ -181,9 +184,10 @@ class ProviderManager: existing_secret = Secret.from_encrypted(existing_provider.access_key_enc) existing_access_key = await existing_secret.get_plaintext_async() - # Only re-encrypt if different + # Only re-encrypt if different (async to avoid blocking event loop) if existing_access_key != update_data["access_key"]: - existing_provider.access_key_enc = Secret.from_plaintext(update_data["access_key"]).get_encrypted() + access_key_secret = await Secret.from_plaintext_async(update_data["access_key"]) + existing_provider.access_key_enc = access_key_secret.get_encrypted() # Remove from update_data since we set directly on existing_provider update_data.pop("access_key", None) diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index 365ae612..c34611e3 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -210,11 +210,11 @@ class SandboxConfigManager: return db_env_var else: async with db_registry.async_session() as session: - # Encrypt the value before storing (only to value_enc, not plaintext) + # Encrypt the value before storing (async to avoid blocking event loop) from letta.schemas.secret import Secret if env_var.value: - env_var.value_enc = Secret.from_plaintext(env_var.value) + env_var.value_enc = await Secret.from_plaintext_async(env_var.value) env_var.value = "" # Don't store plaintext, use empty string for NOT NULL constraint env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True)) @@ -242,9 +242,10 @@ class SandboxConfigManager: existing_secret = Secret.from_encrypted(env_var.value_enc) existing_value = await existing_secret.get_plaintext_async() - # Only re-encrypt if different + # Only re-encrypt if different (async to avoid blocking event loop) if existing_value != update_data["value"]: - env_var.value_enc = Secret.from_plaintext(update_data["value"]).get_encrypted() + value_secret = await Secret.from_plaintext_async(update_data["value"]) + env_var.value_enc = value_secret.get_encrypted() # Don't store plaintext anymore # Remove from update_data since we set directly on env_var diff --git a/tests/test_secret.py b/tests/test_secret.py index 713e565a..0dbc80d9 100644 --- a/tests/test_secret.py +++ b/tests/test_secret.py @@ -216,3 +216,95 @@ class TestSecret: assert mock_decrypt.call_count == 1 finally: settings.encryption_key = original_key + + @pytest.mark.asyncio + async def test_from_plaintext_async_with_key(self): + """Test creating a Secret from plaintext value asynchronously with encryption key.""" + from letta.settings import settings + + # Set encryption key + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext = "my-async-secret-value" + + secret = await Secret.from_plaintext_async(plaintext) + + # Should store encrypted value + assert secret.encrypted_value is not None + assert secret.encrypted_value != plaintext + + # Should decrypt to original value + result = await secret.get_plaintext_async() + assert result == plaintext + finally: + settings.encryption_key = original_key + + @pytest.mark.asyncio + async def test_from_plaintext_async_without_key_stores_plaintext(self): + """Test creating a Secret asynchronously without encryption key stores as plaintext.""" + from letta.settings import settings + + # Clear encryption key + original_key = settings.encryption_key + settings.encryption_key = None + + try: + plaintext = "my-async-plaintext-value" + + # Should store as plaintext in _enc column when no encryption key + secret = await Secret.from_plaintext_async(plaintext) + + # Should store the plaintext value directly in encrypted_value + assert secret.encrypted_value == plaintext + result = await secret.get_plaintext_async() + assert result == plaintext + finally: + settings.encryption_key = original_key + + @pytest.mark.asyncio + async def test_from_plaintext_async_with_none(self): + """Test creating a Secret asynchronously from None value.""" + secret = await Secret.from_plaintext_async(None) + + assert secret.encrypted_value is None + result = await secret.get_plaintext_async() + assert result is None + assert secret.is_empty() is True + + @pytest.mark.asyncio + async def test_from_plaintexts_async(self): + """Test batch encrypting multiple secrets concurrently.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + values = { + "key1": "value1", + "key2": "value2", + "key3": "value3", + } + + secrets = await Secret.from_plaintexts_async(values) + + # Should return dict with same keys + assert set(secrets.keys()) == {"key1", "key2", "key3"} + + # Each secret should decrypt to original value + for key, secret in secrets.items(): + assert isinstance(secret, Secret) + assert secret.encrypted_value is not None + assert secret.encrypted_value != values[key] + result = await secret.get_plaintext_async() + assert result == values[key] + finally: + settings.encryption_key = original_key + + @pytest.mark.asyncio + async def test_from_plaintexts_async_empty_dict(self): + """Test batch encrypting with empty dict.""" + secrets = await Secret.from_plaintexts_async({}) + assert secrets == {}