From 5730f69ecff58c3d060f2a4545a25bd0e61f69c6 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 11 Nov 2025 18:21:51 -0800 Subject: [PATCH] feat: modal tool execution - NO FEATURE FLAGS USES MODAL [LET-4357] (#5120) * initial commit * add delay to deploy * fix tests * add tests * passing tests * cleanup * and use modal * working on modal * gate on tool metadata * agent state * cleanup --------- Co-authored-by: Letta Bot --- letta/constants.py | 18 + letta/helpers/tool_helpers.py | 69 ++ letta/schemas/sandbox_config.py | 5 +- letta/server/rest_api/routers/v1/tools.py | 1 + .../tool_executor/sandbox_tool_executor.py | 54 +- letta/services/tool_manager.py | 272 ++++++- letta/services/tool_sandbox/modal_sandbox.py | 491 +++-------- letta/settings.py | 12 +- tests/conftest.py | 9 + tests/integration_test_modal.py | 770 ++++++++++++++++++ tests/pytest.ini | 1 + tests/test_client.py | 2 +- 12 files changed, 1314 insertions(+), 390 deletions(-) create mode 100644 letta/helpers/tool_helpers.py create mode 100644 tests/integration_test_modal.py diff --git a/letta/constants.py b/letta/constants.py index 0e93f8ae..330edff3 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -44,6 +44,11 @@ IN_CONTEXT_MEMORY_KEYWORD = "CORE_MEMORY" # OpenAI error message: Invalid 'messages[1].tool_calls[0].id': string too long. Expected a string with maximum length 29, but got a string with length 36 instead. TOOL_CALL_ID_MAX_LEN = 29 +# Maximum length for tool names to support Modal deployment +# Modal function names are limited to 64 characters: tool_name + "_" + project_id +# Reserving 16 characters for project_id suffix (e.g., "_project-12345678") +MAX_TOOL_NAME_LENGTH = 48 + # Max steps for agent loop DEFAULT_MAX_STEPS = 50 @@ -440,5 +445,18 @@ EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES = ["claude-4-sonnet", "claude-3-5-so # But include models with these keywords in base tool rules (overrides exclusion) INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES = ["mini"] +# Deployment and versioning +MODAL_DEFAULT_TOOL_NAME = "modal_tool_wrapper..modal_function" # NOTE: must stay in sync with modal_tool_wrapper +MODAL_DEFAULT_CONFIG_KEY = "default" +MODAL_MODAL_DEPLOYMENTS_KEY = "modal_deployments" +MODAL_VERSION_HASH_LENGTH = 12 + +# Modal execution settings +MODAL_DEFAULT_TIMEOUT = 60 +MODAL_DEFAULT_MAX_CONCURRENT_INPUTS = 1 +MODAL_DEFAULT_PYTHON_VERSION = "3.12" + +# Security settings +MODAL_SAFE_IMPORT_MODULES = {"typing", "pydantic", "datetime", "enum", "uuid", "decimal"} # Default handle for model used to generate tools DEFAULT_GENERATE_TOOL_MODEL_HANDLE = "openai/gpt-4.1" diff --git a/letta/helpers/tool_helpers.py b/letta/helpers/tool_helpers.py new file mode 100644 index 00000000..2e554da0 --- /dev/null +++ b/letta/helpers/tool_helpers.py @@ -0,0 +1,69 @@ +import hashlib + +from letta.constants import MODAL_VERSION_HASH_LENGTH +from letta.schemas.tool import Tool + + +def _serialize_dependencies(tool: Tool) -> str: + """ + Serialize dependencies in a consistent way for hashing. + TODO: This should be improved per LET-3770 to ensure consistent ordering. + For now, we convert to string representation. + """ + parts = [] + + if tool.pip_requirements: + # TODO: Sort these consistently + parts.append(f"pip:{str(tool.pip_requirements)}") + + if tool.npm_requirements: + # TODO: Sort these consistently + parts.append(f"npm:{str(tool.npm_requirements)}") + + return ";".join(parts) + + +def compute_tool_hash(tool: Tool): + """ + Calculate a hash representing the current version of the tool and configuration. + This hash changes when: + - Tool source code changes + - Tool dependencies change + - Sandbox configuration changes + - Language/runtime changes + """ + components = [ + tool.source_code if tool.source_code else "", + tool.source_type if tool.source_type else "", + _serialize_dependencies(tool), + ] + + combined = "|".join(components) + return hashlib.sha256(combined.encode()).hexdigest()[:MODAL_VERSION_HASH_LENGTH] + + +def generate_modal_function_name(tool_name: str, organization_id: str, project_id: str = "default") -> str: + """ + Generate a Modal function name from tool name and project ID. + Shortens the project ID to just the prefix and first UUID segment. + + Args: + tool_name: Name of the tool + organization_id: Full organization ID (not used in function name, but kept for future use) + project_id: Project ID (e.g., project-12345678-90ab-cdef-1234-567890abcdef or "default") + + Returns: + Modal function name (e.g., tool_name_project-12345678 or tool_name_default) + """ + from letta.constants import MAX_TOOL_NAME_LENGTH + + max_tool_name_length = 64 + + # Shorten the organization_id to just the first segment (e.g., project-12345678) + short_organization_id = organization_id[: (max_tool_name_length - MAX_TOOL_NAME_LENGTH - 1)] + + # make extra sure the tool name is not too long + name = f"{tool_name[:MAX_TOOL_NAME_LENGTH]}_{short_organization_id}" + + # safe fallback + return name diff --git a/letta/schemas/sandbox_config.py b/letta/schemas/sandbox_config.py index ef282770..306a9faf 100644 --- a/letta/schemas/sandbox_config.py +++ b/letta/schemas/sandbox_config.py @@ -4,12 +4,11 @@ from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field, model_validator -from letta.constants import LETTA_TOOL_EXECUTION_DIR +from letta.constants import LETTA_TOOL_EXECUTION_DIR, MODAL_DEFAULT_TIMEOUT from letta.schemas.agent import AgentState from letta.schemas.enums import PrimitiveType, SandboxType from letta.schemas.letta_base import LettaBase, OrmMetadataBase from letta.schemas.pip_requirement import PipRequirement -from letta.services.tool_sandbox.modal_constants import DEFAULT_MODAL_TIMEOUT from letta.settings import tool_settings # Sandbox Config @@ -81,7 +80,7 @@ class E2BSandboxConfig(BaseModel): class ModalSandboxConfig(BaseModel): - timeout: int = Field(DEFAULT_MODAL_TIMEOUT, description="Time limit for the sandbox (in seconds).") + timeout: int = Field(MODAL_DEFAULT_TIMEOUT, description="Time limit for the sandbox (in seconds).") pip_requirements: list[str] | None = Field(None, description="A list of pip packages to install in the Modal sandbox") npm_requirements: list[str] | None = Field(None, description="A list of npm packages to install in the Modal sandbox") language: Literal["python", "typescript"] = "python" diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 894ba6aa..c6cf0d87 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -7,6 +7,7 @@ from httpx import ConnectError, HTTPStatusError from pydantic import BaseModel, Field from starlette.responses import StreamingResponse +from letta.constants import MAX_TOOL_NAME_LENGTH from letta.constants import DEFAULT_GENERATE_TOOL_MODEL_HANDLE from letta.errors import ( LettaInvalidArgumentError, diff --git a/letta/services/tool_executor/sandbox_tool_executor.py b/letta/services/tool_executor/sandbox_tool_executor.py index 4415bba3..64405bab 100644 --- a/letta/services/tool_executor/sandbox_tool_executor.py +++ b/letta/services/tool_executor/sandbox_tool_executor.py @@ -46,28 +46,41 @@ class SandboxToolExecutor(ToolExecutor): agent_state_copy = self._create_agent_state_copy(agent_state) if agent_state else None - # Execute in sandbox depending on API key - if tool_settings.sandbox_type == SandboxType.E2B: - from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B + # Execute in sandbox with Modal first (if configured and requested), then fallback to E2B/LOCAL + # Try Modal if: (1) Modal credentials configured AND (2) tool requests Modal via metadata + tool_requests_modal = tool.metadata_ and tool.metadata_.get("sandbox") == "modal" + modal_configured = tool_settings.modal_sandbox_enabled - sandbox = AsyncToolSandboxE2B( - function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars - ) - # TODO (cliandy): this is just for testing right now, separate this out into it's own subclass and handling logic - elif tool_settings.sandbox_type == SandboxType.MODAL: - from letta.services.tool_sandbox.modal_sandbox import AsyncToolSandboxModal, TypescriptToolSandboxModal + tool_execution_result = None - if tool.source_type == ToolSourceType.typescript: - sandbox = TypescriptToolSandboxModal( + # Try Modal first if both conditions met + if tool_requests_modal and modal_configured: + try: + from letta.services.tool_sandbox.modal_sandbox import AsyncToolSandboxModal + + logger.info(f"Attempting Modal execution for tool {tool.name}") + sandbox = AsyncToolSandboxModal( function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars, + organization_id=actor.organization_id, ) - elif tool.source_type == ToolSourceType.python: - sandbox = AsyncToolSandboxModal( + # TODO: pass through letta api key + tool_execution_result = await sandbox.run(agent_state=agent_state_copy, additional_env_vars=sandbox_env_vars) + except Exception as e: + # Modal execution failed, log and fall back to E2B/LOCAL + logger.warning(f"Modal execution failed for tool {tool.name}: {e}. Falling back to {tool_settings.sandbox_type.value}") + tool_execution_result = None + + # Fallback to E2B or LOCAL if Modal wasn't tried or failed + if tool_execution_result is None: + if tool_settings.sandbox_type == SandboxType.E2B: + from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B + + sandbox = AsyncToolSandboxE2B( function_name, function_args, actor, @@ -76,13 +89,16 @@ class SandboxToolExecutor(ToolExecutor): sandbox_env_vars=sandbox_env_vars, ) else: - raise ValueError(f"Tool source type was {tool.source_type} but is required to be python or typescript to run in Modal.") - else: - sandbox = AsyncToolSandboxLocal( - function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars - ) + sandbox = AsyncToolSandboxLocal( + function_name, + function_args, + actor, + tool_object=tool, + sandbox_config=sandbox_config, + sandbox_env_vars=sandbox_env_vars, + ) - tool_execution_result = await sandbox.run(agent_state=agent_state_copy) + tool_execution_result = await sandbox.run(agent_state=agent_state_copy) log_lines = (tool_execution_result.stdout or []) + (tool_execution_result.stderr or []) logger.debug("Tool execution log: %s", "\n".join(log_lines)) diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 33a2e4e4..03fbf28f 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -16,30 +16,176 @@ from letta.constants import ( LETTA_TOOL_MODULE_NAMES, LETTA_TOOL_SET, LOCAL_ONLY_MULTI_AGENT_TOOLS, + MAX_TOOL_NAME_LENGTH, MCP_TOOL_TAG_NAME_PREFIX, + MODAL_DEFAULT_TOOL_NAME, ) -from letta.errors import LettaToolNameConflictError, LettaToolNameSchemaMismatchError +from letta.errors import LettaInvalidArgumentError, LettaToolNameConflictError, LettaToolNameSchemaMismatchError from letta.functions.functions import derive_openai_json_schema, load_function_set +from letta.helpers.tool_helpers import compute_tool_hash, generate_modal_function_name from letta.log import get_logger # TODO: Remove this once we translate all of these to the ORM from letta.orm.errors import NoResultFound from letta.orm.tool import Tool as ToolModel from letta.otel.tracing import trace_method -from letta.schemas.enums import PrimitiveType, ToolType +from letta.schemas.agent import AgentState +from letta.schemas.enums import PrimitiveType, SandboxType, ToolType from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.services.helpers.agent_manager_helper import calculate_multi_agent_tools from letta.services.mcp.types import SSEServerConfig, StdioServerConfig from letta.services.tool_schema_generator import generate_schema_for_tool_creation, generate_schema_for_tool_update -from letta.settings import settings +from letta.settings import settings, tool_settings from letta.utils import enforce_types, printd from letta.validators import raise_on_invalid_id logger = get_logger(__name__) +# NOTE: function name and nested modal function decorator name must stay in sync with MODAL_DEFAULT_TOOL_NAME +def modal_tool_wrapper(tool: PydanticTool, actor: PydanticUser, sandbox_env_vars: dict = None, project_id: str = "default"): + """Create a Modal function wrapper for a tool""" + import contextlib + import io + import os + import sys + from typing import Optional + + import modal + from letta_client import Letta + + packages = [str(req) for req in tool.pip_requirements] if tool.pip_requirements else [] + packages.append("letta_client") + packages.append("letta") # Base letta without extras + packages.append("asyncpg>=0.30.0") # Fixes asyncpg import error + packages.append("psycopg2-binary>=2.9.10") # PostgreSQL adapter (pre-compiled, no build required) + # packages.append("pgvector>=0.3.6") # Vector operations support + + function_name = generate_modal_function_name(tool.name, actor.organization_id, project_id) + modal_app = modal.App(function_name) + logger.info(f"Creating Modal app {tool.id} with name {function_name}") + + # Create secrets dict with sandbox env vars + secrets_dict = {"LETTA_API_KEY": None} + if sandbox_env_vars: + secrets_dict.update(sandbox_env_vars) + + @modal_app.function( + image=modal.Image.debian_slim(python_version="3.13").pip_install(packages), + restrict_modal_access=True, + timeout=10, + secrets=[modal.Secret.from_dict(secrets_dict)], + serialized=True, + ) + def modal_function( + tool_name: str, agent_state: Optional[dict], agent_id: Optional[str], env_vars: dict, letta_api_key: Optional[str] = None, **kwargs + ): + """Wrapper function for running untrusted code in a Modal function""" + # Reconstruct AgentState from dict if passed (to avoid cloudpickle serialization issues) + # This is done with extra safety to handle schema mismatches between environments + reconstructed_agent_state = None + if agent_state: + try: + from letta.schemas.agent import AgentState as AgentStateModel + + # Filter dict to only include fields that exist in Modal's version of AgentState + # This prevents ValidationError from extra fields in newer schemas + modal_agent_fields = set(AgentStateModel.model_fields.keys()) + filtered_agent_state = {key: value for key, value in agent_state.items() if key in modal_agent_fields} + + # Try to reconstruct with filtered data + reconstructed_agent_state = AgentStateModel.model_validate(filtered_agent_state) + + # Log if we filtered out any fields + filtered_out = set(agent_state.keys()) - modal_agent_fields + if filtered_out: + print(f"Fields not in available in AgentState: {filtered_out}", file=sys.stderr) + + except ImportError as e: + print(f"Cannot import AgentState: {e}", file=sys.stderr) + print("Passing agent_state as dict to tool", file=sys.stderr) + reconstructed_agent_state = agent_state + except Exception as e: + print(f"Warning: Could not reconstruct AgentState (schema mismatch?): {e}", file=sys.stderr) + print("Passing agent_state as dict to tool", file=sys.stderr) + reconstructed_agent_state = agent_state + + # Set environment variables + if env_vars: + for key, value in env_vars.items(): + os.environ[key] = str(value) + + # Initialize the Letta client + if letta_api_key: + letta_client = Letta(token=letta_api_key, base_url=os.environ.get("LETTA_API_URL", "https://api.letta.com")) + else: + letta_client = None + + tool_namespace = { + "__builtins__": __builtins__, # Include built-in functions + "_letta_client": letta_client, # Make letta_client available + "os": os, # Include os module for env vars access + "agent_id": agent_id, + # Add any other modules/variables the tool might need + } + + # Initialize the tool code + # Create a namespace for the tool + # tool_namespace = {} + exec(tool.source_code, tool_namespace) + + # Get the tool function + if tool_name not in tool_namespace: + raise Exception(f"Tool function {tool_name} not found in {tool.source_code}, globals: {tool_namespace}") + tool_func = tool_namespace[tool_name] + + # Detect if the tool function is async + import asyncio + import inspect + import traceback + + is_async = inspect.iscoroutinefunction(tool_func) + + # Capture stdout and stderr during tool execution + stdout_capture = io.StringIO() + stderr_capture = io.StringIO() + result = None + error_occurred = False + + with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture): + try: + # if `agent_state` is in the tool function arguments, inject it + # Pass reconstructed AgentState (or dict if reconstruction failed) + if "agent_state" in tool_func.__code__.co_varnames: + kwargs["agent_state"] = reconstructed_agent_state + + # Execute the tool function (async or sync) + if is_async: + result = asyncio.run(tool_func(**kwargs)) + else: + result = tool_func(**kwargs) + except Exception as e: + # Capture the exception and write to stderr + error_occurred = True + traceback.print_exc(file=stderr_capture) + + # Get captured output + stdout = stdout_capture.getvalue() + stderr = stderr_capture.getvalue() + + return { + "result": result, + "stdout": stdout, + "stderr": stderr, + "agent_state": agent_state, # TODO: deprecate (use letta_client instead) + "error": error_occurred or bool(stderr), + } + + return modal_app + + class ToolManager: """Manager class to handle business logic related to Tools.""" @@ -56,12 +202,16 @@ class ToolManager: else: raise ValueError("Failed to generate schema for tool", pydantic_tool.source_code) - print("SCHEMA", pydantic_tool.json_schema) - # make sure the name matches the json_schema if not pydantic_tool.name: pydantic_tool.name = pydantic_tool.json_schema.get("name") else: + # if name is provided, make sure its less tahn the MAX_TOOL_NAME_LENGTH + if len(pydantic_tool.name) > MAX_TOOL_NAME_LENGTH: + raise LettaInvalidArgumentError( + f"Tool name {pydantic_tool.name} is too long. It must be less than {MAX_TOOL_NAME_LENGTH} characters." + ) + if pydantic_tool.name != pydantic_tool.json_schema.get("name"): raise LettaToolNameSchemaMismatchError( tool_name=pydantic_tool.name, @@ -136,6 +286,13 @@ class ToolManager: # Auto-generate description if not provided if pydantic_tool.description is None and pydantic_tool.json_schema: pydantic_tool.description = pydantic_tool.json_schema.get("description", None) + + # Add tool hash to metadata for Modal deployment tracking + tool_hash = compute_tool_hash(pydantic_tool) + if pydantic_tool.metadata_ is None: + pydantic_tool.metadata_ = {} + pydantic_tool.metadata_["tool_hash"] = tool_hash + tool_data = pydantic_tool.model_dump(to_orm=True) # Set the organization id at the ORM layer tool_data["organization_id"] = actor.organization_id @@ -153,7 +310,17 @@ class ToolManager: ) await tool.create_async(session, actor=actor) # Re-raise other database-related errors - return tool.to_pydantic() + created_tool = tool.to_pydantic() + + # Deploy Modal app for the new tool + # Both Modal credentials configured AND tool metadata must indicate Modal + tool_requests_modal = created_tool.metadata_ and created_tool.metadata_.get("sandbox") == "modal" + modal_configured = tool_settings.modal_sandbox_enabled + + if created_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured: + await self.create_or_update_modal_app(created_tool, actor) + + return created_tool @enforce_types @trace_method @@ -599,6 +766,33 @@ class ToolManager: source_code=update_data.get("source_code"), ) + # Create a preview of the updated tool by merging current tool with updates + # This allows us to compute the hash before the database session + updated_tool_pydantic = current_tool.model_copy(deep=True) + for key, value in update_data.items(): + setattr(updated_tool_pydantic, key, value) + if new_schema is not None: + updated_tool_pydantic.json_schema = new_schema + updated_tool_pydantic.name = new_name + if updated_tool_type: + updated_tool_pydantic.tool_type = updated_tool_type + + # Check if we need to redeploy the Modal app due to changes + # Compute this before the session to avoid issues + tool_requests_modal = updated_tool_pydantic.metadata_ and updated_tool_pydantic.metadata_.get("sandbox") == "modal" + modal_configured = tool_settings.modal_sandbox_enabled + should_check_modal = tool_requests_modal and modal_configured and updated_tool_pydantic.tool_type == ToolType.CUSTOM + + # Compute hash before session if needed + new_hash = None + old_hash = None + needs_modal_deployment = False + + if should_check_modal: + new_hash = compute_tool_hash(updated_tool_pydantic) + old_hash = current_tool.metadata_.get("tool_hash") if current_tool.metadata_ else None + needs_modal_deployment = new_hash != old_hash + # Now perform the update within the session async with db_registry.async_session() as session: # Fetch the tool by ID @@ -618,7 +812,25 @@ class ToolManager: # Save the updated tool to the database tool = await tool.update_async(db_session=session, actor=actor) - return tool.to_pydantic() + updated_tool = tool.to_pydantic() + + # Update Modal hash in metadata if needed (inside session context) + if needs_modal_deployment: + if updated_tool.metadata_ is None: + updated_tool.metadata_ = {} + updated_tool.metadata_["tool_hash"] = new_hash + + # Update the metadata in the database (still inside session) + tool.metadata_ = updated_tool.metadata_ + tool = await tool.update_async(db_session=session, actor=actor) + updated_tool = tool.to_pydantic() + + # Deploy Modal app outside of session (it creates its own sessions) + if needs_modal_deployment: + logger.info(f"Deploying Modal app for tool {updated_tool.id} with new hash: {new_hash}") + await self.create_or_update_modal_app(updated_tool, actor) + + return updated_tool @enforce_types @trace_method @@ -783,3 +995,49 @@ class ToolManager: created_tool = await self.create_tool_async(tool, actor=actor) tools.append(created_tool) return tools + + # MODAL RELATED METHODS + @trace_method + async def create_or_update_modal_app(self, tool: PydanticTool, actor: PydanticUser): + """Create a Modal app with the tool function registered""" + import time + + import modal + + from letta.services.sandbox_config_manager import SandboxConfigManager + + # Load sandbox env vars to bake them into the Modal secrets + sandbox_env_vars = {} + try: + sandbox_config_manager = SandboxConfigManager() + sandbox_config = await sandbox_config_manager.get_or_create_default_sandbox_config_async( + sandbox_type=SandboxType.MODAL, actor=actor + ) + if sandbox_config: + sandbox_env_vars = await sandbox_config_manager.get_sandbox_env_vars_as_dict_async( + sandbox_config_id=sandbox_config.id, actor=actor, limit=None + ) + logger.info(f"Loaded {len(sandbox_env_vars)} sandbox env vars for Modal app {tool.id}") + except Exception as e: + logger.warning(f"Could not load sandbox env vars for Modal app {tool.id}: {e}") + + # Create the Modal app using the global function with sandbox env vars + modal_app = modal_tool_wrapper(tool, actor, sandbox_env_vars) + + # Deploy the app first + with modal.enable_output(): + try: + deploy = modal_app.deploy() + except Exception as e: + raise LettaInvalidArgumentError(f"Failed to deploy tool {tool.id} with name {tool.name} to Modal: {e}") + + # After deployment, look up the function to configure autoscaler + try: + func = modal.Function.from_name(modal_app.name, MODAL_DEFAULT_TOOL_NAME) + func.update_autoscaler(scaledown_window=2) # drain inactive old containers + time.sleep(5) + func.update_autoscaler(scaledown_window=60) + except Exception as e: + logger.warning(f"Failed to configure autoscaler for Modal function {modal_app.name}: {e}") + + return deploy diff --git a/letta/services/tool_sandbox/modal_sandbox.py b/letta/services/tool_sandbox/modal_sandbox.py index aa736715..1c2acd33 100644 --- a/letta/services/tool_sandbox/modal_sandbox.py +++ b/letta/services/tool_sandbox/modal_sandbox.py @@ -1,7 +1,14 @@ -from typing import Any, Dict, Optional +""" +Model sandbox implementation, which configures on Modal App per tool. +""" + +from typing import TYPE_CHECKING, Any, Dict, Optional import modal +from e2b.sandbox.commands.command_handle import CommandExitException +from e2b_code_interpreter import AsyncSandbox +from letta.constants import MODAL_DEFAULT_TOOL_NAME from letta.log import get_logger from letta.otel.tracing import log_event, trace_method from letta.schemas.agent import AgentState @@ -9,412 +16,180 @@ from letta.schemas.enums import SandboxType from letta.schemas.sandbox_config import SandboxConfig from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult -from letta.services.helpers.tool_parser_helper import parse_stdout_best_effort +from letta.services.helpers.tool_parser_helper import parse_function_arguments, parse_stdout_best_effort +from letta.services.tool_manager import ToolManager from letta.services.tool_sandbox.base import AsyncToolSandboxBase -from letta.settings import tool_settings from letta.types import JsonDict from letta.utils import get_friendly_error_msg logger = get_logger(__name__) -# class AsyncToolSandboxModalBase(AsyncToolSandboxBase): -# pass +if TYPE_CHECKING: + from e2b_code_interpreter import Execution class AsyncToolSandboxModal(AsyncToolSandboxBase): + METADATA_CONFIG_STATE_KEY = "config_state" + def __init__( self, tool_name: str, args: JsonDict, user, - tool_object: Tool | None = None, - sandbox_config: SandboxConfig | None = None, - sandbox_env_vars: dict[str, Any] | None = None, + force_recreate: bool = True, + tool_object: Optional[Tool] = None, + sandbox_config: Optional[SandboxConfig] = None, + sandbox_env_vars: Optional[Dict[str, Any]] = None, + organization_id: Optional[str] = None, + project_id: str = "default", ): super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars) + self.force_recreate = force_recreate + # Get organization_id from user if not explicitly provided + self.organization_id = organization_id if organization_id is not None else user.organization_id + self.project_id = project_id - if not tool_settings.modal_token_id or not tool_settings.modal_token_secret: - raise ValueError("MODAL_TOKEN_ID and MODAL_TOKEN_SECRET must be set.") + # TODO: check to make sure modal app `App(tool.id)` exists - # Create a unique app name based on tool and config - self._app_name = self._generate_app_name() + async def _wait_for_modal_function_deployment(self, timeout: int = 60): + """Wait for Modal app deployment to complete by retrying function lookup.""" + import asyncio + import time - def _generate_app_name(self) -> str: - """Generate a unique app name based on tool and configuration. Created based on tool name and org""" - return f"{self.user.organization_id}-{self.tool_name}" + import modal - async def _fetch_or_create_modal_app(self, sbx_config: SandboxConfig, env_vars: Dict[str, str]) -> modal.App: - """Create a Modal app with the tool function registered.""" - try: - app = await modal.App.lookup.aio(self._app_name) - return app - except: - app = modal.App(self._app_name) + from letta.helpers.tool_helpers import generate_modal_function_name - modal_config = sbx_config.get_modal_config() + # Use the same naming logic as deployment + function_name = generate_modal_function_name(self.tool.name, self.organization_id, self.project_id) - # Get the base image with dependencies - image = self._get_modal_image(sbx_config) + start_time = time.time() + retry_delay = 2 # seconds - # Decorator for the tool, note information on running untrusted code: https://modal.com/docs/guide/restricted-access - # The `@app.function` decorator must apply to functions in global scope, unless `serialized=True` is set. - @app.function(image=image, timeout=modal_config.timeout, restrict_modal_access=True, max_inputs=1, serialized=True) - def execute_tool_with_script(execution_script: str, environment_vars: dict[str, str]): - """Execute the generated tool script in Modal sandbox.""" - import os + while time.time() - start_time < timeout: + try: + f = modal.Function.from_name(function_name, MODAL_DEFAULT_TOOL_NAME) + logger.info(f"Modal function found successfully for app {function_name}, function {f}") + return f + except Exception as e: + elapsed = time.time() - start_time + if elapsed >= timeout: + raise TimeoutError( + f"Modal app {function_name} deployment timed out after {timeout} seconds. " + f"Expected app name: {function_name}, function: {MODAL_DEFAULT_TOOL_NAME}" + ) from e + logger.info(f"Modal app {function_name} not ready yet (elapsed: {elapsed:.1f}s), waiting {retry_delay}s...") + await asyncio.sleep(retry_delay) - # Note: We pass environment variables directly instead of relying on Modal secrets - # This is more flexible and doesn't require pre-configured secrets - for key, value in environment_vars.items(): - os.environ[key] = str(value) - - exec_globals = {} - exec(execution_script, exec_globals) - - # Store the function reference in the app for later use - app.remote_executor = execute_tool_with_script - return app + raise TimeoutError(f"Modal app {function_name} deployment timed out after {timeout} seconds") @trace_method async def run( self, + agent_id: Optional[str] = None, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None, ) -> ToolExecutionResult: - if self.provided_sandbox_config: - sbx_config = self.provided_sandbox_config - else: - sbx_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async( - sandbox_type=SandboxType.MODAL, actor=self.user - ) - - envs = await self._gather_env_vars(agent_state, additional_env_vars or {}, sbx_config.id, is_local=False) - - # Generate execution script (this includes the tool source code and execution logic) - execution_script = await self.generate_execution_script(agent_state=agent_state) + await self._init_async() try: - log_event( - "modal_execution_started", - {"tool": self.tool_name, "app_name": self._app_name, "env_vars": list(envs)}, - ) + log_event("modal_execution_started", {"tool": self.tool_name, "modal_app_id": self.tool.id}) + logger.info(f"Waiting for Modal function deployment for app {self.tool.id}") + func = await self._wait_for_modal_function_deployment() + logger.info(f"Modal function found successfully for app {self.tool.id}, function {str(func)}") + logger.info(f"Calling with arguments {self.args}") - # Create Modal app with the tool function registered - app = await self._fetch_or_create_modal_app(sbx_config, envs) + # TODO: use another mechanism to pass through the key + if additional_env_vars is None: + letta_api_key = None + else: + letta_api_key = additional_env_vars.get("LETTA_API_KEY", None) - # Execute the tool remotely - with app.run(): - # app = modal.Cls.from_name(app.name, "NodeShimServer")() - result = app.remote_executor.remote(execution_script, envs) + # Construct dynamic env vars + # Priority order (later overrides earlier): + # 1. Sandbox-level env vars (from database) + # 2. Agent-specific env vars + # 3. Additional runtime env vars + env_vars = {} - # Process the result - if result["error"]: - # Tool errors are expected behavior - tools can raise exceptions as part of their normal operation - # Only log at debug level to avoid triggering Sentry alerts for expected errors - logger.debug(f"Tool {self.tool_name} raised a {result['error']['name']}: {result['error']['value']}") - logger.debug(f"Traceback from Modal sandbox: \n{result['error']['traceback']}") - func_return = get_friendly_error_msg( - function_name=self.tool_name, exception_name=result["error"]["name"], exception_message=result["error"]["value"] - ) - log_event( - "modal_execution_failed", - { - "tool": self.tool_name, - "app_name": self._app_name, - "error_type": result["error"]["name"], - "error_message": result["error"]["value"], - "func_return": func_return, - }, - ) - # Parse the result from stdout even if there was an error - # (in case the function returned something before failing) - agent_state = None # Initialize agent_state + # Load sandbox-level environment variables from the database + # These can be updated after deployment and will be available at runtime + if self.provided_sandbox_env_vars: + env_vars.update(self.provided_sandbox_env_vars) + else: try: - func_return_parsed, agent_state_parsed = parse_stdout_best_effort(result["stdout"]) - if func_return_parsed is not None: - func_return = func_return_parsed - agent_state = agent_state_parsed - except Exception: - # If parsing fails, keep the error message - pass - else: - func_return, agent_state = parse_stdout_best_effort(result["stdout"]) - log_event( - "modal_execution_succeeded", - { - "tool": self.tool_name, - "app_name": self._app_name, - "func_return": func_return, - }, - ) + from letta.services.sandbox_config_manager import SandboxConfigManager + + sandbox_config_manager = SandboxConfigManager() + sandbox_config = await sandbox_config_manager.get_or_create_default_sandbox_config_async( + sandbox_type=SandboxType.MODAL, actor=self.user + ) + if sandbox_config: + sandbox_env_vars = await sandbox_config_manager.get_sandbox_env_vars_as_dict_async( + sandbox_config_id=sandbox_config.id, actor=self.user, limit=None + ) + env_vars.update(sandbox_env_vars) + except Exception as e: + logger.warning(f"Could not load sandbox env vars for tool {self.tool_name}: {e}") + + # Add agent-specific environment variables (these override sandbox-level) + if agent_state and agent_state.secrets: + for secret in agent_state.secrets: + env_vars[secret.key] = secret.value + + # Add any additional env vars passed at runtime (highest priority) + if additional_env_vars: + env_vars.update(additional_env_vars) + + # Call the modal function (already retrieved at line 101) + # Convert agent_state to dict to avoid cloudpickle serialization issues + agent_state_dict = agent_state.model_dump() if agent_state else None + + logger.info(f"Calling function {func} with arguments {self.args}") + result = await func.remote.aio( + tool_name=self.tool_name, + agent_state=agent_state_dict, + agent_id=agent_id, + env_vars=env_vars, + letta_api_key=letta_api_key, + **self.args, + ) + logger.info(f"Modal function result: {result}") + + # Reconstruct agent_state if it was returned (use original as fallback) + result_agent_state = agent_state + if result.get("agent_state"): + if isinstance(result["agent_state"], dict): + try: + from letta.schemas.agent import AgentState + + result_agent_state = AgentState.model_validate(result["agent_state"]) + except Exception as e: + logger.warning(f"Failed to reconstruct AgentState: {e}, using original") + else: + result_agent_state = result["agent_state"] return ToolExecutionResult( - func_return=func_return, - agent_state=agent_state, - stdout=[result["stdout"]] if result["stdout"] else [], - stderr=[result["stderr"]] if result["stderr"] else [], + func_return=result["result"], + agent_state=result_agent_state, + stdout=[result["stdout"]], + stderr=[result["stderr"]], status="error" if result["error"] else "success", - sandbox_config_fingerprint=sbx_config.fingerprint(), ) - except Exception as e: - logger.error(f"Modal execution for tool {self.tool_name} encountered an error: {e}") - func_return = get_friendly_error_msg( - function_name=self.tool_name, - exception_name=type(e).__name__, - exception_message=str(e), - ) log_event( - "modal_execution_error", + "modal_execution_failed", { "tool": self.tool_name, - "app_name": self._app_name, + "modal_app_id": self.tool.id, "error": str(e), - "func_return": func_return, }, ) + logger.error(f"Modal execution failed for tool {self.tool_name} {self.tool.id}: {e}") return ToolExecutionResult( - func_return=func_return, - agent_state=None, - stdout=[], + func_return=None, + agent_state=agent_state, + stdout=[""], stderr=[str(e)], status="error", - sandbox_config_fingerprint=sbx_config.fingerprint(), ) - - def _get_modal_image(self, sbx_config: SandboxConfig) -> modal.Image: - """Get Modal image with required public python dependencies. - - Caching and rebuilding is handled in a cascading manner - https://modal.com/docs/guide/images#image-caching-and-rebuilds - """ - image = modal.Image.debian_slim(python_version="3.12") - - all_requirements = ["letta"] - - # Add sandbox-specific pip requirements - modal_configs = sbx_config.get_modal_config() - if modal_configs.pip_requirements: - all_requirements.extend([str(req) for req in modal_configs.pip_requirements]) - - # Add tool-specific pip requirements - if self.tool and self.tool.pip_requirements: - all_requirements.extend([str(req) for req in self.tool.pip_requirements]) - - if all_requirements: - image = image.pip_install(*all_requirements) - - return image - - def use_top_level_await(self) -> bool: - """ - Modal functions don't have an active event loop by default, - so we should use asyncio.run() like local execution. - """ - return False - - -class TypescriptToolSandboxModal(AsyncToolSandboxModal): - """Modal sandbox implementation for TypeScript tools.""" - - @trace_method - async def run( - self, - agent_state: Optional[AgentState] = None, - additional_env_vars: Optional[Dict] = None, - ) -> ToolExecutionResult: - """Run TypeScript tool in Modal sandbox using Node.js server.""" - if self.provided_sandbox_config: - sbx_config = self.provided_sandbox_config - else: - sbx_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async( - sandbox_type=SandboxType.MODAL, actor=self.user - ) - - envs = await self._gather_env_vars(agent_state, additional_env_vars or {}, sbx_config.id, is_local=False) - - # Generate execution script (JSON args for TypeScript) - json_args = await self.generate_execution_script(agent_state=agent_state) - - try: - log_event( - "modal_typescript_execution_started", - {"tool": self.tool_name, "app_name": self._app_name, "args": json_args}, - ) - - # Create Modal app with the TypeScript Node.js server - app = await self._fetch_or_create_modal_app(sbx_config, envs) - - # Execute the TypeScript tool remotely via the Node.js server - with app.run(): - # Get the NodeShimServer class from Modal - node_server = modal.Cls.from_name(self._app_name, "NodeShimServer") - - # Call the remote_executor method with the JSON arguments - # The server will parse the JSON and call the TypeScript function - result = node_server().remote_executor.remote(json_args) - - # Process the TypeScript execution result - if isinstance(result, dict) and "error" in result: - # Handle errors from TypeScript execution - logger.debug(f"TypeScript tool {self.tool_name} raised an error: {result['error']}") - func_return = get_friendly_error_msg( - function_name=self.tool_name, - exception_name="TypeScriptError", - exception_message=str(result["error"]), - ) - log_event( - "modal_typescript_execution_failed", - { - "tool": self.tool_name, - "app_name": self._app_name, - "error": result["error"], - "func_return": func_return, - }, - ) - return ToolExecutionResult( - func_return=func_return, - agent_state=None, # TypeScript tools don't support agent_state yet - stdout=[], - stderr=[str(result["error"])], - status="error", - sandbox_config_fingerprint=sbx_config.fingerprint(), - ) - else: - # Success case - TypeScript function returned a result - func_return = str(result) if result is not None else "" - log_event( - "modal_typescript_execution_succeeded", - { - "tool": self.tool_name, - "app_name": self._app_name, - "func_return": func_return, - }, - ) - return ToolExecutionResult( - func_return=func_return, - agent_state=None, # TypeScript tools don't support agent_state yet - stdout=[], - stderr=[], - status="success", - sandbox_config_fingerprint=sbx_config.fingerprint(), - ) - - except Exception as e: - logger.error(f"Modal TypeScript execution for tool {self.tool_name} encountered an error: {e}") - func_return = get_friendly_error_msg( - function_name=self.tool_name, - exception_name=type(e).__name__, - exception_message=str(e), - ) - log_event( - "modal_typescript_execution_error", - { - "tool": self.tool_name, - "app_name": self._app_name, - "error": str(e), - "func_return": func_return, - }, - ) - return ToolExecutionResult( - func_return=func_return, - agent_state=None, - stdout=[], - stderr=[str(e)], - status="error", - sandbox_config_fingerprint=sbx_config.fingerprint(), - ) - - async def _fetch_or_create_modal_app(self, sbx_config: SandboxConfig, env_vars: Dict[str, str]) -> modal.App: - """Create or fetch a Modal app with TypeScript execution capabilities.""" - try: - return await modal.App.lookup.aio(self._app_name) - except: - app = modal.App(self._app_name) - - modal_config = sbx_config.get_modal_config() - - # Get the base image with dependencies - image = self._get_modal_image(sbx_config) - - # Import the NodeShimServer that will handle TypeScript execution - from sandbox.node_server import NodeShimServer - - # Register the NodeShimServer class with Modal - # This creates a serverless function that can handle concurrent requests - app.cls(image=image, restrict_modal_access=True, include_source=False, timeout=modal_config.timeout if modal_config else 60)( - modal.concurrent(max_inputs=100, target_inputs=50)(NodeShimServer) - ) - - # Deploy the app to Modal - with modal.enable_output(): - await app.deploy.aio() - - return app - - async def generate_execution_script(self, agent_state: Optional[AgentState], wrap_print_with_markers: bool = False) -> str: - """Generate the execution script for TypeScript tools. - - For TypeScript tools, this returns the JSON-encoded arguments that will be passed - to the Node.js server via the remote_executor method. - """ - import json - - # Convert args to JSON string for TypeScript execution - # The Node.js server expects JSON-encoded arguments - return json.dumps(self.args) - - def _get_modal_image(self, sbx_config: SandboxConfig) -> modal.Image: - """Build a Modal image with Node.js, TypeScript, and the user's tool function.""" - import importlib.util - from pathlib import Path - - # Find the sandbox module location - spec = importlib.util.find_spec("sandbox") - if not spec or not spec.origin: - raise ValueError("Could not find sandbox module") - server_dir = Path(spec.origin).parent - - # Get the TypeScript function source code - if not self.tool or not self.tool.source_code: - raise ValueError("TypeScript tool must have source code") - - ts_function = self.tool.source_code - - # Get npm dependencies from sandbox config and tool - modal_config = sbx_config.get_modal_config() - npm_dependencies = [] - - # Add dependencies from sandbox config - if modal_config and modal_config.npm_requirements: - npm_dependencies.extend(modal_config.npm_requirements) - - # Add dependencies from the tool itself - if self.tool.npm_requirements: - npm_dependencies.extend(self.tool.npm_requirements) - - # Build npm install command for user dependencies - user_dependencies_cmd = "" - if npm_dependencies: - # Ensure unique dependencies - unique_deps = list(set(npm_dependencies)) - user_dependencies_cmd = " && npm install " + " ".join(unique_deps) - - # Escape single quotes in the TypeScript function for shell command - escaped_ts_function = ts_function.replace("'", "'\\''") - - # Build the Docker image with Node.js and TypeScript - image = ( - modal.Image.from_registry("node:22-slim", add_python="3.12") - .add_local_dir(server_dir, "/root/sandbox", ignore=["node_modules", "build"], copy=True) - .run_commands( - # Install dependencies and build the TypeScript server - f"cd /root/sandbox/resources/server && npm install{user_dependencies_cmd}", - # Write the user's TypeScript function to a file - f"echo '{escaped_ts_function}' > /root/sandbox/user-function.ts", - ) - ) - return image - - -# probably need to do parse_stdout_best_effort diff --git a/letta/settings.py b/letta/settings.py index e5061d2a..56f369b8 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -39,12 +39,20 @@ class ToolSettings(BaseSettings): mcp_read_from_config: bool = False # if False, will throw if attempting to read/write from file mcp_disable_stdio: bool = False + @property + def modal_sandbox_enabled(self) -> bool: + """Check if Modal credentials are configured.""" + return bool(self.modal_token_id and self.modal_token_secret) + @property def sandbox_type(self) -> SandboxType: + """Default sandbox type based on available credentials. + + Note: Modal is checked separately via modal_sandbox_enabled property. + This property determines the fallback behavior (E2B or LOCAL). + """ if self.e2b_api_key: return SandboxType.E2B - # elif self.modal_token_id and self.modal_token_secret: - # return SandboxType.MODAL else: return SandboxType.LOCAL diff --git a/tests/conftest.py b/tests/conftest.py index 9f77a917..82e9ff8e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -169,6 +169,15 @@ def check_e2b_key_is_set(): yield +@pytest.fixture +def check_modal_key_is_set(): + from letta.settings import tool_settings + + assert tool_settings.modal_token_id is not None, "Missing modal token id! Cannot execute these tests." + assert tool_settings.modal_token_secret is not None, "Missing modal token secret! Cannot execute these tests." + yield + + @pytest.fixture async def default_organization(): """Fixture to create and return the default organization.""" diff --git a/tests/integration_test_modal.py b/tests/integration_test_modal.py new file mode 100644 index 00000000..fab27b92 --- /dev/null +++ b/tests/integration_test_modal.py @@ -0,0 +1,770 @@ +import os +import secrets +import string +import uuid +from pathlib import Path +from unittest.mock import patch + +import pytest +from sqlalchemy import delete + +from letta.config import LettaConfig +from letta.functions.function_sets.base import core_memory_append, core_memory_replace +from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable +from letta.schemas.agent import AgentState, CreateAgent +from letta.schemas.block import CreateBlock +from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate +from letta.schemas.organization import Organization +from letta.schemas.pip_requirement import PipRequirement +from letta.schemas.sandbox_config import LocalSandboxConfig, ModalSandboxConfig, SandboxConfigCreate +from letta.schemas.user import User +from letta.server.db import db_registry +from letta.server.server import SyncServer +from letta.services.organization_manager import OrganizationManager +from letta.services.sandbox_config_manager import SandboxConfigManager +from letta.services.tool_manager import ToolManager +from letta.services.tool_sandbox.modal_sandbox import AsyncToolSandboxModal +from letta.services.user_manager import UserManager +from tests.helpers.utils import create_tool_from_func + +# Constants +namespace = uuid.NAMESPACE_DNS +org_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-org")) +user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user")) + +# Set environment variable immediately to prevent pooling issues +os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] = "true" + + +# Disable SQLAlchemy connection pooling for tests to prevent event loop issues +@pytest.fixture(scope="session", autouse=True) +def disable_db_pooling_for_tests(): + """Disable database connection pooling for the entire test session.""" + # Environment variable is already set above and settings reloaded + yield + # Clean up environment variable after tests + if "LETTA_DISABLE_SQLALCHEMY_POOLING" in os.environ: + del os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] + + +# @pytest.fixture(autouse=True) +# async def cleanup_db_connections(): +# """Cleanup database connections after each test.""" +# yield +# +# # Dispose async engines in the current event loop +# try: +# await close_db() +# except Exception as e: +# # Log the error but don't fail the test +# print(f"Warning: Failed to cleanup database connections: {e}") + + +# Fixtures +@pytest.fixture(scope="module") +def server(): + """ + Creates a SyncServer instance for testing. + + Loads and saves config to ensure proper initialization. + """ + config = LettaConfig.load() + + config.save() + + server = SyncServer(init_with_default_org_and_user=True) + # create user/org + yield server + + +@pytest.fixture(autouse=True) +async def clear_tables(): + """Fixture to clear the organization table before each test.""" + from letta.server.db import db_registry + + async with db_registry.async_session() as session: + await session.execute(delete(SandboxEnvironmentVariable)) + await session.execute(delete(SandboxConfig)) + await session.commit() # Commit the deletion + + +@pytest.fixture +async def test_organization(): + """Fixture to create and return the default organization.""" + org = await OrganizationManager().create_organization_async(Organization(name=org_name)) + yield org + + +@pytest.fixture +async def test_user(test_organization): + """Fixture to create and return the default user within the default organization.""" + user = await UserManager().create_actor_async(User(name=user_name, organization_id=test_organization.id)) + yield user + + +@pytest.fixture +async def add_integers_tool(test_user): + def add(x: int, y: int) -> int: + """ + Simple function that adds two integers. + + Parameters: + x (int): The first integer to add. + y (int): The second integer to add. + + Returns: + int: The result of adding x and y. + """ + return x + y + + tool = create_tool_from_func(add) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def cowsay_tool(test_user): + # This defines a tool for a package we definitely do NOT have in letta + # If this test passes, that means the tool was correctly executed in a separate Python environment + def cowsay() -> str: + """ + Simple function that uses the cowsay package to print out the secret word env variable. + + Returns: + str: The cowsay ASCII art. + """ + import os + + import cowsay + + cowsay.cow(os.getenv("secret_word")) + + tool = create_tool_from_func(cowsay) + # Add cowsay as a pip requirement for Modal + tool.pip_requirements = [PipRequirement(name="cowsay")] + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def get_env_tool(test_user): + def get_env() -> str: + """ + Simple function that returns the secret word env variable. + + Returns: + str: The secret word + """ + import os + + secret_word = os.getenv("secret_word") + print(secret_word) + return secret_word + + tool = create_tool_from_func(get_env) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def get_warning_tool(test_user): + def warn_hello_world() -> str: + """ + Simple function that warns hello world. + + Returns: + str: hello world + """ + import warnings + + msg = "Hello World" + warnings.warn(msg) + return msg + + tool = create_tool_from_func(warn_hello_world) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def always_err_tool(test_user): + def error() -> str: + """ + Simple function that errors + + Returns: + str: not important + """ + # Raise a unusual error so we know it's from this function + print("Going to error now") + raise ZeroDivisionError("This is an intentionally weird division!") + + tool = create_tool_from_func(error) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def list_tool(test_user): + def create_list(): + """Simple function that returns a list""" + + return [1] * 5 + + tool = create_tool_from_func(create_list) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def clear_core_memory_tool(test_user): + def clear_memory(agent_state: "AgentState"): + """Clear the core memory""" + agent_state.memory.get_block("human").value = "" + agent_state.memory.get_block("persona").value = "" + + tool = create_tool_from_func(clear_memory) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def external_codebase_tool(test_user): + from tests.test_tool_sandbox.restaurant_management_system.adjust_menu_prices import adjust_menu_prices + + tool = create_tool_from_func(adjust_menu_prices) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def agent_state(server: SyncServer): + await server.init_async(init_with_default_org_and_user=True) + actor = await server.user_manager.create_default_actor_async() + agent_state = await server.create_agent_async( + CreateAgent( + memory_blocks=[ + CreateBlock( + label="human", + value="username: sarah", + ), + CreateBlock( + label="persona", + value="This is the persona", + ), + ], + include_base_tools=True, + model="openai/gpt-4o-mini", + tags=["test_agents"], + embedding="letta/letta-free", + ), + actor=actor, + ) + agent_state.tool_rules = [] + yield agent_state + + +@pytest.fixture +async def custom_test_sandbox_config(test_user): + """ + Fixture to create a consistent local sandbox configuration for tests. + + Args: + test_user: The test user to be used for creating the sandbox configuration. + + Returns: + A tuple containing the SandboxConfigManager and the created sandbox configuration. + """ + # Create the SandboxConfigManager + manager = SandboxConfigManager() + + # Set the sandbox to be within the external codebase path and use a venv + external_codebase_path = str(Path(__file__).parent / "test_tool_sandbox" / "restaurant_management_system") + # tqdm is used in this codebase, but NOT in the requirements.txt, this tests that we can successfully install pip requirements + local_sandbox_config = LocalSandboxConfig( + sandbox_dir=external_codebase_path, use_venv=True, pip_requirements=[PipRequirement(name="tqdm")] + ) + + # Create the sandbox configuration + config_create = SandboxConfigCreate(config=local_sandbox_config.model_dump()) + + # Create or update the sandbox configuration + await manager.create_or_update_sandbox_config_async(sandbox_config_create=config_create, actor=test_user) + + return manager, local_sandbox_config + + +@pytest.fixture +async def core_memory_tools(test_user): + """Create all base tools for testing.""" + tools = {} + for func in [ + core_memory_replace, + core_memory_append, + ]: + tool = create_tool_from_func(func) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + tools[func.__name__] = tool + yield tools + + +@pytest.fixture +async def async_add_integers_tool(test_user): + async def async_add(x: int, y: int) -> int: + """ + Async function that adds two integers. + + Parameters: + x (int): The first integer to add. + y (int): The second integer to add. + + Returns: + int: The result of adding x and y. + """ + import asyncio + + # Add a small delay to simulate async work + await asyncio.sleep(0.1) + return x + y + + tool = create_tool_from_func(async_add) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def async_get_env_tool(test_user): + async def async_get_env() -> str: + """ + Async function that returns the secret word env variable. + + Returns: + str: The secret word + """ + import asyncio + import os + + # Add a small delay to simulate async work + await asyncio.sleep(0.1) + secret_word = os.getenv("secret_word") + print(secret_word) + return secret_word + + tool = create_tool_from_func(async_get_env) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def async_stateful_tool(test_user): + async def async_clear_memory(agent_state: "AgentState"): + """Async function that clears the core memory""" + import asyncio + + # Add a small delay to simulate async work + await asyncio.sleep(0.1) + agent_state.memory.get_block("human").value = "" + agent_state.memory.get_block("persona").value = "" + + tool = create_tool_from_func(async_clear_memory) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def async_error_tool(test_user): + async def async_error() -> str: + """ + Async function that errors + + Returns: + str: not important + """ + import asyncio + + # Add some async work before erroring + await asyncio.sleep(0.1) + print("Going to error now") + raise ValueError("This is an intentional async error!") + + tool = create_tool_from_func(async_error) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def async_list_tool(test_user): + async def async_create_list() -> list: + """Async function that returns a list""" + import asyncio + + await asyncio.sleep(0.05) + return [1, 2, 3, 4, 5] + + tool = create_tool_from_func(async_create_list) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def tool_with_pip_requirements(test_user): + def use_requests_and_numpy() -> str: + """ + Function that uses requests and numpy packages to test tool-specific pip requirements. + + Returns: + str: Success message if packages are available. + """ + try: + import numpy as np + import requests + + # Simple usage to verify packages work + response = requests.get("https://httpbin.org/json", timeout=30) + arr = np.array([1, 2, 3]) + return f"Success! Status: {response.status_code}, Array sum: {np.sum(arr)}" + except ImportError as e: + return f"Import error: {e}" + except Exception as e: + return f"Other error: {e}" + + tool = create_tool_from_func(use_requests_and_numpy) + # Add pip requirements to the tool + tool.pip_requirements = [ + PipRequirement(name="requests", version="2.31.0"), + PipRequirement(name="numpy"), + ] + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +@pytest.fixture +async def async_complex_tool(test_user): + async def async_complex_computation(iterations: int = 3) -> dict: + """ + Async function that performs complex computation with multiple awaits. + + Parameters: + iterations (int): Number of iterations to perform. + + Returns: + dict: Results of the computation. + """ + import asyncio + import time + + results = [] + start_time = time.time() + + for i in range(iterations): + # Simulate async I/O + await asyncio.sleep(0.1) + results.append(i * 2) + + end_time = time.time() + + return { + "results": results, + "duration": end_time - start_time, + "iterations": iterations, + "average": sum(results) / len(results) if results else 0, + } + + tool = create_tool_from_func(async_complex_computation) + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + yield tool + + +# Modal sandbox tests + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_default(check_modal_key_is_set, add_integers_tool, test_user): + args = {"x": 10, "y": 5} + + # Mock and assert correct pathway was invoked + with patch.object(AsyncToolSandboxModal, "run") as mock_run: + sandbox = AsyncToolSandboxModal(add_integers_tool.name, args, user=test_user) + await sandbox.run() + mock_run.assert_called_once() + + # Run again to get actual response + sandbox = AsyncToolSandboxModal(add_integers_tool.name, args, user=test_user) + result = await sandbox.run() + assert int(result.func_return) == args["x"] + args["y"] + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_pip_installs(check_modal_key_is_set, cowsay_tool, test_user): + """Test that Modal sandbox installs tool-level pip requirements.""" + manager = SandboxConfigManager() + config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump()) + config = await manager.create_or_update_sandbox_config_async(config_create, test_user) + + key = "secret_word" + long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20)) + await manager.create_sandbox_env_var_async( + SandboxEnvironmentVariableCreate(key=key, value=long_random_string), + sandbox_config_id=config.id, + actor=test_user, + ) + + sandbox = AsyncToolSandboxModal(cowsay_tool.name, {}, user=test_user) + result = await sandbox.run() + assert long_random_string in result.stdout[0] + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_stateful_tool(check_modal_key_is_set, clear_core_memory_tool, test_user, agent_state): + sandbox = AsyncToolSandboxModal(clear_core_memory_tool.name, {}, user=test_user) + result = await sandbox.run(agent_state=agent_state) + assert result.agent_state.memory.get_block("human").value == "" + assert result.agent_state.memory.get_block("persona").value == "" + assert result.func_return is None + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_inject_env_var_existing_sandbox(check_modal_key_is_set, get_env_tool, test_user): + manager = SandboxConfigManager() + config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump()) + config = await manager.create_or_update_sandbox_config_async(config_create, test_user) + + sandbox = AsyncToolSandboxModal(get_env_tool.name, {}, user=test_user) + result = await sandbox.run() + assert result.func_return is None + + key = "secret_word" + long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20)) + await manager.create_sandbox_env_var_async( + SandboxEnvironmentVariableCreate(key=key, value=long_random_string), + sandbox_config_id=config.id, + actor=test_user, + ) + + sandbox = AsyncToolSandboxModal(get_env_tool.name, {}, user=test_user) + result = await sandbox.run() + assert long_random_string in result.func_return + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_per_agent_env(check_modal_key_is_set, get_env_tool, agent_state, test_user): + manager = SandboxConfigManager() + key = "secret_word" + wrong_val = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20)) + correct_val = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20)) + + config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump()) + config = await manager.create_or_update_sandbox_config_async(config_create, test_user) + await manager.create_sandbox_env_var_async( + SandboxEnvironmentVariableCreate(key=key, value=wrong_val), + sandbox_config_id=config.id, + actor=test_user, + ) + + agent_state.secrets = [AgentEnvironmentVariable(key=key, value=correct_val, agent_id=agent_state.id)] + + sandbox = AsyncToolSandboxModal(get_env_tool.name, {}, user=test_user) + result = await sandbox.run(agent_state=agent_state) + assert wrong_val not in result.func_return + assert correct_val in result.func_return + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_with_list_rv(check_modal_key_is_set, list_tool, test_user): + sandbox = AsyncToolSandboxModal(list_tool.name, {}, user=test_user) + result = await sandbox.run() + assert len(result.func_return) == 5 + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_with_tool_pip_requirements(check_modal_key_is_set, tool_with_pip_requirements, test_user): + """Test that Modal sandbox installs tool-specific pip requirements.""" + manager = SandboxConfigManager() + config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump()) + await manager.create_or_update_sandbox_config_async(config_create, test_user) + + sandbox = AsyncToolSandboxModal(tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements) + result = await sandbox.run() + + # Should succeed since tool pip requirements were installed + assert "Success!" in result.func_return + assert "Status: 200" in result.func_return + assert "Array sum: 6" in result.func_return + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_with_mixed_pip_requirements(check_modal_key_is_set, tool_with_pip_requirements, test_user): + """Test that Modal sandbox installs tool pip requirements. + + Note: Modal does not support sandbox-level pip requirements - all pip requirements + must be specified at the tool level since the Modal app is deployed with a fixed image. + """ + manager = SandboxConfigManager() + config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump()) + await manager.create_or_update_sandbox_config_async(config_create, test_user) + + sandbox = AsyncToolSandboxModal(tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements) + result = await sandbox.run() + + # Should succeed since tool pip requirements were installed + assert "Success!" in result.func_return + assert "Status: 200" in result.func_return + assert "Array sum: 6" in result.func_return + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_with_broken_tool_pip_requirements_error_handling(check_modal_key_is_set, test_user): + """Test that Modal sandbox provides informative error messages for broken tool pip requirements.""" + + def use_broken_package() -> str: + """ + Function that tries to use packages with broken version constraints. + + Returns: + str: Success message if packages are available. + """ + return "Should not reach here" + + tool = create_tool_from_func(use_broken_package) + # Add broken pip requirements + tool.pip_requirements = [ + PipRequirement(name="numpy", version="1.24.0"), # Old version incompatible with newer Python + PipRequirement(name="nonexistent-package-12345"), # Non-existent package + ] + # expect a LettaInvalidArgumentError + from letta.errors import LettaInvalidArgumentError + + with pytest.raises(LettaInvalidArgumentError): + tool = await ToolManager().create_or_update_tool_async(tool, test_user) + + +@pytest.mark.asyncio +async def test_async_function_detection(add_integers_tool, async_add_integers_tool, test_user): + """Test that async function detection works correctly""" + # Test sync function detection + sync_sandbox = AsyncToolSandboxModal(add_integers_tool.name, {}, test_user, tool_object=add_integers_tool) + await sync_sandbox._init_async() + assert not sync_sandbox.is_async_function + + # Test async function detection + async_sandbox = AsyncToolSandboxModal(async_add_integers_tool.name, {}, test_user, tool_object=async_add_integers_tool) + await async_sandbox._init_async() + assert async_sandbox.is_async_function + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_async_function_execution(check_modal_key_is_set, async_add_integers_tool, test_user): + """Test that async functions execute correctly in Modal sandbox""" + args = {"x": 20, "y": 30} + + sandbox = AsyncToolSandboxModal(async_add_integers_tool.name, args, user=test_user) + result = await sandbox.run() + assert int(result.func_return) == args["x"] + args["y"] + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_async_complex_computation(check_modal_key_is_set, async_complex_tool, test_user): + """Test complex async computation with multiple awaits in Modal sandbox""" + args = {"iterations": 2} + + sandbox = AsyncToolSandboxModal(async_complex_tool.name, args, user=test_user) + result = await sandbox.run() + + func_return = result.func_return + assert isinstance(func_return, dict) + assert func_return["results"] == [0, 2] + assert func_return["iterations"] == 2 + assert func_return["average"] == 1.0 + assert func_return["duration"] > 0.15 + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_async_list_return(check_modal_key_is_set, async_list_tool, test_user): + """Test async function returning list in Modal sandbox""" + sandbox = AsyncToolSandboxModal(async_list_tool.name, {}, user=test_user) + result = await sandbox.run() + assert result.func_return == [1, 2, 3, 4, 5] + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_async_with_env_vars(check_modal_key_is_set, async_get_env_tool, test_user): + """Test async function with environment variables in Modal sandbox""" + manager = SandboxConfigManager() + config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump()) + config = await manager.create_or_update_sandbox_config_async(config_create, test_user) + + # Create environment variable + key = "secret_word" + test_value = "async_modal_test_value_456" + await manager.create_sandbox_env_var_async( + SandboxEnvironmentVariableCreate(key=key, value=test_value), sandbox_config_id=config.id, actor=test_user + ) + + sandbox = AsyncToolSandboxModal(async_get_env_tool.name, {}, user=test_user) + result = await sandbox.run() + + assert test_value in result.func_return + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_async_with_agent_state(check_modal_key_is_set, async_stateful_tool, test_user, agent_state): + """Test async function with agent state in Modal sandbox""" + sandbox = AsyncToolSandboxModal(async_stateful_tool.name, {}, user=test_user) + result = await sandbox.run(agent_state=agent_state) + + assert result.agent_state.memory.get_block("human").value == "" + assert result.agent_state.memory.get_block("persona").value == "" + assert result.func_return is None + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_async_error_handling(check_modal_key_is_set, async_error_tool, test_user): + """Test async function error handling in Modal sandbox""" + sandbox = AsyncToolSandboxModal(async_error_tool.name, {}, user=test_user) + result = await sandbox.run() + + # Check that error was captured + assert len(result.stdout) != 0, "stdout not empty" + assert "error" in result.stdout[0], "stdout contains printed string" + assert len(result.stderr) != 0, "stderr not empty" + assert "ValueError: This is an intentional async error!" in result.stderr[0], "stderr contains expected error" + + +@pytest.mark.asyncio +@pytest.mark.modal_sandbox +async def test_modal_sandbox_async_per_agent_env(check_modal_key_is_set, async_get_env_tool, agent_state, test_user): + """Test async function with per-agent environment variables in Modal sandbox""" + manager = SandboxConfigManager() + key = "secret_word" + wrong_val = "wrong_async_modal_value" + correct_val = "correct_async_modal_value" + + config_create = SandboxConfigCreate(config=ModalSandboxConfig().model_dump()) + config = await manager.create_or_update_sandbox_config_async(config_create, test_user) + await manager.create_sandbox_env_var_async( + SandboxEnvironmentVariableCreate(key=key, value=wrong_val), + sandbox_config_id=config.id, + actor=test_user, + ) + + agent_state.secrets = [AgentEnvironmentVariable(key=key, value=correct_val, agent_id=agent_state.id)] + + sandbox = AsyncToolSandboxModal(async_get_env_tool.name, {}, user=test_user) + result = await sandbox.run(agent_state=agent_state) + assert wrong_val not in result.func_return + assert correct_val in result.func_return diff --git a/tests/pytest.ini b/tests/pytest.ini index 48a51023..b4e80dd2 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -11,6 +11,7 @@ filterwarnings = markers = local_sandbox: mark test as part of local sandbox tests e2b_sandbox: mark test as part of E2B sandbox tests + modal_sandbox: mark test as part of Modal sandbox tests openai_basic: Tests for OpenAI endpoints anthropic_basic: Tests for Anthropic endpoints azure_basic: Tests for Azure endpoints diff --git a/tests/test_client.py b/tests/test_client.py index 436ddb40..2d2c9fb8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -395,7 +395,7 @@ def test_function_always_error(client: Letta): assert response_message.status == "error" # TODO: add this back # assert "Error executing function testing_method" in response_message.tool_return, response_message.tool_return - assert "ZeroDivisionError: division by zero" in response_message.stderr[0] + assert "division by zero" in response_message.stderr[0] client.agents.delete(agent_id=agent.id)