From db00320126236baa1fb423926d97e09be564b8fb Mon Sep 17 00:00:00 2001 From: jnjpng Date: Tue, 16 Sep 2025 15:56:17 -0700 Subject: [PATCH] fix: clean up mcp encryption tests and logic (#2958) --- letta/schemas/mcp.py | 109 +++---- letta/schemas/secret.py | 13 +- letta/services/mcp_manager.py | 111 ++----- tests/test_mcp_encryption.py | 540 ++++++++++++++++++---------------- tests/test_secret.py | 124 ++++++++ 5 files changed, 493 insertions(+), 404 deletions(-) diff --git a/letta/schemas/mcp.py b/letta/schemas/mcp.py index 6f273816..e68d614e 100644 --- a/letta/schemas/mcp.py +++ b/letta/schemas/mcp.py @@ -14,6 +14,7 @@ from letta.functions.mcp_client.types import ( from letta.orm.mcp_oauth import OAuthSessionStatus from letta.schemas.letta_base import LettaBase from letta.schemas.secret import Secret, SecretDict +from letta.settings import settings class BaseMCPServer(LettaBase): @@ -75,41 +76,24 @@ class MCPServer(BaseMCPServer): def model_dump(self, to_orm: bool = False, **kwargs): """Override model_dump to handle encryption when saving to database.""" - import os - - # Check environment variable directly to handle test patching - encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") - if not encryption_key: - from letta.settings import settings - - encryption_key = settings.encryption_key - data = super().model_dump(to_orm=to_orm, **kwargs) - if to_orm and encryption_key: - # Temporarily set the encryption key for Secret/SecretDict - from letta.settings import settings + if to_orm and settings.encryption_key: + # Encrypt token if present + if self.token is not None: + token_secret = Secret.from_plaintext(self.token) + secret_dict = token_secret.to_dict() + data["token_enc"] = secret_dict["encrypted"] + # Keep plaintext for dual-write during migration + data["token"] = secret_dict["plaintext"] - original_key = settings.encryption_key - settings.encryption_key = encryption_key - try: - # Encrypt token if present - if self.token is not None: - token_secret = Secret.from_plaintext(self.token) - secret_dict = token_secret.to_dict() - data["token_enc"] = secret_dict["encrypted"] - # Keep plaintext for dual-write during migration - data["token"] = secret_dict["plaintext"] - - # Encrypt custom headers if present - if self.custom_headers is not None: - headers_secret = SecretDict.from_plaintext(self.custom_headers) - secret_dict = headers_secret.to_dict() - data["custom_headers_enc"] = secret_dict["encrypted"] - # Keep plaintext for dual-write during migration - data["custom_headers"] = secret_dict["plaintext"] - finally: - settings.encryption_key = original_key + # Encrypt custom headers if present + if self.custom_headers is not None: + headers_secret = SecretDict.from_plaintext(self.custom_headers) + secret_dict = headers_secret.to_dict() + data["custom_headers_enc"] = secret_dict["encrypted"] + # Keep plaintext for dual-write during migration + data["custom_headers"] = secret_dict["plaintext"] return data @@ -277,49 +261,32 @@ class MCPOAuthSession(BaseMCPOAuth): def model_dump(self, to_orm: bool = False, **kwargs): """Override model_dump to handle encryption when saving to database.""" - import os - - # Check environment variable directly to handle test patching - encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") - if not encryption_key: - from letta.settings import settings - - encryption_key = settings.encryption_key - data = super().model_dump(to_orm=to_orm, **kwargs) - if to_orm and encryption_key: - # Temporarily set the encryption key for Secret - from letta.settings import settings + if to_orm and settings.encryption_key: + # Encrypt access token if present + if self.access_token is not None: + token_secret = Secret.from_plaintext(self.access_token) + secret_dict = token_secret.to_dict() + data["access_token_enc"] = secret_dict["encrypted"] + # Keep plaintext for dual-write during migration + data["access_token"] = secret_dict["plaintext"] - original_key = settings.encryption_key - settings.encryption_key = encryption_key - try: - # Encrypt access token if present - if self.access_token is not None: - token_secret = Secret.from_plaintext(self.access_token) - secret_dict = token_secret.to_dict() - data["access_token_enc"] = secret_dict["encrypted"] - # Keep plaintext for dual-write during migration - data["access_token"] = secret_dict["plaintext"] + # Encrypt refresh token if present + if self.refresh_token is not None: + token_secret = Secret.from_plaintext(self.refresh_token) + secret_dict = token_secret.to_dict() + data["refresh_token_enc"] = secret_dict["encrypted"] + # Keep plaintext for dual-write during migration + data["refresh_token"] = secret_dict["plaintext"] - # Encrypt refresh token if present - if self.refresh_token is not None: - token_secret = Secret.from_plaintext(self.refresh_token) - secret_dict = token_secret.to_dict() - data["refresh_token_enc"] = secret_dict["encrypted"] - # Keep plaintext for dual-write during migration - data["refresh_token"] = secret_dict["plaintext"] - - # Encrypt client secret if present - if self.client_secret is not None: - secret = Secret.from_plaintext(self.client_secret) - secret_dict = secret.to_dict() - data["client_secret_enc"] = secret_dict["encrypted"] - # Keep plaintext for dual-write during migration - data["client_secret"] = secret_dict["plaintext"] - finally: - settings.encryption_key = original_key + # Encrypt client secret if present + if self.client_secret is not None: + secret = Secret.from_plaintext(self.client_secret) + secret_dict = secret.to_dict() + data["client_secret_enc"] = secret_dict["encrypted"] + # Keep plaintext for dual-write during migration + data["client_secret"] = secret_dict["plaintext"] return data diff --git a/letta/schemas/secret.py b/letta/schemas/secret.py index fe2c7cf0..6cba8249 100644 --- a/letta/schemas/secret.py +++ b/letta/schemas/secret.py @@ -118,8 +118,8 @@ class Secret(BaseModel): # Decrypt and cache try: plaintext = CryptoUtils.decrypt(self._encrypted_value) - # Note: We can't actually cache due to frozen=True, but in practice - # we'll create new instances rather than mutating + # Cache the decrypted value (PrivateAttr fields can be mutated even with frozen=True) + self._plaintext_cache = plaintext return plaintext except Exception: # If decryption fails and this wasn't originally encrypted, @@ -224,9 +224,16 @@ class SecretDict(BaseModel): if self._encrypted_value is None: return None + # Use cached value if available + if self._plaintext_cache is not None: + return self._plaintext_cache + try: decrypted_json = CryptoUtils.decrypt(self._encrypted_value) - return json.loads(decrypted_json) + plaintext_dict = json.loads(decrypted_json) + # Cache the decrypted value (PrivateAttr fields can be mutated even with frozen=True) + self._plaintext_cache = plaintext_dict + return plaintext_dict except Exception: if not self._was_encrypted: return None diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 08bf38a4..89e320a3 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -755,55 +755,33 @@ class MCPManager: """ Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields. """ - # Check for encryption key from env or settings - import os - from letta.settings import settings - encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key - # Get decrypted values using the dual-read approach + # Secret.from_db() will automatically use settings.encryption_key if available access_token = None if oauth_session.access_token_enc or oauth_session.access_token: - if encryption_key: - # Temporarily set the key for Secret - original_key = settings.encryption_key - settings.encryption_key = encryption_key - try: - secret = Secret.from_db(oauth_session.access_token_enc, oauth_session.access_token) - access_token = secret.get_plaintext() - finally: - settings.encryption_key = original_key + if settings.encryption_key: + secret = Secret.from_db(oauth_session.access_token_enc, oauth_session.access_token) + access_token = secret.get_plaintext() else: # No encryption key, use plaintext if available access_token = oauth_session.access_token refresh_token = None if oauth_session.refresh_token_enc or oauth_session.refresh_token: - if encryption_key: - # Temporarily set the key for Secret - original_key = settings.encryption_key - settings.encryption_key = encryption_key - try: - secret = Secret.from_db(oauth_session.refresh_token_enc, oauth_session.refresh_token) - refresh_token = secret.get_plaintext() - finally: - settings.encryption_key = original_key + if settings.encryption_key: + secret = Secret.from_db(oauth_session.refresh_token_enc, oauth_session.refresh_token) + refresh_token = secret.get_plaintext() else: # No encryption key, use plaintext if available refresh_token = oauth_session.refresh_token client_secret = None if oauth_session.client_secret_enc or oauth_session.client_secret: - if encryption_key: - # Temporarily set the key for Secret - original_key = settings.encryption_key - settings.encryption_key = encryption_key - try: - secret = Secret.from_db(oauth_session.client_secret_enc, oauth_session.client_secret) - client_secret = secret.get_plaintext() - finally: - settings.encryption_key = original_key + if settings.encryption_key: + secret = Secret.from_db(oauth_session.client_secret_enc, oauth_session.client_secret) + client_secret = secret.get_plaintext() else: # No encryption key, use plaintext if available client_secret = oauth_session.client_secret @@ -900,25 +878,14 @@ class MCPManager: # Handle encryption for access_token if session_update.access_token is not None: - # Check for encryption key from env or settings - import os - from letta.settings import settings - encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key - - if encryption_key: - # Temporarily set the key for Secret - original_key = settings.encryption_key - settings.encryption_key = encryption_key - try: - token_secret = Secret.from_plaintext(session_update.access_token) - secret_dict = token_secret.to_dict() - oauth_session.access_token_enc = secret_dict["encrypted"] - # During migration phase, also update plaintext - oauth_session.access_token = secret_dict["plaintext"] if not token_secret._was_encrypted else None - finally: - settings.encryption_key = original_key + if settings.encryption_key: + token_secret = Secret.from_plaintext(session_update.access_token) + secret_dict = token_secret.to_dict() + oauth_session.access_token_enc = secret_dict["encrypted"] + # During migration phase, also update plaintext + oauth_session.access_token = secret_dict["plaintext"] if not token_secret._was_encrypted else None else: # No encryption, store plaintext oauth_session.access_token = session_update.access_token @@ -926,25 +893,14 @@ class MCPManager: # Handle encryption for refresh_token if session_update.refresh_token is not None: - # Check for encryption key from env or settings - import os - from letta.settings import settings - encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key - - if encryption_key: - # Temporarily set the key for Secret - original_key = settings.encryption_key - settings.encryption_key = encryption_key - try: - token_secret = Secret.from_plaintext(session_update.refresh_token) - secret_dict = token_secret.to_dict() - oauth_session.refresh_token_enc = secret_dict["encrypted"] - # During migration phase, also update plaintext - oauth_session.refresh_token = secret_dict["plaintext"] if not token_secret._was_encrypted else None - finally: - settings.encryption_key = original_key + if settings.encryption_key: + token_secret = Secret.from_plaintext(session_update.refresh_token) + secret_dict = token_secret.to_dict() + oauth_session.refresh_token_enc = secret_dict["encrypted"] + # During migration phase, also update plaintext + oauth_session.refresh_token = secret_dict["plaintext"] if not token_secret._was_encrypted else None else: # No encryption, store plaintext oauth_session.refresh_token = session_update.refresh_token @@ -961,25 +917,14 @@ class MCPManager: # Handle encryption for client_secret if session_update.client_secret is not None: - # Check for encryption key from env or settings - import os - from letta.settings import settings - encryption_key = os.environ.get("LETTA_ENCRYPTION_KEY") or settings.encryption_key - - if encryption_key: - # Temporarily set the key for Secret - original_key = settings.encryption_key - settings.encryption_key = encryption_key - try: - secret_secret = Secret.from_plaintext(session_update.client_secret) - secret_dict = secret_secret.to_dict() - oauth_session.client_secret_enc = secret_dict["encrypted"] - # During migration phase, also update plaintext - oauth_session.client_secret = secret_dict["plaintext"] if not secret_secret._was_encrypted else None - finally: - settings.encryption_key = original_key + if settings.encryption_key: + secret_secret = Secret.from_plaintext(session_update.client_secret) + secret_dict = secret_secret.to_dict() + oauth_session.client_secret_enc = secret_dict["encrypted"] + # During migration phase, also update plaintext + oauth_session.client_secret = secret_dict["plaintext"] if not secret_secret._was_encrypted else None else: # No encryption, store plaintext oauth_session.client_secret = session_update.client_secret diff --git a/tests/test_mcp_encryption.py b/tests/test_mcp_encryption.py index bcd2e827..cec05c3a 100644 --- a/tests/test_mcp_encryption.py +++ b/tests/test_mcp_encryption.py @@ -17,6 +17,7 @@ from letta.helpers.crypto_utils import CryptoUtils from letta.orm import MCPOAuth, MCPServer as ORMMCPServer from letta.schemas.mcp import ( MCPOAuthSessionCreate, + MCPOAuthSessionUpdate, MCPServer as PydanticMCPServer, MCPServerType, SSEServerConfig, @@ -26,6 +27,7 @@ from letta.schemas.secret import Secret, SecretDict from letta.server.db import db_registry from letta.server.server import SyncServer from letta.services.mcp_manager import MCPManager +from letta.settings import settings @pytest.fixture(scope="module") @@ -43,135 +45,151 @@ class TestMCPServerEncryption: MOCK_ENCRYPTION_KEY = "test-mcp-encryption-key-123456" @pytest.mark.asyncio - @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) @patch("letta.services.mcp_manager.MCPManager.get_mcp_client") async def test_create_mcp_server_with_token_encryption(self, mock_get_client, server, default_user): """Test that MCP server tokens are encrypted when stored.""" - # Mock the MCP client - mock_client = AsyncMock() - mock_client.list_tools.return_value = [] - mock_get_client.return_value = mock_client + # Set encryption key directly on settings + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_ENCRYPTION_KEY - # Create MCP server with token - server_name = f"test_encrypted_server_{uuid4().hex[:8]}" - token = "super-secret-api-token-12345" - server_url = "https://api.example.com/mcp" + try: + # Mock the MCP client + mock_client = AsyncMock() + mock_client.list_tools.return_value = [] + mock_get_client.return_value = mock_client - mcp_server = PydanticMCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url, token=token) + # Create MCP server with token + server_name = f"test_encrypted_server_{uuid4().hex[:8]}" + token = "super-secret-api-token-12345" + server_url = "https://api.example.com/mcp" - created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user) + mcp_server = PydanticMCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url, token=token) - # Verify server was created - assert created_server.server_name == server_name - assert created_server.server_type == MCPServerType.SSE + created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user) - # Check database directly to verify encryption - async with db_registry.async_session() as session: - result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id)) - db_server = result.scalar_one() + # Verify server was created + assert created_server.server_name == server_name + assert created_server.server_type == MCPServerType.SSE - # Token should be encrypted in database - assert db_server.token_enc is not None - assert db_server.token_enc != token # Should not be plaintext + # Check database directly to verify encryption + async with db_registry.async_session() as session: + result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id)) + db_server = result.scalar_one() - # Decrypt to verify correctness - decrypted_token = CryptoUtils.decrypt(db_server.token_enc, self.MOCK_ENCRYPTION_KEY) - assert decrypted_token == token + # Token should be encrypted in database + assert db_server.token_enc is not None + assert db_server.token_enc != token # Should not be plaintext - # Legacy plaintext column should be None (or empty for dual-write) - # During migration phase, might store both - if db_server.token: - assert db_server.token == token # Dual-write phase + # Decrypt to verify correctness + decrypted_token = CryptoUtils.decrypt(db_server.token_enc) + assert decrypted_token == token - # Clean up - await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + # Legacy plaintext column should be None (or empty for dual-write) + # During migration phase, might store both + if db_server.token: + assert db_server.token == token # Dual-write phase + + # Clean up + await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + + finally: + # Restore original encryption key + settings.encryption_key = original_key @pytest.mark.asyncio - @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) @patch("letta.services.mcp_manager.MCPManager.get_mcp_client") async def test_create_mcp_server_with_custom_headers_encryption(self, mock_get_client, server, default_user): """Test that MCP server custom headers are encrypted when stored.""" - # Mock the MCP client - mock_client = AsyncMock() - mock_client.list_tools.return_value = [] - mock_get_client.return_value = mock_client - - server_name = f"test_headers_server_{uuid4().hex[:8]}" - custom_headers = {"Authorization": "Bearer secret-token-xyz", "X-API-Key": "api-key-123456", "X-Custom-Header": "custom-value"} - server_url = "https://api.example.com/mcp" - - mcp_server = PydanticMCPServer( - server_name=server_name, server_type=MCPServerType.STREAMABLE_HTTP, server_url=server_url, custom_headers=custom_headers - ) - - created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user) - - # Check database directly - async with db_registry.async_session() as session: - result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id)) - db_server = result.scalar_one() - - # Custom headers should be encrypted as JSON - assert db_server.custom_headers_enc is not None - - # Decrypt and parse JSON - decrypted_json = CryptoUtils.decrypt(db_server.custom_headers_enc, self.MOCK_ENCRYPTION_KEY) - decrypted_headers = json.loads(decrypted_json) - assert decrypted_headers == custom_headers - - # Clean up - await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) - - @pytest.mark.asyncio - @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) - async def test_retrieve_mcp_server_decrypts_values(self, server, default_user): - """Test that retrieving MCP server decrypts encrypted values.""" - # Manually insert encrypted server into database - server_id = f"mcp_server-{uuid4().hex[:8]}" - server_name = f"test_decrypt_server_{uuid4().hex[:8]}" - plaintext_token = "decryption-test-token" - encrypted_token = CryptoUtils.encrypt(plaintext_token, self.MOCK_ENCRYPTION_KEY) - - async with db_registry.async_session() as session: - db_server = ORMMCPServer( - id=server_id, - server_name=server_name, - server_type=MCPServerType.SSE.value, - server_url="https://test.com", - token_enc=encrypted_token, - token=None, # No plaintext - created_by_id=default_user.id, - last_updated_by_id=default_user.id, - organization_id=default_user.organization_id, - created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc), - ) - session.add(db_server) - await session.commit() - - # Retrieve server directly by ID to avoid issues with other servers in DB - test_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user) - - assert test_server is not None - assert test_server.server_name == server_name - # Token should be decrypted when accessed via the secret method - # Ensure encryption key is available for decryption - from letta.settings import settings - + # Set encryption key directly on settings original_key = settings.encryption_key settings.encryption_key = self.MOCK_ENCRYPTION_KEY + try: - token_secret = test_server.get_token_secret() - assert token_secret.get_plaintext() == plaintext_token + # Mock the MCP client + mock_client = AsyncMock() + mock_client.list_tools.return_value = [] + mock_get_client.return_value = mock_client + + server_name = f"test_headers_server_{uuid4().hex[:8]}" + custom_headers = {"Authorization": "Bearer secret-token-xyz", "X-API-Key": "api-key-123456", "X-Custom-Header": "custom-value"} + server_url = "https://api.example.com/mcp" + + mcp_server = PydanticMCPServer( + server_name=server_name, server_type=MCPServerType.STREAMABLE_HTTP, server_url=server_url, custom_headers=custom_headers + ) + + created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user) + + # Check database directly + async with db_registry.async_session() as session: + result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == created_server.id)) + db_server = result.scalar_one() + + # Custom headers should be encrypted as JSON + assert db_server.custom_headers_enc is not None + + # Decrypt and parse JSON + decrypted_json = CryptoUtils.decrypt(db_server.custom_headers_enc) + decrypted_headers = json.loads(decrypted_json) + assert decrypted_headers == custom_headers + + # Clean up + await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + finally: + # Restore original encryption key settings.encryption_key = original_key - # Clean up - async with db_registry.async_session() as session: - result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == server_id)) - db_server = result.scalar_one() - await session.delete(db_server) - await session.commit() + @pytest.mark.asyncio + async def test_retrieve_mcp_server_decrypts_values(self, server, default_user): + """Test that retrieving MCP server decrypts encrypted values.""" + # Set encryption key directly on settings + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_ENCRYPTION_KEY + + try: + # Manually insert encrypted server into database + server_id = f"mcp_server-{uuid4().hex[:8]}" + server_name = f"test_decrypt_server_{uuid4().hex[:8]}" + plaintext_token = "decryption-test-token" + encrypted_token = CryptoUtils.encrypt(plaintext_token) + + async with db_registry.async_session() as session: + db_server = ORMMCPServer( + id=server_id, + server_name=server_name, + server_type=MCPServerType.SSE.value, + server_url="https://test.com", + token_enc=encrypted_token, + token=None, # No plaintext + created_by_id=default_user.id, + last_updated_by_id=default_user.id, + organization_id=default_user.organization_id, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(db_server) + await session.commit() + + # Retrieve server directly by ID to avoid issues with other servers in DB + test_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user) + + assert test_server is not None + assert test_server.server_name == server_name + # Token should be decrypted when accessed via the secret method + token_secret = test_server.get_token_secret() + assert token_secret.get_plaintext() == plaintext_token + + # Clean up + async with db_registry.async_session() as session: + result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == server_id)) + db_server = result.scalar_one() + await session.delete(db_server) + await session.commit() + + finally: + # Restore original encryption key + settings.encryption_key = original_key @pytest.mark.asyncio @patch.dict(os.environ, {}, clear=True) # No encryption key @@ -214,194 +232,222 @@ class TestMCPOAuthEncryption: MOCK_ENCRYPTION_KEY = "test-oauth-encryption-key-123456" @pytest.mark.asyncio - @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) async def test_create_oauth_session_with_encryption(self, server, default_user): """Test that OAuth tokens are encrypted when stored.""" - server_url = "https://github.com/mcp" - server_name = "GitHub MCP" + # Set encryption key directly on settings + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_ENCRYPTION_KEY - # Step 1: Create OAuth session (without tokens initially) - oauth_session_create = MCPOAuthSessionCreate( - server_url=server_url, - server_name=server_name, - organization_id=default_user.organization_id, - user_id=default_user.id, - ) + try: + server_url = "https://github.com/mcp" + server_name = "GitHub MCP" - created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user) + # Step 1: Create OAuth session (without tokens initially) + oauth_session_create = MCPOAuthSessionCreate( + server_url=server_url, + server_name=server_name, + organization_id=default_user.organization_id, + user_id=default_user.id, + ) - assert created_session.server_url == server_url - assert created_session.server_name == server_name + created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user) - # Step 2: Update session with tokens (simulating OAuth callback) - from letta.schemas.mcp import MCPOAuthSessionUpdate + assert created_session.server_url == server_url + assert created_session.server_name == server_name - update_data = MCPOAuthSessionUpdate( - access_token="github-access-token-abc123", - refresh_token="github-refresh-token-xyz789", - client_id="client-id-123", - client_secret="client-secret-super-secret", - expires_at=datetime.now(timezone.utc), - ) + # Step 2: Update session with tokens (simulating OAuth callback) + update_data = MCPOAuthSessionUpdate( + access_token="github-access-token-abc123", + refresh_token="github-refresh-token-xyz789", + client_id="client-id-123", + client_secret="client-secret-super-secret", + expires_at=datetime.now(timezone.utc), + ) - await server.mcp_manager.update_oauth_session(created_session.id, update_data, actor=default_user) + await server.mcp_manager.update_oauth_session(created_session.id, update_data, actor=default_user) - # Check database directly for encryption - async with db_registry.async_session() as session: - result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id)) - db_oauth = result.scalar_one() + # Check database directly for encryption + async with db_registry.async_session() as session: + result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id)) + db_oauth = result.scalar_one() - # All sensitive fields should be encrypted - assert db_oauth.access_token_enc is not None - assert db_oauth.access_token_enc != update_data.access_token + # All sensitive fields should be encrypted + assert db_oauth.access_token_enc is not None + assert db_oauth.access_token_enc != update_data.access_token - assert db_oauth.refresh_token_enc is not None + assert db_oauth.refresh_token_enc is not None - assert db_oauth.client_secret_enc is not None + assert db_oauth.client_secret_enc is not None - # Verify decryption - decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc, self.MOCK_ENCRYPTION_KEY) - assert decrypted_access == update_data.access_token + # Verify decryption + decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc) + assert decrypted_access == update_data.access_token - decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc, self.MOCK_ENCRYPTION_KEY) - assert decrypted_refresh == update_data.refresh_token + decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc) + assert decrypted_refresh == update_data.refresh_token - decrypted_secret = CryptoUtils.decrypt(db_oauth.client_secret_enc, self.MOCK_ENCRYPTION_KEY) - assert decrypted_secret == update_data.client_secret + decrypted_secret = CryptoUtils.decrypt(db_oauth.client_secret_enc) + assert decrypted_secret == update_data.client_secret + + finally: + # Restore original encryption key + settings.encryption_key = original_key # Clean up not needed - test database is reset @pytest.mark.asyncio - @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) async def test_retrieve_oauth_session_decrypts_tokens(self, server, default_user): """Test that retrieving OAuth session decrypts tokens.""" - # Manually insert encrypted OAuth session - session_id = f"mcp-oauth-{str(uuid4())[:8]}" - access_token = "test-access-token" - refresh_token = "test-refresh-token" - client_secret = "test-client-secret" + # Set encryption key directly on settings + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_ENCRYPTION_KEY - encrypted_access = CryptoUtils.encrypt(access_token, self.MOCK_ENCRYPTION_KEY) - encrypted_refresh = CryptoUtils.encrypt(refresh_token, self.MOCK_ENCRYPTION_KEY) - encrypted_secret = CryptoUtils.encrypt(client_secret, self.MOCK_ENCRYPTION_KEY) + try: + # Manually insert encrypted OAuth session + session_id = f"mcp-oauth-{str(uuid4())[:8]}" + access_token = "test-access-token" + refresh_token = "test-refresh-token" + client_secret = "test-client-secret" - async with db_registry.async_session() as session: - db_oauth = MCPOAuth( - id=session_id, - state=f"test-state-{uuid4().hex[:8]}", - server_url="https://test.com/mcp", - server_name="Test Provider", - access_token_enc=encrypted_access, - refresh_token_enc=encrypted_refresh, - client_id="test-client", - client_secret_enc=encrypted_secret, - user_id=default_user.id, - organization_id=default_user.organization_id, - created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc), - ) - session.add(db_oauth) - await session.commit() + encrypted_access = CryptoUtils.encrypt(access_token) + encrypted_refresh = CryptoUtils.encrypt(refresh_token) + encrypted_secret = CryptoUtils.encrypt(client_secret) - # Retrieve through manager by ID - test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user) - assert test_session is not None + async with db_registry.async_session() as session: + db_oauth = MCPOAuth( + id=session_id, + state=f"test-state-{uuid4().hex[:8]}", + server_url="https://test.com/mcp", + server_name="Test Provider", + access_token_enc=encrypted_access, + refresh_token_enc=encrypted_refresh, + client_id="test-client", + client_secret_enc=encrypted_secret, + user_id=default_user.id, + organization_id=default_user.organization_id, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(db_oauth) + await session.commit() - # Tokens should be decrypted - assert test_session.access_token == access_token - assert test_session.refresh_token == refresh_token - assert test_session.client_secret == client_secret + # Retrieve through manager by ID + test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user) + assert test_session is not None - # Clean up not needed - test database is reset + # Tokens should be decrypted + assert test_session.access_token == access_token + assert test_session.refresh_token == refresh_token + assert test_session.client_secret == client_secret + + # Clean up not needed - test database is reset + + finally: + # Restore original encryption key + settings.encryption_key = original_key @pytest.mark.asyncio - @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) async def test_update_oauth_session_maintains_encryption(self, server, default_user): """Test that updating OAuth session maintains encryption.""" - # Create initial session (without tokens) - from letta.schemas.mcp import MCPOAuthSessionUpdate + # Set encryption key directly on settings + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_ENCRYPTION_KEY - oauth_session_create = MCPOAuthSessionCreate( - server_url="https://test.com/mcp", - server_name="Test Update Provider", - organization_id=default_user.organization_id, - user_id=default_user.id, - ) + try: + # Create initial session (without tokens) + oauth_session_create = MCPOAuthSessionCreate( + server_url="https://test.com/mcp", + server_name="Test Update Provider", + organization_id=default_user.organization_id, + user_id=default_user.id, + ) - created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user) + created_session = await server.mcp_manager.create_oauth_session(oauth_session_create, actor=default_user) - # Add initial tokens - initial_update = MCPOAuthSessionUpdate( - access_token="initial-token", - refresh_token="initial-refresh", - client_id="client-123", - client_secret="initial-secret", - ) + # Add initial tokens + initial_update = MCPOAuthSessionUpdate( + access_token="initial-token", + refresh_token="initial-refresh", + client_id="client-123", + client_secret="initial-secret", + ) - await server.mcp_manager.update_oauth_session(created_session.id, initial_update, actor=default_user) + await server.mcp_manager.update_oauth_session(created_session.id, initial_update, actor=default_user) - # Update with new tokens - new_access_token = "updated-access-token" - new_refresh_token = "updated-refresh-token" + # Update with new tokens + new_access_token = "updated-access-token" + new_refresh_token = "updated-refresh-token" - new_update = MCPOAuthSessionUpdate( - access_token=new_access_token, - refresh_token=new_refresh_token, - ) + new_update = MCPOAuthSessionUpdate( + access_token=new_access_token, + refresh_token=new_refresh_token, + ) - updated_session = await server.mcp_manager.update_oauth_session(created_session.id, new_update, actor=default_user) + updated_session = await server.mcp_manager.update_oauth_session(created_session.id, new_update, actor=default_user) - # Verify update worked - assert updated_session.access_token == new_access_token - assert updated_session.refresh_token == new_refresh_token + # Verify update worked + assert updated_session.access_token == new_access_token + assert updated_session.refresh_token == new_refresh_token - # Check database encryption - async with db_registry.async_session() as session: - result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id)) - db_oauth = result.scalar_one() + # Check database encryption + async with db_registry.async_session() as session: + result = await session.execute(select(MCPOAuth).where(MCPOAuth.id == created_session.id)) + db_oauth = result.scalar_one() - # New tokens should be encrypted - decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc, self.MOCK_ENCRYPTION_KEY) - assert decrypted_access == new_access_token + # New tokens should be encrypted + decrypted_access = CryptoUtils.decrypt(db_oauth.access_token_enc) + assert decrypted_access == new_access_token - decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc, self.MOCK_ENCRYPTION_KEY) - assert decrypted_refresh == new_refresh_token + decrypted_refresh = CryptoUtils.decrypt(db_oauth.refresh_token_enc) + assert decrypted_refresh == new_refresh_token - # Clean up not needed - test database is reset + # Clean up not needed - test database is reset + + finally: + # Restore original encryption key + settings.encryption_key = original_key @pytest.mark.asyncio - @patch.dict(os.environ, {"LETTA_ENCRYPTION_KEY": MOCK_ENCRYPTION_KEY}) async def test_dual_read_backward_compatibility(self, server, default_user): """Test that system can read both encrypted and plaintext values (migration support).""" - # Insert a record with both encrypted and plaintext values - session_id = f"mcp-oauth-{str(uuid4())[:8]}" - plaintext_token = "legacy-plaintext-token" - new_encrypted_token = "new-encrypted-token" - encrypted_new = CryptoUtils.encrypt(new_encrypted_token, self.MOCK_ENCRYPTION_KEY) + # Set encryption key directly on settings + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_ENCRYPTION_KEY - async with db_registry.async_session() as session: - db_oauth = MCPOAuth( - id=session_id, - state=f"dual-read-state-{uuid4().hex[:8]}", - server_url="https://test.com/mcp", - server_name="Dual Read Test", - # Both encrypted and plaintext values - access_token=plaintext_token, # Legacy plaintext - access_token_enc=encrypted_new, # New encrypted - client_id="test-client", - user_id=default_user.id, - organization_id=default_user.organization_id, - created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc), - ) - session.add(db_oauth) - await session.commit() + try: + # Insert a record with both encrypted and plaintext values + session_id = f"mcp-oauth-{str(uuid4())[:8]}" + plaintext_token = "legacy-plaintext-token" + new_encrypted_token = "new-encrypted-token" + encrypted_new = CryptoUtils.encrypt(new_encrypted_token) - # Retrieve through manager - test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user) - assert test_session is not None + async with db_registry.async_session() as session: + db_oauth = MCPOAuth( + id=session_id, + state=f"dual-read-state-{uuid4().hex[:8]}", + server_url="https://test.com/mcp", + server_name="Dual Read Test", + # Both encrypted and plaintext values + access_token=plaintext_token, # Legacy plaintext + access_token_enc=encrypted_new, # New encrypted + client_id="test-client", + user_id=default_user.id, + organization_id=default_user.organization_id, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(db_oauth) + await session.commit() - # Should prefer encrypted value over plaintext - assert test_session.access_token == new_encrypted_token + # Retrieve through manager + test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user) + assert test_session is not None - # Clean up not needed - test database is reset + # Should prefer encrypted value over plaintext + assert test_session.access_token == new_encrypted_token + + # Clean up not needed - test database is reset + + finally: + # Restore original encryption key + settings.encryption_key = original_key diff --git a/tests/test_secret.py b/tests/test_secret.py index a273344f..59533918 100644 --- a/tests/test_secret.py +++ b/tests/test_secret.py @@ -1,4 +1,5 @@ import json +from unittest.mock import MagicMock, patch import pytest @@ -219,6 +220,61 @@ class TestSecret: finally: settings.encryption_key = original_key + def test_plaintext_caching(self): + """Test that plaintext values are cached after first decryption.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext = "cached-value" + secret = Secret.from_plaintext(plaintext) + + # First call should decrypt and cache + result1 = secret.get_plaintext() + assert result1 == plaintext + assert secret._plaintext_cache == plaintext + + # Second call should use cache + result2 = secret.get_plaintext() + assert result2 == plaintext + assert result1 is result2 # Should be the same object reference + finally: + settings.encryption_key = original_key + + def test_caching_only_decrypts_once(self): + """Test that decryption only happens once when caching is enabled.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext = "test-single-decrypt" + encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY) + + # Create a Secret from encrypted value + secret = Secret.from_encrypted(encrypted) + + # Mock the decrypt method to track calls + with patch.object(CryptoUtils, "decrypt", wraps=CryptoUtils.decrypt) as mock_decrypt: + # First call should decrypt + result1 = secret.get_plaintext() + assert result1 == plaintext + assert mock_decrypt.call_count == 1 + + # Second and third calls should use cache + result2 = secret.get_plaintext() + result3 = secret.get_plaintext() + assert result2 == plaintext + assert result3 == plaintext + + # Decrypt should still have been called only once + assert mock_decrypt.call_count == 1 + finally: + settings.encryption_key = original_key + class TestSecretDict: """Test suite for SecretDict wrapper class.""" @@ -371,3 +427,71 @@ class TestSecretDict: assert secret_dict.get_plaintext() == new_dict finally: settings.encryption_key = original_key + + def test_plaintext_dict_caching(self): + """Test that plaintext dictionary values are cached after first decryption.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext_dict = {"key1": "value1", "key2": "value2", "nested": {"inner": "value"}} + secret_dict = SecretDict.from_plaintext(plaintext_dict) + + # First call should decrypt and cache + result1 = secret_dict.get_plaintext() + assert result1 == plaintext_dict + assert secret_dict._plaintext_cache == plaintext_dict + + # Second call should use cache + result2 = secret_dict.get_plaintext() + assert result2 == plaintext_dict + assert result1 is result2 # Should be the same object reference + finally: + settings.encryption_key = original_key + + def test_dict_caching_only_decrypts_once(self): + """Test that SecretDict decryption only happens once when caching is enabled.""" + from letta.settings import settings + + original_key = settings.encryption_key + settings.encryption_key = self.MOCK_KEY + + try: + plaintext_dict = {"api_key": "sk-12345", "api_secret": "secret-value"} + encrypted = CryptoUtils.encrypt(json.dumps(plaintext_dict), self.MOCK_KEY) + + # Create a SecretDict from encrypted value + secret_dict = SecretDict.from_encrypted(encrypted) + + # Mock the decrypt method to track calls + with patch.object(CryptoUtils, "decrypt", wraps=CryptoUtils.decrypt) as mock_decrypt: + # First call should decrypt + result1 = secret_dict.get_plaintext() + assert result1 == plaintext_dict + assert mock_decrypt.call_count == 1 + + # Second and third calls should use cache + result2 = secret_dict.get_plaintext() + result3 = secret_dict.get_plaintext() + assert result2 == plaintext_dict + assert result3 == plaintext_dict + + # Decrypt should still have been called only once + assert mock_decrypt.call_count == 1 + finally: + settings.encryption_key = original_key + + def test_cache_handles_none_values(self): + """Test that caching works correctly with None/empty values.""" + # Test with None value + secret_dict = SecretDict.from_plaintext(None) + + # First call + result1 = secret_dict.get_plaintext() + assert result1 is None + + # Second call should also return None (not trying to decrypt) + result2 = secret_dict.get_plaintext() + assert result2 is None