Files
letta-server/tests/integration_test_modal_sandbox_v2.py

825 lines
30 KiB
Python

"""
Integration tests for Modal Sandbox V2.
These tests cover:
- Basic tool execution with Modal
- Error handling and edge cases
- Async tool execution
- Version tracking and redeployment
- Persistence of deployment metadata
- Concurrent execution handling
- Multiple sandbox configurations
- Service restart scenarios
"""
import asyncio
import os
import uuid
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from letta.schemas.enums import ToolSourceType
from letta.schemas.organization import Organization
from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.sandbox_config import ModalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxType
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_sandbox.modal_sandbox_v2 import AsyncToolSandboxModalV2
from letta.services.tool_sandbox.modal_version_manager import ModalVersionManager, get_version_manager
from letta.services.user_manager import UserManager
# ============================================================================
# SHARED FIXTURES
# ============================================================================
@pytest.fixture
def test_organization():
"""Create a test organization in the database."""
org_manager = OrganizationManager()
org = org_manager.create_organization(Organization(name=f"test-org-{uuid.uuid4().hex[:8]}"))
yield org
# Cleanup would go here if needed
@pytest.fixture
def test_user(test_organization):
"""Create a test user in the database."""
user_manager = UserManager()
user = user_manager.create_user(User(name=f"test-user-{uuid.uuid4().hex[:8]}", organization_id=test_organization.id))
yield user
# Cleanup would go here if needed
@pytest.fixture
def mock_user():
"""Create a mock user for tests that don't need database persistence."""
user = MagicMock()
user.organization_id = f"test-org-{uuid.uuid4().hex[:8]}"
user.id = f"user-{uuid.uuid4().hex[:8]}"
return user
@pytest.fixture
def basic_tool(test_user):
"""Create a basic tool for testing."""
from letta.services.tool_manager import ToolManager
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="calculate",
source_type=ToolSourceType.python,
source_code="""
def calculate(operation: str, a: float, b: float) -> float:
'''Perform a calculation on two numbers.
Args:
operation: The operation to perform (add, subtract, multiply, divide)
a: The first number
b: The second number
Returns:
float: The result of the calculation
'''
if operation == "add":
return a + b
elif operation == "subtract":
return a - b
elif operation == "multiply":
return a * b
elif operation == "divide":
if b == 0:
raise ValueError("Cannot divide by zero")
return a / b
else:
raise ValueError(f"Unknown operation: {operation}")
""",
json_schema={
"parameters": {
"properties": {
"operation": {"type": "string", "description": "The operation to perform"},
"a": {"type": "number", "description": "The first number"},
"b": {"type": "number", "description": "The second number"},
}
}
},
)
# Create the tool in the database
tool_manager = ToolManager()
created_tool = tool_manager.create_or_update_tool(tool, actor=test_user)
yield created_tool
# Cleanup would go here if needed
@pytest.fixture
def async_tool(test_user):
"""Create an async tool for testing."""
from letta.services.tool_manager import ToolManager
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="fetch_data",
source_type=ToolSourceType.python,
source_code="""
import asyncio
async def fetch_data(url: str, delay: float = 0.1) -> Dict:
'''Simulate fetching data from a URL.
Args:
url: The URL to fetch data from
delay: The delay in seconds before returning
Returns:
Dict: A dictionary containing the fetched data
'''
await asyncio.sleep(delay)
return {
"url": url,
"status": "success",
"data": f"Data from {url}",
"timestamp": "2024-01-01T00:00:00Z"
}
""",
json_schema={
"parameters": {
"properties": {
"url": {"type": "string", "description": "The URL to fetch data from"},
"delay": {"type": "number", "default": 0.1, "description": "The delay in seconds"},
}
}
},
)
# Create the tool in the database
tool_manager = ToolManager()
created_tool = tool_manager.create_or_update_tool(tool, actor=test_user)
yield created_tool
# Cleanup would go here if needed
@pytest.fixture
def tool_with_dependencies(test_user):
"""Create a tool that requires external dependencies."""
from letta.services.tool_manager import ToolManager
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="process_json",
source_type=ToolSourceType.python,
source_code="""
import json
import hashlib
def process_json(data: str) -> Dict:
'''Process JSON data and return metadata.
Args:
data: The JSON string to process
Returns:
Dict: Metadata about the JSON data
'''
try:
parsed = json.loads(data)
data_hash = hashlib.md5(data.encode()).hexdigest()
return {
"valid": True,
"keys": list(parsed.keys()) if isinstance(parsed, dict) else None,
"type": type(parsed).__name__,
"hash": data_hash,
"size": len(data),
}
except json.JSONDecodeError as e:
return {
"valid": False,
"error": str(e),
"size": len(data),
}
""",
json_schema={
"parameters": {
"properties": {
"data": {"type": "string", "description": "The JSON string to process"},
}
}
},
pip_requirements=[PipRequirement(name="hashlib")], # Actually built-in, but for testing
)
# Create the tool in the database
tool_manager = ToolManager()
created_tool = tool_manager.create_or_update_tool(tool, actor=test_user)
yield created_tool
# Cleanup would go here if needed
@pytest.fixture
def sandbox_config(test_user):
"""Create a test sandbox configuration in the database."""
manager = SandboxConfigManager()
modal_config = ModalSandboxConfig(
timeout=60,
pip_requirements=["pandas==2.0.0"],
)
config_create = SandboxConfigCreate(config=modal_config.model_dump())
config = manager.create_or_update_sandbox_config(sandbox_config_create=config_create, actor=test_user)
yield config
# Cleanup would go here if needed
@pytest.fixture
def mock_sandbox_config():
"""Create a mock sandbox configuration for tests that don't need database persistence."""
modal_config = ModalSandboxConfig(
timeout=60,
pip_requirements=["pandas==2.0.0"],
)
return SandboxConfig(
id=f"sandbox-{uuid.uuid4().hex[:8]}",
type=SandboxType.MODAL,
config=modal_config.model_dump(),
)
# ============================================================================
# BASIC EXECUTION TESTS (Requires Modal credentials)
# ============================================================================
@pytest.mark.skipif(
True or not os.getenv("MODAL_TOKEN_ID") or not os.getenv("MODAL_TOKEN_SECRET"), reason="Modal credentials not configured"
)
class TestModalV2BasicExecution:
"""Basic execution tests with Modal."""
@pytest.mark.asyncio
async def test_basic_execution(self, basic_tool, test_user):
"""Test basic tool execution with different operations."""
sandbox = AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "add", "a": 5, "b": 3},
user=test_user,
tool_id=basic_tool.id,
tool_object=basic_tool,
)
result = await sandbox.run()
assert result.status == "success"
assert result.func_return == 8.0
# Test division
sandbox2 = AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "divide", "a": 10, "b": 2},
user=test_user,
tool_id=basic_tool.id,
tool_object=basic_tool,
)
result2 = await sandbox2.run()
assert result2.status == "success"
assert result2.func_return == 5.0
@pytest.mark.asyncio
async def test_error_handling(self, basic_tool, test_user):
"""Test error handling in tool execution."""
# Test division by zero
sandbox = AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "divide", "a": 10, "b": 0},
user=test_user,
tool_id=basic_tool.id,
tool_object=basic_tool,
)
result = await sandbox.run()
assert result.status == "error"
assert "Cannot divide by zero" in str(result.func_return)
# Test unknown operation
sandbox2 = AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "unknown", "a": 1, "b": 2},
user=test_user,
tool_id=basic_tool.id,
tool_object=basic_tool,
)
result2 = await sandbox2.run()
assert result2.status == "error"
assert "Unknown operation" in str(result2.func_return)
@pytest.mark.asyncio
async def test_async_tool_execution(self, async_tool, test_user):
"""Test execution of async tools."""
sandbox = AsyncToolSandboxModalV2(
tool_name="fetch_data",
args={"url": "https://example.com", "delay": 0.01},
user=test_user,
tool_id=async_tool.id,
tool_object=async_tool,
)
result = await sandbox.run()
assert result.status == "success"
# Parse the result (it should be a dict)
data = result.func_return
assert isinstance(data, dict)
assert data["url"] == "https://example.com"
assert data["status"] == "success"
assert "Data from https://example.com" in data["data"]
@pytest.mark.asyncio
async def test_concurrent_executions(self, basic_tool, test_user):
"""Test that concurrent executions work correctly."""
# Create multiple sandboxes with different arguments
sandboxes = [
AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "add", "a": i, "b": i + 1},
user=test_user,
tool_id=basic_tool.id,
tool_object=basic_tool,
)
for i in range(5)
]
# Execute all concurrently
results = await asyncio.gather(*[s.run() for s in sandboxes])
# Verify all succeeded with correct results
for i, result in enumerate(results):
assert result.status == "success"
expected = i + (i + 1) # a + b
assert result.func_return == expected
# ============================================================================
# PERSISTENCE AND VERSION TRACKING TESTS
# ============================================================================
@pytest.mark.asyncio
class TestModalV2Persistence:
"""Tests for deployment persistence and version tracking."""
async def test_deployment_persists_in_tool_metadata(self, mock_user, sandbox_config):
"""Test that deployment info is correctly stored in tool metadata."""
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="calculate",
source_code="def calculate(x: float) -> float:\n '''Double a number.\n \n Args:\n x: The number to double\n \n Returns:\n The doubled value\n '''\n return x * 2",
json_schema={"parameters": {"properties": {"x": {"type": "number"}}}},
metadata_={},
)
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
mock_tool_manager = MockToolManager.return_value
mock_tool_manager.get_tool_by_id.return_value = tool
mock_tool_manager.update_tool_by_id_async = AsyncMock(return_value=tool)
version_manager = ModalVersionManager()
# Register a deployment
app_name = f"{mock_user.organization_id}-{tool.name}-v2"
version_hash = "abc123def456"
mock_app = MagicMock()
await version_manager.register_deployment(
tool_id=tool.id,
app_name=app_name,
version_hash=version_hash,
app=mock_app,
dependencies={"pandas", "numpy"},
sandbox_config_id=sandbox_config.id,
actor=mock_user,
)
# Verify update was called with correct metadata
mock_tool_manager.update_tool_by_id_async.assert_called_once()
call_args = mock_tool_manager.update_tool_by_id_async.call_args
metadata = call_args[1]["tool_update"].metadata_
assert "modal_deployments" in metadata
assert sandbox_config.id in metadata["modal_deployments"]
deployment_data = metadata["modal_deployments"][sandbox_config.id]
assert deployment_data["app_name"] == app_name
assert deployment_data["version_hash"] == version_hash
assert set(deployment_data["dependencies"]) == {"pandas", "numpy"}
async def test_version_tracking_and_redeployment(self, mock_user, basic_tool, sandbox_config):
"""Test version tracking and redeployment on code changes."""
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
mock_tool_manager = MockToolManager.return_value
mock_tool_manager.get_tool_by_id.return_value = basic_tool
# Track metadata updates
metadata_store = {}
async def update_tool(*args, **kwargs):
metadata_store.update(kwargs.get("metadata_", {}))
basic_tool.metadata_ = metadata_store
return basic_tool
mock_tool_manager.update_tool_by_id_async = AsyncMock(side_effect=update_tool)
version_manager = ModalVersionManager()
app_name = f"{mock_user.organization_id}-{basic_tool.name}-v2"
# First deployment
version1 = "version1hash"
await version_manager.register_deployment(
tool_id=basic_tool.id,
app_name=app_name,
version_hash=version1,
app=MagicMock(),
sandbox_config_id=sandbox_config.id,
actor=mock_user,
)
# Should not need redeployment with same version
assert not await version_manager.needs_redeployment(basic_tool.id, version1, sandbox_config.id, actor=mock_user)
# Should need redeployment with different version
version2 = "version2hash"
assert await version_manager.needs_redeployment(basic_tool.id, version2, sandbox_config.id, actor=mock_user)
async def test_deployment_survives_service_restart(self, mock_user, sandbox_config):
"""Test that deployment info survives a service restart."""
tool_id = f"tool-{uuid.uuid4().hex[:8]}"
app_name = f"{mock_user.organization_id}-calculate-v2"
version_hash = "restart-test-v1"
# Simulate existing deployment in metadata
existing_metadata = {
"modal_deployments": {
sandbox_config.id: {
"app_name": app_name,
"version_hash": version_hash,
"deployed_at": datetime.now().isoformat(),
"dependencies": ["pandas"],
}
}
}
tool = Tool(
id=tool_id,
name="calculate",
source_code="def calculate(x: float) -> float:\n '''Identity function.\n \n Args:\n x: The input value\n \n Returns:\n The same value\n '''\n return x",
json_schema={"parameters": {"properties": {}}},
metadata_=existing_metadata,
)
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
mock_tool_manager = MockToolManager.return_value
mock_tool_manager.get_tool_by_id.return_value = tool
# Create new version manager (simulating service restart)
version_manager = ModalVersionManager()
# Should be able to retrieve existing deployment
deployment = await version_manager.get_deployment(tool_id, sandbox_config.id, actor=mock_user)
assert deployment is not None
assert deployment.app_name == app_name
assert deployment.version_hash == version_hash
assert deployment.dependencies == {"pandas"}
# Should not need redeployment with same version
assert not await version_manager.needs_redeployment(tool_id, version_hash, sandbox_config.id, actor=mock_user)
async def test_different_sandbox_configs_same_tool(self, mock_user):
"""Test that different sandbox configs can have different deployments for the same tool."""
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="multi_config",
source_code="def test(x: int) -> int:\n '''Test function.\n \n Args:\n x: The input value\n \n Returns:\n The same value\n '''\n return x",
json_schema={"parameters": {"properties": {}}},
metadata_={},
)
# Create two different sandbox configs
config1 = SandboxConfig(
id=f"sandbox-{uuid.uuid4().hex[:8]}",
type=SandboxType.MODAL,
config=ModalSandboxConfig(timeout=30, pip_requirements=["pandas"]).model_dump(),
)
config2 = SandboxConfig(
id=f"sandbox-{uuid.uuid4().hex[:8]}",
type=SandboxType.MODAL,
config=ModalSandboxConfig(timeout=60, pip_requirements=["numpy"]).model_dump(),
)
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
mock_tool_manager = MockToolManager.return_value
mock_tool_manager.get_tool_by_id.return_value = tool
# Track all metadata updates
all_metadata = {"modal_deployments": {}}
async def update_tool(*args, **kwargs):
new_meta = kwargs.get("metadata_", {})
if "modal_deployments" in new_meta:
all_metadata["modal_deployments"].update(new_meta["modal_deployments"])
tool.metadata_ = all_metadata
return tool
mock_tool_manager.update_tool_by_id_async = AsyncMock(side_effect=update_tool)
version_manager = ModalVersionManager()
app_name = f"{mock_user.organization_id}-{tool.name}-v2"
# Deploy with config1
await version_manager.register_deployment(
tool_id=tool.id,
app_name=app_name,
version_hash="config1-hash",
app=MagicMock(),
sandbox_config_id=config1.id,
actor=mock_user,
)
# Deploy with config2
await version_manager.register_deployment(
tool_id=tool.id,
app_name=app_name,
version_hash="config2-hash",
app=MagicMock(),
sandbox_config_id=config2.id,
actor=mock_user,
)
# Both deployments should exist
deployment1 = await version_manager.get_deployment(tool.id, config1.id, actor=mock_user)
deployment2 = await version_manager.get_deployment(tool.id, config2.id, actor=mock_user)
assert deployment1 is not None
assert deployment2 is not None
assert deployment1.version_hash == "config1-hash"
assert deployment2.version_hash == "config2-hash"
async def test_sandbox_config_changes_trigger_redeployment(self, basic_tool, mock_user):
"""Test that sandbox config changes trigger redeployment."""
# Skip the actual Modal deployment part in this test
# Just test the version hash calculation changes
config1 = SandboxConfig(
id=f"sandbox-{uuid.uuid4().hex[:8]}",
type=SandboxType.MODAL,
config=ModalSandboxConfig(timeout=30).model_dump(),
)
config2 = SandboxConfig(
id=f"sandbox-{uuid.uuid4().hex[:8]}",
type=SandboxType.MODAL,
config=ModalSandboxConfig(
timeout=60,
pip_requirements=["requests"],
).model_dump(),
)
# Mock the Modal credentials to allow sandbox instantiation
with patch("letta.services.tool_sandbox.modal_sandbox_v2.tool_settings") as mock_settings:
mock_settings.modal_token_id = "test-token-id"
mock_settings.modal_token_secret = "test-token-secret"
sandbox1 = AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "add", "a": 1, "b": 1},
user=mock_user,
tool_id=basic_tool.id,
tool_object=basic_tool,
sandbox_config=config1,
)
sandbox2 = AsyncToolSandboxModalV2(
tool_name="calculate",
args={"operation": "add", "a": 2, "b": 2},
user=mock_user,
tool_id=basic_tool.id,
tool_object=basic_tool,
sandbox_config=config2,
)
# Version hashes should be different due to config changes
version1 = sandbox1._deployment_manager.calculate_version_hash(config1)
version2 = sandbox2._deployment_manager.calculate_version_hash(config2)
assert version1 != version2
# ============================================================================
# MOCKED INTEGRATION TESTS (No Modal credentials required)
# ============================================================================
class TestModalV2MockedIntegration:
"""Integration tests with mocked Modal components."""
@pytest.mark.asyncio
async def test_full_integration_with_persistence(self, mock_user, sandbox_config):
"""Test the full Modal sandbox V2 integration with persistence."""
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="integration_test",
source_code="""
def calculate(operation: str, a: float, b: float) -> float:
'''Perform a simple calculation'''
if operation == "add":
return a + b
return 0
""",
json_schema={
"parameters": {
"properties": {
"operation": {"type": "string"},
"a": {"type": "number"},
"b": {"type": "number"},
}
}
},
metadata_={},
)
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
with patch("letta.services.tool_sandbox.modal_sandbox_v2.modal") as mock_modal:
mock_tool_manager = MockToolManager.return_value
mock_tool_manager.get_tool_by_id.return_value = tool
# Track metadata updates
async def update_tool(*args, **kwargs):
tool.metadata_ = kwargs.get("metadata_", {})
return tool
mock_tool_manager.update_tool_by_id_async = update_tool
# Mock Modal app
mock_app = MagicMock()
mock_app.run = MagicMock()
# Mock the function decorator
def mock_function_decorator(*args, **kwargs):
def decorator(func):
mock_func = MagicMock()
mock_func.remote = MagicMock()
mock_func.remote.aio = AsyncMock(
return_value={
"result": 8,
"agent_state": None,
"stdout": "",
"stderr": "",
"error": None,
}
)
mock_app.tool_executor = mock_func
return mock_func
return decorator
mock_app.function = mock_function_decorator
mock_app.deploy = MagicMock()
mock_app.deploy.aio = AsyncMock()
mock_modal.App.return_value = mock_app
# Mock the sandbox config manager
with patch("letta.services.tool_sandbox.base.SandboxConfigManager") as MockSCM:
mock_scm = MockSCM.return_value
mock_scm.get_sandbox_env_vars_as_dict_async = AsyncMock(return_value={})
# Create sandbox
sandbox = AsyncToolSandboxModalV2(
tool_name="integration_test",
args={"operation": "add", "a": 5, "b": 3},
user=mock_user,
tool_id=tool.id,
tool_object=tool,
sandbox_config=sandbox_config,
)
# Mock version manager methods through deployment manager
version_manager = sandbox._deployment_manager.version_manager
if version_manager:
with patch.object(version_manager, "get_deployment", return_value=None):
with patch.object(version_manager, "register_deployment", return_value=None):
# First execution - should deploy
result1 = await sandbox.run()
assert result1.status == "success"
assert result1.func_return == 8
else:
# If no version manager, just run
result1 = await sandbox.run()
assert result1.status == "success"
assert result1.func_return == 8
@pytest.mark.asyncio
async def test_concurrent_deployment_handling(self, mock_user, sandbox_config):
"""Test that concurrent deployment requests are handled correctly."""
tool = Tool(
id=f"tool-{uuid.uuid4().hex[:8]}",
name="concurrent_test",
source_code="def test(x: int) -> int:\n '''Test function.\n \n Args:\n x: The input value\n \n Returns:\n The same value\n '''\n return x",
json_schema={"parameters": {"properties": {}}},
metadata_={},
)
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
mock_tool_manager = MockToolManager.return_value
mock_tool_manager.get_tool_by_id.return_value = tool
# Track update calls
update_calls = []
async def track_update(*args, **kwargs):
update_calls.append((args, kwargs))
await asyncio.sleep(0.01) # Simulate slight delay
return tool
mock_tool_manager.update_tool_by_id_async = AsyncMock(side_effect=track_update)
version_manager = ModalVersionManager()
app_name = f"{mock_user.organization_id}-{tool.name}-v2"
version_hash = "concurrent123"
# Launch multiple concurrent deployments
tasks = []
for i in range(5):
task = version_manager.register_deployment(
tool_id=tool.id,
app_name=app_name,
version_hash=version_hash,
app=MagicMock(),
sandbox_config_id=sandbox_config.id,
actor=mock_user,
)
tasks.append(task)
# Wait for all to complete
await asyncio.gather(*tasks)
# All calls should complete (current implementation doesn't dedupe)
assert len(update_calls) == 5
# ============================================================================
# DEPLOYMENT STATISTICS TESTS
# ============================================================================
@pytest.mark.skipif(not os.getenv("MODAL_TOKEN_ID") or not os.getenv("MODAL_TOKEN_SECRET"), reason="Modal credentials not configured")
class TestModalV2DeploymentStats:
"""Tests for deployment statistics tracking."""
@pytest.mark.asyncio
async def test_deployment_stats(self, basic_tool, async_tool, test_user):
"""Test deployment statistics tracking."""
version_manager = get_version_manager()
# Clear any existing deployments (for test isolation)
version_manager.clear_deployments()
# Ensure clean state
await asyncio.sleep(0.1)
# Deploy multiple tools
tools = [basic_tool, async_tool]
for tool in tools:
sandbox = AsyncToolSandboxModalV2(
tool_name=tool.name,
args={},
user=test_user,
tool_id=tool.id,
tool_object=tool,
)
await sandbox.run()
# Get stats
stats = await version_manager.get_deployment_stats()
assert stats["total_deployments"] >= 2
assert stats["active_deployments"] >= 2
assert stats["stale_deployments"] == 0
# Check individual deployment info
for deployment in stats["deployments"]:
assert "app_name" in deployment
assert "version" in deployment
assert "usage_count" in deployment
assert deployment["usage_count"] >= 1
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])