* cherrypick just relevant commits? * make work with poetry * update poetry? * regen? * change tests and dev to dependency groups instead of optional extras * Fix Poetry/UV compatibility issues - Fix sqlite-vec dependency: Remove optional flag from Poetry section to match main deps - Regenerate poetry.lock to sync with pyproject.toml changes - Test both package managers successfully: - Poetry: `poetry install --with dev --with test -E postgres -E external-tools -E cloud-tool-sandbox` - UV: `uv sync --group dev --group test --extra postgres --extra external-tools --extra cloud-tool-sandbox` Resolves Poetry lock sync errors and ensures sqlite-vec is available for tests. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * more robust pip install * Fix fern SDK wheel installation in CI workflow Replace unreliable command substitution with proper error handling: - Check if directory exists before attempting to find wheels - Store wheel file path in variable to avoid empty arguments - Provide clear error messages when directory/wheels are missing - Prevents "required arguments were not provided" error in uv pip install Fixes: error: the following required arguments were not provided: <PACKAGE> 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * debugging * trigger CI * ls * revert whl installation to -e * programmatic HIT version insertion * version templating properly * set var properly * labelling * remove version insertion * ? * try using sed '2r /dev/stdin' * version * try again smh * not trigger on poetry version * only add once * filter only for project not poetry * hand re-construct the file * save tail? * fix docker command * please please please * rename test -> tests * update poetry and rename group to -E * move async into tests extra and regen lock files and add sqlite extra * remove loading cached venv from cloud api integration * add uv dependency to CI runners * test removing the custom event loop * regen poetry.lock and try to fix async tests * wrap async pg exception and event loop tweak in plugins * remove event loop from plugins test and remove caching from cloud-api-integration-test * migrate all tests away from event loop for pytest-asyncio * pin firecrawl * pin e2b * take claude's suggestion * deeper down the claude rabbit hole * increase timeout for httpbin.org --------- Co-authored-by: Claude <noreply@anthropic.com>
815 lines
30 KiB
Python
815 lines
30 KiB
Python
"""
|
|
Integration tests for Modal Sandbox V2.
|
|
|
|
These tests cover:
|
|
- Basic tool execution with Modal
|
|
- Error handling and edge cases
|
|
- Async tool execution
|
|
- Version tracking and redeployment
|
|
- Persistence of deployment metadata
|
|
- Concurrent execution handling
|
|
- Multiple sandbox configurations
|
|
- Service restart scenarios
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
import uuid
|
|
from datetime import datetime
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from letta.schemas.enums import ToolSourceType
|
|
from letta.schemas.organization import Organization
|
|
from letta.schemas.pip_requirement import PipRequirement
|
|
from letta.schemas.sandbox_config import ModalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxType
|
|
from letta.schemas.tool import Tool
|
|
from letta.schemas.user import User
|
|
from letta.services.organization_manager import OrganizationManager
|
|
from letta.services.sandbox_config_manager import SandboxConfigManager
|
|
from letta.services.tool_sandbox.modal_sandbox_v2 import AsyncToolSandboxModalV2
|
|
from letta.services.tool_sandbox.modal_version_manager import ModalVersionManager, get_version_manager
|
|
from letta.services.user_manager import UserManager
|
|
|
|
# ============================================================================
|
|
# SHARED FIXTURES
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def test_organization():
|
|
"""Create a test organization in the database."""
|
|
org_manager = OrganizationManager()
|
|
org = org_manager.create_organization(Organization(name=f"test-org-{uuid.uuid4().hex[:8]}"))
|
|
yield org
|
|
# Cleanup would go here if needed
|
|
|
|
|
|
@pytest.fixture
|
|
def test_user(test_organization):
|
|
"""Create a test user in the database."""
|
|
user_manager = UserManager()
|
|
user = user_manager.create_user(User(name=f"test-user-{uuid.uuid4().hex[:8]}", organization_id=test_organization.id))
|
|
yield user
|
|
# Cleanup would go here if needed
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_user():
|
|
"""Create a mock user for tests that don't need database persistence."""
|
|
user = MagicMock()
|
|
user.organization_id = f"test-org-{uuid.uuid4().hex[:8]}"
|
|
user.id = f"user-{uuid.uuid4().hex[:8]}"
|
|
return user
|
|
|
|
|
|
@pytest.fixture
|
|
def basic_tool(test_user):
|
|
"""Create a basic tool for testing."""
|
|
from letta.services.tool_manager import ToolManager
|
|
|
|
tool = Tool(
|
|
id=f"tool-{uuid.uuid4().hex[:8]}",
|
|
name="calculate",
|
|
source_type=ToolSourceType.python,
|
|
source_code="""
|
|
def calculate(operation: str, a: float, b: float) -> float:
|
|
'''Perform a calculation on two numbers.
|
|
|
|
Args:
|
|
operation: The operation to perform (add, subtract, multiply, divide)
|
|
a: The first number
|
|
b: The second number
|
|
|
|
Returns:
|
|
float: The result of the calculation
|
|
'''
|
|
if operation == "add":
|
|
return a + b
|
|
elif operation == "subtract":
|
|
return a - b
|
|
elif operation == "multiply":
|
|
return a * b
|
|
elif operation == "divide":
|
|
if b == 0:
|
|
raise ValueError("Cannot divide by zero")
|
|
return a / b
|
|
else:
|
|
raise ValueError(f"Unknown operation: {operation}")
|
|
""",
|
|
json_schema={
|
|
"parameters": {
|
|
"properties": {
|
|
"operation": {"type": "string", "description": "The operation to perform"},
|
|
"a": {"type": "number", "description": "The first number"},
|
|
"b": {"type": "number", "description": "The second number"},
|
|
}
|
|
}
|
|
},
|
|
)
|
|
|
|
# Create the tool in the database
|
|
tool_manager = ToolManager()
|
|
created_tool = tool_manager.create_or_update_tool(tool, actor=test_user)
|
|
yield created_tool
|
|
|
|
# Cleanup would go here if needed
|
|
|
|
|
|
@pytest.fixture
|
|
def async_tool(test_user):
|
|
"""Create an async tool for testing."""
|
|
from letta.services.tool_manager import ToolManager
|
|
|
|
tool = Tool(
|
|
id=f"tool-{uuid.uuid4().hex[:8]}",
|
|
name="fetch_data",
|
|
source_type=ToolSourceType.python,
|
|
source_code="""
|
|
import asyncio
|
|
|
|
async def fetch_data(url: str, delay: float = 0.1) -> Dict:
|
|
'''Simulate fetching data from a URL.
|
|
|
|
Args:
|
|
url: The URL to fetch data from
|
|
delay: The delay in seconds before returning
|
|
|
|
Returns:
|
|
Dict: A dictionary containing the fetched data
|
|
'''
|
|
await asyncio.sleep(delay)
|
|
return {
|
|
"url": url,
|
|
"status": "success",
|
|
"data": f"Data from {url}",
|
|
"timestamp": "2024-01-01T00:00:00Z"
|
|
}
|
|
""",
|
|
json_schema={
|
|
"parameters": {
|
|
"properties": {
|
|
"url": {"type": "string", "description": "The URL to fetch data from"},
|
|
"delay": {"type": "number", "default": 0.1, "description": "The delay in seconds"},
|
|
}
|
|
}
|
|
},
|
|
)
|
|
|
|
# Create the tool in the database
|
|
tool_manager = ToolManager()
|
|
created_tool = tool_manager.create_or_update_tool(tool, actor=test_user)
|
|
yield created_tool
|
|
|
|
# Cleanup would go here if needed
|
|
|
|
|
|
@pytest.fixture
|
|
def tool_with_dependencies(test_user):
|
|
"""Create a tool that requires external dependencies."""
|
|
from letta.services.tool_manager import ToolManager
|
|
|
|
tool = Tool(
|
|
id=f"tool-{uuid.uuid4().hex[:8]}",
|
|
name="process_json",
|
|
source_type=ToolSourceType.python,
|
|
source_code="""
|
|
import json
|
|
import hashlib
|
|
|
|
def process_json(data: str) -> Dict:
|
|
'''Process JSON data and return metadata.
|
|
|
|
Args:
|
|
data: The JSON string to process
|
|
|
|
Returns:
|
|
Dict: Metadata about the JSON data
|
|
'''
|
|
try:
|
|
parsed = json.loads(data)
|
|
data_hash = hashlib.md5(data.encode()).hexdigest()
|
|
|
|
return {
|
|
"valid": True,
|
|
"keys": list(parsed.keys()) if isinstance(parsed, dict) else None,
|
|
"type": type(parsed).__name__,
|
|
"hash": data_hash,
|
|
"size": len(data),
|
|
}
|
|
except json.JSONDecodeError as e:
|
|
return {
|
|
"valid": False,
|
|
"error": str(e),
|
|
"size": len(data),
|
|
}
|
|
""",
|
|
json_schema={
|
|
"parameters": {
|
|
"properties": {
|
|
"data": {"type": "string", "description": "The JSON string to process"},
|
|
}
|
|
}
|
|
},
|
|
pip_requirements=[PipRequirement(name="hashlib")], # Actually built-in, but for testing
|
|
)
|
|
|
|
# Create the tool in the database
|
|
tool_manager = ToolManager()
|
|
created_tool = tool_manager.create_or_update_tool(tool, actor=test_user)
|
|
yield created_tool
|
|
|
|
# Cleanup would go here if needed
|
|
|
|
|
|
@pytest.fixture
|
|
def sandbox_config(test_user):
|
|
"""Create a test sandbox configuration in the database."""
|
|
manager = SandboxConfigManager()
|
|
modal_config = ModalSandboxConfig(
|
|
timeout=60,
|
|
pip_requirements=["pandas==2.0.0"],
|
|
)
|
|
config_create = SandboxConfigCreate(config=modal_config.model_dump())
|
|
config = manager.create_or_update_sandbox_config(sandbox_config_create=config_create, actor=test_user)
|
|
yield config
|
|
# Cleanup would go here if needed
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_sandbox_config():
|
|
"""Create a mock sandbox configuration for tests that don't need database persistence."""
|
|
modal_config = ModalSandboxConfig(
|
|
timeout=60,
|
|
pip_requirements=["pandas==2.0.0"],
|
|
)
|
|
return SandboxConfig(
|
|
id=f"sandbox-{uuid.uuid4().hex[:8]}",
|
|
type=SandboxType.MODAL,
|
|
config=modal_config.model_dump(),
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# BASIC EXECUTION TESTS (Requires Modal credentials)
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
True or not os.getenv("MODAL_TOKEN_ID") or not os.getenv("MODAL_TOKEN_SECRET"), reason="Modal credentials not configured"
|
|
)
|
|
class TestModalV2BasicExecution:
|
|
"""Basic execution tests with Modal."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_basic_execution(self, basic_tool, test_user):
|
|
"""Test basic tool execution with different operations."""
|
|
sandbox = AsyncToolSandboxModalV2(
|
|
tool_name="calculate",
|
|
args={"operation": "add", "a": 5, "b": 3},
|
|
user=test_user,
|
|
tool_object=basic_tool,
|
|
)
|
|
|
|
result = await sandbox.run()
|
|
assert result.status == "success"
|
|
assert result.func_return == 8.0
|
|
|
|
# Test division
|
|
sandbox2 = AsyncToolSandboxModalV2(
|
|
tool_name="calculate",
|
|
args={"operation": "divide", "a": 10, "b": 2},
|
|
user=test_user,
|
|
tool_object=basic_tool,
|
|
)
|
|
|
|
result2 = await sandbox2.run()
|
|
assert result2.status == "success"
|
|
assert result2.func_return == 5.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_handling(self, basic_tool, test_user):
|
|
"""Test error handling in tool execution."""
|
|
# Test division by zero
|
|
sandbox = AsyncToolSandboxModalV2(
|
|
tool_name="calculate",
|
|
args={"operation": "divide", "a": 10, "b": 0},
|
|
user=test_user,
|
|
tool_object=basic_tool,
|
|
)
|
|
|
|
result = await sandbox.run()
|
|
assert result.status == "error"
|
|
assert "Cannot divide by zero" in str(result.func_return)
|
|
|
|
# Test unknown operation
|
|
sandbox2 = AsyncToolSandboxModalV2(
|
|
tool_name="calculate",
|
|
args={"operation": "unknown", "a": 1, "b": 2},
|
|
user=test_user,
|
|
tool_object=basic_tool,
|
|
)
|
|
|
|
result2 = await sandbox2.run()
|
|
assert result2.status == "error"
|
|
assert "Unknown operation" in str(result2.func_return)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_tool_execution(self, async_tool, test_user):
|
|
"""Test execution of async tools."""
|
|
sandbox = AsyncToolSandboxModalV2(
|
|
tool_name="fetch_data",
|
|
args={"url": "https://example.com", "delay": 0.01},
|
|
user=test_user,
|
|
tool_object=async_tool,
|
|
)
|
|
|
|
result = await sandbox.run()
|
|
assert result.status == "success"
|
|
|
|
# Parse the result (it should be a dict)
|
|
data = result.func_return
|
|
assert isinstance(data, dict)
|
|
assert data["url"] == "https://example.com"
|
|
assert data["status"] == "success"
|
|
assert "Data from https://example.com" in data["data"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_executions(self, basic_tool, test_user):
|
|
"""Test that concurrent executions work correctly."""
|
|
# Create multiple sandboxes with different arguments
|
|
sandboxes = [
|
|
AsyncToolSandboxModalV2(
|
|
tool_name="calculate",
|
|
args={"operation": "add", "a": i, "b": i + 1},
|
|
user=test_user,
|
|
tool_object=basic_tool,
|
|
)
|
|
for i in range(5)
|
|
]
|
|
|
|
# Execute all concurrently
|
|
results = await asyncio.gather(*[s.run() for s in sandboxes])
|
|
|
|
# Verify all succeeded with correct results
|
|
for i, result in enumerate(results):
|
|
assert result.status == "success"
|
|
expected = i + (i + 1) # a + b
|
|
assert result.func_return == expected
|
|
|
|
|
|
# ============================================================================
|
|
# PERSISTENCE AND VERSION TRACKING TESTS
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestModalV2Persistence:
|
|
"""Tests for deployment persistence and version tracking."""
|
|
|
|
async def test_deployment_persists_in_tool_metadata(self, mock_user, sandbox_config):
|
|
"""Test that deployment info is correctly stored in tool metadata."""
|
|
tool = Tool(
|
|
id=f"tool-{uuid.uuid4().hex[:8]}",
|
|
name="calculate",
|
|
source_code="def calculate(x: float) -> float:\n '''Double a number.\n \n Args:\n x: The number to double\n \n Returns:\n The doubled value\n '''\n return x * 2",
|
|
json_schema={"parameters": {"properties": {"x": {"type": "number"}}}},
|
|
metadata_={},
|
|
)
|
|
|
|
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
|
|
mock_tool_manager = MockToolManager.return_value
|
|
mock_tool_manager.get_tool_by_id.return_value = tool
|
|
mock_tool_manager.update_tool_by_id_async = AsyncMock(return_value=tool)
|
|
|
|
version_manager = ModalVersionManager()
|
|
|
|
# Register a deployment
|
|
app_name = f"{mock_user.organization_id}-{tool.name}-v2"
|
|
version_hash = "abc123def456"
|
|
mock_app = MagicMock()
|
|
|
|
await version_manager.register_deployment(
|
|
tool_id=tool.id,
|
|
app_name=app_name,
|
|
version_hash=version_hash,
|
|
app=mock_app,
|
|
dependencies={"pandas", "numpy"},
|
|
sandbox_config_id=sandbox_config.id,
|
|
actor=mock_user,
|
|
)
|
|
|
|
# Verify update was called with correct metadata
|
|
mock_tool_manager.update_tool_by_id_async.assert_called_once()
|
|
call_args = mock_tool_manager.update_tool_by_id_async.call_args
|
|
|
|
metadata = call_args[1]["tool_update"].metadata_
|
|
assert "modal_deployments" in metadata
|
|
assert sandbox_config.id in metadata["modal_deployments"]
|
|
|
|
deployment_data = metadata["modal_deployments"][sandbox_config.id]
|
|
assert deployment_data["app_name"] == app_name
|
|
assert deployment_data["version_hash"] == version_hash
|
|
assert set(deployment_data["dependencies"]) == {"pandas", "numpy"}
|
|
|
|
async def test_version_tracking_and_redeployment(self, mock_user, basic_tool, sandbox_config):
|
|
"""Test version tracking and redeployment on code changes."""
|
|
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
|
|
mock_tool_manager = MockToolManager.return_value
|
|
mock_tool_manager.get_tool_by_id.return_value = basic_tool
|
|
|
|
# Track metadata updates
|
|
metadata_store = {}
|
|
|
|
async def update_tool(*args, **kwargs):
|
|
metadata_store.update(kwargs.get("metadata_", {}))
|
|
basic_tool.metadata_ = metadata_store
|
|
return basic_tool
|
|
|
|
mock_tool_manager.update_tool_by_id_async = AsyncMock(side_effect=update_tool)
|
|
|
|
version_manager = ModalVersionManager()
|
|
app_name = f"{mock_user.organization_id}-{basic_tool.name}-v2"
|
|
|
|
# First deployment
|
|
version1 = "version1hash"
|
|
await version_manager.register_deployment(
|
|
tool_id=basic_tool.id,
|
|
app_name=app_name,
|
|
version_hash=version1,
|
|
app=MagicMock(),
|
|
sandbox_config_id=sandbox_config.id,
|
|
actor=mock_user,
|
|
)
|
|
|
|
# Should not need redeployment with same version
|
|
assert not await version_manager.needs_redeployment(basic_tool.id, version1, sandbox_config.id, actor=mock_user)
|
|
|
|
# Should need redeployment with different version
|
|
version2 = "version2hash"
|
|
assert await version_manager.needs_redeployment(basic_tool.id, version2, sandbox_config.id, actor=mock_user)
|
|
|
|
async def test_deployment_survives_service_restart(self, mock_user, sandbox_config):
|
|
"""Test that deployment info survives a service restart."""
|
|
tool_id = f"tool-{uuid.uuid4().hex[:8]}"
|
|
app_name = f"{mock_user.organization_id}-calculate-v2"
|
|
version_hash = "restart-test-v1"
|
|
|
|
# Simulate existing deployment in metadata
|
|
existing_metadata = {
|
|
"modal_deployments": {
|
|
sandbox_config.id: {
|
|
"app_name": app_name,
|
|
"version_hash": version_hash,
|
|
"deployed_at": datetime.now().isoformat(),
|
|
"dependencies": ["pandas"],
|
|
}
|
|
}
|
|
}
|
|
|
|
tool = Tool(
|
|
id=tool_id,
|
|
name="calculate",
|
|
source_code="def calculate(x: float) -> float:\n '''Identity function.\n \n Args:\n x: The input value\n \n Returns:\n The same value\n '''\n return x",
|
|
json_schema={"parameters": {"properties": {}}},
|
|
metadata_=existing_metadata,
|
|
)
|
|
|
|
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
|
|
mock_tool_manager = MockToolManager.return_value
|
|
mock_tool_manager.get_tool_by_id.return_value = tool
|
|
|
|
# Create new version manager (simulating service restart)
|
|
version_manager = ModalVersionManager()
|
|
|
|
# Should be able to retrieve existing deployment
|
|
deployment = await version_manager.get_deployment(tool_id, sandbox_config.id, actor=mock_user)
|
|
|
|
assert deployment is not None
|
|
assert deployment.app_name == app_name
|
|
assert deployment.version_hash == version_hash
|
|
assert deployment.dependencies == {"pandas"}
|
|
|
|
# Should not need redeployment with same version
|
|
assert not await version_manager.needs_redeployment(tool_id, version_hash, sandbox_config.id, actor=mock_user)
|
|
|
|
async def test_different_sandbox_configs_same_tool(self, mock_user):
|
|
"""Test that different sandbox configs can have different deployments for the same tool."""
|
|
tool = Tool(
|
|
id=f"tool-{uuid.uuid4().hex[:8]}",
|
|
name="multi_config",
|
|
source_code="def test(x: int) -> int:\n '''Test function.\n \n Args:\n x: The input value\n \n Returns:\n The same value\n '''\n return x",
|
|
json_schema={"parameters": {"properties": {}}},
|
|
metadata_={},
|
|
)
|
|
|
|
# Create two different sandbox configs
|
|
config1 = SandboxConfig(
|
|
id=f"sandbox-{uuid.uuid4().hex[:8]}",
|
|
type=SandboxType.MODAL,
|
|
config=ModalSandboxConfig(timeout=30, pip_requirements=["pandas"]).model_dump(),
|
|
)
|
|
|
|
config2 = SandboxConfig(
|
|
id=f"sandbox-{uuid.uuid4().hex[:8]}",
|
|
type=SandboxType.MODAL,
|
|
config=ModalSandboxConfig(timeout=60, pip_requirements=["numpy"]).model_dump(),
|
|
)
|
|
|
|
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
|
|
mock_tool_manager = MockToolManager.return_value
|
|
mock_tool_manager.get_tool_by_id.return_value = tool
|
|
|
|
# Track all metadata updates
|
|
all_metadata = {"modal_deployments": {}}
|
|
|
|
async def update_tool(*args, **kwargs):
|
|
new_meta = kwargs.get("metadata_", {})
|
|
if "modal_deployments" in new_meta:
|
|
all_metadata["modal_deployments"].update(new_meta["modal_deployments"])
|
|
tool.metadata_ = all_metadata
|
|
return tool
|
|
|
|
mock_tool_manager.update_tool_by_id_async = AsyncMock(side_effect=update_tool)
|
|
|
|
version_manager = ModalVersionManager()
|
|
app_name = f"{mock_user.organization_id}-{tool.name}-v2"
|
|
|
|
# Deploy with config1
|
|
await version_manager.register_deployment(
|
|
tool_id=tool.id,
|
|
app_name=app_name,
|
|
version_hash="config1-hash",
|
|
app=MagicMock(),
|
|
sandbox_config_id=config1.id,
|
|
actor=mock_user,
|
|
)
|
|
|
|
# Deploy with config2
|
|
await version_manager.register_deployment(
|
|
tool_id=tool.id,
|
|
app_name=app_name,
|
|
version_hash="config2-hash",
|
|
app=MagicMock(),
|
|
sandbox_config_id=config2.id,
|
|
actor=mock_user,
|
|
)
|
|
|
|
# Both deployments should exist
|
|
deployment1 = await version_manager.get_deployment(tool.id, config1.id, actor=mock_user)
|
|
deployment2 = await version_manager.get_deployment(tool.id, config2.id, actor=mock_user)
|
|
|
|
assert deployment1 is not None
|
|
assert deployment2 is not None
|
|
assert deployment1.version_hash == "config1-hash"
|
|
assert deployment2.version_hash == "config2-hash"
|
|
|
|
async def test_sandbox_config_changes_trigger_redeployment(self, basic_tool, mock_user):
|
|
"""Test that sandbox config changes trigger redeployment."""
|
|
# Skip the actual Modal deployment part in this test
|
|
# Just test the version hash calculation changes
|
|
|
|
config1 = SandboxConfig(
|
|
id=f"sandbox-{uuid.uuid4().hex[:8]}",
|
|
type=SandboxType.MODAL,
|
|
config=ModalSandboxConfig(timeout=30).model_dump(),
|
|
)
|
|
|
|
config2 = SandboxConfig(
|
|
id=f"sandbox-{uuid.uuid4().hex[:8]}",
|
|
type=SandboxType.MODAL,
|
|
config=ModalSandboxConfig(
|
|
timeout=60,
|
|
pip_requirements=["requests"],
|
|
).model_dump(),
|
|
)
|
|
|
|
# Mock the Modal credentials to allow sandbox instantiation
|
|
with patch("letta.services.tool_sandbox.modal_sandbox_v2.tool_settings") as mock_settings:
|
|
mock_settings.modal_token_id = "test-token-id"
|
|
mock_settings.modal_token_secret = "test-token-secret"
|
|
|
|
sandbox1 = AsyncToolSandboxModalV2(
|
|
tool_name="calculate",
|
|
args={"operation": "add", "a": 1, "b": 1},
|
|
user=mock_user,
|
|
tool_object=basic_tool,
|
|
sandbox_config=config1,
|
|
)
|
|
|
|
sandbox2 = AsyncToolSandboxModalV2(
|
|
tool_name="calculate",
|
|
args={"operation": "add", "a": 2, "b": 2},
|
|
user=mock_user,
|
|
tool_object=basic_tool,
|
|
sandbox_config=config2,
|
|
)
|
|
|
|
# Version hashes should be different due to config changes
|
|
version1 = sandbox1._deployment_manager.calculate_version_hash(config1)
|
|
version2 = sandbox2._deployment_manager.calculate_version_hash(config2)
|
|
assert version1 != version2
|
|
|
|
|
|
# ============================================================================
|
|
# MOCKED INTEGRATION TESTS (No Modal credentials required)
|
|
# ============================================================================
|
|
|
|
|
|
class TestModalV2MockedIntegration:
|
|
"""Integration tests with mocked Modal components."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_integration_with_persistence(self, mock_user, sandbox_config):
|
|
"""Test the full Modal sandbox V2 integration with persistence."""
|
|
tool = Tool(
|
|
id=f"tool-{uuid.uuid4().hex[:8]}",
|
|
name="integration_test",
|
|
source_code="""
|
|
def calculate(operation: str, a: float, b: float) -> float:
|
|
'''Perform a simple calculation'''
|
|
if operation == "add":
|
|
return a + b
|
|
return 0
|
|
""",
|
|
json_schema={
|
|
"parameters": {
|
|
"properties": {
|
|
"operation": {"type": "string"},
|
|
"a": {"type": "number"},
|
|
"b": {"type": "number"},
|
|
}
|
|
}
|
|
},
|
|
metadata_={},
|
|
)
|
|
|
|
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
|
|
with patch("letta.services.tool_sandbox.modal_sandbox_v2.modal") as mock_modal:
|
|
mock_tool_manager = MockToolManager.return_value
|
|
mock_tool_manager.get_tool_by_id.return_value = tool
|
|
|
|
# Track metadata updates
|
|
async def update_tool(*args, **kwargs):
|
|
tool.metadata_ = kwargs.get("metadata_", {})
|
|
return tool
|
|
|
|
mock_tool_manager.update_tool_by_id_async = update_tool
|
|
|
|
# Mock Modal app
|
|
mock_app = MagicMock()
|
|
mock_app.run = MagicMock()
|
|
|
|
# Mock the function decorator
|
|
def mock_function_decorator(*args, **kwargs):
|
|
def decorator(func):
|
|
mock_func = MagicMock()
|
|
mock_func.remote = MagicMock()
|
|
mock_func.remote.aio = AsyncMock(
|
|
return_value={
|
|
"result": 8,
|
|
"agent_state": None,
|
|
"stdout": "",
|
|
"stderr": "",
|
|
"error": None,
|
|
}
|
|
)
|
|
mock_app.tool_executor = mock_func
|
|
return mock_func
|
|
|
|
return decorator
|
|
|
|
mock_app.function = mock_function_decorator
|
|
mock_app.deploy = MagicMock()
|
|
mock_app.deploy.aio = AsyncMock()
|
|
|
|
mock_modal.App.return_value = mock_app
|
|
|
|
# Mock the sandbox config manager
|
|
with patch("letta.services.tool_sandbox.base.SandboxConfigManager") as MockSCM:
|
|
mock_scm = MockSCM.return_value
|
|
mock_scm.get_sandbox_env_vars_as_dict_async = AsyncMock(return_value={})
|
|
|
|
# Create sandbox
|
|
sandbox = AsyncToolSandboxModalV2(
|
|
tool_name="integration_test",
|
|
args={"operation": "add", "a": 5, "b": 3},
|
|
user=mock_user,
|
|
tool_object=tool,
|
|
sandbox_config=sandbox_config,
|
|
)
|
|
|
|
# Mock version manager methods through deployment manager
|
|
version_manager = sandbox._deployment_manager.version_manager
|
|
if version_manager:
|
|
with patch.object(version_manager, "get_deployment", return_value=None):
|
|
with patch.object(version_manager, "register_deployment", return_value=None):
|
|
# First execution - should deploy
|
|
result1 = await sandbox.run()
|
|
assert result1.status == "success"
|
|
assert result1.func_return == 8
|
|
else:
|
|
# If no version manager, just run
|
|
result1 = await sandbox.run()
|
|
assert result1.status == "success"
|
|
assert result1.func_return == 8
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_deployment_handling(self, mock_user, sandbox_config):
|
|
"""Test that concurrent deployment requests are handled correctly."""
|
|
tool = Tool(
|
|
id=f"tool-{uuid.uuid4().hex[:8]}",
|
|
name="concurrent_test",
|
|
source_code="def test(x: int) -> int:\n '''Test function.\n \n Args:\n x: The input value\n \n Returns:\n The same value\n '''\n return x",
|
|
json_schema={"parameters": {"properties": {}}},
|
|
metadata_={},
|
|
)
|
|
|
|
with patch("letta.services.tool_sandbox.modal_version_manager.ToolManager") as MockToolManager:
|
|
mock_tool_manager = MockToolManager.return_value
|
|
mock_tool_manager.get_tool_by_id.return_value = tool
|
|
|
|
# Track update calls
|
|
update_calls = []
|
|
|
|
async def track_update(*args, **kwargs):
|
|
update_calls.append((args, kwargs))
|
|
await asyncio.sleep(0.01) # Simulate slight delay
|
|
return tool
|
|
|
|
mock_tool_manager.update_tool_by_id_async = AsyncMock(side_effect=track_update)
|
|
|
|
version_manager = ModalVersionManager()
|
|
app_name = f"{mock_user.organization_id}-{tool.name}-v2"
|
|
version_hash = "concurrent123"
|
|
|
|
# Launch multiple concurrent deployments
|
|
tasks = []
|
|
for i in range(5):
|
|
task = version_manager.register_deployment(
|
|
tool_id=tool.id,
|
|
app_name=app_name,
|
|
version_hash=version_hash,
|
|
app=MagicMock(),
|
|
sandbox_config_id=sandbox_config.id,
|
|
actor=mock_user,
|
|
)
|
|
tasks.append(task)
|
|
|
|
# Wait for all to complete
|
|
await asyncio.gather(*tasks)
|
|
|
|
# All calls should complete (current implementation doesn't dedupe)
|
|
assert len(update_calls) == 5
|
|
|
|
|
|
# ============================================================================
|
|
# DEPLOYMENT STATISTICS TESTS
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.skipif(not os.getenv("MODAL_TOKEN_ID") or not os.getenv("MODAL_TOKEN_SECRET"), reason="Modal credentials not configured")
|
|
class TestModalV2DeploymentStats:
|
|
"""Tests for deployment statistics tracking."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deployment_stats(self, basic_tool, async_tool, test_user):
|
|
"""Test deployment statistics tracking."""
|
|
version_manager = get_version_manager()
|
|
|
|
# Clear any existing deployments (for test isolation)
|
|
version_manager.clear_deployments()
|
|
|
|
# Ensure clean state
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Deploy multiple tools
|
|
tools = [basic_tool, async_tool]
|
|
for tool in tools:
|
|
sandbox = AsyncToolSandboxModalV2(
|
|
tool_name=tool.name,
|
|
args={},
|
|
user=test_user,
|
|
tool_object=tool,
|
|
)
|
|
await sandbox.run()
|
|
|
|
# Get stats
|
|
stats = await version_manager.get_deployment_stats()
|
|
|
|
assert stats["total_deployments"] >= 2
|
|
assert stats["active_deployments"] >= 2
|
|
assert stats["stale_deployments"] == 0
|
|
|
|
# Check individual deployment info
|
|
for deployment in stats["deployments"]:
|
|
assert "app_name" in deployment
|
|
assert "version" in deployment
|
|
assert "usage_count" in deployment
|
|
assert deployment["usage_count"] >= 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "-s"])
|