diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index b49675eb..1da8bf3f 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -1868,7 +1868,8 @@ class LettaAgent(BaseAgent): start_time = get_utc_timestamp_ns() agent_step_span.add_event(name="tool_execution_started") - sandbox_env_vars = {var.key: var.value for var in agent_state.secrets} + # Decrypt environment variable values + sandbox_env_vars = {var.key: var.get_value_secret().get_plaintext() for var in agent_state.secrets} tool_execution_manager = ToolExecutionManager( agent_state=agent_state, message_manager=self.message_manager, diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index 1f56dd6d..0070dee1 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -1106,7 +1106,8 @@ class LettaAgentV2(BaseAgentV2): start_time = get_utc_timestamp_ns() agent_step_span.add_event(name="tool_execution_started") - sandbox_env_vars = {var.key: var.value for var in agent_state.secrets} + # Decrypt environment variable values + sandbox_env_vars = {var.key: var.get_value_secret().get_plaintext() for var in agent_state.secrets} tool_execution_manager = ToolExecutionManager( agent_state=agent_state, message_manager=self.message_manager, diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index c8f3415b..38d6c16a 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -437,7 +437,8 @@ class VoiceAgent(BaseAgent): ) # Use ToolExecutionManager for modern tool execution - sandbox_env_vars = {var.key: var.value for var in agent_state.secrets} + # Decrypt environment variable values + sandbox_env_vars = {var.key: var.get_value_secret().get_plaintext() for var in agent_state.secrets} tool_execution_manager = ToolExecutionManager( agent_state=agent_state, message_manager=self.message_manager, diff --git a/letta/schemas/mcp.py b/letta/schemas/mcp.py index e68d614e..856760fd 100644 --- a/letta/schemas/mcp.py +++ b/letta/schemas/mcp.py @@ -1,3 +1,4 @@ +import json from datetime import datetime from typing import Any, Dict, List, Optional, Union @@ -13,7 +14,7 @@ from letta.functions.mcp_client.types import ( ) from letta.orm.mcp_oauth import OAuthSessionStatus from letta.schemas.letta_base import LettaBase -from letta.schemas.secret import Secret, SecretDict +from letta.schemas.secret import Secret from letta.settings import settings @@ -31,8 +32,8 @@ class MCPServer(BaseMCPServer): token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for authentication)") custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs") - token_enc: Optional[str] = Field(None, description="Encrypted token") - custom_headers_enc: Optional[str] = Field(None, description="Encrypted custom headers") + token_enc: Secret | None = Field(None, description="Encrypted token as Secret object") + custom_headers_enc: Secret | None = Field(None, description="Encrypted custom headers as Secret object") # stdio config stdio_config: Optional[StdioServerConfig] = Field( @@ -48,55 +49,55 @@ class MCPServer(BaseMCPServer): def get_token_secret(self) -> Secret: """Get the token as a Secret object, preferring encrypted over plaintext.""" - return Secret.from_db(self.token_enc, self.token) + if self.token_enc is not None: + return self.token_enc + return Secret.from_db(None, self.token) - def get_custom_headers_secret(self) -> SecretDict: - """Get custom headers as a SecretDict object, preferring encrypted over plaintext.""" - return SecretDict.from_db(self.custom_headers_enc, self.custom_headers) + def get_custom_headers_secret(self) -> Secret: + """Get custom headers as a Secret object (stores JSON string), preferring encrypted over plaintext.""" + if self.custom_headers_enc is not None: + return self.custom_headers_enc + # Fallback: convert plaintext dict to JSON string and wrap in Secret + if self.custom_headers is not None: + json_str = json.dumps(self.custom_headers) + return Secret.from_plaintext(json_str) + return Secret.from_plaintext(None) + + def get_custom_headers_dict(self) -> Optional[Dict[str, str]]: + """Get custom headers as a plaintext dictionary.""" + secret = self.get_custom_headers_secret() + json_str = secret.get_plaintext() + if json_str: + try: + return json.loads(json_str) + except (json.JSONDecodeError, TypeError): + return None + return None def set_token_secret(self, secret: Secret) -> None: """Set token from a Secret object, updating both encrypted and plaintext fields.""" + self.token_enc = secret secret_dict = secret.to_dict() - self.token_enc = secret_dict["encrypted"] # Only set plaintext during migration phase - if not secret._was_encrypted: + if not secret.was_encrypted: self.token = secret_dict["plaintext"] else: self.token = None - def set_custom_headers_secret(self, secret: SecretDict) -> None: - """Set custom headers from a SecretDict object, updating both fields.""" + def set_custom_headers_secret(self, secret: Secret) -> None: + """Set custom headers from a Secret object (containing JSON string), updating both fields.""" + self.custom_headers_enc = secret secret_dict = secret.to_dict() - self.custom_headers_enc = secret_dict["encrypted"] - # Only set plaintext during migration phase - if not secret._was_encrypted: - self.custom_headers = secret_dict["plaintext"] + # Parse JSON string to dict for plaintext field + json_str = secret_dict.get("plaintext") + if json_str and not secret.was_encrypted: + try: + self.custom_headers = json.loads(json_str) + except (json.JSONDecodeError, TypeError): + self.custom_headers = None else: self.custom_headers = None - def model_dump(self, to_orm: bool = False, **kwargs): - """Override model_dump to handle encryption when saving to database.""" - data = super().model_dump(to_orm=to_orm, **kwargs) - - if to_orm and settings.encryption_key: - # Encrypt token if present - if self.token is not None: - token_secret = Secret.from_plaintext(self.token) - secret_dict = token_secret.to_dict() - data["token_enc"] = secret_dict["encrypted"] - # Keep plaintext for dual-write during migration - data["token"] = secret_dict["plaintext"] - - # Encrypt custom headers if present - if self.custom_headers is not None: - headers_secret = SecretDict.from_plaintext(self.custom_headers) - secret_dict = headers_secret.to_dict() - data["custom_headers_enc"] = secret_dict["encrypted"] - # Keep plaintext for dual-write during migration - data["custom_headers"] = secret_dict["plaintext"] - - return data - def to_config( self, environment_variables: Optional[Dict[str, str]] = None, @@ -106,8 +107,8 @@ class MCPServer(BaseMCPServer): token_secret = self.get_token_secret() token_plaintext = token_secret.get_plaintext() - headers_secret = self.get_custom_headers_secret() - headers_plaintext = headers_secret.get_plaintext() + # Get custom headers as dict + headers_plaintext = self.get_custom_headers_dict() if self.server_type == MCPServerType.SSE: config = SSEServerConfig( @@ -194,6 +195,9 @@ class MCPOAuthSession(BaseMCPOAuth): authorization_url: Optional[str] = Field(None, description="OAuth authorization URL") authorization_code: Optional[str] = Field(None, description="OAuth authorization code") + # Encrypted authorization code (for internal use) + authorization_code_enc: Secret | None = Field(None, description="Encrypted OAuth authorization code as Secret object") + # Token data access_token: Optional[str] = Field(None, description="OAuth access token") refresh_token: Optional[str] = Field(None, description="OAuth refresh token") @@ -202,8 +206,8 @@ class MCPOAuthSession(BaseMCPOAuth): scope: Optional[str] = Field(None, description="OAuth scope") # Encrypted token fields (for internal use) - access_token_enc: Optional[str] = Field(None, description="Encrypted OAuth access token") - refresh_token_enc: Optional[str] = Field(None, description="Encrypted OAuth refresh token") + access_token_enc: Secret | None = Field(None, description="Encrypted OAuth access token as Secret object") + refresh_token_enc: Secret | None = Field(None, description="Encrypted OAuth refresh token as Secret object") # Client configuration client_id: Optional[str] = Field(None, description="OAuth client ID") @@ -211,7 +215,7 @@ class MCPOAuthSession(BaseMCPOAuth): redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI") # Encrypted client secret (for internal use) - client_secret_enc: Optional[str] = Field(None, description="Encrypted OAuth client secret") + client_secret_enc: Secret | None = Field(None, description="Encrypted OAuth client secret as Secret object") # Session state status: OAuthSessionStatus = Field(default=OAuthSessionStatus.PENDING, description="Session status") @@ -222,73 +226,63 @@ class MCPOAuthSession(BaseMCPOAuth): def get_access_token_secret(self) -> Secret: """Get the access token as a Secret object, preferring encrypted over plaintext.""" - return Secret.from_db(self.access_token_enc, self.access_token) + if self.access_token_enc is not None: + return self.access_token_enc + return Secret.from_db(None, self.access_token) def get_refresh_token_secret(self) -> Secret: """Get the refresh token as a Secret object, preferring encrypted over plaintext.""" - return Secret.from_db(self.refresh_token_enc, self.refresh_token) + if self.refresh_token_enc is not None: + return self.refresh_token_enc + return Secret.from_db(None, self.refresh_token) def get_client_secret_secret(self) -> Secret: """Get the client secret as a Secret object, preferring encrypted over plaintext.""" - return Secret.from_db(self.client_secret_enc, self.client_secret) + if self.client_secret_enc is not None: + return self.client_secret_enc + return Secret.from_db(None, self.client_secret) + + def get_authorization_code_secret(self) -> Secret: + """Get the authorization code as a Secret object, preferring encrypted over plaintext.""" + if self.authorization_code_enc is not None: + return self.authorization_code_enc + return Secret.from_db(None, self.authorization_code) def set_access_token_secret(self, secret: Secret) -> None: """Set access token from a Secret object.""" + self.access_token_enc = secret secret_dict = secret.to_dict() - self.access_token_enc = secret_dict["encrypted"] - if not secret._was_encrypted: + if not secret.was_encrypted: self.access_token = secret_dict["plaintext"] else: self.access_token = None def set_refresh_token_secret(self, secret: Secret) -> None: """Set refresh token from a Secret object.""" + self.refresh_token_enc = secret secret_dict = secret.to_dict() - self.refresh_token_enc = secret_dict["encrypted"] - if not secret._was_encrypted: + if not secret.was_encrypted: self.refresh_token = secret_dict["plaintext"] else: self.refresh_token = None def set_client_secret_secret(self, secret: Secret) -> None: """Set client secret from a Secret object.""" + self.client_secret_enc = secret secret_dict = secret.to_dict() - self.client_secret_enc = secret_dict["encrypted"] - if not secret._was_encrypted: + if not secret.was_encrypted: self.client_secret = secret_dict["plaintext"] else: self.client_secret = None - def model_dump(self, to_orm: bool = False, **kwargs): - """Override model_dump to handle encryption when saving to database.""" - data = super().model_dump(to_orm=to_orm, **kwargs) - - if to_orm and settings.encryption_key: - # Encrypt access token if present - if self.access_token is not None: - token_secret = Secret.from_plaintext(self.access_token) - secret_dict = token_secret.to_dict() - data["access_token_enc"] = secret_dict["encrypted"] - # Keep plaintext for dual-write during migration - data["access_token"] = secret_dict["plaintext"] - - # Encrypt refresh token if present - if self.refresh_token is not None: - token_secret = Secret.from_plaintext(self.refresh_token) - secret_dict = token_secret.to_dict() - data["refresh_token_enc"] = secret_dict["encrypted"] - # Keep plaintext for dual-write during migration - data["refresh_token"] = secret_dict["plaintext"] - - # Encrypt client secret if present - if self.client_secret is not None: - secret = Secret.from_plaintext(self.client_secret) - secret_dict = secret.to_dict() - data["client_secret_enc"] = secret_dict["encrypted"] - # Keep plaintext for dual-write during migration - data["client_secret"] = secret_dict["plaintext"] - - return data + def set_authorization_code_secret(self, secret: Secret) -> None: + """Set authorization code from a Secret object.""" + self.authorization_code_enc = secret + secret_dict = secret.to_dict() + if not secret.was_encrypted: + self.authorization_code = secret_dict["plaintext"] + else: + self.authorization_code = None class MCPOAuthSessionCreate(BaseMCPOAuth): diff --git a/letta/schemas/providers/anthropic.py b/letta/schemas/providers/anthropic.py index c72f315e..9ed2ce5f 100644 --- a/letta/schemas/providers/anthropic.py +++ b/letta/schemas/providers/anthropic.py @@ -4,9 +4,11 @@ from typing import Literal import anthropic from pydantic import Field +from letta.errors import ErrorCode, LLMAuthenticationError, LLMError from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.llm_config import LLMConfig from letta.schemas.providers.base import Provider +from letta.settings import model_settings # https://docs.anthropic.com/claude/docs/models-overview # Sadly hardcoded @@ -98,8 +100,9 @@ class AnthropicProvider(Provider): base_url: str = "https://api.anthropic.com/v1" async def check_api_key(self): - if self.api_key: - anthropic_client = anthropic.Anthropic(api_key=self.api_key) + api_key = self.get_api_key_secret().get_plaintext() + if api_key: + anthropic_client = anthropic.Anthropic(api_key=api_key) try: # just use a cheap model to count some tokens - as of 5/7/2025 this is faster than fetching the list of models anthropic_client.messages.count_tokens(model=MODEL_LIST[-1]["name"], messages=[{"role": "user", "content": "a"}]) @@ -116,8 +119,9 @@ class AnthropicProvider(Provider): NOTE: currently there is no GET /models, so we need to hardcode """ - if self.api_key: - anthropic_client = anthropic.AsyncAnthropic(api_key=self.api_key) + api_key = self.get_api_key_secret().get_plaintext() + if api_key: + anthropic_client = anthropic.AsyncAnthropic(api_key=api_key) elif model_settings.anthropic_api_key: anthropic_client = anthropic.AsyncAnthropic() else: diff --git a/letta/schemas/providers/azure.py b/letta/schemas/providers/azure.py index 0da8c5fa..3264c7fd 100644 --- a/letta/schemas/providers/azure.py +++ b/letta/schemas/providers/azure.py @@ -60,7 +60,8 @@ class AzureProvider(Provider): def azure_openai_get_deployed_model_list(self) -> list: """https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP""" - client = AzureOpenAI(api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.base_url) + api_key = self.get_api_key_secret().get_plaintext() + client = AzureOpenAI(api_key=api_key, api_version=self.api_version, azure_endpoint=self.base_url) try: models_list = client.models.list() @@ -71,8 +72,8 @@ class AzureProvider(Provider): # https://xxx.openai.azure.com/openai/models?api-version=xxx headers = {"Content-Type": "application/json"} - if self.api_key is not None: - headers["api-key"] = f"{self.api_key}" + if api_key is not None: + headers["api-key"] = f"{api_key}" # 2. Get all the deployed models url = self.get_azure_deployment_list_endpoint() @@ -165,7 +166,8 @@ class AzureProvider(Provider): return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default) async def check_api_key(self): - if not self.api_key: + api_key = self.get_api_key_secret().get_plaintext() + if not api_key: raise ValueError("No API key provided") try: diff --git a/letta/schemas/providers/bedrock.py b/letta/schemas/providers/bedrock.py index 94b0ffa9..461b77fa 100644 --- a/letta/schemas/providers/bedrock.py +++ b/letta/schemas/providers/bedrock.py @@ -25,11 +25,15 @@ class BedrockProvider(Provider): from aioboto3.session import Session try: + # Decrypt credentials before using + access_key = self.get_access_key_secret().get_plaintext() + secret_key = self.get_api_key_secret().get_plaintext() + session = Session() async with session.client( "bedrock", - aws_access_key_id=self.access_key, - aws_secret_access_key=self.api_key, + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, region_name=self.region, ) as bedrock: response = await bedrock.list_inference_profiles() diff --git a/letta/schemas/providers/cerebras.py b/letta/schemas/providers/cerebras.py index 173dc4ba..85ef6d1a 100644 --- a/letta/schemas/providers/cerebras.py +++ b/letta/schemas/providers/cerebras.py @@ -38,7 +38,8 @@ class CerebrasProvider(OpenAIProvider): async def list_llm_models_async(self) -> list[LLMConfig]: from letta.llm_api.openai import openai_get_model_list_async - response = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + api_key = self.get_api_key_secret().get_plaintext() + response = await openai_get_model_list_async(self.base_url, api_key=api_key) if "data" in response: data = response["data"] diff --git a/letta/schemas/providers/deepseek.py b/letta/schemas/providers/deepseek.py index 0c1ae0c2..ac0144e3 100644 --- a/letta/schemas/providers/deepseek.py +++ b/letta/schemas/providers/deepseek.py @@ -34,7 +34,8 @@ class DeepSeekProvider(OpenAIProvider): async def list_llm_models_async(self) -> list[LLMConfig]: from letta.llm_api.openai import openai_get_model_list_async - response = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + api_key = self.get_api_key_secret().get_plaintext() + response = await openai_get_model_list_async(self.base_url, api_key=api_key) data = response.get("data", response) configs = [] diff --git a/letta/schemas/providers/google_gemini.py b/letta/schemas/providers/google_gemini.py index 6404e0fc..ba7a2021 100644 --- a/letta/schemas/providers/google_gemini.py +++ b/letta/schemas/providers/google_gemini.py @@ -19,13 +19,15 @@ class GoogleAIProvider(Provider): async def check_api_key(self): from letta.llm_api.google_ai_client import google_ai_check_valid_api_key - google_ai_check_valid_api_key(self.api_key) + api_key = self.get_api_key_secret().get_plaintext() + google_ai_check_valid_api_key(api_key) async def list_llm_models_async(self): from letta.llm_api.google_ai_client import google_ai_get_model_list_async # Get and filter the model list - model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) + api_key = self.get_api_key_secret().get_plaintext() + model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=api_key) model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] model_options = [str(m["name"]) for m in model_options] @@ -58,7 +60,8 @@ class GoogleAIProvider(Provider): from letta.llm_api.google_ai_client import google_ai_get_model_list_async # TODO: use base_url instead - model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) + api_key = self.get_api_key_secret().get_plaintext() + model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=api_key) return self._list_embedding_models(model_options) def _list_embedding_models(self, model_options): @@ -91,7 +94,8 @@ class GoogleAIProvider(Provider): if model_name in LLM_MAX_TOKENS: return LLM_MAX_TOKENS[model_name] else: - return google_ai_get_model_context_window(self.base_url, self.api_key, model_name) + api_key = self.get_api_key_secret().get_plaintext() + return google_ai_get_model_context_window(self.base_url, api_key, model_name) async def get_model_context_window_async(self, model_name: str) -> int | None: from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async @@ -99,4 +103,5 @@ class GoogleAIProvider(Provider): if model_name in LLM_MAX_TOKENS: return LLM_MAX_TOKENS[model_name] else: - return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name) + api_key = self.get_api_key_secret().get_plaintext() + return await google_ai_get_model_context_window_async(self.base_url, api_key, model_name) diff --git a/letta/schemas/providers/groq.py b/letta/schemas/providers/groq.py index 18b4cb31..9945e4ff 100644 --- a/letta/schemas/providers/groq.py +++ b/letta/schemas/providers/groq.py @@ -16,7 +16,8 @@ class GroqProvider(OpenAIProvider): async def list_llm_models_async(self) -> list[LLMConfig]: from letta.llm_api.openai import openai_get_model_list_async - response = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + api_key = self.get_api_key_secret().get_plaintext() + response = await openai_get_model_list_async(self.base_url, api_key=api_key) configs = [] for model in response["data"]: if "context_window" not in model: diff --git a/letta/schemas/providers/mistral.py b/letta/schemas/providers/mistral.py index 2eeb3a23..c4777eba 100644 --- a/letta/schemas/providers/mistral.py +++ b/letta/schemas/providers/mistral.py @@ -18,7 +18,8 @@ class MistralProvider(Provider): # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... # See: https://openrouter.ai/docs/requests - response = await mistral_get_model_list_async(self.base_url, api_key=self.api_key) + api_key = self.get_api_key_secret().get_plaintext() + response = await mistral_get_model_list_async(self.base_url, api_key=api_key) assert "data" in response, f"Mistral model query response missing 'data' field: {response}" diff --git a/letta/schemas/providers/openai.py b/letta/schemas/providers/openai.py index d4f2fce9..52ff323d 100644 --- a/letta/schemas/providers/openai.py +++ b/letta/schemas/providers/openai.py @@ -25,7 +25,9 @@ class OpenAIProvider(Provider): async def check_api_key(self): from letta.llm_api.openai import openai_check_valid_api_key # TODO: DO NOT USE THIS - old code path - openai_check_valid_api_key(self.base_url, self.api_key) + # Decrypt API key before using + api_key = self.get_api_key_secret().get_plaintext() + openai_check_valid_api_key(self.base_url, api_key) async def _get_models_async(self) -> list[dict]: from letta.llm_api.openai import openai_get_model_list_async @@ -37,9 +39,12 @@ class OpenAIProvider(Provider): # Similar to Nebius extra_params = {"verbose": True} if "nebius.com" in self.base_url else None + # Decrypt API key before using + api_key = self.get_api_key_secret().get_plaintext() + response = await openai_get_model_list_async( self.base_url, - api_key=self.api_key, + api_key=api_key, extra_params=extra_params, # fix_url=True, # NOTE: make sure together ends with /v1 ) diff --git a/letta/schemas/providers/together.py b/letta/schemas/providers/together.py index 4b0259f5..6dc7b083 100644 --- a/letta/schemas/providers/together.py +++ b/letta/schemas/providers/together.py @@ -26,7 +26,8 @@ class TogetherProvider(OpenAIProvider): async def list_llm_models_async(self) -> list[LLMConfig]: from letta.llm_api.openai import openai_get_model_list_async - models = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + api_key = self.get_api_key_secret().get_plaintext() + models = await openai_get_model_list_async(self.base_url, api_key=api_key) return self._list_llm_models(models) async def list_embedding_models_async(self) -> list[EmbeddingConfig]: @@ -88,7 +89,8 @@ class TogetherProvider(OpenAIProvider): return configs async def check_api_key(self): - if not self.api_key: + api_key = self.get_api_key_secret().get_plaintext() + if not api_key: raise ValueError("No API key provided") try: diff --git a/letta/schemas/providers/xai.py b/letta/schemas/providers/xai.py index d042aad0..ed8902ed 100644 --- a/letta/schemas/providers/xai.py +++ b/letta/schemas/providers/xai.py @@ -32,7 +32,8 @@ class XAIProvider(OpenAIProvider): async def list_llm_models_async(self) -> list[LLMConfig]: from letta.llm_api.openai import openai_get_model_list_async - response = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + api_key = self.get_api_key_secret().get_plaintext() + response = await openai_get_model_list_async(self.base_url, api_key=api_key) data = response.get("data", response) diff --git a/letta/schemas/secret.py b/letta/schemas/secret.py index 8986a1f5..687b8896 100644 --- a/letta/schemas/secret.py +++ b/letta/schemas/secret.py @@ -271,11 +271,14 @@ class Secret(BaseModel): def __get_pydantic_json_schema__(cls, core_schema: core_schema.CoreSchema, handler) -> Dict[str, Any]: """ Define JSON schema representation for Secret fields. + In JSON schema (OpenAPI docs), Secret fields appear as nullable strings. The actual encryption/decryption happens at runtime via __get_pydantic_core_schema__. + Args: core_schema: The core schema for this type handler: Handler for generating JSON schema + Returns: A JSON schema dict representing this type as a nullable string """ @@ -285,161 +288,3 @@ class Secret(BaseModel): "nullable": True, "description": "Encrypted secret value (stored as encrypted string)", } - - -class SecretDict(BaseModel): - """ - A wrapper for dictionaries containing sensitive key-value pairs. - - Used for custom headers and other key-value configurations. - - TODO: Once we deprecate plaintext columns in the database: - - Remove the dual-write logic in to_dict() - - Remove the from_db() method's plaintext_value parameter - - Remove the _was_encrypted flag (no longer needed for migration) - - Simplify get_plaintext() to only handle encrypted JSON values - """ - - _encrypted_value: Optional[str] = PrivateAttr(default=None) - _plaintext_cache: Optional[Dict[str, str]] = PrivateAttr(default=None) - _was_encrypted: bool = PrivateAttr(default=False) - - model_config = ConfigDict(frozen=True) - - @classmethod - def from_plaintext(cls, value: Optional[Dict[str, str]]) -> "SecretDict": - """Create a SecretDict from a plaintext dictionary.""" - if value is None: - instance = cls() - instance._encrypted_value = None - instance._was_encrypted = False - return instance - - # Serialize to JSON then try to encrypt - json_str = json.dumps(value) - try: - encrypted = CryptoUtils.encrypt(json_str) - instance = cls() - instance._encrypted_value = encrypted - instance._was_encrypted = False - return instance - except ValueError as e: - # No encryption key available, store as plaintext JSON - if "No encryption key configured" in str(e): - logger.warning( - "No encryption key configured. Storing SecretDict value as plaintext JSON. " - "Set LETTA_ENCRYPTION_KEY environment variable to enable encryption." - ) - instance = cls() - instance._encrypted_value = json_str # Store JSON string - instance._plaintext_cache = value # Cache the dict - instance._was_encrypted = False - return instance - raise # Re-raise if it's a different error - - @classmethod - def from_encrypted(cls, encrypted_value: Optional[str]) -> "SecretDict": - """Create a SecretDict from an encrypted value.""" - instance = cls() - instance._encrypted_value = encrypted_value - instance._was_encrypted = True - return instance - - @classmethod - def from_db(cls, encrypted_value: Optional[str], plaintext_value: Optional[Dict[str, str]]) -> "SecretDict": - """Create a SecretDict from database values during migration phase.""" - if encrypted_value is not None: - return cls.from_encrypted(encrypted_value) - elif plaintext_value is not None: - return cls.from_plaintext(plaintext_value) - else: - return cls.from_plaintext(None) - - def get_encrypted(self) -> Optional[str]: - """Get the encrypted value.""" - return self._encrypted_value - - def get_plaintext(self) -> Optional[Dict[str, str]]: - """Get the decrypted dictionary.""" - if self._encrypted_value is None: - return None - - # Use cached value if available, but only if it looks like plaintext - # or we're confident we can decrypt it - if self._plaintext_cache is not None: - # If we have a cache but the stored value looks encrypted and we have no key, - # we should not use the cache - if CryptoUtils.is_encrypted(self._encrypted_value) and not CryptoUtils.is_encryption_available(): - self._plaintext_cache = None # Clear invalid cache - else: - return self._plaintext_cache - - try: - decrypted_json = CryptoUtils.decrypt(self._encrypted_value) - plaintext_dict = json.loads(decrypted_json) - # Cache the decrypted value (PrivateAttr fields can be mutated even with frozen=True) - self._plaintext_cache = plaintext_dict - return plaintext_dict - except ValueError as e: - error_msg = str(e) - - # Handle missing encryption key - if "No encryption key configured" in error_msg: - # Check if the value looks encrypted - if CryptoUtils.is_encrypted(self._encrypted_value): - # Value was encrypted, but now we have no key - can't decrypt - logger.warning( - "Cannot decrypt SecretDict value - no encryption key configured. " - "The value was encrypted and requires the original key to decrypt." - ) - # Return None to indicate we can't get the plaintext - return None - else: - # Value is plaintext JSON (stored when no key was available) - logger.debug("SecretDict value is plaintext JSON (stored without encryption)") - try: - plaintext_dict = json.loads(self._encrypted_value) - self._plaintext_cache = plaintext_dict - return plaintext_dict - except json.JSONDecodeError: - logger.error("Failed to parse SecretDict plaintext as JSON") - return None - - # Handle decryption failure (might be plaintext JSON) - elif "Failed to decrypt data" in error_msg: - # Check if it might be plaintext JSON - if not CryptoUtils.is_encrypted(self._encrypted_value): - # It's plaintext JSON that was stored when no key was available - logger.debug("SecretDict value appears to be plaintext JSON (stored without encryption)") - try: - plaintext_dict = json.loads(self._encrypted_value) - self._plaintext_cache = plaintext_dict - return plaintext_dict - except json.JSONDecodeError: - logger.error("Failed to parse SecretDict plaintext as JSON") - return None - # Otherwise, it's corrupted or wrong key - logger.error("Failed to decrypt SecretDict value - data may be corrupted or wrong key") - raise - - # Migration case: handle legacy plaintext - elif not self._was_encrypted: - if self._encrypted_value: - try: - plaintext_dict = json.loads(self._encrypted_value) - self._plaintext_cache = plaintext_dict - return plaintext_dict - except json.JSONDecodeError: - pass - return None - - # Re-raise for other errors - raise - - def is_empty(self) -> bool: - """Check if the secret dict is empty/None.""" - return self._encrypted_value is None - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary for database storage.""" - return {"encrypted": self.get_encrypted(), "plaintext": self.get_plaintext() if not self._was_encrypted else None} diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 20240338..d92aec16 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -64,6 +64,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message, Message as PydanticMessage, MessageCreate, MessageUpdate from letta.schemas.passage import Passage as PydanticPassage +from letta.schemas.secret import Secret from letta.schemas.source import Source as PydanticSource from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool_rule import ContinueToolRule, RequiresApprovalToolRule, TerminalToolRule @@ -521,16 +522,22 @@ class AgentManager: env_rows = [] agent_secrets = agent_create.secrets or agent_create.tool_exec_environment_variables + if agent_secrets: - env_rows = [ - { + # Encrypt environment variable values + env_rows = [] + for key, val in agent_secrets.items(): + row = { "agent_id": aid, "key": key, "value": val, "organization_id": actor.organization_id, } - for key, val in agent_secrets.items() - ] + # Encrypt value (Secret.from_plaintext handles missing encryption key internally) + value_secret = Secret.from_plaintext(val) + row["value_enc"] = value_secret.get_encrypted() + env_rows.append(row) + result = await session.execute(insert(AgentEnvironmentVariable).values(env_rows).returning(AgentEnvironmentVariable.id)) env_rows = [{**row, "id": env_var_id} for row, env_var_id in zip(env_rows, result.scalars().all())] @@ -742,16 +749,44 @@ class AgentManager: agent_secrets = agent_update.secrets or agent_update.tool_exec_environment_variables if agent_secrets is not None: + # Fetch existing environment variables to check if values changed + result = await session.execute(select(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid)) + existing_env_vars = {env.key: env for env in result.scalars().all()} + + # TODO: do we need to delete each time or can we just upsert? await session.execute(delete(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid)) - env_rows = [ - { + # Encrypt environment variable values + # Only re-encrypt if the value has actually changed + env_rows = [] + for k, v in agent_secrets.items(): + row = { "agent_id": aid, "key": k, "value": v, "organization_id": agent.organization_id, } - for k, v in agent_secrets.items() - ] + + # Check if value changed to avoid unnecessary re-encryption + existing_env = existing_env_vars.get(k) + existing_value = None + if existing_env: + if existing_env.value_enc: + existing_secret = Secret.from_encrypted(existing_env.value_enc) + existing_value = existing_secret.get_plaintext() + elif existing_env.value: + existing_value = existing_env.value + + # Encrypt value (reuse existing encrypted value if unchanged) + if existing_value == v and existing_env and existing_env.value_enc: + # Value unchanged, reuse existing encrypted value + row["value_enc"] = existing_env.value_enc + else: + # Value changed or new, encrypt + value_secret = Secret.from_plaintext(v) + row["value_enc"] = value_secret.get_encrypted() + + env_rows.append(row) + if env_rows: await self._bulk_insert_pivot_async(session, AgentEnvironmentVariable.__table__, env_rows) session.expire(agent, ["tool_exec_environment_variables"]) diff --git a/letta/services/mcp/oauth_utils.py b/letta/services/mcp/oauth_utils.py index cab6f833..5ff6085b 100644 --- a/letta/services/mcp/oauth_utils.py +++ b/letta/services/mcp/oauth_utils.py @@ -34,12 +34,21 @@ class DatabaseTokenStorage(TokenStorage): async def get_tokens(self) -> Optional[OAuthToken]: """Retrieve tokens from database.""" oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor) - if not oauth_session or not oauth_session.access_token: + if not oauth_session: return None + # Decrypt tokens using getter methods + access_token_secret = oauth_session.get_access_token_secret() + access_token = access_token_secret.get_plaintext() + if not access_token: + return None + + refresh_token_secret = oauth_session.get_refresh_token_secret() + refresh_token = refresh_token_secret.get_plaintext() + return OAuthToken( - access_token=oauth_session.access_token, - refresh_token=oauth_session.refresh_token, + access_token=access_token, + refresh_token=refresh_token, token_type=oauth_session.token_type, expires_in=int(oauth_session.expires_at.timestamp() - time.time()), scope=oauth_session.scope, @@ -63,9 +72,13 @@ class DatabaseTokenStorage(TokenStorage): if not oauth_session or not oauth_session.client_id: return None + # Decrypt client secret using getter method + client_secret_secret = oauth_session.get_client_secret_secret() + client_secret = client_secret_secret.get_plaintext() + return OAuthClientInformationFull( client_id=oauth_session.client_id, - client_secret=oauth_session.client_secret, + client_secret=client_secret, redirect_uris=[oauth_session.redirect_uri] if oauth_session.redirect_uri else [], ) @@ -134,13 +147,23 @@ class MCPOAuthSession: async def store_authorization_code(self, code: str, state: str) -> Optional[MCPOAuth]: """Store the authorization code from OAuth callback.""" + # Use mcp_manager to ensure proper encryption + from letta.schemas.mcp import MCPOAuthSessionUpdate + from letta.schemas.secret import Secret + async with db_registry.async_session() as session: try: oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) - oauth_record.authorization_code = code - oauth_record.state = state + + # Encrypt the authorization_code before storing + if code is not None: + oauth_record.authorization_code_enc = Secret.from_plaintext(code).get_encrypted() + # Keep plaintext for dual-write during migration + oauth_record.authorization_code = code + oauth_record.status = OAuthSessionStatus.AUTHORIZED - oauth_record.updated_at = datetime.now() + oauth_record.state = state + return await oauth_record.update_async(db_session=session, actor=None) except Exception: return None @@ -212,7 +235,9 @@ async def create_oauth_provider( while time.time() - start_time < timeout: oauth_session = await mcp_manager.get_oauth_session_by_id(session_id, actor) if oauth_session and oauth_session.authorization_code: - return oauth_session.authorization_code, oauth_session.state + # Decrypt the authorization code before returning + auth_code_secret = oauth_session.get_authorization_code_secret() + return auth_code_secret.get_plaintext(), oauth_session.state elif oauth_session and oauth_session.status == OAuthSessionStatus.ERROR: raise Exception("OAuth authorization failed") await asyncio.sleep(1) diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 8790f909..c1e1fa10 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -36,7 +36,7 @@ from letta.schemas.mcp import ( UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer, ) -from letta.schemas.secret import Secret, SecretDict +from letta.schemas.secret import Secret from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry @@ -44,7 +44,7 @@ from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPCl from letta.services.mcp.stdio_client import AsyncStdioMCPClient from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient from letta.services.tool_manager import ToolManager -from letta.settings import tool_settings +from letta.settings import settings, tool_settings from letta.utils import enforce_types, printd, safe_create_task logger = get_logger(__name__) @@ -318,6 +318,7 @@ class MCPManager: update_data = pydantic_mcp_server.model_dump(exclude_unset=True, exclude_none=True) # If there's anything to update (can only update the configs, not the name) + # TODO: pass in custom headers for update as well? if update_data: if pydantic_mcp_server.server_type == MCPServerType.SSE: update_request = UpdateSSEMCPServer(server_url=pydantic_mcp_server.server_url, token=pydantic_mcp_server.token) @@ -325,7 +326,7 @@ class MCPManager: update_request = UpdateStdioMCPServer(stdio_config=pydantic_mcp_server.stdio_config) elif pydantic_mcp_server.server_type == MCPServerType.STREAMABLE_HTTP: update_request = UpdateStreamableHTTPMCPServer( - server_url=pydantic_mcp_server.server_url, token=pydantic_mcp_server.token + server_url=pydantic_mcp_server.server_url, auth_token=pydantic_mcp_server.token ) else: raise ValueError(f"Unsupported server type: {pydantic_mcp_server.server_type}") @@ -347,6 +348,17 @@ class MCPManager: try: # Set the organization id at the ORM layer pydantic_mcp_server.organization_id = actor.organization_id + + # Explicitly populate encrypted fields + if pydantic_mcp_server.token is not None: + pydantic_mcp_server.token_enc = Secret.from_plaintext(pydantic_mcp_server.token) + if pydantic_mcp_server.custom_headers is not None: + # custom_headers is a Dict[str, str], serialize to JSON then encrypt + import json + + json_str = json.dumps(pydantic_mcp_server.custom_headers) + pydantic_mcp_server.custom_headers_enc = Secret.from_plaintext(json_str) + mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True) # Ensure custom_headers None is stored as SQL NULL, not JSON null @@ -412,7 +424,9 @@ class MCPManager: token_secret = Secret.from_plaintext(token) mcp_server.set_token_secret(token_secret) if server_config.custom_headers: - headers_secret = SecretDict.from_plaintext(server_config.custom_headers) + # Convert dict to JSON string, then encrypt as Secret + headers_json = json.dumps(server_config.custom_headers) + headers_secret = Secret.from_plaintext(headers_json) mcp_server.set_custom_headers_secret(headers_secret) elif isinstance(server_config, StreamableHTTPServerConfig): @@ -427,7 +441,9 @@ class MCPManager: token_secret = Secret.from_plaintext(token) mcp_server.set_token_secret(token_secret) if server_config.custom_headers: - headers_secret = SecretDict.from_plaintext(server_config.custom_headers) + # Convert dict to JSON string, then encrypt as Secret + headers_json = json.dumps(server_config.custom_headers) + headers_secret = Secret.from_plaintext(headers_json) mcp_server.set_custom_headers_secret(headers_secret) else: raise ValueError(f"Unsupported server config type: {type(server_config)}") @@ -517,27 +533,52 @@ class MCPManager: update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True) # Handle encryption for token if provided + # Only re-encrypt if the value has actually changed if "token" in update_data and update_data["token"] is not None: - token_secret = Secret.from_plaintext(update_data["token"]) - secret_dict = token_secret.to_dict() - update_data["token_enc"] = secret_dict["encrypted"] - # During migration phase, also update plaintext - if not token_secret._was_encrypted: - update_data["token"] = secret_dict["plaintext"] - else: - update_data["token"] = None + # Check if value changed + existing_token = None + if mcp_server.token_enc: + existing_secret = Secret.from_encrypted(mcp_server.token_enc) + existing_token = existing_secret.get_plaintext() + elif mcp_server.token: + existing_token = mcp_server.token + + # Only re-encrypt if different + if existing_token != update_data["token"]: + mcp_server.token_enc = Secret.from_plaintext(update_data["token"]).get_encrypted() + # Keep plaintext for dual-write during migration + mcp_server.token = update_data["token"] + + # Remove from update_data since we set directly on mcp_server + update_data.pop("token", None) + update_data.pop("token_enc", None) # Handle encryption for custom_headers if provided + # Only re-encrypt if the value has actually changed if "custom_headers" in update_data: if update_data["custom_headers"] is not None: - headers_secret = SecretDict.from_plaintext(update_data["custom_headers"]) - secret_dict = headers_secret.to_dict() - update_data["custom_headers_enc"] = secret_dict["encrypted"] - # During migration phase, also update plaintext - if not headers_secret._was_encrypted: - update_data["custom_headers"] = secret_dict["plaintext"] - else: - update_data["custom_headers"] = None + # custom_headers is a Dict[str, str], serialize to JSON then encrypt + import json + + json_str = json.dumps(update_data["custom_headers"]) + + # Check if value changed + existing_headers_json = None + if mcp_server.custom_headers_enc: + existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc) + existing_headers_json = existing_secret.get_plaintext() + elif mcp_server.custom_headers: + existing_headers_json = json.dumps(mcp_server.custom_headers) + + # Only re-encrypt if different + if existing_headers_json != json_str: + mcp_server.custom_headers_enc = Secret.from_plaintext(json_str).get_encrypted() + # Keep plaintext for dual-write during migration + mcp_server.custom_headers = update_data["custom_headers"] + + # Remove from update_data since we set directly on mcp_server + update_data.pop("custom_headers", None) + update_data.pop("custom_headers_enc", None) else: # Ensure custom_headers None is stored as SQL NULL, not JSON null update_data.pop("custom_headers", None) @@ -758,7 +799,8 @@ class MCPManager: # If no OAuth provider is provided, check if we have stored OAuth credentials if oauth_provider is None and hasattr(server_config, "server_url"): oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor) - if oauth_session and oauth_session.access_token: + # Check if access token exists by attempting to decrypt it + if oauth_session and oauth_session.get_access_token_secret().get_plaintext(): # Create OAuth provider from stored credentials from letta.services.mcp.oauth_utils import create_oauth_provider @@ -787,8 +829,6 @@ class MCPManager: """ Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields. """ - from letta.settings import settings - # Get decrypted values using the dual-read approach # Secret.from_db() will automatically use settings.encryption_key if available access_token = None @@ -818,7 +858,17 @@ class MCPManager: # No encryption key, use plaintext if available client_secret = oauth_session.client_secret - return MCPOAuthSession( + authorization_code = None + if oauth_session.authorization_code_enc or oauth_session.authorization_code: + if settings.encryption_key: + secret = Secret.from_db(oauth_session.authorization_code_enc, oauth_session.authorization_code) + authorization_code = secret.get_plaintext() + else: + # No encryption key, use plaintext if available + authorization_code = oauth_session.authorization_code + + # Create the Pydantic object with encrypted fields as Secret objects + pydantic_session = MCPOAuthSession( id=oauth_session.id, state=oauth_session.state, server_id=oauth_session.server_id, @@ -827,7 +877,7 @@ class MCPManager: user_id=oauth_session.user_id, organization_id=oauth_session.organization_id, authorization_url=oauth_session.authorization_url, - authorization_code=oauth_session.authorization_code, + authorization_code=authorization_code, access_token=access_token, refresh_token=refresh_token, token_type=oauth_session.token_type, @@ -839,7 +889,15 @@ class MCPManager: status=oauth_session.status, created_at=oauth_session.created_at, updated_at=oauth_session.updated_at, + # Encrypted fields as Secret objects (converted from encrypted strings in DB) + authorization_code_enc=Secret.from_encrypted(oauth_session.authorization_code_enc) + if oauth_session.authorization_code_enc + else None, + access_token_enc=Secret.from_encrypted(oauth_session.access_token_enc) if oauth_session.access_token_enc else None, + refresh_token_enc=Secret.from_encrypted(oauth_session.refresh_token_enc) if oauth_session.refresh_token_enc else None, + client_secret_enc=Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None, ) + return pydantic_session @enforce_types async def create_oauth_session(self, session_create: MCPOAuthSessionCreate, actor: PydanticUser) -> MCPOAuthSession: @@ -905,38 +963,57 @@ class MCPManager: # Update fields that are provided if session_update.authorization_url is not None: oauth_session.authorization_url = session_update.authorization_url + + # Handle encryption for authorization_code + # Only re-encrypt if the value has actually changed if session_update.authorization_code is not None: - oauth_session.authorization_code = session_update.authorization_code + # Check if value changed + existing_code = None + if oauth_session.authorization_code_enc: + existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc) + existing_code = existing_secret.get_plaintext() + elif oauth_session.authorization_code: + existing_code = oauth_session.authorization_code + + # Only re-encrypt if different + if existing_code != session_update.authorization_code: + oauth_session.authorization_code_enc = Secret.from_plaintext(session_update.authorization_code).get_encrypted() + # Keep plaintext for dual-write during migration + oauth_session.authorization_code = session_update.authorization_code # Handle encryption for access_token + # Only re-encrypt if the value has actually changed if session_update.access_token is not None: - from letta.settings import settings + # Check if value changed + existing_token = None + if oauth_session.access_token_enc: + existing_secret = Secret.from_encrypted(oauth_session.access_token_enc) + existing_token = existing_secret.get_plaintext() + elif oauth_session.access_token: + existing_token = oauth_session.access_token - if settings.encryption_key: - token_secret = Secret.from_plaintext(session_update.access_token) - secret_dict = token_secret.to_dict() - oauth_session.access_token_enc = secret_dict["encrypted"] - # During migration phase, also update plaintext - oauth_session.access_token = secret_dict["plaintext"] if not token_secret._was_encrypted else None - else: - # No encryption, store plaintext + # Only re-encrypt if different + if existing_token != session_update.access_token: + oauth_session.access_token_enc = Secret.from_plaintext(session_update.access_token).get_encrypted() + # Keep plaintext for dual-write during migration oauth_session.access_token = session_update.access_token - oauth_session.access_token_enc = None # Handle encryption for refresh_token + # Only re-encrypt if the value has actually changed if session_update.refresh_token is not None: - from letta.settings import settings + # Check if value changed + existing_refresh = None + if oauth_session.refresh_token_enc: + existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc) + existing_refresh = existing_secret.get_plaintext() + elif oauth_session.refresh_token: + existing_refresh = oauth_session.refresh_token - if settings.encryption_key: - token_secret = Secret.from_plaintext(session_update.refresh_token) - secret_dict = token_secret.to_dict() - oauth_session.refresh_token_enc = secret_dict["encrypted"] - # During migration phase, also update plaintext - oauth_session.refresh_token = secret_dict["plaintext"] if not token_secret._was_encrypted else None - else: - # No encryption, store plaintext + # Only re-encrypt if different + if existing_refresh != session_update.refresh_token: + oauth_session.refresh_token_enc = Secret.from_plaintext(session_update.refresh_token).get_encrypted() + # Keep plaintext for dual-write during migration oauth_session.refresh_token = session_update.refresh_token - oauth_session.refresh_token_enc = None if session_update.token_type is not None: oauth_session.token_type = session_update.token_type @@ -948,19 +1025,21 @@ class MCPManager: oauth_session.client_id = session_update.client_id # Handle encryption for client_secret + # Only re-encrypt if the value has actually changed if session_update.client_secret is not None: - from letta.settings import settings + # Check if value changed + existing_secret_val = None + if oauth_session.client_secret_enc: + existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc) + existing_secret_val = existing_secret.get_plaintext() + elif oauth_session.client_secret: + existing_secret_val = oauth_session.client_secret - if settings.encryption_key: - secret_secret = Secret.from_plaintext(session_update.client_secret) - secret_dict = secret_secret.to_dict() - oauth_session.client_secret_enc = secret_dict["encrypted"] - # During migration phase, also update plaintext - oauth_session.client_secret = secret_dict["plaintext"] if not secret_secret._was_encrypted else None - else: - # No encryption, store plaintext + # Only re-encrypt if different + if existing_secret_val != session_update.client_secret: + oauth_session.client_secret_enc = Secret.from_plaintext(session_update.client_secret).get_encrypted() + # Keep plaintext for dual-write during migration oauth_session.client_secret = session_update.client_secret - oauth_session.client_secret_enc = None if session_update.redirect_uri is not None: oauth_session.redirect_uri = session_update.redirect_uri diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index ff4e0406..108b9a19 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -4,6 +4,7 @@ from letta.orm.provider import Provider as ProviderModel from letta.otel.tracing import trace_method from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.providers import Provider as PydanticProvider, ProviderCheck, ProviderCreate, ProviderUpdate +from letta.schemas.secret import Secret from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.utils import enforce_types @@ -27,6 +28,12 @@ class ProviderManager: # Lazily create the provider id prior to persistence provider.resolve_identifier() + # Explicitly populate encrypted fields from plaintext + if provider.api_key is not None: + provider.api_key_enc = Secret.from_plaintext(provider.api_key) + if provider.access_key is not None: + provider.access_key_enc = Secret.from_plaintext(provider.access_key) + new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True)) await new_provider.create_async(session, actor=actor) return new_provider.to_pydantic() @@ -43,6 +50,50 @@ class ProviderManager: # Update only the fields that are provided in ProviderUpdate update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + + # Handle encryption for api_key if provided + # Only re-encrypt if the value has actually changed + if "api_key" in update_data and update_data["api_key"] is not None: + # Check if value changed + existing_api_key = None + if existing_provider.api_key_enc: + existing_secret = Secret.from_encrypted(existing_provider.api_key_enc) + existing_api_key = existing_secret.get_plaintext() + elif existing_provider.api_key: + existing_api_key = existing_provider.api_key + + # Only re-encrypt if different + if existing_api_key != update_data["api_key"]: + existing_provider.api_key_enc = Secret.from_plaintext(update_data["api_key"]).get_encrypted() + # Keep plaintext for dual-write during migration + existing_provider.api_key = update_data["api_key"] + + # Remove from update_data since we set directly on existing_provider + update_data.pop("api_key", None) + update_data.pop("api_key_enc", None) + + # Handle encryption for access_key if provided + # Only re-encrypt if the value has actually changed + if "access_key" in update_data and update_data["access_key"] is not None: + # Check if value changed + existing_access_key = None + if existing_provider.access_key_enc: + existing_secret = Secret.from_encrypted(existing_provider.access_key_enc) + existing_access_key = existing_secret.get_plaintext() + elif existing_provider.access_key: + existing_access_key = existing_provider.access_key + + # Only re-encrypt if different + if existing_access_key != update_data["access_key"]: + existing_provider.access_key_enc = Secret.from_plaintext(update_data["access_key"]).get_encrypted() + # Keep plaintext for dual-write during migration + existing_provider.access_key = update_data["access_key"] + + # Remove from update_data since we set directly on existing_provider + update_data.pop("access_key", None) + update_data.pop("access_key_enc", None) + + # Apply remaining updates for key, value in update_data.items(): setattr(existing_provider, key, value) @@ -117,13 +168,21 @@ class ProviderManager: @trace_method def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]: providers = self.list_providers(name=provider_name, actor=actor) - return providers[0].api_key if providers else None + if providers: + # Decrypt the API key before returning + api_key_secret = providers[0].get_api_key_secret() + return api_key_secret.get_plaintext() + return None @enforce_types @trace_method async def get_override_key_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]: providers = await self.list_providers_async(name=provider_name, actor=actor) - return providers[0].api_key if providers else None + if providers: + # Decrypt the API key before returning + api_key_secret = providers[0].get_api_key_secret() + return api_key_secret.get_plaintext() + return None @enforce_types @trace_method @@ -131,10 +190,15 @@ class ProviderManager: self, provider_name: Union[str, None], actor: PydanticUser ) -> Tuple[Optional[str], Optional[str], Optional[str]]: providers = await self.list_providers_async(name=provider_name, actor=actor) - access_key = providers[0].access_key if providers else None - secret_key = providers[0].api_key if providers else None - region = providers[0].region if providers else None - return access_key, secret_key, region + if providers: + # Decrypt the credentials before returning + access_key_secret = providers[0].get_access_key_secret() + api_key_secret = providers[0].get_api_key_secret() + access_key = access_key_secret.get_plaintext() + secret_key = api_key_secret.get_plaintext() + region = providers[0].region + return access_key, secret_key, region + return None, None, None @enforce_types @trace_method @@ -142,10 +206,14 @@ class ProviderManager: self, provider_name: Union[str, None], actor: PydanticUser ) -> Tuple[Optional[str], Optional[str], Optional[str]]: providers = self.list_providers(name=provider_name, actor=actor) - api_key = providers[0].api_key if providers else None - base_url = providers[0].base_url if providers else None - api_version = providers[0].api_version if providers else None - return api_key, base_url, api_version + if providers: + # Decrypt the API key before returning + api_key_secret = providers[0].get_api_key_secret() + api_key = api_key_secret.get_plaintext() + base_url = providers[0].base_url + api_version = providers[0].api_version + return api_key, base_url, api_version + return None, None, None @enforce_types @trace_method @@ -153,10 +221,14 @@ class ProviderManager: self, provider_name: Union[str, None], actor: PydanticUser ) -> Tuple[Optional[str], Optional[str], Optional[str]]: providers = await self.list_providers_async(name=provider_name, actor=actor) - api_key = providers[0].api_key if providers else None - base_url = providers[0].base_url if providers else None - api_version = providers[0].api_version if providers else None - return api_key, base_url, api_version + if providers: + # Decrypt the API key before returning + api_key_secret = providers[0].get_api_key_secret() + api_key = api_key_secret.get_plaintext() + base_url = providers[0].base_url + api_version = providers[0].api_version + return api_key, base_url, api_version + return None, None, None @enforce_types @trace_method diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index 8bc67824..ad4a7563 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -198,6 +198,12 @@ class SandboxConfigManager: return db_env_var else: async with db_registry.async_session() as session: + # Explicitly encrypt the value before storing + from letta.schemas.secret import Secret + + if env_var.value is not None: + env_var.value_enc = Secret.from_plaintext(env_var.value) + env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True)) await env_var.create_async(session, actor=actor) return env_var.to_pydantic() @@ -211,6 +217,31 @@ class SandboxConfigManager: async with db_registry.async_session() as session: env_var = await SandboxEnvVarModel.read_async(db_session=session, identifier=env_var_id, actor=actor) update_data = env_var_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + + # Handle encryption for value if provided + # Only re-encrypt if the value has actually changed + if "value" in update_data and update_data["value"] is not None: + from letta.schemas.secret import Secret + + # Check if value changed + existing_value = None + if env_var.value_enc: + existing_secret = Secret.from_encrypted(env_var.value_enc) + existing_value = existing_secret.get_plaintext() + elif env_var.value: + existing_value = env_var.value + + # Only re-encrypt if different + if existing_value != update_data["value"]: + env_var.value_enc = Secret.from_plaintext(update_data["value"]).get_encrypted() + # Keep plaintext for dual-write during migration + env_var.value = update_data["value"] + + # Remove from update_data since we set directly on env_var + update_data.pop("value", None) + update_data.pop("value_enc", None) + + # Apply remaining updates update_data = {key: value for key, value in update_data.items() if getattr(env_var, key) != value} if update_data: @@ -277,7 +308,9 @@ class SandboxConfigManager: env_vars = self.list_sandbox_env_vars(sandbox_config_id, actor, after, limit) result = {} for env_var in env_vars: - result[env_var.key] = env_var.value + # Decrypt the value before returning + value_secret = env_var.get_value_secret() + result[env_var.key] = value_secret.get_plaintext() return result @enforce_types @@ -286,7 +319,8 @@ class SandboxConfigManager: self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50 ) -> Dict[str, str]: env_vars = await self.list_sandbox_env_vars_async(sandbox_config_id, actor, after, limit) - return {env_var.key: env_var.value for env_var in env_vars} + # Decrypt values before returning + return {env_var.key: env_var.get_value_secret().get_plaintext() for env_var in env_vars} @enforce_types @trace_method diff --git a/tests/managers/test_agent_manager.py b/tests/managers/test_agent_manager.py index 14972440..94c7cb92 100644 --- a/tests/managers/test_agent_manager.py +++ b/tests/managers/test_agent_manager.py @@ -853,3 +853,170 @@ async def test_list_agents_ordering_and_pagination(server: SyncServer, default_u before_alpha_desc = await server.agent_manager.list_agents_async(actor=default_user, before=agent_ids["alpha_agent"], ascending=False) before_names_desc = [a.name for a in before_alpha_desc] assert before_names_desc == ["gamma_agent", "beta_agent"] + + +# ====================================================================================================================== +# AgentManager Tests - Environment Variable Encryption +# ====================================================================================================================== + + +@pytest.fixture +def encryption_key(): + """Fixture to ensure encryption key is set for tests.""" + original_key = settings.encryption_key + # Set a test encryption key if not already set + if not settings.encryption_key: + settings.encryption_key = "test-encryption-key-32-bytes!!" + yield settings.encryption_key + # Restore original + settings.encryption_key = original_key + + +@pytest.mark.asyncio +async def test_agent_environment_variables_encrypt_on_create(server: SyncServer, default_user, encryption_key): + """Test that creating an agent with secrets encrypts the values in the database.""" + from letta.orm.sandbox_config import AgentEnvironmentVariable as AgentEnvironmentVariableModel + from letta.schemas.secret import Secret + + # Create agent with secrets + agent_create = CreateAgent( + name="test-agent-with-secrets", + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=DEFAULT_EMBEDDING_CONFIG, + include_base_tools=False, + secrets={ + "API_KEY": "sk-test-secret-12345", + "DATABASE_URL": "postgres://user:pass@localhost/db", + }, + ) + + created_agent = await server.agent_manager.create_agent_async(agent_create, actor=default_user) + + # Verify agent has secrets + assert created_agent.secrets is not None + assert len(created_agent.secrets) == 2 + + # Verify secrets are AgentEnvironmentVariable objects with Secret fields + for secret_obj in created_agent.secrets: + assert secret_obj.key in ["API_KEY", "DATABASE_URL"] + assert secret_obj.value_enc is not None + assert isinstance(secret_obj.value_enc, Secret) + + # Verify values are encrypted in the database + async with db_registry.async_session() as session: + env_vars = await session.execute( + select(AgentEnvironmentVariableModel).where(AgentEnvironmentVariableModel.agent_id == created_agent.id) + ) + env_var_list = list(env_vars.scalars().all()) + + assert len(env_var_list) == 2 + for env_var in env_var_list: + # Check that value_enc is not None and is encrypted + assert env_var.value_enc is not None + assert isinstance(env_var.value_enc, str) + + # Decrypt and verify + decrypted = Secret.from_encrypted(env_var.value_enc).get_plaintext() + if env_var.key == "API_KEY": + assert decrypted == "sk-test-secret-12345" + elif env_var.key == "DATABASE_URL": + assert decrypted == "postgres://user:pass@localhost/db" + + +@pytest.mark.asyncio +async def test_agent_environment_variables_decrypt_on_read(server: SyncServer, default_user, encryption_key): + """Test that reading an agent deserializes secrets correctly to AgentEnvironmentVariable objects.""" + from letta.schemas.environment_variables import AgentEnvironmentVariable + from letta.schemas.secret import Secret + + # Create agent with secrets + agent_create = CreateAgent( + name="test-agent-read-secrets", + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=DEFAULT_EMBEDDING_CONFIG, + include_base_tools=False, + secrets={ + "TEST_KEY": "test-value-67890", + }, + ) + + created_agent = await server.agent_manager.create_agent_async(agent_create, actor=default_user) + agent_id = created_agent.id + + # Read the agent back + retrieved_agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=default_user) + + # Verify secrets are properly deserialized + assert retrieved_agent.secrets is not None + assert len(retrieved_agent.secrets) == 1 + + secret_obj = retrieved_agent.secrets[0] + assert isinstance(secret_obj, AgentEnvironmentVariable) + assert secret_obj.key == "TEST_KEY" + assert secret_obj.value == "test-value-67890" + + # Verify value_enc is a Secret object (not a string) + assert secret_obj.value_enc is not None + assert isinstance(secret_obj.value_enc, Secret) + + # Verify we can decrypt through the Secret object + decrypted = secret_obj.value_enc.get_plaintext() + assert decrypted == "test-value-67890" + + # Verify get_value_secret() method works + value_secret = secret_obj.get_value_secret() + assert isinstance(value_secret, Secret) + assert value_secret.get_plaintext() == "test-value-67890" + + +@pytest.mark.asyncio +async def test_agent_environment_variables_update_encryption(server: SyncServer, default_user, encryption_key): + """Test that updating agent secrets encrypts new values.""" + from letta.orm.sandbox_config import AgentEnvironmentVariable as AgentEnvironmentVariableModel + from letta.schemas.secret import Secret + + # Create agent with initial secrets + agent_create = CreateAgent( + name="test-agent-update-secrets", + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=DEFAULT_EMBEDDING_CONFIG, + include_base_tools=False, + secrets={ + "INITIAL_KEY": "initial-value", + }, + ) + + created_agent = await server.agent_manager.create_agent_async(agent_create, actor=default_user) + agent_id = created_agent.id + + # Update with new secrets + agent_update = UpdateAgent( + secrets={ + "UPDATED_KEY": "updated-value-abc", + "NEW_KEY": "new-value-xyz", + }, + ) + + updated_agent = await server.agent_manager.update_agent_async(agent_id=agent_id, agent_update=agent_update, actor=default_user) + + # Verify updated secrets + assert updated_agent.secrets is not None + assert len(updated_agent.secrets) == 2 + + # Verify in database + async with db_registry.async_session() as session: + env_vars = await session.execute(select(AgentEnvironmentVariableModel).where(AgentEnvironmentVariableModel.agent_id == agent_id)) + env_var_list = list(env_vars.scalars().all()) + + assert len(env_var_list) == 2 + for env_var in env_var_list: + assert env_var.value_enc is not None + + # Decrypt and verify + decrypted = Secret.from_encrypted(env_var.value_enc).get_plaintext() + if env_var.key == "UPDATED_KEY": + assert decrypted == "updated-value-abc" + elif env_var.key == "NEW_KEY": + assert decrypted == "new-value-xyz" + else: + pytest.fail(f"Unexpected key: {env_var.key}") diff --git a/tests/managers/test_mcp_manager.py b/tests/managers/test_mcp_manager.py index 2aad82c0..e4bb9efc 100644 --- a/tests/managers/test_mcp_manager.py +++ b/tests/managers/test_mcp_manager.py @@ -901,3 +901,237 @@ async def test_mcp_server_resync_tools(server, default_user, default_organizatio finally: # Clean up await server.mcp_manager.delete_mcp_server_by_id(mcp_server_id, actor=default_user) + + +# ====================================================================================================================== +# MCPManager Tests - Encryption +# ====================================================================================================================== + + +@pytest.fixture +def encryption_key(): + """Fixture to ensure encryption key is set for tests.""" + original_key = settings.encryption_key + # Set a test encryption key if not already set + if not settings.encryption_key: + settings.encryption_key = "test-encryption-key-32-bytes!!" + yield settings.encryption_key + # Restore original + settings.encryption_key = original_key + + +@pytest.mark.asyncio +async def test_mcp_server_token_encryption_on_create(server, default_user, encryption_key): + """Test that creating an MCP server encrypts the token in the database.""" + from letta.functions.mcp_client.types import MCPServerType + from letta.orm.mcp_server import MCPServer as MCPServerModel + from letta.schemas.mcp import MCPServer + from letta.schemas.secret import Secret + + # Create MCP server with token + mcp_server = MCPServer( + server_name="test-encrypted-server", + server_type=MCPServerType.STREAMABLE_HTTP, + server_url="https://api.example.com/mcp", + token="sk-test-secret-token-12345", + ) + + created_server = await server.mcp_manager.create_mcp_server(mcp_server, actor=default_user) + + try: + # Verify server was created + assert created_server is not None + assert created_server.server_name == "test-encrypted-server" + + # Verify plaintext token is accessible (dual-write during migration) + assert created_server.token == "sk-test-secret-token-12345" + + # Verify token_enc is a Secret object + assert created_server.token_enc is not None + assert isinstance(created_server.token_enc, Secret) + + # Read directly from database to verify encryption + async with db_registry.async_session() as session: + server_orm = await MCPServerModel.read_async( + db_session=session, + identifier=created_server.id, + actor=default_user, + ) + + # Verify plaintext column has the value (dual-write) + assert server_orm.token == "sk-test-secret-token-12345" + + # Verify encrypted column is populated and different from plaintext + assert server_orm.token_enc is not None + assert server_orm.token_enc != "sk-test-secret-token-12345" + # Encrypted value should be longer + assert len(server_orm.token_enc) > len("sk-test-secret-token-12345") + + finally: + # Clean up + await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_mcp_server_token_decryption_on_read(server, default_user, encryption_key): + """Test that reading an MCP server decrypts the token correctly.""" + from letta.functions.mcp_client.types import MCPServerType + from letta.schemas.mcp import MCPServer + from letta.schemas.secret import Secret + + # Create MCP server + mcp_server = MCPServer( + server_name="test-decrypt-server", + server_type=MCPServerType.STREAMABLE_HTTP, + server_url="https://api.example.com/mcp", + token="sk-test-decrypt-token-67890", + ) + + created_server = await server.mcp_manager.create_mcp_server(mcp_server, actor=default_user) + server_id = created_server.id + + try: + # Read the server back + retrieved_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user) + + # Verify the token is decrypted correctly + assert retrieved_server.token == "sk-test-decrypt-token-67890" + + # Verify we can get the decrypted token through the secret getter + token_secret = retrieved_server.get_token_secret() + assert isinstance(token_secret, Secret) + decrypted_token = token_secret.get_plaintext() + assert decrypted_token == "sk-test-decrypt-token-67890" + + finally: + # Clean up + await server.mcp_manager.delete_mcp_server_by_id(server_id, actor=default_user) + + +@pytest.mark.asyncio +async def test_mcp_server_custom_headers_encryption(server, default_user, encryption_key): + """Test that custom headers are encrypted as JSON strings.""" + from letta.functions.mcp_client.types import MCPServerType + from letta.orm.mcp_server import MCPServer as MCPServerModel + from letta.schemas.mcp import MCPServer + from letta.schemas.secret import Secret + + # Create MCP server with custom headers + custom_headers = {"Authorization": "Bearer token123", "X-API-Key": "secret-key-456"} + mcp_server = MCPServer( + server_name="test-headers-server", + server_type=MCPServerType.STREAMABLE_HTTP, + server_url="https://api.example.com/mcp", + custom_headers=custom_headers, + ) + + created_server = await server.mcp_manager.create_mcp_server(mcp_server, actor=default_user) + + try: + # Verify custom_headers are accessible + assert created_server.custom_headers == custom_headers + + # Verify custom_headers_enc is a Secret object (stores JSON string) + assert created_server.custom_headers_enc is not None + assert isinstance(created_server.custom_headers_enc, Secret) + + # Verify the getter method returns a Secret (JSON string) + headers_secret = created_server.get_custom_headers_secret() + assert isinstance(headers_secret, Secret) + # Verify the Secret contains JSON string + json_str = headers_secret.get_plaintext() + assert json_str is not None + import json + + assert json.loads(json_str) == custom_headers + + # Verify the convenience method returns dict directly + headers_dict = created_server.get_custom_headers_dict() + assert headers_dict == custom_headers + + # Read from DB to verify encryption + async with db_registry.async_session() as session: + server_orm = await MCPServerModel.read_async( + db_session=session, + identifier=created_server.id, + actor=default_user, + ) + + # Verify encrypted column contains encrypted JSON string + assert server_orm.custom_headers_enc is not None + # Decrypt and verify it's valid JSON matching original headers + decrypted_json = Secret.from_encrypted(server_orm.custom_headers_enc).get_plaintext() + import json + + decrypted_headers = json.loads(decrypted_json) + assert decrypted_headers == custom_headers + + finally: + # Clean up + await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_oauth_session_tokens_encryption(server, default_user, encryption_key): + """Test that OAuth session tokens are encrypted in the database.""" + from letta.orm.mcp_oauth import MCPOAuth as MCPOAuthModel + from letta.schemas.mcp import MCPOAuthSessionCreate, MCPOAuthSessionUpdate + from letta.schemas.secret import Secret + + # Create OAuth session + session_create = MCPOAuthSessionCreate( + server_url="https://oauth.example.com", + server_name="test-oauth-server", + organization_id=default_user.organization_id, + user_id=default_user.id, + ) + + created_session = await server.mcp_manager.create_oauth_session(session_create, actor=default_user) + session_id = created_session.id + + try: + # Update with OAuth tokens + session_update = MCPOAuthSessionUpdate( + access_token="access-token-abc123", + refresh_token="refresh-token-xyz789", + client_secret="client-secret-def456", + authorization_code="auth-code-ghi012", + ) + + updated_session = await server.mcp_manager.update_oauth_session(session_id, session_update, actor=default_user) + + # Verify tokens are accessible + assert updated_session.access_token == "access-token-abc123" + assert updated_session.refresh_token == "refresh-token-xyz789" + assert updated_session.client_secret == "client-secret-def456" + assert updated_session.authorization_code == "auth-code-ghi012" + + # Verify encrypted fields are Secret objects + assert isinstance(updated_session.access_token_enc, Secret) + assert isinstance(updated_session.refresh_token_enc, Secret) + assert isinstance(updated_session.client_secret_enc, Secret) + assert isinstance(updated_session.authorization_code_enc, Secret) + + # Read from DB to verify all tokens are encrypted + async with db_registry.async_session() as session: + oauth_orm = await MCPOAuthModel.read_async( + db_session=session, + identifier=session_id, + actor=default_user, + ) + + # Verify all encrypted columns are populated and encrypted + assert oauth_orm.access_token_enc is not None + assert oauth_orm.refresh_token_enc is not None + assert oauth_orm.client_secret_enc is not None + assert oauth_orm.authorization_code_enc is not None + + # Decrypt and verify + assert Secret.from_encrypted(oauth_orm.access_token_enc).get_plaintext() == "access-token-abc123" + assert Secret.from_encrypted(oauth_orm.refresh_token_enc).get_plaintext() == "refresh-token-xyz789" + assert Secret.from_encrypted(oauth_orm.client_secret_enc).get_plaintext() == "client-secret-def456" + assert Secret.from_encrypted(oauth_orm.authorization_code_enc).get_plaintext() == "auth-code-ghi012" + + finally: + # Clean up + await server.mcp_manager.delete_oauth_session(session_id, actor=default_user) diff --git a/tests/managers/test_provider_manager.py b/tests/managers/test_provider_manager.py new file mode 100644 index 00000000..bdc3082c --- /dev/null +++ b/tests/managers/test_provider_manager.py @@ -0,0 +1,322 @@ +"""Tests for ProviderManager encryption/decryption logic.""" + +import os + +import pytest + +from letta.orm.provider import Provider as ProviderModel +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate +from letta.schemas.secret import Secret +from letta.server.db import db_registry +from letta.services.organization_manager import OrganizationManager +from letta.services.provider_manager import ProviderManager +from letta.services.user_manager import UserManager +from letta.settings import settings + + +@pytest.fixture +async def default_organization(): + """Fixture to create and return the default organization.""" + manager = OrganizationManager() + org = await manager.create_default_organization_async() + yield org + + +@pytest.fixture +async def default_user(default_organization): + """Fixture to create and return the default user within the default organization.""" + manager = UserManager() + user = await manager.create_default_actor_async(org_id=default_organization.id) + yield user + + +@pytest.fixture +async def provider_manager(): + """Fixture to create and return a ProviderManager instance.""" + return ProviderManager() + + +@pytest.fixture +def encryption_key(): + """Fixture to ensure encryption key is set for tests.""" + original_key = settings.encryption_key + # Set a test encryption key if not already set + if not settings.encryption_key: + settings.encryption_key = "test-encryption-key-32-bytes!!" + yield settings.encryption_key + # Restore original + settings.encryption_key = original_key + + +# ====================================================================================================================== +# Provider Encryption Tests +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_provider_create_encrypts_api_key(provider_manager, default_user, encryption_key): + """Test that creating a provider encrypts the api_key and stores it in api_key_enc.""" + # Create a provider with plaintext api_key + provider_create = ProviderCreate( + name="test-openai-provider", + provider_type=ProviderType.openai, + api_key="sk-test-plaintext-api-key-12345", + base_url="https://api.openai.com/v1", + ) + + # Create provider through manager + created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user) + + # Verify provider was created + assert created_provider is not None + assert created_provider.name == "test-openai-provider" + assert created_provider.provider_type == ProviderType.openai + + # Verify plaintext api_key is still accessible (dual-write during migration) + assert created_provider.api_key == "sk-test-plaintext-api-key-12345" + + # Read directly from database to verify encryption + async with db_registry.async_session() as session: + provider_orm = await ProviderModel.read_async( + db_session=session, + identifier=created_provider.id, + actor=default_user, + ) + + # Verify plaintext column has the value (dual-write) + assert provider_orm.api_key == "sk-test-plaintext-api-key-12345" + + # Verify encrypted column is populated and different from plaintext + assert provider_orm.api_key_enc is not None + assert provider_orm.api_key_enc != "sk-test-plaintext-api-key-12345" + # Encrypted value should be base64-encoded and longer + assert len(provider_orm.api_key_enc) > len("sk-test-plaintext-api-key-12345") + + +@pytest.mark.asyncio +async def test_provider_read_decrypts_api_key(provider_manager, default_user, encryption_key): + """Test that reading a provider decrypts the api_key from api_key_enc.""" + # Create a provider + provider_create = ProviderCreate( + name="test-anthropic-provider", + provider_type=ProviderType.anthropic, + api_key="sk-ant-test-key-67890", + ) + + created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user) + provider_id = created_provider.id + + # Read the provider back + retrieved_provider = await provider_manager.get_provider_async(provider_id, actor=default_user) + + # Verify the api_key is decrypted correctly + assert retrieved_provider.api_key == "sk-ant-test-key-67890" + + # Verify we can get the decrypted key through the secret getter + api_key_secret = retrieved_provider.get_api_key_secret() + assert isinstance(api_key_secret, Secret) + decrypted_key = api_key_secret.get_plaintext() + assert decrypted_key == "sk-ant-test-key-67890" + + +@pytest.mark.asyncio +async def test_provider_update_encrypts_new_api_key(provider_manager, default_user, encryption_key): + """Test that updating a provider's api_key encrypts the new value.""" + # Create initial provider + provider_create = ProviderCreate( + name="test-groq-provider", + provider_type=ProviderType.groq, + api_key="gsk-initial-key-123", + ) + + created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user) + provider_id = created_provider.id + + # Update the api_key + provider_update = ProviderUpdate( + api_key="gsk-updated-key-456", + ) + + updated_provider = await provider_manager.update_provider_async(provider_id, provider_update, actor=default_user) + + # Verify the updated key is accessible + assert updated_provider.api_key == "gsk-updated-key-456" + + # Read from DB to verify new encrypted value + async with db_registry.async_session() as session: + provider_orm = await ProviderModel.read_async( + db_session=session, + identifier=provider_id, + actor=default_user, + ) + + # Verify both columns are updated + assert provider_orm.api_key == "gsk-updated-key-456" + assert provider_orm.api_key_enc is not None + + # Decrypt and verify + decrypted = Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() + assert decrypted == "gsk-updated-key-456" + + +@pytest.mark.asyncio +async def test_bedrock_credentials_encryption(provider_manager, default_user, encryption_key): + """Test that Bedrock provider encrypts both access_key and api_key (secret_key).""" + # Create Bedrock provider with both keys + provider_create = ProviderCreate( + name="test-bedrock-provider", + provider_type=ProviderType.bedrock, + api_key="secret-access-key-xyz", # This is the secret key + access_key="access-key-id-abc", # This is the access key ID + region="us-east-1", + ) + + created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user) + + # Verify both keys are accessible + assert created_provider.api_key == "secret-access-key-xyz" + assert created_provider.access_key == "access-key-id-abc" + + # Read from DB to verify both are encrypted + async with db_registry.async_session() as session: + provider_orm = await ProviderModel.read_async( + db_session=session, + identifier=created_provider.id, + actor=default_user, + ) + + # Verify both encrypted columns are populated + assert provider_orm.api_key_enc is not None + assert provider_orm.access_key_enc is not None + + # Verify encrypted values are different from plaintext + assert provider_orm.api_key_enc != "secret-access-key-xyz" + assert provider_orm.access_key_enc != "access-key-id-abc" + + # Test the manager method for getting Bedrock credentials + access_key, secret_key, region = await provider_manager.get_bedrock_credentials_async("test-bedrock-provider", actor=default_user) + + assert access_key == "access-key-id-abc" + assert secret_key == "secret-access-key-xyz" + assert region == "us-east-1" + + +@pytest.mark.asyncio +async def test_provider_secret_not_exposed_in_logs(provider_manager, default_user, encryption_key): + """Test that Secret objects don't expose plaintext in string representations.""" + # Create a provider + provider_create = ProviderCreate( + name="test-secret-provider", + provider_type=ProviderType.openai, + api_key="sk-very-secret-key-do-not-log", + ) + + created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user) + + # Get the Secret object + api_key_secret = created_provider.get_api_key_secret() + + # Verify string representation doesn't expose the key + secret_str = str(api_key_secret) + secret_repr = repr(api_key_secret) + + assert "sk-very-secret-key-do-not-log" not in secret_str + assert "sk-very-secret-key-do-not-log" not in secret_repr + assert "****" in secret_str or "Secret" in secret_str + assert "****" in secret_repr or "Secret" in secret_repr + + +@pytest.mark.asyncio +async def test_provider_pydantic_to_orm_serialization(provider_manager, default_user, encryption_key): + """Test the full Pydantic → ORM → Pydantic round-trip maintains data integrity.""" + # Create a provider through the normal flow + provider_create = ProviderCreate( + name="test-roundtrip-provider", + provider_type=ProviderType.openai, + api_key="sk-roundtrip-test-key-999", + base_url="https://api.openai.com/v1", + ) + + # Step 1: Create provider (Pydantic → ORM) + created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user) + original_api_key = created_provider.api_key + + # Step 2: Read provider back (ORM → Pydantic) + retrieved_provider = await provider_manager.get_provider_async(created_provider.id, actor=default_user) + + # Verify data integrity + assert retrieved_provider.api_key == original_api_key + assert retrieved_provider.name == "test-roundtrip-provider" + assert retrieved_provider.provider_type == ProviderType.openai + assert retrieved_provider.base_url == "https://api.openai.com/v1" + + # Verify Secret object works correctly + api_key_secret = retrieved_provider.get_api_key_secret() + assert api_key_secret.get_plaintext() == original_api_key + + # Step 3: Convert to ORM again (should preserve encrypted field) + orm_data = retrieved_provider.model_dump(to_orm=True) + + # Verify encrypted field is in the ORM data + assert "api_key_enc" in orm_data + assert orm_data["api_key_enc"] is not None + assert orm_data["api_key"] == original_api_key + + +@pytest.mark.asyncio +async def test_provider_with_none_api_key(provider_manager, default_user, encryption_key): + """Test that providers can be created with None api_key (some providers may not need it).""" + # Create a provider without an api_key + provider_create = ProviderCreate( + name="test-no-key-provider", + provider_type=ProviderType.ollama, + api_key="", # Empty string + base_url="http://localhost:11434", + ) + + created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user) + + # Verify provider was created + assert created_provider is not None + assert created_provider.name == "test-no-key-provider" + + # Read from DB + async with db_registry.async_session() as session: + provider_orm = await ProviderModel.read_async( + db_session=session, + identifier=created_provider.id, + actor=default_user, + ) + + # api_key_enc should handle empty string appropriately + # (encrypt empty string or store as None) + assert provider_orm.api_key_enc is not None or provider_orm.api_key == "" + + +@pytest.mark.asyncio +async def test_list_providers_decrypts_all(provider_manager, default_user, encryption_key): + """Test that listing multiple providers decrypts all their api_keys correctly.""" + # Create multiple providers + providers_to_create = [ + ProviderCreate(name=f"test-provider-{i}", provider_type=ProviderType.openai, api_key=f"sk-key-{i}") for i in range(3) + ] + + created_ids = [] + for provider_create in providers_to_create: + provider = await provider_manager.create_provider_async(provider_create, actor=default_user) + created_ids.append(provider.id) + + # List all providers + all_providers = await provider_manager.list_providers_async(actor=default_user) + + # Filter to our test providers + test_providers = [p for p in all_providers if p.id in created_ids] + + # Verify all are decrypted correctly + assert len(test_providers) == 3 + for i, provider in enumerate(sorted(test_providers, key=lambda p: p.name)): + assert provider.api_key == f"sk-key-{i}" + # Verify Secret getter works + secret = provider.get_api_key_secret() + assert secret.get_plaintext() == f"sk-key-{i}" diff --git a/tests/test_mcp_encryption.py b/tests/test_mcp_encryption.py index cec05c3a..f5a5b989 100644 --- a/tests/test_mcp_encryption.py +++ b/tests/test_mcp_encryption.py @@ -23,7 +23,7 @@ from letta.schemas.mcp import ( SSEServerConfig, StdioServerConfig, ) -from letta.schemas.secret import Secret, SecretDict +from letta.schemas.secret import Secret from letta.server.db import db_registry from letta.server.server import SyncServer from letta.services.mcp_manager import MCPManager diff --git a/tests/test_secret.py b/tests/test_secret.py index 24cdce00..b4151875 100644 --- a/tests/test_secret.py +++ b/tests/test_secret.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch import pytest from letta.helpers.crypto_utils import CryptoUtils -from letta.schemas.secret import Secret, SecretDict +from letta.schemas.secret import Secret class TestSecret: @@ -26,9 +26,9 @@ class TestSecret: secret = Secret.from_plaintext(plaintext) # Should store encrypted value - assert secret._encrypted_value is not None - assert secret._encrypted_value != plaintext - assert secret._was_encrypted is False + assert secret.encrypted_value is not None + assert secret.encrypted_value != plaintext + assert secret.was_encrypted is False # Should decrypt to original value assert secret.get_plaintext() == plaintext @@ -50,9 +50,9 @@ class TestSecret: secret = Secret.from_plaintext(plaintext) # Should store the plaintext value - assert secret._encrypted_value == plaintext + assert secret.encrypted_value == plaintext assert secret.get_plaintext() == plaintext - assert not secret._was_encrypted + assert not secret.was_encrypted finally: settings.encryption_key = original_key @@ -60,8 +60,8 @@ class TestSecret: """Test creating a Secret from None value.""" secret = Secret.from_plaintext(None) - assert secret._encrypted_value is None - assert secret._was_encrypted is False + assert secret.encrypted_value is None + assert secret.was_encrypted is False assert secret.get_plaintext() is None assert secret.is_empty() is True @@ -78,8 +78,8 @@ class TestSecret: secret = Secret.from_encrypted(encrypted) - assert secret._encrypted_value == encrypted - assert secret._was_encrypted is True + assert secret.encrypted_value == encrypted + assert secret.was_encrypted is True assert secret.get_plaintext() == plaintext finally: settings.encryption_key = original_key @@ -97,8 +97,8 @@ class TestSecret: secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=None) - assert secret._encrypted_value == encrypted - assert secret._was_encrypted is True + assert secret.encrypted_value == encrypted + assert secret.was_encrypted is True assert secret.get_plaintext() == plaintext finally: settings.encryption_key = original_key @@ -117,8 +117,8 @@ class TestSecret: secret = Secret.from_db(encrypted_value=None, plaintext_value=plaintext) # Should encrypt the plaintext - assert secret._encrypted_value is not None - assert secret._was_encrypted is False + assert secret.encrypted_value is not None + assert secret.was_encrypted is False assert secret.get_plaintext() == plaintext finally: settings.encryption_key = original_key @@ -278,354 +278,3 @@ class TestSecret: assert mock_decrypt.call_count == 1 finally: settings.encryption_key = original_key - - -class TestSecretDict: - """Test suite for SecretDict wrapper class.""" - - MOCK_KEY = "test-secretdict-key-1234567890" - - def test_from_plaintext_dict(self): - """Test creating a SecretDict from plaintext dictionary.""" - from letta.settings import settings - - original_key = settings.encryption_key - settings.encryption_key = self.MOCK_KEY - - try: - plaintext_dict = {"api_key": "sk-1234567890", "api_secret": "secret-value", "nested": {"token": "bearer-token"}} - - secret_dict = SecretDict.from_plaintext(plaintext_dict) - - # Should store encrypted JSON - assert secret_dict._encrypted_value is not None - - # Should decrypt to original dict - assert secret_dict.get_plaintext() == plaintext_dict - finally: - settings.encryption_key = original_key - - def test_from_plaintext_none(self): - """Test creating a SecretDict from None value.""" - secret_dict = SecretDict.from_plaintext(None) - - assert secret_dict._encrypted_value is None - assert secret_dict.get_plaintext() is None - - def test_from_encrypted_with_json(self): - """Test creating a SecretDict from encrypted JSON.""" - from letta.settings import settings - - original_key = settings.encryption_key - settings.encryption_key = self.MOCK_KEY - - try: - plaintext_dict = {"header1": "value1", "Authorization": "Bearer token123"} - - json_str = json.dumps(plaintext_dict) - encrypted = CryptoUtils.encrypt(json_str, self.MOCK_KEY) - - secret_dict = SecretDict.from_encrypted(encrypted) - - assert secret_dict._encrypted_value == encrypted - assert secret_dict.get_plaintext() == plaintext_dict - finally: - settings.encryption_key = original_key - - def test_from_db_with_encrypted(self): - """Test creating SecretDict from database with encrypted value.""" - from letta.settings import settings - - original_key = settings.encryption_key - settings.encryption_key = self.MOCK_KEY - - try: - plaintext_dict = {"key": "value"} - json_str = json.dumps(plaintext_dict) - encrypted = CryptoUtils.encrypt(json_str, self.MOCK_KEY) - - secret_dict = SecretDict.from_db(encrypted_value=encrypted, plaintext_value=None) - - assert secret_dict.get_plaintext() == plaintext_dict - finally: - settings.encryption_key = original_key - - def test_from_db_with_plaintext_json(self): - """Test creating SecretDict from database with plaintext JSON (backward compatibility).""" - from letta.settings import settings - - original_key = settings.encryption_key - settings.encryption_key = self.MOCK_KEY - - try: - plaintext_dict = {"legacy": "headers"} - - # from_db expects a Dict, not a JSON string - secret_dict = SecretDict.from_db(encrypted_value=None, plaintext_value=plaintext_dict) - - assert secret_dict.get_plaintext() == plaintext_dict - # Should have encrypted it - assert secret_dict._encrypted_value is not None - finally: - settings.encryption_key = original_key - - def test_complex_nested_structure(self): - """Test SecretDict with complex nested structures.""" - from letta.settings import settings - - original_key = settings.encryption_key - settings.encryption_key = self.MOCK_KEY - - try: - complex_dict = { - "level1": {"level2": {"level3": ["item1", "item2"], "secret": "nested-secret"}, "array": [1, 2, {"nested": "value"}]}, - "simple": "value", - "number": 42, - "boolean": True, - "null": None, - } - - secret_dict = SecretDict.from_plaintext(complex_dict) - decrypted = secret_dict.get_plaintext() - - assert decrypted == complex_dict - finally: - settings.encryption_key = original_key - - def test_empty_dict(self): - """Test SecretDict with empty dictionary.""" - from letta.settings import settings - - original_key = settings.encryption_key - settings.encryption_key = self.MOCK_KEY - - try: - empty_dict = {} - - secret_dict = SecretDict.from_plaintext(empty_dict) - assert secret_dict.get_plaintext() == empty_dict - - # Encrypted value should still be created - encrypted = secret_dict.get_encrypted() - assert encrypted is not None - finally: - settings.encryption_key = original_key - - def test_dual_read_prefer_encrypted(self): - """Test that SecretDict prefers encrypted value over plaintext when both exist.""" - from letta.settings import settings - - original_key = settings.encryption_key - settings.encryption_key = self.MOCK_KEY - - try: - new_dict = {"current": "value"} - old_dict = {"legacy": "value"} - - encrypted = CryptoUtils.encrypt(json.dumps(new_dict), self.MOCK_KEY) - plaintext = json.dumps(old_dict) - - secret_dict = SecretDict.from_db(encrypted_value=encrypted, plaintext_value=plaintext) - - # Should use encrypted value, not plaintext - assert secret_dict.get_plaintext() == new_dict - finally: - settings.encryption_key = original_key - - def test_plaintext_dict_caching(self): - """Test that plaintext dictionary values are cached after first decryption.""" - from letta.settings import settings - - original_key = settings.encryption_key - settings.encryption_key = self.MOCK_KEY - - try: - plaintext_dict = {"key1": "value1", "key2": "value2", "nested": {"inner": "value"}} - secret_dict = SecretDict.from_plaintext(plaintext_dict) - - # First call should decrypt and cache - result1 = secret_dict.get_plaintext() - assert result1 == plaintext_dict - assert secret_dict._plaintext_cache == plaintext_dict - - # Second call should use cache - result2 = secret_dict.get_plaintext() - assert result2 == plaintext_dict - assert result1 is result2 # Should be the same object reference - finally: - settings.encryption_key = original_key - - def test_dict_caching_only_decrypts_once(self): - """Test that SecretDict decryption only happens once when caching is enabled.""" - from letta.settings import settings - - original_key = settings.encryption_key - settings.encryption_key = self.MOCK_KEY - - try: - plaintext_dict = {"api_key": "sk-12345", "api_secret": "secret-value"} - encrypted = CryptoUtils.encrypt(json.dumps(plaintext_dict), self.MOCK_KEY) - - # Create a SecretDict from encrypted value - secret_dict = SecretDict.from_encrypted(encrypted) - - # Mock the decrypt method to track calls - with patch.object(CryptoUtils, "decrypt", wraps=CryptoUtils.decrypt) as mock_decrypt: - # First call should decrypt - result1 = secret_dict.get_plaintext() - assert result1 == plaintext_dict - assert mock_decrypt.call_count == 1 - - # Second and third calls should use cache - result2 = secret_dict.get_plaintext() - result3 = secret_dict.get_plaintext() - assert result2 == plaintext_dict - assert result3 == plaintext_dict - - # Decrypt should still have been called only once - assert mock_decrypt.call_count == 1 - finally: - settings.encryption_key = original_key - - def test_cache_handles_none_values(self): - """Test that caching works correctly with None/empty values.""" - # Test with None value - secret_dict = SecretDict.from_plaintext(None) - - # First call - result1 = secret_dict.get_plaintext() - assert result1 is None - - # Second call should also return None (not trying to decrypt) - result2 = secret_dict.get_plaintext() - assert result2 is None - - def test_from_plaintext_dict_without_key(self): - """Test creating a SecretDict from plaintext dictionary without encryption key (fallback).""" - from letta.settings import settings - - original_key = settings.encryption_key - settings.encryption_key = None - - try: - plaintext_dict = {"key1": "value1", "key2": "value2"} - - # Should handle gracefully and store as JSON plaintext - secret_dict = SecretDict.from_plaintext(plaintext_dict) - - # Should store the JSON string - assert secret_dict._encrypted_value == json.dumps(plaintext_dict) - assert secret_dict.get_plaintext() == plaintext_dict - assert not secret_dict._was_encrypted - finally: - settings.encryption_key = original_key - - def test_encryption_key_transition_no_key_to_has_key(self): - """Test transition from no encryption key to having one.""" - from letta.settings import settings - - original_key = settings.encryption_key - - try: - # Start with no encryption key - settings.encryption_key = None - - # Create secrets without encryption - plaintext = "test-value-123" - secret = Secret.from_plaintext(plaintext) - - plaintext_dict = {"api_key": "sk-12345", "api_secret": "secret"} - secret_dict = SecretDict.from_plaintext(plaintext_dict) - - # Verify they're stored as plaintext - assert secret._encrypted_value == plaintext - assert secret_dict._encrypted_value == json.dumps(plaintext_dict) - - # Now add an encryption key - settings.encryption_key = self.MOCK_KEY - - # Should still be able to read the plaintext values - assert secret.get_plaintext() == plaintext - assert secret_dict.get_plaintext() == plaintext_dict - - # Create new secrets with encryption enabled - new_secret = Secret.from_plaintext("new-encrypted-value") - assert new_secret._encrypted_value != "new-encrypted-value" # Should be encrypted - - finally: - settings.encryption_key = original_key - - def test_encryption_key_transition_has_key_to_no_key(self): - """Test transition from having encryption key to not having one.""" - from letta.settings import settings - - original_key = settings.encryption_key - - try: - # Start with encryption key - settings.encryption_key = self.MOCK_KEY - - # Create secrets with encryption - plaintext = "encrypted-test-value" - secret = Secret.from_plaintext(plaintext) - - plaintext_dict = {"token": "bearer-xyz", "key": "value"} - secret_dict = SecretDict.from_plaintext(plaintext_dict) - - # Verify they're encrypted - assert secret._encrypted_value != plaintext - assert secret_dict._encrypted_value != json.dumps(plaintext_dict) - - # Remove encryption key - settings.encryption_key = None - - # Should handle gracefully - return None for encrypted values - # (can't decrypt without key) - result = secret.get_plaintext() - assert result is None # Can't decrypt without key - - dict_result = secret_dict.get_plaintext() - assert dict_result is None # Can't decrypt without key - - finally: - settings.encryption_key = original_key - - def test_round_trip_compatibility(self): - """Test that values can be read correctly regardless of when they were stored.""" - from letta.settings import settings - - original_key = settings.encryption_key - - try: - # Create some values without encryption - settings.encryption_key = None - unencrypted_secret = Secret.from_plaintext("unencrypted") - unencrypted_dict = SecretDict.from_plaintext({"plain": "text"}) - - # Create some values with encryption - settings.encryption_key = self.MOCK_KEY - encrypted_secret = Secret.from_plaintext("encrypted") - encrypted_dict = SecretDict.from_plaintext({"secure": "data"}) - - # Mix them - can read unencrypted with key present - assert unencrypted_secret.get_plaintext() == "unencrypted" - assert unencrypted_dict.get_plaintext() == {"plain": "text"} - assert encrypted_secret.get_plaintext() == "encrypted" - assert encrypted_dict.get_plaintext() == {"secure": "data"} - - # Remove key - can only read unencrypted - settings.encryption_key = None - assert unencrypted_secret.get_plaintext() == "unencrypted" - assert unencrypted_dict.get_plaintext() == {"plain": "text"} - assert encrypted_secret.get_plaintext() is None # Can't decrypt - assert encrypted_dict.get_plaintext() is None # Can't decrypt - - # Restore key - can read all again - settings.encryption_key = self.MOCK_KEY - assert unencrypted_secret.get_plaintext() == "unencrypted" - assert unencrypted_dict.get_plaintext() == {"plain": "text"} - assert encrypted_secret.get_plaintext() == "encrypted" - assert encrypted_dict.get_plaintext() == {"secure": "data"} - - finally: - settings.encryption_key = original_key