feat: typescript sandbox
This commit is contained in:
196
letta/functions/typescript_parser.py
Normal file
196
letta/functions/typescript_parser.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""TypeScript function parsing for JSON schema generation."""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from letta.errors import LettaToolCreateError
|
||||
|
||||
|
||||
def derive_typescript_json_schema(source_code: str, name: Optional[str] = None) -> dict:
|
||||
"""Derives the OpenAI JSON schema for a given TypeScript function source code.
|
||||
|
||||
This parser extracts the function signature, parameters, and types from TypeScript
|
||||
code and generates a JSON schema compatible with OpenAI's function calling format.
|
||||
|
||||
Args:
|
||||
source_code: TypeScript source code containing an exported function
|
||||
name: Optional function name override
|
||||
|
||||
Returns:
|
||||
JSON schema dict with name, description, and parameters
|
||||
|
||||
Raises:
|
||||
LettaToolCreateError: If parsing fails or no exported function is found
|
||||
"""
|
||||
try:
|
||||
# Find the exported function
|
||||
function_pattern = r"export\s+function\s+(\w+)\s*\((.*?)\)\s*:\s*([\w<>\[\]|]+)?"
|
||||
match = re.search(function_pattern, source_code, re.DOTALL)
|
||||
|
||||
if not match:
|
||||
# Try async function
|
||||
async_pattern = r"export\s+async\s+function\s+(\w+)\s*\((.*?)\)\s*:\s*([\w<>\[\]|]+)?"
|
||||
match = re.search(async_pattern, source_code, re.DOTALL)
|
||||
|
||||
if not match:
|
||||
raise LettaToolCreateError("No exported function found in TypeScript source code")
|
||||
|
||||
func_name = match.group(1)
|
||||
params_str = match.group(2).strip()
|
||||
# return_type = match.group(3) if match.group(3) else 'any'
|
||||
|
||||
# Use provided name or extracted name
|
||||
schema_name = name or func_name
|
||||
|
||||
# Extract JSDoc comment for description
|
||||
description = extract_jsdoc_description(source_code, func_name)
|
||||
if not description:
|
||||
description = f"TypeScript function {func_name}"
|
||||
|
||||
# Parse parameters
|
||||
parameters = parse_typescript_parameters(params_str)
|
||||
|
||||
# Build OpenAI-compatible JSON schema
|
||||
schema = {
|
||||
"name": schema_name,
|
||||
"description": description,
|
||||
"parameters": {"type": "object", "properties": parameters["properties"], "required": parameters["required"]},
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
except Exception as e:
|
||||
raise LettaToolCreateError(f"TypeScript schema generation failed: {str(e)}") from e
|
||||
|
||||
|
||||
def extract_jsdoc_description(source_code: str, func_name: str) -> Optional[str]:
|
||||
"""Extract JSDoc description for a function."""
|
||||
# Look for JSDoc comment before the function
|
||||
jsdoc_pattern = r"/\*\*(.*?)\*/\s*export\s+(?:async\s+)?function\s+" + re.escape(func_name)
|
||||
match = re.search(jsdoc_pattern, source_code, re.DOTALL)
|
||||
|
||||
if match:
|
||||
jsdoc_content = match.group(1)
|
||||
# Extract the main description (text before @param tags)
|
||||
lines = jsdoc_content.split("\n")
|
||||
description_lines = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip().lstrip("*").strip()
|
||||
if line and not line.startswith("@"):
|
||||
description_lines.append(line)
|
||||
elif line.startswith("@"):
|
||||
break
|
||||
|
||||
if description_lines:
|
||||
return " ".join(description_lines)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_typescript_parameters(params_str: str) -> Dict[str, Any]:
|
||||
"""Parse TypeScript function parameters and generate JSON schema properties."""
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
if not params_str:
|
||||
return {"properties": properties, "required": required}
|
||||
|
||||
# Split parameters by comma (handling nested types)
|
||||
params = split_parameters(params_str)
|
||||
|
||||
for param in params:
|
||||
param = param.strip()
|
||||
if not param:
|
||||
continue
|
||||
|
||||
# Parse parameter name, optional flag, and type
|
||||
param_match = re.match(r"(\w+)(\?)?\s*:\s*(.+)", param)
|
||||
if param_match:
|
||||
param_name = param_match.group(1)
|
||||
is_optional = param_match.group(2) == "?"
|
||||
param_type = param_match.group(3).strip()
|
||||
|
||||
# Convert TypeScript type to JSON schema type
|
||||
json_type = typescript_to_json_schema_type(param_type)
|
||||
|
||||
properties[param_name] = json_type
|
||||
|
||||
# Add to required list if not optional
|
||||
if not is_optional:
|
||||
required.append(param_name)
|
||||
|
||||
return {"properties": properties, "required": required}
|
||||
|
||||
|
||||
def split_parameters(params_str: str) -> list:
|
||||
"""Split parameter string by commas, handling nested types."""
|
||||
params = []
|
||||
current_param = ""
|
||||
depth = 0
|
||||
|
||||
for char in params_str:
|
||||
if char in "<[{(":
|
||||
depth += 1
|
||||
elif char in ">]})":
|
||||
depth -= 1
|
||||
elif char == "," and depth == 0:
|
||||
params.append(current_param)
|
||||
current_param = ""
|
||||
continue
|
||||
|
||||
current_param += char
|
||||
|
||||
if current_param:
|
||||
params.append(current_param)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def typescript_to_json_schema_type(ts_type: str) -> Dict[str, Any]:
|
||||
"""Convert TypeScript type to JSON schema type definition."""
|
||||
ts_type = ts_type.strip()
|
||||
|
||||
# Basic type mappings
|
||||
type_map = {
|
||||
"string": {"type": "string"},
|
||||
"number": {"type": "number"},
|
||||
"boolean": {"type": "boolean"},
|
||||
"any": {"type": "string"}, # Default to string for any
|
||||
"void": {"type": "null"},
|
||||
"null": {"type": "null"},
|
||||
"undefined": {"type": "null"},
|
||||
}
|
||||
|
||||
# Check for basic types
|
||||
if ts_type in type_map:
|
||||
return type_map[ts_type]
|
||||
|
||||
# Handle arrays
|
||||
if ts_type.endswith("[]"):
|
||||
item_type = ts_type[:-2].strip()
|
||||
return {"type": "array", "items": typescript_to_json_schema_type(item_type)}
|
||||
|
||||
# Handle Array<T> syntax
|
||||
array_match = re.match(r"Array<(.+)>", ts_type)
|
||||
if array_match:
|
||||
item_type = array_match.group(1)
|
||||
return {"type": "array", "items": typescript_to_json_schema_type(item_type)}
|
||||
|
||||
# Handle union types (simplified - just use string)
|
||||
if "|" in ts_type:
|
||||
# For union types, we'll default to string for simplicity
|
||||
# A more sophisticated parser could handle this better
|
||||
return {"type": "string"}
|
||||
|
||||
# Handle object types (simplified)
|
||||
if ts_type.startswith("{") and ts_type.endswith("}"):
|
||||
return {"type": "object"}
|
||||
|
||||
# Handle Record<K, V> and similar generic types
|
||||
record_match = re.match(r"Record<(.+),\s*(.+)>", ts_type)
|
||||
if record_match:
|
||||
return {"type": "object", "additionalProperties": typescript_to_json_schema_type(record_match.group(2))}
|
||||
|
||||
# Default case - treat unknown types as objects
|
||||
return {"type": "object"}
|
||||
@@ -22,7 +22,7 @@ from letta.functions.schema_generator import (
|
||||
generate_tool_schema_for_mcp,
|
||||
)
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ToolType
|
||||
from letta.schemas.enums import ToolSourceType, ToolType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.npm_requirement import NpmRequirement
|
||||
from letta.schemas.pip_requirement import PipRequirement
|
||||
@@ -76,27 +76,42 @@ class Tool(BaseTool):
|
||||
"""
|
||||
from letta.functions.helpers import generate_model_from_args_json_schema
|
||||
|
||||
if self.tool_type is ToolType.CUSTOM:
|
||||
if self.tool_type == ToolType.CUSTOM:
|
||||
if not self.source_code:
|
||||
logger.error("Custom tool with id=%s is missing source_code field", self.id)
|
||||
raise ValueError(f"Custom tool with id={self.id} is missing source_code field.")
|
||||
|
||||
# Always derive json_schema for freshest possible json_schema
|
||||
if self.args_json_schema is not None:
|
||||
name, description = get_function_name_and_docstring(self.source_code, self.name)
|
||||
args_schema = generate_model_from_args_json_schema(self.args_json_schema)
|
||||
self.json_schema = generate_schema_from_args_schema_v2(
|
||||
args_schema=args_schema,
|
||||
name=name,
|
||||
description=description,
|
||||
append_heartbeat=False,
|
||||
)
|
||||
else: # elif not self.json_schema: # TODO: JSON schema is not being derived correctly the first time?
|
||||
# If there's not a json_schema provided, then we need to re-derive
|
||||
try:
|
||||
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
|
||||
except Exception as e:
|
||||
logger.error("Failed to derive json schema for tool with id=%s name=%s: %s", self.id, self.name, e)
|
||||
if self.source_type == ToolSourceType.typescript:
|
||||
# TypeScript tools don't support args_json_schema, only direct schema generation
|
||||
if not self.json_schema:
|
||||
try:
|
||||
from letta.functions.typescript_parser import derive_typescript_json_schema
|
||||
|
||||
self.json_schema = derive_typescript_json_schema(source_code=self.source_code)
|
||||
except Exception as e:
|
||||
logger.error("Failed to derive TypeScript json schema for tool with id=%s name=%s: %s", self.id, self.name, e)
|
||||
elif (
|
||||
self.source_type == ToolSourceType.python or self.source_type is None
|
||||
): # default to python if not provided for backwards compatability
|
||||
# Python tool handling
|
||||
# Always derive json_schema for freshest possible json_schema
|
||||
if self.args_json_schema is not None:
|
||||
name, description = get_function_name_and_docstring(self.source_code, self.name)
|
||||
args_schema = generate_model_from_args_json_schema(self.args_json_schema)
|
||||
self.json_schema = generate_schema_from_args_schema_v2(
|
||||
args_schema=args_schema,
|
||||
name=name,
|
||||
description=description,
|
||||
append_heartbeat=False,
|
||||
)
|
||||
else: # elif not self.json_schema: # TODO: JSON schema is not being derived correctly the first time?
|
||||
# If there's not a json_schema provided, then we need to re-derive
|
||||
try:
|
||||
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
|
||||
except Exception as e:
|
||||
logger.error("Failed to derive json schema for tool with id=%s name=%s: %s", self.id, self.name, e)
|
||||
else:
|
||||
raise ValueError(f"Unknown tool source type: {self.source_type}")
|
||||
elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE, ToolType.LETTA_SLEEPTIME_CORE}:
|
||||
# If it's letta core tool, we generate the json_schema on the fly here
|
||||
self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name)
|
||||
|
||||
@@ -40,7 +40,7 @@ from letta.schemas.block import Block, BlockUpdate, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
# openai schemas
|
||||
from letta.schemas.enums import JobStatus, MessageStreamStatus, ProviderCategory, ProviderType, SandboxType
|
||||
from letta.schemas.enums import JobStatus, MessageStreamStatus, ProviderCategory, ProviderType, SandboxType, ToolSourceType
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.group import GroupCreate, ManagerType, SleeptimeManager, VoiceSleeptimeManager
|
||||
from letta.schemas.job import Job, JobUpdate
|
||||
@@ -1903,12 +1903,19 @@ class SyncServer(Server):
|
||||
pip_requirements: Optional[List[PipRequirement]] = None,
|
||||
) -> ToolReturnMessage:
|
||||
"""Run a tool from source code"""
|
||||
if tool_source_type is not None and tool_source_type != "python":
|
||||
raise ValueError("Only Python source code is supported at this time")
|
||||
|
||||
if tool_source_type not in (None, ToolSourceType.python, ToolSourceType.typescript):
|
||||
raise ValueError("Tool source type is not supported at this time. Found {tool_source_type}")
|
||||
|
||||
# If tools_json_schema is explicitly passed in, override it on the created Tool object
|
||||
if tool_json_schema:
|
||||
tool = Tool(name=tool_name, source_code=tool_source, json_schema=tool_json_schema, pip_requirements=pip_requirements)
|
||||
tool = Tool(
|
||||
name=tool_name,
|
||||
source_code=tool_source,
|
||||
json_schema=tool_json_schema,
|
||||
pip_requirements=pip_requirements,
|
||||
source_type=tool_source_type,
|
||||
)
|
||||
else:
|
||||
# NOTE: we're creating a floating Tool object and NOT persisting to DB
|
||||
tool = Tool(
|
||||
@@ -1916,6 +1923,7 @@ class SyncServer(Server):
|
||||
source_code=tool_source,
|
||||
args_json_schema=tool_args_json_schema,
|
||||
pip_requirements=pip_requirements,
|
||||
source_type=tool_source_type,
|
||||
)
|
||||
|
||||
assert tool.name is not None, "Failed to create tool object"
|
||||
|
||||
@@ -5,7 +5,7 @@ from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_fun
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import SandboxType
|
||||
from letta.schemas.enums import SandboxType, ToolSourceType
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
@@ -19,11 +19,6 @@ from letta.utils import get_friendly_error_msg
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if tool_settings.e2b_api_key:
|
||||
from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B
|
||||
if tool_settings.modal_api_key:
|
||||
from letta.services.tool_sandbox.modal_sandbox import AsyncToolSandboxModal
|
||||
|
||||
|
||||
class SandboxToolExecutor(ToolExecutor):
|
||||
"""Executor for sandboxed tools."""
|
||||
@@ -54,13 +49,35 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
|
||||
# Execute in sandbox depending on API key
|
||||
if tool_settings.sandbox_type == SandboxType.E2B:
|
||||
from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B
|
||||
|
||||
sandbox = AsyncToolSandboxE2B(
|
||||
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
|
||||
)
|
||||
# TODO (cliandy): this is just for testing right now, separate this out into it's own subclass and handling logic
|
||||
elif tool_settings.sandbox_type == SandboxType.MODAL:
|
||||
sandbox = AsyncToolSandboxModal(
|
||||
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
|
||||
)
|
||||
from letta.services.tool_sandbox.modal_sandbox import AsyncToolSandboxModal, TypescriptToolSandboxModal
|
||||
|
||||
if tool.source_type == ToolSourceType.typescript:
|
||||
sandbox = TypescriptToolSandboxModal(
|
||||
function_name,
|
||||
function_args,
|
||||
actor,
|
||||
tool_object=tool,
|
||||
sandbox_config=sandbox_config,
|
||||
sandbox_env_vars=sandbox_env_vars,
|
||||
)
|
||||
elif tool.source_type == ToolSourceType.python:
|
||||
sandbox = AsyncToolSandboxModal(
|
||||
function_name,
|
||||
function_args,
|
||||
actor,
|
||||
tool_object=tool,
|
||||
sandbox_config=sandbox_config,
|
||||
sandbox_env_vars=sandbox_env_vars,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Tool source type was {tool.source_type} but is required to be python or typescript to run in Modal.")
|
||||
else:
|
||||
sandbox = AsyncToolSandboxLocal(
|
||||
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
|
||||
|
||||
@@ -17,6 +17,9 @@ from letta.utils import get_friendly_error_msg
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# class AsyncToolSandboxModalBase(AsyncToolSandboxBase):
|
||||
# pass
|
||||
|
||||
|
||||
class AsyncToolSandboxModal(AsyncToolSandboxBase):
|
||||
def __init__(
|
||||
@@ -30,8 +33,8 @@ class AsyncToolSandboxModal(AsyncToolSandboxBase):
|
||||
):
|
||||
super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars)
|
||||
|
||||
if not tool_settings.modal_api_key:
|
||||
raise ValueError("Modal API key is required but not set in tool_settings.modal_api_key")
|
||||
if not tool_settings.modal_token_id or not tool_settings.modal_token_secret:
|
||||
raise ValueError("MODAL_TOKEN_ID and MODAL_TOKEN_SECRET must be set.")
|
||||
|
||||
# Create a unique app name based on tool and config
|
||||
self._app_name = self._generate_app_name()
|
||||
@@ -42,7 +45,12 @@ class AsyncToolSandboxModal(AsyncToolSandboxBase):
|
||||
|
||||
async def _fetch_or_create_modal_app(self, sbx_config: SandboxConfig, env_vars: Dict[str, str]) -> modal.App:
|
||||
"""Create a Modal app with the tool function registered."""
|
||||
app = await modal.App.lookup.aio(self._app_name)
|
||||
try:
|
||||
app = await modal.App.lookup.aio(self._app_name)
|
||||
return app
|
||||
except:
|
||||
app = modal.App(self._app_name)
|
||||
|
||||
modal_config = sbx_config.get_modal_config()
|
||||
|
||||
# Get the base image with dependencies
|
||||
@@ -96,6 +104,7 @@ class AsyncToolSandboxModal(AsyncToolSandboxBase):
|
||||
|
||||
# Execute the tool remotely
|
||||
with app.run():
|
||||
# app = modal.Cls.from_name(app.name, "NodeShimServer")()
|
||||
result = app.remote_executor.remote(execution_script, envs)
|
||||
|
||||
# Process the result
|
||||
@@ -203,3 +212,209 @@ class AsyncToolSandboxModal(AsyncToolSandboxBase):
|
||||
so we should use asyncio.run() like local execution.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class TypescriptToolSandboxModal(AsyncToolSandboxModal):
|
||||
"""Modal sandbox implementation for TypeScript tools."""
|
||||
|
||||
@trace_method
|
||||
async def run(
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""Run TypeScript tool in Modal sandbox using Node.js server."""
|
||||
if self.provided_sandbox_config:
|
||||
sbx_config = self.provided_sandbox_config
|
||||
else:
|
||||
sbx_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async(
|
||||
sandbox_type=SandboxType.MODAL, actor=self.user
|
||||
)
|
||||
|
||||
envs = await self._gather_env_vars(agent_state, additional_env_vars or {}, sbx_config.id, is_local=False)
|
||||
|
||||
# Generate execution script (JSON args for TypeScript)
|
||||
json_args = await self.generate_execution_script(agent_state=agent_state)
|
||||
|
||||
try:
|
||||
log_event(
|
||||
"modal_typescript_execution_started",
|
||||
{"tool": self.tool_name, "app_name": self._app_name, "args": json_args},
|
||||
)
|
||||
|
||||
# Create Modal app with the TypeScript Node.js server
|
||||
app = await self._fetch_or_create_modal_app(sbx_config, envs)
|
||||
|
||||
# Execute the TypeScript tool remotely via the Node.js server
|
||||
with app.run():
|
||||
# Get the NodeShimServer class from Modal
|
||||
node_server = modal.Cls.from_name(self._app_name, "NodeShimServer")
|
||||
|
||||
# Call the remote_executor method with the JSON arguments
|
||||
# The server will parse the JSON and call the TypeScript function
|
||||
result = node_server().remote_executor.remote(json_args)
|
||||
|
||||
# Process the TypeScript execution result
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
# Handle errors from TypeScript execution
|
||||
logger.debug(f"TypeScript tool {self.tool_name} raised an error: {result['error']}")
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name,
|
||||
exception_name="TypeScriptError",
|
||||
exception_message=str(result["error"]),
|
||||
)
|
||||
log_event(
|
||||
"modal_typescript_execution_failed",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"app_name": self._app_name,
|
||||
"error": result["error"],
|
||||
"func_return": func_return,
|
||||
},
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=None, # TypeScript tools don't support agent_state yet
|
||||
stdout=[],
|
||||
stderr=[str(result["error"])],
|
||||
status="error",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
else:
|
||||
# Success case - TypeScript function returned a result
|
||||
func_return = str(result) if result is not None else ""
|
||||
log_event(
|
||||
"modal_typescript_execution_succeeded",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"app_name": self._app_name,
|
||||
"func_return": func_return,
|
||||
},
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=None, # TypeScript tools don't support agent_state yet
|
||||
stdout=[],
|
||||
stderr=[],
|
||||
status="success",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Modal TypeScript execution for tool {self.tool_name} encountered an error: {e}")
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name,
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
log_event(
|
||||
"modal_typescript_execution_error",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"app_name": self._app_name,
|
||||
"error": str(e),
|
||||
"func_return": func_return,
|
||||
},
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=None,
|
||||
stdout=[],
|
||||
stderr=[str(e)],
|
||||
status="error",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
async def _fetch_or_create_modal_app(self, sbx_config: SandboxConfig, env_vars: Dict[str, str]) -> modal.App:
|
||||
"""Create or fetch a Modal app with TypeScript execution capabilities."""
|
||||
try:
|
||||
return await modal.App.lookup.aio(self._app_name)
|
||||
except:
|
||||
app = modal.App(self._app_name)
|
||||
|
||||
modal_config = sbx_config.get_modal_config()
|
||||
|
||||
# Get the base image with dependencies
|
||||
image = self._get_modal_image(sbx_config)
|
||||
|
||||
# Import the NodeShimServer that will handle TypeScript execution
|
||||
from sandbox.node_server import NodeShimServer
|
||||
|
||||
# Register the NodeShimServer class with Modal
|
||||
# This creates a serverless function that can handle concurrent requests
|
||||
app.cls(image=image, restrict_modal_access=True, include_source=False, timeout=modal_config.timeout if modal_config else 60)(
|
||||
modal.concurrent(max_inputs=100, target_inputs=50)(NodeShimServer)
|
||||
)
|
||||
|
||||
# Deploy the app to Modal
|
||||
with modal.enable_output():
|
||||
await app.deploy.aio()
|
||||
|
||||
return app
|
||||
|
||||
async def generate_execution_script(self, agent_state: Optional[AgentState], wrap_print_with_markers: bool = False) -> str:
|
||||
"""Generate the execution script for TypeScript tools.
|
||||
|
||||
For TypeScript tools, this returns the JSON-encoded arguments that will be passed
|
||||
to the Node.js server via the remote_executor method.
|
||||
"""
|
||||
import json
|
||||
|
||||
# Convert args to JSON string for TypeScript execution
|
||||
# The Node.js server expects JSON-encoded arguments
|
||||
return json.dumps(self.args)
|
||||
|
||||
def _get_modal_image(self, sbx_config: SandboxConfig) -> modal.Image:
|
||||
"""Build a Modal image with Node.js, TypeScript, and the user's tool function."""
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
# Find the sandbox module location
|
||||
spec = importlib.util.find_spec("sandbox")
|
||||
if not spec or not spec.origin:
|
||||
raise ValueError("Could not find sandbox module")
|
||||
server_dir = Path(spec.origin).parent
|
||||
|
||||
# Get the TypeScript function source code
|
||||
if not self.tool or not self.tool.source_code:
|
||||
raise ValueError("TypeScript tool must have source code")
|
||||
|
||||
ts_function = self.tool.source_code
|
||||
|
||||
# Get npm dependencies from sandbox config and tool
|
||||
modal_config = sbx_config.get_modal_config()
|
||||
npm_dependencies = []
|
||||
|
||||
# Add dependencies from sandbox config
|
||||
if modal_config and modal_config.npm_requirements:
|
||||
npm_dependencies.extend(modal_config.npm_requirements)
|
||||
|
||||
# Add dependencies from the tool itself
|
||||
if self.tool.npm_requirements:
|
||||
npm_dependencies.extend(self.tool.npm_requirements)
|
||||
|
||||
# Build npm install command for user dependencies
|
||||
user_dependencies_cmd = ""
|
||||
if npm_dependencies:
|
||||
# Ensure unique dependencies
|
||||
unique_deps = list(set(npm_dependencies))
|
||||
user_dependencies_cmd = " && npm install " + " ".join(unique_deps)
|
||||
|
||||
# Escape single quotes in the TypeScript function for shell command
|
||||
escaped_ts_function = ts_function.replace("'", "'\\''")
|
||||
|
||||
# Build the Docker image with Node.js and TypeScript
|
||||
image = (
|
||||
modal.Image.from_registry("node:22-slim", add_python="3.12")
|
||||
.add_local_dir(server_dir, "/root/sandbox", ignore=["node_modules", "build"], copy=True)
|
||||
.run_commands(
|
||||
# Install dependencies and build the TypeScript server
|
||||
f"cd /root/sandbox/resources/server && npm install{user_dependencies_cmd}",
|
||||
# Write the user's TypeScript function to a file
|
||||
f"echo '{escaped_ts_function}' > /root/sandbox/user-function.ts",
|
||||
)
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
# probably need to do parse_stdout_best_effort
|
||||
|
||||
@@ -18,7 +18,8 @@ class ToolSettings(BaseSettings):
|
||||
e2b_api_key: str | None = Field(default=None, description="API key for using E2B as a tool sandbox")
|
||||
e2b_sandbox_template_id: str | None = Field(default=None, description="Template ID for E2B Sandbox. Updated Manually.")
|
||||
|
||||
modal_api_key: str | None = Field(default=None, description="API key for using Modal as a tool sandbox")
|
||||
modal_token_id: str | None = Field(default=None, description="Token id for using Modal as a tool sandbox")
|
||||
modal_token_secret: str | None = Field(default=None, description="Token secret for using Modal as a tool sandbox")
|
||||
|
||||
# Search Providers
|
||||
tavily_api_key: str | None = Field(default=None, description="API key for using Tavily as a search provider.")
|
||||
@@ -41,7 +42,7 @@ class ToolSettings(BaseSettings):
|
||||
def sandbox_type(self) -> SandboxType:
|
||||
if self.e2b_api_key:
|
||||
return SandboxType.E2B
|
||||
elif self.modal_api_key:
|
||||
elif self.modal_token_id and self.modal_token_secret:
|
||||
return SandboxType.MODAL
|
||||
else:
|
||||
return SandboxType.LOCAL
|
||||
|
||||
0
sandbox/__init__.py
Normal file
0
sandbox/__init__.py
Normal file
79
sandbox/node_server.py
Normal file
79
sandbox/node_server.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import modal
|
||||
|
||||
|
||||
class NodeShimServer:
|
||||
# This runs once startup
|
||||
@modal.enter()
|
||||
def start_server(self):
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
server_root_dir = "/root/sandbox/resources/server"
|
||||
# /app/server
|
||||
|
||||
# Comment this in to show the updated user-function.ts file
|
||||
# subprocess.run(["sh", "-c", "cat /app/server/user-function.ts"], check=True)
|
||||
|
||||
subprocess.run(["sh", "-c", f"cd {server_root_dir} && npm run build"], check=True)
|
||||
subprocess.Popen(
|
||||
[
|
||||
"sh",
|
||||
"-c",
|
||||
f"cd {server_root_dir} && npm run start",
|
||||
],
|
||||
)
|
||||
|
||||
time.sleep(1)
|
||||
print("🔮 Node server started and listening on /tmp/my_unix_socket.sock")
|
||||
|
||||
@modal.method()
|
||||
def remote_executor(self, json_args: str): # Dynamic TypeScript function execution
|
||||
"""Execute a TypeScript function with JSON-encoded arguments.
|
||||
|
||||
Args:
|
||||
json_args: JSON string containing the function arguments
|
||||
|
||||
Returns:
|
||||
The result from the TypeScript function execution
|
||||
"""
|
||||
import http.client
|
||||
import json
|
||||
import socket
|
||||
|
||||
class UnixSocketHTTPConnection(http.client.HTTPConnection):
|
||||
def __init__(self, path):
|
||||
super().__init__("localhost")
|
||||
self.unix_path = path
|
||||
|
||||
def connect(self):
|
||||
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
self.sock.connect(self.unix_path)
|
||||
|
||||
try:
|
||||
# Connect to the Node.js server via Unix socket
|
||||
conn = UnixSocketHTTPConnection("/tmp/my_unix_socket.sock")
|
||||
|
||||
# Send the JSON arguments directly to the server
|
||||
# The server will parse them and call the TypeScript function
|
||||
conn.request("POST", "/", body=json_args)
|
||||
response = conn.getresponse()
|
||||
output = response.read().decode()
|
||||
|
||||
# Parse the response from the server
|
||||
try:
|
||||
output_json = json.loads(output)
|
||||
|
||||
# Check if there was an error
|
||||
if "error" in output_json:
|
||||
return {"error": output_json["error"]}
|
||||
|
||||
# Return the successful result
|
||||
return output_json.get("result")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# If the response isn't valid JSON, it's likely an error message
|
||||
return {"error": f"Invalid JSON response from TypeScript server: {output}"}
|
||||
|
||||
except Exception as e:
|
||||
# Handle connection or other errors
|
||||
return {"error": f"Error executing TypeScript function: {str(e)}"}
|
||||
15
sandbox/resources/server/README.md
Normal file
15
sandbox/resources/server/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# TS Server
|
||||
|
||||
Skeleton typescript app to support user-defined tool call function. Runs inside Modal container.
|
||||
|
||||
## Overview
|
||||
|
||||
- `server.ts` - node process listening on a unix socket
|
||||
- `entrypoint.ts` - light function that deserializes JSON encoded input string to inputs into user defined function
|
||||
- `user-function.ts` - fully defined by the user
|
||||
|
||||
## Instructions
|
||||
|
||||
1. `npm install`
|
||||
2. `npm run build`
|
||||
3. `npm run start` to start the server
|
||||
42
sandbox/resources/server/entrypoint.ts
Normal file
42
sandbox/resources/server/entrypoint.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
import * as userModule from "./user-function.js";
|
||||
|
||||
/**
|
||||
* Entrypoint for the user function.
|
||||
* Dynamically finds and executes the exported TypeScript function.
|
||||
*
|
||||
* @param encoded_input - JSON encoded input
|
||||
*/
|
||||
export function runUserFunction(encoded_input: string): { result: any; error?: string } {
|
||||
try {
|
||||
const input = JSON.parse(encoded_input);
|
||||
|
||||
// Find the first exported function from the user module
|
||||
const functionNames = Object.keys(userModule).filter(
|
||||
key => typeof userModule[key] === 'function'
|
||||
);
|
||||
|
||||
if (functionNames.length === 0) {
|
||||
return {
|
||||
result: null,
|
||||
error: "No exported function found in user-function.ts"
|
||||
};
|
||||
}
|
||||
|
||||
// Use the first exported function (TypeScript tools should only export one)
|
||||
const functionName = functionNames[0];
|
||||
const userFunction = userModule[functionName];
|
||||
|
||||
// Call the function with the provided arguments
|
||||
// The arguments are passed as an object, so we need to extract them
|
||||
// in the order expected by the function
|
||||
const result = userFunction(...Object.values(input));
|
||||
|
||||
return { result };
|
||||
} catch (error) {
|
||||
// Return error information for debugging
|
||||
return {
|
||||
result: null,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
};
|
||||
}
|
||||
}
|
||||
45
sandbox/resources/server/package-lock.json
generated
Normal file
45
sandbox/resources/server/package-lock.json
generated
Normal file
@@ -0,0 +1,45 @@
|
||||
{
|
||||
"name": "app",
|
||||
"version": "1.0.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "app",
|
||||
"version": "1.0.0",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"@types/node": "^24.1.0",
|
||||
"typescript": "^5.8.3"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/node": {
|
||||
"version": "24.1.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-24.1.0.tgz",
|
||||
"integrity": "sha512-ut5FthK5moxFKH2T1CUOC6ctR67rQRvvHdFLCD2Ql6KXmMuCrjsSsRI9UsLCm9M18BMwClv4pn327UvB7eeO1w==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"undici-types": "~7.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/typescript": {
|
||||
"version": "5.8.3",
|
||||
"resolved": "https://registry.npmjs.org/typescript/-/typescript-5.8.3.tgz",
|
||||
"integrity": "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==",
|
||||
"license": "Apache-2.0",
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.17"
|
||||
}
|
||||
},
|
||||
"node_modules/undici-types": {
|
||||
"version": "7.8.0",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.8.0.tgz",
|
||||
"integrity": "sha512-9UJ2xGDvQ43tYyVMpuHlsgApydB8ZKfVYTsLDhXkFL/6gfkp+U8xTGdh8pMJv1SpZna0zxG1DwsKZsreLbXBxw==",
|
||||
"license": "MIT"
|
||||
}
|
||||
}
|
||||
}
|
||||
19
sandbox/resources/server/package.json
Normal file
19
sandbox/resources/server/package.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"name": "app",
|
||||
"type": "module",
|
||||
"version": "1.0.0",
|
||||
"description": "Skeleton typescript app to support user-defined tool call function",
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"build": "tsc",
|
||||
"start": "node build/server.js",
|
||||
"test": "echo \"Error: no test specified\" && exit 1"
|
||||
},
|
||||
"keywords": [],
|
||||
"author": "",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"@types/node": "^24.1.0",
|
||||
"typescript": "^5.8.3"
|
||||
}
|
||||
}
|
||||
43
sandbox/resources/server/server.ts
Normal file
43
sandbox/resources/server/server.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
import { createServer } from "http";
|
||||
import { unlinkSync, existsSync } from "fs";
|
||||
import { runUserFunction } from "./entrypoint.js";
|
||||
|
||||
const SOCKET_PATH = "/tmp/my_unix_socket.sock";
|
||||
|
||||
// Remove old socket if it exists
|
||||
if (existsSync(SOCKET_PATH)) {
|
||||
try {
|
||||
unlinkSync(SOCKET_PATH);
|
||||
} catch (err) {
|
||||
console.error("Failed to remove old socket:", err);
|
||||
}
|
||||
}
|
||||
|
||||
const server = createServer((req, res) => {
|
||||
let data = "";
|
||||
|
||||
req.on("data", chunk => {
|
||||
data += chunk;
|
||||
});
|
||||
|
||||
req.on("end", () => {
|
||||
try {
|
||||
if (data.length > 0){
|
||||
const response = runUserFunction(data);
|
||||
res.writeHead(200);
|
||||
res.end(JSON.stringify(response));
|
||||
}
|
||||
} catch (err) {
|
||||
res.writeHead(400);
|
||||
res.end("[Server] Error: " + err);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
server.on("error", (err) => {
|
||||
console.error("[Server] Error:", err);
|
||||
});
|
||||
|
||||
server.listen(SOCKET_PATH, () => {
|
||||
console.log("[Server] Listening on", SOCKET_PATH);
|
||||
});
|
||||
12
sandbox/resources/server/tsconfig.json
Normal file
12
sandbox/resources/server/tsconfig.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "Node",
|
||||
"strict": true,
|
||||
"outDir": "build",
|
||||
"types": ["node"],
|
||||
},
|
||||
"include": ["entrypoint.ts", "server.ts", "user-function.ts"]
|
||||
}
|
||||
|
||||
2
sandbox/resources/server/user-function.ts
Normal file
2
sandbox/resources/server/user-function.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
// THIS FILE CONTAINS USER DEFINED CODE THAT MAY BE OVERWRITTEN.
|
||||
export function repeatString(str: string, multiplier: number): string {return str.repeat(multiplier);}
|
||||
Reference in New Issue
Block a user