feat: modal tool execution - NO FEATURE FLAGS USES MODAL [LET-4357] (#5120)

* initial commit

* add delay to deploy

* fix tests

* add tests

* passing tests

* cleanup

* and use modal

* working on modal

* gate on tool metadata

* agent state

* cleanup

---------

Co-authored-by: Letta Bot <noreply@letta.com>
This commit is contained in:
Sarah Wooders
2025-11-11 18:21:51 -08:00
committed by Caren Thomas
parent 2a8523aa01
commit 5730f69ecf
12 changed files with 1314 additions and 390 deletions

View File

@@ -44,6 +44,11 @@ IN_CONTEXT_MEMORY_KEYWORD = "CORE_MEMORY"
# OpenAI error message: Invalid 'messages[1].tool_calls[0].id': string too long. Expected a string with maximum length 29, but got a string with length 36 instead.
TOOL_CALL_ID_MAX_LEN = 29
# Maximum length for tool names to support Modal deployment
# Modal function names are limited to 64 characters: tool_name + "_" + project_id
# Reserving 16 characters for project_id suffix (e.g., "_project-12345678")
MAX_TOOL_NAME_LENGTH = 48
# Max steps for agent loop
DEFAULT_MAX_STEPS = 50
@@ -440,5 +445,18 @@ EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES = ["claude-4-sonnet", "claude-3-5-so
# But include models with these keywords in base tool rules (overrides exclusion)
INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES = ["mini"]
# Deployment and versioning
MODAL_DEFAULT_TOOL_NAME = "modal_tool_wrapper.<locals>.modal_function" # NOTE: must stay in sync with modal_tool_wrapper
MODAL_DEFAULT_CONFIG_KEY = "default"
MODAL_MODAL_DEPLOYMENTS_KEY = "modal_deployments"
MODAL_VERSION_HASH_LENGTH = 12
# Modal execution settings
MODAL_DEFAULT_TIMEOUT = 60
MODAL_DEFAULT_MAX_CONCURRENT_INPUTS = 1
MODAL_DEFAULT_PYTHON_VERSION = "3.12"
# Security settings
MODAL_SAFE_IMPORT_MODULES = {"typing", "pydantic", "datetime", "enum", "uuid", "decimal"}
# Default handle for model used to generate tools
DEFAULT_GENERATE_TOOL_MODEL_HANDLE = "openai/gpt-4.1"

View File

@@ -0,0 +1,69 @@
import hashlib
from letta.constants import MODAL_VERSION_HASH_LENGTH
from letta.schemas.tool import Tool
def _serialize_dependencies(tool: Tool) -> str:
"""
Serialize dependencies in a consistent way for hashing.
TODO: This should be improved per LET-3770 to ensure consistent ordering.
For now, we convert to string representation.
"""
parts = []
if tool.pip_requirements:
# TODO: Sort these consistently
parts.append(f"pip:{str(tool.pip_requirements)}")
if tool.npm_requirements:
# TODO: Sort these consistently
parts.append(f"npm:{str(tool.npm_requirements)}")
return ";".join(parts)
def compute_tool_hash(tool: Tool):
"""
Calculate a hash representing the current version of the tool and configuration.
This hash changes when:
- Tool source code changes
- Tool dependencies change
- Sandbox configuration changes
- Language/runtime changes
"""
components = [
tool.source_code if tool.source_code else "",
tool.source_type if tool.source_type else "",
_serialize_dependencies(tool),
]
combined = "|".join(components)
return hashlib.sha256(combined.encode()).hexdigest()[:MODAL_VERSION_HASH_LENGTH]
def generate_modal_function_name(tool_name: str, organization_id: str, project_id: str = "default") -> str:
"""
Generate a Modal function name from tool name and project ID.
Shortens the project ID to just the prefix and first UUID segment.
Args:
tool_name: Name of the tool
organization_id: Full organization ID (not used in function name, but kept for future use)
project_id: Project ID (e.g., project-12345678-90ab-cdef-1234-567890abcdef or "default")
Returns:
Modal function name (e.g., tool_name_project-12345678 or tool_name_default)
"""
from letta.constants import MAX_TOOL_NAME_LENGTH
max_tool_name_length = 64
# Shorten the organization_id to just the first segment (e.g., project-12345678)
short_organization_id = organization_id[: (max_tool_name_length - MAX_TOOL_NAME_LENGTH - 1)]
# make extra sure the tool name is not too long
name = f"{tool_name[:MAX_TOOL_NAME_LENGTH]}_{short_organization_id}"
# safe fallback
return name

View File

@@ -4,12 +4,11 @@ from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, model_validator
from letta.constants import LETTA_TOOL_EXECUTION_DIR
from letta.constants import LETTA_TOOL_EXECUTION_DIR, MODAL_DEFAULT_TIMEOUT
from letta.schemas.agent import AgentState
from letta.schemas.enums import PrimitiveType, SandboxType
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
from letta.schemas.pip_requirement import PipRequirement
from letta.services.tool_sandbox.modal_constants import DEFAULT_MODAL_TIMEOUT
from letta.settings import tool_settings
# Sandbox Config
@@ -81,7 +80,7 @@ class E2BSandboxConfig(BaseModel):
class ModalSandboxConfig(BaseModel):
timeout: int = Field(DEFAULT_MODAL_TIMEOUT, description="Time limit for the sandbox (in seconds).")
timeout: int = Field(MODAL_DEFAULT_TIMEOUT, description="Time limit for the sandbox (in seconds).")
pip_requirements: list[str] | None = Field(None, description="A list of pip packages to install in the Modal sandbox")
npm_requirements: list[str] | None = Field(None, description="A list of npm packages to install in the Modal sandbox")
language: Literal["python", "typescript"] = "python"

View File

@@ -7,6 +7,7 @@ from httpx import ConnectError, HTTPStatusError
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse
from letta.constants import MAX_TOOL_NAME_LENGTH
from letta.constants import DEFAULT_GENERATE_TOOL_MODEL_HANDLE
from letta.errors import (
LettaInvalidArgumentError,

View File

@@ -46,27 +46,19 @@ class SandboxToolExecutor(ToolExecutor):
agent_state_copy = self._create_agent_state_copy(agent_state) if agent_state else None
# Execute in sandbox depending on API key
if tool_settings.sandbox_type == SandboxType.E2B:
from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B
# Execute in sandbox with Modal first (if configured and requested), then fallback to E2B/LOCAL
# Try Modal if: (1) Modal credentials configured AND (2) tool requests Modal via metadata
tool_requests_modal = tool.metadata_ and tool.metadata_.get("sandbox") == "modal"
modal_configured = tool_settings.modal_sandbox_enabled
sandbox = AsyncToolSandboxE2B(
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
)
# TODO (cliandy): this is just for testing right now, separate this out into it's own subclass and handling logic
elif tool_settings.sandbox_type == SandboxType.MODAL:
from letta.services.tool_sandbox.modal_sandbox import AsyncToolSandboxModal, TypescriptToolSandboxModal
tool_execution_result = None
if tool.source_type == ToolSourceType.typescript:
sandbox = TypescriptToolSandboxModal(
function_name,
function_args,
actor,
tool_object=tool,
sandbox_config=sandbox_config,
sandbox_env_vars=sandbox_env_vars,
)
elif tool.source_type == ToolSourceType.python:
# Try Modal first if both conditions met
if tool_requests_modal and modal_configured:
try:
from letta.services.tool_sandbox.modal_sandbox import AsyncToolSandboxModal
logger.info(f"Attempting Modal execution for tool {tool.name}")
sandbox = AsyncToolSandboxModal(
function_name,
function_args,
@@ -74,12 +66,36 @@ class SandboxToolExecutor(ToolExecutor):
tool_object=tool,
sandbox_config=sandbox_config,
sandbox_env_vars=sandbox_env_vars,
organization_id=actor.organization_id,
)
# TODO: pass through letta api key
tool_execution_result = await sandbox.run(agent_state=agent_state_copy, additional_env_vars=sandbox_env_vars)
except Exception as e:
# Modal execution failed, log and fall back to E2B/LOCAL
logger.warning(f"Modal execution failed for tool {tool.name}: {e}. Falling back to {tool_settings.sandbox_type.value}")
tool_execution_result = None
# Fallback to E2B or LOCAL if Modal wasn't tried or failed
if tool_execution_result is None:
if tool_settings.sandbox_type == SandboxType.E2B:
from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B
sandbox = AsyncToolSandboxE2B(
function_name,
function_args,
actor,
tool_object=tool,
sandbox_config=sandbox_config,
sandbox_env_vars=sandbox_env_vars,
)
else:
raise ValueError(f"Tool source type was {tool.source_type} but is required to be python or typescript to run in Modal.")
else:
sandbox = AsyncToolSandboxLocal(
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
function_name,
function_args,
actor,
tool_object=tool,
sandbox_config=sandbox_config,
sandbox_env_vars=sandbox_env_vars,
)
tool_execution_result = await sandbox.run(agent_state=agent_state_copy)

View File

@@ -16,30 +16,176 @@ from letta.constants import (
LETTA_TOOL_MODULE_NAMES,
LETTA_TOOL_SET,
LOCAL_ONLY_MULTI_AGENT_TOOLS,
MAX_TOOL_NAME_LENGTH,
MCP_TOOL_TAG_NAME_PREFIX,
MODAL_DEFAULT_TOOL_NAME,
)
from letta.errors import LettaToolNameConflictError, LettaToolNameSchemaMismatchError
from letta.errors import LettaInvalidArgumentError, LettaToolNameConflictError, LettaToolNameSchemaMismatchError
from letta.functions.functions import derive_openai_json_schema, load_function_set
from letta.helpers.tool_helpers import compute_tool_hash, generate_modal_function_name
from letta.log import get_logger
# TODO: Remove this once we translate all of these to the ORM
from letta.orm.errors import NoResultFound
from letta.orm.tool import Tool as ToolModel
from letta.otel.tracing import trace_method
from letta.schemas.enums import PrimitiveType, ToolType
from letta.schemas.agent import AgentState
from letta.schemas.enums import PrimitiveType, SandboxType, ToolType
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.helpers.agent_manager_helper import calculate_multi_agent_tools
from letta.services.mcp.types import SSEServerConfig, StdioServerConfig
from letta.services.tool_schema_generator import generate_schema_for_tool_creation, generate_schema_for_tool_update
from letta.settings import settings
from letta.settings import settings, tool_settings
from letta.utils import enforce_types, printd
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
# NOTE: function name and nested modal function decorator name must stay in sync with MODAL_DEFAULT_TOOL_NAME
def modal_tool_wrapper(tool: PydanticTool, actor: PydanticUser, sandbox_env_vars: dict = None, project_id: str = "default"):
"""Create a Modal function wrapper for a tool"""
import contextlib
import io
import os
import sys
from typing import Optional
import modal
from letta_client import Letta
packages = [str(req) for req in tool.pip_requirements] if tool.pip_requirements else []
packages.append("letta_client")
packages.append("letta") # Base letta without extras
packages.append("asyncpg>=0.30.0") # Fixes asyncpg import error
packages.append("psycopg2-binary>=2.9.10") # PostgreSQL adapter (pre-compiled, no build required)
# packages.append("pgvector>=0.3.6") # Vector operations support
function_name = generate_modal_function_name(tool.name, actor.organization_id, project_id)
modal_app = modal.App(function_name)
logger.info(f"Creating Modal app {tool.id} with name {function_name}")
# Create secrets dict with sandbox env vars
secrets_dict = {"LETTA_API_KEY": None}
if sandbox_env_vars:
secrets_dict.update(sandbox_env_vars)
@modal_app.function(
image=modal.Image.debian_slim(python_version="3.13").pip_install(packages),
restrict_modal_access=True,
timeout=10,
secrets=[modal.Secret.from_dict(secrets_dict)],
serialized=True,
)
def modal_function(
tool_name: str, agent_state: Optional[dict], agent_id: Optional[str], env_vars: dict, letta_api_key: Optional[str] = None, **kwargs
):
"""Wrapper function for running untrusted code in a Modal function"""
# Reconstruct AgentState from dict if passed (to avoid cloudpickle serialization issues)
# This is done with extra safety to handle schema mismatches between environments
reconstructed_agent_state = None
if agent_state:
try:
from letta.schemas.agent import AgentState as AgentStateModel
# Filter dict to only include fields that exist in Modal's version of AgentState
# This prevents ValidationError from extra fields in newer schemas
modal_agent_fields = set(AgentStateModel.model_fields.keys())
filtered_agent_state = {key: value for key, value in agent_state.items() if key in modal_agent_fields}
# Try to reconstruct with filtered data
reconstructed_agent_state = AgentStateModel.model_validate(filtered_agent_state)
# Log if we filtered out any fields
filtered_out = set(agent_state.keys()) - modal_agent_fields
if filtered_out:
print(f"Fields not in available in AgentState: {filtered_out}", file=sys.stderr)
except ImportError as e:
print(f"Cannot import AgentState: {e}", file=sys.stderr)
print("Passing agent_state as dict to tool", file=sys.stderr)
reconstructed_agent_state = agent_state
except Exception as e:
print(f"Warning: Could not reconstruct AgentState (schema mismatch?): {e}", file=sys.stderr)
print("Passing agent_state as dict to tool", file=sys.stderr)
reconstructed_agent_state = agent_state
# Set environment variables
if env_vars:
for key, value in env_vars.items():
os.environ[key] = str(value)
# Initialize the Letta client
if letta_api_key:
letta_client = Letta(token=letta_api_key, base_url=os.environ.get("LETTA_API_URL", "https://api.letta.com"))
else:
letta_client = None
tool_namespace = {
"__builtins__": __builtins__, # Include built-in functions
"_letta_client": letta_client, # Make letta_client available
"os": os, # Include os module for env vars access
"agent_id": agent_id,
# Add any other modules/variables the tool might need
}
# Initialize the tool code
# Create a namespace for the tool
# tool_namespace = {}
exec(tool.source_code, tool_namespace)
# Get the tool function
if tool_name not in tool_namespace:
raise Exception(f"Tool function {tool_name} not found in {tool.source_code}, globals: {tool_namespace}")
tool_func = tool_namespace[tool_name]
# Detect if the tool function is async
import asyncio
import inspect
import traceback
is_async = inspect.iscoroutinefunction(tool_func)
# Capture stdout and stderr during tool execution
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
result = None
error_occurred = False
with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture):
try:
# if `agent_state` is in the tool function arguments, inject it
# Pass reconstructed AgentState (or dict if reconstruction failed)
if "agent_state" in tool_func.__code__.co_varnames:
kwargs["agent_state"] = reconstructed_agent_state
# Execute the tool function (async or sync)
if is_async:
result = asyncio.run(tool_func(**kwargs))
else:
result = tool_func(**kwargs)
except Exception as e:
# Capture the exception and write to stderr
error_occurred = True
traceback.print_exc(file=stderr_capture)
# Get captured output
stdout = stdout_capture.getvalue()
stderr = stderr_capture.getvalue()
return {
"result": result,
"stdout": stdout,
"stderr": stderr,
"agent_state": agent_state, # TODO: deprecate (use letta_client instead)
"error": error_occurred or bool(stderr),
}
return modal_app
class ToolManager:
"""Manager class to handle business logic related to Tools."""
@@ -56,12 +202,16 @@ class ToolManager:
else:
raise ValueError("Failed to generate schema for tool", pydantic_tool.source_code)
print("SCHEMA", pydantic_tool.json_schema)
# make sure the name matches the json_schema
if not pydantic_tool.name:
pydantic_tool.name = pydantic_tool.json_schema.get("name")
else:
# if name is provided, make sure its less tahn the MAX_TOOL_NAME_LENGTH
if len(pydantic_tool.name) > MAX_TOOL_NAME_LENGTH:
raise LettaInvalidArgumentError(
f"Tool name {pydantic_tool.name} is too long. It must be less than {MAX_TOOL_NAME_LENGTH} characters."
)
if pydantic_tool.name != pydantic_tool.json_schema.get("name"):
raise LettaToolNameSchemaMismatchError(
tool_name=pydantic_tool.name,
@@ -136,6 +286,13 @@ class ToolManager:
# Auto-generate description if not provided
if pydantic_tool.description is None and pydantic_tool.json_schema:
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
# Add tool hash to metadata for Modal deployment tracking
tool_hash = compute_tool_hash(pydantic_tool)
if pydantic_tool.metadata_ is None:
pydantic_tool.metadata_ = {}
pydantic_tool.metadata_["tool_hash"] = tool_hash
tool_data = pydantic_tool.model_dump(to_orm=True)
# Set the organization id at the ORM layer
tool_data["organization_id"] = actor.organization_id
@@ -153,7 +310,17 @@ class ToolManager:
)
await tool.create_async(session, actor=actor) # Re-raise other database-related errors
return tool.to_pydantic()
created_tool = tool.to_pydantic()
# Deploy Modal app for the new tool
# Both Modal credentials configured AND tool metadata must indicate Modal
tool_requests_modal = created_tool.metadata_ and created_tool.metadata_.get("sandbox") == "modal"
modal_configured = tool_settings.modal_sandbox_enabled
if created_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured:
await self.create_or_update_modal_app(created_tool, actor)
return created_tool
@enforce_types
@trace_method
@@ -599,6 +766,33 @@ class ToolManager:
source_code=update_data.get("source_code"),
)
# Create a preview of the updated tool by merging current tool with updates
# This allows us to compute the hash before the database session
updated_tool_pydantic = current_tool.model_copy(deep=True)
for key, value in update_data.items():
setattr(updated_tool_pydantic, key, value)
if new_schema is not None:
updated_tool_pydantic.json_schema = new_schema
updated_tool_pydantic.name = new_name
if updated_tool_type:
updated_tool_pydantic.tool_type = updated_tool_type
# Check if we need to redeploy the Modal app due to changes
# Compute this before the session to avoid issues
tool_requests_modal = updated_tool_pydantic.metadata_ and updated_tool_pydantic.metadata_.get("sandbox") == "modal"
modal_configured = tool_settings.modal_sandbox_enabled
should_check_modal = tool_requests_modal and modal_configured and updated_tool_pydantic.tool_type == ToolType.CUSTOM
# Compute hash before session if needed
new_hash = None
old_hash = None
needs_modal_deployment = False
if should_check_modal:
new_hash = compute_tool_hash(updated_tool_pydantic)
old_hash = current_tool.metadata_.get("tool_hash") if current_tool.metadata_ else None
needs_modal_deployment = new_hash != old_hash
# Now perform the update within the session
async with db_registry.async_session() as session:
# Fetch the tool by ID
@@ -618,7 +812,25 @@ class ToolManager:
# Save the updated tool to the database
tool = await tool.update_async(db_session=session, actor=actor)
return tool.to_pydantic()
updated_tool = tool.to_pydantic()
# Update Modal hash in metadata if needed (inside session context)
if needs_modal_deployment:
if updated_tool.metadata_ is None:
updated_tool.metadata_ = {}
updated_tool.metadata_["tool_hash"] = new_hash
# Update the metadata in the database (still inside session)
tool.metadata_ = updated_tool.metadata_
tool = await tool.update_async(db_session=session, actor=actor)
updated_tool = tool.to_pydantic()
# Deploy Modal app outside of session (it creates its own sessions)
if needs_modal_deployment:
logger.info(f"Deploying Modal app for tool {updated_tool.id} with new hash: {new_hash}")
await self.create_or_update_modal_app(updated_tool, actor)
return updated_tool
@enforce_types
@trace_method
@@ -783,3 +995,49 @@ class ToolManager:
created_tool = await self.create_tool_async(tool, actor=actor)
tools.append(created_tool)
return tools
# MODAL RELATED METHODS
@trace_method
async def create_or_update_modal_app(self, tool: PydanticTool, actor: PydanticUser):
"""Create a Modal app with the tool function registered"""
import time
import modal
from letta.services.sandbox_config_manager import SandboxConfigManager
# Load sandbox env vars to bake them into the Modal secrets
sandbox_env_vars = {}
try:
sandbox_config_manager = SandboxConfigManager()
sandbox_config = await sandbox_config_manager.get_or_create_default_sandbox_config_async(
sandbox_type=SandboxType.MODAL, actor=actor
)
if sandbox_config:
sandbox_env_vars = await sandbox_config_manager.get_sandbox_env_vars_as_dict_async(
sandbox_config_id=sandbox_config.id, actor=actor, limit=None
)
logger.info(f"Loaded {len(sandbox_env_vars)} sandbox env vars for Modal app {tool.id}")
except Exception as e:
logger.warning(f"Could not load sandbox env vars for Modal app {tool.id}: {e}")
# Create the Modal app using the global function with sandbox env vars
modal_app = modal_tool_wrapper(tool, actor, sandbox_env_vars)
# Deploy the app first
with modal.enable_output():
try:
deploy = modal_app.deploy()
except Exception as e:
raise LettaInvalidArgumentError(f"Failed to deploy tool {tool.id} with name {tool.name} to Modal: {e}")
# After deployment, look up the function to configure autoscaler
try:
func = modal.Function.from_name(modal_app.name, MODAL_DEFAULT_TOOL_NAME)
func.update_autoscaler(scaledown_window=2) # drain inactive old containers
time.sleep(5)
func.update_autoscaler(scaledown_window=60)
except Exception as e:
logger.warning(f"Failed to configure autoscaler for Modal function {modal_app.name}: {e}")
return deploy

View File

@@ -1,7 +1,14 @@
from typing import Any, Dict, Optional
"""
Model sandbox implementation, which configures on Modal App per tool.
"""
from typing import TYPE_CHECKING, Any, Dict, Optional
import modal
from e2b.sandbox.commands.command_handle import CommandExitException
from e2b_code_interpreter import AsyncSandbox
from letta.constants import MODAL_DEFAULT_TOOL_NAME
from letta.log import get_logger
from letta.otel.tracing import log_event, trace_method
from letta.schemas.agent import AgentState
@@ -9,412 +16,180 @@ from letta.schemas.enums import SandboxType
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.services.helpers.tool_parser_helper import parse_stdout_best_effort
from letta.services.helpers.tool_parser_helper import parse_function_arguments, parse_stdout_best_effort
from letta.services.tool_manager import ToolManager
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
from letta.settings import tool_settings
from letta.types import JsonDict
from letta.utils import get_friendly_error_msg
logger = get_logger(__name__)
# class AsyncToolSandboxModalBase(AsyncToolSandboxBase):
# pass
if TYPE_CHECKING:
from e2b_code_interpreter import Execution
class AsyncToolSandboxModal(AsyncToolSandboxBase):
METADATA_CONFIG_STATE_KEY = "config_state"
def __init__(
self,
tool_name: str,
args: JsonDict,
user,
tool_object: Tool | None = None,
sandbox_config: SandboxConfig | None = None,
sandbox_env_vars: dict[str, Any] | None = None,
force_recreate: bool = True,
tool_object: Optional[Tool] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
organization_id: Optional[str] = None,
project_id: str = "default",
):
super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars)
self.force_recreate = force_recreate
# Get organization_id from user if not explicitly provided
self.organization_id = organization_id if organization_id is not None else user.organization_id
self.project_id = project_id
if not tool_settings.modal_token_id or not tool_settings.modal_token_secret:
raise ValueError("MODAL_TOKEN_ID and MODAL_TOKEN_SECRET must be set.")
# TODO: check to make sure modal app `App(tool.id)` exists
# Create a unique app name based on tool and config
self._app_name = self._generate_app_name()
async def _wait_for_modal_function_deployment(self, timeout: int = 60):
"""Wait for Modal app deployment to complete by retrying function lookup."""
import asyncio
import time
def _generate_app_name(self) -> str:
"""Generate a unique app name based on tool and configuration. Created based on tool name and org"""
return f"{self.user.organization_id}-{self.tool_name}"
import modal
async def _fetch_or_create_modal_app(self, sbx_config: SandboxConfig, env_vars: Dict[str, str]) -> modal.App:
"""Create a Modal app with the tool function registered."""
from letta.helpers.tool_helpers import generate_modal_function_name
# Use the same naming logic as deployment
function_name = generate_modal_function_name(self.tool.name, self.organization_id, self.project_id)
start_time = time.time()
retry_delay = 2 # seconds
while time.time() - start_time < timeout:
try:
app = await modal.App.lookup.aio(self._app_name)
return app
except:
app = modal.App(self._app_name)
f = modal.Function.from_name(function_name, MODAL_DEFAULT_TOOL_NAME)
logger.info(f"Modal function found successfully for app {function_name}, function {f}")
return f
except Exception as e:
elapsed = time.time() - start_time
if elapsed >= timeout:
raise TimeoutError(
f"Modal app {function_name} deployment timed out after {timeout} seconds. "
f"Expected app name: {function_name}, function: {MODAL_DEFAULT_TOOL_NAME}"
) from e
logger.info(f"Modal app {function_name} not ready yet (elapsed: {elapsed:.1f}s), waiting {retry_delay}s...")
await asyncio.sleep(retry_delay)
modal_config = sbx_config.get_modal_config()
# Get the base image with dependencies
image = self._get_modal_image(sbx_config)
# Decorator for the tool, note information on running untrusted code: https://modal.com/docs/guide/restricted-access
# The `@app.function` decorator must apply to functions in global scope, unless `serialized=True` is set.
@app.function(image=image, timeout=modal_config.timeout, restrict_modal_access=True, max_inputs=1, serialized=True)
def execute_tool_with_script(execution_script: str, environment_vars: dict[str, str]):
"""Execute the generated tool script in Modal sandbox."""
import os
# Note: We pass environment variables directly instead of relying on Modal secrets
# This is more flexible and doesn't require pre-configured secrets
for key, value in environment_vars.items():
os.environ[key] = str(value)
exec_globals = {}
exec(execution_script, exec_globals)
# Store the function reference in the app for later use
app.remote_executor = execute_tool_with_script
return app
raise TimeoutError(f"Modal app {function_name} deployment timed out after {timeout} seconds")
@trace_method
async def run(
self,
agent_id: Optional[str] = None,
agent_state: Optional[AgentState] = None,
additional_env_vars: Optional[Dict] = None,
) -> ToolExecutionResult:
if self.provided_sandbox_config:
sbx_config = self.provided_sandbox_config
else:
sbx_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async(
sandbox_type=SandboxType.MODAL, actor=self.user
)
envs = await self._gather_env_vars(agent_state, additional_env_vars or {}, sbx_config.id, is_local=False)
# Generate execution script (this includes the tool source code and execution logic)
execution_script = await self.generate_execution_script(agent_state=agent_state)
await self._init_async()
try:
log_event(
"modal_execution_started",
{"tool": self.tool_name, "app_name": self._app_name, "env_vars": list(envs)},
log_event("modal_execution_started", {"tool": self.tool_name, "modal_app_id": self.tool.id})
logger.info(f"Waiting for Modal function deployment for app {self.tool.id}")
func = await self._wait_for_modal_function_deployment()
logger.info(f"Modal function found successfully for app {self.tool.id}, function {str(func)}")
logger.info(f"Calling with arguments {self.args}")
# TODO: use another mechanism to pass through the key
if additional_env_vars is None:
letta_api_key = None
else:
letta_api_key = additional_env_vars.get("LETTA_API_KEY", None)
# Construct dynamic env vars
# Priority order (later overrides earlier):
# 1. Sandbox-level env vars (from database)
# 2. Agent-specific env vars
# 3. Additional runtime env vars
env_vars = {}
# Load sandbox-level environment variables from the database
# These can be updated after deployment and will be available at runtime
if self.provided_sandbox_env_vars:
env_vars.update(self.provided_sandbox_env_vars)
else:
try:
from letta.services.sandbox_config_manager import SandboxConfigManager
sandbox_config_manager = SandboxConfigManager()
sandbox_config = await sandbox_config_manager.get_or_create_default_sandbox_config_async(
sandbox_type=SandboxType.MODAL, actor=self.user
)
# Create Modal app with the tool function registered
app = await self._fetch_or_create_modal_app(sbx_config, envs)
# Execute the tool remotely
with app.run():
# app = modal.Cls.from_name(app.name, "NodeShimServer")()
result = app.remote_executor.remote(execution_script, envs)
# Process the result
if result["error"]:
# Tool errors are expected behavior - tools can raise exceptions as part of their normal operation
# Only log at debug level to avoid triggering Sentry alerts for expected errors
logger.debug(f"Tool {self.tool_name} raised a {result['error']['name']}: {result['error']['value']}")
logger.debug(f"Traceback from Modal sandbox: \n{result['error']['traceback']}")
func_return = get_friendly_error_msg(
function_name=self.tool_name, exception_name=result["error"]["name"], exception_message=result["error"]["value"]
if sandbox_config:
sandbox_env_vars = await sandbox_config_manager.get_sandbox_env_vars_as_dict_async(
sandbox_config_id=sandbox_config.id, actor=self.user, limit=None
)
env_vars.update(sandbox_env_vars)
except Exception as e:
logger.warning(f"Could not load sandbox env vars for tool {self.tool_name}: {e}")
# Add agent-specific environment variables (these override sandbox-level)
if agent_state and agent_state.secrets:
for secret in agent_state.secrets:
env_vars[secret.key] = secret.value
# Add any additional env vars passed at runtime (highest priority)
if additional_env_vars:
env_vars.update(additional_env_vars)
# Call the modal function (already retrieved at line 101)
# Convert agent_state to dict to avoid cloudpickle serialization issues
agent_state_dict = agent_state.model_dump() if agent_state else None
logger.info(f"Calling function {func} with arguments {self.args}")
result = await func.remote.aio(
tool_name=self.tool_name,
agent_state=agent_state_dict,
agent_id=agent_id,
env_vars=env_vars,
letta_api_key=letta_api_key,
**self.args,
)
logger.info(f"Modal function result: {result}")
# Reconstruct agent_state if it was returned (use original as fallback)
result_agent_state = agent_state
if result.get("agent_state"):
if isinstance(result["agent_state"], dict):
try:
from letta.schemas.agent import AgentState
result_agent_state = AgentState.model_validate(result["agent_state"])
except Exception as e:
logger.warning(f"Failed to reconstruct AgentState: {e}, using original")
else:
result_agent_state = result["agent_state"]
return ToolExecutionResult(
func_return=result["result"],
agent_state=result_agent_state,
stdout=[result["stdout"]],
stderr=[result["stderr"]],
status="error" if result["error"] else "success",
)
except Exception as e:
log_event(
"modal_execution_failed",
{
"tool": self.tool_name,
"app_name": self._app_name,
"error_type": result["error"]["name"],
"error_message": result["error"]["value"],
"func_return": func_return,
"modal_app_id": self.tool.id,
"error": str(e),
},
)
# Parse the result from stdout even if there was an error
# (in case the function returned something before failing)
agent_state = None # Initialize agent_state
try:
func_return_parsed, agent_state_parsed = parse_stdout_best_effort(result["stdout"])
if func_return_parsed is not None:
func_return = func_return_parsed
agent_state = agent_state_parsed
except Exception:
# If parsing fails, keep the error message
pass
else:
func_return, agent_state = parse_stdout_best_effort(result["stdout"])
log_event(
"modal_execution_succeeded",
{
"tool": self.tool_name,
"app_name": self._app_name,
"func_return": func_return,
},
)
logger.error(f"Modal execution failed for tool {self.tool_name} {self.tool.id}: {e}")
return ToolExecutionResult(
func_return=func_return,
func_return=None,
agent_state=agent_state,
stdout=[result["stdout"]] if result["stdout"] else [],
stderr=[result["stderr"]] if result["stderr"] else [],
status="error" if result["error"] else "success",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
except Exception as e:
logger.error(f"Modal execution for tool {self.tool_name} encountered an error: {e}")
func_return = get_friendly_error_msg(
function_name=self.tool_name,
exception_name=type(e).__name__,
exception_message=str(e),
)
log_event(
"modal_execution_error",
{
"tool": self.tool_name,
"app_name": self._app_name,
"error": str(e),
"func_return": func_return,
},
)
return ToolExecutionResult(
func_return=func_return,
agent_state=None,
stdout=[],
stdout=[""],
stderr=[str(e)],
status="error",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
def _get_modal_image(self, sbx_config: SandboxConfig) -> modal.Image:
"""Get Modal image with required public python dependencies.
Caching and rebuilding is handled in a cascading manner
https://modal.com/docs/guide/images#image-caching-and-rebuilds
"""
image = modal.Image.debian_slim(python_version="3.12")
all_requirements = ["letta"]
# Add sandbox-specific pip requirements
modal_configs = sbx_config.get_modal_config()
if modal_configs.pip_requirements:
all_requirements.extend([str(req) for req in modal_configs.pip_requirements])
# Add tool-specific pip requirements
if self.tool and self.tool.pip_requirements:
all_requirements.extend([str(req) for req in self.tool.pip_requirements])
if all_requirements:
image = image.pip_install(*all_requirements)
return image
def use_top_level_await(self) -> bool:
"""
Modal functions don't have an active event loop by default,
so we should use asyncio.run() like local execution.
"""
return False
class TypescriptToolSandboxModal(AsyncToolSandboxModal):
"""Modal sandbox implementation for TypeScript tools."""
@trace_method
async def run(
self,
agent_state: Optional[AgentState] = None,
additional_env_vars: Optional[Dict] = None,
) -> ToolExecutionResult:
"""Run TypeScript tool in Modal sandbox using Node.js server."""
if self.provided_sandbox_config:
sbx_config = self.provided_sandbox_config
else:
sbx_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async(
sandbox_type=SandboxType.MODAL, actor=self.user
)
envs = await self._gather_env_vars(agent_state, additional_env_vars or {}, sbx_config.id, is_local=False)
# Generate execution script (JSON args for TypeScript)
json_args = await self.generate_execution_script(agent_state=agent_state)
try:
log_event(
"modal_typescript_execution_started",
{"tool": self.tool_name, "app_name": self._app_name, "args": json_args},
)
# Create Modal app with the TypeScript Node.js server
app = await self._fetch_or_create_modal_app(sbx_config, envs)
# Execute the TypeScript tool remotely via the Node.js server
with app.run():
# Get the NodeShimServer class from Modal
node_server = modal.Cls.from_name(self._app_name, "NodeShimServer")
# Call the remote_executor method with the JSON arguments
# The server will parse the JSON and call the TypeScript function
result = node_server().remote_executor.remote(json_args)
# Process the TypeScript execution result
if isinstance(result, dict) and "error" in result:
# Handle errors from TypeScript execution
logger.debug(f"TypeScript tool {self.tool_name} raised an error: {result['error']}")
func_return = get_friendly_error_msg(
function_name=self.tool_name,
exception_name="TypeScriptError",
exception_message=str(result["error"]),
)
log_event(
"modal_typescript_execution_failed",
{
"tool": self.tool_name,
"app_name": self._app_name,
"error": result["error"],
"func_return": func_return,
},
)
return ToolExecutionResult(
func_return=func_return,
agent_state=None, # TypeScript tools don't support agent_state yet
stdout=[],
stderr=[str(result["error"])],
status="error",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
else:
# Success case - TypeScript function returned a result
func_return = str(result) if result is not None else ""
log_event(
"modal_typescript_execution_succeeded",
{
"tool": self.tool_name,
"app_name": self._app_name,
"func_return": func_return,
},
)
return ToolExecutionResult(
func_return=func_return,
agent_state=None, # TypeScript tools don't support agent_state yet
stdout=[],
stderr=[],
status="success",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
except Exception as e:
logger.error(f"Modal TypeScript execution for tool {self.tool_name} encountered an error: {e}")
func_return = get_friendly_error_msg(
function_name=self.tool_name,
exception_name=type(e).__name__,
exception_message=str(e),
)
log_event(
"modal_typescript_execution_error",
{
"tool": self.tool_name,
"app_name": self._app_name,
"error": str(e),
"func_return": func_return,
},
)
return ToolExecutionResult(
func_return=func_return,
agent_state=None,
stdout=[],
stderr=[str(e)],
status="error",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
async def _fetch_or_create_modal_app(self, sbx_config: SandboxConfig, env_vars: Dict[str, str]) -> modal.App:
"""Create or fetch a Modal app with TypeScript execution capabilities."""
try:
return await modal.App.lookup.aio(self._app_name)
except:
app = modal.App(self._app_name)
modal_config = sbx_config.get_modal_config()
# Get the base image with dependencies
image = self._get_modal_image(sbx_config)
# Import the NodeShimServer that will handle TypeScript execution
from sandbox.node_server import NodeShimServer
# Register the NodeShimServer class with Modal
# This creates a serverless function that can handle concurrent requests
app.cls(image=image, restrict_modal_access=True, include_source=False, timeout=modal_config.timeout if modal_config else 60)(
modal.concurrent(max_inputs=100, target_inputs=50)(NodeShimServer)
)
# Deploy the app to Modal
with modal.enable_output():
await app.deploy.aio()
return app
async def generate_execution_script(self, agent_state: Optional[AgentState], wrap_print_with_markers: bool = False) -> str:
"""Generate the execution script for TypeScript tools.
For TypeScript tools, this returns the JSON-encoded arguments that will be passed
to the Node.js server via the remote_executor method.
"""
import json
# Convert args to JSON string for TypeScript execution
# The Node.js server expects JSON-encoded arguments
return json.dumps(self.args)
def _get_modal_image(self, sbx_config: SandboxConfig) -> modal.Image:
"""Build a Modal image with Node.js, TypeScript, and the user's tool function."""
import importlib.util
from pathlib import Path
# Find the sandbox module location
spec = importlib.util.find_spec("sandbox")
if not spec or not spec.origin:
raise ValueError("Could not find sandbox module")
server_dir = Path(spec.origin).parent
# Get the TypeScript function source code
if not self.tool or not self.tool.source_code:
raise ValueError("TypeScript tool must have source code")
ts_function = self.tool.source_code
# Get npm dependencies from sandbox config and tool
modal_config = sbx_config.get_modal_config()
npm_dependencies = []
# Add dependencies from sandbox config
if modal_config and modal_config.npm_requirements:
npm_dependencies.extend(modal_config.npm_requirements)
# Add dependencies from the tool itself
if self.tool.npm_requirements:
npm_dependencies.extend(self.tool.npm_requirements)
# Build npm install command for user dependencies
user_dependencies_cmd = ""
if npm_dependencies:
# Ensure unique dependencies
unique_deps = list(set(npm_dependencies))
user_dependencies_cmd = " && npm install " + " ".join(unique_deps)
# Escape single quotes in the TypeScript function for shell command
escaped_ts_function = ts_function.replace("'", "'\\''")
# Build the Docker image with Node.js and TypeScript
image = (
modal.Image.from_registry("node:22-slim", add_python="3.12")
.add_local_dir(server_dir, "/root/sandbox", ignore=["node_modules", "build"], copy=True)
.run_commands(
# Install dependencies and build the TypeScript server
f"cd /root/sandbox/resources/server && npm install{user_dependencies_cmd}",
# Write the user's TypeScript function to a file
f"echo '{escaped_ts_function}' > /root/sandbox/user-function.ts",
)
)
return image
# probably need to do parse_stdout_best_effort

View File

@@ -39,12 +39,20 @@ class ToolSettings(BaseSettings):
mcp_read_from_config: bool = False # if False, will throw if attempting to read/write from file
mcp_disable_stdio: bool = False
@property
def modal_sandbox_enabled(self) -> bool:
"""Check if Modal credentials are configured."""
return bool(self.modal_token_id and self.modal_token_secret)
@property
def sandbox_type(self) -> SandboxType:
"""Default sandbox type based on available credentials.
Note: Modal is checked separately via modal_sandbox_enabled property.
This property determines the fallback behavior (E2B or LOCAL).
"""
if self.e2b_api_key:
return SandboxType.E2B
# elif self.modal_token_id and self.modal_token_secret:
# return SandboxType.MODAL
else:
return SandboxType.LOCAL

View File

@@ -169,6 +169,15 @@ def check_e2b_key_is_set():
yield
@pytest.fixture
def check_modal_key_is_set():
from letta.settings import tool_settings
assert tool_settings.modal_token_id is not None, "Missing modal token id! Cannot execute these tests."
assert tool_settings.modal_token_secret is not None, "Missing modal token secret! Cannot execute these tests."
yield
@pytest.fixture
async def default_organization():
"""Fixture to create and return the default organization."""

View File

@@ -0,0 +1,770 @@
import os
import secrets
import string
import uuid
from pathlib import Path
from unittest.mock import patch
import pytest
from sqlalchemy import delete
from letta.config import LettaConfig
from letta.functions.function_sets.base import core_memory_append, core_memory_replace
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate
from letta.schemas.organization import Organization
from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.sandbox_config import LocalSandboxConfig, ModalSandboxConfig, SandboxConfigCreate
from letta.schemas.user import User
from letta.server.db import db_registry
from letta.server.server import SyncServer
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
from letta.services.tool_sandbox.modal_sandbox import AsyncToolSandboxModal
from letta.services.user_manager import UserManager
from tests.helpers.utils import create_tool_from_func
# Constants
namespace = uuid.NAMESPACE_DNS
org_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-org"))
user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user"))
# Set environment variable immediately to prevent pooling issues
os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] = "true"
# Disable SQLAlchemy connection pooling for tests to prevent event loop issues
@pytest.fixture(scope="session", autouse=True)
def disable_db_pooling_for_tests():
"""Disable database connection pooling for the entire test session."""
# Environment variable is already set above and settings reloaded
yield
# Clean up environment variable after tests
if "LETTA_DISABLE_SQLALCHEMY_POOLING" in os.environ:
del os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"]
# @pytest.fixture(autouse=True)
# async def cleanup_db_connections():
# """Cleanup database connections after each test."""
# yield
#
# # Dispose async engines in the current event loop
# try:
# await close_db()
# except Exception as e:
# # Log the error but don't fail the test
# print(f"Warning: Failed to cleanup database connections: {e}")
# Fixtures
@pytest.fixture(scope="module")
def server():
"""
Creates a SyncServer instance for testing.
Loads and saves config to ensure proper initialization.
"""
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=True)
# create user/org
yield server
@pytest.fixture(autouse=True)
async def clear_tables():
"""Fixture to clear the organization table before each test."""
from letta.server.db import db_registry
async with db_registry.async_session() as session:
await session.execute(delete(SandboxEnvironmentVariable))
await session.execute(delete(SandboxConfig))
await session.commit() # Commit the deletion
@pytest.fixture
async def test_organization():
"""Fixture to create and return the default organization."""
org = await OrganizationManager().create_organization_async(Organization(name=org_name))
yield org
@pytest.fixture
async def test_user(test_organization):
"""Fixture to create and return the default user within the default organization."""
user = await UserManager().create_actor_async(User(name=user_name, organization_id=test_organization.id))
yield user
@pytest.fixture
async def add_integers_tool(test_user):
def add(x: int, y: int) -> int:
"""
Simple function that adds two integers.
Parameters:
x (int): The first integer to add.
y (int): The second integer to add.
Returns:
int: The result of adding x and y.
"""
return x + y
tool = create_tool_from_func(add)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def cowsay_tool(test_user):
# This defines a tool for a package we definitely do NOT have in letta
# If this test passes, that means the tool was correctly executed in a separate Python environment
def cowsay() -> str:
"""
Simple function that uses the cowsay package to print out the secret word env variable.
Returns:
str: The cowsay ASCII art.
"""
import os
import cowsay
cowsay.cow(os.getenv("secret_word"))
tool = create_tool_from_func(cowsay)
# Add cowsay as a pip requirement for Modal
tool.pip_requirements = [PipRequirement(name="cowsay")]
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def get_env_tool(test_user):
def get_env() -> str:
"""
Simple function that returns the secret word env variable.
Returns:
str: The secret word
"""
import os
secret_word = os.getenv("secret_word")
print(secret_word)
return secret_word
tool = create_tool_from_func(get_env)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def get_warning_tool(test_user):
def warn_hello_world() -> str:
"""
Simple function that warns hello world.
Returns:
str: hello world
"""
import warnings
msg = "Hello World"
warnings.warn(msg)
return msg
tool = create_tool_from_func(warn_hello_world)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def always_err_tool(test_user):
def error() -> str:
"""
Simple function that errors
Returns:
str: not important
"""
# Raise a unusual error so we know it's from this function
print("Going to error now")
raise ZeroDivisionError("This is an intentionally weird division!")
tool = create_tool_from_func(error)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def list_tool(test_user):
def create_list():
"""Simple function that returns a list"""
return [1] * 5
tool = create_tool_from_func(create_list)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def clear_core_memory_tool(test_user):
def clear_memory(agent_state: "AgentState"):
"""Clear the core memory"""
agent_state.memory.get_block("human").value = ""
agent_state.memory.get_block("persona").value = ""
tool = create_tool_from_func(clear_memory)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def external_codebase_tool(test_user):
from tests.test_tool_sandbox.restaurant_management_system.adjust_menu_prices import adjust_menu_prices
tool = create_tool_from_func(adjust_menu_prices)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def agent_state(server: SyncServer):
await server.init_async(init_with_default_org_and_user=True)
actor = await server.user_manager.create_default_actor_async()
agent_state = await server.create_agent_async(
CreateAgent(
memory_blocks=[
CreateBlock(
label="human",
value="username: sarah",
),
CreateBlock(
label="persona",
value="This is the persona",
),
],
include_base_tools=True,
model="openai/gpt-4o-mini",
tags=["test_agents"],
embedding="letta/letta-free",
),
actor=actor,
)
agent_state.tool_rules = []
yield agent_state
@pytest.fixture
async def custom_test_sandbox_config(test_user):
"""
Fixture to create a consistent local sandbox configuration for tests.
Args:
test_user: The test user to be used for creating the sandbox configuration.
Returns:
A tuple containing the SandboxConfigManager and the created sandbox configuration.
"""
# Create the SandboxConfigManager
manager = SandboxConfigManager()
# Set the sandbox to be within the external codebase path and use a venv
external_codebase_path = str(Path(__file__).parent / "test_tool_sandbox" / "restaurant_management_system")
# tqdm is used in this codebase, but NOT in the requirements.txt, this tests that we can successfully install pip requirements
local_sandbox_config = LocalSandboxConfig(
sandbox_dir=external_codebase_path, use_venv=True, pip_requirements=[PipRequirement(name="tqdm")]
)
# Create the sandbox configuration
config_create = SandboxConfigCreate(config=local_sandbox_config.model_dump())
# Create or update the sandbox configuration
await manager.create_or_update_sandbox_config_async(sandbox_config_create=config_create, actor=test_user)
return manager, local_sandbox_config
@pytest.fixture
async def core_memory_tools(test_user):
"""Create all base tools for testing."""
tools = {}
for func in [
core_memory_replace,
core_memory_append,
]:
tool = create_tool_from_func(func)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
tools[func.__name__] = tool
yield tools
@pytest.fixture
async def async_add_integers_tool(test_user):
async def async_add(x: int, y: int) -> int:
"""
Async function that adds two integers.
Parameters:
x (int): The first integer to add.
y (int): The second integer to add.
Returns:
int: The result of adding x and y.
"""
import asyncio
# Add a small delay to simulate async work
await asyncio.sleep(0.1)
return x + y
tool = create_tool_from_func(async_add)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def async_get_env_tool(test_user):
async def async_get_env() -> str:
"""
Async function that returns the secret word env variable.
Returns:
str: The secret word
"""
import asyncio
import os
# Add a small delay to simulate async work
await asyncio.sleep(0.1)
secret_word = os.getenv("secret_word")
print(secret_word)
return secret_word
tool = create_tool_from_func(async_get_env)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def async_stateful_tool(test_user):
async def async_clear_memory(agent_state: "AgentState"):
"""Async function that clears the core memory"""
import asyncio
# Add a small delay to simulate async work
await asyncio.sleep(0.1)
agent_state.memory.get_block("human").value = ""
agent_state.memory.get_block("persona").value = ""
tool = create_tool_from_func(async_clear_memory)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def async_error_tool(test_user):
async def async_error() -> str:
"""
Async function that errors
Returns:
str: not important
"""
import asyncio
# Add some async work before erroring
await asyncio.sleep(0.1)
print("Going to error now")
raise ValueError("This is an intentional async error!")
tool = create_tool_from_func(async_error)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def async_list_tool(test_user):
async def async_create_list() -> list:
"""Async function that returns a list"""
import asyncio
await asyncio.sleep(0.05)
return [1, 2, 3, 4, 5]
tool = create_tool_from_func(async_create_list)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def tool_with_pip_requirements(test_user):
def use_requests_and_numpy() -> str:
"""
Function that uses requests and numpy packages to test tool-specific pip requirements.
Returns:
str: Success message if packages are available.
"""
try:
import numpy as np
import requests
# Simple usage to verify packages work
response = requests.get("https://httpbin.org/json", timeout=30)
arr = np.array([1, 2, 3])
return f"Success! Status: {response.status_code}, Array sum: {np.sum(arr)}"
except ImportError as e:
return f"Import error: {e}"
except Exception as e:
return f"Other error: {e}"
tool = create_tool_from_func(use_requests_and_numpy)
# Add pip requirements to the tool
tool.pip_requirements = [
PipRequirement(name="requests", version="2.31.0"),
PipRequirement(name="numpy"),
]
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
@pytest.fixture
async def async_complex_tool(test_user):
async def async_complex_computation(iterations: int = 3) -> dict:
"""
Async function that performs complex computation with multiple awaits.
Parameters:
iterations (int): Number of iterations to perform.
Returns:
dict: Results of the computation.
"""
import asyncio
import time
results = []
start_time = time.time()
for i in range(iterations):
# Simulate async I/O
await asyncio.sleep(0.1)
results.append(i * 2)
end_time = time.time()
return {
"results": results,
"duration": end_time - start_time,
"iterations": iterations,
"average": sum(results) / len(results) if results else 0,
}
tool = create_tool_from_func(async_complex_computation)
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield tool
# Modal sandbox tests
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_default(check_modal_key_is_set, add_integers_tool, test_user):
args = {"x": 10, "y": 5}
# Mock and assert correct pathway was invoked
with patch.object(AsyncToolSandboxModal, "run") as mock_run:
sandbox = AsyncToolSandboxModal(add_integers_tool.name, args, user=test_user)
await sandbox.run()
mock_run.assert_called_once()
# Run again to get actual response
sandbox = AsyncToolSandboxModal(add_integers_tool.name, args, user=test_user)
result = await sandbox.run()
assert int(result.func_return) == args["x"] + args["y"]
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_pip_installs(check_modal_key_is_set, cowsay_tool, test_user):
"""Test that Modal sandbox installs tool-level pip requirements."""
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
config = await manager.create_or_update_sandbox_config_async(config_create, test_user)
key = "secret_word"
long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20))
await manager.create_sandbox_env_var_async(
SandboxEnvironmentVariableCreate(key=key, value=long_random_string),
sandbox_config_id=config.id,
actor=test_user,
)
sandbox = AsyncToolSandboxModal(cowsay_tool.name, {}, user=test_user)
result = await sandbox.run()
assert long_random_string in result.stdout[0]
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_stateful_tool(check_modal_key_is_set, clear_core_memory_tool, test_user, agent_state):
sandbox = AsyncToolSandboxModal(clear_core_memory_tool.name, {}, user=test_user)
result = await sandbox.run(agent_state=agent_state)
assert result.agent_state.memory.get_block("human").value == ""
assert result.agent_state.memory.get_block("persona").value == ""
assert result.func_return is None
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_inject_env_var_existing_sandbox(check_modal_key_is_set, get_env_tool, test_user):
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
config = await manager.create_or_update_sandbox_config_async(config_create, test_user)
sandbox = AsyncToolSandboxModal(get_env_tool.name, {}, user=test_user)
result = await sandbox.run()
assert result.func_return is None
key = "secret_word"
long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20))
await manager.create_sandbox_env_var_async(
SandboxEnvironmentVariableCreate(key=key, value=long_random_string),
sandbox_config_id=config.id,
actor=test_user,
)
sandbox = AsyncToolSandboxModal(get_env_tool.name, {}, user=test_user)
result = await sandbox.run()
assert long_random_string in result.func_return
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_per_agent_env(check_modal_key_is_set, get_env_tool, agent_state, test_user):
manager = SandboxConfigManager()
key = "secret_word"
wrong_val = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20))
correct_val = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20))
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
config = await manager.create_or_update_sandbox_config_async(config_create, test_user)
await manager.create_sandbox_env_var_async(
SandboxEnvironmentVariableCreate(key=key, value=wrong_val),
sandbox_config_id=config.id,
actor=test_user,
)
agent_state.secrets = [AgentEnvironmentVariable(key=key, value=correct_val, agent_id=agent_state.id)]
sandbox = AsyncToolSandboxModal(get_env_tool.name, {}, user=test_user)
result = await sandbox.run(agent_state=agent_state)
assert wrong_val not in result.func_return
assert correct_val in result.func_return
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_with_list_rv(check_modal_key_is_set, list_tool, test_user):
sandbox = AsyncToolSandboxModal(list_tool.name, {}, user=test_user)
result = await sandbox.run()
assert len(result.func_return) == 5
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_with_tool_pip_requirements(check_modal_key_is_set, tool_with_pip_requirements, test_user):
"""Test that Modal sandbox installs tool-specific pip requirements."""
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
await manager.create_or_update_sandbox_config_async(config_create, test_user)
sandbox = AsyncToolSandboxModal(tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements)
result = await sandbox.run()
# Should succeed since tool pip requirements were installed
assert "Success!" in result.func_return
assert "Status: 200" in result.func_return
assert "Array sum: 6" in result.func_return
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_with_mixed_pip_requirements(check_modal_key_is_set, tool_with_pip_requirements, test_user):
"""Test that Modal sandbox installs tool pip requirements.
Note: Modal does not support sandbox-level pip requirements - all pip requirements
must be specified at the tool level since the Modal app is deployed with a fixed image.
"""
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
await manager.create_or_update_sandbox_config_async(config_create, test_user)
sandbox = AsyncToolSandboxModal(tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements)
result = await sandbox.run()
# Should succeed since tool pip requirements were installed
assert "Success!" in result.func_return
assert "Status: 200" in result.func_return
assert "Array sum: 6" in result.func_return
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_with_broken_tool_pip_requirements_error_handling(check_modal_key_is_set, test_user):
"""Test that Modal sandbox provides informative error messages for broken tool pip requirements."""
def use_broken_package() -> str:
"""
Function that tries to use packages with broken version constraints.
Returns:
str: Success message if packages are available.
"""
return "Should not reach here"
tool = create_tool_from_func(use_broken_package)
# Add broken pip requirements
tool.pip_requirements = [
PipRequirement(name="numpy", version="1.24.0"), # Old version incompatible with newer Python
PipRequirement(name="nonexistent-package-12345"), # Non-existent package
]
# expect a LettaInvalidArgumentError
from letta.errors import LettaInvalidArgumentError
with pytest.raises(LettaInvalidArgumentError):
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
@pytest.mark.asyncio
async def test_async_function_detection(add_integers_tool, async_add_integers_tool, test_user):
"""Test that async function detection works correctly"""
# Test sync function detection
sync_sandbox = AsyncToolSandboxModal(add_integers_tool.name, {}, test_user, tool_object=add_integers_tool)
await sync_sandbox._init_async()
assert not sync_sandbox.is_async_function
# Test async function detection
async_sandbox = AsyncToolSandboxModal(async_add_integers_tool.name, {}, test_user, tool_object=async_add_integers_tool)
await async_sandbox._init_async()
assert async_sandbox.is_async_function
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_async_function_execution(check_modal_key_is_set, async_add_integers_tool, test_user):
"""Test that async functions execute correctly in Modal sandbox"""
args = {"x": 20, "y": 30}
sandbox = AsyncToolSandboxModal(async_add_integers_tool.name, args, user=test_user)
result = await sandbox.run()
assert int(result.func_return) == args["x"] + args["y"]
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_async_complex_computation(check_modal_key_is_set, async_complex_tool, test_user):
"""Test complex async computation with multiple awaits in Modal sandbox"""
args = {"iterations": 2}
sandbox = AsyncToolSandboxModal(async_complex_tool.name, args, user=test_user)
result = await sandbox.run()
func_return = result.func_return
assert isinstance(func_return, dict)
assert func_return["results"] == [0, 2]
assert func_return["iterations"] == 2
assert func_return["average"] == 1.0
assert func_return["duration"] > 0.15
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_async_list_return(check_modal_key_is_set, async_list_tool, test_user):
"""Test async function returning list in Modal sandbox"""
sandbox = AsyncToolSandboxModal(async_list_tool.name, {}, user=test_user)
result = await sandbox.run()
assert result.func_return == [1, 2, 3, 4, 5]
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_async_with_env_vars(check_modal_key_is_set, async_get_env_tool, test_user):
"""Test async function with environment variables in Modal sandbox"""
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
config = await manager.create_or_update_sandbox_config_async(config_create, test_user)
# Create environment variable
key = "secret_word"
test_value = "async_modal_test_value_456"
await manager.create_sandbox_env_var_async(
SandboxEnvironmentVariableCreate(key=key, value=test_value), sandbox_config_id=config.id, actor=test_user
)
sandbox = AsyncToolSandboxModal(async_get_env_tool.name, {}, user=test_user)
result = await sandbox.run()
assert test_value in result.func_return
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_async_with_agent_state(check_modal_key_is_set, async_stateful_tool, test_user, agent_state):
"""Test async function with agent state in Modal sandbox"""
sandbox = AsyncToolSandboxModal(async_stateful_tool.name, {}, user=test_user)
result = await sandbox.run(agent_state=agent_state)
assert result.agent_state.memory.get_block("human").value == ""
assert result.agent_state.memory.get_block("persona").value == ""
assert result.func_return is None
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_async_error_handling(check_modal_key_is_set, async_error_tool, test_user):
"""Test async function error handling in Modal sandbox"""
sandbox = AsyncToolSandboxModal(async_error_tool.name, {}, user=test_user)
result = await sandbox.run()
# Check that error was captured
assert len(result.stdout) != 0, "stdout not empty"
assert "error" in result.stdout[0], "stdout contains printed string"
assert len(result.stderr) != 0, "stderr not empty"
assert "ValueError: This is an intentional async error!" in result.stderr[0], "stderr contains expected error"
@pytest.mark.asyncio
@pytest.mark.modal_sandbox
async def test_modal_sandbox_async_per_agent_env(check_modal_key_is_set, async_get_env_tool, agent_state, test_user):
"""Test async function with per-agent environment variables in Modal sandbox"""
manager = SandboxConfigManager()
key = "secret_word"
wrong_val = "wrong_async_modal_value"
correct_val = "correct_async_modal_value"
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
config = await manager.create_or_update_sandbox_config_async(config_create, test_user)
await manager.create_sandbox_env_var_async(
SandboxEnvironmentVariableCreate(key=key, value=wrong_val),
sandbox_config_id=config.id,
actor=test_user,
)
agent_state.secrets = [AgentEnvironmentVariable(key=key, value=correct_val, agent_id=agent_state.id)]
sandbox = AsyncToolSandboxModal(async_get_env_tool.name, {}, user=test_user)
result = await sandbox.run(agent_state=agent_state)
assert wrong_val not in result.func_return
assert correct_val in result.func_return

View File

@@ -11,6 +11,7 @@ filterwarnings =
markers =
local_sandbox: mark test as part of local sandbox tests
e2b_sandbox: mark test as part of E2B sandbox tests
modal_sandbox: mark test as part of Modal sandbox tests
openai_basic: Tests for OpenAI endpoints
anthropic_basic: Tests for Anthropic endpoints
azure_basic: Tests for Azure endpoints

View File

@@ -395,7 +395,7 @@ def test_function_always_error(client: Letta):
assert response_message.status == "error"
# TODO: add this back
# assert "Error executing function testing_method" in response_message.tool_return, response_message.tool_return
assert "ZeroDivisionError: division by zero" in response_message.stderr[0]
assert "division by zero" in response_message.stderr[0]
client.agents.delete(agent_id=agent.id)