Files
letta-server/letta/services/tool_executor/sandbox_tool_executor.py
Shubham Naik acbbccd28a feat: have core ask cloud for any relavent api credentials to allow a… [LET-6179] (#6172)
feat: have core ask cloud for any relavent api credentials to allow an agent to perform letta tasks

Co-authored-by: Shubham Naik <shub@memgpt.ai>
2025-11-24 19:09:32 -08:00

172 lines
7.4 KiB
Python

import traceback
from typing import Any, Dict, Optional
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import SandboxType, ToolSourceType
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.sandbox_credentials_service import SandboxCredentialsService
from letta.services.tool_executor.tool_executor_base import ToolExecutor
from letta.services.tool_sandbox.local_sandbox import AsyncToolSandboxLocal
from letta.settings import tool_settings
from letta.types import JsonDict
from letta.utils import get_friendly_error_msg
logger = get_logger(__name__)
class SandboxToolExecutor(ToolExecutor):
"""Executor for sandboxed tools."""
@trace_method
async def execute(
self,
function_name: str,
function_args: JsonDict,
tool: Tool,
actor: User,
agent_state: Optional[AgentState] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
# Store original memory state
if agent_state:
orig_memory_str = agent_state.memory.compile(llm_config=agent_state.llm_config)
else:
orig_memory_str = None
# Fetch credentials from webhook
credentials_service = SandboxCredentialsService()
fetched_credentials = await credentials_service.fetch_credentials(
actor=actor,
tool_name=tool.name,
agent_id=agent_state.id if agent_state else None,
)
# Merge fetched credentials with provided sandbox_env_vars
if sandbox_env_vars is None:
sandbox_env_vars = {}
sandbox_env_vars = {**fetched_credentials, **sandbox_env_vars}
try:
# Prepare function arguments
function_args = self._prepare_function_args(function_args, tool, function_name)
agent_state_copy = self._create_agent_state_copy(agent_state) if agent_state else None
# Execute in sandbox with Modal first (if configured and requested), then fallback to E2B/LOCAL
# Try Modal if: (1) Modal credentials configured AND (2) tool requests Modal via metadata
tool_requests_modal = tool.metadata_ and tool.metadata_.get("sandbox") == "modal"
modal_configured = tool_settings.modal_sandbox_enabled
tool_execution_result = None
# Try Modal first if both conditions met
if tool_requests_modal and modal_configured:
try:
from letta.services.tool_sandbox.modal_sandbox import AsyncToolSandboxModal
logger.info(f"Attempting Modal execution for tool {tool.name}")
sandbox = AsyncToolSandboxModal(
function_name,
function_args,
actor,
tool_object=tool,
sandbox_config=sandbox_config,
sandbox_env_vars=sandbox_env_vars,
organization_id=actor.organization_id,
)
# TODO: pass through letta api key
tool_execution_result = await sandbox.run(agent_state=agent_state_copy, additional_env_vars=sandbox_env_vars)
except Exception as e:
# Modal execution failed, log and fall back to E2B/LOCAL
logger.warning(f"Modal execution failed for tool {tool.name}: {e}. Falling back to {tool_settings.sandbox_type.value}")
tool_execution_result = None
# Fallback to E2B or LOCAL if Modal wasn't tried or failed
if tool_execution_result is None:
if tool_settings.sandbox_type == SandboxType.E2B:
from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B
sandbox = AsyncToolSandboxE2B(
function_name,
function_args,
actor,
tool_object=tool,
sandbox_config=sandbox_config,
sandbox_env_vars=sandbox_env_vars,
)
else:
sandbox = AsyncToolSandboxLocal(
function_name,
function_args,
actor,
tool_object=tool,
sandbox_config=sandbox_config,
sandbox_env_vars=sandbox_env_vars,
)
tool_execution_result = await sandbox.run(agent_state=agent_state_copy)
log_lines = (tool_execution_result.stdout or []) + (tool_execution_result.stderr or [])
logger.debug("Tool execution log: %s", "\n".join(log_lines))
# Verify memory integrity
if agent_state:
new_memory_str = agent_state.memory.compile(llm_config=agent_state.llm_config)
assert orig_memory_str == new_memory_str, "Memory should not be modified in a sandbox tool"
# Update agent memory if needed
if tool_execution_result.agent_state is not None:
await AgentManager().update_memory_if_changed_async(agent_state.id, tool_execution_result.agent_state.memory, actor)
return tool_execution_result
except Exception as e:
return self._handle_execution_error(e, function_name, traceback.format_exc())
@staticmethod
def _prepare_function_args(function_args: JsonDict, tool: Tool, function_name: str) -> dict:
"""Prepare function arguments with proper type coercion."""
try:
# Parse the source code to extract function annotations
annotations = get_function_annotations_from_source(tool.source_code, function_name)
# Coerce the function arguments to the correct types based on the annotations
return coerce_dict_args_by_annotations(function_args, annotations)
except ValueError:
# Just log the error and continue with original args
# This is defensive programming - we try to coerce but fall back if it fails
return function_args
@staticmethod
def _create_agent_state_copy(agent_state: AgentState):
"""Create a copy of agent state for sandbox execution."""
agent_state_copy = agent_state.__deepcopy__()
# Remove tools from copy to prevent nested tool execution
agent_state_copy.tools = []
agent_state_copy.tool_rules = []
return agent_state_copy
@staticmethod
def _handle_execution_error(
exception: Exception,
function_name: str,
stderr: str,
) -> ToolExecutionResult:
"""Handle tool execution errors."""
error_message = get_friendly_error_msg(
function_name=function_name, exception_name=type(exception).__name__, exception_message=str(exception)
)
return ToolExecutionResult(
status="error",
func_return=error_message,
stderr=[stderr],
)