feat: tool function arguments passed in at runtime

This commit is contained in:
Andy Li
2025-08-15 16:24:56 -07:00
committed by GitHub
parent 773a6452d1
commit 81993f23eb
10 changed files with 2817 additions and 12 deletions

View File

@@ -9,6 +9,7 @@ from letta.schemas.agent import AgentState
from letta.schemas.enums import SandboxType
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
from letta.schemas.pip_requirement import PipRequirement
from letta.services.tool_sandbox.modal_constants import DEFAULT_MODAL_TIMEOUT
from letta.settings import tool_settings
# Sandbox Config
@@ -80,7 +81,7 @@ class E2BSandboxConfig(BaseModel):
class ModalSandboxConfig(BaseModel):
timeout: int = Field(5 * 60, description="Time limit for the sandbox (in seconds).")
timeout: int = Field(DEFAULT_MODAL_TIMEOUT, description="Time limit for the sandbox (in seconds).")
pip_requirements: list[str] | None = Field(None, description="A list of pip packages to install in the Modal sandbox")
npm_requirements: list[str] | None = Field(None, description="A list of npm packages to install in the Modal sandbox")
language: Literal["python", "typescript"] = "python"

View File

@@ -164,22 +164,14 @@ class AsyncToolSandboxBase(ABC):
import ast
try:
# Parse the source code to AST
tree = ast.parse(self.tool.source_code)
# Look for function definitions
for node in ast.walk(tree):
if isinstance(node, ast.AsyncFunctionDef) and node.name == self.tool.name:
return True
elif isinstance(node, ast.FunctionDef) and node.name == self.tool.name:
return False
# If we couldn't find the function definition, fall back to string matching
return "async def " + self.tool.name in self.tool.source_code
except SyntaxError:
# If source code can't be parsed, fall back to string matching
return "async def " + self.tool.name in self.tool.source_code
return False
except:
return False
def use_top_level_await(self) -> bool:
"""

View File

@@ -0,0 +1,17 @@
"""Shared constants for Modal sandbox implementations."""
# Deployment and versioning
DEFAULT_CONFIG_KEY = "default"
MODAL_DEPLOYMENTS_KEY = "modal_deployments"
VERSION_HASH_LENGTH = 12
# Cache settings
CACHE_TTL_SECONDS = 60
# Modal execution settings
DEFAULT_MODAL_TIMEOUT = 60
DEFAULT_MAX_CONCURRENT_INPUTS = 1
DEFAULT_PYTHON_VERSION = "3.12"
# Security settings
SAFE_IMPORT_MODULES = {"typing", "pydantic", "datetime", "enum", "uuid", "decimal"}

View File

@@ -0,0 +1,242 @@
"""
Modal Deployment Manager - Handles deployment orchestration with optional locking.
This module separates deployment logic from the main sandbox execution,
making it easier to understand and optionally disable locking/version tracking.
"""
import hashlib
from typing import Tuple
import modal
from letta.log import get_logger
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.services.tool_sandbox.modal_constants import VERSION_HASH_LENGTH
from letta.services.tool_sandbox.modal_version_manager import ModalVersionManager, get_version_manager
logger = get_logger(__name__)
class ModalDeploymentManager:
"""Manages Modal app deployments with optional locking and version tracking."""
def __init__(
self,
tool: Tool,
version_manager: ModalVersionManager | None = None,
use_locking: bool = True,
use_version_tracking: bool = True,
):
"""
Initialize deployment manager.
Args:
tool: The tool to deploy
version_manager: Version manager for tracking deployments (optional)
use_locking: Whether to use locking for coordinated deployments
use_version_tracking: Whether to track and reuse existing deployments
"""
self.tool = tool
self.version_manager = version_manager or get_version_manager() if (use_locking or use_version_tracking) else None
self.use_locking = use_locking
self.use_version_tracking = use_version_tracking
self._app_name = self._generate_app_name()
def _generate_app_name(self) -> str:
"""Generate app name based on tool ID."""
return self.tool.id[:40]
def calculate_version_hash(self, sbx_config: SandboxConfig) -> str:
"""Calculate version hash for the current configuration."""
components = (
self.tool.source_code,
str(self.tool.pip_requirements) if self.tool.pip_requirements else "",
str(self.tool.npm_requirements) if self.tool.npm_requirements else "",
sbx_config.fingerprint(),
)
combined = "|".join(components)
return hashlib.sha256(combined.encode()).hexdigest()[:VERSION_HASH_LENGTH]
def get_full_app_name(self, version_hash: str) -> str:
"""Get the full app name including version."""
app_full_name = f"{self._app_name}-{version_hash}"
# Ensure total length is under 64 characters
if len(app_full_name) > 63:
max_id_len = 63 - len(version_hash) - 1
app_full_name = f"{self._app_name[:max_id_len]}-{version_hash}"
return app_full_name
async def get_or_deploy_app(
self,
sbx_config: SandboxConfig,
user,
create_app_func,
) -> Tuple[modal.App, str]:
"""
Get existing app or deploy new one.
Args:
sbx_config: Sandbox configuration
user: User/actor for permissions
create_app_func: Function to create and deploy the app
Returns:
Tuple of (Modal app, version hash)
"""
version_hash = self.calculate_version_hash(sbx_config)
# Simple path: no version tracking or locking
if not self.use_version_tracking:
logger.info(f"Deploying Modal app {self._app_name} (version tracking disabled)")
app = await create_app_func(sbx_config, version_hash)
return app, version_hash
# Try to use existing deployment
if self.use_version_tracking:
existing_app = await self._try_get_existing_app(sbx_config, version_hash, user)
if existing_app:
return existing_app, version_hash
# Need to deploy - with or without locking
if self.use_locking:
return await self._deploy_with_locking(sbx_config, version_hash, user, create_app_func)
else:
return await self._deploy_without_locking(sbx_config, version_hash, user, create_app_func)
async def _try_get_existing_app(
self,
sbx_config: SandboxConfig,
version_hash: str,
user,
) -> modal.App | None:
"""Try to get an existing deployed app."""
if not self.version_manager:
return None
deployment = await self.version_manager.get_deployment(
tool_id=self.tool.id, sandbox_config_id=sbx_config.id if sbx_config else None, actor=user
)
if deployment and deployment.version_hash == version_hash:
app_full_name = self.get_full_app_name(version_hash)
logger.info(f"Checking for existing Modal app {app_full_name}")
try:
app = await modal.App.lookup.aio(app_full_name)
logger.info(f"Found existing Modal app {app_full_name}")
return app
except Exception:
logger.info(f"Modal app {app_full_name} not found in Modal, will redeploy")
return None
return None
async def _deploy_without_locking(
self,
sbx_config: SandboxConfig,
version_hash: str,
user,
create_app_func,
) -> Tuple[modal.App, str]:
"""Deploy without locking - simpler but may have race conditions."""
app_full_name = self.get_full_app_name(version_hash)
logger.info(f"Deploying Modal app {app_full_name} (no locking)")
# Deploy the app
app = await create_app_func(sbx_config, version_hash)
# Register deployment if tracking is enabled
if self.use_version_tracking and self.version_manager:
await self._register_deployment(sbx_config, version_hash, app, user)
return app, version_hash
async def _deploy_with_locking(
self,
sbx_config: SandboxConfig,
version_hash: str,
user,
create_app_func,
) -> Tuple[modal.App, str]:
"""Deploy with locking to prevent concurrent deployments."""
cache_key = f"{self.tool.id}:{sbx_config.id if sbx_config else 'default'}"
deployment_lock = self.version_manager.get_deployment_lock(cache_key)
async with deployment_lock:
# Double-check after acquiring lock
existing_app = await self._try_get_existing_app(sbx_config, version_hash, user)
if existing_app:
return existing_app, version_hash
# Check if another process is deploying
if self.version_manager.is_deployment_in_progress(cache_key, version_hash):
logger.info(f"Another process is deploying {self._app_name} v{version_hash}, waiting...")
# Release lock and wait
deployment_lock = None
# Wait for other deployment if needed
if deployment_lock is None:
success = await self.version_manager.wait_for_deployment(cache_key, version_hash, timeout=120)
if success:
existing_app = await self._try_get_existing_app(sbx_config, version_hash, user)
if existing_app:
return existing_app, version_hash
raise RuntimeError(f"Deployment completed but app not found")
else:
raise RuntimeError(f"Timeout waiting for deployment")
# We're deploying - mark as in progress
deployment_key = None
async with deployment_lock:
deployment_key = self.version_manager.mark_deployment_in_progress(cache_key, version_hash)
try:
app_full_name = self.get_full_app_name(version_hash)
logger.info(f"Deploying Modal app {app_full_name} with locking")
# Deploy the app
app = await create_app_func(sbx_config, version_hash)
# Mark deployment complete
if deployment_key:
self.version_manager.complete_deployment(deployment_key)
# Register deployment
if self.use_version_tracking:
await self._register_deployment(sbx_config, version_hash, app, user)
return app, version_hash
except Exception:
if deployment_key:
self.version_manager.complete_deployment(deployment_key)
raise
async def _register_deployment(
self,
sbx_config: SandboxConfig,
version_hash: str,
app: modal.App,
user,
):
if not self.version_manager:
return
dependencies = set()
if self.tool.pip_requirements:
dependencies.update(str(req) for req in self.tool.pip_requirements)
modal_config = sbx_config.get_modal_config()
if modal_config.pip_requirements:
dependencies.update(str(req) for req in modal_config.pip_requirements)
await self.version_manager.register_deployment(
tool_id=self.tool.id,
app_name=self._app_name,
version_hash=version_hash,
app=app,
dependencies=dependencies,
sandbox_config_id=sbx_config.id if sbx_config else None,
actor=user,
)

View File

@@ -0,0 +1,429 @@
"""
This runs tool calls within an isolated modal sandbox. This does this by doing the following:
1. deploying modal functions that embed the original functions
2. dynamically executing tools with arguments passed in at runtime
3. tracking deployment versions to know when a deployment update is needed
"""
from typing import Any, Dict
import modal
from letta.log import get_logger
from letta.otel.tracing import log_event, trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import SandboxType
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
from letta.services.tool_sandbox.modal_constants import DEFAULT_MAX_CONCURRENT_INPUTS, DEFAULT_PYTHON_VERSION
from letta.services.tool_sandbox.modal_deployment_manager import ModalDeploymentManager
from letta.services.tool_sandbox.modal_version_manager import ModalVersionManager
from letta.services.tool_sandbox.safe_pickle import SafePickleError, safe_pickle_dumps, sanitize_for_pickle
from letta.settings import tool_settings
from letta.types import JsonDict
from letta.utils import get_friendly_error_msg
logger = get_logger(__name__)
class AsyncToolSandboxModalV2(AsyncToolSandboxBase):
"""Modal sandbox with dynamic argument passing and version tracking."""
def __init__(
self,
tool_name: str,
args: JsonDict,
user,
tool_object: Tool | None = None,
sandbox_config: SandboxConfig | None = None,
sandbox_env_vars: dict[str, Any] | None = None,
version_manager: ModalVersionManager | None = None,
use_locking: bool = True,
use_version_tracking: bool = True,
):
"""
Initialize the Modal sandbox.
Args:
tool_name: Name of the tool to execute
args: Arguments to pass to the tool
user: User/actor for permissions
tool_object: Tool object (optional)
sandbox_config: Sandbox configuration (optional)
sandbox_env_vars: Environment variables (optional)
version_manager: Version manager, will create default if needed (optional)
use_locking: Whether to use locking for deployment coordination (default: True)
use_version_tracking: Whether to track and reuse deployments (default: True)
"""
super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars)
if not tool_settings.modal_token_id or not tool_settings.modal_token_secret:
raise ValueError("MODAL_TOKEN_ID and MODAL_TOKEN_SECRET must be set.")
# Initialize deployment manager with configurable options
self._deployment_manager = ModalDeploymentManager(
tool=self.tool,
version_manager=version_manager,
use_locking=use_locking,
use_version_tracking=use_version_tracking,
)
self._version_hash = None
async def _get_or_deploy_modal_app(self, sbx_config: SandboxConfig) -> modal.App:
"""Get existing Modal app or deploy a new version if needed."""
app, version_hash = await self._deployment_manager.get_or_deploy_app(
sbx_config=sbx_config,
user=self.user,
create_app_func=self._create_and_deploy_app,
)
self._version_hash = version_hash
return app
async def _create_and_deploy_app(self, sbx_config: SandboxConfig, version: str) -> modal.App:
"""Create and deploy a new Modal app with the executor function."""
import importlib.util
from pathlib import Path
# App name = tool_id + version hash
app_full_name = self._deployment_manager.get_full_app_name(version)
app = modal.App(app_full_name)
modal_config = sbx_config.get_modal_config()
image = self._get_modal_image(sbx_config)
# Find the sandbox module dynamically
spec = importlib.util.find_spec("sandbox")
if not spec or not spec.origin:
raise ValueError("Could not find sandbox module")
sandbox_dir = Path(spec.origin).parent
# Read the modal_executor module content
executor_path = sandbox_dir / "modal_executor.py"
if not executor_path.exists():
raise ValueError(f"modal_executor.py not found at {executor_path}")
with open(executor_path, "r") as f:
f.read()
# Create a single file mount instead of directory mount
# This avoids sys.path manipulation
image = image.add_local_file(str(executor_path), remote_path="/modal_executor.py")
# Register the executor function with Modal
@app.function(
image=image,
timeout=modal_config.timeout,
restrict_modal_access=True,
max_inputs=DEFAULT_MAX_CONCURRENT_INPUTS,
serialized=True,
)
def tool_executor(
tool_source: str,
tool_name: str,
args_pickled: bytes,
agent_state_pickled: bytes | None,
inject_agent_state: bool,
is_async: bool,
args_schema_code: str | None,
environment_vars: Dict[str, Any],
) -> Dict[str, Any]:
"""Execute tool in Modal container."""
# Execute the modal_executor code in a clean namespace
# Create a module-like namespace for executor
executor_namespace = {
"__name__": "modal_executor",
"__file__": "/modal_executor.py",
}
# Read and execute the module file
with open("/modal_executor.py", "r") as f:
exec(compile(f.read(), "/modal_executor.py", "exec"), executor_namespace)
# Call the wrapper function from the executed namespace
return executor_namespace["execute_tool_wrapper"](
tool_source=tool_source,
tool_name=tool_name,
args_pickled=args_pickled,
agent_state_pickled=agent_state_pickled,
inject_agent_state=inject_agent_state,
is_async=is_async,
args_schema_code=args_schema_code,
environment_vars=environment_vars,
)
# Store the function reference
app.tool_executor = tool_executor
# Deploy the app
logger.info(f"Deploying Modal app {app_full_name}")
log_event("modal_v2_deploy_started", {"app_name": app_full_name, "version": version})
try:
# Try to look up the app first to see if it already exists
try:
await modal.App.lookup.aio(app_full_name)
logger.info(f"Modal app {app_full_name} already exists, skipping deployment")
log_event("modal_v2_deploy_already_exists", {"app_name": app_full_name, "version": version})
# Return the created app with the function attached
return app
except:
# App doesn't exist, need to deploy
pass
with modal.enable_output():
await app.deploy.aio()
log_event("modal_v2_deploy_succeeded", {"app_name": app_full_name, "version": version})
except Exception as e:
log_event("modal_v2_deploy_failed", {"app_name": app_full_name, "version": version, "error": str(e)})
raise
return app
@trace_method
async def run(
self,
agent_state: AgentState | None = None,
additional_env_vars: Dict | None = None,
) -> ToolExecutionResult:
"""Execute the tool in Modal sandbox with dynamic argument passing."""
if self.provided_sandbox_config:
sbx_config = self.provided_sandbox_config
else:
sbx_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async(
sandbox_type=SandboxType.MODAL, actor=self.user
)
envs = await self._gather_env_vars(agent_state, additional_env_vars or {}, sbx_config.id, is_local=False)
# Prepare schema code if needed
args_schema_code = None
if self.tool.args_json_schema:
from letta.services.helpers.tool_execution_helper import add_imports_and_pydantic_schemas_for_args
args_schema_code = add_imports_and_pydantic_schemas_for_args(self.tool.args_json_schema)
# Serialize arguments and agent state with safety checks
try:
args_pickled = safe_pickle_dumps(self.args)
except SafePickleError as e:
logger.warning(f"Failed to pickle args, attempting sanitization: {e}")
sanitized_args = sanitize_for_pickle(self.args)
try:
args_pickled = safe_pickle_dumps(sanitized_args)
except SafePickleError:
# Final fallback: convert to string representation
args_pickled = safe_pickle_dumps(str(self.args))
agent_state_pickled = None
if self.inject_agent_state and agent_state:
try:
agent_state_pickled = safe_pickle_dumps(agent_state)
except SafePickleError as e:
logger.warning(f"Failed to pickle agent state: {e}")
# For agent state, we prefer to skip injection rather than send corrupted data
agent_state_pickled = None
self.inject_agent_state = False
try:
log_event(
"modal_execution_started",
{
"tool": self.tool_name,
"app_name": self._deployment_manager._app_name,
"version": self._version_hash,
"env_vars": list(envs),
"args_size": len(args_pickled),
"agent_state_size": len(agent_state_pickled) if agent_state_pickled else 0,
"inject_agent_state": self.inject_agent_state,
},
)
# Get or deploy the Modal app
app = await self._get_or_deploy_modal_app(sbx_config)
# Get modal config for timeout settings
modal_config = sbx_config.get_modal_config()
# Execute the tool remotely with retry logic
max_retries = 3
retry_delay = 1 # seconds
last_error = None
for attempt in range(max_retries):
try:
# Add timeout to prevent hanging
import asyncio
result = await asyncio.wait_for(
app.tool_executor.remote.aio(
tool_source=self.tool.source_code,
tool_name=self.tool.name,
args_pickled=args_pickled,
agent_state_pickled=agent_state_pickled,
inject_agent_state=self.inject_agent_state,
is_async=self.is_async_function,
args_schema_code=args_schema_code,
environment_vars=envs,
),
timeout=modal_config.timeout + 10, # Add 10s buffer to Modal's own timeout
)
break # Success, exit retry loop
except asyncio.TimeoutError as e:
last_error = e
logger.warning(f"Modal execution timeout on attempt {attempt + 1}/{max_retries} for tool {self.tool_name}")
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
except Exception as e:
last_error = e
# Check if it's a transient error worth retrying
error_str = str(e).lower()
if any(x in error_str for x in ["segmentation fault", "sigsegv", "connection", "timeout"]):
logger.warning(f"Transient error on attempt {attempt + 1}/{max_retries} for tool {self.tool_name}: {e}")
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
retry_delay *= 2
continue
# Non-transient error, don't retry
raise
else:
# All retries exhausted
raise last_error
# Process the result
if result["error"]:
logger.debug(f"Tool {self.tool_name} raised a {result['error']['name']}: {result['error']['value']}")
logger.debug(f"Traceback from Modal sandbox: \n{result['error']['traceback']}")
# Check for segfault indicators
is_segfault = False
if "SIGSEGV" in str(result["error"]["value"]) or "Segmentation fault" in str(result["error"]["value"]):
is_segfault = True
logger.error(f"SEGFAULT detected in tool {self.tool_name}: {result['error']['value']}")
func_return = get_friendly_error_msg(
function_name=self.tool_name,
exception_name=result["error"]["name"],
exception_message=result["error"]["value"],
)
log_event(
"modal_execution_failed",
{
"tool": self.tool_name,
"app_name": self._deployment_manager._app_name,
"version": self._version_hash,
"error_type": result["error"]["name"],
"error_message": result["error"]["value"],
"func_return": func_return,
"is_segfault": is_segfault,
"stdout": result.get("stdout", ""),
"stderr": result.get("stderr", ""),
},
)
status = "error"
else:
func_return = result["result"]
agent_state = result["agent_state"]
log_event(
"modal_v2_execution_succeeded",
{
"tool": self.tool_name,
"app_name": self._deployment_manager._app_name,
"version": self._version_hash,
"func_return": str(func_return)[:500], # Limit logged result size
"stdout_size": len(result.get("stdout", "")),
"stderr_size": len(result.get("stderr", "")),
},
)
status = "success"
return ToolExecutionResult(
func_return=func_return,
agent_state=agent_state if not result["error"] else None,
stdout=[result["stdout"]] if result["stdout"] else [],
stderr=[result["stderr"]] if result["stderr"] else [],
status=status,
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
except Exception as e:
import traceback
error_context = {
"tool": self.tool_name,
"app_name": self._deployment_manager._app_name,
"version": self._version_hash,
"error_type": type(e).__name__,
"error_message": str(e),
"traceback": traceback.format_exc(),
}
logger.error(f"Modal V2 execution for tool {self.tool_name} encountered an error: {e}", extra=error_context)
# Determine if this is a deployment error or execution error
if "deploy" in str(e).lower() or "modal" in str(e).lower():
error_category = "deployment_error"
else:
error_category = "execution_error"
func_return = get_friendly_error_msg(
function_name=self.tool_name,
exception_name=type(e).__name__,
exception_message=str(e),
)
log_event(f"modal_v2_{error_category}", error_context)
return ToolExecutionResult(
func_return=func_return,
agent_state=None,
stdout=[],
stderr=[f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}"],
status="error",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
def _get_modal_image(self, sbx_config: SandboxConfig) -> modal.Image:
"""Get Modal image with required public python dependencies.
Caching and rebuilding is handled in a cascading manner
https://modal.com/docs/guide/images#image-caching-and-rebuilds
"""
# Start with a more robust base image with development tools
image = modal.Image.debian_slim(python_version=DEFAULT_PYTHON_VERSION)
# Add system packages for better C extension support
image = image.apt_install(
"build-essential", # Compilation tools
"libsqlite3-dev", # SQLite development headers
"libffi-dev", # Foreign Function Interface library
"libssl-dev", # OpenSSL development headers
"python3-dev", # Python development headers
)
# Include dependencies required by letta's ORM modules
# These are needed when unpickling agent_state objects
all_requirements = [
"letta",
"sqlite-vec>=0.1.7a2", # Required for SQLite vector operations
"numpy<2.0", # Pin numpy to avoid compatibility issues
]
# Add sandbox-specific pip requirements
modal_configs = sbx_config.get_modal_config()
if modal_configs.pip_requirements:
all_requirements.extend([str(req) for req in modal_configs.pip_requirements])
# Add tool-specific pip requirements
if self.tool and self.tool.pip_requirements:
all_requirements.extend([str(req) for req in self.tool.pip_requirements])
if all_requirements:
image = image.pip_install(*all_requirements)
return image

View File

@@ -0,0 +1,273 @@
"""
This module tracks and manages deployed app versions. We currently use the tools.metadata field
to store the information detailing modal deployments and when we need to redeploy due to changes.
Modal Version Manager - Tracks and manages deployed Modal app versions.
"""
import asyncio
import time
from datetime import datetime
from typing import Any
import modal
from pydantic import BaseModel, ConfigDict, Field
from letta.log import get_logger
from letta.schemas.tool import ToolUpdate
from letta.services.tool_manager import ToolManager
from letta.services.tool_sandbox.modal_constants import CACHE_TTL_SECONDS, DEFAULT_CONFIG_KEY, MODAL_DEPLOYMENTS_KEY
logger = get_logger(__name__)
class DeploymentInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Information about a deployed Modal app."""
app_name: str = Field(..., description="The name of the modal app.")
version_hash: str = Field(..., description="The version hash of the modal app.")
deployed_at: datetime = Field(..., description="The time the modal app was deployed.")
dependencies: set[str] = Field(default_factory=set, description="A set of dependencies.")
# app_reference: modal.App | None = Field(None, description="The reference to the modal app.", exclude=True)
app_reference: Any = Field(None, description="The reference to the modal app.", exclude=True)
class ModalVersionManager:
"""Manages versions and deployments of Modal apps using tools.metadata."""
def __init__(self):
self.tool_manager = ToolManager()
self._deployment_locks: dict[str, asyncio.Lock] = {}
self._cache: dict[str, tuple[DeploymentInfo, float]] = {}
self._deployments_in_progress: dict[str, asyncio.Event] = {}
self._deployments: dict[str, DeploymentInfo] = {} # Track all deployments for stats
@staticmethod
def _make_cache_key(tool_id: str, sandbox_config_id: str | None = None) -> str:
"""Generate cache key for tool and config combination."""
return f"{tool_id}:{sandbox_config_id or DEFAULT_CONFIG_KEY}"
@staticmethod
def _get_config_key(sandbox_config_id: str | None = None) -> str:
"""Get standardized config key."""
return sandbox_config_id or DEFAULT_CONFIG_KEY
def _is_cache_valid(self, timestamp: float) -> bool:
"""Check if cache entry is still valid."""
return time.time() - timestamp < CACHE_TTL_SECONDS
def _get_deployment_metadata(self, tool) -> dict:
"""Get or initialize modal deployments metadata."""
if not tool.metadata_:
tool.metadata_ = {}
if MODAL_DEPLOYMENTS_KEY not in tool.metadata_:
tool.metadata_[MODAL_DEPLOYMENTS_KEY] = {}
return tool.metadata_[MODAL_DEPLOYMENTS_KEY]
def _create_deployment_data(self, app_name: str, version_hash: str, dependencies: set[str]) -> dict:
"""Create deployment data dictionary for metadata storage."""
return {
"app_name": app_name,
"version_hash": version_hash,
"deployed_at": datetime.now().isoformat(),
"dependencies": list(dependencies),
}
async def get_deployment(self, tool_id: str, sandbox_config_id: str | None = None, actor=None) -> DeploymentInfo | None:
"""Get deployment info from tool metadata."""
cache_key = self._make_cache_key(tool_id, sandbox_config_id)
if cache_key in self._cache:
info, timestamp = self._cache[cache_key]
if self._is_cache_valid(timestamp):
return info
tool = self.tool_manager.get_tool_by_id(tool_id, actor=actor)
if not tool or not tool.metadata_:
return None
modal_deployments = tool.metadata_.get(MODAL_DEPLOYMENTS_KEY, {})
config_key = self._get_config_key(sandbox_config_id)
if config_key not in modal_deployments:
return None
deployment_data = modal_deployments[config_key]
info = DeploymentInfo(
app_name=deployment_data["app_name"],
version_hash=deployment_data["version_hash"],
deployed_at=datetime.fromisoformat(deployment_data["deployed_at"]),
dependencies=set(deployment_data.get("dependencies", [])),
app_reference=None,
)
self._cache[cache_key] = (info, time.time())
return info
async def register_deployment(
self,
tool_id: str,
app_name: str,
version_hash: str,
app: modal.App,
dependencies: set[str] | None = None,
sandbox_config_id: str | None = None,
actor=None,
) -> DeploymentInfo:
"""Register a new deployment in tool metadata."""
cache_key = self._make_cache_key(tool_id, sandbox_config_id)
config_key = self._get_config_key(sandbox_config_id)
async with self.get_deployment_lock(cache_key):
tool = self.tool_manager.get_tool_by_id(tool_id, actor=actor)
if not tool:
raise ValueError(f"Tool {tool_id} not found")
modal_deployments = self._get_deployment_metadata(tool)
info = DeploymentInfo(
app_name=app_name,
version_hash=version_hash,
deployed_at=datetime.now(),
dependencies=dependencies or set(),
app_reference=app,
)
modal_deployments[config_key] = self._create_deployment_data(app_name, version_hash, info.dependencies)
# Use ToolUpdate to update metadata
tool_update = ToolUpdate(metadata_=tool.metadata_)
await self.tool_manager.update_tool_by_id_async(
tool_id=tool_id,
tool_update=tool_update,
actor=actor,
)
self._cache[cache_key] = (info, time.time())
self._deployments[cache_key] = info # Track for stats
return info
async def needs_redeployment(self, tool_id: str, current_version: str, sandbox_config_id: str | None = None, actor=None) -> bool:
"""Check if an app needs to be redeployed."""
deployment = await self.get_deployment(tool_id, sandbox_config_id, actor=actor)
if not deployment:
return True
return deployment.version_hash != current_version
def get_deployment_lock(self, cache_key: str) -> asyncio.Lock:
"""Get or create a deployment lock for a tool+config combination."""
if cache_key not in self._deployment_locks:
self._deployment_locks[cache_key] = asyncio.Lock()
return self._deployment_locks[cache_key]
def mark_deployment_in_progress(self, cache_key: str, version_hash: str) -> str:
"""Mark that a deployment is in progress for a specific version.
Returns a unique deployment ID that should be used to complete/fail the deployment.
"""
deployment_key = f"{cache_key}:{version_hash}"
if deployment_key not in self._deployments_in_progress:
self._deployments_in_progress[deployment_key] = asyncio.Event()
return deployment_key
def is_deployment_in_progress(self, cache_key: str, version_hash: str) -> bool:
"""Check if a deployment is currently in progress."""
deployment_key = f"{cache_key}:{version_hash}"
return deployment_key in self._deployments_in_progress
async def wait_for_deployment(self, cache_key: str, version_hash: str, timeout: float = 120) -> bool:
"""Wait for an in-progress deployment to complete.
Returns True if deployment completed within timeout, False otherwise.
"""
deployment_key = f"{cache_key}:{version_hash}"
if deployment_key not in self._deployments_in_progress:
return True # No deployment in progress
event = self._deployments_in_progress[deployment_key]
try:
await asyncio.wait_for(event.wait(), timeout=timeout)
return True
except asyncio.TimeoutError:
return False
def complete_deployment(self, deployment_key: str):
"""Mark a deployment as complete and wake up any waiters."""
if deployment_key in self._deployments_in_progress:
self._deployments_in_progress[deployment_key].set()
# Clean up after a short delay to allow waiters to wake up
asyncio.create_task(self._cleanup_deployment_marker(deployment_key))
async def _cleanup_deployment_marker(self, deployment_key: str):
"""Clean up deployment marker after a delay."""
await asyncio.sleep(5) # Give waiters time to wake up
if deployment_key in self._deployments_in_progress:
del self._deployments_in_progress[deployment_key]
async def force_redeploy(self, tool_id: str, sandbox_config_id: str | None = None, actor=None):
"""Force a redeployment by removing deployment info from tool metadata."""
cache_key = self._make_cache_key(tool_id, sandbox_config_id)
config_key = self._get_config_key(sandbox_config_id)
async with self.get_deployment_lock(cache_key):
tool = self.tool_manager.get_tool_by_id(tool_id, actor=actor)
if not tool or not tool.metadata_:
return
modal_deployments = tool.metadata_.get(MODAL_DEPLOYMENTS_KEY, {})
if config_key in modal_deployments:
del modal_deployments[config_key]
# Use ToolUpdate to update metadata
tool_update = ToolUpdate(metadata_=tool.metadata_)
await self.tool_manager.update_tool_by_id_async(
tool_id=tool_id,
tool_update=tool_update,
actor=actor,
)
if cache_key in self._cache:
del self._cache[cache_key]
def clear_deployments(self):
"""Clear all deployment tracking (for testing purposes)."""
self._deployments.clear()
self._cache.clear()
self._deployments_in_progress.clear()
async def get_deployment_stats(self) -> dict:
"""Get statistics about current deployments."""
total_deployments = len(self._deployments)
active_deployments = len([d for d in self._deployments.values() if d])
stale_deployments = total_deployments - active_deployments
deployments_list = []
for cache_key, deployment in self._deployments.items():
if deployment:
deployments_list.append(
{
"app_name": deployment.app_name,
"version": deployment.version_hash,
"usage_count": 1, # Track usage in future
"deployed_at": deployment.deployed_at.isoformat(),
}
)
return {
"total_deployments": total_deployments,
"active_deployments": active_deployments,
"stale_deployments": stale_deployments,
"deployments": deployments_list,
}
_version_manager = None
def get_version_manager() -> ModalVersionManager:
"""Get the global Modal version manager instance."""
global _version_manager
if _version_manager is None:
_version_manager = ModalVersionManager()
return _version_manager

