package main import ( "context" "crypto/aes" "crypto/cipher" "crypto/ed25519" "crypto/sha256" "encoding/hex" "fmt" "io" "log" "net/http" "os" "os/exec" "runtime" "time" "github.com/Fimeg/RedFlag/aggregator-agent/internal/acknowledgment" "github.com/Fimeg/RedFlag/aggregator-agent/internal/client" "github.com/Fimeg/RedFlag/aggregator-agent/internal/config" "github.com/Fimeg/RedFlag/aggregator-agent/internal/models" "github.com/Fimeg/RedFlag/aggregator-agent/internal/orchestrator" ) // handleScanStorage scans disk usage metrics only func handleScanStorage(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error { log.Println("Scanning storage...") ctx := context.Background() startTime := time.Now() // Execute storage scanner result, err := orch.ScanSingle(ctx, "storage") if err != nil { return fmt.Errorf("failed to scan storage: %w", err) } // Format results results := []orchestrator.ScanResult{result} stdout, stderr, exitCode := orchestrator.FormatScanSummary(results) duration := time.Since(startTime) stdout += fmt.Sprintf("\nStorage scan completed in %.2f seconds\n", duration.Seconds()) // [REMOVED logReport after ReportLog removal - unused] // logReport := client.LogReport{...} // Report storage metrics to server using dedicated endpoint // Use proper StorageMetricReport with clean field names storageScanner := orchestrator.NewStorageScanner(cfg.AgentVersion) var metrics []orchestrator.StorageMetric // Declare outside if block for ReportLog access if storageScanner.IsAvailable() { var err error metrics, err = storageScanner.ScanStorage() if err != nil { return fmt.Errorf("failed to scan storage metrics: %w", err) } if len(metrics) > 0 { // Convert from orchestrator.StorageMetric to models.StorageMetric metricItems := make([]models.StorageMetric, 0, len(metrics)) for _, m := range metrics { item := models.StorageMetric{ Mountpoint: m.Mountpoint, Device: m.Device, DiskType: m.DiskType, Filesystem: m.Filesystem, TotalBytes: m.TotalBytes, UsedBytes: m.UsedBytes, AvailableBytes: m.AvailableBytes, UsedPercent: m.UsedPercent, IsRoot: m.IsRoot, IsLargest: m.IsLargest, Severity: m.Severity, Metadata: m.Metadata, } metricItems = append(metricItems, item) } report := models.StorageMetricReport{ AgentID: cfg.AgentID, CommandID: commandID, Timestamp: time.Now(), Metrics: metricItems, } if err := apiClient.ReportStorageMetrics(cfg.AgentID, report); err != nil { return fmt.Errorf("failed to report storage metrics: %w", err) } log.Printf("[INFO] [storage] Successfully reported %d storage metrics to server\n", len(metrics)) } } // Create history entry for unified view with proper formatting logReport := client.LogReport{ CommandID: commandID, Action: "scan_storage", Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0], Stdout: stdout, Stderr: stderr, ExitCode: exitCode, DurationSeconds: int(duration.Seconds()), Metadata: map[string]string{ "subsystem_label": "Disk Usage", "subsystem": "storage", "metrics_count": fmt.Sprintf("%d", len(metrics)), }, } if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil { log.Printf("[ERROR] [agent] [storage] report_log_failed: %v", err) log.Printf("[HISTORY] [agent] [storage] report_log_failed error=\"%v\" timestamp=%s", err, time.Now().Format(time.RFC3339)) } else { log.Printf("[INFO] [agent] [storage] history_log_created command_id=%s timestamp=%s", commandID, time.Now().Format(time.RFC3339)) log.Printf("[HISTORY] [agent] [scan_storage] log_created agent_id=%s command_id=%s result=%s timestamp=%s", cfg.AgentID, commandID, map[bool]string{true: "success", false: "failure"}[exitCode == 0], time.Now().Format(time.RFC3339)) } return nil } // handleScanSystem scans system metrics (CPU, memory, processes, uptime) func handleScanSystem(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error { log.Println("Scanning system metrics...") ctx := context.Background() startTime := time.Now() // Execute system scanner result, err := orch.ScanSingle(ctx, "system") if err != nil { return fmt.Errorf("failed to scan system: %w", err) } // Format results results := []orchestrator.ScanResult{result} stdout, stderr, exitCode := orchestrator.FormatScanSummary(results) duration := time.Since(startTime) stdout += fmt.Sprintf("\nSystem scan completed in %.2f seconds\n", duration.Seconds()) // [REMOVED logReport after ReportLog removal - unused] // logReport := client.LogReport{...} // Report system metrics to server using dedicated endpoint // Get system scanner and use proper interface systemScanner := orchestrator.NewSystemScanner(cfg.AgentVersion) var metrics []orchestrator.SystemMetric // Declare outside if block for ReportLog access if systemScanner.IsAvailable() { var err error metrics, err = systemScanner.ScanSystem() if err != nil { return fmt.Errorf("failed to scan system metrics: %w", err) } if len(metrics) > 0 { // Convert SystemMetric to MetricsReportItem for API call metricItems := make([]client.MetricsReportItem, 0, len(metrics)) for _, metric := range metrics { item := client.MetricsReportItem{ PackageType: "system", PackageName: metric.MetricName, CurrentVersion: metric.CurrentValue, AvailableVersion: metric.AvailableValue, Severity: metric.Severity, RepositorySource: metric.MetricType, Metadata: metric.Metadata, } metricItems = append(metricItems, item) } report := client.MetricsReport{ CommandID: commandID, Timestamp: time.Now(), Metrics: metricItems, } if err := apiClient.ReportMetrics(cfg.AgentID, report); err != nil { return fmt.Errorf("failed to report system metrics: %w", err) } log.Printf("✓ Reported %d system metrics to server\n", len(metrics)) } } // Create history entry for unified view with proper formatting logReport := client.LogReport{ CommandID: commandID, Action: "scan_system", Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0], Stdout: stdout, Stderr: stderr, ExitCode: exitCode, DurationSeconds: int(duration.Seconds()), Metadata: map[string]string{ "subsystem_label": "System Metrics", "subsystem": "system", "metrics_count": fmt.Sprintf("%d", len(metrics)), }, } if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil { log.Printf("[ERROR] [agent] [system] report_log_failed: %v", err) log.Printf("[HISTORY] [agent] [system] report_log_failed error=\"%v\" timestamp=%s", err, time.Now().Format(time.RFC3339)) } else { log.Printf("[INFO] [agent] [system] history_log_created command_id=%s timestamp=%s", commandID, time.Now().Format(time.RFC3339)) log.Printf("[HISTORY] [agent] [scan_system] log_created agent_id=%s command_id=%s result=%s timestamp=%s", cfg.AgentID, commandID, map[bool]string{true: "success", false: "failure"}[exitCode == 0], time.Now().Format(time.RFC3339)) } return nil } // handleScanDocker scans Docker image updates only func handleScanDocker(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error { log.Println("Scanning Docker images...") ctx := context.Background() startTime := time.Now() // Execute Docker scanner result, err := orch.ScanSingle(ctx, "docker") if err != nil { return fmt.Errorf("failed to scan Docker: %w", err) } // Format results results := []orchestrator.ScanResult{result} stdout, stderr, exitCode := orchestrator.FormatScanSummary(results) duration := time.Since(startTime) stdout += fmt.Sprintf("\nDocker scan completed in %.2f seconds\n", duration.Seconds()) // [REMOVED logReport after ReportLog removal - unused] // logReport := client.LogReport{...} // Report Docker images to server using dedicated endpoint // Get Docker scanner and use proper interface dockerScanner, err := orchestrator.NewDockerScanner() if err != nil { return fmt.Errorf("failed to create Docker scanner: %w", err) } defer dockerScanner.Close() var images []orchestrator.DockerImage // Declare outside if block for ReportLog access var updateCount int // Declare outside if block for ReportLog access if dockerScanner.IsAvailable() { images, err = dockerScanner.ScanDocker() if err != nil { return fmt.Errorf("failed to scan Docker images: %w", err) } // Always report all Docker images (not just those with updates) updateCount = 0 // Reset for counting if len(images) > 0 { // Convert DockerImage to DockerReportItem for API call imageItems := make([]client.DockerReportItem, 0, len(images)) for _, image := range images { item := client.DockerReportItem{ PackageType: "docker_image", PackageName: image.ImageName, CurrentVersion: image.ImageID, AvailableVersion: image.LatestImageID, Severity: image.Severity, RepositorySource: image.RepositorySource, Metadata: image.Metadata, } imageItems = append(imageItems, item) } report := client.DockerReport{ CommandID: commandID, Timestamp: time.Now(), Images: imageItems, } if err := apiClient.ReportDockerImages(cfg.AgentID, report); err != nil { return fmt.Errorf("failed to report Docker images: %w", err) } for _, image := range images { if image.HasUpdate { updateCount++ } } log.Printf("✓ Reported %d Docker images (%d with updates) to server\n", len(images), updateCount) } else { log.Println("No Docker images found") } } else { log.Println("Docker not available on this system") } // Create history entry for unified view with proper formatting logReport := client.LogReport{ CommandID: commandID, Action: "scan_docker", Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0], Stdout: stdout, Stderr: stderr, ExitCode: exitCode, DurationSeconds: int(duration.Seconds()), Metadata: map[string]string{ "subsystem_label": "Docker Images", "subsystem": "docker", "images_count": fmt.Sprintf("%d", len(images)), "updates_found": fmt.Sprintf("%d", updateCount), }, } if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil { log.Printf("[ERROR] [agent] [docker] report_log_failed: %v", err) log.Printf("[HISTORY] [agent] [docker] report_log_failed error=\"%v\" timestamp=%s", err, time.Now().Format(time.RFC3339)) } else { log.Printf("[INFO] [agent] [docker] history_log_created command_id=%s timestamp=%s", commandID, time.Now().Format(time.RFC3339)) log.Printf("[HISTORY] [agent] [scan_docker] log_created agent_id=%s command_id=%s result=%s timestamp=%s", cfg.AgentID, commandID, map[bool]string{true: "success", false: "failure"}[exitCode == 0], time.Now().Format(time.RFC3339)) } return nil } // handleScanAPT scans APT package updates only func handleScanAPT(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error { log.Println("Scanning APT packages...") ctx := context.Background() startTime := time.Now() // Execute APT scanner result, err := orch.ScanSingle(ctx, "apt") if err != nil { return fmt.Errorf("failed to scan APT: %w", err) } // Format results results := []orchestrator.ScanResult{result} stdout, stderr, exitCode := orchestrator.FormatScanSummary(results) duration := time.Since(startTime) stdout += fmt.Sprintf("\nAPT scan completed in %.2f seconds\n", duration.Seconds()) // Report APT updates to server if any were found // Declare updates at function scope for ReportLog access var updates []client.UpdateReportItem if result.Status == "success" && len(result.Updates) > 0 { updates = result.Updates report := client.UpdateReport{ CommandID: commandID, Timestamp: time.Now(), Updates: updates, } if err := apiClient.ReportUpdates(cfg.AgentID, report); err != nil { return fmt.Errorf("failed to report APT updates: %w", err) } log.Printf("[INFO] [agent] [apt] Successfully reported %d APT updates to server\n", len(updates)) } // Create history entry for unified view with proper formatting logReport := client.LogReport{ CommandID: commandID, Action: "scan_apt", Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0], Stdout: stdout, Stderr: stderr, ExitCode: exitCode, DurationSeconds: int(duration.Seconds()), Metadata: map[string]string{ "subsystem_label": "APT Packages", "subsystem": "apt", "updates_found": fmt.Sprintf("%d", len(updates)), }, } if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil { log.Printf("[ERROR] [agent] [apt] report_log_failed: %v", err) log.Printf("[HISTORY] [agent] [apt] report_log_failed error=\"%v\" timestamp=%s", err, time.Now().Format(time.RFC3339)) } else { log.Printf("[INFO] [agent] [apt] history_log_created command_id=%s timestamp=%s", commandID, time.Now().Format(time.RFC3339)) log.Printf("[HISTORY] [agent] [scan_apt] log_created agent_id=%s command_id=%s result=%s timestamp=%s", cfg.AgentID, commandID, map[bool]string{true: "success", false: "failure"}[exitCode == 0], time.Now().Format(time.RFC3339)) } return nil } // handleScanDNF scans DNF package updates only func handleScanDNF(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error { log.Println("Scanning DNF packages...") ctx := context.Background() startTime := time.Now() // Execute DNF scanner result, err := orch.ScanSingle(ctx, "dnf") if err != nil { return fmt.Errorf("failed to scan DNF: %w", err) } // Format results results := []orchestrator.ScanResult{result} stdout, stderr, exitCode := orchestrator.FormatScanSummary(results) duration := time.Since(startTime) stdout += fmt.Sprintf("\nDNF scan completed in %.2f seconds\n", duration.Seconds()) // Report DNF updates to server if any were found // Declare updates at function scope for ReportLog access var updates []client.UpdateReportItem if result.Status == "success" && len(result.Updates) > 0 { updates = result.Updates report := client.UpdateReport{ CommandID: commandID, Timestamp: time.Now(), Updates: updates, } if err := apiClient.ReportUpdates(cfg.AgentID, report); err != nil { return fmt.Errorf("failed to report DNF updates: %w", err) } log.Printf("[INFO] [agent] [dnf] Successfully reported %d DNF updates to server\n", len(updates)) } // Create history entry for unified view with proper formatting logReport := client.LogReport{ CommandID: commandID, Action: "scan_dnf", Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0], Stdout: stdout, Stderr: stderr, ExitCode: exitCode, DurationSeconds: int(duration.Seconds()), Metadata: map[string]string{ "subsystem_label": "DNF Packages", "subsystem": "dnf", "updates_found": fmt.Sprintf("%d", len(updates)), }, } if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil { log.Printf("[ERROR] [agent] [dnf] report_log_failed: %v", err) log.Printf("[HISTORY] [agent] [dnf] report_log_failed error=\"%v\" timestamp=%s", err, time.Now().Format(time.RFC3339)) } else { log.Printf("[INFO] [agent] [dnf] history_log_created command_id=%s timestamp=%s", commandID, time.Now().Format(time.RFC3339)) log.Printf("[HISTORY] [agent] [scan_dnf] log_created agent_id=%s command_id=%s result=%s timestamp=%s", cfg.AgentID, commandID, map[bool]string{true: "success", false: "failure"}[exitCode == 0], time.Now().Format(time.RFC3339)) } return nil } // handleScanWindows scans Windows Updates only func handleScanWindows(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error { log.Println("Scanning Windows Updates...") ctx := context.Background() startTime := time.Now() // Execute Windows Update scanner result, err := orch.ScanSingle(ctx, "windows") if err != nil { return fmt.Errorf("failed to scan Windows Updates: %w", err) } // Format results results := []orchestrator.ScanResult{result} stdout, stderr, exitCode := orchestrator.FormatScanSummary(results) duration := time.Since(startTime) stdout += fmt.Sprintf("\nWindows Update scan completed in %.2f seconds\n", duration.Seconds()) // Report Windows updates to server if any were found // Declare updates at function scope for ReportLog access var updates []client.UpdateReportItem if result.Status == "success" && len(result.Updates) > 0 { updates = result.Updates report := client.UpdateReport{ CommandID: commandID, Timestamp: time.Now(), Updates: updates, } if err := apiClient.ReportUpdates(cfg.AgentID, report); err != nil { return fmt.Errorf("failed to report Windows updates: %w", err) } log.Printf("[INFO] [agent] [windows] Successfully reported %d Windows updates to server\n", len(updates)) } // Create history entry for unified view with proper formatting logReport := client.LogReport{ CommandID: commandID, Action: "scan_windows", Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0], Stdout: stdout, Stderr: stderr, ExitCode: exitCode, DurationSeconds: int(duration.Seconds()), Metadata: map[string]string{ "subsystem_label": "Windows Updates", "subsystem": "windows", "updates_found": fmt.Sprintf("%d", len(updates)), }, } if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil { log.Printf("[ERROR] [agent] [windows] report_log_failed: %v", err) log.Printf("[HISTORY] [agent] [windows] report_log_failed error=\"%v\" timestamp=%s", err, time.Now().Format(time.RFC3339)) } else { log.Printf("[INFO] [agent] [windows] history_log_created command_id=%s timestamp=%s", commandID, time.Now().Format(time.RFC3339)) log.Printf("[HISTORY] [agent] [scan_windows] log_created agent_id=%s command_id=%s result=%s timestamp=%s", cfg.AgentID, commandID, map[bool]string{true: "success", false: "failure"}[exitCode == 0], time.Now().Format(time.RFC3339)) } return nil } // handleScanWinget scans Winget package updates only func handleScanWinget(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error { log.Println("Scanning Winget packages...") ctx := context.Background() startTime := time.Now() // Execute Winget scanner result, err := orch.ScanSingle(ctx, "winget") if err != nil { return fmt.Errorf("failed to scan Winget: %w", err) } // Format results results := []orchestrator.ScanResult{result} stdout, stderr, exitCode := orchestrator.FormatScanSummary(results) duration := time.Since(startTime) stdout += fmt.Sprintf("\nWinget scan completed in %.2f seconds\n", duration.Seconds()) // Report Winget updates to server if any were found // Declare updates at function scope for ReportLog access var updates []client.UpdateReportItem if result.Status == "success" && len(result.Updates) > 0 { updates = result.Updates report := client.UpdateReport{ CommandID: commandID, Timestamp: time.Now(), Updates: updates, } if err := apiClient.ReportUpdates(cfg.AgentID, report); err != nil { return fmt.Errorf("failed to report Winget updates: %w", err) } log.Printf("[INFO] [agent] [winget] Successfully reported %d Winget updates to server\n", len(updates)) } // Create history entry for unified view with proper formatting logReport := client.LogReport{ CommandID: commandID, Action: "scan_winget", Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0], Stdout: stdout, Stderr: stderr, ExitCode: exitCode, DurationSeconds: int(duration.Seconds()), Metadata: map[string]string{ "subsystem_label": "Winget Packages", "subsystem": "winget", "updates_found": fmt.Sprintf("%d", len(updates)), }, } if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil { log.Printf("[ERROR] [agent] [winget] report_log_failed: %v", err) log.Printf("[HISTORY] [agent] [winget] report_log_failed error=\"%v\" timestamp=%s", err, time.Now().Format(time.RFC3339)) } else { log.Printf("[INFO] [agent] [winget] history_log_created command_id=%s timestamp=%s", commandID, time.Now().Format(time.RFC3339)) log.Printf("[HISTORY] [agent] [scan_winget] log_created agent_id=%s command_id=%s result=%s timestamp=%s", cfg.AgentID, commandID, map[bool]string{true: "success", false: "failure"}[exitCode == 0], time.Now().Format(time.RFC3339)) } return nil } // handleUpdateAgent handles agent update commands with signature verification func handleUpdateAgent(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, params map[string]interface{}, commandID string) error { log.Println("Processing agent update command...") // Extract parameters version, ok := params["version"].(string) if !ok { return fmt.Errorf("missing version parameter") } platform, ok := params["platform"].(string) if !ok { return fmt.Errorf("missing platform parameter") } downloadURL, ok := params["download_url"].(string) if !ok { return fmt.Errorf("missing download_url parameter") } signature, ok := params["signature"].(string) if !ok { return fmt.Errorf("missing signature parameter") } checksum, ok := params["checksum"].(string) if !ok { return fmt.Errorf("missing checksum parameter") } // Extract nonce parameters for replay protection nonceUUIDStr, ok := params["nonce_uuid"].(string) if !ok { return fmt.Errorf("missing nonce_uuid parameter") } nonceTimestampStr, ok := params["nonce_timestamp"].(string) if !ok { return fmt.Errorf("missing nonce_timestamp parameter") } nonceSignature, ok := params["nonce_signature"].(string) if !ok { return fmt.Errorf("missing nonce_signature parameter") } log.Printf("Updating agent to version %s (%s)", version, platform) // Validate nonce for replay protection log.Printf("[tunturi_ed25519] Validating nonce...") log.Printf("[SECURITY] Nonce validation - UUID: %s, Timestamp: %s", nonceUUIDStr, nonceTimestampStr) if err := validateNonce(nonceUUIDStr, nonceTimestampStr, nonceSignature); err != nil { log.Printf("[SECURITY] ✗ Nonce validation FAILED: %v", err) return fmt.Errorf("[tunturi_ed25519] nonce validation failed: %w", err) } log.Printf("[SECURITY] ✓ Nonce validated successfully") // Record start time for duration calculation updateStartTime := time.Now() // Report the update command as started logReport := client.LogReport{ CommandID: commandID, Action: "update_agent", Result: "started", Stdout: fmt.Sprintf("Starting agent update to version %s\n", version), Stderr: "", ExitCode: 0, DurationSeconds: 0, Metadata: map[string]string{ "subsystem_label": "Agent Update", "subsystem": "agent", "target_version": version, }, } if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil { log.Printf("Failed to report update start log: %v\n", err) } log.Printf("Starting secure update process for version %s", version) log.Printf("Download URL: %s", downloadURL) log.Printf("Signature: %s...", signature[:16]) // Log first 16 chars of signature log.Printf("Expected checksum: %s", checksum) // Step 1: Download the update package log.Printf("Step 1: Downloading update package...") tempBinaryPath, err := downloadUpdatePackage(downloadURL) if err != nil { return fmt.Errorf("failed to download update package: %w", err) } defer os.Remove(tempBinaryPath) // Cleanup on exit // Step 2: Verify checksum log.Printf("Step 2: Verifying checksum...") actualChecksum, err := computeSHA256(tempBinaryPath) if err != nil { return fmt.Errorf("failed to compute checksum: %w", err) } if actualChecksum != checksum { return fmt.Errorf("checksum mismatch: expected %s, got %s", checksum, actualChecksum) } log.Printf("✓ Checksum verified: %s", actualChecksum) // Step 3: Verify Ed25519 signature log.Printf("[tunturi_ed25519] Step 3: Verifying Ed25519 signature...") if err := verifyBinarySignature(tempBinaryPath, signature); err != nil { return fmt.Errorf("[tunturi_ed25519] signature verification failed: %w", err) } log.Printf("[tunturi_ed25519] ✓ Signature verified") // Step 4: Create backup of current binary log.Printf("Step 4: Creating backup...") currentBinaryPath, err := getCurrentBinaryPath() if err != nil { return fmt.Errorf("failed to determine current binary path: %w", err) } backupPath := currentBinaryPath + ".bak" var updateSuccess bool = false // Track overall success if err := createBackup(currentBinaryPath, backupPath); err != nil { log.Printf("Warning: Failed to create backup: %v", err) } else { // Defer rollback/cleanup logic defer func() { if !updateSuccess { // Rollback on failure log.Printf("[tunturi_ed25519] Rollback: restoring from backup...") if restoreErr := restoreFromBackup(backupPath, currentBinaryPath); restoreErr != nil { log.Printf("[tunturi_ed25519] CRITICAL: Failed to restore backup: %v", restoreErr) } else { log.Printf("[tunturi_ed25519] ✓ Successfully rolled back to backup") } } else { // Clean up backup on success log.Printf("[tunturi_ed25519] ✓ Update successful, cleaning up backup") os.Remove(backupPath) } }() } // Step 5: Atomic installation log.Printf("Step 5: Installing new binary...") if err := installNewBinary(tempBinaryPath, currentBinaryPath); err != nil { return fmt.Errorf("failed to install new binary: %w", err) } // Step 6: Restart agent service log.Printf("Step 6: Restarting agent service...") if err := restartAgentService(); err != nil { return fmt.Errorf("failed to restart agent: %w", err) } // Step 7: Watchdog timer for confirmation log.Printf("Step 7: Starting watchdog for update confirmation...") updateSuccess = waitForUpdateConfirmation(apiClient, cfg, ackTracker, version, 5*time.Minute) success := updateSuccess // Alias for logging below finalLogReport := client.LogReport{ CommandID: commandID, Action: "update_agent", Result: map[bool]string{true: "success", false: "failure"}[success], Stdout: fmt.Sprintf("Agent update to version %s %s\n", version, map[bool]string{true: "completed successfully", false: "failed"}[success]), Stderr: map[bool]string{true: "", false: "Update verification timeout or restart failure"}[success], ExitCode: map[bool]int{true: 0, false: 1}[success], DurationSeconds: int(time.Since(updateStartTime).Seconds()), Metadata: map[string]string{ "subsystem_label": "Agent Update", "subsystem": "agent", "target_version": version, "success": map[bool]string{true: "true", false: "false"}[success], }, } if err := reportLogWithAck(apiClient, cfg, ackTracker, finalLogReport); err != nil { log.Printf("Failed to report update completion log: %v\n", err) } if success { log.Printf("✓ Agent successfully updated to version %s", version) } else { return fmt.Errorf("agent update verification failed") } return nil } // Helper functions for the update process func downloadUpdatePackage(downloadURL string) (string, error) { // Download to temporary file tempFile, err := os.CreateTemp("", "redflag-update-*.bin") if err != nil { return "", fmt.Errorf("failed to create temp file: %w", err) } defer tempFile.Close() resp, err := http.Get(downloadURL) if err != nil { return "", fmt.Errorf("failed to download: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("download failed with status: %d", resp.StatusCode) } if _, err := tempFile.ReadFrom(resp.Body); err != nil { return "", fmt.Errorf("failed to write download: %w", err) } return tempFile.Name(), nil } func computeSHA256(filePath string) (string, error) { file, err := os.Open(filePath) if err != nil { return "", fmt.Errorf("failed to open file: %w", err) } defer file.Close() hash := sha256.New() if _, err := io.Copy(hash, file); err != nil { return "", fmt.Errorf("failed to compute hash: %w", err) } return hex.EncodeToString(hash.Sum(nil)), nil } func getCurrentBinaryPath() (string, error) { execPath, err := os.Executable() if err != nil { return "", fmt.Errorf("failed to get executable path: %w", err) } return execPath, nil } func createBackup(src, dst string) error { srcFile, err := os.Open(src) if err != nil { return fmt.Errorf("failed to open source: %w", err) } defer srcFile.Close() dstFile, err := os.Create(dst) if err != nil { return fmt.Errorf("failed to create backup: %w", err) } defer dstFile.Close() if _, err := dstFile.ReadFrom(srcFile); err != nil { return fmt.Errorf("failed to copy backup: %w", err) } // Ensure backup is executable if err := os.Chmod(dst, 0755); err != nil { return fmt.Errorf("failed to set backup permissions: %w", err) } return nil } func restoreFromBackup(backup, target string) error { // Remove current binary if it exists if _, err := os.Stat(target); err == nil { if err := os.Remove(target); err != nil { return fmt.Errorf("failed to remove current binary: %w", err) } } // Copy backup to target return createBackup(backup, target) } func installNewBinary(src, dst string) error { // Copy new binary to a temporary location first tempDst := dst + ".new" srcFile, err := os.Open(src) if err != nil { return fmt.Errorf("failed to open source binary: %w", err) } defer srcFile.Close() dstFile, err := os.Create(tempDst) if err != nil { return fmt.Errorf("failed to create temp binary: %w", err) } defer dstFile.Close() if _, err := dstFile.ReadFrom(srcFile); err != nil { return fmt.Errorf("failed to copy binary: %w", err) } dstFile.Close() // Set executable permissions if err := os.Chmod(tempDst, 0755); err != nil { return fmt.Errorf("failed to set binary permissions: %w", err) } // Atomic rename if err := os.Rename(tempDst, dst); err != nil { os.Remove(tempDst) // Cleanup temp file return fmt.Errorf("failed to atomically replace binary: %w", err) } return nil } func restartAgentService() error { var cmd *exec.Cmd switch runtime.GOOS { case "linux": // Try systemd first cmd = exec.Command("systemctl", "restart", "redflag-agent") if err := cmd.Run(); err == nil { log.Printf("✓ Systemd service restarted") return nil } // Fallback to service command cmd = exec.Command("service", "redflag-agent", "restart") case "windows": cmd = exec.Command("sc", "stop", "RedFlagAgent") cmd.Run() cmd = exec.Command("sc", "start", "RedFlagAgent") default: return fmt.Errorf("unsupported OS for service restart: %s", runtime.GOOS) } if err := cmd.Run(); err != nil { return fmt.Errorf("failed to restart service: %w", err) } log.Printf("✓ Agent service restarted") return nil } func waitForUpdateConfirmation(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, expectedVersion string, timeout time.Duration) bool { deadline := time.Now().Add(timeout) pollInterval := 15 * time.Second log.Printf("[tunturi_ed25519] Watchdog: waiting for version %s confirmation (timeout: %v)...", expectedVersion, timeout) for time.Now().Before(deadline) { // Poll server for current agent version agent, err := apiClient.GetAgent(cfg.AgentID.String()) if err != nil { log.Printf("[tunturi_ed25519] Watchdog: failed to poll server: %v (retrying...)", err) time.Sleep(pollInterval) continue } // Check if the version matches the expected version if agent != nil && agent.CurrentVersion == expectedVersion { log.Printf("[tunturi_ed25519] Watchdog: ✓ Version confirmed: %s", expectedVersion) return true } log.Printf("[tunturi_ed25519] Watchdog: Current version: %s, Expected: %s (polling...)", agent.CurrentVersion, expectedVersion) time.Sleep(pollInterval) } log.Printf("[tunturi_ed25519] Watchdog: ✗ Timeout after %v - version not confirmed", timeout) log.Printf("[tunturi_ed25519] Rollback initiated") return false } // AES-256-GCM decryption helper functions for encrypted update packages // deriveKeyFromNonce derives an AES-256 key from a nonce using SHA-256 func deriveKeyFromNonce(nonce string) []byte { hash := sha256.Sum256([]byte(nonce)) return hash[:] // 32 bytes for AES-256 } // decryptAES256GCM decrypts data using AES-256-GCM with the provided nonce-derived key func decryptAES256GCM(encryptedData, nonce string) ([]byte, error) { // Derive key from nonce key := deriveKeyFromNonce(nonce) // Decode hex data data, err := hex.DecodeString(encryptedData) if err != nil { return nil, fmt.Errorf("failed to decode hex data: %w", err) } // Create AES cipher block, err := aes.NewCipher(key) if err != nil { return nil, fmt.Errorf("failed to create AES cipher: %w", err) } // Create GCM gcm, err := cipher.NewGCM(block) if err != nil { return nil, fmt.Errorf("failed to create GCM: %w", err) } // Check minimum length nonceSize := gcm.NonceSize() if len(data) < nonceSize { return nil, fmt.Errorf("encrypted data too short") } // Extract nonce and ciphertext nonceBytes, ciphertext := data[:nonceSize], data[nonceSize:] // Decrypt plaintext, err := gcm.Open(nil, nonceBytes, ciphertext, nil) if err != nil { return nil, fmt.Errorf("failed to decrypt: %w", err) } return plaintext, nil } // TODO: Integration with system/machine_id.go for key derivation // This stub should be integrated with the existing machine ID system // for more sophisticated key management based on hardware fingerprinting // // Example integration approach: // - Use machine_id.go to generate stable hardware fingerprint // - Combine hardware fingerprint with nonce for key derivation // - Store derived keys securely in memory only // - Implement key rotation support for long-running agents // verifyBinarySignature verifies the Ed25519 signature of a binary file func verifyBinarySignature(binaryPath, signatureHex string) error { // Get the server public key from cache publicKey, err := getServerPublicKey() if err != nil { return fmt.Errorf("failed to get server public key: %w", err) } // Read the binary content content, err := os.ReadFile(binaryPath) if err != nil { return fmt.Errorf("failed to read binary: %w", err) } // Decode signature from hex signatureBytes, err := hex.DecodeString(signatureHex) if err != nil { return fmt.Errorf("failed to decode signature: %w", err) } // Verify signature length if len(signatureBytes) != ed25519.SignatureSize { return fmt.Errorf("invalid signature length: expected %d bytes, got %d", ed25519.SignatureSize, len(signatureBytes)) } // Ed25519 verification valid := ed25519.Verify(ed25519.PublicKey(publicKey), content, signatureBytes) if !valid { return fmt.Errorf("signature verification failed: invalid signature") } return nil } // getServerPublicKey retrieves the Ed25519 public key from cache // The key is fetched from the server at startup and cached locally func getServerPublicKey() ([]byte, error) { // Load from cache (fetched during agent startup) publicKey, err := loadCachedPublicKeyDirect() if err != nil { return nil, fmt.Errorf("failed to load server public key: %w (hint: key is fetched at agent startup)", err) } return publicKey, nil } // loadCachedPublicKeyDirect loads the cached public key from the standard location func loadCachedPublicKeyDirect() ([]byte, error) { var keyPath string if runtime.GOOS == "windows" { keyPath = "C:\\ProgramData\\RedFlag\\server_public_key" } else { keyPath = "/etc/redflag/server_public_key" } data, err := os.ReadFile(keyPath) if err != nil { return nil, fmt.Errorf("public key not found: %w", err) } if len(data) != 32 { // ed25519.PublicKeySize return nil, fmt.Errorf("invalid public key size: expected 32 bytes, got %d", len(data)) } return data, nil } // validateNonce validates the nonce for replay protection func validateNonce(nonceUUIDStr, nonceTimestampStr, nonceSignature string) error { // Parse timestamp nonceTimestamp, err := time.Parse(time.RFC3339, nonceTimestampStr) if err != nil { return fmt.Errorf("invalid nonce timestamp format: %w", err) } // Check freshness (< 5 minutes) age := time.Since(nonceTimestamp) if age > 5*time.Minute { return fmt.Errorf("nonce expired: age %v > 5 minutes", age) } if age < 0 { return fmt.Errorf("nonce timestamp in the future: %v", nonceTimestamp) } // Get server public key from cache publicKey, err := getServerPublicKey() if err != nil { return fmt.Errorf("failed to get server public key: %w", err) } // Recreate nonce data (must match server format) nonceData := fmt.Sprintf("%s:%d", nonceUUIDStr, nonceTimestamp.Unix()) // Decode signature signatureBytes, err := hex.DecodeString(nonceSignature) if err != nil { return fmt.Errorf("invalid nonce signature format: %w", err) } if len(signatureBytes) != ed25519.SignatureSize { return fmt.Errorf("invalid nonce signature length: expected %d bytes, got %d", ed25519.SignatureSize, len(signatureBytes)) } // Verify Ed25519 signature valid := ed25519.Verify(ed25519.PublicKey(publicKey), []byte(nonceData), signatureBytes) if !valid { return fmt.Errorf("invalid nonce signature") } return nil }