feat: modal tool execution - NO FEATURE FLAGS USES MODAL [LET-4357] (#5120)
* initial commit * add delay to deploy * fix tests * add tests * passing tests * cleanup * and use modal * working on modal * gate on tool metadata * agent state * cleanup --------- Co-authored-by: Letta Bot <noreply@letta.com>
This commit is contained in:
committed by
Caren Thomas
parent
2a8523aa01
commit
5730f69ecf
@@ -44,6 +44,11 @@ IN_CONTEXT_MEMORY_KEYWORD = "CORE_MEMORY"
|
||||
# OpenAI error message: Invalid 'messages[1].tool_calls[0].id': string too long. Expected a string with maximum length 29, but got a string with length 36 instead.
|
||||
TOOL_CALL_ID_MAX_LEN = 29
|
||||
|
||||
# Maximum length for tool names to support Modal deployment
|
||||
# Modal function names are limited to 64 characters: tool_name + "_" + project_id
|
||||
# Reserving 16 characters for project_id suffix (e.g., "_project-12345678")
|
||||
MAX_TOOL_NAME_LENGTH = 48
|
||||
|
||||
# Max steps for agent loop
|
||||
DEFAULT_MAX_STEPS = 50
|
||||
|
||||
@@ -440,5 +445,18 @@ EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES = ["claude-4-sonnet", "claude-3-5-so
|
||||
# But include models with these keywords in base tool rules (overrides exclusion)
|
||||
INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES = ["mini"]
|
||||
|
||||
# Deployment and versioning
|
||||
MODAL_DEFAULT_TOOL_NAME = "modal_tool_wrapper.<locals>.modal_function" # NOTE: must stay in sync with modal_tool_wrapper
|
||||
MODAL_DEFAULT_CONFIG_KEY = "default"
|
||||
MODAL_MODAL_DEPLOYMENTS_KEY = "modal_deployments"
|
||||
MODAL_VERSION_HASH_LENGTH = 12
|
||||
|
||||
# Modal execution settings
|
||||
MODAL_DEFAULT_TIMEOUT = 60
|
||||
MODAL_DEFAULT_MAX_CONCURRENT_INPUTS = 1
|
||||
MODAL_DEFAULT_PYTHON_VERSION = "3.12"
|
||||
|
||||
# Security settings
|
||||
MODAL_SAFE_IMPORT_MODULES = {"typing", "pydantic", "datetime", "enum", "uuid", "decimal"}
|
||||
# Default handle for model used to generate tools
|
||||
DEFAULT_GENERATE_TOOL_MODEL_HANDLE = "openai/gpt-4.1"
|
||||
|
||||
69
letta/helpers/tool_helpers.py
Normal file
69
letta/helpers/tool_helpers.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import hashlib
|
||||
|
||||
from letta.constants import MODAL_VERSION_HASH_LENGTH
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
|
||||
def _serialize_dependencies(tool: Tool) -> str:
|
||||
"""
|
||||
Serialize dependencies in a consistent way for hashing.
|
||||
TODO: This should be improved per LET-3770 to ensure consistent ordering.
|
||||
For now, we convert to string representation.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if tool.pip_requirements:
|
||||
# TODO: Sort these consistently
|
||||
parts.append(f"pip:{str(tool.pip_requirements)}")
|
||||
|
||||
if tool.npm_requirements:
|
||||
# TODO: Sort these consistently
|
||||
parts.append(f"npm:{str(tool.npm_requirements)}")
|
||||
|
||||
return ";".join(parts)
|
||||
|
||||
|
||||
def compute_tool_hash(tool: Tool):
|
||||
"""
|
||||
Calculate a hash representing the current version of the tool and configuration.
|
||||
This hash changes when:
|
||||
- Tool source code changes
|
||||
- Tool dependencies change
|
||||
- Sandbox configuration changes
|
||||
- Language/runtime changes
|
||||
"""
|
||||
components = [
|
||||
tool.source_code if tool.source_code else "",
|
||||
tool.source_type if tool.source_type else "",
|
||||
_serialize_dependencies(tool),
|
||||
]
|
||||
|
||||
combined = "|".join(components)
|
||||
return hashlib.sha256(combined.encode()).hexdigest()[:MODAL_VERSION_HASH_LENGTH]
|
||||
|
||||
|
||||
def generate_modal_function_name(tool_name: str, organization_id: str, project_id: str = "default") -> str:
|
||||
"""
|
||||
Generate a Modal function name from tool name and project ID.
|
||||
Shortens the project ID to just the prefix and first UUID segment.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
organization_id: Full organization ID (not used in function name, but kept for future use)
|
||||
project_id: Project ID (e.g., project-12345678-90ab-cdef-1234-567890abcdef or "default")
|
||||
|
||||
Returns:
|
||||
Modal function name (e.g., tool_name_project-12345678 or tool_name_default)
|
||||
"""
|
||||
from letta.constants import MAX_TOOL_NAME_LENGTH
|
||||
|
||||
max_tool_name_length = 64
|
||||
|
||||
# Shorten the organization_id to just the first segment (e.g., project-12345678)
|
||||
short_organization_id = organization_id[: (max_tool_name_length - MAX_TOOL_NAME_LENGTH - 1)]
|
||||
|
||||
# make extra sure the tool name is not too long
|
||||
name = f"{tool_name[:MAX_TOOL_NAME_LENGTH]}_{short_organization_id}"
|
||||
|
||||
# safe fallback
|
||||
return name
|
||||
@@ -4,12 +4,11 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.constants import LETTA_TOOL_EXECUTION_DIR, MODAL_DEFAULT_TIMEOUT
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import PrimitiveType, SandboxType
|
||||
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
|
||||
from letta.schemas.pip_requirement import PipRequirement
|
||||
from letta.services.tool_sandbox.modal_constants import DEFAULT_MODAL_TIMEOUT
|
||||
from letta.settings import tool_settings
|
||||
|
||||
# Sandbox Config
|
||||
@@ -81,7 +80,7 @@ class E2BSandboxConfig(BaseModel):
|
||||
|
||||
|
||||
class ModalSandboxConfig(BaseModel):
|
||||
timeout: int = Field(DEFAULT_MODAL_TIMEOUT, description="Time limit for the sandbox (in seconds).")
|
||||
timeout: int = Field(MODAL_DEFAULT_TIMEOUT, description="Time limit for the sandbox (in seconds).")
|
||||
pip_requirements: list[str] | None = Field(None, description="A list of pip packages to install in the Modal sandbox")
|
||||
npm_requirements: list[str] | None = Field(None, description="A list of npm packages to install in the Modal sandbox")
|
||||
language: Literal["python", "typescript"] = "python"
|
||||
|
||||
@@ -7,6 +7,7 @@ from httpx import ConnectError, HTTPStatusError
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from letta.constants import MAX_TOOL_NAME_LENGTH
|
||||
from letta.constants import DEFAULT_GENERATE_TOOL_MODEL_HANDLE
|
||||
from letta.errors import (
|
||||
LettaInvalidArgumentError,
|
||||
|
||||
@@ -46,27 +46,19 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
|
||||
agent_state_copy = self._create_agent_state_copy(agent_state) if agent_state else None
|
||||
|
||||
# Execute in sandbox depending on API key
|
||||
if tool_settings.sandbox_type == SandboxType.E2B:
|
||||
from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B
|
||||
# Execute in sandbox with Modal first (if configured and requested), then fallback to E2B/LOCAL
|
||||
# Try Modal if: (1) Modal credentials configured AND (2) tool requests Modal via metadata
|
||||
tool_requests_modal = tool.metadata_ and tool.metadata_.get("sandbox") == "modal"
|
||||
modal_configured = tool_settings.modal_sandbox_enabled
|
||||
|
||||
sandbox = AsyncToolSandboxE2B(
|
||||
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
|
||||
)
|
||||
# TODO (cliandy): this is just for testing right now, separate this out into it's own subclass and handling logic
|
||||
elif tool_settings.sandbox_type == SandboxType.MODAL:
|
||||
from letta.services.tool_sandbox.modal_sandbox import AsyncToolSandboxModal, TypescriptToolSandboxModal
|
||||
tool_execution_result = None
|
||||
|
||||
if tool.source_type == ToolSourceType.typescript:
|
||||
sandbox = TypescriptToolSandboxModal(
|
||||
function_name,
|
||||
function_args,
|
||||
actor,
|
||||
tool_object=tool,
|
||||
sandbox_config=sandbox_config,
|
||||
sandbox_env_vars=sandbox_env_vars,
|
||||
)
|
||||
elif tool.source_type == ToolSourceType.python:
|
||||
# Try Modal first if both conditions met
|
||||
if tool_requests_modal and modal_configured:
|
||||
try:
|
||||
from letta.services.tool_sandbox.modal_sandbox import AsyncToolSandboxModal
|
||||
|
||||
logger.info(f"Attempting Modal execution for tool {tool.name}")
|
||||
sandbox = AsyncToolSandboxModal(
|
||||
function_name,
|
||||
function_args,
|
||||
@@ -74,12 +66,36 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
tool_object=tool,
|
||||
sandbox_config=sandbox_config,
|
||||
sandbox_env_vars=sandbox_env_vars,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
# TODO: pass through letta api key
|
||||
tool_execution_result = await sandbox.run(agent_state=agent_state_copy, additional_env_vars=sandbox_env_vars)
|
||||
except Exception as e:
|
||||
# Modal execution failed, log and fall back to E2B/LOCAL
|
||||
logger.warning(f"Modal execution failed for tool {tool.name}: {e}. Falling back to {tool_settings.sandbox_type.value}")
|
||||
tool_execution_result = None
|
||||
|
||||
# Fallback to E2B or LOCAL if Modal wasn't tried or failed
|
||||
if tool_execution_result is None:
|
||||
if tool_settings.sandbox_type == SandboxType.E2B:
|
||||
from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B
|
||||
|
||||
sandbox = AsyncToolSandboxE2B(
|
||||
function_name,
|
||||
function_args,
|
||||
actor,
|
||||
tool_object=tool,
|
||||
sandbox_config=sandbox_config,
|
||||
sandbox_env_vars=sandbox_env_vars,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Tool source type was {tool.source_type} but is required to be python or typescript to run in Modal.")
|
||||
else:
|
||||
sandbox = AsyncToolSandboxLocal(
|
||||
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
|
||||
function_name,
|
||||
function_args,
|
||||
actor,
|
||||
tool_object=tool,
|
||||
sandbox_config=sandbox_config,
|
||||
sandbox_env_vars=sandbox_env_vars,
|
||||
)
|
||||
|
||||
tool_execution_result = await sandbox.run(agent_state=agent_state_copy)
|
||||
|
||||
@@ -16,30 +16,176 @@ from letta.constants import (
|
||||
LETTA_TOOL_MODULE_NAMES,
|
||||
LETTA_TOOL_SET,
|
||||
LOCAL_ONLY_MULTI_AGENT_TOOLS,
|
||||
MAX_TOOL_NAME_LENGTH,
|
||||
MCP_TOOL_TAG_NAME_PREFIX,
|
||||
MODAL_DEFAULT_TOOL_NAME,
|
||||
)
|
||||
from letta.errors import LettaToolNameConflictError, LettaToolNameSchemaMismatchError
|
||||
from letta.errors import LettaInvalidArgumentError, LettaToolNameConflictError, LettaToolNameSchemaMismatchError
|
||||
from letta.functions.functions import derive_openai_json_schema, load_function_set
|
||||
from letta.helpers.tool_helpers import compute_tool_hash, generate_modal_function_name
|
||||
from letta.log import get_logger
|
||||
|
||||
# TODO: Remove this once we translate all of these to the ORM
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.tool import Tool as ToolModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import PrimitiveType, ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import PrimitiveType, SandboxType, ToolType
|
||||
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.helpers.agent_manager_helper import calculate_multi_agent_tools
|
||||
from letta.services.mcp.types import SSEServerConfig, StdioServerConfig
|
||||
from letta.services.tool_schema_generator import generate_schema_for_tool_creation, generate_schema_for_tool_update
|
||||
from letta.settings import settings
|
||||
from letta.settings import settings, tool_settings
|
||||
from letta.utils import enforce_types, printd
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# NOTE: function name and nested modal function decorator name must stay in sync with MODAL_DEFAULT_TOOL_NAME
|
||||
def modal_tool_wrapper(tool: PydanticTool, actor: PydanticUser, sandbox_env_vars: dict = None, project_id: str = "default"):
|
||||
"""Create a Modal function wrapper for a tool"""
|
||||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import modal
|
||||
from letta_client import Letta
|
||||
|
||||
packages = [str(req) for req in tool.pip_requirements] if tool.pip_requirements else []
|
||||
packages.append("letta_client")
|
||||
packages.append("letta") # Base letta without extras
|
||||
packages.append("asyncpg>=0.30.0") # Fixes asyncpg import error
|
||||
packages.append("psycopg2-binary>=2.9.10") # PostgreSQL adapter (pre-compiled, no build required)
|
||||
# packages.append("pgvector>=0.3.6") # Vector operations support
|
||||
|
||||
function_name = generate_modal_function_name(tool.name, actor.organization_id, project_id)
|
||||
modal_app = modal.App(function_name)
|
||||
logger.info(f"Creating Modal app {tool.id} with name {function_name}")
|
||||
|
||||
# Create secrets dict with sandbox env vars
|
||||
secrets_dict = {"LETTA_API_KEY": None}
|
||||
if sandbox_env_vars:
|
||||
secrets_dict.update(sandbox_env_vars)
|
||||
|
||||
@modal_app.function(
|
||||
image=modal.Image.debian_slim(python_version="3.13").pip_install(packages),
|
||||
restrict_modal_access=True,
|
||||
timeout=10,
|
||||
secrets=[modal.Secret.from_dict(secrets_dict)],
|
||||
serialized=True,
|
||||
)
|
||||
def modal_function(
|
||||
tool_name: str, agent_state: Optional[dict], agent_id: Optional[str], env_vars: dict, letta_api_key: Optional[str] = None, **kwargs
|
||||
):
|
||||
"""Wrapper function for running untrusted code in a Modal function"""
|
||||
# Reconstruct AgentState from dict if passed (to avoid cloudpickle serialization issues)
|
||||
# This is done with extra safety to handle schema mismatches between environments
|
||||
reconstructed_agent_state = None
|
||||
if agent_state:
|
||||
try:
|
||||
from letta.schemas.agent import AgentState as AgentStateModel
|
||||
|
||||
# Filter dict to only include fields that exist in Modal's version of AgentState
|
||||
# This prevents ValidationError from extra fields in newer schemas
|
||||
modal_agent_fields = set(AgentStateModel.model_fields.keys())
|
||||
filtered_agent_state = {key: value for key, value in agent_state.items() if key in modal_agent_fields}
|
||||
|
||||
# Try to reconstruct with filtered data
|
||||
reconstructed_agent_state = AgentStateModel.model_validate(filtered_agent_state)
|
||||
|
||||
# Log if we filtered out any fields
|
||||
filtered_out = set(agent_state.keys()) - modal_agent_fields
|
||||
if filtered_out:
|
||||
print(f"Fields not in available in AgentState: {filtered_out}", file=sys.stderr)
|
||||
|
||||
except ImportError as e:
|
||||
print(f"Cannot import AgentState: {e}", file=sys.stderr)
|
||||
print("Passing agent_state as dict to tool", file=sys.stderr)
|
||||
reconstructed_agent_state = agent_state
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not reconstruct AgentState (schema mismatch?): {e}", file=sys.stderr)
|
||||
print("Passing agent_state as dict to tool", file=sys.stderr)
|
||||
reconstructed_agent_state = agent_state
|
||||
|
||||
# Set environment variables
|
||||
if env_vars:
|
||||
for key, value in env_vars.items():
|
||||
os.environ[key] = str(value)
|
||||
|
||||
# Initialize the Letta client
|
||||
if letta_api_key:
|
||||
letta_client = Letta(token=letta_api_key, base_url=os.environ.get("LETTA_API_URL", "https://api.letta.com"))
|
||||
else:
|
||||
letta_client = None
|
||||
|
||||
tool_namespace = {
|
||||
"__builtins__": __builtins__, # Include built-in functions
|
||||
"_letta_client": letta_client, # Make letta_client available
|
||||
"os": os, # Include os module for env vars access
|
||||
"agent_id": agent_id,
|
||||
# Add any other modules/variables the tool might need
|
||||
}
|
||||
|
||||
# Initialize the tool code
|
||||
# Create a namespace for the tool
|
||||
# tool_namespace = {}
|
||||
exec(tool.source_code, tool_namespace)
|
||||
|
||||
# Get the tool function
|
||||
if tool_name not in tool_namespace:
|
||||
raise Exception(f"Tool function {tool_name} not found in {tool.source_code}, globals: {tool_namespace}")
|
||||
tool_func = tool_namespace[tool_name]
|
||||
|
||||
# Detect if the tool function is async
|
||||
import asyncio
|
||||
import inspect
|
||||
import traceback
|
||||
|
||||
is_async = inspect.iscoroutinefunction(tool_func)
|
||||
|
||||
# Capture stdout and stderr during tool execution
|
||||
stdout_capture = io.StringIO()
|
||||
stderr_capture = io.StringIO()
|
||||
result = None
|
||||
error_occurred = False
|
||||
|
||||
with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture):
|
||||
try:
|
||||
# if `agent_state` is in the tool function arguments, inject it
|
||||
# Pass reconstructed AgentState (or dict if reconstruction failed)
|
||||
if "agent_state" in tool_func.__code__.co_varnames:
|
||||
kwargs["agent_state"] = reconstructed_agent_state
|
||||
|
||||
# Execute the tool function (async or sync)
|
||||
if is_async:
|
||||
result = asyncio.run(tool_func(**kwargs))
|
||||
else:
|
||||
result = tool_func(**kwargs)
|
||||
except Exception as e:
|
||||
# Capture the exception and write to stderr
|
||||
error_occurred = True
|
||||
traceback.print_exc(file=stderr_capture)
|
||||
|
||||
# Get captured output
|
||||
stdout = stdout_capture.getvalue()
|
||||
stderr = stderr_capture.getvalue()
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"stdout": stdout,
|
||||
"stderr": stderr,
|
||||
"agent_state": agent_state, # TODO: deprecate (use letta_client instead)
|
||||
"error": error_occurred or bool(stderr),
|
||||
}
|
||||
|
||||
return modal_app
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""Manager class to handle business logic related to Tools."""
|
||||
|
||||
@@ -56,12 +202,16 @@ class ToolManager:
|
||||
else:
|
||||
raise ValueError("Failed to generate schema for tool", pydantic_tool.source_code)
|
||||
|
||||
print("SCHEMA", pydantic_tool.json_schema)
|
||||
|
||||
# make sure the name matches the json_schema
|
||||
if not pydantic_tool.name:
|
||||
pydantic_tool.name = pydantic_tool.json_schema.get("name")
|
||||
else:
|
||||
# if name is provided, make sure its less tahn the MAX_TOOL_NAME_LENGTH
|
||||
if len(pydantic_tool.name) > MAX_TOOL_NAME_LENGTH:
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Tool name {pydantic_tool.name} is too long. It must be less than {MAX_TOOL_NAME_LENGTH} characters."
|
||||
)
|
||||
|
||||
if pydantic_tool.name != pydantic_tool.json_schema.get("name"):
|
||||
raise LettaToolNameSchemaMismatchError(
|
||||
tool_name=pydantic_tool.name,
|
||||
@@ -136,6 +286,13 @@ class ToolManager:
|
||||
# Auto-generate description if not provided
|
||||
if pydantic_tool.description is None and pydantic_tool.json_schema:
|
||||
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
|
||||
|
||||
# Add tool hash to metadata for Modal deployment tracking
|
||||
tool_hash = compute_tool_hash(pydantic_tool)
|
||||
if pydantic_tool.metadata_ is None:
|
||||
pydantic_tool.metadata_ = {}
|
||||
pydantic_tool.metadata_["tool_hash"] = tool_hash
|
||||
|
||||
tool_data = pydantic_tool.model_dump(to_orm=True)
|
||||
# Set the organization id at the ORM layer
|
||||
tool_data["organization_id"] = actor.organization_id
|
||||
@@ -153,7 +310,17 @@ class ToolManager:
|
||||
)
|
||||
|
||||
await tool.create_async(session, actor=actor) # Re-raise other database-related errors
|
||||
return tool.to_pydantic()
|
||||
created_tool = tool.to_pydantic()
|
||||
|
||||
# Deploy Modal app for the new tool
|
||||
# Both Modal credentials configured AND tool metadata must indicate Modal
|
||||
tool_requests_modal = created_tool.metadata_ and created_tool.metadata_.get("sandbox") == "modal"
|
||||
modal_configured = tool_settings.modal_sandbox_enabled
|
||||
|
||||
if created_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured:
|
||||
await self.create_or_update_modal_app(created_tool, actor)
|
||||
|
||||
return created_tool
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -599,6 +766,33 @@ class ToolManager:
|
||||
source_code=update_data.get("source_code"),
|
||||
)
|
||||
|
||||
# Create a preview of the updated tool by merging current tool with updates
|
||||
# This allows us to compute the hash before the database session
|
||||
updated_tool_pydantic = current_tool.model_copy(deep=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(updated_tool_pydantic, key, value)
|
||||
if new_schema is not None:
|
||||
updated_tool_pydantic.json_schema = new_schema
|
||||
updated_tool_pydantic.name = new_name
|
||||
if updated_tool_type:
|
||||
updated_tool_pydantic.tool_type = updated_tool_type
|
||||
|
||||
# Check if we need to redeploy the Modal app due to changes
|
||||
# Compute this before the session to avoid issues
|
||||
tool_requests_modal = updated_tool_pydantic.metadata_ and updated_tool_pydantic.metadata_.get("sandbox") == "modal"
|
||||
modal_configured = tool_settings.modal_sandbox_enabled
|
||||
should_check_modal = tool_requests_modal and modal_configured and updated_tool_pydantic.tool_type == ToolType.CUSTOM
|
||||
|
||||
# Compute hash before session if needed
|
||||
new_hash = None
|
||||
old_hash = None
|
||||
needs_modal_deployment = False
|
||||
|
||||
if should_check_modal:
|
||||
new_hash = compute_tool_hash(updated_tool_pydantic)
|
||||
old_hash = current_tool.metadata_.get("tool_hash") if current_tool.metadata_ else None
|
||||
needs_modal_deployment = new_hash != old_hash
|
||||
|
||||
# Now perform the update within the session
|
||||
async with db_registry.async_session() as session:
|
||||
# Fetch the tool by ID
|
||||
@@ -618,7 +812,25 @@ class ToolManager:
|
||||
|
||||
# Save the updated tool to the database
|
||||
tool = await tool.update_async(db_session=session, actor=actor)
|
||||
return tool.to_pydantic()
|
||||
updated_tool = tool.to_pydantic()
|
||||
|
||||
# Update Modal hash in metadata if needed (inside session context)
|
||||
if needs_modal_deployment:
|
||||
if updated_tool.metadata_ is None:
|
||||
updated_tool.metadata_ = {}
|
||||
updated_tool.metadata_["tool_hash"] = new_hash
|
||||
|
||||
# Update the metadata in the database (still inside session)
|
||||
tool.metadata_ = updated_tool.metadata_
|
||||
tool = await tool.update_async(db_session=session, actor=actor)
|
||||
updated_tool = tool.to_pydantic()
|
||||
|
||||
# Deploy Modal app outside of session (it creates its own sessions)
|
||||
if needs_modal_deployment:
|
||||
logger.info(f"Deploying Modal app for tool {updated_tool.id} with new hash: {new_hash}")
|
||||
await self.create_or_update_modal_app(updated_tool, actor)
|
||||
|
||||
return updated_tool
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -783,3 +995,49 @@ class ToolManager:
|
||||
created_tool = await self.create_tool_async(tool, actor=actor)
|
||||
tools.append(created_tool)
|
||||
return tools
|
||||
|
||||
# MODAL RELATED METHODS
|
||||
@trace_method
|
||||
async def create_or_update_modal_app(self, tool: PydanticTool, actor: PydanticUser):
|
||||
"""Create a Modal app with the tool function registered"""
|
||||
import time
|
||||
|
||||
import modal
|
||||
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
|
||||
# Load sandbox env vars to bake them into the Modal secrets
|
||||
sandbox_env_vars = {}
|
||||
try:
|
||||
sandbox_config_manager = SandboxConfigManager()
|
||||
sandbox_config = await sandbox_config_manager.get_or_create_default_sandbox_config_async(
|
||||
sandbox_type=SandboxType.MODAL, actor=actor
|
||||
)
|
||||
if sandbox_config:
|
||||
sandbox_env_vars = await sandbox_config_manager.get_sandbox_env_vars_as_dict_async(
|
||||
sandbox_config_id=sandbox_config.id, actor=actor, limit=None
|
||||
)
|
||||
logger.info(f"Loaded {len(sandbox_env_vars)} sandbox env vars for Modal app {tool.id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load sandbox env vars for Modal app {tool.id}: {e}")
|
||||
|
||||
# Create the Modal app using the global function with sandbox env vars
|
||||
modal_app = modal_tool_wrapper(tool, actor, sandbox_env_vars)
|
||||
|
||||
# Deploy the app first
|
||||
with modal.enable_output():
|
||||
try:
|
||||
deploy = modal_app.deploy()
|
||||
except Exception as e:
|
||||
raise LettaInvalidArgumentError(f"Failed to deploy tool {tool.id} with name {tool.name} to Modal: {e}")
|
||||
|
||||
# After deployment, look up the function to configure autoscaler
|
||||
try:
|
||||
func = modal.Function.from_name(modal_app.name, MODAL_DEFAULT_TOOL_NAME)
|
||||
func.update_autoscaler(scaledown_window=2) # drain inactive old containers
|
||||
time.sleep(5)
|
||||
func.update_autoscaler(scaledown_window=60)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to configure autoscaler for Modal function {modal_app.name}: {e}")
|
||||
|
||||
return deploy
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
from typing import Any, Dict, Optional
|
||||
"""
|
||||
Model sandbox implementation, which configures on Modal App per tool.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import modal
|
||||
from e2b.sandbox.commands.command_handle import CommandExitException
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
|
||||
from letta.constants import MODAL_DEFAULT_TOOL_NAME
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -9,412 +16,180 @@ from letta.schemas.enums import SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.services.helpers.tool_parser_helper import parse_stdout_best_effort
|
||||
from letta.services.helpers.tool_parser_helper import parse_function_arguments, parse_stdout_best_effort
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
|
||||
from letta.settings import tool_settings
|
||||
from letta.types import JsonDict
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# class AsyncToolSandboxModalBase(AsyncToolSandboxBase):
|
||||
# pass
|
||||
if TYPE_CHECKING:
|
||||
from e2b_code_interpreter import Execution
|
||||
|
||||
|
||||
class AsyncToolSandboxModal(AsyncToolSandboxBase):
|
||||
METADATA_CONFIG_STATE_KEY = "config_state"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
args: JsonDict,
|
||||
user,
|
||||
tool_object: Tool | None = None,
|
||||
sandbox_config: SandboxConfig | None = None,
|
||||
sandbox_env_vars: dict[str, Any] | None = None,
|
||||
force_recreate: bool = True,
|
||||
tool_object: Optional[Tool] = None,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
organization_id: Optional[str] = None,
|
||||
project_id: str = "default",
|
||||
):
|
||||
super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars)
|
||||
self.force_recreate = force_recreate
|
||||
# Get organization_id from user if not explicitly provided
|
||||
self.organization_id = organization_id if organization_id is not None else user.organization_id
|
||||
self.project_id = project_id
|
||||
|
||||
if not tool_settings.modal_token_id or not tool_settings.modal_token_secret:
|
||||
raise ValueError("MODAL_TOKEN_ID and MODAL_TOKEN_SECRET must be set.")
|
||||
# TODO: check to make sure modal app `App(tool.id)` exists
|
||||
|
||||
# Create a unique app name based on tool and config
|
||||
self._app_name = self._generate_app_name()
|
||||
async def _wait_for_modal_function_deployment(self, timeout: int = 60):
|
||||
"""Wait for Modal app deployment to complete by retrying function lookup."""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
def _generate_app_name(self) -> str:
|
||||
"""Generate a unique app name based on tool and configuration. Created based on tool name and org"""
|
||||
return f"{self.user.organization_id}-{self.tool_name}"
|
||||
import modal
|
||||
|
||||
async def _fetch_or_create_modal_app(self, sbx_config: SandboxConfig, env_vars: Dict[str, str]) -> modal.App:
|
||||
"""Create a Modal app with the tool function registered."""
|
||||
from letta.helpers.tool_helpers import generate_modal_function_name
|
||||
|
||||
# Use the same naming logic as deployment
|
||||
function_name = generate_modal_function_name(self.tool.name, self.organization_id, self.project_id)
|
||||
|
||||
start_time = time.time()
|
||||
retry_delay = 2 # seconds
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
app = await modal.App.lookup.aio(self._app_name)
|
||||
return app
|
||||
except:
|
||||
app = modal.App(self._app_name)
|
||||
f = modal.Function.from_name(function_name, MODAL_DEFAULT_TOOL_NAME)
|
||||
logger.info(f"Modal function found successfully for app {function_name}, function {f}")
|
||||
return f
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed >= timeout:
|
||||
raise TimeoutError(
|
||||
f"Modal app {function_name} deployment timed out after {timeout} seconds. "
|
||||
f"Expected app name: {function_name}, function: {MODAL_DEFAULT_TOOL_NAME}"
|
||||
) from e
|
||||
logger.info(f"Modal app {function_name} not ready yet (elapsed: {elapsed:.1f}s), waiting {retry_delay}s...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
modal_config = sbx_config.get_modal_config()
|
||||
|
||||
# Get the base image with dependencies
|
||||
image = self._get_modal_image(sbx_config)
|
||||
|
||||
# Decorator for the tool, note information on running untrusted code: https://modal.com/docs/guide/restricted-access
|
||||
# The `@app.function` decorator must apply to functions in global scope, unless `serialized=True` is set.
|
||||
@app.function(image=image, timeout=modal_config.timeout, restrict_modal_access=True, max_inputs=1, serialized=True)
|
||||
def execute_tool_with_script(execution_script: str, environment_vars: dict[str, str]):
|
||||
"""Execute the generated tool script in Modal sandbox."""
|
||||
import os
|
||||
|
||||
# Note: We pass environment variables directly instead of relying on Modal secrets
|
||||
# This is more flexible and doesn't require pre-configured secrets
|
||||
for key, value in environment_vars.items():
|
||||
os.environ[key] = str(value)
|
||||
|
||||
exec_globals = {}
|
||||
exec(execution_script, exec_globals)
|
||||
|
||||
# Store the function reference in the app for later use
|
||||
app.remote_executor = execute_tool_with_script
|
||||
return app
|
||||
raise TimeoutError(f"Modal app {function_name} deployment timed out after {timeout} seconds")
|
||||
|
||||
@trace_method
|
||||
async def run(
|
||||
self,
|
||||
agent_id: Optional[str] = None,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> ToolExecutionResult:
|
||||
if self.provided_sandbox_config:
|
||||
sbx_config = self.provided_sandbox_config
|
||||
else:
|
||||
sbx_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async(
|
||||
sandbox_type=SandboxType.MODAL, actor=self.user
|
||||
)
|
||||
|
||||
envs = await self._gather_env_vars(agent_state, additional_env_vars or {}, sbx_config.id, is_local=False)
|
||||
|
||||
# Generate execution script (this includes the tool source code and execution logic)
|
||||
execution_script = await self.generate_execution_script(agent_state=agent_state)
|
||||
await self._init_async()
|
||||
|
||||
try:
|
||||
log_event(
|
||||
"modal_execution_started",
|
||||
{"tool": self.tool_name, "app_name": self._app_name, "env_vars": list(envs)},
|
||||
log_event("modal_execution_started", {"tool": self.tool_name, "modal_app_id": self.tool.id})
|
||||
logger.info(f"Waiting for Modal function deployment for app {self.tool.id}")
|
||||
func = await self._wait_for_modal_function_deployment()
|
||||
logger.info(f"Modal function found successfully for app {self.tool.id}, function {str(func)}")
|
||||
logger.info(f"Calling with arguments {self.args}")
|
||||
|
||||
# TODO: use another mechanism to pass through the key
|
||||
if additional_env_vars is None:
|
||||
letta_api_key = None
|
||||
else:
|
||||
letta_api_key = additional_env_vars.get("LETTA_API_KEY", None)
|
||||
|
||||
# Construct dynamic env vars
|
||||
# Priority order (later overrides earlier):
|
||||
# 1. Sandbox-level env vars (from database)
|
||||
# 2. Agent-specific env vars
|
||||
# 3. Additional runtime env vars
|
||||
env_vars = {}
|
||||
|
||||
# Load sandbox-level environment variables from the database
|
||||
# These can be updated after deployment and will be available at runtime
|
||||
if self.provided_sandbox_env_vars:
|
||||
env_vars.update(self.provided_sandbox_env_vars)
|
||||
else:
|
||||
try:
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
|
||||
sandbox_config_manager = SandboxConfigManager()
|
||||
sandbox_config = await sandbox_config_manager.get_or_create_default_sandbox_config_async(
|
||||
sandbox_type=SandboxType.MODAL, actor=self.user
|
||||
)
|
||||
|
||||
# Create Modal app with the tool function registered
|
||||
app = await self._fetch_or_create_modal_app(sbx_config, envs)
|
||||
|
||||
# Execute the tool remotely
|
||||
with app.run():
|
||||
# app = modal.Cls.from_name(app.name, "NodeShimServer")()
|
||||
result = app.remote_executor.remote(execution_script, envs)
|
||||
|
||||
# Process the result
|
||||
if result["error"]:
|
||||
# Tool errors are expected behavior - tools can raise exceptions as part of their normal operation
|
||||
# Only log at debug level to avoid triggering Sentry alerts for expected errors
|
||||
logger.debug(f"Tool {self.tool_name} raised a {result['error']['name']}: {result['error']['value']}")
|
||||
logger.debug(f"Traceback from Modal sandbox: \n{result['error']['traceback']}")
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name, exception_name=result["error"]["name"], exception_message=result["error"]["value"]
|
||||
if sandbox_config:
|
||||
sandbox_env_vars = await sandbox_config_manager.get_sandbox_env_vars_as_dict_async(
|
||||
sandbox_config_id=sandbox_config.id, actor=self.user, limit=None
|
||||
)
|
||||
env_vars.update(sandbox_env_vars)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load sandbox env vars for tool {self.tool_name}: {e}")
|
||||
|
||||
# Add agent-specific environment variables (these override sandbox-level)
|
||||
if agent_state and agent_state.secrets:
|
||||
for secret in agent_state.secrets:
|
||||
env_vars[secret.key] = secret.value
|
||||
|
||||
# Add any additional env vars passed at runtime (highest priority)
|
||||
if additional_env_vars:
|
||||
env_vars.update(additional_env_vars)
|
||||
|
||||
# Call the modal function (already retrieved at line 101)
|
||||
# Convert agent_state to dict to avoid cloudpickle serialization issues
|
||||
agent_state_dict = agent_state.model_dump() if agent_state else None
|
||||
|
||||
logger.info(f"Calling function {func} with arguments {self.args}")
|
||||
result = await func.remote.aio(
|
||||
tool_name=self.tool_name,
|
||||
agent_state=agent_state_dict,
|
||||
agent_id=agent_id,
|
||||
env_vars=env_vars,
|
||||
letta_api_key=letta_api_key,
|
||||
**self.args,
|
||||
)
|
||||
logger.info(f"Modal function result: {result}")
|
||||
|
||||
# Reconstruct agent_state if it was returned (use original as fallback)
|
||||
result_agent_state = agent_state
|
||||
if result.get("agent_state"):
|
||||
if isinstance(result["agent_state"], dict):
|
||||
try:
|
||||
from letta.schemas.agent import AgentState
|
||||
|
||||
result_agent_state = AgentState.model_validate(result["agent_state"])
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconstruct AgentState: {e}, using original")
|
||||
else:
|
||||
result_agent_state = result["agent_state"]
|
||||
|
||||
return ToolExecutionResult(
|
||||
func_return=result["result"],
|
||||
agent_state=result_agent_state,
|
||||
stdout=[result["stdout"]],
|
||||
stderr=[result["stderr"]],
|
||||
status="error" if result["error"] else "success",
|
||||
)
|
||||
except Exception as e:
|
||||
log_event(
|
||||
"modal_execution_failed",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"app_name": self._app_name,
|
||||
"error_type": result["error"]["name"],
|
||||
"error_message": result["error"]["value"],
|
||||
"func_return": func_return,
|
||||
"modal_app_id": self.tool.id,
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
# Parse the result from stdout even if there was an error
|
||||
# (in case the function returned something before failing)
|
||||
agent_state = None # Initialize agent_state
|
||||
try:
|
||||
func_return_parsed, agent_state_parsed = parse_stdout_best_effort(result["stdout"])
|
||||
if func_return_parsed is not None:
|
||||
func_return = func_return_parsed
|
||||
agent_state = agent_state_parsed
|
||||
except Exception:
|
||||
# If parsing fails, keep the error message
|
||||
pass
|
||||
else:
|
||||
func_return, agent_state = parse_stdout_best_effort(result["stdout"])
|
||||
log_event(
|
||||
"modal_execution_succeeded",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"app_name": self._app_name,
|
||||
"func_return": func_return,
|
||||
},
|
||||
)
|
||||
|
||||
logger.error(f"Modal execution failed for tool {self.tool_name} {self.tool.id}: {e}")
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
func_return=None,
|
||||
agent_state=agent_state,
|
||||
stdout=[result["stdout"]] if result["stdout"] else [],
|
||||
stderr=[result["stderr"]] if result["stderr"] else [],
|
||||
status="error" if result["error"] else "success",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Modal execution for tool {self.tool_name} encountered an error: {e}")
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name,
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
log_event(
|
||||
"modal_execution_error",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"app_name": self._app_name,
|
||||
"error": str(e),
|
||||
"func_return": func_return,
|
||||
},
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=None,
|
||||
stdout=[],
|
||||
stdout=[""],
|
||||
stderr=[str(e)],
|
||||
status="error",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
def _get_modal_image(self, sbx_config: SandboxConfig) -> modal.Image:
|
||||
"""Get Modal image with required public python dependencies.
|
||||
|
||||
Caching and rebuilding is handled in a cascading manner
|
||||
https://modal.com/docs/guide/images#image-caching-and-rebuilds
|
||||
"""
|
||||
image = modal.Image.debian_slim(python_version="3.12")
|
||||
|
||||
all_requirements = ["letta"]
|
||||
|
||||
# Add sandbox-specific pip requirements
|
||||
modal_configs = sbx_config.get_modal_config()
|
||||
if modal_configs.pip_requirements:
|
||||
all_requirements.extend([str(req) for req in modal_configs.pip_requirements])
|
||||
|
||||
# Add tool-specific pip requirements
|
||||
if self.tool and self.tool.pip_requirements:
|
||||
all_requirements.extend([str(req) for req in self.tool.pip_requirements])
|
||||
|
||||
if all_requirements:
|
||||
image = image.pip_install(*all_requirements)
|
||||
|
||||
return image
|
||||
|
||||
def use_top_level_await(self) -> bool:
|
||||
"""
|
||||
Modal functions don't have an active event loop by default,
|
||||
so we should use asyncio.run() like local execution.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class TypescriptToolSandboxModal(AsyncToolSandboxModal):
|
||||
"""Modal sandbox implementation for TypeScript tools."""
|
||||
|
||||
@trace_method
|
||||
async def run(
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""Run TypeScript tool in Modal sandbox using Node.js server."""
|
||||
if self.provided_sandbox_config:
|
||||
sbx_config = self.provided_sandbox_config
|
||||
else:
|
||||
sbx_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async(
|
||||
sandbox_type=SandboxType.MODAL, actor=self.user
|
||||
)
|
||||
|
||||
envs = await self._gather_env_vars(agent_state, additional_env_vars or {}, sbx_config.id, is_local=False)
|
||||
|
||||
# Generate execution script (JSON args for TypeScript)
|
||||
json_args = await self.generate_execution_script(agent_state=agent_state)
|
||||
|
||||
try:
|
||||
log_event(
|
||||
"modal_typescript_execution_started",
|
||||
{"tool": self.tool_name, "app_name": self._app_name, "args": json_args},
|
||||
)
|
||||
|
||||
# Create Modal app with the TypeScript Node.js server
|
||||
app = await self._fetch_or_create_modal_app(sbx_config, envs)
|
||||
|
||||
# Execute the TypeScript tool remotely via the Node.js server
|
||||
with app.run():
|
||||
# Get the NodeShimServer class from Modal
|
||||
node_server = modal.Cls.from_name(self._app_name, "NodeShimServer")
|
||||
|
||||
# Call the remote_executor method with the JSON arguments
|
||||
# The server will parse the JSON and call the TypeScript function
|
||||
result = node_server().remote_executor.remote(json_args)
|
||||
|
||||
# Process the TypeScript execution result
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
# Handle errors from TypeScript execution
|
||||
logger.debug(f"TypeScript tool {self.tool_name} raised an error: {result['error']}")
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name,
|
||||
exception_name="TypeScriptError",
|
||||
exception_message=str(result["error"]),
|
||||
)
|
||||
log_event(
|
||||
"modal_typescript_execution_failed",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"app_name": self._app_name,
|
||||
"error": result["error"],
|
||||
"func_return": func_return,
|
||||
},
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=None, # TypeScript tools don't support agent_state yet
|
||||
stdout=[],
|
||||
stderr=[str(result["error"])],
|
||||
status="error",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
else:
|
||||
# Success case - TypeScript function returned a result
|
||||
func_return = str(result) if result is not None else ""
|
||||
log_event(
|
||||
"modal_typescript_execution_succeeded",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"app_name": self._app_name,
|
||||
"func_return": func_return,
|
||||
},
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=None, # TypeScript tools don't support agent_state yet
|
||||
stdout=[],
|
||||
stderr=[],
|
||||
status="success",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Modal TypeScript execution for tool {self.tool_name} encountered an error: {e}")
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name,
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
log_event(
|
||||
"modal_typescript_execution_error",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"app_name": self._app_name,
|
||||
"error": str(e),
|
||||
"func_return": func_return,
|
||||
},
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=None,
|
||||
stdout=[],
|
||||
stderr=[str(e)],
|
||||
status="error",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
async def _fetch_or_create_modal_app(self, sbx_config: SandboxConfig, env_vars: Dict[str, str]) -> modal.App:
|
||||
"""Create or fetch a Modal app with TypeScript execution capabilities."""
|
||||
try:
|
||||
return await modal.App.lookup.aio(self._app_name)
|
||||
except:
|
||||
app = modal.App(self._app_name)
|
||||
|
||||
modal_config = sbx_config.get_modal_config()
|
||||
|
||||
# Get the base image with dependencies
|
||||
image = self._get_modal_image(sbx_config)
|
||||
|
||||
# Import the NodeShimServer that will handle TypeScript execution
|
||||
from sandbox.node_server import NodeShimServer
|
||||
|
||||
# Register the NodeShimServer class with Modal
|
||||
# This creates a serverless function that can handle concurrent requests
|
||||
app.cls(image=image, restrict_modal_access=True, include_source=False, timeout=modal_config.timeout if modal_config else 60)(
|
||||
modal.concurrent(max_inputs=100, target_inputs=50)(NodeShimServer)
|
||||
)
|
||||
|
||||
# Deploy the app to Modal
|
||||
with modal.enable_output():
|
||||
await app.deploy.aio()
|
||||
|
||||
return app
|
||||
|
||||
async def generate_execution_script(self, agent_state: Optional[AgentState], wrap_print_with_markers: bool = False) -> str:
|
||||
"""Generate the execution script for TypeScript tools.
|
||||
|
||||
For TypeScript tools, this returns the JSON-encoded arguments that will be passed
|
||||
to the Node.js server via the remote_executor method.
|
||||
"""
|
||||
import json
|
||||
|
||||
# Convert args to JSON string for TypeScript execution
|
||||
# The Node.js server expects JSON-encoded arguments
|
||||
return json.dumps(self.args)
|
||||
|
||||
def _get_modal_image(self, sbx_config: SandboxConfig) -> modal.Image:
|
||||
"""Build a Modal image with Node.js, TypeScript, and the user's tool function."""
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
# Find the sandbox module location
|
||||
spec = importlib.util.find_spec("sandbox")
|
||||
if not spec or not spec.origin:
|
||||
raise ValueError("Could not find sandbox module")
|
||||
server_dir = Path(spec.origin).parent
|
||||
|
||||
# Get the TypeScript function source code
|
||||
if not self.tool or not self.tool.source_code:
|
||||
raise ValueError("TypeScript tool must have source code")
|
||||
|
||||
ts_function = self.tool.source_code
|
||||
|
||||
# Get npm dependencies from sandbox config and tool
|
||||
modal_config = sbx_config.get_modal_config()
|
||||
npm_dependencies = []
|
||||
|
||||
# Add dependencies from sandbox config
|
||||
if modal_config and modal_config.npm_requirements:
|
||||
npm_dependencies.extend(modal_config.npm_requirements)
|
||||
|
||||
# Add dependencies from the tool itself
|
||||
if self.tool.npm_requirements:
|
||||
npm_dependencies.extend(self.tool.npm_requirements)
|
||||
|
||||
# Build npm install command for user dependencies
|
||||
user_dependencies_cmd = ""
|
||||
if npm_dependencies:
|
||||
# Ensure unique dependencies
|
||||
unique_deps = list(set(npm_dependencies))
|
||||
user_dependencies_cmd = " && npm install " + " ".join(unique_deps)
|
||||
|
||||
# Escape single quotes in the TypeScript function for shell command
|
||||
escaped_ts_function = ts_function.replace("'", "'\\''")
|
||||
|
||||
# Build the Docker image with Node.js and TypeScript
|
||||
image = (
|
||||
modal.Image.from_registry("node:22-slim", add_python="3.12")
|
||||
.add_local_dir(server_dir, "/root/sandbox", ignore=["node_modules", "build"], copy=True)
|
||||
.run_commands(
|
||||
# Install dependencies and build the TypeScript server
|
||||
f"cd /root/sandbox/resources/server && npm install{user_dependencies_cmd}",
|
||||
# Write the user's TypeScript function to a file
|
||||
f"echo '{escaped_ts_function}' > /root/sandbox/user-function.ts",
|
||||
)
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
# probably need to do parse_stdout_best_effort
|
||||
|
||||
@@ -39,12 +39,20 @@ class ToolSettings(BaseSettings):
|
||||
mcp_read_from_config: bool = False # if False, will throw if attempting to read/write from file
|
||||
mcp_disable_stdio: bool = False
|
||||
|
||||
@property
|
||||
def modal_sandbox_enabled(self) -> bool:
|
||||
"""Check if Modal credentials are configured."""
|
||||
return bool(self.modal_token_id and self.modal_token_secret)
|
||||
|
||||
@property
|
||||
def sandbox_type(self) -> SandboxType:
|
||||
"""Default sandbox type based on available credentials.
|
||||
|
||||
Note: Modal is checked separately via modal_sandbox_enabled property.
|
||||
This property determines the fallback behavior (E2B or LOCAL).
|
||||
"""
|
||||
if self.e2b_api_key:
|
||||
return SandboxType.E2B
|
||||
# elif self.modal_token_id and self.modal_token_secret:
|
||||
# return SandboxType.MODAL
|
||||
else:
|
||||
return SandboxType.LOCAL
|
||||
|
||||
|
||||
@@ -169,6 +169,15 @@ def check_e2b_key_is_set():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_modal_key_is_set():
|
||||
from letta.settings import tool_settings
|
||||
|
||||
assert tool_settings.modal_token_id is not None, "Missing modal token id! Cannot execute these tests."
|
||||
assert tool_settings.modal_token_secret is not None, "Missing modal token secret! Cannot execute these tests."
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def default_organization():
|
||||
"""Fixture to create and return the default organization."""
|
||||
|
||||
770
tests/integration_test_modal.py
Normal file
770
tests/integration_test_modal.py
Normal file
@@ -0,0 +1,770 @@
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.functions.function_sets.base import core_memory_append, core_memory_replace
|
||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.schemas.agent import AgentState, CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.pip_requirement import PipRequirement
|
||||
from letta.schemas.sandbox_config import LocalSandboxConfig, ModalSandboxConfig, SandboxConfigCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.server.db import db_registry
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.tool_sandbox.modal_sandbox import AsyncToolSandboxModal
|
||||
from letta.services.user_manager import UserManager
|
||||
from tests.helpers.utils import create_tool_from_func
|
||||
|
||||
# Constants
|
||||
namespace = uuid.NAMESPACE_DNS
|
||||
org_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-org"))
|
||||
user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user"))
|
||||
|
||||
# Set environment variable immediately to prevent pooling issues
|
||||
os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] = "true"
|
||||
|
||||
|
||||
# Disable SQLAlchemy connection pooling for tests to prevent event loop issues
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def disable_db_pooling_for_tests():
|
||||
"""Disable database connection pooling for the entire test session."""
|
||||
# Environment variable is already set above and settings reloaded
|
||||
yield
|
||||
# Clean up environment variable after tests
|
||||
if "LETTA_DISABLE_SQLALCHEMY_POOLING" in os.environ:
|
||||
del os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"]
|
||||
|
||||
|
||||
# @pytest.fixture(autouse=True)
|
||||
# async def cleanup_db_connections():
|
||||
# """Cleanup database connections after each test."""
|
||||
# yield
|
||||
#
|
||||
# # Dispose async engines in the current event loop
|
||||
# try:
|
||||
# await close_db()
|
||||
# except Exception as e:
|
||||
# # Log the error but don't fail the test
|
||||
# print(f"Warning: Failed to cleanup database connections: {e}")
|
||||
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
"""
|
||||
Creates a SyncServer instance for testing.
|
||||
|
||||
Loads and saves config to ensure proper initialization.
|
||||
"""
|
||||
config = LettaConfig.load()
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer(init_with_default_org_and_user=True)
|
||||
# create user/org
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def clear_tables():
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
from letta.server.db import db_registry
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
await session.execute(delete(SandboxEnvironmentVariable))
|
||||
await session.execute(delete(SandboxConfig))
|
||||
await session.commit() # Commit the deletion
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_organization():
|
||||
"""Fixture to create and return the default organization."""
|
||||
org = await OrganizationManager().create_organization_async(Organization(name=org_name))
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user(test_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = await UserManager().create_actor_async(User(name=user_name, organization_id=test_organization.id))
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def add_integers_tool(test_user):
|
||||
def add(x: int, y: int) -> int:
|
||||
"""
|
||||
Simple function that adds two integers.
|
||||
|
||||
Parameters:
|
||||
x (int): The first integer to add.
|
||||
y (int): The second integer to add.
|
||||
|
||||
Returns:
|
||||
int: The result of adding x and y.
|
||||
"""
|
||||
return x + y
|
||||
|
||||
tool = create_tool_from_func(add)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def cowsay_tool(test_user):
|
||||
# This defines a tool for a package we definitely do NOT have in letta
|
||||
# If this test passes, that means the tool was correctly executed in a separate Python environment
|
||||
def cowsay() -> str:
|
||||
"""
|
||||
Simple function that uses the cowsay package to print out the secret word env variable.
|
||||
|
||||
Returns:
|
||||
str: The cowsay ASCII art.
|
||||
"""
|
||||
import os
|
||||
|
||||
import cowsay
|
||||
|
||||
cowsay.cow(os.getenv("secret_word"))
|
||||
|
||||
tool = create_tool_from_func(cowsay)
|
||||
# Add cowsay as a pip requirement for Modal
|
||||
tool.pip_requirements = [PipRequirement(name="cowsay")]
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_env_tool(test_user):
|
||||
def get_env() -> str:
|
||||
"""
|
||||
Simple function that returns the secret word env variable.
|
||||
|
||||
Returns:
|
||||
str: The secret word
|
||||
"""
|
||||
import os
|
||||
|
||||
secret_word = os.getenv("secret_word")
|
||||
print(secret_word)
|
||||
return secret_word
|
||||
|
||||
tool = create_tool_from_func(get_env)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_warning_tool(test_user):
|
||||
def warn_hello_world() -> str:
|
||||
"""
|
||||
Simple function that warns hello world.
|
||||
|
||||
Returns:
|
||||
str: hello world
|
||||
"""
|
||||
import warnings
|
||||
|
||||
msg = "Hello World"
|
||||
warnings.warn(msg)
|
||||
return msg
|
||||
|
||||
tool = create_tool_from_func(warn_hello_world)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def always_err_tool(test_user):
|
||||
def error() -> str:
|
||||
"""
|
||||
Simple function that errors
|
||||
|
||||
Returns:
|
||||
str: not important
|
||||
"""
|
||||
# Raise a unusual error so we know it's from this function
|
||||
print("Going to error now")
|
||||
raise ZeroDivisionError("This is an intentionally weird division!")
|
||||
|
||||
tool = create_tool_from_func(error)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def list_tool(test_user):
|
||||
def create_list():
|
||||
"""Simple function that returns a list"""
|
||||
|
||||
return [1] * 5
|
||||
|
||||
tool = create_tool_from_func(create_list)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def clear_core_memory_tool(test_user):
|
||||
def clear_memory(agent_state: "AgentState"):
|
||||
"""Clear the core memory"""
|
||||
agent_state.memory.get_block("human").value = ""
|
||||
agent_state.memory.get_block("persona").value = ""
|
||||
|
||||
tool = create_tool_from_func(clear_memory)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def external_codebase_tool(test_user):
|
||||
from tests.test_tool_sandbox.restaurant_management_system.adjust_menu_prices import adjust_menu_prices
|
||||
|
||||
tool = create_tool_from_func(adjust_menu_prices)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def agent_state(server: SyncServer):
|
||||
await server.init_async(init_with_default_org_and_user=True)
|
||||
actor = await server.user_manager.create_default_actor_async()
|
||||
agent_state = await server.create_agent_async(
|
||||
CreateAgent(
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="human",
|
||||
value="username: sarah",
|
||||
),
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="This is the persona",
|
||||
),
|
||||
],
|
||||
include_base_tools=True,
|
||||
model="openai/gpt-4o-mini",
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
agent_state.tool_rules = []
|
||||
yield agent_state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def custom_test_sandbox_config(test_user):
|
||||
"""
|
||||
Fixture to create a consistent local sandbox configuration for tests.
|
||||
|
||||
Args:
|
||||
test_user: The test user to be used for creating the sandbox configuration.
|
||||
|
||||
Returns:
|
||||
A tuple containing the SandboxConfigManager and the created sandbox configuration.
|
||||
"""
|
||||
# Create the SandboxConfigManager
|
||||
manager = SandboxConfigManager()
|
||||
|
||||
# Set the sandbox to be within the external codebase path and use a venv
|
||||
external_codebase_path = str(Path(__file__).parent / "test_tool_sandbox" / "restaurant_management_system")
|
||||
# tqdm is used in this codebase, but NOT in the requirements.txt, this tests that we can successfully install pip requirements
|
||||
local_sandbox_config = LocalSandboxConfig(
|
||||
sandbox_dir=external_codebase_path, use_venv=True, pip_requirements=[PipRequirement(name="tqdm")]
|
||||
)
|
||||
|
||||
# Create the sandbox configuration
|
||||
config_create = SandboxConfigCreate(config=local_sandbox_config.model_dump())
|
||||
|
||||
# Create or update the sandbox configuration
|
||||
await manager.create_or_update_sandbox_config_async(sandbox_config_create=config_create, actor=test_user)
|
||||
|
||||
return manager, local_sandbox_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def core_memory_tools(test_user):
|
||||
"""Create all base tools for testing."""
|
||||
tools = {}
|
||||
for func in [
|
||||
core_memory_replace,
|
||||
core_memory_append,
|
||||
]:
|
||||
tool = create_tool_from_func(func)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
tools[func.__name__] = tool
|
||||
yield tools
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_add_integers_tool(test_user):
|
||||
async def async_add(x: int, y: int) -> int:
|
||||
"""
|
||||
Async function that adds two integers.
|
||||
|
||||
Parameters:
|
||||
x (int): The first integer to add.
|
||||
y (int): The second integer to add.
|
||||
|
||||
Returns:
|
||||
int: The result of adding x and y.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Add a small delay to simulate async work
|
||||
await asyncio.sleep(0.1)
|
||||
return x + y
|
||||
|
||||
tool = create_tool_from_func(async_add)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_get_env_tool(test_user):
|
||||
async def async_get_env() -> str:
|
||||
"""
|
||||
Async function that returns the secret word env variable.
|
||||
|
||||
Returns:
|
||||
str: The secret word
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
# Add a small delay to simulate async work
|
||||
await asyncio.sleep(0.1)
|
||||
secret_word = os.getenv("secret_word")
|
||||
print(secret_word)
|
||||
return secret_word
|
||||
|
||||
tool = create_tool_from_func(async_get_env)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_stateful_tool(test_user):
|
||||
async def async_clear_memory(agent_state: "AgentState"):
|
||||
"""Async function that clears the core memory"""
|
||||
import asyncio
|
||||
|
||||
# Add a small delay to simulate async work
|
||||
await asyncio.sleep(0.1)
|
||||
agent_state.memory.get_block("human").value = ""
|
||||
agent_state.memory.get_block("persona").value = ""
|
||||
|
||||
tool = create_tool_from_func(async_clear_memory)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_error_tool(test_user):
|
||||
async def async_error() -> str:
|
||||
"""
|
||||
Async function that errors
|
||||
|
||||
Returns:
|
||||
str: not important
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Add some async work before erroring
|
||||
await asyncio.sleep(0.1)
|
||||
print("Going to error now")
|
||||
raise ValueError("This is an intentional async error!")
|
||||
|
||||
tool = create_tool_from_func(async_error)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_list_tool(test_user):
|
||||
async def async_create_list() -> list:
|
||||
"""Async function that returns a list"""
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
return [1, 2, 3, 4, 5]
|
||||
|
||||
tool = create_tool_from_func(async_create_list)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def tool_with_pip_requirements(test_user):
|
||||
def use_requests_and_numpy() -> str:
|
||||
"""
|
||||
Function that uses requests and numpy packages to test tool-specific pip requirements.
|
||||
|
||||
Returns:
|
||||
str: Success message if packages are available.
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
# Simple usage to verify packages work
|
||||
response = requests.get("https://httpbin.org/json", timeout=30)
|
||||
arr = np.array([1, 2, 3])
|
||||
return f"Success! Status: {response.status_code}, Array sum: {np.sum(arr)}"
|
||||
except ImportError as e:
|
||||
return f"Import error: {e}"
|
||||
except Exception as e:
|
||||
return f"Other error: {e}"
|
||||
|
||||
tool = create_tool_from_func(use_requests_and_numpy)
|
||||
# Add pip requirements to the tool
|
||||
tool.pip_requirements = [
|
||||
PipRequirement(name="requests", version="2.31.0"),
|
||||
PipRequirement(name="numpy"),
|
||||
]
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_complex_tool(test_user):
|
||||
async def async_complex_computation(iterations: int = 3) -> dict:
|
||||
"""
|
||||
Async function that performs complex computation with multiple awaits.
|
||||
|
||||
Parameters:
|
||||
iterations (int): Number of iterations to perform.
|
||||
|
||||
Returns:
|
||||
dict: Results of the computation.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
results = []
|
||||
start_time = time.time()
|
||||
|
||||
for i in range(iterations):
|
||||
# Simulate async I/O
|
||||
await asyncio.sleep(0.1)
|
||||
results.append(i * 2)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"duration": end_time - start_time,
|
||||
"iterations": iterations,
|
||||
"average": sum(results) / len(results) if results else 0,
|
||||
}
|
||||
|
||||
tool = create_tool_from_func(async_complex_computation)
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
# Modal sandbox tests
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_default(check_modal_key_is_set, add_integers_tool, test_user):
|
||||
args = {"x": 10, "y": 5}
|
||||
|
||||
# Mock and assert correct pathway was invoked
|
||||
with patch.object(AsyncToolSandboxModal, "run") as mock_run:
|
||||
sandbox = AsyncToolSandboxModal(add_integers_tool.name, args, user=test_user)
|
||||
await sandbox.run()
|
||||
mock_run.assert_called_once()
|
||||
|
||||
# Run again to get actual response
|
||||
sandbox = AsyncToolSandboxModal(add_integers_tool.name, args, user=test_user)
|
||||
result = await sandbox.run()
|
||||
assert int(result.func_return) == args["x"] + args["y"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_pip_installs(check_modal_key_is_set, cowsay_tool, test_user):
|
||||
"""Test that Modal sandbox installs tool-level pip requirements."""
|
||||
manager = SandboxConfigManager()
|
||||
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
|
||||
config = await manager.create_or_update_sandbox_config_async(config_create, test_user)
|
||||
|
||||
key = "secret_word"
|
||||
long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20))
|
||||
await manager.create_sandbox_env_var_async(
|
||||
SandboxEnvironmentVariableCreate(key=key, value=long_random_string),
|
||||
sandbox_config_id=config.id,
|
||||
actor=test_user,
|
||||
)
|
||||
|
||||
sandbox = AsyncToolSandboxModal(cowsay_tool.name, {}, user=test_user)
|
||||
result = await sandbox.run()
|
||||
assert long_random_string in result.stdout[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_stateful_tool(check_modal_key_is_set, clear_core_memory_tool, test_user, agent_state):
|
||||
sandbox = AsyncToolSandboxModal(clear_core_memory_tool.name, {}, user=test_user)
|
||||
result = await sandbox.run(agent_state=agent_state)
|
||||
assert result.agent_state.memory.get_block("human").value == ""
|
||||
assert result.agent_state.memory.get_block("persona").value == ""
|
||||
assert result.func_return is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_inject_env_var_existing_sandbox(check_modal_key_is_set, get_env_tool, test_user):
|
||||
manager = SandboxConfigManager()
|
||||
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
|
||||
config = await manager.create_or_update_sandbox_config_async(config_create, test_user)
|
||||
|
||||
sandbox = AsyncToolSandboxModal(get_env_tool.name, {}, user=test_user)
|
||||
result = await sandbox.run()
|
||||
assert result.func_return is None
|
||||
|
||||
key = "secret_word"
|
||||
long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20))
|
||||
await manager.create_sandbox_env_var_async(
|
||||
SandboxEnvironmentVariableCreate(key=key, value=long_random_string),
|
||||
sandbox_config_id=config.id,
|
||||
actor=test_user,
|
||||
)
|
||||
|
||||
sandbox = AsyncToolSandboxModal(get_env_tool.name, {}, user=test_user)
|
||||
result = await sandbox.run()
|
||||
assert long_random_string in result.func_return
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_per_agent_env(check_modal_key_is_set, get_env_tool, agent_state, test_user):
|
||||
manager = SandboxConfigManager()
|
||||
key = "secret_word"
|
||||
wrong_val = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20))
|
||||
correct_val = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20))
|
||||
|
||||
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
|
||||
config = await manager.create_or_update_sandbox_config_async(config_create, test_user)
|
||||
await manager.create_sandbox_env_var_async(
|
||||
SandboxEnvironmentVariableCreate(key=key, value=wrong_val),
|
||||
sandbox_config_id=config.id,
|
||||
actor=test_user,
|
||||
)
|
||||
|
||||
agent_state.secrets = [AgentEnvironmentVariable(key=key, value=correct_val, agent_id=agent_state.id)]
|
||||
|
||||
sandbox = AsyncToolSandboxModal(get_env_tool.name, {}, user=test_user)
|
||||
result = await sandbox.run(agent_state=agent_state)
|
||||
assert wrong_val not in result.func_return
|
||||
assert correct_val in result.func_return
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_with_list_rv(check_modal_key_is_set, list_tool, test_user):
|
||||
sandbox = AsyncToolSandboxModal(list_tool.name, {}, user=test_user)
|
||||
result = await sandbox.run()
|
||||
assert len(result.func_return) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_with_tool_pip_requirements(check_modal_key_is_set, tool_with_pip_requirements, test_user):
|
||||
"""Test that Modal sandbox installs tool-specific pip requirements."""
|
||||
manager = SandboxConfigManager()
|
||||
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
|
||||
await manager.create_or_update_sandbox_config_async(config_create, test_user)
|
||||
|
||||
sandbox = AsyncToolSandboxModal(tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements)
|
||||
result = await sandbox.run()
|
||||
|
||||
# Should succeed since tool pip requirements were installed
|
||||
assert "Success!" in result.func_return
|
||||
assert "Status: 200" in result.func_return
|
||||
assert "Array sum: 6" in result.func_return
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_with_mixed_pip_requirements(check_modal_key_is_set, tool_with_pip_requirements, test_user):
|
||||
"""Test that Modal sandbox installs tool pip requirements.
|
||||
|
||||
Note: Modal does not support sandbox-level pip requirements - all pip requirements
|
||||
must be specified at the tool level since the Modal app is deployed with a fixed image.
|
||||
"""
|
||||
manager = SandboxConfigManager()
|
||||
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
|
||||
await manager.create_or_update_sandbox_config_async(config_create, test_user)
|
||||
|
||||
sandbox = AsyncToolSandboxModal(tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements)
|
||||
result = await sandbox.run()
|
||||
|
||||
# Should succeed since tool pip requirements were installed
|
||||
assert "Success!" in result.func_return
|
||||
assert "Status: 200" in result.func_return
|
||||
assert "Array sum: 6" in result.func_return
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_with_broken_tool_pip_requirements_error_handling(check_modal_key_is_set, test_user):
|
||||
"""Test that Modal sandbox provides informative error messages for broken tool pip requirements."""
|
||||
|
||||
def use_broken_package() -> str:
|
||||
"""
|
||||
Function that tries to use packages with broken version constraints.
|
||||
|
||||
Returns:
|
||||
str: Success message if packages are available.
|
||||
"""
|
||||
return "Should not reach here"
|
||||
|
||||
tool = create_tool_from_func(use_broken_package)
|
||||
# Add broken pip requirements
|
||||
tool.pip_requirements = [
|
||||
PipRequirement(name="numpy", version="1.24.0"), # Old version incompatible with newer Python
|
||||
PipRequirement(name="nonexistent-package-12345"), # Non-existent package
|
||||
]
|
||||
# expect a LettaInvalidArgumentError
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_detection(add_integers_tool, async_add_integers_tool, test_user):
|
||||
"""Test that async function detection works correctly"""
|
||||
# Test sync function detection
|
||||
sync_sandbox = AsyncToolSandboxModal(add_integers_tool.name, {}, test_user, tool_object=add_integers_tool)
|
||||
await sync_sandbox._init_async()
|
||||
assert not sync_sandbox.is_async_function
|
||||
|
||||
# Test async function detection
|
||||
async_sandbox = AsyncToolSandboxModal(async_add_integers_tool.name, {}, test_user, tool_object=async_add_integers_tool)
|
||||
await async_sandbox._init_async()
|
||||
assert async_sandbox.is_async_function
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_async_function_execution(check_modal_key_is_set, async_add_integers_tool, test_user):
|
||||
"""Test that async functions execute correctly in Modal sandbox"""
|
||||
args = {"x": 20, "y": 30}
|
||||
|
||||
sandbox = AsyncToolSandboxModal(async_add_integers_tool.name, args, user=test_user)
|
||||
result = await sandbox.run()
|
||||
assert int(result.func_return) == args["x"] + args["y"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_async_complex_computation(check_modal_key_is_set, async_complex_tool, test_user):
|
||||
"""Test complex async computation with multiple awaits in Modal sandbox"""
|
||||
args = {"iterations": 2}
|
||||
|
||||
sandbox = AsyncToolSandboxModal(async_complex_tool.name, args, user=test_user)
|
||||
result = await sandbox.run()
|
||||
|
||||
func_return = result.func_return
|
||||
assert isinstance(func_return, dict)
|
||||
assert func_return["results"] == [0, 2]
|
||||
assert func_return["iterations"] == 2
|
||||
assert func_return["average"] == 1.0
|
||||
assert func_return["duration"] > 0.15
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_async_list_return(check_modal_key_is_set, async_list_tool, test_user):
|
||||
"""Test async function returning list in Modal sandbox"""
|
||||
sandbox = AsyncToolSandboxModal(async_list_tool.name, {}, user=test_user)
|
||||
result = await sandbox.run()
|
||||
assert result.func_return == [1, 2, 3, 4, 5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_async_with_env_vars(check_modal_key_is_set, async_get_env_tool, test_user):
|
||||
"""Test async function with environment variables in Modal sandbox"""
|
||||
manager = SandboxConfigManager()
|
||||
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
|
||||
config = await manager.create_or_update_sandbox_config_async(config_create, test_user)
|
||||
|
||||
# Create environment variable
|
||||
key = "secret_word"
|
||||
test_value = "async_modal_test_value_456"
|
||||
await manager.create_sandbox_env_var_async(
|
||||
SandboxEnvironmentVariableCreate(key=key, value=test_value), sandbox_config_id=config.id, actor=test_user
|
||||
)
|
||||
|
||||
sandbox = AsyncToolSandboxModal(async_get_env_tool.name, {}, user=test_user)
|
||||
result = await sandbox.run()
|
||||
|
||||
assert test_value in result.func_return
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_async_with_agent_state(check_modal_key_is_set, async_stateful_tool, test_user, agent_state):
|
||||
"""Test async function with agent state in Modal sandbox"""
|
||||
sandbox = AsyncToolSandboxModal(async_stateful_tool.name, {}, user=test_user)
|
||||
result = await sandbox.run(agent_state=agent_state)
|
||||
|
||||
assert result.agent_state.memory.get_block("human").value == ""
|
||||
assert result.agent_state.memory.get_block("persona").value == ""
|
||||
assert result.func_return is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_async_error_handling(check_modal_key_is_set, async_error_tool, test_user):
|
||||
"""Test async function error handling in Modal sandbox"""
|
||||
sandbox = AsyncToolSandboxModal(async_error_tool.name, {}, user=test_user)
|
||||
result = await sandbox.run()
|
||||
|
||||
# Check that error was captured
|
||||
assert len(result.stdout) != 0, "stdout not empty"
|
||||
assert "error" in result.stdout[0], "stdout contains printed string"
|
||||
assert len(result.stderr) != 0, "stderr not empty"
|
||||
assert "ValueError: This is an intentional async error!" in result.stderr[0], "stderr contains expected error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.modal_sandbox
|
||||
async def test_modal_sandbox_async_per_agent_env(check_modal_key_is_set, async_get_env_tool, agent_state, test_user):
|
||||
"""Test async function with per-agent environment variables in Modal sandbox"""
|
||||
manager = SandboxConfigManager()
|
||||
key = "secret_word"
|
||||
wrong_val = "wrong_async_modal_value"
|
||||
correct_val = "correct_async_modal_value"
|
||||
|
||||
config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump())
|
||||
config = await manager.create_or_update_sandbox_config_async(config_create, test_user)
|
||||
await manager.create_sandbox_env_var_async(
|
||||
SandboxEnvironmentVariableCreate(key=key, value=wrong_val),
|
||||
sandbox_config_id=config.id,
|
||||
actor=test_user,
|
||||
)
|
||||
|
||||
agent_state.secrets = [AgentEnvironmentVariable(key=key, value=correct_val, agent_id=agent_state.id)]
|
||||
|
||||
sandbox = AsyncToolSandboxModal(async_get_env_tool.name, {}, user=test_user)
|
||||
result = await sandbox.run(agent_state=agent_state)
|
||||
assert wrong_val not in result.func_return
|
||||
assert correct_val in result.func_return
|
||||
@@ -11,6 +11,7 @@ filterwarnings =
|
||||
markers =
|
||||
local_sandbox: mark test as part of local sandbox tests
|
||||
e2b_sandbox: mark test as part of E2B sandbox tests
|
||||
modal_sandbox: mark test as part of Modal sandbox tests
|
||||
openai_basic: Tests for OpenAI endpoints
|
||||
anthropic_basic: Tests for Anthropic endpoints
|
||||
azure_basic: Tests for Azure endpoints
|
||||
|
||||
@@ -395,7 +395,7 @@ def test_function_always_error(client: Letta):
|
||||
assert response_message.status == "error"
|
||||
# TODO: add this back
|
||||
# assert "Error executing function testing_method" in response_message.tool_return, response_message.tool_return
|
||||
assert "ZeroDivisionError: division by zero" in response_message.stderr[0]
|
||||
assert "division by zero" in response_message.stderr[0]
|
||||
|
||||
client.agents.delete(agent_id=agent.id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user