View File

@@ -0,0 +1,193 @@
"""Safe pickle serialization wrapper for Modal sandbox.
This module provides defensive serialization utilities to prevent segmentation
faults and other crashes when passing complex objects to Modal containers.
"""
import pickle
import sys
from typing import Any, Optional, Tuple
from letta.log import get_logger
logger = get_logger(__name__)
# Serialization limits
MAX_PICKLE_SIZE = 10 * 1024 * 1024 # 10MB limit
MAX_RECURSION_DEPTH = 50 # Prevent deep object graphs
PICKLE_PROTOCOL = 4 # Use protocol 4 for better compatibility
class SafePickleError(Exception):
"""Raised when safe pickling fails."""
class RecursionLimiter:
"""Context manager to limit recursion depth during pickling."""
def __init__(self, max_depth: int):
self.max_depth = max_depth
self.original_limit = None
def __enter__(self):
self.original_limit = sys.getrecursionlimit()
sys.setrecursionlimit(min(self.max_depth, self.original_limit))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.original_limit is not None:
sys.setrecursionlimit(self.original_limit)
def safe_pickle_dumps(obj: Any, max_size: int = MAX_PICKLE_SIZE) -> bytes:
"""Safely pickle an object with size and recursion limits.
Args:
obj: The object to pickle
max_size: Maximum allowed pickle size in bytes
Returns:
bytes: The pickled object
Raises:
SafePickleError: If pickling fails or exceeds limits
"""
try:
# First check for obvious size issues
# Do a quick pickle to check size
quick_pickle = pickle.dumps(obj, protocol=PICKLE_PROTOCOL)
if len(quick_pickle) > max_size:
raise SafePickleError(f"Pickle size {len(quick_pickle)} exceeds limit {max_size}")
# Check recursion depth by traversing the object
def check_depth(obj, depth=0):
if depth > MAX_RECURSION_DEPTH:
raise SafePickleError(f"Object graph too deep (depth > {MAX_RECURSION_DEPTH})")
if isinstance(obj, (list, tuple)):
for item in obj:
check_depth(item, depth + 1)
elif isinstance(obj, dict):
for value in obj.values():
check_depth(value, depth + 1)
elif hasattr(obj, "__dict__"):
check_depth(obj.__dict__, depth + 1)
check_depth(obj)
logger.debug(f"Successfully pickled object of size {len(quick_pickle)} bytes")
return quick_pickle
except SafePickleError:
raise
except RecursionError as e:
raise SafePickleError(f"Object graph too deep: {e}")
except Exception as e:
raise SafePickleError(f"Failed to pickle object: {e}")
def safe_pickle_loads(data: bytes) -> Any:
"""Safely unpickle data with error handling.
Args:
data: The pickled data
Returns:
Any: The unpickled object
Raises:
SafePickleError: If unpickling fails
"""
if not data:
raise SafePickleError("Cannot unpickle empty data")
if len(data) > MAX_PICKLE_SIZE:
raise SafePickleError(f"Pickle data size {len(data)} exceeds limit {MAX_PICKLE_SIZE}")
try:
obj = pickle.loads(data)
logger.debug(f"Successfully unpickled object from {len(data)} bytes")
return obj
except Exception as e:
raise SafePickleError(f"Failed to unpickle data: {e}")
def try_pickle_with_fallback(obj: Any, fallback_value: Any = None, max_size: int = MAX_PICKLE_SIZE) -> Tuple[Optional[bytes], bool]:
"""Try to pickle an object with fallback on failure.
Args:
obj: The object to pickle
fallback_value: Value to use if pickling fails
max_size: Maximum allowed pickle size
Returns:
Tuple of (pickled_data or None, success_flag)
"""
try:
pickled = safe_pickle_dumps(obj, max_size)
return pickled, True
except SafePickleError as e:
logger.warning(f"Failed to pickle object, using fallback: {e}")
if fallback_value is not None:
try:
pickled = safe_pickle_dumps(fallback_value, max_size)
return pickled, False
except SafePickleError:
pass
return None, False
def validate_pickleable(obj: Any) -> bool:
"""Check if an object can be safely pickled.
Args:
obj: The object to validate
Returns:
bool: True if the object can be pickled safely
"""
try:
# Try to pickle to a small buffer
safe_pickle_dumps(obj, max_size=MAX_PICKLE_SIZE)
return True
except SafePickleError:
return False
def sanitize_for_pickle(obj: Any) -> Any:
"""Sanitize an object for safe pickling.
This function attempts to make an object pickleable by converting
problematic types to safe alternatives.
Args:
obj: The object to sanitize
Returns:
Any: A sanitized version of the object
"""
# Handle common problematic types
if hasattr(obj, "__dict__"):
# For objects with __dict__, try to sanitize attributes
sanitized = {}
for key, value in obj.__dict__.items():
if key.startswith("_"):
continue # Skip private attributes
# Convert non-pickleable types
if callable(value):
sanitized[key] = f"<function {value.__name__}>"
elif hasattr(value, "__module__"):
sanitized[key] = f"<{value.__class__.__name__} object>"
else:
try:
# Test if the value is pickleable
pickle.dumps(value, protocol=PICKLE_PROTOCOL)
sanitized[key] = value
except:
sanitized[key] = str(value)
return sanitized
# For other types, return as-is and let pickle handle it
return obj

