Files
Redflag/aggregator-agent/cmd/agent/main.go
Fimeg 3f0838affc refactor: replace 899 lines of script generation with templates
Created InstallTemplateService with clean template-based script generation.
Added linux.sh.tmpl and windows.ps1.tmpl for install scripts.
Removed massive generateLinuxScript and generateWindowsScript functions.
Downloads handler now uses template service (1073 lines → 174 lines).
Templates easily maintainable without modifying Go code.
2025-11-10 22:41:47 -05:00

1765 lines
58 KiB
Go
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"flag"
"fmt"
"log"
"math/rand"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/acknowledgment"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/cache"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/circuitbreaker"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/client"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/config"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/crypto"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/display"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/installer"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/migration"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/orchestrator"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/scanner"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/service"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/system"
"github.com/google/uuid"
)
const (
AgentVersion = "0.1.23" // v0.1.23: Real security metrics and config sync
)
var (
lastConfigVersion int64 = 0 // Track last applied config version
)
// getConfigPath returns the platform-specific config path
func getConfigPath() string {
if runtime.GOOS == "windows" {
return "C:\\ProgramData\\RedFlag\\config.json"
}
return "/etc/redflag/config.json"
}
// getStatePath returns the platform-specific state directory path
func getStatePath() string {
if runtime.GOOS == "windows" {
return "C:\\ProgramData\\RedFlag\\state"
}
return "/var/lib/redflag"
}
// reportLogWithAck reports a command log to the server and tracks it for acknowledgment
func reportLogWithAck(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, logReport client.LogReport) error {
// Track this command result as pending acknowledgment
ackTracker.Add(logReport.CommandID)
// Save acknowledgment state immediately
if err := ackTracker.Save(); err != nil {
log.Printf("Warning: Failed to save acknowledgment for command %s: %v", logReport.CommandID, err)
}
// Report the log to the server (FIX: was calling itself recursively!)
if err := apiClient.ReportLog(cfg.AgentID, logReport); err != nil {
// If reporting failed, increment retry count but don't remove from pending
ackTracker.IncrementRetry(logReport.CommandID)
return err
}
return nil
}
// getCurrentPollingInterval returns the appropriate polling interval based on rapid mode
func getCurrentPollingInterval(cfg *config.Config) int {
// Check if rapid polling mode is active and not expired
if cfg.RapidPollingEnabled && time.Now().Before(cfg.RapidPollingUntil) {
return 5 // Rapid polling: 5 seconds
}
// Check if rapid polling has expired and clean up
if cfg.RapidPollingEnabled && time.Now().After(cfg.RapidPollingUntil) {
cfg.RapidPollingEnabled = false
cfg.RapidPollingUntil = time.Time{}
// Save the updated config to clean up expired rapid mode
if err := cfg.Save(getConfigPath()); err != nil {
log.Printf("Warning: Failed to cleanup expired rapid polling mode: %v", err)
}
}
return cfg.CheckInInterval // Normal polling: 5 minutes (300 seconds) by default
}
// getDefaultServerURL returns the default server URL with environment variable support
func getDefaultServerURL() string {
// Check environment variable first
if envURL := os.Getenv("REDFLAG_SERVER_URL"); envURL != "" {
return envURL
}
// Platform-specific defaults
if runtime.GOOS == "windows" {
// For Windows, use a placeholder that prompts users to configure
return "http://REPLACE_WITH_SERVER_IP:8080"
}
return "http://localhost:8080"
}
func main() {
// Define CLI flags
registerCmd := flag.Bool("register", false, "Register agent with server")
scanCmd := flag.Bool("scan", false, "Scan for updates and display locally")
statusCmd := flag.Bool("status", false, "Show agent status")
listUpdatesCmd := flag.Bool("list-updates", false, "List detailed update information")
versionCmd := flag.Bool("version", false, "Show version information")
serverURL := flag.String("server", "", "Server URL")
registrationToken := flag.String("token", "", "Registration token for secure enrollment")
proxyHTTP := flag.String("proxy-http", "", "HTTP proxy URL")
proxyHTTPS := flag.String("proxy-https", "", "HTTPS proxy URL")
proxyNoProxy := flag.String("proxy-no", "", "Comma-separated hosts to bypass proxy")
logLevel := flag.String("log-level", "", "Log level (debug, info, warn, error)")
configFile := flag.String("config", "", "Configuration file path")
tagsFlag := flag.String("tags", "", "Comma-separated tags for agent")
organization := flag.String("organization", "", "Organization/group name")
displayName := flag.String("name", "", "Display name for agent")
insecureTLS := flag.Bool("insecure-tls", false, "Skip TLS certificate verification")
exportFormat := flag.String("export", "", "Export format: json, csv")
// Windows service management commands
installServiceCmd := flag.Bool("install-service", false, "Install as Windows service")
removeServiceCmd := flag.Bool("remove-service", false, "Remove Windows service")
startServiceCmd := flag.Bool("start-service", false, "Start Windows service")
stopServiceCmd := flag.Bool("stop-service", false, "Stop Windows service")
serviceStatusCmd := flag.Bool("service-status", false, "Show Windows service status")
flag.Parse()
// Handle version command
if *versionCmd {
fmt.Printf("RedFlag Agent v%s\n", AgentVersion)
fmt.Printf("Self-hosted update management platform\n")
os.Exit(0)
}
// Handle Windows service management commands (only on Windows)
if runtime.GOOS == "windows" {
if *installServiceCmd {
if err := service.InstallService(); err != nil {
log.Fatalf("Failed to install service: %v", err)
}
fmt.Println("RedFlag service installed successfully")
os.Exit(0)
}
if *removeServiceCmd {
if err := service.RemoveService(); err != nil {
log.Fatalf("Failed to remove service: %v", err)
}
fmt.Println("RedFlag service removed successfully")
os.Exit(0)
}
if *startServiceCmd {
if err := service.StartService(); err != nil {
log.Fatalf("Failed to start service: %v", err)
}
fmt.Println("RedFlag service started successfully")
os.Exit(0)
}
if *stopServiceCmd {
if err := service.StopService(); err != nil {
log.Fatalf("Failed to stop service: %v", err)
}
fmt.Println("RedFlag service stopped successfully")
os.Exit(0)
}
if *serviceStatusCmd {
if err := service.ServiceStatus(); err != nil {
log.Fatalf("Failed to get service status: %v", err)
}
os.Exit(0)
}
}
// Parse tags from comma-separated string
var tags []string
if *tagsFlag != "" {
tags = strings.Split(*tagsFlag, ",")
for i, tag := range tags {
tags[i] = strings.TrimSpace(tag)
}
}
// Create CLI flags structure
cliFlags := &config.CLIFlags{
ServerURL: *serverURL,
RegistrationToken: *registrationToken,
ProxyHTTP: *proxyHTTP,
ProxyHTTPS: *proxyHTTPS,
ProxyNoProxy: *proxyNoProxy,
LogLevel: *logLevel,
ConfigFile: *configFile,
Tags: tags,
Organization: *organization,
DisplayName: *displayName,
InsecureTLS: *insecureTLS,
}
// Determine config path
configPath := getConfigPath()
if *configFile != "" {
configPath = *configFile
}
// Check for migration requirements before loading configuration
migrationConfig := migration.NewFileDetectionConfig()
// Set old paths to detect existing installations
migrationConfig.OldConfigPath = "/etc/aggregator"
migrationConfig.OldStatePath = "/var/lib/aggregator"
// Set new paths that agent will actually use
migrationConfig.NewConfigPath = filepath.Dir(configPath)
migrationConfig.NewStatePath = getStatePath()
// Detect migration requirements
migrationDetection, err := migration.DetectMigrationRequirements(migrationConfig)
if err != nil {
log.Printf("Warning: Failed to detect migration requirements: %v", err)
} else if migrationDetection.RequiresMigration {
log.Printf("[RedFlag Server Migrator] Migration detected: %s → %s", migrationDetection.CurrentAgentVersion, AgentVersion)
log.Printf("[RedFlag Server Migrator] Required migrations: %v", migrationDetection.RequiredMigrations)
// Create migration plan
migrationPlan := &migration.MigrationPlan{
Detection: migrationDetection,
TargetVersion: AgentVersion,
Config: migrationConfig,
BackupPath: filepath.Join(getStatePath(), "migration_backups"), // Set backup path within agent's state directory
}
// Execute migration
executor := migration.NewMigrationExecutor(migrationPlan)
result, err := executor.ExecuteMigration()
if err != nil {
log.Printf("[RedFlag Server Migrator] Migration failed: %v", err)
log.Printf("[RedFlag Server Migrator] Backup available at: %s", result.BackupPath)
log.Printf("[RedFlag Server Migrator] Agent may not function correctly until migration is completed")
} else {
log.Printf("[RedFlag Server Migrator] Migration completed successfully")
if result.RollbackAvailable {
log.Printf("[RedFlag Server Migrator] Rollback available at: %s", result.BackupPath)
}
}
}
// Load configuration with priority: CLI > env > file > defaults
cfg, err := config.Load(configPath, cliFlags)
if err != nil {
log.Fatal("Failed to load configuration:", err)
}
// Always set the current agent version in config
if cfg.AgentVersion != AgentVersion {
if cfg.AgentVersion != "" {
log.Printf("[RedFlag Server Migrator] Version change detected: %s → %s", cfg.AgentVersion, AgentVersion)
log.Printf("[RedFlag Server Migrator] Performing lightweight migration check...")
}
// Update config version to match current agent
cfg.AgentVersion = AgentVersion
// Save updated config
if err := cfg.Save(configPath); err != nil {
log.Printf("Warning: Failed to update agent version in config: %v", err)
} else {
if cfg.AgentVersion != "" {
log.Printf("[RedFlag Server Migrator] Agent version updated in configuration")
}
}
}
// Handle registration
if *registerCmd {
// Validate server URL for Windows users
if runtime.GOOS == "windows" && strings.Contains(*serverURL, "REPLACE_WITH_SERVER_IP") {
fmt.Println("❌ CONFIGURATION REQUIRED!")
fmt.Println("==================================================================")
fmt.Println("Please configure the server URL before registering:")
fmt.Println("")
fmt.Println("Option 1 - Use the -server flag:")
fmt.Printf(" redflag-agent.exe -register -server http://10.10.20.159:8080\n")
fmt.Println("")
fmt.Println("Option 2 - Use environment variable:")
fmt.Println(" set REDFLAG_SERVER_URL=http://10.10.20.159:8080")
fmt.Println(" redflag-agent.exe -register")
fmt.Println("")
fmt.Println("Option 3 - Create a .env file:")
fmt.Println(" REDFLAG_SERVER_URL=http://10.10.20.159:8080")
fmt.Println("==================================================================")
os.Exit(1)
}
if err := registerAgent(cfg, *serverURL); err != nil {
log.Fatal("Registration failed:", err)
}
fmt.Println("==================================================================")
fmt.Println("🎉 AGENT REGISTRATION SUCCESSFUL!")
fmt.Println("==================================================================")
fmt.Printf("📋 Agent ID: %s\n", cfg.AgentID)
fmt.Printf("🌐 Server: %s\n", cfg.ServerURL)
fmt.Printf("⏱️ Check-in Interval: %ds\n", cfg.CheckInInterval)
fmt.Println("==================================================================")
fmt.Println("💡 Save this Agent ID for your records!")
fmt.Println("🚀 You can now start the agent without flags")
fmt.Println("")
return
}
// Handle scan command
if *scanCmd {
if err := handleScanCommand(cfg, *exportFormat); err != nil {
log.Fatal("Scan failed:", err)
}
return
}
// Handle status command
if *statusCmd {
if err := handleStatusCommand(cfg); err != nil {
log.Fatal("Status command failed:", err)
}
return
}
// Handle list-updates command
if *listUpdatesCmd {
if err := handleListUpdatesCommand(cfg, *exportFormat); err != nil {
log.Fatal("List updates failed:", err)
}
return
}
// Check if registered
if !cfg.IsRegistered() {
log.Fatal("Agent not registered. Run with -register flag first.")
}
// Check if running as Windows service
if runtime.GOOS == "windows" && service.IsService() {
// Run as Windows service
if err := service.RunService(cfg); err != nil {
log.Fatal("Service failed:", err)
}
return
}
// Start agent service (console mode)
if err := runAgent(cfg); err != nil {
log.Fatal("Agent failed:", err)
}
}
func registerAgent(cfg *config.Config, serverURL string) error {
// Get detailed system information
sysInfo, err := system.GetSystemInfo(AgentVersion)
if err != nil {
log.Printf("Warning: Failed to get detailed system info: %v\n", err)
// Fall back to basic detection
hostname, _ := os.Hostname()
osType, osVersion, osArch := client.DetectSystem()
sysInfo = &system.SystemInfo{
Hostname: hostname,
OSType: osType,
OSVersion: osVersion,
OSArchitecture: osArch,
AgentVersion: AgentVersion,
Metadata: make(map[string]string),
}
}
// Use registration token from config if available
apiClient := client.NewClient(serverURL, cfg.RegistrationToken)
// Create metadata with system information
metadata := map[string]string{
"installation_time": time.Now().Format(time.RFC3339),
}
// Add system info to metadata
if sysInfo.CPUInfo.ModelName != "" {
metadata["cpu_model"] = sysInfo.CPUInfo.ModelName
}
if sysInfo.CPUInfo.Cores > 0 {
metadata["cpu_cores"] = fmt.Sprintf("%d", sysInfo.CPUInfo.Cores)
}
if sysInfo.MemoryInfo.Total > 0 {
metadata["memory_total"] = fmt.Sprintf("%d", sysInfo.MemoryInfo.Total)
}
if sysInfo.RunningProcesses > 0 {
metadata["processes"] = fmt.Sprintf("%d", sysInfo.RunningProcesses)
}
if sysInfo.Uptime != "" {
metadata["uptime"] = sysInfo.Uptime
}
// Add disk information
for i, disk := range sysInfo.DiskInfo {
if i == 0 {
metadata["disk_mount"] = disk.Mountpoint
metadata["disk_total"] = fmt.Sprintf("%d", disk.Total)
metadata["disk_used"] = fmt.Sprintf("%d", disk.Used)
break // Only add primary disk info
}
}
// Get machine ID for binding
machineID, err := system.GetMachineID()
if err != nil {
log.Printf("Warning: Failed to get machine ID: %v", err)
machineID = "unknown-" + sysInfo.Hostname
}
// Get embedded public key fingerprint
publicKeyFingerprint := system.GetPublicKeyFingerprint()
if publicKeyFingerprint == "" {
log.Printf("Warning: No embedded public key fingerprint found")
}
req := client.RegisterRequest{
Hostname: sysInfo.Hostname,
OSType: sysInfo.OSType,
OSVersion: sysInfo.OSVersion,
OSArchitecture: sysInfo.OSArchitecture,
AgentVersion: sysInfo.AgentVersion,
MachineID: machineID,
PublicKeyFingerprint: publicKeyFingerprint,
Metadata: metadata,
}
resp, err := apiClient.Register(req)
if err != nil {
return err
}
// Update configuration
cfg.ServerURL = serverURL
cfg.AgentID = resp.AgentID
cfg.Token = resp.Token
cfg.RefreshToken = resp.RefreshToken
// Get check-in interval from server config
if interval, ok := resp.Config["check_in_interval"].(float64); ok {
cfg.CheckInInterval = int(interval)
} else {
cfg.CheckInInterval = 300 // Default 5 minutes
}
// Save configuration
if err := cfg.Save(getConfigPath()); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
// Fetch and cache server public key for signature verification
log.Println("Fetching server public key for update signature verification...")
if err := fetchAndCachePublicKey(cfg.ServerURL); err != nil {
log.Printf("Warning: Failed to fetch server public key: %v", err)
log.Printf("Agent will not be able to verify update signatures")
// Don't fail registration - key can be fetched later
} else {
log.Println("✓ Server public key cached successfully")
}
return nil
}
// fetchAndCachePublicKey fetches the server's Ed25519 public key and caches it locally
func fetchAndCachePublicKey(serverURL string) error {
_, err := crypto.FetchAndCacheServerPublicKey(serverURL)
return err
}
// renewTokenIfNeeded handles 401 errors by renewing the agent token using refresh token
func renewTokenIfNeeded(apiClient *client.Client, cfg *config.Config, err error) (*client.Client, error) {
if err != nil && strings.Contains(err.Error(), "401 Unauthorized") {
log.Printf("🔄 Access token expired - attempting renewal with refresh token...")
// Check if we have a refresh token
if cfg.RefreshToken == "" {
log.Printf("❌ No refresh token available - re-registration required")
return nil, fmt.Errorf("refresh token missing - please re-register agent")
}
// Create temporary client without token for renewal
tempClient := client.NewClient(cfg.ServerURL, "")
// Attempt to renew access token using refresh token
if err := tempClient.RenewToken(cfg.AgentID, cfg.RefreshToken); err != nil {
log.Printf("❌ Refresh token renewal failed: %v", err)
log.Printf("💡 Refresh token may be expired (>90 days) - re-registration required")
return nil, fmt.Errorf("refresh token renewal failed: %w - please re-register agent", err)
}
// Update config with new access token (agent ID and refresh token stay the same!)
cfg.Token = tempClient.GetToken()
// Save updated config
if err := cfg.Save(getConfigPath()); err != nil {
log.Printf("⚠️ Warning: Failed to save renewed access token: %v", err)
}
log.Printf("✅ Access token renewed successfully - agent ID maintained: %s", cfg.AgentID)
return tempClient, nil
}
// Return original client if no 401 error
return apiClient, nil
}
// getCurrentSubsystemEnabled returns the current enabled state for a subsystem
func getCurrentSubsystemEnabled(cfg *config.Config, subsystemName string) bool {
switch subsystemName {
case "system":
return cfg.Subsystems.System.Enabled
case "updates":
return cfg.Subsystems.Updates.Enabled
case "docker":
return cfg.Subsystems.Docker.Enabled
case "storage":
return cfg.Subsystems.Storage.Enabled
case "apt":
return cfg.Subsystems.APT.Enabled
case "dnf":
return cfg.Subsystems.DNF.Enabled
case "windows":
return cfg.Subsystems.Windows.Enabled
case "winget":
return cfg.Subsystems.Winget.Enabled
default:
// Unknown subsystem, assume disabled
return false
}
}
// syncServerConfig checks for and applies server configuration updates
func syncServerConfig(apiClient *client.Client, cfg *config.Config) error {
// Get current config from server
serverConfig, err := apiClient.GetConfig(cfg.AgentID)
if err != nil {
return fmt.Errorf("failed to get server config: %w", err)
}
// Check if config version is newer
if serverConfig.Version <= lastConfigVersion {
return nil // No update needed
}
log.Printf("📡 Server config update detected (version: %d)", serverConfig.Version)
changes := false
// Track potential check-in interval changes separately to avoid inflation
newCheckInInterval := cfg.CheckInInterval
// Apply subsystem configuration from server
for subsystemName, subsystemConfig := range serverConfig.Subsystems {
if configMap, ok := subsystemConfig.(map[string]interface{}); ok {
enabled := false
intervalMinutes := 0
autoRun := false
if e, exists := configMap["enabled"]; exists {
if eVal, ok := e.(bool); ok {
enabled = eVal
}
}
if i, exists := configMap["interval_minutes"]; exists {
if iVal, ok := i.(float64); ok {
intervalMinutes = int(iVal)
}
}
if a, exists := configMap["auto_run"]; exists {
if aVal, ok := a.(bool); ok {
autoRun = aVal
}
}
// Get current subsystem enabled state dynamically
currentEnabled := getCurrentSubsystemEnabled(cfg, subsystemName)
if enabled != currentEnabled {
log.Printf(" → %s: enabled=%v (changed)", subsystemName, enabled)
changes = true
}
// Check if interval actually changed, but don't modify cfg.CheckInInterval yet
if intervalMinutes > 0 && intervalMinutes != newCheckInInterval {
log.Printf(" → %s: interval=%d minutes (changed)", subsystemName, intervalMinutes)
changes = true
newCheckInInterval = intervalMinutes // Update temp variable, not the config
}
if autoRun {
log.Printf(" → %s: auto_run=%v (server-side scheduling)", subsystemName, autoRun)
}
}
}
// Apply the check-in interval change only once after all subsystems processed
if newCheckInInterval != cfg.CheckInInterval {
cfg.CheckInInterval = newCheckInInterval
}
if changes {
log.Printf("✅ Server configuration applied successfully")
} else {
log.Printf(" Server config received but no changes detected")
}
// Update last config version
lastConfigVersion = serverConfig.Version
return nil
}
func runAgent(cfg *config.Config) error {
log.Printf("🚩 RedFlag Agent v%s starting...\n", AgentVersion)
log.Printf("==================================================================")
log.Printf("📋 AGENT ID: %s", cfg.AgentID)
log.Printf("🌐 SERVER: %s", cfg.ServerURL)
log.Printf("⏱️ CHECK-IN INTERVAL: %ds", cfg.CheckInInterval)
log.Printf("==================================================================")
log.Printf("💡 Tip: Use this Agent ID to identify this agent in the web UI")
log.Printf("")
apiClient := client.NewClient(cfg.ServerURL, cfg.Token)
// Initialize scanners for package updates (used by update orchestrator)
aptScanner := scanner.NewAPTScanner()
dnfScanner := scanner.NewDNFScanner()
windowsUpdateScanner := scanner.NewWindowsUpdateScanner()
wingetScanner := scanner.NewWingetScanner()
// Docker, Storage, and System scanners are created by individual subsystem handlers
// dockerScanner is created in handleScanDocker
// storageScanner and systemScanner are created in main for individual handlers
// Initialize circuit breakers for update scanners only
aptCB := circuitbreaker.New("APT", circuitbreaker.Config{
FailureThreshold: cfg.Subsystems.APT.CircuitBreaker.FailureThreshold,
FailureWindow: cfg.Subsystems.APT.CircuitBreaker.FailureWindow,
OpenDuration: cfg.Subsystems.APT.CircuitBreaker.OpenDuration,
HalfOpenAttempts: cfg.Subsystems.APT.CircuitBreaker.HalfOpenAttempts,
})
dnfCB := circuitbreaker.New("DNF", circuitbreaker.Config{
FailureThreshold: cfg.Subsystems.DNF.CircuitBreaker.FailureThreshold,
FailureWindow: cfg.Subsystems.DNF.CircuitBreaker.FailureWindow,
OpenDuration: cfg.Subsystems.DNF.CircuitBreaker.OpenDuration,
HalfOpenAttempts: cfg.Subsystems.DNF.CircuitBreaker.HalfOpenAttempts,
})
windowsCB := circuitbreaker.New("Windows Update", circuitbreaker.Config{
FailureThreshold: cfg.Subsystems.Windows.CircuitBreaker.FailureThreshold,
FailureWindow: cfg.Subsystems.Windows.CircuitBreaker.FailureWindow,
OpenDuration: cfg.Subsystems.Windows.CircuitBreaker.OpenDuration,
HalfOpenAttempts: cfg.Subsystems.Windows.CircuitBreaker.HalfOpenAttempts,
})
wingetCB := circuitbreaker.New("Winget", circuitbreaker.Config{
FailureThreshold: cfg.Subsystems.Winget.CircuitBreaker.FailureThreshold,
FailureWindow: cfg.Subsystems.Winget.CircuitBreaker.FailureWindow,
OpenDuration: cfg.Subsystems.Winget.CircuitBreaker.OpenDuration,
HalfOpenAttempts: cfg.Subsystems.Winget.CircuitBreaker.HalfOpenAttempts,
})
// Initialize scanner orchestrator for parallel execution and granular subsystem management
scanOrchestrator := orchestrator.NewOrchestrator()
// Register update scanners ONLY - package management systems
scanOrchestrator.RegisterScanner("apt", orchestrator.NewAPTScannerWrapper(aptScanner), aptCB, cfg.Subsystems.APT.Timeout, cfg.Subsystems.APT.Enabled)
scanOrchestrator.RegisterScanner("dnf", orchestrator.NewDNFScannerWrapper(dnfScanner), dnfCB, cfg.Subsystems.DNF.Timeout, cfg.Subsystems.DNF.Enabled)
scanOrchestrator.RegisterScanner("windows", orchestrator.NewWindowsUpdateScannerWrapper(windowsUpdateScanner), windowsCB, cfg.Subsystems.Windows.Timeout, cfg.Subsystems.Windows.Enabled)
scanOrchestrator.RegisterScanner("winget", orchestrator.NewWingetScannerWrapper(wingetScanner), wingetCB, cfg.Subsystems.Winget.Timeout, cfg.Subsystems.Winget.Enabled)
// NOTE: Docker, Storage, and System scanners are NOT registered with the update orchestrator
// They have their own dedicated handlers and endpoints:
// - Docker: handleScanDocker → ReportDockerImages()
// - Storage: handleScanStorage → ReportMetrics()
// - System: handleScanSystem → ReportMetrics()
// Initialize acknowledgment tracker for command result reliability
ackTracker := acknowledgment.NewTracker(getStatePath())
if err := ackTracker.Load(); err != nil {
log.Printf("Warning: Failed to load pending acknowledgments: %v", err)
} else {
pendingCount := len(ackTracker.GetPending())
if pendingCount > 0 {
log.Printf("Loaded %d pending command acknowledgments from previous session", pendingCount)
}
}
// Periodic cleanup of old/stale acknowledgments
go func() {
cleanupTicker := time.NewTicker(1 * time.Hour)
defer cleanupTicker.Stop()
for range cleanupTicker.C {
removed := ackTracker.Cleanup()
if removed > 0 {
log.Printf("Cleaned up %d stale acknowledgments", removed)
if err := ackTracker.Save(); err != nil {
log.Printf("Warning: Failed to save acknowledgments after cleanup: %v", err)
}
}
}
}()
// System info tracking
var lastSystemInfoUpdate time.Time
const systemInfoUpdateInterval = 1 * time.Hour // Update detailed system info every hour
// Main check-in loop
for {
// Add jitter to prevent thundering herd
jitter := time.Duration(rand.Intn(30)) * time.Second
time.Sleep(jitter)
// Check if we need to send detailed system info update
if time.Since(lastSystemInfoUpdate) >= systemInfoUpdateInterval {
log.Printf("Updating detailed system information...")
if err := reportSystemInfo(apiClient, cfg); err != nil {
log.Printf("Failed to report system info: %v\n", err)
} else {
lastSystemInfoUpdate = time.Now()
log.Printf("✓ System information updated\n")
}
}
log.Printf("Checking in with server... (Agent v%s)", AgentVersion)
// Collect lightweight system metrics
sysMetrics, err := system.GetLightweightMetrics()
var metrics *client.SystemMetrics
if err == nil {
metrics = &client.SystemMetrics{
CPUPercent: sysMetrics.CPUPercent,
MemoryPercent: sysMetrics.MemoryPercent,
MemoryUsedGB: sysMetrics.MemoryUsedGB,
MemoryTotalGB: sysMetrics.MemoryTotalGB,
DiskUsedGB: sysMetrics.DiskUsedGB,
DiskTotalGB: sysMetrics.DiskTotalGB,
DiskPercent: sysMetrics.DiskPercent,
Uptime: sysMetrics.Uptime,
Version: AgentVersion,
}
}
// Add heartbeat status to metrics metadata if available
if metrics != nil && cfg.RapidPollingEnabled {
// Check if rapid polling is still valid
if time.Now().Before(cfg.RapidPollingUntil) {
// Include heartbeat metadata in metrics
if metrics.Metadata == nil {
metrics.Metadata = make(map[string]interface{})
}
metrics.Metadata["rapid_polling_enabled"] = true
metrics.Metadata["rapid_polling_until"] = cfg.RapidPollingUntil.Format(time.RFC3339)
metrics.Metadata["rapid_polling_duration_minutes"] = int(time.Until(cfg.RapidPollingUntil).Minutes())
} else {
// Heartbeat expired, disable it
cfg.RapidPollingEnabled = false
cfg.RapidPollingUntil = time.Time{}
}
}
// Add pending acknowledgments to metrics for reliability
if metrics != nil {
pendingAcks := ackTracker.GetPending()
if len(pendingAcks) > 0 {
metrics.PendingAcknowledgments = pendingAcks
log.Printf("Including %d pending acknowledgments in check-in: %v", len(pendingAcks), pendingAcks)
} else {
log.Printf("No pending acknowledgments to send")
}
} else {
log.Printf("Metrics is nil - not sending system information or acknowledgments")
}
// Get commands from server (with optional metrics)
response, err := apiClient.GetCommands(cfg.AgentID, metrics)
if err != nil {
// Try to renew token if we got a 401 error
newClient, renewErr := renewTokenIfNeeded(apiClient, cfg, err)
if renewErr != nil {
log.Printf("Check-in unsuccessful and token renewal failed: %v\n", renewErr)
time.Sleep(time.Duration(getCurrentPollingInterval(cfg)) * time.Second)
continue
}
// If token was renewed, update client and retry
if newClient != apiClient {
log.Printf("🔄 Retrying check-in with renewed token...")
apiClient = newClient
response, err = apiClient.GetCommands(cfg.AgentID, metrics)
if err != nil {
log.Printf("Check-in unsuccessful even after token renewal: %v\n", err)
time.Sleep(time.Duration(getCurrentPollingInterval(cfg)) * time.Second)
continue
}
} else {
log.Printf("Check-in unsuccessful: %v\n", err)
time.Sleep(time.Duration(getCurrentPollingInterval(cfg)) * time.Second)
continue
}
}
// Process acknowledged command results
if response != nil && len(response.AcknowledgedIDs) > 0 {
ackTracker.Acknowledge(response.AcknowledgedIDs)
log.Printf("Server acknowledged %d command result(s)", len(response.AcknowledgedIDs))
// Save acknowledgment state
if err := ackTracker.Save(); err != nil {
log.Printf("Warning: Failed to save acknowledgment state: %v", err)
}
}
// Sync configuration from server (non-blocking)
go func() {
if err := syncServerConfig(apiClient, cfg); err != nil {
log.Printf("Warning: Failed to sync server config: %v", err)
}
}()
commands := response.Commands
if len(commands) == 0 {
log.Printf("Check-in successful - no new commands")
} else {
log.Printf("Check-in successful - received %d command(s)", len(commands))
}
// Process each command
for _, cmd := range commands {
log.Printf("Processing command: %s (%s)\n", cmd.Type, cmd.ID)
switch cmd.Type {
case "scan_updates":
if err := handleScanUpdatesV2(apiClient, cfg, ackTracker, scanOrchestrator, cmd.ID); err != nil {
log.Printf("Error scanning updates: %v\n", err)
}
case "scan_storage":
if err := handleScanStorage(apiClient, cfg, ackTracker, scanOrchestrator, cmd.ID); err != nil {
log.Printf("Error scanning storage: %v\n", err)
}
case "scan_system":
if err := handleScanSystem(apiClient, cfg, ackTracker, scanOrchestrator, cmd.ID); err != nil {
log.Printf("Error scanning system: %v\n", err)
}
case "scan_docker":
if err := handleScanDocker(apiClient, cfg, ackTracker, scanOrchestrator, cmd.ID); err != nil {
log.Printf("Error scanning Docker: %v\n", err)
}
case "collect_specs":
log.Println("Spec collection not yet implemented")
case "dry_run_update":
if err := handleDryRunUpdate(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil {
log.Printf("Error dry running update: %v\n", err)
}
case "install_updates":
if err := handleInstallUpdates(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil {
log.Printf("Error installing updates: %v\n", err)
}
case "confirm_dependencies":
if err := handleConfirmDependencies(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil {
log.Printf("Error confirming dependencies: %v\n", err)
}
case "enable_heartbeat":
if err := handleEnableHeartbeat(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil {
log.Printf("[Heartbeat] Error enabling heartbeat: %v\n", err)
}
case "disable_heartbeat":
if err := handleDisableHeartbeat(apiClient, cfg, ackTracker, cmd.ID); err != nil {
log.Printf("[Heartbeat] Error disabling heartbeat: %v\n", err)
}
case "reboot":
if err := handleReboot(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil {
log.Printf("[Reboot] Error processing reboot command: %v\n", err)
}
case "update_agent":
if err := handleUpdateAgent(apiClient, cfg, ackTracker, cmd.Params, cmd.ID); err != nil {
log.Printf("[Update] Error processing agent update command: %v\n", err)
}
default:
log.Printf("Unknown command type: %s - reporting as invalid command\n", cmd.Type)
// Report invalid command back to server
logReport := client.LogReport{
CommandID: cmd.ID,
Action: "process_command",
Result: "failed",
Stdout: "",
Stderr: fmt.Sprintf("Invalid command type: %s", cmd.Type),
ExitCode: 1,
DurationSeconds: 0,
}
if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil {
log.Printf("Failed to report invalid command result: %v", reportErr)
}
}
}
// Wait for next check-in
time.Sleep(time.Duration(getCurrentPollingInterval(cfg)) * time.Second)
}
}
// subsystemScan executes a scanner function with circuit breaker and timeout protection
func subsystemScan(name string, cb *circuitbreaker.CircuitBreaker, timeout time.Duration, scanFn func() ([]client.UpdateReportItem, error)) ([]client.UpdateReportItem, error) {
var updates []client.UpdateReportItem
var scanErr error
err := cb.Call(func() error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
type result struct {
updates []client.UpdateReportItem
err error
}
resultChan := make(chan result, 1)
go func() {
u, e := scanFn()
resultChan <- result{u, e}
}()
select {
case <-ctx.Done():
return fmt.Errorf("%s scan timeout after %v", name, timeout)
case res := <-resultChan:
if res.err != nil {
return res.err
}
updates = res.updates
return nil
}
})
if err != nil {
scanErr = err
}
return updates, scanErr
}
// handleScanCommand performs a local scan and displays results
func handleScanCommand(cfg *config.Config, exportFormat string) error {
// Initialize scanners
aptScanner := scanner.NewAPTScanner()
dnfScanner := scanner.NewDNFScanner()
dockerScanner, _ := scanner.NewDockerScanner()
windowsUpdateScanner := scanner.NewWindowsUpdateScanner()
wingetScanner := scanner.NewWingetScanner()
fmt.Println("🔍 Scanning for updates...")
var allUpdates []client.UpdateReportItem
// Scan APT updates
if aptScanner.IsAvailable() {
fmt.Println(" - Scanning APT packages...")
updates, err := aptScanner.Scan()
if err != nil {
fmt.Printf(" ⚠️ APT scan failed: %v\n", err)
} else {
fmt.Printf(" ✓ Found %d APT updates\n", len(updates))
allUpdates = append(allUpdates, updates...)
}
}
// Scan DNF updates
if dnfScanner.IsAvailable() {
fmt.Println(" - Scanning DNF packages...")
updates, err := dnfScanner.Scan()
if err != nil {
fmt.Printf(" ⚠️ DNF scan failed: %v\n", err)
} else {
fmt.Printf(" ✓ Found %d DNF updates\n", len(updates))
allUpdates = append(allUpdates, updates...)
}
}
// Scan Docker updates
if dockerScanner != nil && dockerScanner.IsAvailable() {
fmt.Println(" - Scanning Docker images...")
updates, err := dockerScanner.Scan()
if err != nil {
fmt.Printf(" ⚠️ Docker scan failed: %v\n", err)
} else {
fmt.Printf(" ✓ Found %d Docker image updates\n", len(updates))
allUpdates = append(allUpdates, updates...)
}
}
// Scan Windows updates
if windowsUpdateScanner.IsAvailable() {
fmt.Println(" - Scanning Windows updates...")
updates, err := windowsUpdateScanner.Scan()
if err != nil {
fmt.Printf(" ⚠️ Windows Update scan failed: %v\n", err)
} else {
fmt.Printf(" ✓ Found %d Windows updates\n", len(updates))
allUpdates = append(allUpdates, updates...)
}
}
// Scan Winget packages
if wingetScanner.IsAvailable() {
fmt.Println(" - Scanning Winget packages...")
updates, err := wingetScanner.Scan()
if err != nil {
fmt.Printf(" ⚠️ Winget scan failed: %v\n", err)
} else {
fmt.Printf(" ✓ Found %d Winget package updates\n", len(updates))
allUpdates = append(allUpdates, updates...)
}
}
// Load and update cache
localCache, err := cache.Load()
if err != nil {
fmt.Printf("⚠️ Warning: Failed to load cache: %v\n", err)
localCache = &cache.LocalCache{}
}
// Update cache with scan results
localCache.UpdateScanResults(allUpdates)
if cfg.IsRegistered() {
localCache.SetAgentInfo(cfg.AgentID, cfg.ServerURL)
localCache.SetAgentStatus("online")
}
// Save cache
if err := localCache.Save(); err != nil {
fmt.Printf("⚠️ Warning: Failed to save cache: %v\n", err)
}
// Display results
fmt.Println()
return display.PrintScanResults(allUpdates, exportFormat)
}
// handleStatusCommand displays agent status information
func handleStatusCommand(cfg *config.Config) error {
// Load cache
localCache, err := cache.Load()
if err != nil {
return fmt.Errorf("failed to load cache: %w", err)
}
// Determine status
agentStatus := "offline"
if cfg.IsRegistered() {
agentStatus = "online"
}
if localCache.AgentStatus != "" {
agentStatus = localCache.AgentStatus
}
// Use cached info if available, otherwise use config
agentID := cfg.AgentID.String()
if localCache.AgentID != (uuid.UUID{}) {
agentID = localCache.AgentID.String()
}
serverURL := cfg.ServerURL
if localCache.ServerURL != "" {
serverURL = localCache.ServerURL
}
// Display status
display.PrintAgentStatus(
agentID,
serverURL,
localCache.LastCheckIn,
localCache.LastScanTime,
localCache.UpdateCount,
agentStatus,
)
return nil
}
// handleListUpdatesCommand displays detailed update information
func handleListUpdatesCommand(cfg *config.Config, exportFormat string) error {
// Load cache
localCache, err := cache.Load()
if err != nil {
return fmt.Errorf("failed to load cache: %w", err)
}
// Check if we have cached scan results
if len(localCache.Updates) == 0 {
fmt.Println("📋 No cached scan results found.")
fmt.Println("💡 Run '--scan' first to discover available updates.")
return nil
}
// Warn if cache is old
if localCache.IsExpired(24 * time.Hour) {
fmt.Printf("⚠️ Scan results are %s old. Run '--scan' for latest results.\n\n",
formatTimeSince(localCache.LastScanTime))
}
// Display detailed results
return display.PrintDetailedUpdates(localCache.Updates, exportFormat)
}
// handleInstallUpdates handles install_updates command
func handleInstallUpdates(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, commandID string, params map[string]interface{}) error {
log.Println("Installing updates...")
// Parse parameters
packageType := ""
packageName := ""
if pt, ok := params["package_type"].(string); ok {
packageType = pt
}
if pn, ok := params["package_name"].(string); ok {
packageName = pn
}
// Validate package type
if packageType == "" {
return fmt.Errorf("package_type parameter is required")
}
// Create installer based on package type
inst, err := installer.InstallerFactory(packageType)
if err != nil {
return fmt.Errorf("failed to create installer for package type %s: %w", packageType, err)
}
// Check if installer is available
if !inst.IsAvailable() {
return fmt.Errorf("%s installer is not available on this system", packageType)
}
var result *installer.InstallResult
var action string
// Perform installation based on what's specified
if packageName != "" {
action = "update"
log.Printf("Updating package: %s (type: %s)", packageName, packageType)
result, err = inst.UpdatePackage(packageName)
} else if len(params) > 1 {
// Multiple packages might be specified in various ways
var packageNames []string
for key, value := range params {
if key != "package_type" {
if name, ok := value.(string); ok && name != "" {
packageNames = append(packageNames, name)
}
}
}
if len(packageNames) > 0 {
action = "install_multiple"
log.Printf("Installing multiple packages: %v (type: %s)", packageNames, packageType)
result, err = inst.InstallMultiple(packageNames)
} else {
// Upgrade all packages if no specific packages named
action = "upgrade"
log.Printf("Upgrading all packages (type: %s)", packageType)
result, err = inst.Upgrade()
}
} else {
// Upgrade all packages if no specific packages named
action = "upgrade"
log.Printf("Upgrading all packages (type: %s)", packageType)
result, err = inst.Upgrade()
}
if err != nil {
// Report installation failure with actual command output
logReport := client.LogReport{
CommandID: commandID,
Action: action,
Result: "failed",
Stdout: result.Stdout,
Stderr: result.Stderr,
ExitCode: result.ExitCode,
DurationSeconds: result.DurationSeconds,
}
if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil {
log.Printf("Failed to report installation failure: %v\n", reportErr)
}
return fmt.Errorf("installation failed: %w", err)
}
// Report installation success
logReport := client.LogReport{
CommandID: commandID,
Action: result.Action,
Result: "success",
Stdout: result.Stdout,
Stderr: result.Stderr,
ExitCode: result.ExitCode,
DurationSeconds: result.DurationSeconds,
}
// Add additional metadata to the log report
if len(result.PackagesInstalled) > 0 {
logReport.Stdout += fmt.Sprintf("\nPackages installed: %v", result.PackagesInstalled)
}
if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil {
log.Printf("Failed to report installation success: %v\n", reportErr)
}
if result.Success {
log.Printf("✓ Installation completed successfully in %d seconds\n", result.DurationSeconds)
if len(result.PackagesInstalled) > 0 {
log.Printf(" Packages installed: %v\n", result.PackagesInstalled)
}
} else {
log.Printf("✗ Installation failed after %d seconds\n", result.DurationSeconds)
log.Printf(" Error: %s\n", result.ErrorMessage)
}
return nil
}
// handleDryRunUpdate handles dry_run_update command
func handleDryRunUpdate(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, commandID string, params map[string]interface{}) error {
log.Println("Performing dry run update...")
// Parse parameters
packageType := ""
packageName := ""
if pt, ok := params["package_type"].(string); ok {
packageType = pt
}
if pn, ok := params["package_name"].(string); ok {
packageName = pn
}
// Validate parameters
if packageType == "" || packageName == "" {
return fmt.Errorf("package_type and package_name parameters are required")
}
// Create installer based on package type
inst, err := installer.InstallerFactory(packageType)
if err != nil {
return fmt.Errorf("failed to create installer for package type %s: %w", packageType, err)
}
// Check if installer is available
if !inst.IsAvailable() {
return fmt.Errorf("%s installer is not available on this system", packageType)
}
// Perform dry run
log.Printf("Dry running package: %s (type: %s)", packageName, packageType)
result, err := inst.DryRun(packageName)
if err != nil {
// Report dry run failure
logReport := client.LogReport{
CommandID: commandID,
Action: "dry_run",
Result: "failed",
Stdout: "",
Stderr: fmt.Sprintf("Dry run error: %v", err),
ExitCode: 1,
DurationSeconds: 0,
}
if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil {
log.Printf("Failed to report dry run failure: %v\n", reportErr)
}
return fmt.Errorf("dry run failed: %w", err)
}
// Convert installer.InstallResult to client.InstallResult for reporting
clientResult := &client.InstallResult{
Success: result.Success,
ErrorMessage: result.ErrorMessage,
Stdout: result.Stdout,
Stderr: result.Stderr,
ExitCode: result.ExitCode,
DurationSeconds: result.DurationSeconds,
Action: result.Action,
PackagesInstalled: result.PackagesInstalled,
ContainersUpdated: result.ContainersUpdated,
Dependencies: result.Dependencies,
IsDryRun: true,
}
// Report dependencies back to server
depReport := client.DependencyReport{
PackageName: packageName,
PackageType: packageType,
Dependencies: result.Dependencies,
UpdateID: params["update_id"].(string),
DryRunResult: clientResult,
}
if reportErr := apiClient.ReportDependencies(cfg.AgentID, depReport); reportErr != nil {
log.Printf("Failed to report dependencies: %v\n", reportErr)
return fmt.Errorf("failed to report dependencies: %w", reportErr)
}
// Report dry run success
logReport := client.LogReport{
CommandID: commandID,
Action: "dry_run",
Result: "success",
Stdout: result.Stdout,
Stderr: result.Stderr,
ExitCode: result.ExitCode,
DurationSeconds: result.DurationSeconds,
}
if len(result.Dependencies) > 0 {
logReport.Stdout += fmt.Sprintf("\nDependencies found: %v", result.Dependencies)
}
if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil {
log.Printf("Failed to report dry run success: %v\n", reportErr)
}
if result.Success {
log.Printf("✓ Dry run completed successfully in %d seconds\n", result.DurationSeconds)
if len(result.Dependencies) > 0 {
log.Printf(" Dependencies found: %v\n", result.Dependencies)
} else {
log.Printf(" No additional dependencies found\n")
}
} else {
log.Printf("✗ Dry run failed after %d seconds\n", result.DurationSeconds)
log.Printf(" Error: %s\n", result.ErrorMessage)
}
return nil
}
// handleConfirmDependencies handles confirm_dependencies command
func handleConfirmDependencies(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, commandID string, params map[string]interface{}) error {
log.Println("Installing update with confirmed dependencies...")
// Parse parameters
packageType := ""
packageName := ""
var dependencies []string
if pt, ok := params["package_type"].(string); ok {
packageType = pt
}
if pn, ok := params["package_name"].(string); ok {
packageName = pn
}
if deps, ok := params["dependencies"].([]interface{}); ok {
for _, dep := range deps {
if depStr, ok := dep.(string); ok {
dependencies = append(dependencies, depStr)
}
}
}
// Validate parameters
if packageType == "" || packageName == "" {
return fmt.Errorf("package_type and package_name parameters are required")
}
// Create installer based on package type
inst, err := installer.InstallerFactory(packageType)
if err != nil {
return fmt.Errorf("failed to create installer for package type %s: %w", packageType, err)
}
// Check if installer is available
if !inst.IsAvailable() {
return fmt.Errorf("%s installer is not available on this system", packageType)
}
var result *installer.InstallResult
var action string
// Perform installation with dependencies
if len(dependencies) > 0 {
action = "install_with_dependencies"
log.Printf("Installing package with dependencies: %s (dependencies: %v)", packageName, dependencies)
// Install main package + dependencies
allPackages := append([]string{packageName}, dependencies...)
result, err = inst.InstallMultiple(allPackages)
} else {
action = "upgrade"
log.Printf("Installing package: %s (no dependencies)", packageName)
// Use UpdatePackage instead of Install to handle existing packages
result, err = inst.UpdatePackage(packageName)
}
if err != nil {
// Report installation failure with actual command output
logReport := client.LogReport{
CommandID: commandID,
Action: action,
Result: "failed",
Stdout: result.Stdout,
Stderr: result.Stderr,
ExitCode: result.ExitCode,
DurationSeconds: result.DurationSeconds,
}
if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil {
log.Printf("Failed to report installation failure: %v\n", reportErr)
}
return fmt.Errorf("installation failed: %w", err)
}
// Report installation success
logReport := client.LogReport{
CommandID: commandID,
Action: result.Action,
Result: "success",
Stdout: result.Stdout,
Stderr: result.Stderr,
ExitCode: result.ExitCode,
DurationSeconds: result.DurationSeconds,
}
// Add additional metadata to the log report
if len(result.PackagesInstalled) > 0 {
logReport.Stdout += fmt.Sprintf("\nPackages installed: %v", result.PackagesInstalled)
}
if len(dependencies) > 0 {
logReport.Stdout += fmt.Sprintf("\nDependencies included: %v", dependencies)
}
if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil {
log.Printf("Failed to report installation success: %v\n", reportErr)
}
if result.Success {
log.Printf("✓ Installation with dependencies completed successfully in %d seconds\n", result.DurationSeconds)
if len(result.PackagesInstalled) > 0 {
log.Printf(" Packages installed: %v\n", result.PackagesInstalled)
}
} else {
log.Printf("✗ Installation with dependencies failed after %d seconds\n", result.DurationSeconds)
log.Printf(" Error: %s\n", result.ErrorMessage)
}
return nil
}
// handleEnableHeartbeat handles enable_heartbeat command
func handleEnableHeartbeat(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, commandID string, params map[string]interface{}) error {
// Parse duration parameter (default to 10 minutes)
durationMinutes := 10
if duration, ok := params["duration_minutes"]; ok {
if durationFloat, ok := duration.(float64); ok {
durationMinutes = int(durationFloat)
}
}
// Calculate when heartbeat should expire
expiryTime := time.Now().Add(time.Duration(durationMinutes) * time.Minute)
log.Printf("[Heartbeat] Enabling rapid polling for %d minutes (expires: %s)", durationMinutes, expiryTime.Format(time.RFC3339))
// Update agent config to enable rapid polling
cfg.RapidPollingEnabled = true
cfg.RapidPollingUntil = expiryTime
// Save config to persist heartbeat settings
if err := cfg.Save(getConfigPath()); err != nil {
log.Printf("[Heartbeat] Warning: Failed to save config: %v", err)
}
// Create log report for heartbeat enable
logReport := client.LogReport{
CommandID: commandID,
Action: "enable_heartbeat",
Result: "success",
Stdout: fmt.Sprintf("Heartbeat enabled for %d minutes", durationMinutes),
Stderr: "",
ExitCode: 0,
DurationSeconds: 0,
}
if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil {
log.Printf("[Heartbeat] Failed to report heartbeat enable: %v", reportErr)
}
// Send immediate check-in to update heartbeat status in UI
log.Printf("[Heartbeat] Sending immediate check-in to update status")
sysMetrics, err := system.GetLightweightMetrics()
if err == nil {
metrics := &client.SystemMetrics{
CPUPercent: sysMetrics.CPUPercent,
MemoryPercent: sysMetrics.MemoryPercent,
MemoryUsedGB: sysMetrics.MemoryUsedGB,
MemoryTotalGB: sysMetrics.MemoryTotalGB,
DiskUsedGB: sysMetrics.DiskUsedGB,
DiskTotalGB: sysMetrics.DiskTotalGB,
DiskPercent: sysMetrics.DiskPercent,
Uptime: sysMetrics.Uptime,
Version: AgentVersion,
}
// Include heartbeat metadata to show enabled state
metrics.Metadata = map[string]interface{}{
"rapid_polling_enabled": true,
"rapid_polling_until": expiryTime.Format(time.RFC3339),
}
// Send immediate check-in with updated heartbeat status
_, checkinErr := apiClient.GetCommands(cfg.AgentID, metrics)
if checkinErr != nil {
log.Printf("[Heartbeat] Failed to send immediate check-in: %v", checkinErr)
} else {
log.Printf("[Heartbeat] Immediate check-in sent successfully")
}
} else {
log.Printf("[Heartbeat] Failed to get system metrics for immediate check-in: %v", err)
}
log.Printf("[Heartbeat] Rapid polling enabled successfully")
return nil
}
// handleDisableHeartbeat handles disable_heartbeat command
func handleDisableHeartbeat(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, commandID string) error {
log.Printf("[Heartbeat] Disabling rapid polling")
// Update agent config to disable rapid polling
cfg.RapidPollingEnabled = false
cfg.RapidPollingUntil = time.Time{} // Zero value
// Save config to persist heartbeat settings
if err := cfg.Save(getConfigPath()); err != nil {
log.Printf("[Heartbeat] Warning: Failed to save config: %v", err)
}
// Create log report for heartbeat disable
logReport := client.LogReport{
CommandID: commandID,
Action: "disable_heartbeat",
Result: "success",
Stdout: "Heartbeat disabled",
Stderr: "",
ExitCode: 0,
DurationSeconds: 0,
}
if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil {
log.Printf("[Heartbeat] Failed to report heartbeat disable: %v", reportErr)
}
// Send immediate check-in to update heartbeat status in UI
log.Printf("[Heartbeat] Sending immediate check-in to update status")
sysMetrics, err := system.GetLightweightMetrics()
if err == nil {
metrics := &client.SystemMetrics{
CPUPercent: sysMetrics.CPUPercent,
MemoryPercent: sysMetrics.MemoryPercent,
MemoryUsedGB: sysMetrics.MemoryUsedGB,
MemoryTotalGB: sysMetrics.MemoryTotalGB,
DiskUsedGB: sysMetrics.DiskUsedGB,
DiskTotalGB: sysMetrics.DiskTotalGB,
DiskPercent: sysMetrics.DiskPercent,
Uptime: sysMetrics.Uptime,
Version: AgentVersion,
}
// Include empty heartbeat metadata to explicitly show disabled state
metrics.Metadata = map[string]interface{}{
"rapid_polling_enabled": false,
"rapid_polling_until": "",
}
// Send immediate check-in with updated heartbeat status
_, checkinErr := apiClient.GetCommands(cfg.AgentID, metrics)
if checkinErr != nil {
log.Printf("[Heartbeat] Failed to send immediate check-in: %v", checkinErr)
} else {
log.Printf("[Heartbeat] Immediate check-in sent successfully")
}
} else {
log.Printf("[Heartbeat] Failed to get system metrics for immediate check-in: %v", err)
}
log.Printf("[Heartbeat] Rapid polling disabled successfully")
return nil
}
// reportSystemInfo collects and reports detailed system information to the server
func reportSystemInfo(apiClient *client.Client, cfg *config.Config) error {
// Collect detailed system information
sysInfo, err := system.GetSystemInfo(AgentVersion)
if err != nil {
return fmt.Errorf("failed to get system info: %w", err)
}
// Create system info report
report := client.SystemInfoReport{
Timestamp: time.Now(),
CPUModel: sysInfo.CPUInfo.ModelName,
CPUCores: sysInfo.CPUInfo.Cores,
CPUThreads: sysInfo.CPUInfo.Threads,
MemoryTotal: sysInfo.MemoryInfo.Total,
DiskTotal: uint64(0),
DiskUsed: uint64(0),
IPAddress: sysInfo.IPAddress,
Processes: sysInfo.RunningProcesses,
Uptime: sysInfo.Uptime,
Metadata: make(map[string]interface{}),
}
// Add primary disk info
if len(sysInfo.DiskInfo) > 0 {
primaryDisk := sysInfo.DiskInfo[0]
report.DiskTotal = primaryDisk.Total
report.DiskUsed = primaryDisk.Used
report.Metadata["disk_mount"] = primaryDisk.Mountpoint
report.Metadata["disk_filesystem"] = primaryDisk.Filesystem
}
// Add collection timestamp and additional metadata
report.Metadata["collected_at"] = time.Now().Format(time.RFC3339)
report.Metadata["hostname"] = sysInfo.Hostname
report.Metadata["os_type"] = sysInfo.OSType
report.Metadata["os_version"] = sysInfo.OSVersion
report.Metadata["os_architecture"] = sysInfo.OSArchitecture
// Add any existing metadata from system info
for key, value := range sysInfo.Metadata {
report.Metadata[key] = value
}
// Report to server
if err := apiClient.ReportSystemInfo(cfg.AgentID, report); err != nil {
return fmt.Errorf("failed to report system info: %w", err)
}
return nil
}
// handleReboot handles reboot command
func handleReboot(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, commandID string, params map[string]interface{}) error {
log.Println("[Reboot] Processing reboot request...")
// Parse parameters
delayMinutes := 1 // Default to 1 minute
message := "System reboot requested by RedFlag"
if delay, ok := params["delay_minutes"]; ok {
if delayFloat, ok := delay.(float64); ok {
delayMinutes = int(delayFloat)
}
}
if msg, ok := params["message"].(string); ok && msg != "" {
message = msg
}
log.Printf("[Reboot] Scheduling system reboot in %d minute(s): %s", delayMinutes, message)
var cmd *exec.Cmd
// Execute platform-specific reboot command
if runtime.GOOS == "linux" {
// Linux: shutdown -r +MINUTES "message"
cmd = exec.Command("shutdown", "-r", fmt.Sprintf("+%d", delayMinutes), message)
} else if runtime.GOOS == "windows" {
// Windows: shutdown /r /t SECONDS /c "message"
delaySeconds := delayMinutes * 60
cmd = exec.Command("shutdown", "/r", "/t", fmt.Sprintf("%d", delaySeconds), "/c", message)
} else {
err := fmt.Errorf("reboot not supported on platform: %s", runtime.GOOS)
log.Printf("[Reboot] Error: %v", err)
// Report failure
logReport := client.LogReport{
CommandID: commandID,
Action: "reboot",
Result: "failed",
Stdout: "",
Stderr: err.Error(),
ExitCode: 1,
DurationSeconds: 0,
}
reportLogWithAck(apiClient, cfg, ackTracker, logReport)
return err
}
// Execute reboot command
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("[Reboot] Failed to schedule reboot: %v", err)
log.Printf("[Reboot] Output: %s", string(output))
// Report failure
logReport := client.LogReport{
CommandID: commandID,
Action: "reboot",
Result: "failed",
Stdout: string(output),
Stderr: err.Error(),
ExitCode: 1,
DurationSeconds: 0,
}
reportLogWithAck(apiClient, cfg, ackTracker, logReport)
return err
}
log.Printf("[Reboot] System reboot scheduled successfully")
log.Printf("[Reboot] The system will reboot in %d minute(s)", delayMinutes)
// Report success
logReport := client.LogReport{
CommandID: commandID,
Action: "reboot",
Result: "success",
Stdout: fmt.Sprintf("System reboot scheduled for %d minute(s) from now. Message: %s", delayMinutes, message),
Stderr: "",
ExitCode: 0,
DurationSeconds: 0,
}
if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil {
log.Printf("[Reboot] Failed to report reboot command result: %v", reportErr)
}
return nil
}
// formatTimeSince formats a duration as "X time ago"
func formatTimeSince(t time.Time) string {
duration := time.Since(t)
if duration < time.Minute {
return fmt.Sprintf("%d seconds ago", int(duration.Seconds()))
} else if duration < time.Hour {
return fmt.Sprintf("%d minutes ago", int(duration.Minutes()))
} else if duration < 24*time.Hour {
return fmt.Sprintf("%d hours ago", int(duration.Hours()))
} else {
return fmt.Sprintf("%d days ago", int(duration.Hours()/24))
}
}