Files
letta-server/letta/services/tool_sandbox/modal_sandbox.py

202 lines
8.3 KiB
Python

"""
Model sandbox implementation, which configures on Modal App per tool.
"""
from typing import TYPE_CHECKING, Any, Dict, Optional
import modal
from e2b.sandbox.commands.command_handle import CommandExitException
from e2b_code_interpreter import AsyncSandbox
from letta.constants import MODAL_DEFAULT_TOOL_NAME
from letta.log import get_logger
from letta.otel.tracing import log_event, trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import SandboxType
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.services.helpers.tool_parser_helper import parse_function_arguments, parse_stdout_best_effort
from letta.services.tool_manager import ToolManager
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
from letta.types import JsonDict
from letta.utils import get_friendly_error_msg
logger = get_logger(__name__)
if TYPE_CHECKING:
from e2b_code_interpreter import Execution
class AsyncToolSandboxModal(AsyncToolSandboxBase):
METADATA_CONFIG_STATE_KEY = "config_state"
def __init__(
self,
tool_name: str,
args: JsonDict,
user,
tool_id: str,
agent_id: Optional[str] = None,
project_id: Optional[str] = None,
force_recreate: bool = True,
tool_object: Optional[Tool] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
organization_id: Optional[str] = None,
):
super().__init__(
tool_name,
args,
user,
tool_id=tool_id,
agent_id=agent_id,
project_id=project_id,
tool_object=tool_object,
sandbox_config=sandbox_config,
sandbox_env_vars=sandbox_env_vars,
)
self.force_recreate = force_recreate
# Get organization_id from user if not explicitly provided
self.organization_id = organization_id if organization_id is not None else user.organization_id
# TODO: check to make sure modal app `App(tool.id)` exists
async def _wait_for_modal_function_deployment(self, timeout: int = 60):
"""Wait for Modal app deployment to complete by retrying function lookup."""
import asyncio
import time
import modal
from letta.helpers.tool_helpers import generate_modal_function_name
# Use the same naming logic as deployment
function_name = generate_modal_function_name(self.tool.name, self.organization_id, self.project_id)
start_time = time.time()
retry_delay = 2 # seconds
while time.time() - start_time < timeout:
try:
f = modal.Function.from_name(function_name, MODAL_DEFAULT_TOOL_NAME)
logger.info(f"Modal function found successfully for app {function_name}, function {f}")
return f
except Exception as e:
elapsed = time.time() - start_time
if elapsed >= timeout:
raise TimeoutError(
f"Modal app {function_name} deployment timed out after {timeout} seconds. "
f"Expected app name: {function_name}, function: {MODAL_DEFAULT_TOOL_NAME}"
) from e
logger.info(f"Modal app {function_name} not ready yet (elapsed: {elapsed:.1f}s), waiting {retry_delay}s...")
await asyncio.sleep(retry_delay)
raise TimeoutError(f"Modal app {function_name} deployment timed out after {timeout} seconds")
@trace_method
async def run(
self,
agent_state: Optional[AgentState] = None,
additional_env_vars: Optional[Dict] = None,
) -> ToolExecutionResult:
await self._init_async()
try:
log_event("modal_execution_started", {"tool": self.tool_name, "modal_app_id": self.tool.id})
logger.info(f"Waiting for Modal function deployment for app {self.tool.id}")
func = await self._wait_for_modal_function_deployment()
logger.info(f"Modal function found successfully for app {self.tool.id}, function {str(func)}")
logger.info(f"Calling with arguments {self.args}")
# TODO: use another mechanism to pass through the key
if additional_env_vars is None:
letta_api_key = None
else:
letta_api_key = additional_env_vars.get("LETTA_SECRET_API_KEY", None)
# Construct dynamic env vars with proper layering:
# 1. Global sandbox env vars from DB (always included)
# 2. Provided sandbox env vars (agent-scoped, override global on key collision)
# 3. Agent-specific env vars from secrets
# 4. Additional runtime env vars (highest priority)
env_vars = {}
# Always load global sandbox-level environment variables from the database
try:
sandbox_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async(
sandbox_type=SandboxType.MODAL, actor=self.user
)
if sandbox_config:
global_env_vars = await self.sandbox_config_manager.get_sandbox_env_vars_as_dict_async(
sandbox_config_id=sandbox_config.id, actor=self.user, limit=None
)
env_vars.update(global_env_vars)
except Exception as e:
logger.warning(f"Could not load global sandbox env vars for tool {self.tool_name}: {e}")
# Override with provided sandbox env vars (agent-scoped)
if self.provided_sandbox_env_vars:
env_vars.update(self.provided_sandbox_env_vars)
# Override with agent-specific environment variables from secrets
if agent_state:
env_vars.update(agent_state.get_agent_env_vars_as_dict())
# Override with additional env vars passed at runtime (highest priority)
if additional_env_vars:
env_vars.update(additional_env_vars)
# Call the modal function (already retrieved at line 101)
# Convert agent_state to dict to avoid cloudpickle serialization issues
agent_state_dict = agent_state.model_dump() if agent_state else None
logger.info(f"Calling function {func} with arguments {self.args}")
result = await func.remote.aio(
tool_name=self.tool_name,
agent_state=agent_state_dict,
agent_id=self.agent_id,
env_vars=env_vars,
letta_api_key=letta_api_key,
**self.args,
)
logger.info(f"Modal function result: {result}")
# Reconstruct agent_state if it was returned (use original as fallback)
result_agent_state = agent_state
if result.get("agent_state"):
if isinstance(result["agent_state"], dict):
try:
from letta.schemas.agent import AgentState
result_agent_state = AgentState.model_validate(result["agent_state"])
except Exception as e:
logger.warning(f"Failed to reconstruct AgentState: {e}, using original")
else:
result_agent_state = result["agent_state"]
return ToolExecutionResult(
func_return=result["result"],
agent_state=result_agent_state,
stdout=[result["stdout"]],
stderr=[result["stderr"]],
status="error" if result["error"] else "success",
)
except Exception as e:
log_event(
"modal_execution_failed",
{
"tool": self.tool_name,
"modal_app_id": self.tool.id,
"error": str(e),
},
)
logger.error(f"Modal execution failed for tool {self.tool_name} {self.tool.id}: {e}")
return ToolExecutionResult(
func_return=None,
agent_state=agent_state,
stdout=[""],
stderr=[str(e)],
status="error",
)