- Update Update scanner default from 15min to 12 hours (backend) - Add 1 week and 2 week frequency options (frontend) - Rename AgentScanners to AgentHealth component - Add OS-aware package manager badges (APT, DNF, Windows/Winget, Docker) - Fix all build errors (types, imports, storage metrics) - Add useMemo optimization for enabled/auto-run counts
901 lines
27 KiB
Go
901 lines
27 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/ed25519"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"runtime"
|
|
"time"
|
|
|
|
"github.com/Fimeg/RedFlag/aggregator-agent/internal/acknowledgment"
|
|
"github.com/Fimeg/RedFlag/aggregator-agent/internal/client"
|
|
"github.com/Fimeg/RedFlag/aggregator-agent/internal/config"
|
|
"github.com/Fimeg/RedFlag/aggregator-agent/internal/models"
|
|
"github.com/Fimeg/RedFlag/aggregator-agent/internal/orchestrator"
|
|
)
|
|
|
|
// handleScanUpdatesV2 scans all update subsystems (APT, DNF, Docker, Windows Update, Winget) in parallel
|
|
// This is the new orchestrator-based version for v0.1.20
|
|
func handleScanUpdatesV2(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error {
|
|
log.Println("Scanning for updates (parallel execution)...")
|
|
|
|
ctx := context.Background()
|
|
startTime := time.Now()
|
|
|
|
// Execute all update scanners in parallel
|
|
results, allUpdates := orch.ScanAll(ctx)
|
|
|
|
// Format results
|
|
stdout, stderr, exitCode := orchestrator.FormatScanSummary(results)
|
|
|
|
// Add timing information
|
|
duration := time.Since(startTime)
|
|
stdout += fmt.Sprintf("\nScan completed in %.2f seconds\n", duration.Seconds())
|
|
|
|
// Create scan log entry with subsystem metadata
|
|
logReport := client.LogReport{
|
|
CommandID: commandID,
|
|
Action: "scan_updates",
|
|
Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0],
|
|
Stdout: stdout,
|
|
Stderr: stderr,
|
|
ExitCode: exitCode,
|
|
DurationSeconds: int(duration.Seconds()),
|
|
Metadata: map[string]string{
|
|
"subsystem_label": "Package Updates",
|
|
"subsystem": "updates",
|
|
},
|
|
}
|
|
|
|
// Report the scan log
|
|
if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil {
|
|
log.Printf("Failed to report scan log: %v\n", err)
|
|
// Continue anyway - updates are more important
|
|
}
|
|
|
|
// Report updates to server if any were found
|
|
if len(allUpdates) > 0 {
|
|
report := client.UpdateReport{
|
|
CommandID: commandID,
|
|
Timestamp: time.Now(),
|
|
Updates: allUpdates,
|
|
}
|
|
|
|
if err := apiClient.ReportUpdates(cfg.AgentID, report); err != nil {
|
|
return fmt.Errorf("failed to report updates: %w", err)
|
|
}
|
|
|
|
log.Printf("✓ Reported %d updates to server\n", len(allUpdates))
|
|
} else {
|
|
log.Println("No updates found")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// handleScanStorage scans disk usage metrics only
|
|
func handleScanStorage(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error {
|
|
log.Println("Scanning storage...")
|
|
|
|
ctx := context.Background()
|
|
startTime := time.Now()
|
|
|
|
// Execute storage scanner
|
|
result, err := orch.ScanSingle(ctx, "storage")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to scan storage: %w", err)
|
|
}
|
|
|
|
// Format results
|
|
results := []orchestrator.ScanResult{result}
|
|
stdout, stderr, exitCode := orchestrator.FormatScanSummary(results)
|
|
|
|
duration := time.Since(startTime)
|
|
stdout += fmt.Sprintf("\nStorage scan completed in %.2f seconds\n", duration.Seconds())
|
|
|
|
// Create scan log entry
|
|
logReport := client.LogReport{
|
|
CommandID: commandID,
|
|
Action: "scan_storage",
|
|
Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0],
|
|
Stdout: stdout,
|
|
Stderr: stderr,
|
|
ExitCode: exitCode,
|
|
DurationSeconds: int(duration.Seconds()),
|
|
Metadata: map[string]string{
|
|
"subsystem_label": "Disk Usage",
|
|
"subsystem": "storage",
|
|
},
|
|
}
|
|
|
|
// Report the scan log
|
|
if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil {
|
|
log.Printf("Failed to report scan log: %v\n", err)
|
|
}
|
|
|
|
// Report storage metrics to server using dedicated endpoint
|
|
// Use proper StorageMetricReport with clean field names
|
|
storageScanner := orchestrator.NewStorageScanner(cfg.AgentVersion)
|
|
if storageScanner.IsAvailable() {
|
|
metrics, err := storageScanner.ScanStorage()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to scan storage metrics: %w", err)
|
|
}
|
|
|
|
if len(metrics) > 0 {
|
|
// Convert from orchestrator.StorageMetric to models.StorageMetric
|
|
metricItems := make([]models.StorageMetric, 0, len(metrics))
|
|
for _, m := range metrics {
|
|
item := models.StorageMetric{
|
|
Mountpoint: m.Mountpoint,
|
|
Device: m.Device,
|
|
DiskType: m.DiskType,
|
|
Filesystem: m.Filesystem,
|
|
TotalBytes: m.TotalBytes,
|
|
UsedBytes: m.UsedBytes,
|
|
AvailableBytes: m.AvailableBytes,
|
|
UsedPercent: m.UsedPercent,
|
|
IsRoot: m.IsRoot,
|
|
IsLargest: m.IsLargest,
|
|
Severity: m.Severity,
|
|
Metadata: m.Metadata,
|
|
}
|
|
metricItems = append(metricItems, item)
|
|
}
|
|
|
|
report := models.StorageMetricReport{
|
|
AgentID: cfg.AgentID,
|
|
CommandID: commandID,
|
|
Timestamp: time.Now(),
|
|
Metrics: metricItems,
|
|
}
|
|
|
|
if err := apiClient.ReportStorageMetrics(cfg.AgentID, report); err != nil {
|
|
return fmt.Errorf("failed to report storage metrics: %w", err)
|
|
}
|
|
|
|
log.Printf("[INFO] [storage] Successfully reported %d storage metrics to server\n", len(metrics))
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// handleScanSystem scans system metrics (CPU, memory, processes, uptime)
|
|
func handleScanSystem(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error {
|
|
log.Println("Scanning system metrics...")
|
|
|
|
ctx := context.Background()
|
|
startTime := time.Now()
|
|
|
|
// Execute system scanner
|
|
result, err := orch.ScanSingle(ctx, "system")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to scan system: %w", err)
|
|
}
|
|
|
|
// Format results
|
|
results := []orchestrator.ScanResult{result}
|
|
stdout, stderr, exitCode := orchestrator.FormatScanSummary(results)
|
|
|
|
duration := time.Since(startTime)
|
|
stdout += fmt.Sprintf("\nSystem scan completed in %.2f seconds\n", duration.Seconds())
|
|
|
|
// Create scan log entry
|
|
logReport := client.LogReport{
|
|
CommandID: commandID,
|
|
Action: "scan_system",
|
|
Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0],
|
|
Stdout: stdout,
|
|
Stderr: stderr,
|
|
ExitCode: exitCode,
|
|
DurationSeconds: int(duration.Seconds()),
|
|
Metadata: map[string]string{
|
|
"subsystem_label": "System Metrics",
|
|
"subsystem": "system",
|
|
},
|
|
}
|
|
|
|
// Report the scan log
|
|
if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil {
|
|
log.Printf("Failed to report scan log: %v\n", err)
|
|
}
|
|
|
|
// Report system metrics to server using dedicated endpoint
|
|
// Get system scanner and use proper interface
|
|
systemScanner := orchestrator.NewSystemScanner("unknown") // TODO: Get actual agent version
|
|
if systemScanner.IsAvailable() {
|
|
metrics, err := systemScanner.ScanSystem()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to scan system metrics: %w", err)
|
|
}
|
|
|
|
if len(metrics) > 0 {
|
|
// Convert SystemMetric to MetricsReportItem for API call
|
|
metricItems := make([]client.MetricsReportItem, 0, len(metrics))
|
|
for _, metric := range metrics {
|
|
item := client.MetricsReportItem{
|
|
PackageType: "system",
|
|
PackageName: metric.MetricName,
|
|
CurrentVersion: metric.CurrentValue,
|
|
AvailableVersion: metric.AvailableValue,
|
|
Severity: metric.Severity,
|
|
RepositorySource: metric.MetricType,
|
|
Metadata: metric.Metadata,
|
|
}
|
|
metricItems = append(metricItems, item)
|
|
}
|
|
|
|
report := client.MetricsReport{
|
|
CommandID: commandID,
|
|
Timestamp: time.Now(),
|
|
Metrics: metricItems,
|
|
}
|
|
|
|
if err := apiClient.ReportMetrics(cfg.AgentID, report); err != nil {
|
|
return fmt.Errorf("failed to report system metrics: %w", err)
|
|
}
|
|
|
|
log.Printf("✓ Reported %d system metrics to server\n", len(metrics))
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// handleScanDocker scans Docker image updates only
|
|
func handleScanDocker(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error {
|
|
log.Println("Scanning Docker images...")
|
|
|
|
ctx := context.Background()
|
|
startTime := time.Now()
|
|
|
|
// Execute Docker scanner
|
|
result, err := orch.ScanSingle(ctx, "docker")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to scan Docker: %w", err)
|
|
}
|
|
|
|
// Format results
|
|
results := []orchestrator.ScanResult{result}
|
|
stdout, stderr, exitCode := orchestrator.FormatScanSummary(results)
|
|
|
|
duration := time.Since(startTime)
|
|
stdout += fmt.Sprintf("\nDocker scan completed in %.2f seconds\n", duration.Seconds())
|
|
|
|
// Create scan log entry
|
|
logReport := client.LogReport{
|
|
CommandID: commandID,
|
|
Action: "scan_docker",
|
|
Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0],
|
|
Stdout: stdout,
|
|
Stderr: stderr,
|
|
ExitCode: exitCode,
|
|
DurationSeconds: int(duration.Seconds()),
|
|
Metadata: map[string]string{
|
|
"subsystem_label": "Docker Images",
|
|
"subsystem": "docker",
|
|
},
|
|
}
|
|
|
|
// Report the scan log
|
|
if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil {
|
|
log.Printf("Failed to report scan log: %v\n", err)
|
|
}
|
|
|
|
// Report Docker images to server using dedicated endpoint
|
|
// Get Docker scanner and use proper interface
|
|
dockerScanner, err := orchestrator.NewDockerScanner()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create Docker scanner: %w", err)
|
|
}
|
|
defer dockerScanner.Close()
|
|
|
|
if dockerScanner.IsAvailable() {
|
|
images, err := dockerScanner.ScanDocker()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to scan Docker images: %w", err)
|
|
}
|
|
|
|
// Always report all Docker images (not just those with updates)
|
|
if len(images) > 0 {
|
|
// Convert DockerImage to DockerReportItem for API call
|
|
imageItems := make([]client.DockerReportItem, 0, len(images))
|
|
for _, image := range images {
|
|
item := client.DockerReportItem{
|
|
PackageType: "docker_image",
|
|
PackageName: image.ImageName,
|
|
CurrentVersion: image.ImageID,
|
|
AvailableVersion: image.LatestImageID,
|
|
Severity: image.Severity,
|
|
RepositorySource: image.RepositorySource,
|
|
Metadata: image.Metadata,
|
|
}
|
|
imageItems = append(imageItems, item)
|
|
}
|
|
|
|
report := client.DockerReport{
|
|
CommandID: commandID,
|
|
Timestamp: time.Now(),
|
|
Images: imageItems,
|
|
}
|
|
|
|
if err := apiClient.ReportDockerImages(cfg.AgentID, report); err != nil {
|
|
return fmt.Errorf("failed to report Docker images: %w", err)
|
|
}
|
|
|
|
updateCount := 0
|
|
for _, image := range images {
|
|
if image.HasUpdate {
|
|
updateCount++
|
|
}
|
|
}
|
|
log.Printf("✓ Reported %d Docker images (%d with updates) to server\n", len(images), updateCount)
|
|
} else {
|
|
log.Println("No Docker images found")
|
|
}
|
|
} else {
|
|
log.Println("Docker not available on this system")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// handleUpdateAgent handles agent update commands with signature verification
|
|
func handleUpdateAgent(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, params map[string]interface{}, commandID string) error {
|
|
log.Println("Processing agent update command...")
|
|
|
|
// Extract parameters
|
|
version, ok := params["version"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing version parameter")
|
|
}
|
|
|
|
platform, ok := params["platform"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing platform parameter")
|
|
}
|
|
|
|
downloadURL, ok := params["download_url"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing download_url parameter")
|
|
}
|
|
|
|
signature, ok := params["signature"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing signature parameter")
|
|
}
|
|
|
|
checksum, ok := params["checksum"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing checksum parameter")
|
|
}
|
|
|
|
// Extract nonce parameters for replay protection
|
|
nonceUUIDStr, ok := params["nonce_uuid"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing nonce_uuid parameter")
|
|
}
|
|
|
|
nonceTimestampStr, ok := params["nonce_timestamp"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing nonce_timestamp parameter")
|
|
}
|
|
|
|
nonceSignature, ok := params["nonce_signature"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing nonce_signature parameter")
|
|
}
|
|
|
|
log.Printf("Updating agent to version %s (%s)", version, platform)
|
|
|
|
// Validate nonce for replay protection
|
|
log.Printf("[tunturi_ed25519] Validating nonce...")
|
|
log.Printf("[SECURITY] Nonce validation - UUID: %s, Timestamp: %s", nonceUUIDStr, nonceTimestampStr)
|
|
if err := validateNonce(nonceUUIDStr, nonceTimestampStr, nonceSignature); err != nil {
|
|
log.Printf("[SECURITY] ✗ Nonce validation FAILED: %v", err)
|
|
return fmt.Errorf("[tunturi_ed25519] nonce validation failed: %w", err)
|
|
}
|
|
log.Printf("[SECURITY] ✓ Nonce validated successfully")
|
|
|
|
// Record start time for duration calculation
|
|
updateStartTime := time.Now()
|
|
|
|
// Report the update command as started
|
|
logReport := client.LogReport{
|
|
CommandID: commandID,
|
|
Action: "update_agent",
|
|
Result: "started",
|
|
Stdout: fmt.Sprintf("Starting agent update to version %s\n", version),
|
|
Stderr: "",
|
|
ExitCode: 0,
|
|
DurationSeconds: 0,
|
|
Metadata: map[string]string{
|
|
"subsystem_label": "Agent Update",
|
|
"subsystem": "agent",
|
|
"target_version": version,
|
|
},
|
|
}
|
|
|
|
if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil {
|
|
log.Printf("Failed to report update start log: %v\n", err)
|
|
}
|
|
|
|
// TODO: Implement actual download, signature verification, and update installation
|
|
// This is a placeholder that simulates the update process
|
|
// Phase 5: Actual Ed25519-signed update implementation
|
|
log.Printf("Starting secure update process for version %s", version)
|
|
log.Printf("Download URL: %s", downloadURL)
|
|
log.Printf("Signature: %s...", signature[:16]) // Log first 16 chars of signature
|
|
log.Printf("Expected checksum: %s", checksum)
|
|
|
|
// Step 1: Download the update package
|
|
log.Printf("Step 1: Downloading update package...")
|
|
tempBinaryPath, err := downloadUpdatePackage(downloadURL)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to download update package: %w", err)
|
|
}
|
|
defer os.Remove(tempBinaryPath) // Cleanup on exit
|
|
|
|
// Step 2: Verify checksum
|
|
log.Printf("Step 2: Verifying checksum...")
|
|
actualChecksum, err := computeSHA256(tempBinaryPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to compute checksum: %w", err)
|
|
}
|
|
|
|
if actualChecksum != checksum {
|
|
return fmt.Errorf("checksum mismatch: expected %s, got %s", checksum, actualChecksum)
|
|
}
|
|
log.Printf("✓ Checksum verified: %s", actualChecksum)
|
|
|
|
// Step 3: Verify Ed25519 signature
|
|
log.Printf("[tunturi_ed25519] Step 3: Verifying Ed25519 signature...")
|
|
if err := verifyBinarySignature(tempBinaryPath, signature); err != nil {
|
|
return fmt.Errorf("[tunturi_ed25519] signature verification failed: %w", err)
|
|
}
|
|
log.Printf("[tunturi_ed25519] ✓ Signature verified")
|
|
|
|
// Step 4: Create backup of current binary
|
|
log.Printf("Step 4: Creating backup...")
|
|
currentBinaryPath, err := getCurrentBinaryPath()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to determine current binary path: %w", err)
|
|
}
|
|
|
|
backupPath := currentBinaryPath + ".bak"
|
|
var updateSuccess bool = false // Track overall success
|
|
|
|
if err := createBackup(currentBinaryPath, backupPath); err != nil {
|
|
log.Printf("Warning: Failed to create backup: %v", err)
|
|
} else {
|
|
// Defer rollback/cleanup logic
|
|
defer func() {
|
|
if !updateSuccess {
|
|
// Rollback on failure
|
|
log.Printf("[tunturi_ed25519] Rollback: restoring from backup...")
|
|
if restoreErr := restoreFromBackup(backupPath, currentBinaryPath); restoreErr != nil {
|
|
log.Printf("[tunturi_ed25519] CRITICAL: Failed to restore backup: %v", restoreErr)
|
|
} else {
|
|
log.Printf("[tunturi_ed25519] ✓ Successfully rolled back to backup")
|
|
}
|
|
} else {
|
|
// Clean up backup on success
|
|
log.Printf("[tunturi_ed25519] ✓ Update successful, cleaning up backup")
|
|
os.Remove(backupPath)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Step 5: Atomic installation
|
|
log.Printf("Step 5: Installing new binary...")
|
|
if err := installNewBinary(tempBinaryPath, currentBinaryPath); err != nil {
|
|
return fmt.Errorf("failed to install new binary: %w", err)
|
|
}
|
|
|
|
// Step 6: Restart agent service
|
|
log.Printf("Step 6: Restarting agent service...")
|
|
if err := restartAgentService(); err != nil {
|
|
return fmt.Errorf("failed to restart agent: %w", err)
|
|
}
|
|
|
|
// Step 7: Watchdog timer for confirmation
|
|
log.Printf("Step 7: Starting watchdog for update confirmation...")
|
|
updateSuccess = waitForUpdateConfirmation(apiClient, cfg, ackTracker, version, 5*time.Minute)
|
|
success := updateSuccess // Alias for logging below
|
|
|
|
finalLogReport := client.LogReport{
|
|
CommandID: commandID,
|
|
Action: "update_agent",
|
|
Result: map[bool]string{true: "success", false: "failure"}[success],
|
|
Stdout: fmt.Sprintf("Agent update to version %s %s\n", version, map[bool]string{true: "completed successfully", false: "failed"}[success]),
|
|
Stderr: map[bool]string{true: "", false: "Update verification timeout or restart failure"}[success],
|
|
ExitCode: map[bool]int{true: 0, false: 1}[success],
|
|
DurationSeconds: int(time.Since(updateStartTime).Seconds()),
|
|
Metadata: map[string]string{
|
|
"subsystem_label": "Agent Update",
|
|
"subsystem": "agent",
|
|
"target_version": version,
|
|
"success": map[bool]string{true: "true", false: "false"}[success],
|
|
},
|
|
}
|
|
|
|
if err := reportLogWithAck(apiClient, cfg, ackTracker, finalLogReport); err != nil {
|
|
log.Printf("Failed to report update completion log: %v\n", err)
|
|
}
|
|
|
|
if success {
|
|
log.Printf("✓ Agent successfully updated to version %s", version)
|
|
} else {
|
|
return fmt.Errorf("agent update verification failed")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Helper functions for the update process
|
|
|
|
func downloadUpdatePackage(downloadURL string) (string, error) {
|
|
// Download to temporary file
|
|
tempFile, err := os.CreateTemp("", "redflag-update-*.bin")
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create temp file: %w", err)
|
|
}
|
|
defer tempFile.Close()
|
|
|
|
resp, err := http.Get(downloadURL)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to download: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("download failed with status: %d", resp.StatusCode)
|
|
}
|
|
|
|
if _, err := tempFile.ReadFrom(resp.Body); err != nil {
|
|
return "", fmt.Errorf("failed to write download: %w", err)
|
|
}
|
|
|
|
return tempFile.Name(), nil
|
|
}
|
|
|
|
func computeSHA256(filePath string) (string, error) {
|
|
file, err := os.Open(filePath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to open file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
hash := sha256.New()
|
|
if _, err := io.Copy(hash, file); err != nil {
|
|
return "", fmt.Errorf("failed to compute hash: %w", err)
|
|
}
|
|
|
|
return hex.EncodeToString(hash.Sum(nil)), nil
|
|
}
|
|
|
|
func getCurrentBinaryPath() (string, error) {
|
|
execPath, err := os.Executable()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get executable path: %w", err)
|
|
}
|
|
return execPath, nil
|
|
}
|
|
|
|
func createBackup(src, dst string) error {
|
|
srcFile, err := os.Open(src)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open source: %w", err)
|
|
}
|
|
defer srcFile.Close()
|
|
|
|
dstFile, err := os.Create(dst)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create backup: %w", err)
|
|
}
|
|
defer dstFile.Close()
|
|
|
|
if _, err := dstFile.ReadFrom(srcFile); err != nil {
|
|
return fmt.Errorf("failed to copy backup: %w", err)
|
|
}
|
|
|
|
// Ensure backup is executable
|
|
if err := os.Chmod(dst, 0755); err != nil {
|
|
return fmt.Errorf("failed to set backup permissions: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func restoreFromBackup(backup, target string) error {
|
|
// Remove current binary if it exists
|
|
if _, err := os.Stat(target); err == nil {
|
|
if err := os.Remove(target); err != nil {
|
|
return fmt.Errorf("failed to remove current binary: %w", err)
|
|
}
|
|
}
|
|
|
|
// Copy backup to target
|
|
return createBackup(backup, target)
|
|
}
|
|
|
|
func installNewBinary(src, dst string) error {
|
|
// Copy new binary to a temporary location first
|
|
tempDst := dst + ".new"
|
|
|
|
srcFile, err := os.Open(src)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open source binary: %w", err)
|
|
}
|
|
defer srcFile.Close()
|
|
|
|
dstFile, err := os.Create(tempDst)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create temp binary: %w", err)
|
|
}
|
|
defer dstFile.Close()
|
|
|
|
if _, err := dstFile.ReadFrom(srcFile); err != nil {
|
|
return fmt.Errorf("failed to copy binary: %w", err)
|
|
}
|
|
dstFile.Close()
|
|
|
|
// Set executable permissions
|
|
if err := os.Chmod(tempDst, 0755); err != nil {
|
|
return fmt.Errorf("failed to set binary permissions: %w", err)
|
|
}
|
|
|
|
// Atomic rename
|
|
if err := os.Rename(tempDst, dst); err != nil {
|
|
os.Remove(tempDst) // Cleanup temp file
|
|
return fmt.Errorf("failed to atomically replace binary: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func restartAgentService() error {
|
|
var cmd *exec.Cmd
|
|
|
|
switch runtime.GOOS {
|
|
case "linux":
|
|
// Try systemd first
|
|
cmd = exec.Command("systemctl", "restart", "redflag-agent")
|
|
if err := cmd.Run(); err == nil {
|
|
log.Printf("✓ Systemd service restarted")
|
|
return nil
|
|
}
|
|
// Fallback to service command
|
|
cmd = exec.Command("service", "redflag-agent", "restart")
|
|
|
|
case "windows":
|
|
cmd = exec.Command("sc", "stop", "RedFlagAgent")
|
|
cmd.Run()
|
|
cmd = exec.Command("sc", "start", "RedFlagAgent")
|
|
|
|
default:
|
|
return fmt.Errorf("unsupported OS for service restart: %s", runtime.GOOS)
|
|
}
|
|
|
|
if err := cmd.Run(); err != nil {
|
|
return fmt.Errorf("failed to restart service: %w", err)
|
|
}
|
|
|
|
log.Printf("✓ Agent service restarted")
|
|
return nil
|
|
}
|
|
|
|
func waitForUpdateConfirmation(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, expectedVersion string, timeout time.Duration) bool {
|
|
deadline := time.Now().Add(timeout)
|
|
pollInterval := 15 * time.Second
|
|
|
|
log.Printf("[tunturi_ed25519] Watchdog: waiting for version %s confirmation (timeout: %v)...", expectedVersion, timeout)
|
|
|
|
for time.Now().Before(deadline) {
|
|
// Poll server for current agent version
|
|
agent, err := apiClient.GetAgent(cfg.AgentID.String())
|
|
if err != nil {
|
|
log.Printf("[tunturi_ed25519] Watchdog: failed to poll server: %v (retrying...)", err)
|
|
time.Sleep(pollInterval)
|
|
continue
|
|
}
|
|
|
|
// Check if the version matches the expected version
|
|
if agent != nil && agent.CurrentVersion == expectedVersion {
|
|
log.Printf("[tunturi_ed25519] Watchdog: ✓ Version confirmed: %s", expectedVersion)
|
|
return true
|
|
}
|
|
|
|
log.Printf("[tunturi_ed25519] Watchdog: Current version: %s, Expected: %s (polling...)",
|
|
agent.CurrentVersion, expectedVersion)
|
|
time.Sleep(pollInterval)
|
|
}
|
|
|
|
log.Printf("[tunturi_ed25519] Watchdog: ✗ Timeout after %v - version not confirmed", timeout)
|
|
log.Printf("[tunturi_ed25519] Rollback initiated")
|
|
return false
|
|
}
|
|
|
|
// AES-256-GCM decryption helper functions for encrypted update packages
|
|
|
|
// deriveKeyFromNonce derives an AES-256 key from a nonce using SHA-256
|
|
func deriveKeyFromNonce(nonce string) []byte {
|
|
hash := sha256.Sum256([]byte(nonce))
|
|
return hash[:] // 32 bytes for AES-256
|
|
}
|
|
|
|
// decryptAES256GCM decrypts data using AES-256-GCM with the provided nonce-derived key
|
|
func decryptAES256GCM(encryptedData, nonce string) ([]byte, error) {
|
|
// Derive key from nonce
|
|
key := deriveKeyFromNonce(nonce)
|
|
|
|
// Decode hex data
|
|
data, err := hex.DecodeString(encryptedData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode hex data: %w", err)
|
|
}
|
|
|
|
// Create AES cipher
|
|
block, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
|
|
}
|
|
|
|
// Create GCM
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
|
}
|
|
|
|
// Check minimum length
|
|
nonceSize := gcm.NonceSize()
|
|
if len(data) < nonceSize {
|
|
return nil, fmt.Errorf("encrypted data too short")
|
|
}
|
|
|
|
// Extract nonce and ciphertext
|
|
nonceBytes, ciphertext := data[:nonceSize], data[nonceSize:]
|
|
|
|
// Decrypt
|
|
plaintext, err := gcm.Open(nil, nonceBytes, ciphertext, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt: %w", err)
|
|
}
|
|
|
|
return plaintext, nil
|
|
}
|
|
|
|
// TODO: Integration with system/machine_id.go for key derivation
|
|
// This stub should be integrated with the existing machine ID system
|
|
// for more sophisticated key management based on hardware fingerprinting
|
|
//
|
|
// Example integration approach:
|
|
// - Use machine_id.go to generate stable hardware fingerprint
|
|
// - Combine hardware fingerprint with nonce for key derivation
|
|
// - Store derived keys securely in memory only
|
|
// - Implement key rotation support for long-running agents
|
|
|
|
// verifyBinarySignature verifies the Ed25519 signature of a binary file
|
|
func verifyBinarySignature(binaryPath, signatureHex string) error {
|
|
// Get the server public key from cache
|
|
publicKey, err := getServerPublicKey()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get server public key: %w", err)
|
|
}
|
|
|
|
// Read the binary content
|
|
content, err := os.ReadFile(binaryPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read binary: %w", err)
|
|
}
|
|
|
|
// Decode signature from hex
|
|
signatureBytes, err := hex.DecodeString(signatureHex)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decode signature: %w", err)
|
|
}
|
|
|
|
// Verify signature length
|
|
if len(signatureBytes) != ed25519.SignatureSize {
|
|
return fmt.Errorf("invalid signature length: expected %d bytes, got %d", ed25519.SignatureSize, len(signatureBytes))
|
|
}
|
|
|
|
// Ed25519 verification
|
|
valid := ed25519.Verify(ed25519.PublicKey(publicKey), content, signatureBytes)
|
|
if !valid {
|
|
return fmt.Errorf("signature verification failed: invalid signature")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// getServerPublicKey retrieves the Ed25519 public key from cache
|
|
// The key is fetched from the server at startup and cached locally
|
|
func getServerPublicKey() ([]byte, error) {
|
|
// Load from cache (fetched during agent startup)
|
|
publicKey, err := loadCachedPublicKeyDirect()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load server public key: %w (hint: key is fetched at agent startup)", err)
|
|
}
|
|
|
|
return publicKey, nil
|
|
}
|
|
|
|
// loadCachedPublicKeyDirect loads the cached public key from the standard location
|
|
func loadCachedPublicKeyDirect() ([]byte, error) {
|
|
var keyPath string
|
|
if runtime.GOOS == "windows" {
|
|
keyPath = "C:\\ProgramData\\RedFlag\\server_public_key"
|
|
} else {
|
|
keyPath = "/etc/redflag/server_public_key"
|
|
}
|
|
|
|
data, err := os.ReadFile(keyPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("public key not found: %w", err)
|
|
}
|
|
|
|
if len(data) != 32 { // ed25519.PublicKeySize
|
|
return nil, fmt.Errorf("invalid public key size: expected 32 bytes, got %d", len(data))
|
|
}
|
|
|
|
return data, nil
|
|
}
|
|
|
|
// validateNonce validates the nonce for replay protection
|
|
func validateNonce(nonceUUIDStr, nonceTimestampStr, nonceSignature string) error {
|
|
// Parse timestamp
|
|
nonceTimestamp, err := time.Parse(time.RFC3339, nonceTimestampStr)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid nonce timestamp format: %w", err)
|
|
}
|
|
|
|
// Check freshness (< 5 minutes)
|
|
age := time.Since(nonceTimestamp)
|
|
if age > 5*time.Minute {
|
|
return fmt.Errorf("nonce expired: age %v > 5 minutes", age)
|
|
}
|
|
|
|
if age < 0 {
|
|
return fmt.Errorf("nonce timestamp in the future: %v", nonceTimestamp)
|
|
}
|
|
|
|
// Get server public key from cache
|
|
publicKey, err := getServerPublicKey()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get server public key: %w", err)
|
|
}
|
|
|
|
// Recreate nonce data (must match server format)
|
|
nonceData := fmt.Sprintf("%s:%d", nonceUUIDStr, nonceTimestamp.Unix())
|
|
|
|
// Decode signature
|
|
signatureBytes, err := hex.DecodeString(nonceSignature)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid nonce signature format: %w", err)
|
|
}
|
|
|
|
if len(signatureBytes) != ed25519.SignatureSize {
|
|
return fmt.Errorf("invalid nonce signature length: expected %d bytes, got %d",
|
|
ed25519.SignatureSize, len(signatureBytes))
|
|
}
|
|
|
|
// Verify Ed25519 signature
|
|
valid := ed25519.Verify(ed25519.PublicKey(publicKey), []byte(nonceData), signatureBytes)
|
|
if !valid {
|
|
return fmt.Errorf("invalid nonce signature")
|
|
}
|
|
|
|
return nil
|
|
}
|