fix: update aysnc get plaintext callsites (#7069)

* base

* resolve

* fix

* fix
This commit is contained in:
jnjpng
2025-12-16 15:10:14 -08:00
committed by Caren Thomas
parent e6a4b3e874
commit 25d75d6528
15 changed files with 154 additions and 75 deletions

View File

@@ -1902,8 +1902,8 @@ class LettaAgent(BaseAgent):
start_time = get_utc_timestamp_ns()
agent_step_span.add_event(name="tool_execution_started")
# Decrypt environment variable values
sandbox_env_vars = {var.key: var.value_enc.get_plaintext() if var.value_enc else None for var in agent_state.secrets}
# Use pre-decrypted environment variable values (populated in from_orm_async)
sandbox_env_vars = {var.key: var.value or "" for var in agent_state.secrets}
tool_execution_manager = ToolExecutionManager(
agent_state=agent_state,
message_manager=self.message_manager,

View File

@@ -1184,8 +1184,8 @@ class LettaAgentV2(BaseAgentV2):
start_time = get_utc_timestamp_ns()
agent_step_span.add_event(name="tool_execution_started")
# Decrypt environment variable values
sandbox_env_vars = {var.key: var.value_enc.get_plaintext() if var.value_enc else None for var in agent_state.secrets}
# Use pre-decrypted environment variable values (populated in from_orm_async)
sandbox_env_vars = {var.key: var.value or "" for var in agent_state.secrets}
tool_execution_manager = ToolExecutionManager(
agent_state=agent_state,
message_manager=self.message_manager,

View File

@@ -438,8 +438,8 @@ class VoiceAgent(BaseAgent):
)
# Use ToolExecutionManager for modern tool execution
# Decrypt environment variable values
sandbox_env_vars = {var.key: var.value_enc.get_plaintext() if var.value_enc else None for var in agent_state.secrets}
# Use pre-decrypted environment variable values (populated in from_orm_async)
sandbox_env_vars = {var.key: var.value or "" for var in agent_state.secrets}
tool_execution_manager = ToolExecutionManager(
agent_state=agent_state,
message_manager=self.message_manager,

View File

@@ -434,7 +434,9 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
state["multi_agent_group"] = multi_agent_group
state["managed_group"] = multi_agent_group
# Convert ORM env vars to Pydantic with async decryption
env_vars_pydantic = [await PydanticAgentEnvVar.from_orm_async(e) for e in tool_exec_environment_variables]
env_vars_pydantic = []
for e in tool_exec_environment_variables:
env_vars_pydantic.append(await PydanticAgentEnvVar.from_orm_async(e))
state["tool_exec_environment_variables"] = env_vars_pydantic
state["secrets"] = env_vars_pydantic
state["model"] = self.llm_config.handle if self.llm_config else None

View File

@@ -70,6 +70,19 @@ class MCPServer(BaseMCPServer):
logger.warning(f"Failed to parse custom_headers_enc for MCP server {self.id}: {e}")
return None
async def get_custom_headers_dict_async(self) -> Optional[Dict[str, str]]:
"""Get custom headers as a plaintext dictionary (async version)."""
secret = self.get_custom_headers_secret()
if secret is None:
return None
json_str = await secret.get_plaintext_async()
if json_str:
try:
return json.loads(json_str)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"Failed to parse custom_headers_enc for MCP server {self.id}: {e}")
return None
def set_token_secret(self, secret: Secret) -> None:
"""Set token from a Secret object."""
self.token_enc = secret
@@ -130,6 +143,53 @@ class MCPServer(BaseMCPServer):
else:
raise ValueError(f"Unsupported server type: {self.server_type}")
async def to_config_async(
self,
environment_variables: Optional[Dict[str, str]] = None,
resolve_variables: bool = True,
) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]:
"""Async version of to_config() that uses async decryption."""
# Get decrypted values for use in config
token_secret = self.get_token_secret()
token_plaintext = await token_secret.get_plaintext_async() if token_secret else None
# Get custom headers as dict
headers_plaintext = await self.get_custom_headers_dict_async()
if self.server_type == MCPServerType.SSE:
config = SSEServerConfig(
server_name=self.server_name,
server_url=self.server_url,
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if token_plaintext and not headers_plaintext else None,
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {token_plaintext}" if token_plaintext and not headers_plaintext else None,
custom_headers=headers_plaintext,
)
if resolve_variables:
config.resolve_environment_variables(environment_variables)
return config
elif self.server_type == MCPServerType.STDIO:
if self.stdio_config is None:
raise ValueError("stdio_config is required for STDIO server type")
if resolve_variables:
self.stdio_config.resolve_environment_variables(environment_variables)
return self.stdio_config
elif self.server_type == MCPServerType.STREAMABLE_HTTP:
if self.server_url is None:
raise ValueError("server_url is required for STREAMABLE_HTTP server type")
config = StreamableHTTPServerConfig(
server_name=self.server_name,
server_url=self.server_url,
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if token_plaintext and not headers_plaintext else None,
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {token_plaintext}" if token_plaintext and not headers_plaintext else None,
custom_headers=headers_plaintext,
)
if resolve_variables:
config.resolve_environment_variables(environment_variables)
return config
else:
raise ValueError(f"Unsupported server type: {self.server_type}")
class UpdateSSEMCPServer(LettaBase):
"""Update an SSE MCP server"""

View File

@@ -638,10 +638,11 @@ async def run_tool_for_agent(
)
# Build environment variables dict from agent secrets
# Use pre-decrypted value field (populated in from_orm_async)
sandbox_env_vars = {}
if agent.tool_exec_environment_variables:
for env_var in agent.tool_exec_environment_variables:
sandbox_env_vars[env_var.key] = env_var.value_enc.get_plaintext() if env_var.value_enc else None
sandbox_env_vars[env_var.key] = env_var.value or ""
# Create tool execution manager and execute the tool
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager

View File

@@ -122,8 +122,8 @@ async def check_existing_provider(
provider = await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor)
# Create a ProviderCheck from the existing provider
api_key = provider.api_key_enc.get_plaintext() if provider.api_key_enc else None
access_key = provider.access_key_enc.get_plaintext() if provider.access_key_enc else None
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
provider_check = ProviderCheck(
provider_type=provider.provider_type,
api_key=api_key,

View File

@@ -427,7 +427,10 @@ async def list_mcp_servers(
else:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
mcp_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
return {server.server_name: server.to_config(resolve_variables=False) for server in mcp_servers}
result = {}
for mcp_server in mcp_servers:
result[mcp_server.server_name] = await mcp_server.to_config_async(resolve_variables=False)
return result
# NOTE: async because the MCP client/session calls are async
@@ -556,7 +559,10 @@ async def add_mcp_server_to_config(
# TODO: don't do this in the future (just return MCPServer)
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
return [server.to_config() for server in all_servers]
result = []
for mcp_server in all_servers:
result.append(await mcp_server.to_config_async())
return result
@router.patch(
@@ -581,7 +587,7 @@ async def update_mcp_server(
updated_server = await server.mcp_manager.update_mcp_server_by_name(
mcp_server_name=mcp_server_name, mcp_server_update=request, actor=actor
)
return updated_server.to_config()
return await updated_server.to_config_async()
@router.delete(
@@ -608,7 +614,10 @@ async def delete_mcp_server_from_config(
# TODO: don't do this in the future (just return MCPServer)
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
return [server.to_config() for server in all_servers]
result = []
for mcp_server in all_servers:
result.append(await mcp_server.to_config_async())
return result
@deprecated("Deprecated in favor of /mcp/servers/connect which handles OAuth flow via SSE stream")
@@ -795,7 +804,7 @@ async def execute_mcp_tool(
raise NoResultFound(f"MCP server '{mcp_server_name}' not found")
# Create client and connect
server_config = mcp_server.to_config()
server_config = await mcp_server.to_config_async()
server_config.resolve_environment_variables()
client = await server.mcp_manager.get_mcp_client(server_config, actor)
await client.connect_to_server()

View File

@@ -841,7 +841,7 @@ class AgentManager:
existing_value = None
if existing_env and existing_env.value_enc:
existing_secret = Secret.from_encrypted(existing_env.value_enc)
existing_value = existing_secret.get_plaintext()
existing_value = await existing_secret.get_plaintext_async()
# Encrypt value (reuse existing encrypted value if unchanged)
if existing_value == v and existing_env and existing_env.value_enc:

View File

@@ -38,11 +38,11 @@ class DatabaseTokenStorage(TokenStorage):
return None
# Read tokens directly from _enc columns
access_token = oauth_session.access_token_enc.get_plaintext() if oauth_session.access_token_enc else None
access_token = await oauth_session.access_token_enc.get_plaintext_async() if oauth_session.access_token_enc else None
if not access_token:
return None
refresh_token = oauth_session.refresh_token_enc.get_plaintext() if oauth_session.refresh_token_enc else None
refresh_token = await oauth_session.refresh_token_enc.get_plaintext_async() if oauth_session.refresh_token_enc else None
return OAuthToken(
access_token=access_token,
@@ -71,7 +71,7 @@ class DatabaseTokenStorage(TokenStorage):
return None
# Read client secret directly from _enc column
client_secret = oauth_session.client_secret_enc.get_plaintext() if oauth_session.client_secret_enc else None
client_secret = await oauth_session.client_secret_enc.get_plaintext_async() if oauth_session.client_secret_enc else None
return OAuthClientInformationFull(
client_id=oauth_session.client_id,
@@ -229,7 +229,7 @@ async def create_oauth_provider(
oauth_session = await mcp_manager.get_oauth_session_by_id(session_id, actor)
if oauth_session and oauth_session.authorization_code_enc:
# Read authorization code directly from _enc column
auth_code = oauth_session.authorization_code_enc.get_plaintext()
auth_code = await oauth_session.authorization_code_enc.get_plaintext_async()
return auth_code, oauth_session.state
elif oauth_session and oauth_session.status == OAuthSessionStatus.ERROR:
raise Exception("OAuth authorization failed")

View File

@@ -70,7 +70,7 @@ class MCPManager:
try:
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
server_config = mcp_config.to_config()
server_config = await mcp_config.to_config_async()
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
await mcp_client.connect_to_server()
@@ -116,7 +116,7 @@ class MCPManager:
# read from DB
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
server_config = mcp_config.to_config(environment_variables)
server_config = await mcp_config.to_config_async(environment_variables)
else:
# read from config file
mcp_config = await self.read_mcp_config()
@@ -541,7 +541,7 @@ class MCPManager:
existing_token = None
if mcp_server.token_enc:
existing_secret = Secret.from_encrypted(mcp_server.token_enc)
existing_token = existing_secret.get_plaintext()
existing_token = await existing_secret.get_plaintext_async()
# Only re-encrypt if different
if existing_token != update_data["token"]:
@@ -561,7 +561,7 @@ class MCPManager:
existing_headers_json = None
if mcp_server.custom_headers_enc:
existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc)
existing_headers_json = existing_secret.get_plaintext()
existing_headers_json = await existing_secret.get_plaintext_async()
# Only re-encrypt if different
if existing_headers_json != json_str:
@@ -793,8 +793,8 @@ class MCPManager:
# If no OAuth provider is provided, check if we have stored OAuth credentials
if oauth_provider is None and hasattr(server_config, "server_url"):
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
# Check if access token exists by reading directly from _enc column
if oauth_session and oauth_session.access_token_enc and oauth_session.access_token_enc.get_plaintext():
# Check if access token exists by attempting to decrypt it
if oauth_session and oauth_session.access_token_enc and await oauth_session.access_token_enc.get_plaintext_async():
# Create OAuth provider from stored credentials
from letta.services.mcp.oauth_utils import create_oauth_provider
@@ -819,7 +819,7 @@ class MCPManager:
raise ValueError(f"Unsupported server config type: {type(server_config)}")
# OAuth-related methods
def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
async def _oauth_orm_to_pydantic_async(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
"""
Convert OAuth ORM model to Pydantic model, reading directly from encrypted columns.
"""
@@ -832,10 +832,10 @@ class MCPManager:
client_secret_enc = Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None
# Get plaintext values from encrypted columns (primary source of truth)
authorization_code = authorization_code_enc.get_plaintext() if authorization_code_enc else None
access_token = access_token_enc.get_plaintext() if access_token_enc else None
refresh_token = refresh_token_enc.get_plaintext() if refresh_token_enc else None
client_secret = client_secret_enc.get_plaintext() if client_secret_enc else None
authorization_code = await authorization_code_enc.get_plaintext_async() if authorization_code_enc else None
access_token = await access_token_enc.get_plaintext_async() if access_token_enc else None
refresh_token = await refresh_token_enc.get_plaintext_async() if refresh_token_enc else None
client_secret = await client_secret_enc.get_plaintext_async() if client_secret_enc else None
# Create the Pydantic object with both encrypted and plaintext fields
pydantic_session = MCPOAuthSession(
@@ -887,7 +887,7 @@ class MCPManager:
oauth_session = await oauth_session.create_async(session, actor=actor)
# Convert to Pydantic model - note: new sessions won't have tokens yet
return self._oauth_orm_to_pydantic(oauth_session)
return await self._oauth_orm_to_pydantic_async(oauth_session)
@enforce_types
async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
@@ -895,7 +895,7 @@ class MCPManager:
async with db_registry.async_session() as session:
try:
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
return self._oauth_orm_to_pydantic(oauth_session)
return await self._oauth_orm_to_pydantic_async(oauth_session)
except NoResultFound:
return None
@@ -921,7 +921,7 @@ class MCPManager:
if not oauth_session:
return None
return self._oauth_orm_to_pydantic(oauth_session)
return await self._oauth_orm_to_pydantic_async(oauth_session)
@enforce_types
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
@@ -939,7 +939,7 @@ class MCPManager:
existing_code = None
if oauth_session.authorization_code_enc:
existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc)
existing_code = existing_secret.get_plaintext()
existing_code = await existing_secret.get_plaintext_async()
# Only re-encrypt if different
if existing_code != session_update.authorization_code:
@@ -951,7 +951,7 @@ class MCPManager:
existing_token = None
if oauth_session.access_token_enc:
existing_secret = Secret.from_encrypted(oauth_session.access_token_enc)
existing_token = existing_secret.get_plaintext()
existing_token = await existing_secret.get_plaintext_async()
# Only re-encrypt if different
if existing_token != session_update.access_token:
@@ -963,7 +963,7 @@ class MCPManager:
existing_refresh = None
if oauth_session.refresh_token_enc:
existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc)
existing_refresh = existing_secret.get_plaintext()
existing_refresh = await existing_secret.get_plaintext_async()
# Only re-encrypt if different
if existing_refresh != session_update.refresh_token:
@@ -984,7 +984,7 @@ class MCPManager:
existing_secret_val = None
if oauth_session.client_secret_enc:
existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc)
existing_secret_val = existing_secret.get_plaintext()
existing_secret_val = await existing_secret.get_plaintext_async()
# Only re-encrypt if different
if existing_secret_val != session_update.client_secret:
@@ -1000,7 +1000,7 @@ class MCPManager:
oauth_session = await oauth_session.update_async(db_session=session, actor=actor)
return self._oauth_orm_to_pydantic(oauth_session)
return await self._oauth_orm_to_pydantic_async(oauth_session)
@enforce_types
async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None:

View File

@@ -162,7 +162,7 @@ class MCPServerManager:
mcp_client = None
try:
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
server_config = mcp_config.to_config()
server_config = await mcp_config.to_config_async()
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
await mcp_client.connect_to_server()
@@ -210,7 +210,7 @@ class MCPServerManager:
# Get the MCP server config
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
server_config = mcp_config.to_config(environment_variables)
server_config = await mcp_config.to_config_async(environment_variables)
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
await mcp_client.connect_to_server()
@@ -691,7 +691,7 @@ class MCPServerManager:
existing_token = None
if mcp_server.token_enc:
existing_secret = Secret.from_encrypted(mcp_server.token_enc)
existing_token = existing_secret.get_plaintext()
existing_token = await existing_secret.get_plaintext_async()
elif mcp_server.token:
existing_token = mcp_server.token
@@ -718,7 +718,7 @@ class MCPServerManager:
existing_headers_json = None
if mcp_server.custom_headers_enc:
existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc)
existing_headers_json = existing_secret.get_plaintext()
existing_headers_json = await existing_secret.get_plaintext_async()
elif mcp_server.custom_headers:
existing_headers_json = json.dumps(mcp_server.custom_headers)
@@ -961,7 +961,7 @@ class MCPServerManager:
if oauth_provider is None and hasattr(server_config, "server_url"):
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
# Check if access token exists by attempting to decrypt it
if oauth_session and oauth_session.get_access_token_secret().get_plaintext():
if oauth_session and await oauth_session.get_access_token_secret().get_plaintext_async():
# Create OAuth provider from stored credentials
from letta.services.mcp.oauth_utils import create_oauth_provider
@@ -986,7 +986,7 @@ class MCPServerManager:
raise ValueError(f"Unsupported server config type: {type(server_config)}")
# OAuth-related methods
def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
async def _oauth_orm_to_pydantic_async(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
"""
Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields.
@@ -994,21 +994,21 @@ class MCPServerManager:
This helps identify unmigrated data during the migration period.
"""
# Get decrypted values - prefer encrypted, fallback to plaintext with error logging
access_token = Secret.from_db(
access_token = await Secret.from_db(
encrypted_value=oauth_session.access_token_enc, plaintext_value=oauth_session.access_token
).get_plaintext()
).get_plaintext_async()
refresh_token = Secret.from_db(
refresh_token = await Secret.from_db(
encrypted_value=oauth_session.refresh_token_enc, plaintext_value=oauth_session.refresh_token
).get_plaintext()
).get_plaintext_async()
client_secret = Secret.from_db(
client_secret = await Secret.from_db(
encrypted_value=oauth_session.client_secret_enc, plaintext_value=oauth_session.client_secret
).get_plaintext()
).get_plaintext_async()
authorization_code = Secret.from_db(
authorization_code = await Secret.from_db(
encrypted_value=oauth_session.authorization_code_enc, plaintext_value=oauth_session.authorization_code
).get_plaintext()
).get_plaintext_async()
# Create the Pydantic object with encrypted fields as Secret objects
pydantic_session = MCPOAuthSession(
@@ -1061,7 +1061,7 @@ class MCPServerManager:
oauth_session = await oauth_session.create_async(session, actor=actor)
# Convert to Pydantic model - note: new sessions won't have tokens yet
return self._oauth_orm_to_pydantic(oauth_session)
return await self._oauth_orm_to_pydantic_async(oauth_session)
@enforce_types
async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
@@ -1069,7 +1069,7 @@ class MCPServerManager:
async with db_registry.async_session() as session:
try:
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
return self._oauth_orm_to_pydantic(oauth_session)
return await self._oauth_orm_to_pydantic_async(oauth_session)
except NoResultFound:
return None
@@ -1095,7 +1095,7 @@ class MCPServerManager:
if not oauth_session:
return None
return self._oauth_orm_to_pydantic(oauth_session)
return await self._oauth_orm_to_pydantic_async(oauth_session)
@enforce_types
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
@@ -1114,7 +1114,7 @@ class MCPServerManager:
existing_code = None
if oauth_session.authorization_code_enc:
existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc)
existing_code = existing_secret.get_plaintext()
existing_code = await existing_secret.get_plaintext_async()
elif oauth_session.authorization_code:
existing_code = oauth_session.authorization_code
@@ -1131,7 +1131,7 @@ class MCPServerManager:
existing_token = None
if oauth_session.access_token_enc:
existing_secret = Secret.from_encrypted(oauth_session.access_token_enc)
existing_token = existing_secret.get_plaintext()
existing_token = await existing_secret.get_plaintext_async()
elif oauth_session.access_token:
existing_token = oauth_session.access_token
@@ -1148,7 +1148,7 @@ class MCPServerManager:
existing_refresh = None
if oauth_session.refresh_token_enc:
existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc)
existing_refresh = existing_secret.get_plaintext()
existing_refresh = await existing_secret.get_plaintext_async()
elif oauth_session.refresh_token:
existing_refresh = oauth_session.refresh_token
@@ -1174,7 +1174,7 @@ class MCPServerManager:
existing_secret_val = None
if oauth_session.client_secret_enc:
existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc)
existing_secret_val = existing_secret.get_plaintext()
existing_secret_val = await existing_secret.get_plaintext_async()
elif oauth_session.client_secret:
existing_secret_val = oauth_session.client_secret
@@ -1194,7 +1194,7 @@ class MCPServerManager:
oauth_session = await oauth_session.update_async(db_session=session, actor=actor)
return self._oauth_orm_to_pydantic(oauth_session)
return await self._oauth_orm_to_pydantic_async(oauth_session)
@enforce_types
async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None:

View File

@@ -115,7 +115,7 @@ class ProviderManager:
existing_api_key = None
if existing_provider.api_key_enc:
existing_secret = Secret.from_encrypted(existing_provider.api_key_enc)
existing_api_key = existing_secret.get_plaintext()
existing_api_key = await existing_secret.get_plaintext_async()
# Only re-encrypt if different
if existing_api_key != update_data["api_key"]:
@@ -132,7 +132,7 @@ class ProviderManager:
existing_access_key = None
if existing_provider.access_key_enc:
existing_secret = Secret.from_encrypted(existing_provider.access_key_enc)
existing_access_key = existing_secret.get_plaintext()
existing_access_key = await existing_secret.get_plaintext_async()
# Only re-encrypt if different
if existing_access_key != update_data["access_key"]:
@@ -336,7 +336,7 @@ class ProviderManager:
if providers:
# Decrypt the API key before returning
api_key_secret = providers[0].api_key_enc
return api_key_secret.get_plaintext() if api_key_secret else None
return await api_key_secret.get_plaintext_async() if api_key_secret else None
return None
@enforce_types
@@ -349,8 +349,8 @@ class ProviderManager:
# Decrypt the credentials before returning
access_key_secret = providers[0].access_key_enc
api_key_secret = providers[0].api_key_enc
access_key = access_key_secret.get_plaintext() if access_key_secret else None
secret_key = api_key_secret.get_plaintext() if api_key_secret else None
access_key = await access_key_secret.get_plaintext_async() if access_key_secret else None
secret_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
region = providers[0].region
return access_key, secret_key, region
return None, None, None
@@ -379,7 +379,7 @@ class ProviderManager:
if providers:
# Decrypt the API key before returning
api_key_secret = providers[0].api_key_enc
api_key = api_key_secret.get_plaintext() if api_key_secret else None
api_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
base_url = providers[0].base_url
api_version = providers[0].api_version
return api_key, base_url, api_version
@@ -400,7 +400,7 @@ class ProviderManager:
).cast_to_subtype()
# TODO: add more string sanity checks here before we hit actual endpoints
if not provider.api_key_enc or not provider.api_key_enc.get_plaintext():
if not provider.api_key_enc or not await provider.api_key_enc.get_plaintext_async():
raise ValueError("API key is required!")
await provider.check_api_key()
@@ -439,8 +439,8 @@ class ProviderManager:
return
# Create provider instance with necessary parameters
api_key = provider.api_key_enc.get_plaintext() if provider.api_key_enc else None
access_key = provider.access_key_enc.get_plaintext() if provider.access_key_enc else None
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
kwargs = {
"name": provider.name,
"api_key": api_key,
@@ -516,8 +516,8 @@ class ProviderManager:
continue
# Convert Provider to ProviderCreate
api_key = provider.api_key_enc.get_plaintext() if provider.api_key_enc else None
access_key = provider.access_key_enc.get_plaintext() if provider.access_key_enc else None
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
provider_create = ProviderCreate(
name=provider.name,
provider_type=provider.provider_type,

View File

@@ -285,7 +285,10 @@ class SandboxConfigManager:
organization_id=actor.organization_id,
sandbox_config_id=sandbox_config_id,
)
return [await PydanticEnvVar.from_orm_async(env_var) for env_var in env_vars]
result = []
for env_var in env_vars:
result.append(await PydanticEnvVar.from_orm_async(env_var))
return result
@enforce_types
@trace_method
@@ -301,7 +304,10 @@ class SandboxConfigManager:
organization_id=actor.organization_id,
key=key,
)
return [await PydanticEnvVar.from_orm_async(env_var) for env_var in env_vars]
result = []
for env_var in env_vars:
result.append(await PydanticEnvVar.from_orm_async(env_var))
return result
@enforce_types
@trace_method

View File

@@ -143,9 +143,10 @@ class AsyncToolSandboxModal(AsyncToolSandboxBase):
logger.warning(f"Could not load sandbox env vars for tool {self.tool_name}: {e}")
# Add agent-specific environment variables (these override sandbox-level)
# Use the pre-decrypted value field which was populated in from_orm_async()
if agent_state and agent_state.secrets:
for secret in agent_state.secrets:
env_vars[secret.key] = secret.value_enc.get_plaintext() if secret.value_enc else None
env_vars[secret.key] = secret.value or ""
# Add any additional env vars passed at runtime (highest priority)
if additional_env_vars: