fix: update more plaintext non async callsites (#7223)
* bae * update * fix * clean up * last
This commit is contained in:
@@ -165,68 +165,36 @@ class MCPOAuthSession(BaseMCPOAuth):
|
||||
updated_at: datetime = Field(default_factory=datetime.now, description="Last update time")
|
||||
|
||||
def get_access_token_secret(self) -> Secret:
|
||||
"""Get the access token as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
|
||||
if self.access_token_enc is not None:
|
||||
return self.access_token_enc
|
||||
# Fallback to plaintext with error logging via Secret.from_db()
|
||||
return Secret.from_db(encrypted_value=None, plaintext_value=self.access_token)
|
||||
"""Get the access token as a Secret object."""
|
||||
return self.access_token_enc if self.access_token_enc is not None else Secret.from_plaintext(None)
|
||||
|
||||
def get_refresh_token_secret(self) -> Secret:
|
||||
"""Get the refresh token as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
|
||||
if self.refresh_token_enc is not None:
|
||||
return self.refresh_token_enc
|
||||
# Fallback to plaintext with error logging via Secret.from_db()
|
||||
return Secret.from_db(encrypted_value=None, plaintext_value=self.refresh_token)
|
||||
"""Get the refresh token as a Secret object."""
|
||||
return self.refresh_token_enc if self.refresh_token_enc is not None else Secret.from_plaintext(None)
|
||||
|
||||
def get_client_secret_secret(self) -> Secret:
|
||||
"""Get the client secret as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
|
||||
if self.client_secret_enc is not None:
|
||||
return self.client_secret_enc
|
||||
# Fallback to plaintext with error logging via Secret.from_db()
|
||||
return Secret.from_db(encrypted_value=None, plaintext_value=self.client_secret)
|
||||
"""Get the client secret as a Secret object."""
|
||||
return self.client_secret_enc if self.client_secret_enc is not None else Secret.from_plaintext(None)
|
||||
|
||||
def get_authorization_code_secret(self) -> Secret:
|
||||
"""Get the authorization code as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
|
||||
if self.authorization_code_enc is not None:
|
||||
return self.authorization_code_enc
|
||||
# Fallback to plaintext with error logging via Secret.from_db()
|
||||
return Secret.from_db(encrypted_value=None, plaintext_value=self.authorization_code)
|
||||
"""Get the authorization code as a Secret object."""
|
||||
return self.authorization_code_enc if self.authorization_code_enc is not None else Secret.from_plaintext(None)
|
||||
|
||||
def set_access_token_secret(self, secret: Secret) -> None:
|
||||
"""Set access token from a Secret object."""
|
||||
self.access_token_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
if not secret.was_encrypted:
|
||||
self.access_token = secret_dict["plaintext"]
|
||||
else:
|
||||
self.access_token = None
|
||||
|
||||
def set_refresh_token_secret(self, secret: Secret) -> None:
|
||||
"""Set refresh token from a Secret object."""
|
||||
self.refresh_token_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
if not secret.was_encrypted:
|
||||
self.refresh_token = secret_dict["plaintext"]
|
||||
else:
|
||||
self.refresh_token = None
|
||||
|
||||
def set_client_secret_secret(self, secret: Secret) -> None:
|
||||
"""Set client secret from a Secret object."""
|
||||
self.client_secret_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
if not secret.was_encrypted:
|
||||
self.client_secret = secret_dict["plaintext"]
|
||||
else:
|
||||
self.client_secret = None
|
||||
|
||||
def set_authorization_code_secret(self, secret: Secret) -> None:
|
||||
"""Set authorization code from a Secret object."""
|
||||
self.authorization_code_enc = secret
|
||||
secret_dict = secret.to_dict()
|
||||
if not secret.was_encrypted:
|
||||
self.authorization_code = secret_dict["plaintext"]
|
||||
else:
|
||||
self.authorization_code = None
|
||||
|
||||
|
||||
class MCPOAuthSessionCreate(BaseMCPOAuth):
|
||||
@@ -290,7 +258,7 @@ class UpdateMCPServerRequest(LettaBase):
|
||||
]
|
||||
|
||||
|
||||
def convert_generic_to_union(server) -> MCPServerUnion:
|
||||
async def convert_generic_to_union(server) -> MCPServerUnion:
|
||||
"""
|
||||
Convert a generic MCPServer (from letta.schemas.mcp) to the appropriate MCPServerUnion type
|
||||
based on the server_type field.
|
||||
@@ -319,9 +287,9 @@ def convert_generic_to_union(server) -> MCPServerUnion:
|
||||
env=server.stdio_config.env if server.stdio_config else None,
|
||||
)
|
||||
elif server.server_type == MCPServerType.SSE:
|
||||
# Get decrypted values from encrypted columns
|
||||
token = server.token_enc.get_plaintext() if server.token_enc else None
|
||||
headers = server.get_custom_headers_dict()
|
||||
# Get decrypted values from encrypted columns (async)
|
||||
token = await server.token_enc.get_plaintext_async() if server.token_enc else None
|
||||
headers = await server.get_custom_headers_dict_async()
|
||||
return SSEMCPServer(
|
||||
id=server.id,
|
||||
server_name=server.server_name,
|
||||
@@ -332,9 +300,9 @@ def convert_generic_to_union(server) -> MCPServerUnion:
|
||||
custom_headers=headers,
|
||||
)
|
||||
elif server.server_type == MCPServerType.STREAMABLE_HTTP:
|
||||
# Get decrypted values from encrypted columns
|
||||
token = server.token_enc.get_plaintext() if server.token_enc else None
|
||||
headers = server.get_custom_headers_dict()
|
||||
# Get decrypted values from encrypted columns (async)
|
||||
token = await server.token_enc.get_plaintext_async() if server.token_enc else None
|
||||
headers = await server.get_custom_headers_dict_async()
|
||||
return StreamableHTTPMCPServer(
|
||||
id=server.id,
|
||||
server_name=server.server_name,
|
||||
|
||||
@@ -108,7 +108,7 @@ class AnthropicProvider(Provider):
|
||||
base_url: str = "https://api.anthropic.com/v1"
|
||||
|
||||
async def check_api_key(self):
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
if api_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
try:
|
||||
@@ -137,7 +137,7 @@ class AnthropicProvider(Provider):
|
||||
|
||||
NOTE: currently there is no GET /models, so we need to hardcode
|
||||
"""
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
if api_key:
|
||||
anthropic_client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
|
||||
@@ -60,7 +60,7 @@ class AzureProvider(Provider):
|
||||
async def azure_openai_get_deployed_model_list(self) -> list:
|
||||
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
client = AsyncAzureOpenAI(api_key=api_key, api_version=self.api_version, azure_endpoint=self.base_url)
|
||||
|
||||
try:
|
||||
@@ -170,7 +170,7 @@ class AzureProvider(Provider):
|
||||
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default)
|
||||
|
||||
async def check_api_key(self):
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
if not api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ class BedrockProvider(Provider):
|
||||
|
||||
try:
|
||||
# Decrypt credentials before using
|
||||
access_key = self.access_key_enc.get_plaintext() if self.access_key_enc else None
|
||||
secret_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
access_key = await self.access_key_enc.get_plaintext_async() if self.access_key_enc else None
|
||||
secret_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
|
||||
session = Session()
|
||||
async with session.client(
|
||||
|
||||
@@ -41,7 +41,7 @@ class CerebrasProvider(OpenAIProvider):
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
|
||||
|
||||
if "data" in response:
|
||||
|
||||
@@ -34,7 +34,7 @@ class DeepSeekProvider(OpenAIProvider):
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
|
||||
data = response.get("data", response)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class GoogleAIProvider(Provider):
|
||||
async def check_api_key(self):
|
||||
from letta.llm_api.google_ai_client import google_ai_check_valid_api_key_async
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
await google_ai_check_valid_api_key_async(api_key)
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
@@ -36,7 +36,7 @@ class GoogleAIProvider(Provider):
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
|
||||
|
||||
# Get and filter the model list
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=api_key)
|
||||
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
|
||||
model_options = [str(m["name"]) for m in model_options]
|
||||
@@ -70,7 +70,7 @@ class GoogleAIProvider(Provider):
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
|
||||
|
||||
# TODO: use base_url instead
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=api_key)
|
||||
return self._list_embedding_models(model_options)
|
||||
|
||||
@@ -113,5 +113,5 @@ class GoogleAIProvider(Provider):
|
||||
if model_name in LLM_MAX_CONTEXT_WINDOW:
|
||||
return LLM_MAX_CONTEXT_WINDOW[model_name]
|
||||
else:
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
return await google_ai_get_model_context_window_async(self.base_url, api_key, model_name)
|
||||
|
||||
@@ -16,7 +16,7 @@ class GroqProvider(OpenAIProvider):
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
|
||||
configs = []
|
||||
for model in response["data"]:
|
||||
|
||||
@@ -18,7 +18,7 @@ class MistralProvider(Provider):
|
||||
|
||||
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
||||
# See: https://openrouter.ai/docs/requests
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
response = await mistral_get_model_list_async(self.base_url, api_key=api_key)
|
||||
|
||||
assert "data" in response, f"Mistral model query response missing 'data' field: {response}"
|
||||
|
||||
@@ -26,7 +26,7 @@ class OpenAIProvider(Provider):
|
||||
from letta.llm_api.openai import openai_check_valid_api_key # TODO: DO NOT USE THIS - old code path
|
||||
|
||||
# Decrypt API key before using
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
openai_check_valid_api_key(self.base_url, api_key)
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
@@ -48,7 +48,7 @@ class OpenAIProvider(Provider):
|
||||
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
|
||||
|
||||
# Decrypt API key before using
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
|
||||
response = await openai_get_model_list_async(
|
||||
self.base_url,
|
||||
|
||||
@@ -30,7 +30,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
models = await openai_get_model_list_async(self.base_url, api_key=api_key)
|
||||
return self._list_llm_models(models)
|
||||
|
||||
@@ -93,7 +93,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
return configs
|
||||
|
||||
async def check_api_key(self):
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
if not api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class XAIProvider(OpenAIProvider):
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
|
||||
|
||||
data = response.get("data", response)
|
||||
|
||||
@@ -169,8 +169,6 @@ class Secret(BaseModel):
|
||||
|
||||
# Use cached value if available
|
||||
if self._plaintext_cache is not None:
|
||||
if not self.was_encrypted:
|
||||
return self._plaintext_cache
|
||||
return self._plaintext_cache
|
||||
|
||||
# Try to decrypt (async)
|
||||
|
||||
@@ -50,7 +50,7 @@ async def create_mcp_server(
|
||||
# TODO: add the tools to the MCP server table we made.
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
new_server = await server.mcp_server_manager.create_mcp_server_from_request(request, actor=actor)
|
||||
return convert_generic_to_union(new_server)
|
||||
return await convert_generic_to_union(new_server)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -67,7 +67,10 @@ async def list_mcp_servers(
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
mcp_servers = await server.mcp_server_manager.list_mcp_servers(actor=actor)
|
||||
return [convert_generic_to_union(mcp_server) for mcp_server in mcp_servers]
|
||||
result = []
|
||||
for mcp_server in mcp_servers:
|
||||
result.append(await convert_generic_to_union(mcp_server))
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -85,7 +88,7 @@ async def retrieve_mcp_server(
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
current_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
|
||||
return convert_generic_to_union(current_server)
|
||||
return await convert_generic_to_union(current_server)
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -125,7 +128,7 @@ async def update_mcp_server(
|
||||
updated_server = await server.mcp_server_manager.update_mcp_server_by_id(
|
||||
mcp_server_id=mcp_server_id, mcp_server_update=internal_update, actor=actor
|
||||
)
|
||||
return convert_generic_to_union(updated_server)
|
||||
return await convert_generic_to_union(updated_server)
|
||||
|
||||
|
||||
@router.get("/{mcp_server_id}/tools", response_model=List[Tool], operation_id="mcp_list_tools_for_mcp_server")
|
||||
@@ -238,7 +241,7 @@ async def connect_mcp_server(
|
||||
mcp_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
|
||||
|
||||
# Convert the MCP server to the appropriate config type
|
||||
config = mcp_server.to_config(resolve_variables=False)
|
||||
config = await mcp_server.to_config_async(resolve_variables=False)
|
||||
|
||||
async def oauth_stream_generator(
|
||||
mcp_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
|
||||
|
||||
@@ -990,25 +990,20 @@ class MCPServerManager:
|
||||
"""
|
||||
Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields.
|
||||
|
||||
Note: Prefers encrypted columns (_enc fields), falls back to plaintext with error logging.
|
||||
This helps identify unmigrated data during the migration period.
|
||||
Note: Prefers encrypted columns (_enc fields), falls back to legacy plaintext columns.
|
||||
"""
|
||||
# Get decrypted values - prefer encrypted, fallback to plaintext with error logging
|
||||
access_token = await Secret.from_db(
|
||||
encrypted_value=oauth_session.access_token_enc, plaintext_value=oauth_session.access_token
|
||||
).get_plaintext_async()
|
||||
# Get decrypted values - prefer encrypted, fallback to legacy plaintext
|
||||
access_token_secret = Secret.from_encrypted(oauth_session.access_token_enc)
|
||||
access_token = await access_token_secret.get_plaintext_async()
|
||||
|
||||
refresh_token = await Secret.from_db(
|
||||
encrypted_value=oauth_session.refresh_token_enc, plaintext_value=oauth_session.refresh_token
|
||||
).get_plaintext_async()
|
||||
refresh_token_secret = Secret.from_encrypted(oauth_session.refresh_token_enc)
|
||||
refresh_token = await refresh_token_secret.get_plaintext_async()
|
||||
|
||||
client_secret = await Secret.from_db(
|
||||
encrypted_value=oauth_session.client_secret_enc, plaintext_value=oauth_session.client_secret
|
||||
).get_plaintext_async()
|
||||
client_secret_secret = Secret.from_encrypted(oauth_session.client_secret_enc)
|
||||
client_secret = await client_secret_secret.get_plaintext_async()
|
||||
|
||||
authorization_code = await Secret.from_db(
|
||||
encrypted_value=oauth_session.authorization_code_enc, plaintext_value=oauth_session.authorization_code
|
||||
).get_plaintext_async()
|
||||
authorization_code_secret = Secret.from_encrypted(oauth_session.authorization_code_enc)
|
||||
authorization_code = await authorization_code_secret.get_plaintext_async()
|
||||
|
||||
# Create the Pydantic object with encrypted fields as Secret objects
|
||||
pydantic_session = MCPOAuthSession(
|
||||
|
||||
@@ -28,7 +28,6 @@ class TestSecret:
|
||||
# Should store encrypted value
|
||||
assert secret.encrypted_value is not None
|
||||
assert secret.encrypted_value != plaintext
|
||||
assert secret.was_encrypted is False
|
||||
|
||||
# Should decrypt to original value
|
||||
assert secret.get_plaintext() == plaintext
|
||||
@@ -52,7 +51,6 @@ class TestSecret:
|
||||
# Should store the plaintext value directly in encrypted_value
|
||||
assert secret.encrypted_value == plaintext
|
||||
assert secret.get_plaintext() == plaintext
|
||||
assert not secret.was_encrypted
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
@@ -61,7 +59,6 @@ class TestSecret:
|
||||
secret = Secret.from_plaintext(None)
|
||||
|
||||
assert secret.encrypted_value is None
|
||||
assert secret.was_encrypted is False
|
||||
assert secret.get_plaintext() is None
|
||||
assert secret.is_empty() is True
|
||||
|
||||
@@ -79,78 +76,10 @@ class TestSecret:
|
||||
secret = Secret.from_encrypted(encrypted)
|
||||
|
||||
assert secret.encrypted_value == encrypted
|
||||
assert secret.was_encrypted is True
|
||||
assert secret.get_plaintext() == plaintext
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_db_with_encrypted_value(self):
|
||||
"""Test creating a Secret from database with encrypted value."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "database-secret"
|
||||
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY)
|
||||
|
||||
secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=None)
|
||||
|
||||
assert secret.encrypted_value == encrypted
|
||||
assert secret.was_encrypted is True
|
||||
assert secret.get_plaintext() == plaintext
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_db_with_plaintext_value_fallback(self, caplog):
|
||||
"""Test creating a Secret from database with only plaintext value falls back with error logging.
|
||||
|
||||
Note: In Phase 1 of migration, from_db() prefers encrypted but falls back to plaintext
|
||||
with error logging to help identify unmigrated data.
|
||||
"""
|
||||
import logging
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "legacy-plaintext"
|
||||
|
||||
# When only plaintext is provided, should fall back to plaintext with error logging
|
||||
with caplog.at_level(logging.ERROR):
|
||||
secret = Secret.from_db(encrypted_value=None, plaintext_value=plaintext)
|
||||
|
||||
# Should use the plaintext value (fallback)
|
||||
assert secret.get_plaintext() == plaintext
|
||||
|
||||
# Should have logged an error about reading from plaintext column
|
||||
assert "MIGRATION_NEEDED" in caplog.text
|
||||
assert "plaintext column" in caplog.text
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_from_db_dual_read(self):
|
||||
"""Test dual read functionality - prefer encrypted over plaintext."""
|
||||
from letta.settings import settings
|
||||
|
||||
original_key = settings.encryption_key
|
||||
settings.encryption_key = self.MOCK_KEY
|
||||
|
||||
try:
|
||||
plaintext = "correct-value"
|
||||
old_plaintext = "old-legacy-value"
|
||||
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY)
|
||||
|
||||
# When both values exist, should prefer encrypted
|
||||
secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=old_plaintext)
|
||||
|
||||
assert secret.get_plaintext() == plaintext # Should use encrypted value, not plaintext
|
||||
finally:
|
||||
settings.encryption_key = original_key
|
||||
|
||||
def test_get_encrypted(self):
|
||||
"""Test getting the encrypted value for database storage."""
|
||||
from letta.settings import settings
|
||||
|
||||
Reference in New Issue
Block a user