fix: update aysnc get plaintext callsites (#7069)
* base * resolve * fix * fix
This commit is contained in:
@@ -1902,8 +1902,8 @@ class LettaAgent(BaseAgent):
|
||||
start_time = get_utc_timestamp_ns()
|
||||
agent_step_span.add_event(name="tool_execution_started")
|
||||
|
||||
# Decrypt environment variable values
|
||||
sandbox_env_vars = {var.key: var.value_enc.get_plaintext() if var.value_enc else None for var in agent_state.secrets}
|
||||
# Use pre-decrypted environment variable values (populated in from_orm_async)
|
||||
sandbox_env_vars = {var.key: var.value or "" for var in agent_state.secrets}
|
||||
tool_execution_manager = ToolExecutionManager(
|
||||
agent_state=agent_state,
|
||||
message_manager=self.message_manager,
|
||||
|
||||
@@ -1184,8 +1184,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
start_time = get_utc_timestamp_ns()
|
||||
agent_step_span.add_event(name="tool_execution_started")
|
||||
|
||||
# Decrypt environment variable values
|
||||
sandbox_env_vars = {var.key: var.value_enc.get_plaintext() if var.value_enc else None for var in agent_state.secrets}
|
||||
# Use pre-decrypted environment variable values (populated in from_orm_async)
|
||||
sandbox_env_vars = {var.key: var.value or "" for var in agent_state.secrets}
|
||||
tool_execution_manager = ToolExecutionManager(
|
||||
agent_state=agent_state,
|
||||
message_manager=self.message_manager,
|
||||
|
||||
@@ -438,8 +438,8 @@ class VoiceAgent(BaseAgent):
|
||||
)
|
||||
|
||||
# Use ToolExecutionManager for modern tool execution
|
||||
# Decrypt environment variable values
|
||||
sandbox_env_vars = {var.key: var.value_enc.get_plaintext() if var.value_enc else None for var in agent_state.secrets}
|
||||
# Use pre-decrypted environment variable values (populated in from_orm_async)
|
||||
sandbox_env_vars = {var.key: var.value or "" for var in agent_state.secrets}
|
||||
tool_execution_manager = ToolExecutionManager(
|
||||
agent_state=agent_state,
|
||||
message_manager=self.message_manager,
|
||||
|
||||
@@ -434,7 +434,9 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
state["multi_agent_group"] = multi_agent_group
|
||||
state["managed_group"] = multi_agent_group
|
||||
# Convert ORM env vars to Pydantic with async decryption
|
||||
env_vars_pydantic = [await PydanticAgentEnvVar.from_orm_async(e) for e in tool_exec_environment_variables]
|
||||
env_vars_pydantic = []
|
||||
for e in tool_exec_environment_variables:
|
||||
env_vars_pydantic.append(await PydanticAgentEnvVar.from_orm_async(e))
|
||||
state["tool_exec_environment_variables"] = env_vars_pydantic
|
||||
state["secrets"] = env_vars_pydantic
|
||||
state["model"] = self.llm_config.handle if self.llm_config else None
|
||||
|
||||
@@ -70,6 +70,19 @@ class MCPServer(BaseMCPServer):
|
||||
logger.warning(f"Failed to parse custom_headers_enc for MCP server {self.id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_custom_headers_dict_async(self) -> Optional[Dict[str, str]]:
|
||||
"""Get custom headers as a plaintext dictionary (async version)."""
|
||||
secret = self.get_custom_headers_secret()
|
||||
if secret is None:
|
||||
return None
|
||||
json_str = await secret.get_plaintext_async()
|
||||
if json_str:
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"Failed to parse custom_headers_enc for MCP server {self.id}: {e}")
|
||||
return None
|
||||
|
||||
def set_token_secret(self, secret: Secret) -> None:
|
||||
"""Set token from a Secret object."""
|
||||
self.token_enc = secret
|
||||
@@ -130,6 +143,53 @@ class MCPServer(BaseMCPServer):
|
||||
else:
|
||||
raise ValueError(f"Unsupported server type: {self.server_type}")
|
||||
|
||||
async def to_config_async(
|
||||
self,
|
||||
environment_variables: Optional[Dict[str, str]] = None,
|
||||
resolve_variables: bool = True,
|
||||
) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]:
|
||||
"""Async version of to_config() that uses async decryption."""
|
||||
# Get decrypted values for use in config
|
||||
token_secret = self.get_token_secret()
|
||||
token_plaintext = await token_secret.get_plaintext_async() if token_secret else None
|
||||
|
||||
# Get custom headers as dict
|
||||
headers_plaintext = await self.get_custom_headers_dict_async()
|
||||
|
||||
if self.server_type == MCPServerType.SSE:
|
||||
config = SSEServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if token_plaintext and not headers_plaintext else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {token_plaintext}" if token_plaintext and not headers_plaintext else None,
|
||||
custom_headers=headers_plaintext,
|
||||
)
|
||||
if resolve_variables:
|
||||
config.resolve_environment_variables(environment_variables)
|
||||
return config
|
||||
elif self.server_type == MCPServerType.STDIO:
|
||||
if self.stdio_config is None:
|
||||
raise ValueError("stdio_config is required for STDIO server type")
|
||||
if resolve_variables:
|
||||
self.stdio_config.resolve_environment_variables(environment_variables)
|
||||
return self.stdio_config
|
||||
elif self.server_type == MCPServerType.STREAMABLE_HTTP:
|
||||
if self.server_url is None:
|
||||
raise ValueError("server_url is required for STREAMABLE_HTTP server type")
|
||||
|
||||
config = StreamableHTTPServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if token_plaintext and not headers_plaintext else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {token_plaintext}" if token_plaintext and not headers_plaintext else None,
|
||||
custom_headers=headers_plaintext,
|
||||
)
|
||||
if resolve_variables:
|
||||
config.resolve_environment_variables(environment_variables)
|
||||
return config
|
||||
else:
|
||||
raise ValueError(f"Unsupported server type: {self.server_type}")
|
||||
|
||||
|
||||
class UpdateSSEMCPServer(LettaBase):
|
||||
"""Update an SSE MCP server"""
|
||||
|
||||
@@ -638,10 +638,11 @@ async def run_tool_for_agent(
|
||||
)
|
||||
|
||||
# Build environment variables dict from agent secrets
|
||||
# Use pre-decrypted value field (populated in from_orm_async)
|
||||
sandbox_env_vars = {}
|
||||
if agent.tool_exec_environment_variables:
|
||||
for env_var in agent.tool_exec_environment_variables:
|
||||
sandbox_env_vars[env_var.key] = env_var.value_enc.get_plaintext() if env_var.value_enc else None
|
||||
sandbox_env_vars[env_var.key] = env_var.value or ""
|
||||
|
||||
# Create tool execution manager and execute the tool
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
|
||||
@@ -122,8 +122,8 @@ async def check_existing_provider(
|
||||
provider = await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor)
|
||||
|
||||
# Create a ProviderCheck from the existing provider
|
||||
api_key = provider.api_key_enc.get_plaintext() if provider.api_key_enc else None
|
||||
access_key = provider.access_key_enc.get_plaintext() if provider.access_key_enc else None
|
||||
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
||||
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
||||
provider_check = ProviderCheck(
|
||||
provider_type=provider.provider_type,
|
||||
api_key=api_key,
|
||||
|
||||
@@ -427,7 +427,10 @@ async def list_mcp_servers(
|
||||
else:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
mcp_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
|
||||
return {server.server_name: server.to_config(resolve_variables=False) for server in mcp_servers}
|
||||
result = {}
|
||||
for mcp_server in mcp_servers:
|
||||
result[mcp_server.server_name] = await mcp_server.to_config_async(resolve_variables=False)
|
||||
return result
|
||||
|
||||
|
||||
# NOTE: async because the MCP client/session calls are async
|
||||
@@ -556,7 +559,10 @@ async def add_mcp_server_to_config(
|
||||
|
||||
# TODO: don't do this in the future (just return MCPServer)
|
||||
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
|
||||
return [server.to_config() for server in all_servers]
|
||||
result = []
|
||||
for mcp_server in all_servers:
|
||||
result.append(await mcp_server.to_config_async())
|
||||
return result
|
||||
|
||||
|
||||
@router.patch(
|
||||
@@ -581,7 +587,7 @@ async def update_mcp_server(
|
||||
updated_server = await server.mcp_manager.update_mcp_server_by_name(
|
||||
mcp_server_name=mcp_server_name, mcp_server_update=request, actor=actor
|
||||
)
|
||||
return updated_server.to_config()
|
||||
return await updated_server.to_config_async()
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -608,7 +614,10 @@ async def delete_mcp_server_from_config(
|
||||
|
||||
# TODO: don't do this in the future (just return MCPServer)
|
||||
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
|
||||
return [server.to_config() for server in all_servers]
|
||||
result = []
|
||||
for mcp_server in all_servers:
|
||||
result.append(await mcp_server.to_config_async())
|
||||
return result
|
||||
|
||||
|
||||
@deprecated("Deprecated in favor of /mcp/servers/connect which handles OAuth flow via SSE stream")
|
||||
@@ -795,7 +804,7 @@ async def execute_mcp_tool(
|
||||
raise NoResultFound(f"MCP server '{mcp_server_name}' not found")
|
||||
|
||||
# Create client and connect
|
||||
server_config = mcp_server.to_config()
|
||||
server_config = await mcp_server.to_config_async()
|
||||
server_config.resolve_environment_variables()
|
||||
client = await server.mcp_manager.get_mcp_client(server_config, actor)
|
||||
await client.connect_to_server()
|
||||
|
||||
@@ -841,7 +841,7 @@ class AgentManager:
|
||||
existing_value = None
|
||||
if existing_env and existing_env.value_enc:
|
||||
existing_secret = Secret.from_encrypted(existing_env.value_enc)
|
||||
existing_value = existing_secret.get_plaintext()
|
||||
existing_value = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Encrypt value (reuse existing encrypted value if unchanged)
|
||||
if existing_value == v and existing_env and existing_env.value_enc:
|
||||
|
||||
@@ -38,11 +38,11 @@ class DatabaseTokenStorage(TokenStorage):
|
||||
return None
|
||||
|
||||
# Read tokens directly from _enc columns
|
||||
access_token = oauth_session.access_token_enc.get_plaintext() if oauth_session.access_token_enc else None
|
||||
access_token = await oauth_session.access_token_enc.get_plaintext_async() if oauth_session.access_token_enc else None
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
refresh_token = oauth_session.refresh_token_enc.get_plaintext() if oauth_session.refresh_token_enc else None
|
||||
refresh_token = await oauth_session.refresh_token_enc.get_plaintext_async() if oauth_session.refresh_token_enc else None
|
||||
|
||||
return OAuthToken(
|
||||
access_token=access_token,
|
||||
@@ -71,7 +71,7 @@ class DatabaseTokenStorage(TokenStorage):
|
||||
return None
|
||||
|
||||
# Read client secret directly from _enc column
|
||||
client_secret = oauth_session.client_secret_enc.get_plaintext() if oauth_session.client_secret_enc else None
|
||||
client_secret = await oauth_session.client_secret_enc.get_plaintext_async() if oauth_session.client_secret_enc else None
|
||||
|
||||
return OAuthClientInformationFull(
|
||||
client_id=oauth_session.client_id,
|
||||
@@ -229,7 +229,7 @@ async def create_oauth_provider(
|
||||
oauth_session = await mcp_manager.get_oauth_session_by_id(session_id, actor)
|
||||
if oauth_session and oauth_session.authorization_code_enc:
|
||||
# Read authorization code directly from _enc column
|
||||
auth_code = oauth_session.authorization_code_enc.get_plaintext()
|
||||
auth_code = await oauth_session.authorization_code_enc.get_plaintext_async()
|
||||
return auth_code, oauth_session.state
|
||||
elif oauth_session and oauth_session.status == OAuthSessionStatus.ERROR:
|
||||
raise Exception("OAuth authorization failed")
|
||||
|
||||
@@ -70,7 +70,7 @@ class MCPManager:
|
||||
try:
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config()
|
||||
server_config = await mcp_config.to_config_async()
|
||||
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
@@ -116,7 +116,7 @@ class MCPManager:
|
||||
# read from DB
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config(environment_variables)
|
||||
server_config = await mcp_config.to_config_async(environment_variables)
|
||||
else:
|
||||
# read from config file
|
||||
mcp_config = await self.read_mcp_config()
|
||||
@@ -541,7 +541,7 @@ class MCPManager:
|
||||
existing_token = None
|
||||
if mcp_server.token_enc:
|
||||
existing_secret = Secret.from_encrypted(mcp_server.token_enc)
|
||||
existing_token = existing_secret.get_plaintext()
|
||||
existing_token = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_token != update_data["token"]:
|
||||
@@ -561,7 +561,7 @@ class MCPManager:
|
||||
existing_headers_json = None
|
||||
if mcp_server.custom_headers_enc:
|
||||
existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc)
|
||||
existing_headers_json = existing_secret.get_plaintext()
|
||||
existing_headers_json = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_headers_json != json_str:
|
||||
@@ -793,8 +793,8 @@ class MCPManager:
|
||||
# If no OAuth provider is provided, check if we have stored OAuth credentials
|
||||
if oauth_provider is None and hasattr(server_config, "server_url"):
|
||||
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
|
||||
# Check if access token exists by reading directly from _enc column
|
||||
if oauth_session and oauth_session.access_token_enc and oauth_session.access_token_enc.get_plaintext():
|
||||
# Check if access token exists by attempting to decrypt it
|
||||
if oauth_session and oauth_session.access_token_enc and await oauth_session.access_token_enc.get_plaintext_async():
|
||||
# Create OAuth provider from stored credentials
|
||||
from letta.services.mcp.oauth_utils import create_oauth_provider
|
||||
|
||||
@@ -819,7 +819,7 @@ class MCPManager:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
|
||||
# OAuth-related methods
|
||||
def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
|
||||
async def _oauth_orm_to_pydantic_async(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
|
||||
"""
|
||||
Convert OAuth ORM model to Pydantic model, reading directly from encrypted columns.
|
||||
"""
|
||||
@@ -832,10 +832,10 @@ class MCPManager:
|
||||
client_secret_enc = Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None
|
||||
|
||||
# Get plaintext values from encrypted columns (primary source of truth)
|
||||
authorization_code = authorization_code_enc.get_plaintext() if authorization_code_enc else None
|
||||
access_token = access_token_enc.get_plaintext() if access_token_enc else None
|
||||
refresh_token = refresh_token_enc.get_plaintext() if refresh_token_enc else None
|
||||
client_secret = client_secret_enc.get_plaintext() if client_secret_enc else None
|
||||
authorization_code = await authorization_code_enc.get_plaintext_async() if authorization_code_enc else None
|
||||
access_token = await access_token_enc.get_plaintext_async() if access_token_enc else None
|
||||
refresh_token = await refresh_token_enc.get_plaintext_async() if refresh_token_enc else None
|
||||
client_secret = await client_secret_enc.get_plaintext_async() if client_secret_enc else None
|
||||
|
||||
# Create the Pydantic object with both encrypted and plaintext fields
|
||||
pydantic_session = MCPOAuthSession(
|
||||
@@ -887,7 +887,7 @@ class MCPManager:
|
||||
oauth_session = await oauth_session.create_async(session, actor=actor)
|
||||
|
||||
# Convert to Pydantic model - note: new sessions won't have tokens yet
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
|
||||
@@ -895,7 +895,7 @@ class MCPManager:
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@@ -921,7 +921,7 @@ class MCPManager:
|
||||
if not oauth_session:
|
||||
return None
|
||||
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
|
||||
@@ -939,7 +939,7 @@ class MCPManager:
|
||||
existing_code = None
|
||||
if oauth_session.authorization_code_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc)
|
||||
existing_code = existing_secret.get_plaintext()
|
||||
existing_code = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_code != session_update.authorization_code:
|
||||
@@ -951,7 +951,7 @@ class MCPManager:
|
||||
existing_token = None
|
||||
if oauth_session.access_token_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.access_token_enc)
|
||||
existing_token = existing_secret.get_plaintext()
|
||||
existing_token = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_token != session_update.access_token:
|
||||
@@ -963,7 +963,7 @@ class MCPManager:
|
||||
existing_refresh = None
|
||||
if oauth_session.refresh_token_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc)
|
||||
existing_refresh = existing_secret.get_plaintext()
|
||||
existing_refresh = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_refresh != session_update.refresh_token:
|
||||
@@ -984,7 +984,7 @@ class MCPManager:
|
||||
existing_secret_val = None
|
||||
if oauth_session.client_secret_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc)
|
||||
existing_secret_val = existing_secret.get_plaintext()
|
||||
existing_secret_val = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_secret_val != session_update.client_secret:
|
||||
@@ -1000,7 +1000,7 @@ class MCPManager:
|
||||
|
||||
oauth_session = await oauth_session.update_async(db_session=session, actor=actor)
|
||||
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None:
|
||||
|
||||
@@ -162,7 +162,7 @@ class MCPServerManager:
|
||||
mcp_client = None
|
||||
try:
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config()
|
||||
server_config = await mcp_config.to_config_async()
|
||||
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
@@ -210,7 +210,7 @@ class MCPServerManager:
|
||||
|
||||
# Get the MCP server config
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config(environment_variables)
|
||||
server_config = await mcp_config.to_config_async(environment_variables)
|
||||
|
||||
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
|
||||
await mcp_client.connect_to_server()
|
||||
@@ -691,7 +691,7 @@ class MCPServerManager:
|
||||
existing_token = None
|
||||
if mcp_server.token_enc:
|
||||
existing_secret = Secret.from_encrypted(mcp_server.token_enc)
|
||||
existing_token = existing_secret.get_plaintext()
|
||||
existing_token = await existing_secret.get_plaintext_async()
|
||||
elif mcp_server.token:
|
||||
existing_token = mcp_server.token
|
||||
|
||||
@@ -718,7 +718,7 @@ class MCPServerManager:
|
||||
existing_headers_json = None
|
||||
if mcp_server.custom_headers_enc:
|
||||
existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc)
|
||||
existing_headers_json = existing_secret.get_plaintext()
|
||||
existing_headers_json = await existing_secret.get_plaintext_async()
|
||||
elif mcp_server.custom_headers:
|
||||
existing_headers_json = json.dumps(mcp_server.custom_headers)
|
||||
|
||||
@@ -961,7 +961,7 @@ class MCPServerManager:
|
||||
if oauth_provider is None and hasattr(server_config, "server_url"):
|
||||
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
|
||||
# Check if access token exists by attempting to decrypt it
|
||||
if oauth_session and oauth_session.get_access_token_secret().get_plaintext():
|
||||
if oauth_session and await oauth_session.get_access_token_secret().get_plaintext_async():
|
||||
# Create OAuth provider from stored credentials
|
||||
from letta.services.mcp.oauth_utils import create_oauth_provider
|
||||
|
||||
@@ -986,7 +986,7 @@ class MCPServerManager:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
|
||||
# OAuth-related methods
|
||||
def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
|
||||
async def _oauth_orm_to_pydantic_async(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
|
||||
"""
|
||||
Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields.
|
||||
|
||||
@@ -994,21 +994,21 @@ class MCPServerManager:
|
||||
This helps identify unmigrated data during the migration period.
|
||||
"""
|
||||
# Get decrypted values - prefer encrypted, fallback to plaintext with error logging
|
||||
access_token = Secret.from_db(
|
||||
access_token = await Secret.from_db(
|
||||
encrypted_value=oauth_session.access_token_enc, plaintext_value=oauth_session.access_token
|
||||
).get_plaintext()
|
||||
).get_plaintext_async()
|
||||
|
||||
refresh_token = Secret.from_db(
|
||||
refresh_token = await Secret.from_db(
|
||||
encrypted_value=oauth_session.refresh_token_enc, plaintext_value=oauth_session.refresh_token
|
||||
).get_plaintext()
|
||||
).get_plaintext_async()
|
||||
|
||||
client_secret = Secret.from_db(
|
||||
client_secret = await Secret.from_db(
|
||||
encrypted_value=oauth_session.client_secret_enc, plaintext_value=oauth_session.client_secret
|
||||
).get_plaintext()
|
||||
).get_plaintext_async()
|
||||
|
||||
authorization_code = Secret.from_db(
|
||||
authorization_code = await Secret.from_db(
|
||||
encrypted_value=oauth_session.authorization_code_enc, plaintext_value=oauth_session.authorization_code
|
||||
).get_plaintext()
|
||||
).get_plaintext_async()
|
||||
|
||||
# Create the Pydantic object with encrypted fields as Secret objects
|
||||
pydantic_session = MCPOAuthSession(
|
||||
@@ -1061,7 +1061,7 @@ class MCPServerManager:
|
||||
oauth_session = await oauth_session.create_async(session, actor=actor)
|
||||
|
||||
# Convert to Pydantic model - note: new sessions won't have tokens yet
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
|
||||
@@ -1069,7 +1069,7 @@ class MCPServerManager:
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@@ -1095,7 +1095,7 @@ class MCPServerManager:
|
||||
if not oauth_session:
|
||||
return None
|
||||
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
|
||||
@@ -1114,7 +1114,7 @@ class MCPServerManager:
|
||||
existing_code = None
|
||||
if oauth_session.authorization_code_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc)
|
||||
existing_code = existing_secret.get_plaintext()
|
||||
existing_code = await existing_secret.get_plaintext_async()
|
||||
elif oauth_session.authorization_code:
|
||||
existing_code = oauth_session.authorization_code
|
||||
|
||||
@@ -1131,7 +1131,7 @@ class MCPServerManager:
|
||||
existing_token = None
|
||||
if oauth_session.access_token_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.access_token_enc)
|
||||
existing_token = existing_secret.get_plaintext()
|
||||
existing_token = await existing_secret.get_plaintext_async()
|
||||
elif oauth_session.access_token:
|
||||
existing_token = oauth_session.access_token
|
||||
|
||||
@@ -1148,7 +1148,7 @@ class MCPServerManager:
|
||||
existing_refresh = None
|
||||
if oauth_session.refresh_token_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc)
|
||||
existing_refresh = existing_secret.get_plaintext()
|
||||
existing_refresh = await existing_secret.get_plaintext_async()
|
||||
elif oauth_session.refresh_token:
|
||||
existing_refresh = oauth_session.refresh_token
|
||||
|
||||
@@ -1174,7 +1174,7 @@ class MCPServerManager:
|
||||
existing_secret_val = None
|
||||
if oauth_session.client_secret_enc:
|
||||
existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc)
|
||||
existing_secret_val = existing_secret.get_plaintext()
|
||||
existing_secret_val = await existing_secret.get_plaintext_async()
|
||||
elif oauth_session.client_secret:
|
||||
existing_secret_val = oauth_session.client_secret
|
||||
|
||||
@@ -1194,7 +1194,7 @@ class MCPServerManager:
|
||||
|
||||
oauth_session = await oauth_session.update_async(db_session=session, actor=actor)
|
||||
|
||||
return self._oauth_orm_to_pydantic(oauth_session)
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None:
|
||||
|
||||
@@ -115,7 +115,7 @@ class ProviderManager:
|
||||
existing_api_key = None
|
||||
if existing_provider.api_key_enc:
|
||||
existing_secret = Secret.from_encrypted(existing_provider.api_key_enc)
|
||||
existing_api_key = existing_secret.get_plaintext()
|
||||
existing_api_key = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_api_key != update_data["api_key"]:
|
||||
@@ -132,7 +132,7 @@ class ProviderManager:
|
||||
existing_access_key = None
|
||||
if existing_provider.access_key_enc:
|
||||
existing_secret = Secret.from_encrypted(existing_provider.access_key_enc)
|
||||
existing_access_key = existing_secret.get_plaintext()
|
||||
existing_access_key = await existing_secret.get_plaintext_async()
|
||||
|
||||
# Only re-encrypt if different
|
||||
if existing_access_key != update_data["access_key"]:
|
||||
@@ -336,7 +336,7 @@ class ProviderManager:
|
||||
if providers:
|
||||
# Decrypt the API key before returning
|
||||
api_key_secret = providers[0].api_key_enc
|
||||
return api_key_secret.get_plaintext() if api_key_secret else None
|
||||
return await api_key_secret.get_plaintext_async() if api_key_secret else None
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
@@ -349,8 +349,8 @@ class ProviderManager:
|
||||
# Decrypt the credentials before returning
|
||||
access_key_secret = providers[0].access_key_enc
|
||||
api_key_secret = providers[0].api_key_enc
|
||||
access_key = access_key_secret.get_plaintext() if access_key_secret else None
|
||||
secret_key = api_key_secret.get_plaintext() if api_key_secret else None
|
||||
access_key = await access_key_secret.get_plaintext_async() if access_key_secret else None
|
||||
secret_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
|
||||
region = providers[0].region
|
||||
return access_key, secret_key, region
|
||||
return None, None, None
|
||||
@@ -379,7 +379,7 @@ class ProviderManager:
|
||||
if providers:
|
||||
# Decrypt the API key before returning
|
||||
api_key_secret = providers[0].api_key_enc
|
||||
api_key = api_key_secret.get_plaintext() if api_key_secret else None
|
||||
api_key = await api_key_secret.get_plaintext_async() if api_key_secret else None
|
||||
base_url = providers[0].base_url
|
||||
api_version = providers[0].api_version
|
||||
return api_key, base_url, api_version
|
||||
@@ -400,7 +400,7 @@ class ProviderManager:
|
||||
).cast_to_subtype()
|
||||
|
||||
# TODO: add more string sanity checks here before we hit actual endpoints
|
||||
if not provider.api_key_enc or not provider.api_key_enc.get_plaintext():
|
||||
if not provider.api_key_enc or not await provider.api_key_enc.get_plaintext_async():
|
||||
raise ValueError("API key is required!")
|
||||
|
||||
await provider.check_api_key()
|
||||
@@ -439,8 +439,8 @@ class ProviderManager:
|
||||
return
|
||||
|
||||
# Create provider instance with necessary parameters
|
||||
api_key = provider.api_key_enc.get_plaintext() if provider.api_key_enc else None
|
||||
access_key = provider.access_key_enc.get_plaintext() if provider.access_key_enc else None
|
||||
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
||||
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
||||
kwargs = {
|
||||
"name": provider.name,
|
||||
"api_key": api_key,
|
||||
@@ -516,8 +516,8 @@ class ProviderManager:
|
||||
continue
|
||||
|
||||
# Convert Provider to ProviderCreate
|
||||
api_key = provider.api_key_enc.get_plaintext() if provider.api_key_enc else None
|
||||
access_key = provider.access_key_enc.get_plaintext() if provider.access_key_enc else None
|
||||
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
||||
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
||||
provider_create = ProviderCreate(
|
||||
name=provider.name,
|
||||
provider_type=provider.provider_type,
|
||||
|
||||
@@ -285,7 +285,10 @@ class SandboxConfigManager:
|
||||
organization_id=actor.organization_id,
|
||||
sandbox_config_id=sandbox_config_id,
|
||||
)
|
||||
return [await PydanticEnvVar.from_orm_async(env_var) for env_var in env_vars]
|
||||
result = []
|
||||
for env_var in env_vars:
|
||||
result.append(await PydanticEnvVar.from_orm_async(env_var))
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -301,7 +304,10 @@ class SandboxConfigManager:
|
||||
organization_id=actor.organization_id,
|
||||
key=key,
|
||||
)
|
||||
return [await PydanticEnvVar.from_orm_async(env_var) for env_var in env_vars]
|
||||
result = []
|
||||
for env_var in env_vars:
|
||||
result.append(await PydanticEnvVar.from_orm_async(env_var))
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -143,9 +143,10 @@ class AsyncToolSandboxModal(AsyncToolSandboxBase):
|
||||
logger.warning(f"Could not load sandbox env vars for tool {self.tool_name}: {e}")
|
||||
|
||||
# Add agent-specific environment variables (these override sandbox-level)
|
||||
# Use the pre-decrypted value field which was populated in from_orm_async()
|
||||
if agent_state and agent_state.secrets:
|
||||
for secret in agent_state.secrets:
|
||||
env_vars[secret.key] = secret.value_enc.get_plaintext() if secret.value_enc else None
|
||||
env_vars[secret.key] = secret.value or ""
|
||||
|
||||
# Add any additional env vars passed at runtime (highest priority)
|
||||
if additional_env_vars:
|
||||
|
||||
Reference in New Issue
Block a user