fix: update more plaintext non async callsites (#7223)

* bae

* update

* fix

* clean up

* last
This commit is contained in:
jnjpng
2025-12-16 17:16:27 -08:00
committed by Caren Thomas
parent 591420876a
commit 350f3a751c
16 changed files with 52 additions and 159 deletions

View File

@@ -165,68 +165,36 @@ class MCPOAuthSession(BaseMCPOAuth):
updated_at: datetime = Field(default_factory=datetime.now, description="Last update time")
def get_access_token_secret(self) -> Secret:
"""Get the access token as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
if self.access_token_enc is not None:
return self.access_token_enc
# Fallback to plaintext with error logging via Secret.from_db()
return Secret.from_db(encrypted_value=None, plaintext_value=self.access_token)
"""Get the access token as a Secret object."""
return self.access_token_enc if self.access_token_enc is not None else Secret.from_plaintext(None)
def get_refresh_token_secret(self) -> Secret:
"""Get the refresh token as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
if self.refresh_token_enc is not None:
return self.refresh_token_enc
# Fallback to plaintext with error logging via Secret.from_db()
return Secret.from_db(encrypted_value=None, plaintext_value=self.refresh_token)
"""Get the refresh token as a Secret object."""
return self.refresh_token_enc if self.refresh_token_enc is not None else Secret.from_plaintext(None)
def get_client_secret_secret(self) -> Secret:
"""Get the client secret as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
if self.client_secret_enc is not None:
return self.client_secret_enc
# Fallback to plaintext with error logging via Secret.from_db()
return Secret.from_db(encrypted_value=None, plaintext_value=self.client_secret)
"""Get the client secret as a Secret object."""
return self.client_secret_enc if self.client_secret_enc is not None else Secret.from_plaintext(None)
def get_authorization_code_secret(self) -> Secret:
"""Get the authorization code as a Secret object. Prefers encrypted, falls back to plaintext with error logging."""
if self.authorization_code_enc is not None:
return self.authorization_code_enc
# Fallback to plaintext with error logging via Secret.from_db()
return Secret.from_db(encrypted_value=None, plaintext_value=self.authorization_code)
"""Get the authorization code as a Secret object."""
return self.authorization_code_enc if self.authorization_code_enc is not None else Secret.from_plaintext(None)
def set_access_token_secret(self, secret: Secret) -> None:
"""Set access token from a Secret object."""
self.access_token_enc = secret
secret_dict = secret.to_dict()
if not secret.was_encrypted:
self.access_token = secret_dict["plaintext"]
else:
self.access_token = None
def set_refresh_token_secret(self, secret: Secret) -> None:
"""Set refresh token from a Secret object."""
self.refresh_token_enc = secret
secret_dict = secret.to_dict()
if not secret.was_encrypted:
self.refresh_token = secret_dict["plaintext"]
else:
self.refresh_token = None
def set_client_secret_secret(self, secret: Secret) -> None:
"""Set client secret from a Secret object."""
self.client_secret_enc = secret
secret_dict = secret.to_dict()
if not secret.was_encrypted:
self.client_secret = secret_dict["plaintext"]
else:
self.client_secret = None
def set_authorization_code_secret(self, secret: Secret) -> None:
"""Set authorization code from a Secret object."""
self.authorization_code_enc = secret
secret_dict = secret.to_dict()
if not secret.was_encrypted:
self.authorization_code = secret_dict["plaintext"]
else:
self.authorization_code = None
class MCPOAuthSessionCreate(BaseMCPOAuth):
@@ -290,7 +258,7 @@ class UpdateMCPServerRequest(LettaBase):
]
def convert_generic_to_union(server) -> MCPServerUnion:
async def convert_generic_to_union(server) -> MCPServerUnion:
"""
Convert a generic MCPServer (from letta.schemas.mcp) to the appropriate MCPServerUnion type
based on the server_type field.
@@ -319,9 +287,9 @@ def convert_generic_to_union(server) -> MCPServerUnion:
env=server.stdio_config.env if server.stdio_config else None,
)
elif server.server_type == MCPServerType.SSE:
# Get decrypted values from encrypted columns
token = server.token_enc.get_plaintext() if server.token_enc else None
headers = server.get_custom_headers_dict()
# Get decrypted values from encrypted columns (async)
token = await server.token_enc.get_plaintext_async() if server.token_enc else None
headers = await server.get_custom_headers_dict_async()
return SSEMCPServer(
id=server.id,
server_name=server.server_name,
@@ -332,9 +300,9 @@ def convert_generic_to_union(server) -> MCPServerUnion:
custom_headers=headers,
)
elif server.server_type == MCPServerType.STREAMABLE_HTTP:
# Get decrypted values from encrypted columns
token = server.token_enc.get_plaintext() if server.token_enc else None
headers = server.get_custom_headers_dict()
# Get decrypted values from encrypted columns (async)
token = await server.token_enc.get_plaintext_async() if server.token_enc else None
headers = await server.get_custom_headers_dict_async()
return StreamableHTTPMCPServer(
id=server.id,
server_name=server.server_name,

View File

@@ -108,7 +108,7 @@ class AnthropicProvider(Provider):
base_url: str = "https://api.anthropic.com/v1"
async def check_api_key(self):
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
if api_key:
anthropic_client = anthropic.Anthropic(api_key=api_key)
try:
@@ -137,7 +137,7 @@ class AnthropicProvider(Provider):
NOTE: currently there is no GET /models, so we need to hardcode
"""
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
if api_key:
anthropic_client = anthropic.AsyncAnthropic(api_key=api_key)
elif model_settings.anthropic_api_key:

View File

@@ -60,7 +60,7 @@ class AzureProvider(Provider):
async def azure_openai_get_deployed_model_list(self) -> list:
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
client = AsyncAzureOpenAI(api_key=api_key, api_version=self.api_version, azure_endpoint=self.base_url)
try:
@@ -170,7 +170,7 @@ class AzureProvider(Provider):
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default)
async def check_api_key(self):
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
if not api_key:
raise ValueError("No API key provided")

View File

@@ -26,8 +26,8 @@ class BedrockProvider(Provider):
try:
# Decrypt credentials before using
access_key = self.access_key_enc.get_plaintext() if self.access_key_enc else None
secret_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
access_key = await self.access_key_enc.get_plaintext_async() if self.access_key_enc else None
secret_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
session = Session()
async with session.client(

View File

@@ -41,7 +41,7 @@ class CerebrasProvider(OpenAIProvider):
async def list_llm_models_async(self) -> list[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list_async
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
if "data" in response:

View File

@@ -34,7 +34,7 @@ class DeepSeekProvider(OpenAIProvider):
async def list_llm_models_async(self) -> list[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list_async
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
data = response.get("data", response)

View File

@@ -23,7 +23,7 @@ class GoogleAIProvider(Provider):
async def check_api_key(self):
from letta.llm_api.google_ai_client import google_ai_check_valid_api_key_async
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
await google_ai_check_valid_api_key_async(api_key)
def get_default_max_output_tokens(self, model_name: str) -> int:
@@ -36,7 +36,7 @@ class GoogleAIProvider(Provider):
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
# Get and filter the model list
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=api_key)
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options]
@@ -70,7 +70,7 @@ class GoogleAIProvider(Provider):
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
# TODO: use base_url instead
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=api_key)
return self._list_embedding_models(model_options)
@@ -113,5 +113,5 @@ class GoogleAIProvider(Provider):
if model_name in LLM_MAX_CONTEXT_WINDOW:
return LLM_MAX_CONTEXT_WINDOW[model_name]
else:
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
return await google_ai_get_model_context_window_async(self.base_url, api_key, model_name)

View File

@@ -16,7 +16,7 @@ class GroqProvider(OpenAIProvider):
async def list_llm_models_async(self) -> list[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list_async
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
configs = []
for model in response["data"]:

View File

@@ -18,7 +18,7 @@ class MistralProvider(Provider):
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
# See: https://openrouter.ai/docs/requests
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
response = await mistral_get_model_list_async(self.base_url, api_key=api_key)
assert "data" in response, f"Mistral model query response missing 'data' field: {response}"

View File

@@ -26,7 +26,7 @@ class OpenAIProvider(Provider):
from letta.llm_api.openai import openai_check_valid_api_key # TODO: DO NOT USE THIS - old code path
# Decrypt API key before using
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
openai_check_valid_api_key(self.base_url, api_key)
def get_default_max_output_tokens(self, model_name: str) -> int:
@@ -48,7 +48,7 @@ class OpenAIProvider(Provider):
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
# Decrypt API key before using
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
response = await openai_get_model_list_async(
self.base_url,

View File

@@ -30,7 +30,7 @@ class TogetherProvider(OpenAIProvider):
async def list_llm_models_async(self) -> list[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list_async
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
models = await openai_get_model_list_async(self.base_url, api_key=api_key)
return self._list_llm_models(models)
@@ -93,7 +93,7 @@ class TogetherProvider(OpenAIProvider):
return configs
async def check_api_key(self):
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
if not api_key:
raise ValueError("No API key provided")

View File

@@ -38,7 +38,7 @@ class XAIProvider(OpenAIProvider):
async def list_llm_models_async(self) -> list[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list_async
api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
data = response.get("data", response)

View File

@@ -169,8 +169,6 @@ class Secret(BaseModel):
# Use cached value if available
if self._plaintext_cache is not None:
if not self.was_encrypted:
return self._plaintext_cache
return self._plaintext_cache
# Try to decrypt (async)

View File

@@ -50,7 +50,7 @@ async def create_mcp_server(
# TODO: add the tools to the MCP server table we made.
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
new_server = await server.mcp_server_manager.create_mcp_server_from_request(request, actor=actor)
return convert_generic_to_union(new_server)
return await convert_generic_to_union(new_server)
@router.get(
@@ -67,7 +67,10 @@ async def list_mcp_servers(
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
mcp_servers = await server.mcp_server_manager.list_mcp_servers(actor=actor)
return [convert_generic_to_union(mcp_server) for mcp_server in mcp_servers]
result = []
for mcp_server in mcp_servers:
result.append(await convert_generic_to_union(mcp_server))
return result
@router.get(
@@ -85,7 +88,7 @@ async def retrieve_mcp_server(
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
current_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
return convert_generic_to_union(current_server)
return await convert_generic_to_union(current_server)
@router.delete(
@@ -125,7 +128,7 @@ async def update_mcp_server(
updated_server = await server.mcp_server_manager.update_mcp_server_by_id(
mcp_server_id=mcp_server_id, mcp_server_update=internal_update, actor=actor
)
return convert_generic_to_union(updated_server)
return await convert_generic_to_union(updated_server)
@router.get("/{mcp_server_id}/tools", response_model=List[Tool], operation_id="mcp_list_tools_for_mcp_server")
@@ -238,7 +241,7 @@ async def connect_mcp_server(
mcp_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
# Convert the MCP server to the appropriate config type
config = mcp_server.to_config(resolve_variables=False)
config = await mcp_server.to_config_async(resolve_variables=False)
async def oauth_stream_generator(
mcp_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],

View File

@@ -990,25 +990,20 @@ class MCPServerManager:
"""
Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields.
Note: Prefers encrypted columns (_enc fields), falls back to plaintext with error logging.
This helps identify unmigrated data during the migration period.
Note: Prefers encrypted columns (_enc fields), falls back to legacy plaintext columns.
"""
# Get decrypted values - prefer encrypted, fallback to plaintext with error logging
access_token = await Secret.from_db(
encrypted_value=oauth_session.access_token_enc, plaintext_value=oauth_session.access_token
).get_plaintext_async()
# Get decrypted values - prefer encrypted, fallback to legacy plaintext
access_token_secret = Secret.from_encrypted(oauth_session.access_token_enc)
access_token = await access_token_secret.get_plaintext_async()
refresh_token = await Secret.from_db(
encrypted_value=oauth_session.refresh_token_enc, plaintext_value=oauth_session.refresh_token
).get_plaintext_async()
refresh_token_secret = Secret.from_encrypted(oauth_session.refresh_token_enc)
refresh_token = await refresh_token_secret.get_plaintext_async()
client_secret = await Secret.from_db(
encrypted_value=oauth_session.client_secret_enc, plaintext_value=oauth_session.client_secret
).get_plaintext_async()
client_secret_secret = Secret.from_encrypted(oauth_session.client_secret_enc)
client_secret = await client_secret_secret.get_plaintext_async()
authorization_code = await Secret.from_db(
encrypted_value=oauth_session.authorization_code_enc, plaintext_value=oauth_session.authorization_code
).get_plaintext_async()
authorization_code_secret = Secret.from_encrypted(oauth_session.authorization_code_enc)
authorization_code = await authorization_code_secret.get_plaintext_async()
# Create the Pydantic object with encrypted fields as Secret objects
pydantic_session = MCPOAuthSession(

View File

@@ -28,7 +28,6 @@ class TestSecret:
# Should store encrypted value
assert secret.encrypted_value is not None
assert secret.encrypted_value != plaintext
assert secret.was_encrypted is False
# Should decrypt to original value
assert secret.get_plaintext() == plaintext
@@ -52,7 +51,6 @@ class TestSecret:
# Should store the plaintext value directly in encrypted_value
assert secret.encrypted_value == plaintext
assert secret.get_plaintext() == plaintext
assert not secret.was_encrypted
finally:
settings.encryption_key = original_key
@@ -61,7 +59,6 @@ class TestSecret:
secret = Secret.from_plaintext(None)
assert secret.encrypted_value is None
assert secret.was_encrypted is False
assert secret.get_plaintext() is None
assert secret.is_empty() is True
@@ -79,78 +76,10 @@ class TestSecret:
secret = Secret.from_encrypted(encrypted)
assert secret.encrypted_value == encrypted
assert secret.was_encrypted is True
assert secret.get_plaintext() == plaintext
finally:
settings.encryption_key = original_key
def test_from_db_with_encrypted_value(self):
"""Test creating a Secret from database with encrypted value."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext = "database-secret"
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY)
secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=None)
assert secret.encrypted_value == encrypted
assert secret.was_encrypted is True
assert secret.get_plaintext() == plaintext
finally:
settings.encryption_key = original_key
def test_from_db_with_plaintext_value_fallback(self, caplog):
"""Test creating a Secret from database with only plaintext value falls back with error logging.
Note: In Phase 1 of migration, from_db() prefers encrypted but falls back to plaintext
with error logging to help identify unmigrated data.
"""
import logging
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext = "legacy-plaintext"
# When only plaintext is provided, should fall back to plaintext with error logging
with caplog.at_level(logging.ERROR):
secret = Secret.from_db(encrypted_value=None, plaintext_value=plaintext)
# Should use the plaintext value (fallback)
assert secret.get_plaintext() == plaintext
# Should have logged an error about reading from plaintext column
assert "MIGRATION_NEEDED" in caplog.text
assert "plaintext column" in caplog.text
finally:
settings.encryption_key = original_key
def test_from_db_dual_read(self):
"""Test dual read functionality - prefer encrypted over plaintext."""
from letta.settings import settings
original_key = settings.encryption_key
settings.encryption_key = self.MOCK_KEY
try:
plaintext = "correct-value"
old_plaintext = "old-legacy-value"
encrypted = CryptoUtils.encrypt(plaintext, self.MOCK_KEY)
# When both values exist, should prefer encrypted
secret = Secret.from_db(encrypted_value=encrypted, plaintext_value=old_plaintext)
assert secret.get_plaintext() == plaintext # Should use encrypted value, not plaintext
finally:
settings.encryption_key = original_key
def test_get_encrypted(self):
"""Test getting the encrypted value for database storage."""
from letta.settings import settings