Files
Redflag/aggregator-agent/cmd/agent/subsystem_handlers.go
Fimeg e6ac0b1ec4 feat: implement agent migration system
- Fix config version inflation bug in main.go
- Add dynamic subsystem checking to prevent false change detection
- Implement migration detection and execution system
- Add directory migration from /etc/aggregator to /etc/redflag
- Update all path references across codebase to use new directories
- Add configuration schema versioning and automatic migration
- Implement backup and rollback capabilities
- Add security feature detection and hardening
- Update installation scripts and sudoers for new paths
- Complete Phase 1 migration system
2025-11-04 14:25:53 -05:00

894 lines
27 KiB
Go

package main
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/ed25519"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"runtime"
"time"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/acknowledgment"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/client"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/config"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/orchestrator"
)
// handleScanUpdatesV2 scans all update subsystems (APT, DNF, Docker, Windows Update, Winget) in parallel
// This is the new orchestrator-based version for v0.1.20
func handleScanUpdatesV2(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error {
log.Println("Scanning for updates (parallel execution)...")
ctx := context.Background()
startTime := time.Now()
// Execute all update scanners in parallel
results, allUpdates := orch.ScanAll(ctx)
// Format results
stdout, stderr, exitCode := orchestrator.FormatScanSummary(results)
// Add timing information
duration := time.Since(startTime)
stdout += fmt.Sprintf("\nScan completed in %.2f seconds\n", duration.Seconds())
// Create scan log entry with subsystem metadata
logReport := client.LogReport{
CommandID: commandID,
Action: "scan_updates",
Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0],
Stdout: stdout,
Stderr: stderr,
ExitCode: exitCode,
DurationSeconds: int(duration.Seconds()),
Metadata: map[string]string{
"subsystem_label": "Package Updates",
"subsystem": "updates",
},
}
// Report the scan log
if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil {
log.Printf("Failed to report scan log: %v\n", err)
// Continue anyway - updates are more important
}
// Report updates to server if any were found
if len(allUpdates) > 0 {
report := client.UpdateReport{
CommandID: commandID,
Timestamp: time.Now(),
Updates: allUpdates,
}
if err := apiClient.ReportUpdates(cfg.AgentID, report); err != nil {
return fmt.Errorf("failed to report updates: %w", err)
}
log.Printf("✓ Reported %d updates to server\n", len(allUpdates))
} else {
log.Println("No updates found")
}
return nil
}
// handleScanStorage scans disk usage metrics only
func handleScanStorage(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error {
log.Println("Scanning storage...")
ctx := context.Background()
startTime := time.Now()
// Execute storage scanner
result, err := orch.ScanSingle(ctx, "storage")
if err != nil {
return fmt.Errorf("failed to scan storage: %w", err)
}
// Format results
results := []orchestrator.ScanResult{result}
stdout, stderr, exitCode := orchestrator.FormatScanSummary(results)
duration := time.Since(startTime)
stdout += fmt.Sprintf("\nStorage scan completed in %.2f seconds\n", duration.Seconds())
// Create scan log entry
logReport := client.LogReport{
CommandID: commandID,
Action: "scan_storage",
Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0],
Stdout: stdout,
Stderr: stderr,
ExitCode: exitCode,
DurationSeconds: int(duration.Seconds()),
Metadata: map[string]string{
"subsystem_label": "Disk Usage",
"subsystem": "storage",
},
}
// Report the scan log
if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil {
log.Printf("Failed to report scan log: %v\n", err)
}
// Report storage metrics to server using dedicated endpoint
// Get storage scanner and use proper interface
storageScanner := orchestrator.NewStorageScanner("unknown") // TODO: Get actual agent version
if storageScanner.IsAvailable() {
metrics, err := storageScanner.ScanStorage()
if err != nil {
return fmt.Errorf("failed to scan storage metrics: %w", err)
}
if len(metrics) > 0 {
// Convert StorageMetric to MetricsReportItem for API call
metricItems := make([]client.MetricsReportItem, 0, len(metrics))
for _, metric := range metrics {
item := client.MetricsReportItem{
PackageType: "storage",
PackageName: metric.Mountpoint,
CurrentVersion: fmt.Sprintf("%d bytes used", metric.UsedBytes),
AvailableVersion: fmt.Sprintf("%d bytes total", metric.TotalBytes),
Severity: metric.Severity,
RepositorySource: metric.Filesystem,
Metadata: metric.Metadata,
}
metricItems = append(metricItems, item)
}
report := client.MetricsReport{
CommandID: commandID,
Timestamp: time.Now(),
Metrics: metricItems,
}
if err := apiClient.ReportMetrics(cfg.AgentID, report); err != nil {
return fmt.Errorf("failed to report storage metrics: %w", err)
}
log.Printf("✓ Reported %d storage metrics to server\n", len(metrics))
}
}
return nil
}
// handleScanSystem scans system metrics (CPU, memory, processes, uptime)
func handleScanSystem(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error {
log.Println("Scanning system metrics...")
ctx := context.Background()
startTime := time.Now()
// Execute system scanner
result, err := orch.ScanSingle(ctx, "system")
if err != nil {
return fmt.Errorf("failed to scan system: %w", err)
}
// Format results
results := []orchestrator.ScanResult{result}
stdout, stderr, exitCode := orchestrator.FormatScanSummary(results)
duration := time.Since(startTime)
stdout += fmt.Sprintf("\nSystem scan completed in %.2f seconds\n", duration.Seconds())
// Create scan log entry
logReport := client.LogReport{
CommandID: commandID,
Action: "scan_system",
Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0],
Stdout: stdout,
Stderr: stderr,
ExitCode: exitCode,
DurationSeconds: int(duration.Seconds()),
Metadata: map[string]string{
"subsystem_label": "System Metrics",
"subsystem": "system",
},
}
// Report the scan log
if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil {
log.Printf("Failed to report scan log: %v\n", err)
}
// Report system metrics to server using dedicated endpoint
// Get system scanner and use proper interface
systemScanner := orchestrator.NewSystemScanner("unknown") // TODO: Get actual agent version
if systemScanner.IsAvailable() {
metrics, err := systemScanner.ScanSystem()
if err != nil {
return fmt.Errorf("failed to scan system metrics: %w", err)
}
if len(metrics) > 0 {
// Convert SystemMetric to MetricsReportItem for API call
metricItems := make([]client.MetricsReportItem, 0, len(metrics))
for _, metric := range metrics {
item := client.MetricsReportItem{
PackageType: "system",
PackageName: metric.MetricName,
CurrentVersion: metric.CurrentValue,
AvailableVersion: metric.AvailableValue,
Severity: metric.Severity,
RepositorySource: metric.MetricType,
Metadata: metric.Metadata,
}
metricItems = append(metricItems, item)
}
report := client.MetricsReport{
CommandID: commandID,
Timestamp: time.Now(),
Metrics: metricItems,
}
if err := apiClient.ReportMetrics(cfg.AgentID, report); err != nil {
return fmt.Errorf("failed to report system metrics: %w", err)
}
log.Printf("✓ Reported %d system metrics to server\n", len(metrics))
}
}
return nil
}
// handleScanDocker scans Docker image updates only
func handleScanDocker(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, orch *orchestrator.Orchestrator, commandID string) error {
log.Println("Scanning Docker images...")
ctx := context.Background()
startTime := time.Now()
// Execute Docker scanner
result, err := orch.ScanSingle(ctx, "docker")
if err != nil {
return fmt.Errorf("failed to scan Docker: %w", err)
}
// Format results
results := []orchestrator.ScanResult{result}
stdout, stderr, exitCode := orchestrator.FormatScanSummary(results)
duration := time.Since(startTime)
stdout += fmt.Sprintf("\nDocker scan completed in %.2f seconds\n", duration.Seconds())
// Create scan log entry
logReport := client.LogReport{
CommandID: commandID,
Action: "scan_docker",
Result: map[bool]string{true: "success", false: "failure"}[exitCode == 0],
Stdout: stdout,
Stderr: stderr,
ExitCode: exitCode,
DurationSeconds: int(duration.Seconds()),
Metadata: map[string]string{
"subsystem_label": "Docker Images",
"subsystem": "docker",
},
}
// Report the scan log
if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil {
log.Printf("Failed to report scan log: %v\n", err)
}
// Report Docker images to server using dedicated endpoint
// Get Docker scanner and use proper interface
dockerScanner, err := orchestrator.NewDockerScanner()
if err != nil {
return fmt.Errorf("failed to create Docker scanner: %w", err)
}
defer dockerScanner.Close()
if dockerScanner.IsAvailable() {
images, err := dockerScanner.ScanDocker()
if err != nil {
return fmt.Errorf("failed to scan Docker images: %w", err)
}
// Always report all Docker images (not just those with updates)
if len(images) > 0 {
// Convert DockerImage to DockerReportItem for API call
imageItems := make([]client.DockerReportItem, 0, len(images))
for _, image := range images {
item := client.DockerReportItem{
PackageType: "docker_image",
PackageName: image.ImageName,
CurrentVersion: image.ImageID,
AvailableVersion: image.LatestImageID,
Severity: image.Severity,
RepositorySource: image.RepositorySource,
Metadata: image.Metadata,
}
imageItems = append(imageItems, item)
}
report := client.DockerReport{
CommandID: commandID,
Timestamp: time.Now(),
Images: imageItems,
}
if err := apiClient.ReportDockerImages(cfg.AgentID, report); err != nil {
return fmt.Errorf("failed to report Docker images: %w", err)
}
updateCount := 0
for _, image := range images {
if image.HasUpdate {
updateCount++
}
}
log.Printf("✓ Reported %d Docker images (%d with updates) to server\n", len(images), updateCount)
} else {
log.Println("No Docker images found")
}
} else {
log.Println("Docker not available on this system")
}
return nil
}
// handleUpdateAgent handles agent update commands with signature verification
func handleUpdateAgent(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, params map[string]interface{}, commandID string) error {
log.Println("Processing agent update command...")
// Extract parameters
version, ok := params["version"].(string)
if !ok {
return fmt.Errorf("missing version parameter")
}
platform, ok := params["platform"].(string)
if !ok {
return fmt.Errorf("missing platform parameter")
}
downloadURL, ok := params["download_url"].(string)
if !ok {
return fmt.Errorf("missing download_url parameter")
}
signature, ok := params["signature"].(string)
if !ok {
return fmt.Errorf("missing signature parameter")
}
checksum, ok := params["checksum"].(string)
if !ok {
return fmt.Errorf("missing checksum parameter")
}
// Extract nonce parameters for replay protection
nonceUUIDStr, ok := params["nonce_uuid"].(string)
if !ok {
return fmt.Errorf("missing nonce_uuid parameter")
}
nonceTimestampStr, ok := params["nonce_timestamp"].(string)
if !ok {
return fmt.Errorf("missing nonce_timestamp parameter")
}
nonceSignature, ok := params["nonce_signature"].(string)
if !ok {
return fmt.Errorf("missing nonce_signature parameter")
}
log.Printf("Updating agent to version %s (%s)", version, platform)
// Validate nonce for replay protection
log.Printf("[tunturi_ed25519] Validating nonce...")
log.Printf("[SECURITY] Nonce validation - UUID: %s, Timestamp: %s", nonceUUIDStr, nonceTimestampStr)
if err := validateNonce(nonceUUIDStr, nonceTimestampStr, nonceSignature); err != nil {
log.Printf("[SECURITY] ✗ Nonce validation FAILED: %v", err)
return fmt.Errorf("[tunturi_ed25519] nonce validation failed: %w", err)
}
log.Printf("[SECURITY] ✓ Nonce validated successfully")
// Record start time for duration calculation
updateStartTime := time.Now()
// Report the update command as started
logReport := client.LogReport{
CommandID: commandID,
Action: "update_agent",
Result: "started",
Stdout: fmt.Sprintf("Starting agent update to version %s\n", version),
Stderr: "",
ExitCode: 0,
DurationSeconds: 0,
Metadata: map[string]string{
"subsystem_label": "Agent Update",
"subsystem": "agent",
"target_version": version,
},
}
if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil {
log.Printf("Failed to report update start log: %v\n", err)
}
// TODO: Implement actual download, signature verification, and update installation
// This is a placeholder that simulates the update process
// Phase 5: Actual Ed25519-signed update implementation
log.Printf("Starting secure update process for version %s", version)
log.Printf("Download URL: %s", downloadURL)
log.Printf("Signature: %s...", signature[:16]) // Log first 16 chars of signature
log.Printf("Expected checksum: %s", checksum)
// Step 1: Download the update package
log.Printf("Step 1: Downloading update package...")
tempBinaryPath, err := downloadUpdatePackage(downloadURL)
if err != nil {
return fmt.Errorf("failed to download update package: %w", err)
}
defer os.Remove(tempBinaryPath) // Cleanup on exit
// Step 2: Verify checksum
log.Printf("Step 2: Verifying checksum...")
actualChecksum, err := computeSHA256(tempBinaryPath)
if err != nil {
return fmt.Errorf("failed to compute checksum: %w", err)
}
if actualChecksum != checksum {
return fmt.Errorf("checksum mismatch: expected %s, got %s", checksum, actualChecksum)
}
log.Printf("✓ Checksum verified: %s", actualChecksum)
// Step 3: Verify Ed25519 signature
log.Printf("[tunturi_ed25519] Step 3: Verifying Ed25519 signature...")
if err := verifyBinarySignature(tempBinaryPath, signature); err != nil {
return fmt.Errorf("[tunturi_ed25519] signature verification failed: %w", err)
}
log.Printf("[tunturi_ed25519] ✓ Signature verified")
// Step 4: Create backup of current binary
log.Printf("Step 4: Creating backup...")
currentBinaryPath, err := getCurrentBinaryPath()
if err != nil {
return fmt.Errorf("failed to determine current binary path: %w", err)
}
backupPath := currentBinaryPath + ".bak"
var updateSuccess bool = false // Track overall success
if err := createBackup(currentBinaryPath, backupPath); err != nil {
log.Printf("Warning: Failed to create backup: %v", err)
} else {
// Defer rollback/cleanup logic
defer func() {
if !updateSuccess {
// Rollback on failure
log.Printf("[tunturi_ed25519] Rollback: restoring from backup...")
if restoreErr := restoreFromBackup(backupPath, currentBinaryPath); restoreErr != nil {
log.Printf("[tunturi_ed25519] CRITICAL: Failed to restore backup: %v", restoreErr)
} else {
log.Printf("[tunturi_ed25519] ✓ Successfully rolled back to backup")
}
} else {
// Clean up backup on success
log.Printf("[tunturi_ed25519] ✓ Update successful, cleaning up backup")
os.Remove(backupPath)
}
}()
}
// Step 5: Atomic installation
log.Printf("Step 5: Installing new binary...")
if err := installNewBinary(tempBinaryPath, currentBinaryPath); err != nil {
return fmt.Errorf("failed to install new binary: %w", err)
}
// Step 6: Restart agent service
log.Printf("Step 6: Restarting agent service...")
if err := restartAgentService(); err != nil {
return fmt.Errorf("failed to restart agent: %w", err)
}
// Step 7: Watchdog timer for confirmation
log.Printf("Step 7: Starting watchdog for update confirmation...")
updateSuccess = waitForUpdateConfirmation(apiClient, cfg, ackTracker, version, 5*time.Minute)
success := updateSuccess // Alias for logging below
finalLogReport := client.LogReport{
CommandID: commandID,
Action: "update_agent",
Result: map[bool]string{true: "success", false: "failure"}[success],
Stdout: fmt.Sprintf("Agent update to version %s %s\n", version, map[bool]string{true: "completed successfully", false: "failed"}[success]),
Stderr: map[bool]string{true: "", false: "Update verification timeout or restart failure"}[success],
ExitCode: map[bool]int{true: 0, false: 1}[success],
DurationSeconds: int(time.Since(updateStartTime).Seconds()),
Metadata: map[string]string{
"subsystem_label": "Agent Update",
"subsystem": "agent",
"target_version": version,
"success": map[bool]string{true: "true", false: "false"}[success],
},
}
if err := reportLogWithAck(apiClient, cfg, ackTracker, finalLogReport); err != nil {
log.Printf("Failed to report update completion log: %v\n", err)
}
if success {
log.Printf("✓ Agent successfully updated to version %s", version)
} else {
return fmt.Errorf("agent update verification failed")
}
return nil
}
// Helper functions for the update process
func downloadUpdatePackage(downloadURL string) (string, error) {
// Download to temporary file
tempFile, err := os.CreateTemp("", "redflag-update-*.bin")
if err != nil {
return "", fmt.Errorf("failed to create temp file: %w", err)
}
defer tempFile.Close()
resp, err := http.Get(downloadURL)
if err != nil {
return "", fmt.Errorf("failed to download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("download failed with status: %d", resp.StatusCode)
}
if _, err := tempFile.ReadFrom(resp.Body); err != nil {
return "", fmt.Errorf("failed to write download: %w", err)
}
return tempFile.Name(), nil
}
func computeSHA256(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", fmt.Errorf("failed to compute hash: %w", err)
}
return hex.EncodeToString(hash.Sum(nil)), nil
}
func getCurrentBinaryPath() (string, error) {
execPath, err := os.Executable()
if err != nil {
return "", fmt.Errorf("failed to get executable path: %w", err)
}
return execPath, nil
}
func createBackup(src, dst string) error {
srcFile, err := os.Open(src)
if err != nil {
return fmt.Errorf("failed to open source: %w", err)
}
defer srcFile.Close()
dstFile, err := os.Create(dst)
if err != nil {
return fmt.Errorf("failed to create backup: %w", err)
}
defer dstFile.Close()
if _, err := dstFile.ReadFrom(srcFile); err != nil {
return fmt.Errorf("failed to copy backup: %w", err)
}
// Ensure backup is executable
if err := os.Chmod(dst, 0755); err != nil {
return fmt.Errorf("failed to set backup permissions: %w", err)
}
return nil
}
func restoreFromBackup(backup, target string) error {
// Remove current binary if it exists
if _, err := os.Stat(target); err == nil {
if err := os.Remove(target); err != nil {
return fmt.Errorf("failed to remove current binary: %w", err)
}
}
// Copy backup to target
return createBackup(backup, target)
}
func installNewBinary(src, dst string) error {
// Copy new binary to a temporary location first
tempDst := dst + ".new"
srcFile, err := os.Open(src)
if err != nil {
return fmt.Errorf("failed to open source binary: %w", err)
}
defer srcFile.Close()
dstFile, err := os.Create(tempDst)
if err != nil {
return fmt.Errorf("failed to create temp binary: %w", err)
}
defer dstFile.Close()
if _, err := dstFile.ReadFrom(srcFile); err != nil {
return fmt.Errorf("failed to copy binary: %w", err)
}
dstFile.Close()
// Set executable permissions
if err := os.Chmod(tempDst, 0755); err != nil {
return fmt.Errorf("failed to set binary permissions: %w", err)
}
// Atomic rename
if err := os.Rename(tempDst, dst); err != nil {
os.Remove(tempDst) // Cleanup temp file
return fmt.Errorf("failed to atomically replace binary: %w", err)
}
return nil
}
func restartAgentService() error {
var cmd *exec.Cmd
switch runtime.GOOS {
case "linux":
// Try systemd first
cmd = exec.Command("systemctl", "restart", "redflag-agent")
if err := cmd.Run(); err == nil {
log.Printf("✓ Systemd service restarted")
return nil
}
// Fallback to service command
cmd = exec.Command("service", "redflag-agent", "restart")
case "windows":
cmd = exec.Command("sc", "stop", "RedFlagAgent")
cmd.Run()
cmd = exec.Command("sc", "start", "RedFlagAgent")
default:
return fmt.Errorf("unsupported OS for service restart: %s", runtime.GOOS)
}
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to restart service: %w", err)
}
log.Printf("✓ Agent service restarted")
return nil
}
func waitForUpdateConfirmation(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, expectedVersion string, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
pollInterval := 15 * time.Second
log.Printf("[tunturi_ed25519] Watchdog: waiting for version %s confirmation (timeout: %v)...", expectedVersion, timeout)
for time.Now().Before(deadline) {
// Poll server for current agent version
agent, err := apiClient.GetAgent(cfg.AgentID.String())
if err != nil {
log.Printf("[tunturi_ed25519] Watchdog: failed to poll server: %v (retrying...)", err)
time.Sleep(pollInterval)
continue
}
// Check if the version matches the expected version
if agent != nil && agent.CurrentVersion == expectedVersion {
log.Printf("[tunturi_ed25519] Watchdog: ✓ Version confirmed: %s", expectedVersion)
return true
}
log.Printf("[tunturi_ed25519] Watchdog: Current version: %s, Expected: %s (polling...)",
agent.CurrentVersion, expectedVersion)
time.Sleep(pollInterval)
}
log.Printf("[tunturi_ed25519] Watchdog: ✗ Timeout after %v - version not confirmed", timeout)
log.Printf("[tunturi_ed25519] Rollback initiated")
return false
}
// AES-256-GCM decryption helper functions for encrypted update packages
// deriveKeyFromNonce derives an AES-256 key from a nonce using SHA-256
func deriveKeyFromNonce(nonce string) []byte {
hash := sha256.Sum256([]byte(nonce))
return hash[:] // 32 bytes for AES-256
}
// decryptAES256GCM decrypts data using AES-256-GCM with the provided nonce-derived key
func decryptAES256GCM(encryptedData, nonce string) ([]byte, error) {
// Derive key from nonce
key := deriveKeyFromNonce(nonce)
// Decode hex data
data, err := hex.DecodeString(encryptedData)
if err != nil {
return nil, fmt.Errorf("failed to decode hex data: %w", err)
}
// Create AES cipher
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
// Create GCM
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
// Check minimum length
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, fmt.Errorf("encrypted data too short")
}
// Extract nonce and ciphertext
nonceBytes, ciphertext := data[:nonceSize], data[nonceSize:]
// Decrypt
plaintext, err := gcm.Open(nil, nonceBytes, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt: %w", err)
}
return plaintext, nil
}
// TODO: Integration with system/machine_id.go for key derivation
// This stub should be integrated with the existing machine ID system
// for more sophisticated key management based on hardware fingerprinting
//
// Example integration approach:
// - Use machine_id.go to generate stable hardware fingerprint
// - Combine hardware fingerprint with nonce for key derivation
// - Store derived keys securely in memory only
// - Implement key rotation support for long-running agents
// verifyBinarySignature verifies the Ed25519 signature of a binary file
func verifyBinarySignature(binaryPath, signatureHex string) error {
// Get the server public key from cache
publicKey, err := getServerPublicKey()
if err != nil {
return fmt.Errorf("failed to get server public key: %w", err)
}
// Read the binary content
content, err := os.ReadFile(binaryPath)
if err != nil {
return fmt.Errorf("failed to read binary: %w", err)
}
// Decode signature from hex
signatureBytes, err := hex.DecodeString(signatureHex)
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
// Verify signature length
if len(signatureBytes) != ed25519.SignatureSize {
return fmt.Errorf("invalid signature length: expected %d bytes, got %d", ed25519.SignatureSize, len(signatureBytes))
}
// Ed25519 verification
valid := ed25519.Verify(ed25519.PublicKey(publicKey), content, signatureBytes)
if !valid {
return fmt.Errorf("signature verification failed: invalid signature")
}
return nil
}
// getServerPublicKey retrieves the Ed25519 public key from cache
// The key is fetched from the server at startup and cached locally
func getServerPublicKey() ([]byte, error) {
// Load from cache (fetched during agent startup)
publicKey, err := loadCachedPublicKeyDirect()
if err != nil {
return nil, fmt.Errorf("failed to load server public key: %w (hint: key is fetched at agent startup)", err)
}
return publicKey, nil
}
// loadCachedPublicKeyDirect loads the cached public key from the standard location
func loadCachedPublicKeyDirect() ([]byte, error) {
var keyPath string
if runtime.GOOS == "windows" {
keyPath = "C:\\ProgramData\\RedFlag\\server_public_key"
} else {
keyPath = "/etc/redflag/server_public_key"
}
data, err := os.ReadFile(keyPath)
if err != nil {
return nil, fmt.Errorf("public key not found: %w", err)
}
if len(data) != 32 { // ed25519.PublicKeySize
return nil, fmt.Errorf("invalid public key size: expected 32 bytes, got %d", len(data))
}
return data, nil
}
// validateNonce validates the nonce for replay protection
func validateNonce(nonceUUIDStr, nonceTimestampStr, nonceSignature string) error {
// Parse timestamp
nonceTimestamp, err := time.Parse(time.RFC3339, nonceTimestampStr)
if err != nil {
return fmt.Errorf("invalid nonce timestamp format: %w", err)
}
// Check freshness (< 5 minutes)
age := time.Since(nonceTimestamp)
if age > 5*time.Minute {
return fmt.Errorf("nonce expired: age %v > 5 minutes", age)
}
if age < 0 {
return fmt.Errorf("nonce timestamp in the future: %v", nonceTimestamp)
}
// Get server public key from cache
publicKey, err := getServerPublicKey()
if err != nil {
return fmt.Errorf("failed to get server public key: %w", err)
}
// Recreate nonce data (must match server format)
nonceData := fmt.Sprintf("%s:%d", nonceUUIDStr, nonceTimestamp.Unix())
// Decode signature
signatureBytes, err := hex.DecodeString(nonceSignature)
if err != nil {
return fmt.Errorf("invalid nonce signature format: %w", err)
}
if len(signatureBytes) != ed25519.SignatureSize {
return fmt.Errorf("invalid nonce signature length: expected %d bytes, got %d",
ed25519.SignatureSize, len(signatureBytes))
}
// Verify Ed25519 signature
valid := ed25519.Verify(ed25519.PublicKey(publicKey), []byte(nonceData), signatureBytes)
if !valid {
return fmt.Errorf("invalid nonce signature")
}
return nil
}