260
sandbox/modal_executor.py Normal file
View File

@@ -0,0 +1,260 @@
"""Modal function executor for tool sandbox v2.
This module contains the executor function that runs inside Modal containers
to execute tool functions with dynamically passed arguments.
"""
import faulthandler
import signal
from typing import Any, Dict
import modal
# List of safe modules that can be imported in schema code
SAFE_IMPORT_MODULES = {
"typing",
"datetime",
"uuid",
"enum",
"decimal",
"collections",
"abc",
"dataclasses",
"pydantic",
"typing_extensions",
}
class ModalFunctionExecutor:
"""Executes tool functions in Modal with dynamic argument passing."""
@staticmethod
def execute_tool_dynamic(
tool_source: str,
tool_name: str,
args_pickled: bytes,
agent_state_pickled: bytes | None,
inject_agent_state: bool,
is_async: bool,
args_schema_code: str | None,
) -> dict[str, Any]:
"""Execute a tool function with dynamically passed arguments.
This function runs inside the Modal container and receives all parameters
at runtime rather than having them embedded in a script.
"""
import asyncio
import pickle
import sys
import traceback
from io import StringIO
# Enable fault handler for better debugging of segfaults
faulthandler.enable()
stdout_capture = StringIO()
stderr_capture = StringIO()
old_stdout = sys.stdout
old_stderr = sys.stderr
try:
sys.stdout = stdout_capture
sys.stderr = stderr_capture
# Safely unpickle arguments with size validation
if not args_pickled:
raise ValueError("No arguments provided")
if len(args_pickled) > 10 * 1024 * 1024: # 10MB limit
raise ValueError(f"Pickled args too large: {len(args_pickled)} bytes")
try:
args = pickle.loads(args_pickled)
except Exception as e:
raise ValueError(f"Failed to unpickle arguments: {e}")
agent_state = None
if agent_state_pickled:
if len(agent_state_pickled) > 10 * 1024 * 1024: # 10MB limit
raise ValueError(f"Pickled agent state too large: {len(agent_state_pickled)} bytes")
try:
agent_state = pickle.loads(agent_state_pickled)
except Exception as e:
# Log but don't fail - agent state is optional
print(f"Warning: Failed to unpickle agent state: {e}", file=sys.stderr)
agent_state = None
exec_globals = {
"__name__": "__main__",
"__builtins__": __builtins__,
}
if args_schema_code:
import ast
try:
tree = ast.parse(args_schema_code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split(".")[0]
if module_name not in SAFE_IMPORT_MODULES:
raise ValueError(f"Import of '{module_name}' not allowed in schema code")
elif isinstance(node, ast.ImportFrom):
if node.module:
module_name = node.module.split(".")[0]
if module_name not in SAFE_IMPORT_MODULES:
raise ValueError(f"Import from '{module_name}' not allowed in schema code")
exec(compile(tree, "<schema>", "exec"), exec_globals)
except (SyntaxError, ValueError) as e:
raise ValueError(f"Invalid or unsafe schema code: {e}")
exec(tool_source, exec_globals)
if tool_name not in exec_globals:
raise ValueError(f"Function '{tool_name}' not found in tool source code")
func = exec_globals[tool_name]
kwargs = dict(args)
if inject_agent_state:
kwargs["agent_state"] = agent_state
if is_async:
result = asyncio.run(func(**kwargs))
else:
result = func(**kwargs)
try:
from pydantic import BaseModel, ConfigDict
class _TempResultWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
result: Any
wrapped = _TempResultWrapper(result=result)
serialized_result = wrapped.model_dump()["result"]
except (ImportError, Exception):
serialized_result = str(result)
return {
"result": serialized_result,
"agent_state": agent_state,
"stdout": stdout_capture.getvalue(),
"stderr": stderr_capture.getvalue(),
"error": None,
}
except Exception as e:
return {
"result": None,
"agent_state": None,
"stdout": stdout_capture.getvalue(),
"stderr": stderr_capture.getvalue(),
"error": {
"name": type(e).__name__,
"value": str(e),
"traceback": traceback.format_exc(),
},
}
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
def setup_signal_handlers():
"""Setup signal handlers for better debugging."""
def handle_segfault(signum, frame):
import sys
import traceback
print(f"SEGFAULT detected! Signal: {signum}", file=sys.stderr)
print("Stack trace:", file=sys.stderr)
traceback.print_stack(frame, file=sys.stderr)
sys.exit(139) # Standard segfault exit code
def handle_abort(signum, frame):
import sys
import traceback
print(f"ABORT detected! Signal: {signum}", file=sys.stderr)
print("Stack trace:", file=sys.stderr)
traceback.print_stack(frame, file=sys.stderr)
sys.exit(134) # Standard abort exit code
# Register signal handlers
signal.signal(signal.SIGSEGV, handle_segfault)
signal.signal(signal.SIGABRT, handle_abort)
@modal.method()
def execute_tool_wrapper(
self,
tool_source: str,
tool_name: str,
args_pickled: bytes,
agent_state_pickled: bytes | None,
inject_agent_state: bool,
is_async: bool,
args_schema_code: str | None,
environment_vars: Dict[str, str],
) -> Dict[str, Any]:
"""Wrapper function that runs in Modal container with enhanced error handling."""
import os
import resource
import sys
# Setup signal handlers for better crash debugging
setup_signal_handlers()
# Enable fault handler with file output
try:
faulthandler.enable(file=sys.stderr, all_threads=True)
except:
pass # Faulthandler might not be available
# Set resource limits to prevent runaway processes
try:
# Limit memory usage to 1GB
resource.setrlimit(resource.RLIMIT_AS, (1024 * 1024 * 1024, 1024 * 1024 * 1024))
# Limit stack size to 8MB (default is often unlimited)
resource.setrlimit(resource.RLIMIT_STACK, (8 * 1024 * 1024, 8 * 1024 * 1024))
except:
pass # Resource limits might not be available
# Set environment variables
for key, value in environment_vars.items():
os.environ[key] = str(value)
# Add debugging environment variables
os.environ["PYTHONFAULTHANDLER"] = "1"
os.environ["PYTHONDEVMODE"] = "1"
try:
# Execute the tool
return ModalFunctionExecutor.execute_tool_dynamic(
tool_source=tool_source,
tool_name=tool_name,
args_pickled=args_pickled,
agent_state_pickled=agent_state_pickled,
inject_agent_state=inject_agent_state,
is_async=is_async,
args_schema_code=args_schema_code,
)
except Exception as e:
import traceback
# Enhanced error reporting
return {
"result": None,
"agent_state": None,
"stdout": "",
"stderr": f"Container execution failed: {traceback.format_exc()}",
"error": {
"name": type(e).__name__,
"value": str(e),
"traceback": traceback.format_exc(),
},
}

View File

@@ -0,0 +1,828 @@
"""
Integration tests for Modal Sandbox V2.
These tests cover:
- Basic tool execution with Modal
- Error handling and edge cases
- Async tool execution
- Version tracking and redeployment
- Persistence of deployment metadata
- Concurrent execution handling
- Multiple sandbox configurations
- Service restart scenarios
"""
import asyncio
import os
import uuid
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from letta.schemas.enums import ToolSourceType
from letta.schemas.organization import Organization
from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.sandbox_config import ModalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxType
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_sandbox.modal_sandbox_v2 import AsyncToolSandboxModalV2
from letta.services.tool_sandbox.modal_version_manager import ModalVersionManager, get_version_manager
from letta.services.user_manager import UserManager
@pytest.fixture
def event_loop():
"""Create an instance of the default event loop for the test session."""
loop = asyncio.new_event_loop()
yield loop
# Cleanup tasks before closing loop
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
# ============================================================================
# SHARED FIXTURES
# ============================================================================
@pytest.fixture
def test_organization():
"""Create a test organization in the database."""
org_manager = OrganizationManager()
org = org_manager.create_organization(Organization(name=f"test-org-{uuid.uuid4().hex[:8]}"))
yield org
# Cleanup would go here if needed
@pytest.fixture
def test_user(test_organization):
"""Create a test user in the database."""
user_manager = UserManager()
user = user_manager.create_user(User(name=f"test-user-{uuid.uuid4().hex[:8]}", organization_id=test_organization.id))
yield user
# Cleanup would go here if needed
@pytest.fixture
def mock_user():
"""Create a mock user for tests that don't need database persistence."""
user = MagicMock()
user.organization_id = f"test-org-{uuid.uuid4().hex[:8]}"
user.id = f"user-{uuid.uuid4().hex[:8]}"
return user
@pytest.fixture
def basic_tool(test_user):
"""Create a basic tool for testing."""
from letta.services.tool_manager import ToolManager
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="calculate",
source_type=ToolSourceType.python,
source_code="""
def calculate(operation: str, a: float, b: float) -> float:
'''Perform a calculation on two numbers.
Args:
operation: The operation to perform (add, subtract, multiply, divide)
a: The first number
b: The second number
Returns:
float: The result of the calculation
'''
if operation == "add":
return a + b
elif operation == "subtract":
return a - b
elif operation == "multiply":
return a * b
elif operation == "divide":
if b == 0:
raise ValueError("Cannot divide by zero")
return a / b
else:
raise ValueError(f"Unknown operation: {operation}")
""",
json_schema={
"parameters": {
"properties": {
"operation": {"type": "string", "description": "The operation to perform"},
"a": {"type": "number", "description": "The first number"},
"b": {"type": "number", "description": "The second number"},
}
}
},
)
# Create the tool in the database
tool_manager = ToolManager()
created_tool = tool_manager.create_or_update_tool(tool, actor=test_user)
yield created_tool
# Cleanup would go here if needed
@pytest.fixture
def async_tool(test_user):
"""Create an async tool for testing."""
from letta.services.tool_manager import ToolManager
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="fetch_data",
source_type=ToolSourceType.python,
source_code="""
import asyncio
async def fetch_data(url: str, delay: float = 0.1) -> Dict:
'''Simulate fetching data from a URL.
Args:
url: The URL to fetch data from
delay: The delay in seconds before returning
Returns:
Dict: A dictionary containing the fetched data
'''
await asyncio.sleep(delay)
return {
"url": url,
"status": "success",
"data": f"Data from {url}",
"timestamp": "2024-01-01T00:00:00Z"
}
""",
json_schema={
"parameters": {
"properties": {
"url": {"type": "string", "description": "The URL to fetch data from"},
"delay": {"type": "number", "default": 0.1, "description": "The delay in seconds"},
}
}
},
)
# Create the tool in the database
tool_manager = ToolManager()
created_tool = tool_manager.create_or_update_tool(tool, actor=test_user)
yield created_tool
# Cleanup would go here if needed
@pytest.fixture
def tool_with_dependencies(test_user):
"""Create a tool that requires external dependencies."""
from letta.services.tool_manager import ToolManager
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="process_json",
source_type=ToolSourceType.python,
source_code="""
import json
import hashlib
def process_json(data: str) -> Dict:
'''Process JSON data and return metadata.
Args:
data: The JSON string to process
Returns:
Dict: Metadata about the JSON data
'''
try:
parsed = json.loads(data)
data_hash = hashlib.md5(data.encode()).hexdigest()
return {
"valid": True,
"keys": list(parsed.keys()) if isinstance(parsed, dict) else None,
"type": type(parsed).__name__,
"hash": data_hash,
"size": len(data),
}
except json.JSONDecodeError as e:
return {
"valid": False,
"error": str(e),
"size": len(data),
}
""",
json_schema={
"parameters": {
"properties": {
"data": {"type": "string", "description": "The JSON string to process"},
}
}
},
pip_requirements=[PipRequirement(name="hashlib")], # Actually built-in, but for testing
)
# Create the tool in the database
tool_manager = ToolManager()
created_tool = tool_manager.create_or_update_tool(tool, actor=test_user)
yield created_tool
# Cleanup would go here if needed
@pytest.fixture
def sandbox_config(test_user):
"""Create a test sandbox configuration in the database."""
manager = SandboxConfigManager()
modal_config = ModalSandboxConfig(
timeout=60,
pip_requirements=["pandas==2.0.0"],
)
config_create = SandboxConfigCreate(config=modal_config.model_dump())
config = manager.create_or_update_sandbox_config(sandbox_config_create=config_create, actor=test_user)
yield config
# Cleanup would go here if needed
@pytest.fixture
def mock_sandbox_config():
"""Create a mock sandbox configuration for tests that don't need database persistence."""
modal_config = ModalSandboxConfig(
timeout=60,
pip_requirements=["pandas==2.0.0"],
)
return SandboxConfig(
id=f"sandbox-{uuid.uuid4().hex[:8]}",
type=SandboxType.MODAL,
config=modal_config.model_dump(),
)
# ============================================================================
# BASIC EXECUTION TESTS (Requires Modal credentials)
# ============================================================================
@pytest.mark.skipif(
True or not os.getenv("MODAL_TOKEN_ID") or not os.getenv("MODAL_TOKEN_SECRET"), reason="Modal credentials not configured"
)
class TestModalV2BasicExecution:
"""Basic execution tests with Modal."""
@pytest.mark.asyncio
async def test_basic_execution(self, basic_tool, test_user):
"""Test basic tool execution with different operations."""
sandbox = AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "add", "a": 5, "b": 3},
user=test_user,
tool_object=basic_tool,
)
result = await sandbox.run()
assert result.status == "success"
assert result.func_return == 8.0
# Test division
sandbox2 = AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "divide", "a": 10, "b": 2},
user=test_user,
tool_object=basic_tool,
)
result2 = await sandbox2.run()
assert result2.status == "success"
assert result2.func_return == 5.0
@pytest.mark.asyncio
async def test_error_handling(self, basic_tool, test_user):
"""Test error handling in tool execution."""
# Test division by zero
sandbox = AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "divide", "a": 10, "b": 0},
user=test_user,
tool_object=basic_tool,
)
result = await sandbox.run()
assert result.status == "error"
assert "Cannot divide by zero" in str(result.func_return)
# Test unknown operation
sandbox2 = AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "unknown", "a": 1, "b": 2},
user=test_user,
tool_object=basic_tool,
)
result2 = await sandbox2.run()
assert result2.status == "error"
assert "Unknown operation" in str(result2.func_return)
@pytest.mark.asyncio
async def test_async_tool_execution(self, async_tool, test_user):
"""Test execution of async tools."""
sandbox = AsyncToolSandboxModalV2(
tool_name="fetch_data",
args={"url": "https://example.com", "delay": 0.01},
user=test_user,
tool_object=async_tool,
)
result = await sandbox.run()
assert result.status == "success"
# Parse the result (it should be a dict)
data = result.func_return
assert isinstance(data, dict)
assert data["url"] == "https://example.com"
assert data["status"] == "success"
assert "Data from https://example.com" in data["data"]
@pytest.mark.asyncio
async def test_concurrent_executions(self, basic_tool, test_user):
"""Test that concurrent executions work correctly."""
# Create multiple sandboxes with different arguments
sandboxes = [
AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "add", "a": i, "b": i + 1},
user=test_user,
tool_object=basic_tool,
)
for i in range(5)
]
# Execute all concurrently
results = await asyncio.gather(*[s.run() for s in sandboxes])
# Verify all succeeded with correct results
for i, result in enumerate(results):
assert result.status == "success"
expected = i + (i + 1) # a + b
assert result.func_return == expected
# ============================================================================
# PERSISTENCE AND VERSION TRACKING TESTS
# ============================================================================
@pytest.mark.asyncio
class TestModalV2Persistence:
"""Tests for deployment persistence and version tracking."""
async def test_deployment_persists_in_tool_metadata(self, mock_user, sandbox_config):
"""Test that deployment info is correctly stored in tool metadata."""
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="calculate",
source_code="def calculate(x: float) -> float:\n '''Double a number.\n \n Args:\n x: The number to double\n \n Returns:\n The doubled value\n '''\n return x * 2",
json_schema={"parameters": {"properties": {"x": {"type": "number"}}}},
metadata_={},
)
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
mock_tool_manager = MockToolManager.return_value
mock_tool_manager.get_tool_by_id.return_value = tool
mock_tool_manager.update_tool_by_id_async = AsyncMock(return_value=tool)
version_manager = ModalVersionManager()
# Register a deployment
app_name = f"{mock_user.organization_id}-{tool.name}-v2"
version_hash = "abc123def456"
mock_app = MagicMock()
await version_manager.register_deployment(
tool_id=tool.id,
app_name=app_name,
version_hash=version_hash,
app=mock_app,
dependencies={"pandas", "numpy"},
sandbox_config_id=sandbox_config.id,
actor=mock_user,
)
# Verify update was called with correct metadata
mock_tool_manager.update_tool_by_id_async.assert_called_once()
call_args = mock_tool_manager.update_tool_by_id_async.call_args
metadata = call_args[1]["tool_update"].metadata_
assert "modal_deployments" in metadata
assert sandbox_config.id in metadata["modal_deployments"]
deployment_data = metadata["modal_deployments"][sandbox_config.id]
assert deployment_data["app_name"] == app_name
assert deployment_data["version_hash"] == version_hash
assert set(deployment_data["dependencies"]) == {"pandas", "numpy"}
async def test_version_tracking_and_redeployment(self, mock_user, basic_tool, sandbox_config):
"""Test version tracking and redeployment on code changes."""
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
mock_tool_manager = MockToolManager.return_value
mock_tool_manager.get_tool_by_id.return_value = basic_tool
# Track metadata updates
metadata_store = {}
async def update_tool(*args, **kwargs):
metadata_store.update(kwargs.get("metadata_", {}))
basic_tool.metadata_ = metadata_store
return basic_tool
mock_tool_manager.update_tool_by_id_async = AsyncMock(side_effect=update_tool)
version_manager = ModalVersionManager()
app_name = f"{mock_user.organization_id}-{basic_tool.name}-v2"
# First deployment
version1 = "version1hash"
await version_manager.register_deployment(
tool_id=basic_tool.id,
app_name=app_name,
version_hash=version1,
app=MagicMock(),
sandbox_config_id=sandbox_config.id,
actor=mock_user,
)
# Should not need redeployment with same version
assert not await version_manager.needs_redeployment(basic_tool.id, version1, sandbox_config.id, actor=mock_user)
# Should need redeployment with different version
version2 = "version2hash"
assert await version_manager.needs_redeployment(basic_tool.id, version2, sandbox_config.id, actor=mock_user)
async def test_deployment_survives_service_restart(self, mock_user, sandbox_config):
"""Test that deployment info survives a service restart."""
tool_id = f"tool-{uuid.uuid4().hex[:8]}"
app_name = f"{mock_user.organization_id}-calculate-v2"
version_hash = "restart-test-v1"
# Simulate existing deployment in metadata
existing_metadata = {
"modal_deployments": {
sandbox_config.id: {
"app_name": app_name,
"version_hash": version_hash,
"deployed_at": datetime.now().isoformat(),
"dependencies": ["pandas"],
}
}
}
tool = Tool(
id=tool_id,
name="calculate",
source_code="def calculate(x: float) -> float:\n '''Identity function.\n \n Args:\n x: The input value\n \n Returns:\n The same value\n '''\n return x",
json_schema={"parameters": {"properties": {}}},
metadata_=existing_metadata,
)
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
mock_tool_manager = MockToolManager.return_value
mock_tool_manager.get_tool_by_id.return_value = tool
# Create new version manager (simulating service restart)
version_manager = ModalVersionManager()
# Should be able to retrieve existing deployment
deployment = await version_manager.get_deployment(tool_id, sandbox_config.id, actor=mock_user)
assert deployment is not None
assert deployment.app_name == app_name
assert deployment.version_hash == version_hash
assert deployment.dependencies == {"pandas"}
# Should not need redeployment with same version
assert not await version_manager.needs_redeployment(tool_id, version_hash, sandbox_config.id, actor=mock_user)
async def test_different_sandbox_configs_same_tool(self, mock_user):
"""Test that different sandbox configs can have different deployments for the same tool."""
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="multi_config",
source_code="def test(x: int) -> int:\n '''Test function.\n \n Args:\n x: The input value\n \n Returns:\n The same value\n '''\n return x",
json_schema={"parameters": {"properties": {}}},
metadata_={},
)
# Create two different sandbox configs
config1 = SandboxConfig(
id=f"sandbox-{uuid.uuid4().hex[:8]}",
type=SandboxType.MODAL,
config=ModalSandboxConfig(timeout=30, pip_requirements=["pandas"]).model_dump(),
)
config2 = SandboxConfig(
id=f"sandbox-{uuid.uuid4().hex[:8]}",
type=SandboxType.MODAL,
config=ModalSandboxConfig(timeout=60, pip_requirements=["numpy"]).model_dump(),
)
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
mock_tool_manager = MockToolManager.return_value
mock_tool_manager.get_tool_by_id.return_value = tool
# Track all metadata updates
all_metadata = {"modal_deployments": {}}
async def update_tool(*args, **kwargs):
new_meta = kwargs.get("metadata_", {})
if "modal_deployments" in new_meta:
all_metadata["modal_deployments"].update(new_meta["modal_deployments"])
tool.metadata_ = all_metadata
return tool
mock_tool_manager.update_tool_by_id_async = AsyncMock(side_effect=update_tool)
version_manager = ModalVersionManager()
app_name = f"{mock_user.organization_id}-{tool.name}-v2"
# Deploy with config1
await version_manager.register_deployment(
tool_id=tool.id,
app_name=app_name,
version_hash="config1-hash",
app=MagicMock(),
sandbox_config_id=config1.id,
actor=mock_user,
)
# Deploy with config2
await version_manager.register_deployment(
tool_id=tool.id,
app_name=app_name,
version_hash="config2-hash",
app=MagicMock(),
sandbox_config_id=config2.id,
actor=mock_user,
)
# Both deployments should exist
deployment1 = await version_manager.get_deployment(tool.id, config1.id, actor=mock_user)
deployment2 = await version_manager.get_deployment(tool.id, config2.id, actor=mock_user)
assert deployment1 is not None
assert deployment2 is not None
assert deployment1.version_hash == "config1-hash"
assert deployment2.version_hash == "config2-hash"
async def test_sandbox_config_changes_trigger_redeployment(self, basic_tool, mock_user):
"""Test that sandbox config changes trigger redeployment."""
# Skip the actual Modal deployment part in this test
# Just test the version hash calculation changes
config1 = SandboxConfig(
id=f"sandbox-{uuid.uuid4().hex[:8]}",
type=SandboxType.MODAL,
config=ModalSandboxConfig(timeout=30).model_dump(),
)
config2 = SandboxConfig(
id=f"sandbox-{uuid.uuid4().hex[:8]}",
type=SandboxType.MODAL,
config=ModalSandboxConfig(
timeout=60,
pip_requirements=["requests"],
).model_dump(),
)
# Mock the Modal credentials to allow sandbox instantiation
with patch("letta.services.tool_sandbox.modal_sandbox_v2.tool_settings") as mock_settings:
mock_settings.modal_token_id = "test-token-id"
mock_settings.modal_token_secret = "test-token-secret"
sandbox1 = AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "add", "a": 1, "b": 1},
user=mock_user,
tool_object=basic_tool,
sandbox_config=config1,
)
sandbox2 = AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "add", "a": 2, "b": 2},
user=mock_user,
tool_object=basic_tool,
sandbox_config=config2,
)
# Version hashes should be different due to config changes
version1 = sandbox1._deployment_manager.calculate_version_hash(config1)
version2 = sandbox2._deployment_manager.calculate_version_hash(config2)
assert version1 != version2
# ============================================================================
# MOCKED INTEGRATION TESTS (No Modal credentials required)
# ============================================================================
class TestModalV2MockedIntegration:
"""Integration tests with mocked Modal components."""
@pytest.mark.asyncio
async def test_full_integration_with_persistence(self, mock_user, sandbox_config):
"""Test the full Modal sandbox V2 integration with persistence."""
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="integration_test",
source_code="""
def calculate(operation: str, a: float, b: float) -> float:
'''Perform a simple calculation'''
if operation == "add":
return a + b
return 0
""",
json_schema={
"parameters": {
"properties": {
"operation": {"type": "string"},
"a": {"type": "number"},
"b": {"type": "number"},
}
}
},
metadata_={},
)
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
with patch("letta.services.tool_sandbox.modal_sandbox_v2.modal") as mock_modal:
mock_tool_manager = MockToolManager.return_value
mock_tool_manager.get_tool_by_id.return_value = tool
# Track metadata updates
async def update_tool(*args, **kwargs):
tool.metadata_ = kwargs.get("metadata_", {})
return tool
mock_tool_manager.update_tool_by_id_async = update_tool
# Mock Modal app
mock_app = MagicMock()
mock_app.run = MagicMock()
# Mock the function decorator
def mock_function_decorator(*args, **kwargs):
def decorator(func):
mock_func = MagicMock()
mock_func.remote = MagicMock()
mock_func.remote.aio = AsyncMock(
return_value={
"result": 8,
"agent_state": None,
"stdout": "",
"stderr": "",
"error": None,
}
)
mock_app.tool_executor = mock_func
return mock_func
return decorator
mock_app.function = mock_function_decorator
mock_app.deploy = MagicMock()
mock_app.deploy.aio = AsyncMock()
mock_modal.App.return_value = mock_app
# Mock the sandbox config manager
with patch("letta.services.tool_sandbox.base.SandboxConfigManager") as MockSCM:
mock_scm = MockSCM.return_value
mock_scm.get_sandbox_env_vars_as_dict_async = AsyncMock(return_value={})
# Create sandbox
sandbox = AsyncToolSandboxModalV2(
tool_name="integration_test",
args={"operation": "add", "a": 5, "b": 3},
user=mock_user,
tool_object=tool,
sandbox_config=sandbox_config,
)
# Mock version manager methods through deployment manager
version_manager = sandbox._deployment_manager.version_manager
if version_manager:
with patch.object(version_manager, "get_deployment", return_value=None):
with patch.object(version_manager, "register_deployment", return_value=None):
# First execution - should deploy
result1 = await sandbox.run()
assert result1.status == "success"
assert result1.func_return == 8
else:
# If no version manager, just run
result1 = await sandbox.run()
assert result1.status == "success"
assert result1.func_return == 8
@pytest.mark.asyncio
async def test_concurrent_deployment_handling(self, mock_user, sandbox_config):
"""Test that concurrent deployment requests are handled correctly."""
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="concurrent_test",
source_code="def test(x: int) -> int:\n '''Test function.\n \n Args:\n x: The input value\n \n Returns:\n The same value\n '''\n return x",
json_schema={"parameters": {"properties": {}}},
metadata_={},
)
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
mock_tool_manager = MockToolManager.return_value
mock_tool_manager.get_tool_by_id.return_value = tool
# Track update calls
update_calls = []
async def track_update(*args, **kwargs):
update_calls.append((args, kwargs))
await asyncio.sleep(0.01) # Simulate slight delay
return tool
mock_tool_manager.update_tool_by_id_async = AsyncMock(side_effect=track_update)
version_manager = ModalVersionManager()
app_name = f"{mock_user.organization_id}-{tool.name}-v2"
version_hash = "concurrent123"
# Launch multiple concurrent deployments
tasks = []
for i in range(5):
task = version_manager.register_deployment(
tool_id=tool.id,
app_name=app_name,
version_hash=version_hash,
app=MagicMock(),
sandbox_config_id=sandbox_config.id,
actor=mock_user,
)
tasks.append(task)
# Wait for all to complete
await asyncio.gather(*tasks)
# All calls should complete (current implementation doesn't dedupe)
assert len(update_calls) == 5
# ============================================================================
# DEPLOYMENT STATISTICS TESTS
# ============================================================================
@pytest.mark.skipif(not os.getenv("MODAL_TOKEN_ID") or not os.getenv("MODAL_TOKEN_SECRET"), reason="Modal credentials not configured")
class TestModalV2DeploymentStats:
"""Tests for deployment statistics tracking."""
@pytest.mark.asyncio
async def test_deployment_stats(self, basic_tool, async_tool, test_user):
"""Test deployment statistics tracking."""
version_manager = get_version_manager()
# Clear any existing deployments (for test isolation)
version_manager.clear_deployments()
# Ensure clean state
await asyncio.sleep(0.1)
# Deploy multiple tools
tools = [basic_tool, async_tool]
for tool in tools:
sandbox = AsyncToolSandboxModalV2(
tool_name=tool.name,
args={},
user=test_user,
tool_object=tool,
)
await sandbox.run()
# Get stats
stats = await version_manager.get_deployment_stats()
assert stats["total_deployments"] >= 2
assert stats["active_deployments"] >= 2
assert stats["stale_deployments"] == 0
# Check individual deployment info
for deployment in stats["deployments"]:
assert "app_name" in deployment
assert "version" in deployment
assert "usage_count" in deployment
assert deployment["usage_count"] >= 1
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,570 @@
import json
import pickle
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.sandbox_config import ModalSandboxConfig, SandboxConfig, SandboxType
from letta.schemas.tool import Tool
from letta.services.tool_sandbox.modal_sandbox_v2 import AsyncToolSandboxModalV2
from letta.services.tool_sandbox.modal_version_manager import ModalVersionManager
from sandbox.modal_executor import ModalFunctionExecutor
class TestModalFunctionExecutor:
"""Test the ModalFunctionExecutor class."""
def test_execute_tool_dynamic_success(self):
"""Test successful execution of a simple tool."""
tool_source = """
def add_numbers(a: int, b: int) -> int:
return a + b
"""
args = {"a": 5, "b": 3}
args_pickled = pickle.dumps(args)
result = ModalFunctionExecutor.execute_tool_dynamic(
tool_source=tool_source,
tool_name="add_numbers",
args_pickled=args_pickled,
agent_state_pickled=None,
inject_agent_state=False,
is_async=False,
args_schema_code=None,
)
assert result["error"] is None
assert result["result"] == 8 # Actual integer value
assert result["agent_state"] is None
def test_execute_tool_dynamic_with_error(self):
"""Test execution with an error."""
tool_source = """
def divide_numbers(a: int, b: int) -> float:
return a / b
"""
args = {"a": 5, "b": 0}
args_pickled = pickle.dumps(args)
result = ModalFunctionExecutor.execute_tool_dynamic(
tool_source=tool_source,
tool_name="divide_numbers",
args_pickled=args_pickled,
agent_state_pickled=None,
inject_agent_state=False,
is_async=False,
args_schema_code=None,
)
assert result["error"] is not None
assert result["error"]["name"] == "ZeroDivisionError"
assert "division by zero" in result["error"]["value"]
assert result["result"] is None
def test_execute_async_tool(self):
"""Test execution of an async tool."""
tool_source = """
async def async_add(a: int, b: int) -> int:
import asyncio
await asyncio.sleep(0.001)
return a + b
"""
args = {"a": 10, "b": 20}
args_pickled = pickle.dumps(args)
result = ModalFunctionExecutor.execute_tool_dynamic(
tool_source=tool_source,
tool_name="async_add",
args_pickled=args_pickled,
agent_state_pickled=None,
inject_agent_state=False,
is_async=True,
args_schema_code=None,
)
assert result["error"] is None
assert result["result"] == 30
def test_execute_with_stdout_capture(self):
"""Test that stdout is properly captured."""
tool_source = """
def print_and_return(message: str) -> str:
print(f"Processing: {message}")
print("Done!")
return message.upper()
"""
args = {"message": "hello"}
args_pickled = pickle.dumps(args)
result = ModalFunctionExecutor.execute_tool_dynamic(
tool_source=tool_source,
tool_name="print_and_return",
args_pickled=args_pickled,
agent_state_pickled=None,
inject_agent_state=False,
is_async=False,
args_schema_code=None,
)
assert result["error"] is None
assert result["result"] == "HELLO"
assert "Processing: hello" in result["stdout"]
assert "Done!" in result["stdout"]
class TestModalVersionManager:
"""Test the Modal Version Manager."""
@pytest.mark.asyncio
async def test_register_and_get_deployment(self):
"""Test registering and retrieving deployments."""
from unittest.mock import AsyncMock
from letta.schemas.user import User
manager = ModalVersionManager()
# Mock the tool manager
mock_tool = MagicMock()
mock_tool.id = "tool-abc12345"
mock_tool.metadata_ = {}
manager.tool_manager.get_tool_by_id = MagicMock(return_value=mock_tool)
manager.tool_manager.update_tool_by_id_async = AsyncMock(return_value=mock_tool)
# Create a mock actor
mock_actor = MagicMock(spec=User)
mock_actor.id = "user-123"
# Register a deployment
mock_app = MagicMock(spec=["deploy", "stop"])
info = await manager.register_deployment(
tool_id="tool-abc12345",
app_name="test-app",
version_hash="abc123",
app=mock_app,
dependencies={"pandas", "numpy"},
sandbox_config_id="config-123",
actor=mock_actor,
)
assert info.app_name == "test-app"
assert info.version_hash == "abc123"
assert info.dependencies == {"pandas", "numpy"}
# Retrieve the deployment
retrieved = await manager.get_deployment("tool-abc12345", "config-123", actor=mock_actor)
assert retrieved.app_name == info.app_name
assert retrieved.version_hash == info.version_hash
@pytest.mark.asyncio
async def test_needs_redeployment(self):
"""Test checking if redeployment is needed."""
from unittest.mock import AsyncMock
from letta.schemas.user import User
manager = ModalVersionManager()
# Mock the tool manager
mock_tool = MagicMock()
mock_tool.id = "tool-def45678"
mock_tool.metadata_ = {}
manager.tool_manager.get_tool_by_id = MagicMock(return_value=mock_tool)
manager.tool_manager.update_tool_by_id_async = AsyncMock(return_value=mock_tool)
# Create a mock actor
mock_actor = MagicMock(spec=User)
# No deployment exists yet
assert await manager.needs_redeployment("tool-def45678", "v1", "config-123", actor=mock_actor) is True
# Register a deployment
mock_app = MagicMock()
await manager.register_deployment(
tool_id="tool-def45678",
app_name="test-app",
version_hash="v1",
app=mock_app,
sandbox_config_id="config-123",
actor=mock_actor,
)
# Update mock to return the registered deployment
mock_tool.metadata_ = {
"modal_deployments": {
"config-123": {
"app_name": "test-app",
"version_hash": "v1",
"deployed_at": "2024-01-01T00:00:00",
"dependencies": [],
}
}
}
# Same version - no redeployment needed
assert await manager.needs_redeployment("tool-def45678", "v1", "config-123", actor=mock_actor) is False
# Different version - redeployment needed
assert await manager.needs_redeployment("tool-def45678", "v2", "config-123", actor=mock_actor) is True
@pytest.mark.skip(reason="get_deployment_stats method not implemented in ModalVersionManager")
@pytest.mark.asyncio
async def test_deployment_stats(self):
"""Test getting deployment statistics."""
from unittest.mock import AsyncMock
from letta.schemas.user import User
manager = ModalVersionManager()
# Mock the tool manager
mock_tools = {}
for i in range(3):
tool_id = f"tool-{i:08x}"
mock_tool = MagicMock()
mock_tool.id = tool_id
mock_tool.metadata_ = {}
mock_tools[tool_id] = mock_tool
def get_tool_by_id(tool_id, actor=None):
return mock_tools.get(tool_id)
manager.tool_manager.get_tool_by_id = MagicMock(side_effect=get_tool_by_id)
manager.tool_manager.update_tool_by_id_async = AsyncMock()
# Create a mock actor
mock_actor = MagicMock(spec=User)
# Register multiple deployments
for i in range(3):
tool_id = f"tool-{i:08x}"
mock_app = MagicMock()
await manager.register_deployment(
tool_id=tool_id,
app_name=f"app-{i}",
version_hash=f"v{i}",
app=mock_app,
sandbox_config_id="config-123",
actor=mock_actor,
)
stats = await manager.get_deployment_stats()
# Note: The actual implementation may store deployments differently
# This test assumes the stats method exists and returns expected format
assert stats["total_deployments"] >= 0 # Adjust based on actual implementation
assert "deployments" in stats
@pytest.mark.skip(reason="export_state and import_state methods not implemented in ModalVersionManager")
@pytest.mark.asyncio
async def test_export_import_state(self):
"""Test exporting and importing deployment state."""
from unittest.mock import AsyncMock
from letta.schemas.user import User
manager1 = ModalVersionManager()
# Mock the tool manager for manager1
mock_tools = {
"tool-11111111": MagicMock(id="tool-11111111", metadata_={}),
"tool-22222222": MagicMock(id="tool-22222222", metadata_={}),
}
def get_tool_by_id(tool_id, actor=None):
return mock_tools.get(tool_id)
manager1.tool_manager.get_tool_by_id = MagicMock(side_effect=get_tool_by_id)
manager1.tool_manager.update_tool_by_id_async = AsyncMock()
# Create a mock actor
mock_actor = MagicMock(spec=User)
# Register deployments
mock_app = MagicMock()
await manager1.register_deployment(
tool_id="tool-11111111",
app_name="app1",
version_hash="v1",
app=mock_app,
dependencies={"dep1"},
sandbox_config_id="config-123",
actor=mock_actor,
)
await manager1.register_deployment(
tool_id="tool-22222222",
app_name="app2",
version_hash="v2",
app=mock_app,
dependencies={"dep2", "dep3"},
sandbox_config_id="config-123",
actor=mock_actor,
)
# Export state
state_json = await manager1.export_state()
state = json.loads(state_json)
# Verify exported state structure
assert "tool-11111111" in state or "deployments" in state # Depends on implementation
# Import into new manager
manager2 = ModalVersionManager()
manager2.tool_manager.get_tool_by_id = MagicMock(side_effect=get_tool_by_id)
await manager2.import_state(state_json)
# Note: The actual implementation may not have export/import methods
# This test assumes they exist or should be modified based on actual API
class TestAsyncToolSandboxModalV2:
"""Test the AsyncToolSandboxModalV2 class."""
@pytest.fixture
def mock_tool(self):
"""Create a mock tool for testing."""
return Tool(
id="tool-12345678", # Valid tool ID format
name="test_function",
source_code="""
def test_function(x: int, y: int) -> int:
'''Add two numbers together.'''
return x + y
""",
json_schema={
"parameters": {
"properties": {
"x": {"type": "integer"},
"y": {"type": "integer"},
}
}
},
pip_requirements=[PipRequirement(name="requests")],
)
@pytest.fixture
def mock_user(self):
"""Create a mock user for testing."""
user = MagicMock()
user.organization_id = "test-org"
return user
@pytest.fixture
def mock_sandbox_config(self):
"""Create a mock sandbox configuration."""
modal_config = ModalSandboxConfig(
timeout=60,
pip_requirements=["pandas"],
)
config = SandboxConfig(
id="sandbox-12345678", # Valid sandbox ID format
type=SandboxType.MODAL, # Changed from sandbox_type to type
config=modal_config.model_dump(),
)
return config
def test_version_hash_calculation(self, mock_tool, mock_user, mock_sandbox_config):
"""Test that version hash is calculated correctly."""
sandbox = AsyncToolSandboxModalV2(
tool_name="test_function",
args={"x": 1, "y": 2},
user=mock_user,
tool_object=mock_tool,
sandbox_config=mock_sandbox_config,
)
# Access through deployment manager
version1 = sandbox._deployment_manager.calculate_version_hash(mock_sandbox_config)
assert version1 # Should not be empty
assert len(version1) == 12 # We take first 12 chars of hash
# Same inputs should produce same hash
version2 = sandbox._deployment_manager.calculate_version_hash(mock_sandbox_config)
assert version1 == version2
# Changing tool code should change hash
mock_tool.source_code = "def test_function(x, y): return x * y"
sandbox2 = AsyncToolSandboxModalV2(
tool_name="test_function",
args={"x": 1, "y": 2},
user=mock_user,
tool_object=mock_tool,
sandbox_config=mock_sandbox_config,
)
version3 = sandbox2._deployment_manager.calculate_version_hash(mock_sandbox_config)
assert version3 != version1
# Changing dependencies should also change hash
mock_tool.source_code = "def test_function(x, y): return x + y" # Reset
mock_tool.pip_requirements = [PipRequirement(name="numpy")]
sandbox3 = AsyncToolSandboxModalV2(
tool_name="test_function",
args={"x": 1, "y": 2},
user=mock_user,
tool_object=mock_tool,
sandbox_config=mock_sandbox_config,
)
version4 = sandbox3._deployment_manager.calculate_version_hash(mock_sandbox_config)
assert version4 != version1
# Changing sandbox config should change hash
modal_config2 = ModalSandboxConfig(
timeout=120, # Different timeout
pip_requirements=["pandas"],
)
config2 = SandboxConfig(
id="sandbox-87654321",
type=SandboxType.MODAL,
config=modal_config2.model_dump(),
)
version5 = sandbox3._deployment_manager.calculate_version_hash(config2)
assert version5 != version4
def test_app_name_generation(self, mock_tool, mock_user):
"""Test app name generation."""
sandbox = AsyncToolSandboxModalV2(
tool_name="test_function",
args={"x": 1, "y": 2},
user=mock_user,
tool_object=mock_tool,
)
# App name generation is now in deployment manager and uses tool ID
app_name = sandbox._deployment_manager._generate_app_name()
# App name is based on tool ID truncated to 40 chars
assert app_name == mock_tool.id[:40]
@pytest.mark.asyncio
async def test_run_with_mocked_modal(self, mock_tool, mock_user, mock_sandbox_config):
"""Test the run method with mocked Modal components."""
with (
patch("letta.services.tool_sandbox.modal_sandbox_v2.modal") as mock_modal,
patch("letta.services.tool_sandbox.modal_deployment_manager.modal") as mock_modal2,
):
# Mock Modal app
mock_app = MagicMock() # Use MagicMock for the app itself
mock_app.run = MagicMock()
# Mock the function decorator
def mock_function_decorator(*args, **kwargs):
def decorator(func):
# Create a mock that has a remote attribute
mock_func = MagicMock()
mock_func.remote = mock_remote
# Store the mocked function as tool_executor on the app
mock_app.tool_executor = mock_func
return mock_func
return decorator
mock_app.function = mock_function_decorator
# Mock deployment
mock_app.deploy = MagicMock()
mock_app.deploy.aio = AsyncMock()
# Mock the remote execution
mock_remote = MagicMock()
mock_remote.aio = AsyncMock(
return_value={
"result": 3, # Return actual integer, not string
"agent_state": None,
"stdout": "Executing...",
"stderr": "",
"error": None,
}
)
mock_modal.App.return_value = mock_app
mock_modal2.App.return_value = mock_app
# Mock App.lookup.aio to handle app lookup attempts
mock_modal.App.lookup = MagicMock()
mock_modal.App.lookup.aio = AsyncMock(side_effect=Exception("App not found"))
mock_modal2.App.lookup = MagicMock()
mock_modal2.App.lookup.aio = AsyncMock(side_effect=Exception("App not found"))
# Mock enable_output context manager
mock_modal.enable_output = MagicMock()
mock_modal.enable_output.return_value.__enter__ = MagicMock()
mock_modal.enable_output.return_value.__exit__ = MagicMock()
mock_modal2.enable_output = MagicMock()
mock_modal2.enable_output.return_value.__enter__ = MagicMock()
mock_modal2.enable_output.return_value.__exit__ = MagicMock()
# Mock the SandboxConfigManager to avoid type checking issues
with patch("letta.services.tool_sandbox.base.SandboxConfigManager") as MockSCM:
mock_scm = MockSCM.return_value
mock_scm.get_sandbox_env_vars_as_dict_async = AsyncMock(return_value={})
# Create sandbox
sandbox = AsyncToolSandboxModalV2(
tool_name="test_function",
args={"x": 1, "y": 2},
user=mock_user,
tool_object=mock_tool,
sandbox_config=mock_sandbox_config,
)
# Mock the version manager through deployment manager
version_manager = sandbox._deployment_manager.version_manager
if version_manager:
with patch.object(version_manager, "get_deployment", return_value=None):
with patch.object(version_manager, "register_deployment", return_value=None):
# Run the tool
result = await sandbox.run()
else:
# If no version manager (use_version_tracking=False), just run
result = await sandbox.run()
assert result.func_return == 3 # Check for actual integer
assert result.status == "success"
assert "Executing..." in result.stdout[0]
def test_detect_async_function(self, mock_user):
"""Test detection of async functions."""
# Test with sync function
sync_tool = Tool(
id="tool-abcdef12", # Valid tool ID format
name="sync_func",
source_code="def sync_func(x): return x",
json_schema={"parameters": {"properties": {}}},
)
sandbox_sync = AsyncToolSandboxModalV2(
tool_name="sync_func",
args={},
user=mock_user,
tool_object=sync_tool,
)
assert sandbox_sync._detect_async_function() is False
# Test with async function
async_tool = Tool(
id="tool-fedcba21", # Valid tool ID format
name="async_func",
source_code="async def async_func(x): return x",
json_schema={"parameters": {"properties": {}}},
)
sandbox_async = AsyncToolSandboxModalV2(
tool_name="async_func",
args={},
user=mock_user,
tool_object=async_tool,
)
assert sandbox_async._detect_async_function() is True
if __name__ == "__main__":
pytest.main([__file__, "-v"])