import json import pickle from unittest.mock import AsyncMock, MagicMock, patch import pytest from letta.schemas.pip_requirement import PipRequirement from letta.schemas.sandbox_config import ModalSandboxConfig, SandboxConfig, SandboxType from letta.schemas.tool import Tool from letta.services.tool_sandbox.modal_sandbox_v2 import AsyncToolSandboxModalV2 from letta.services.tool_sandbox.modal_version_manager import ModalVersionManager from sandbox.modal_executor import ModalFunctionExecutor class TestModalFunctionExecutor: """Test the ModalFunctionExecutor class.""" def test_execute_tool_dynamic_success(self): """Test successful execution of a simple tool.""" tool_source = """ def add_numbers(a: int, b: int) -> int: return a + b """ args = {"a": 5, "b": 3} args_pickled = pickle.dumps(args) result = ModalFunctionExecutor.execute_tool_dynamic( tool_source=tool_source, tool_name="add_numbers", args_pickled=args_pickled, agent_state_pickled=None, inject_agent_state=False, is_async=False, args_schema_code=None, ) assert result["error"] is None assert result["result"] == 8 # Actual integer value assert result["agent_state"] is None def test_execute_tool_dynamic_with_error(self): """Test execution with an error.""" tool_source = """ def divide_numbers(a: int, b: int) -> float: return a / b """ args = {"a": 5, "b": 0} args_pickled = pickle.dumps(args) result = ModalFunctionExecutor.execute_tool_dynamic( tool_source=tool_source, tool_name="divide_numbers", args_pickled=args_pickled, agent_state_pickled=None, inject_agent_state=False, is_async=False, args_schema_code=None, ) assert result["error"] is not None assert result["error"]["name"] == "ZeroDivisionError" assert "division by zero" in result["error"]["value"] assert result["result"] is None def test_execute_async_tool(self): """Test execution of an async tool.""" tool_source = """ async def async_add(a: int, b: int) -> int: import asyncio await asyncio.sleep(0.001) return a + b """ args = {"a": 10, "b": 20} args_pickled = pickle.dumps(args) result = ModalFunctionExecutor.execute_tool_dynamic( tool_source=tool_source, tool_name="async_add", args_pickled=args_pickled, agent_state_pickled=None, inject_agent_state=False, is_async=True, args_schema_code=None, ) assert result["error"] is None assert result["result"] == 30 def test_execute_with_stdout_capture(self): """Test that stdout is properly captured.""" tool_source = """ def print_and_return(message: str) -> str: print(f"Processing: {message}") print("Done!") return message.upper() """ args = {"message": "hello"} args_pickled = pickle.dumps(args) result = ModalFunctionExecutor.execute_tool_dynamic( tool_source=tool_source, tool_name="print_and_return", args_pickled=args_pickled, agent_state_pickled=None, inject_agent_state=False, is_async=False, args_schema_code=None, ) assert result["error"] is None assert result["result"] == "HELLO" assert "Processing: hello" in result["stdout"] assert "Done!" in result["stdout"] class TestModalVersionManager: """Test the Modal Version Manager.""" @pytest.mark.asyncio async def test_register_and_get_deployment(self): """Test registering and retrieving deployments.""" from unittest.mock import AsyncMock from letta.schemas.user import User manager = ModalVersionManager() # Mock the tool manager mock_tool = MagicMock() mock_tool.id = "tool-abc12345" mock_tool.metadata_ = {} manager.tool_manager.get_tool_by_id = MagicMock(return_value=mock_tool) manager.tool_manager.update_tool_by_id_async = AsyncMock(return_value=mock_tool) # Create a mock actor mock_actor = MagicMock(spec=User) mock_actor.id = "user-123" # Register a deployment mock_app = MagicMock(spec=["deploy", "stop"]) info = await manager.register_deployment( tool_id="tool-abc12345", app_name="test-app", version_hash="abc123", app=mock_app, dependencies={"pandas", "numpy"}, sandbox_config_id="config-123", actor=mock_actor, ) assert info.app_name == "test-app" assert info.version_hash == "abc123" assert info.dependencies == {"pandas", "numpy"} # Retrieve the deployment retrieved = await manager.get_deployment("tool-abc12345", "config-123", actor=mock_actor) assert retrieved.app_name == info.app_name assert retrieved.version_hash == info.version_hash @pytest.mark.asyncio async def test_needs_redeployment(self): """Test checking if redeployment is needed.""" from unittest.mock import AsyncMock from letta.schemas.user import User manager = ModalVersionManager() # Mock the tool manager mock_tool = MagicMock() mock_tool.id = "tool-def45678" mock_tool.metadata_ = {} manager.tool_manager.get_tool_by_id = MagicMock(return_value=mock_tool) manager.tool_manager.update_tool_by_id_async = AsyncMock(return_value=mock_tool) # Create a mock actor mock_actor = MagicMock(spec=User) # No deployment exists yet assert await manager.needs_redeployment("tool-def45678", "v1", "config-123", actor=mock_actor) is True # Register a deployment mock_app = MagicMock() await manager.register_deployment( tool_id="tool-def45678", app_name="test-app", version_hash="v1", app=mock_app, sandbox_config_id="config-123", actor=mock_actor, ) # Update mock to return the registered deployment mock_tool.metadata_ = { "modal_deployments": { "config-123": { "app_name": "test-app", "version_hash": "v1", "deployed_at": "2024-01-01T00:00:00", "dependencies": [], } } } # Same version - no redeployment needed assert await manager.needs_redeployment("tool-def45678", "v1", "config-123", actor=mock_actor) is False # Different version - redeployment needed assert await manager.needs_redeployment("tool-def45678", "v2", "config-123", actor=mock_actor) is True @pytest.mark.skip(reason="get_deployment_stats method not implemented in ModalVersionManager") @pytest.mark.asyncio async def test_deployment_stats(self): """Test getting deployment statistics.""" from unittest.mock import AsyncMock from letta.schemas.user import User manager = ModalVersionManager() # Mock the tool manager mock_tools = {} for i in range(3): tool_id = f"tool-{i:08x}" mock_tool = MagicMock() mock_tool.id = tool_id mock_tool.metadata_ = {} mock_tools[tool_id] = mock_tool def get_tool_by_id(tool_id, actor=None): return mock_tools.get(tool_id) manager.tool_manager.get_tool_by_id = MagicMock(side_effect=get_tool_by_id) manager.tool_manager.update_tool_by_id_async = AsyncMock() # Create a mock actor mock_actor = MagicMock(spec=User) # Register multiple deployments for i in range(3): tool_id = f"tool-{i:08x}" mock_app = MagicMock() await manager.register_deployment( tool_id=tool_id, app_name=f"app-{i}", version_hash=f"v{i}", app=mock_app, sandbox_config_id="config-123", actor=mock_actor, ) stats = await manager.get_deployment_stats() # Note: The actual implementation may store deployments differently # This test assumes the stats method exists and returns expected format assert stats["total_deployments"] >= 0 # Adjust based on actual implementation assert "deployments" in stats @pytest.mark.skip(reason="export_state and import_state methods not implemented in ModalVersionManager") @pytest.mark.asyncio async def test_export_import_state(self): """Test exporting and importing deployment state.""" from unittest.mock import AsyncMock from letta.schemas.user import User manager1 = ModalVersionManager() # Mock the tool manager for manager1 mock_tools = { "tool-11111111": MagicMock(id="tool-11111111", metadata_={}), "tool-22222222": MagicMock(id="tool-22222222", metadata_={}), } def get_tool_by_id(tool_id, actor=None): return mock_tools.get(tool_id) manager1.tool_manager.get_tool_by_id = MagicMock(side_effect=get_tool_by_id) manager1.tool_manager.update_tool_by_id_async = AsyncMock() # Create a mock actor mock_actor = MagicMock(spec=User) # Register deployments mock_app = MagicMock() await manager1.register_deployment( tool_id="tool-11111111", app_name="app1", version_hash="v1", app=mock_app, dependencies={"dep1"}, sandbox_config_id="config-123", actor=mock_actor, ) await manager1.register_deployment( tool_id="tool-22222222", app_name="app2", version_hash="v2", app=mock_app, dependencies={"dep2", "dep3"}, sandbox_config_id="config-123", actor=mock_actor, ) # Export state state_json = await manager1.export_state() state = json.loads(state_json) # Verify exported state structure assert "tool-11111111" in state or "deployments" in state # Depends on implementation # Import into new manager manager2 = ModalVersionManager() manager2.tool_manager.get_tool_by_id = MagicMock(side_effect=get_tool_by_id) await manager2.import_state(state_json) # Note: The actual implementation may not have export/import methods # This test assumes they exist or should be modified based on actual API class TestAsyncToolSandboxModalV2: """Test the AsyncToolSandboxModalV2 class.""" @pytest.fixture def mock_tool(self): """Create a mock tool for testing.""" return Tool( id="tool-12345678", # Valid tool ID format name="test_function", source_code=""" def test_function(x: int, y: int) -> int: '''Add two numbers together.''' return x + y """, json_schema={ "parameters": { "properties": { "x": {"type": "integer"}, "y": {"type": "integer"}, } } }, pip_requirements=[PipRequirement(name="requests")], ) @pytest.fixture def mock_user(self): """Create a mock user for testing.""" user = MagicMock() user.organization_id = "test-org" return user @pytest.fixture def mock_sandbox_config(self): """Create a mock sandbox configuration.""" modal_config = ModalSandboxConfig( timeout=60, pip_requirements=["pandas"], ) config = SandboxConfig( id="sandbox-12345678", # Valid sandbox ID format type=SandboxType.MODAL, # Changed from sandbox_type to type config=modal_config.model_dump(), ) return config def test_version_hash_calculation(self, mock_tool, mock_user, mock_sandbox_config): """Test that version hash is calculated correctly.""" sandbox = AsyncToolSandboxModalV2( tool_name="test_function", args={"x": 1, "y": 2}, user=mock_user, tool_id=mock_tool.id, tool_object=mock_tool, sandbox_config=mock_sandbox_config, ) # Access through deployment manager version1 = sandbox._deployment_manager.calculate_version_hash(mock_sandbox_config) assert version1 # Should not be empty assert len(version1) == 12 # We take first 12 chars of hash # Same inputs should produce same hash version2 = sandbox._deployment_manager.calculate_version_hash(mock_sandbox_config) assert version1 == version2 # Changing tool code should change hash mock_tool.source_code = "def test_function(x, y): return x * y" sandbox2 = AsyncToolSandboxModalV2( tool_name="test_function", args={"x": 1, "y": 2}, user=mock_user, tool_id=mock_tool.id, tool_object=mock_tool, sandbox_config=mock_sandbox_config, ) version3 = sandbox2._deployment_manager.calculate_version_hash(mock_sandbox_config) assert version3 != version1 # Changing dependencies should also change hash mock_tool.source_code = "def test_function(x, y): return x + y" # Reset mock_tool.pip_requirements = [PipRequirement(name="numpy")] sandbox3 = AsyncToolSandboxModalV2( tool_name="test_function", args={"x": 1, "y": 2}, user=mock_user, tool_id=mock_tool.id, tool_object=mock_tool, sandbox_config=mock_sandbox_config, ) version4 = sandbox3._deployment_manager.calculate_version_hash(mock_sandbox_config) assert version4 != version1 # Changing sandbox config should change hash modal_config2 = ModalSandboxConfig( timeout=120, # Different timeout pip_requirements=["pandas"], ) config2 = SandboxConfig( id="sandbox-87654321", type=SandboxType.MODAL, config=modal_config2.model_dump(), ) version5 = sandbox3._deployment_manager.calculate_version_hash(config2) assert version5 != version4 def test_app_name_generation(self, mock_tool, mock_user): """Test app name generation.""" sandbox = AsyncToolSandboxModalV2( tool_name="test_function", args={"x": 1, "y": 2}, user=mock_user, tool_id=mock_tool.id, tool_object=mock_tool, ) # App name generation is now in deployment manager and uses tool ID app_name = sandbox._deployment_manager._generate_app_name() # App name is based on tool ID truncated to 40 chars assert app_name == mock_tool.id[:40] @pytest.mark.asyncio async def test_run_with_mocked_modal(self, mock_tool, mock_user, mock_sandbox_config): """Test the run method with mocked Modal components.""" with ( patch("letta.services.tool_sandbox.modal_sandbox_v2.modal") as mock_modal, patch("letta.services.tool_sandbox.modal_deployment_manager.modal") as mock_modal2, ): # Mock Modal app mock_app = MagicMock() # Use MagicMock for the app itself mock_app.run = MagicMock() # Mock the function decorator def mock_function_decorator(*args, **kwargs): def decorator(func): # Create a mock that has a remote attribute mock_func = MagicMock() mock_func.remote = mock_remote # Store the mocked function as tool_executor on the app mock_app.tool_executor = mock_func return mock_func return decorator mock_app.function = mock_function_decorator # Mock deployment mock_app.deploy = MagicMock() mock_app.deploy.aio = AsyncMock() # Mock the remote execution mock_remote = MagicMock() mock_remote.aio = AsyncMock( return_value={ "result": 3, # Return actual integer, not string "agent_state": None, "stdout": "Executing...", "stderr": "", "error": None, } ) mock_modal.App.return_value = mock_app mock_modal2.App.return_value = mock_app # Mock App.lookup.aio to handle app lookup attempts mock_modal.App.lookup = MagicMock() mock_modal.App.lookup.aio = AsyncMock(side_effect=Exception("App not found")) mock_modal2.App.lookup = MagicMock() mock_modal2.App.lookup.aio = AsyncMock(side_effect=Exception("App not found")) # Mock enable_output context manager mock_modal.enable_output = MagicMock() mock_modal.enable_output.return_value.__enter__ = MagicMock() mock_modal.enable_output.return_value.__exit__ = MagicMock() mock_modal2.enable_output = MagicMock() mock_modal2.enable_output.return_value.__enter__ = MagicMock() mock_modal2.enable_output.return_value.__exit__ = MagicMock() # Mock the SandboxConfigManager to avoid type checking issues with patch("letta.services.tool_sandbox.base.SandboxConfigManager") as MockSCM: mock_scm = MockSCM.return_value mock_scm.get_sandbox_env_vars_as_dict_async = AsyncMock(return_value={}) # Create sandbox sandbox = AsyncToolSandboxModalV2( tool_name="test_function", args={"x": 1, "y": 2}, user=mock_user, tool_id=mock_tool.id, tool_object=mock_tool, sandbox_config=mock_sandbox_config, ) # Mock the version manager through deployment manager version_manager = sandbox._deployment_manager.version_manager if version_manager: with patch.object(version_manager, "get_deployment", return_value=None): with patch.object(version_manager, "register_deployment", return_value=None): # Run the tool result = await sandbox.run() else: # If no version manager (use_version_tracking=False), just run result = await sandbox.run() assert result.func_return == 3 # Check for actual integer assert result.status == "success" assert "Executing..." in result.stdout[0] def test_detect_async_function(self, mock_user): """Test detection of async functions.""" # Test with sync function sync_tool = Tool( id="tool-abcdef12", # Valid tool ID format name="sync_func", source_code="def sync_func(x): return x", json_schema={"parameters": {"properties": {}}}, ) sandbox_sync = AsyncToolSandboxModalV2( tool_name="sync_func", args={}, user=mock_user, tool_id=sync_tool.id, tool_object=sync_tool, ) assert sandbox_sync._detect_async_function() is False # Test with async function async_tool = Tool( id="tool-fedcba21", # Valid tool ID format name="async_func", source_code="async def async_func(x): return x", json_schema={"parameters": {"properties": {}}}, ) sandbox_async = AsyncToolSandboxModalV2( tool_name="async_func", args={}, user=mock_user, tool_id=async_tool.id, tool_object=async_tool, ) assert sandbox_async._detect_async_function() is True if __name__ == "__main__": pytest.main([__file__, "-v"])