diff --git a/README.md b/README.md index 7e78b65..491dc60 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ **Self-hosted update management for homelabs** Cross-platform agents • Web dashboard • Single binary deployment • No enterprise BS +No MacOS yet - need real hardware, not hackintosh hopes and prayers ``` v0.1.18 - Alpha Release @@ -50,13 +51,13 @@ RedFlag lets you manage software updates across all your servers from one dashbo |------------------|---------------------|---------------| | ![Heartbeat](Screenshots/RedFlag%20Heartbeat%20System.png) | ![Tokens](Screenshots/RedFlag%20Registration%20Tokens.jpg) | ![Settings](Screenshots/RedFlag%20Settings%20Page.jpg) | -| Linux Update History | Windows Agent Details | Agent List | +| Linux Update Details | Linux Health Details | Agent List | |---------------------|----------------------|------------| -| ![Linux History](Screenshots/RedFlag%20Linux%20Agent%20History%20Extended.png) | ![Windows Agent](Screenshots/RedFlag%20Windows%20Agent%20Details.png) | ![Agent List](Screenshots/RedFlag%20Agent%20List.png) | +| ![Update Details](Screenshots/RedFlag%20Linux%20Agent%20Update%20Details.png) | ![Health Details](Screenshots/RedFlag%20Linux%20Agent%20Health%20Details.png) | ![Agent List](Screenshots/RedFlag%20Agent%20List.png) | -| Windows Update History | -|------------------------| -| ![Windows History](Screenshots/RedFlag%20Windows%20Agent%20History%20Extended.png) | +| Linux Update History | Windows Agent Details | Windows Update History | +|---------------------|----------------------|------------------------| +| ![Linux History](Screenshots/RedFlag%20Linux%20Agent%20History%20Extended.png) | ![Windows Agent](Screenshots/RedFlag%20Windows%20Agent%20Details.png) | ![Windows History](Screenshots/RedFlag%20Windows%20Agent%20History%20Extended.png) | diff --git a/Screenshots/AgentMgmt.jpg b/Screenshots/AgentMgmt.jpg deleted file mode 100644 index 64285e0..0000000 Binary files a/Screenshots/AgentMgmt.jpg and /dev/null differ diff --git a/Screenshots/AgentMgmt.png b/Screenshots/AgentMgmt.png new file mode 100644 index 0000000..ff9660d Binary files /dev/null and b/Screenshots/AgentMgmt.png differ diff --git a/Screenshots/RedFlag Linux Agent Details.png b/Screenshots/RedFlag Linux Agent Details.png index 8050b12..f9cf0c0 100644 Binary files a/Screenshots/RedFlag Linux Agent Details.png and b/Screenshots/RedFlag Linux Agent Details.png differ diff --git a/Screenshots/RedFlag Linux Agent Health Details.png b/Screenshots/RedFlag Linux Agent Health Details.png new file mode 100644 index 0000000..f48c5f2 Binary files /dev/null and b/Screenshots/RedFlag Linux Agent Health Details.png differ diff --git a/Screenshots/RedFlag Linux Agent Update Details.png b/Screenshots/RedFlag Linux Agent Update Details.png new file mode 100644 index 0000000..f8f4235 Binary files /dev/null and b/Screenshots/RedFlag Linux Agent Update Details.png differ diff --git a/aggregator-agent/agent b/aggregator-agent/agent new file mode 100755 index 0000000..df69cce Binary files /dev/null and b/aggregator-agent/agent differ diff --git a/aggregator-agent/agent-test b/aggregator-agent/agent-test new file mode 100755 index 0000000..c1717d6 Binary files /dev/null and b/aggregator-agent/agent-test differ diff --git a/aggregator-agent/cmd/agent/main.go b/aggregator-agent/cmd/agent/main.go index 88a1431..6886de4 100644 --- a/aggregator-agent/cmd/agent/main.go +++ b/aggregator-agent/cmd/agent/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "log" @@ -11,7 +12,9 @@ import ( "strings" "time" + "github.com/Fimeg/RedFlag/aggregator-agent/internal/acknowledgment" "github.com/Fimeg/RedFlag/aggregator-agent/internal/cache" + "github.com/Fimeg/RedFlag/aggregator-agent/internal/circuitbreaker" "github.com/Fimeg/RedFlag/aggregator-agent/internal/client" "github.com/Fimeg/RedFlag/aggregator-agent/internal/config" "github.com/Fimeg/RedFlag/aggregator-agent/internal/display" @@ -23,7 +26,7 @@ import ( ) const ( - AgentVersion = "0.1.18" // Enhanced disk detection with comprehensive partition reporting + AgentVersion = "0.1.19" // Phase 0: Circuit breakers, timeouts, and subsystem resilience ) // getConfigPath returns the platform-specific config path @@ -34,6 +37,34 @@ func getConfigPath() string { return "/etc/aggregator/config.json" } +// getStatePath returns the platform-specific state directory path +func getStatePath() string { + if runtime.GOOS == "windows" { + return "C:\\ProgramData\\RedFlag\\state" + } + return "/var/lib/aggregator" +} + +// reportLogWithAck reports a command log to the server and tracks it for acknowledgment +func reportLogWithAck(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, logReport client.LogReport) error { + // Track this command result as pending acknowledgment + ackTracker.Add(logReport.CommandID) + + // Save acknowledgment state immediately + if err := ackTracker.Save(); err != nil { + log.Printf("Warning: Failed to save acknowledgment for command %s: %v", logReport.CommandID, err) + } + + // Report the log to the server + if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil { + // If reporting failed, increment retry count but don't remove from pending + ackTracker.IncrementRetry(logReport.CommandID) + return err + } + + return nil +} + // getCurrentPollingInterval returns the appropriate polling interval based on rapid mode func getCurrentPollingInterval(cfg *config.Config) int { // Check if rapid polling mode is active and not expired @@ -403,6 +434,64 @@ func runAgent(cfg *config.Config) error { windowsUpdateScanner := scanner.NewWindowsUpdateScanner() wingetScanner := scanner.NewWingetScanner() + // Initialize circuit breakers for each subsystem + aptCB := circuitbreaker.New("APT", circuitbreaker.Config{ + FailureThreshold: cfg.Subsystems.APT.CircuitBreaker.FailureThreshold, + FailureWindow: cfg.Subsystems.APT.CircuitBreaker.FailureWindow, + OpenDuration: cfg.Subsystems.APT.CircuitBreaker.OpenDuration, + HalfOpenAttempts: cfg.Subsystems.APT.CircuitBreaker.HalfOpenAttempts, + }) + dnfCB := circuitbreaker.New("DNF", circuitbreaker.Config{ + FailureThreshold: cfg.Subsystems.DNF.CircuitBreaker.FailureThreshold, + FailureWindow: cfg.Subsystems.DNF.CircuitBreaker.FailureWindow, + OpenDuration: cfg.Subsystems.DNF.CircuitBreaker.OpenDuration, + HalfOpenAttempts: cfg.Subsystems.DNF.CircuitBreaker.HalfOpenAttempts, + }) + dockerCB := circuitbreaker.New("Docker", circuitbreaker.Config{ + FailureThreshold: cfg.Subsystems.Docker.CircuitBreaker.FailureThreshold, + FailureWindow: cfg.Subsystems.Docker.CircuitBreaker.FailureWindow, + OpenDuration: cfg.Subsystems.Docker.CircuitBreaker.OpenDuration, + HalfOpenAttempts: cfg.Subsystems.Docker.CircuitBreaker.HalfOpenAttempts, + }) + windowsCB := circuitbreaker.New("Windows Update", circuitbreaker.Config{ + FailureThreshold: cfg.Subsystems.Windows.CircuitBreaker.FailureThreshold, + FailureWindow: cfg.Subsystems.Windows.CircuitBreaker.FailureWindow, + OpenDuration: cfg.Subsystems.Windows.CircuitBreaker.OpenDuration, + HalfOpenAttempts: cfg.Subsystems.Windows.CircuitBreaker.HalfOpenAttempts, + }) + wingetCB := circuitbreaker.New("Winget", circuitbreaker.Config{ + FailureThreshold: cfg.Subsystems.Winget.CircuitBreaker.FailureThreshold, + FailureWindow: cfg.Subsystems.Winget.CircuitBreaker.FailureWindow, + OpenDuration: cfg.Subsystems.Winget.CircuitBreaker.OpenDuration, + HalfOpenAttempts: cfg.Subsystems.Winget.CircuitBreaker.HalfOpenAttempts, + }) + + // Initialize acknowledgment tracker for command result reliability + ackTracker := acknowledgment.NewTracker(getStatePath()) + if err := ackTracker.Load(); err != nil { + log.Printf("Warning: Failed to load pending acknowledgments: %v", err) + } else { + pendingCount := len(ackTracker.GetPending()) + if pendingCount > 0 { + log.Printf("Loaded %d pending command acknowledgments from previous session", pendingCount) + } + } + + // Periodic cleanup of old/stale acknowledgments + go func() { + cleanupTicker := time.NewTicker(1 * time.Hour) + defer cleanupTicker.Stop() + for range cleanupTicker.C { + removed := ackTracker.Cleanup() + if removed > 0 { + log.Printf("Cleaned up %d stale acknowledgments", removed) + if err := ackTracker.Save(); err != nil { + log.Printf("Warning: Failed to save acknowledgments after cleanup: %v", err) + } + } + } + }() + // System info tracking var lastSystemInfoUpdate time.Time const systemInfoUpdateInterval = 1 * time.Hour // Update detailed system info every hour @@ -461,8 +550,16 @@ func runAgent(cfg *config.Config) error { } } + // Add pending acknowledgments to metrics for reliability + if metrics != nil { + pendingAcks := ackTracker.GetPending() + if len(pendingAcks) > 0 { + metrics.PendingAcknowledgments = pendingAcks + } + } + // Get commands from server (with optional metrics) - commands, err := apiClient.GetCommands(cfg.AgentID, metrics) + response, err := apiClient.GetCommands(cfg.AgentID, metrics) if err != nil { // Try to renew token if we got a 401 error newClient, renewErr := renewTokenIfNeeded(apiClient, cfg, err) @@ -476,7 +573,7 @@ func runAgent(cfg *config.Config) error { if newClient != apiClient { log.Printf("🔄 Retrying check-in with renewed token...") apiClient = newClient - commands, err = apiClient.GetCommands(cfg.AgentID, metrics) + response, err = apiClient.GetCommands(cfg.AgentID, metrics) if err != nil { log.Printf("Check-in unsuccessful even after token renewal: %v\n", err) time.Sleep(time.Duration(getCurrentPollingInterval(cfg)) * time.Second) @@ -489,6 +586,18 @@ func runAgent(cfg *config.Config) error { } } + // Process acknowledged command results + if response != nil && len(response.AcknowledgedIDs) > 0 { + ackTracker.Acknowledge(response.AcknowledgedIDs) + log.Printf("Server acknowledged %d command result(s)", len(response.AcknowledgedIDs)) + + // Save acknowledgment state + if err := ackTracker.Save(); err != nil { + log.Printf("Warning: Failed to save acknowledgment state: %v", err) + } + } + + commands := response.Commands if len(commands) == 0 { log.Printf("Check-in successful - no new commands") } else { @@ -501,7 +610,7 @@ func runAgent(cfg *config.Config) error { switch cmd.Type { case "scan_updates": - if err := handleScanUpdates(apiClient, cfg, aptScanner, dnfScanner, dockerScanner, windowsUpdateScanner, wingetScanner, cmd.ID); err != nil { + if err := handleScanUpdates(apiClient, cfg, ackTracker, aptScanner, dnfScanner, dockerScanner, windowsUpdateScanner, wingetScanner, aptCB, dnfCB, dockerCB, windowsCB, wingetCB, cmd.ID); err != nil { log.Printf("Error scanning updates: %v\n", err) } @@ -509,33 +618,33 @@ func runAgent(cfg *config.Config) error { log.Println("Spec collection not yet implemented") case "dry_run_update": - if err := handleDryRunUpdate(apiClient, cfg, cmd.ID, cmd.Params); err != nil { + if err := handleDryRunUpdate(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil { log.Printf("Error dry running update: %v\n", err) } case "install_updates": - if err := handleInstallUpdates(apiClient, cfg, cmd.ID, cmd.Params); err != nil { + if err := handleInstallUpdates(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil { log.Printf("Error installing updates: %v\n", err) } case "confirm_dependencies": - if err := handleConfirmDependencies(apiClient, cfg, cmd.ID, cmd.Params); err != nil { + if err := handleConfirmDependencies(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil { log.Printf("Error confirming dependencies: %v\n", err) } case "enable_heartbeat": - if err := handleEnableHeartbeat(apiClient, cfg, cmd.ID, cmd.Params); err != nil { + if err := handleEnableHeartbeat(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil { log.Printf("[Heartbeat] Error enabling heartbeat: %v\n", err) } case "disable_heartbeat": - if err := handleDisableHeartbeat(apiClient, cfg, cmd.ID); err != nil { + if err := handleDisableHeartbeat(apiClient, cfg, ackTracker, cmd.ID); err != nil { log.Printf("[Heartbeat] Error disabling heartbeat: %v\n", err) } case "reboot": - if err := handleReboot(apiClient, cfg, cmd.ID, cmd.Params); err != nil { + if err := handleReboot(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil { log.Printf("[Reboot] Error processing reboot command: %v\n", err) } default: @@ -548,7 +657,46 @@ func runAgent(cfg *config.Config) error { } } -func handleScanUpdates(apiClient *client.Client, cfg *config.Config, aptScanner *scanner.APTScanner, dnfScanner *scanner.DNFScanner, dockerScanner *scanner.DockerScanner, windowsUpdateScanner *scanner.WindowsUpdateScanner, wingetScanner *scanner.WingetScanner, commandID string) error { +// subsystemScan executes a scanner function with circuit breaker and timeout protection +func subsystemScan(name string, cb *circuitbreaker.CircuitBreaker, timeout time.Duration, scanFn func() ([]client.UpdateReportItem, error)) ([]client.UpdateReportItem, error) { + var updates []client.UpdateReportItem + var scanErr error + + err := cb.Call(func() error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + type result struct { + updates []client.UpdateReportItem + err error + } + resultChan := make(chan result, 1) + + go func() { + u, e := scanFn() + resultChan <- result{u, e} + }() + + select { + case <-ctx.Done(): + return fmt.Errorf("%s scan timeout after %v", name, timeout) + case res := <-resultChan: + if res.err != nil { + return res.err + } + updates = res.updates + return nil + } + }) + + if err != nil { + scanErr = err + } + + return updates, scanErr +} + +func handleScanUpdates(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, aptScanner *scanner.APTScanner, dnfScanner *scanner.DNFScanner, dockerScanner *scanner.DockerScanner, windowsUpdateScanner *scanner.WindowsUpdateScanner, wingetScanner *scanner.WingetScanner, aptCB, dnfCB, dockerCB, windowsCB, wingetCB *circuitbreaker.CircuitBreaker, commandID string) error { log.Println("Scanning for updates...") var allUpdates []client.UpdateReportItem @@ -556,9 +704,9 @@ func handleScanUpdates(apiClient *client.Client, cfg *config.Config, aptScanner var scanResults []string // Scan APT updates - if aptScanner.IsAvailable() { + if aptScanner.IsAvailable() && cfg.Subsystems.APT.Enabled { log.Println(" - Scanning APT packages...") - updates, err := aptScanner.Scan() + updates, err := subsystemScan("APT", aptCB, cfg.Subsystems.APT.Timeout, aptScanner.Scan) if err != nil { errorMsg := fmt.Sprintf("APT scan failed: %v", err) log.Printf(" %s\n", errorMsg) @@ -569,14 +717,16 @@ func handleScanUpdates(apiClient *client.Client, cfg *config.Config, aptScanner scanResults = append(scanResults, resultMsg) allUpdates = append(allUpdates, updates...) } + } else if !cfg.Subsystems.APT.Enabled { + scanResults = append(scanResults, "APT scanner disabled") } else { scanResults = append(scanResults, "APT scanner not available") } // Scan DNF updates - if dnfScanner.IsAvailable() { + if dnfScanner.IsAvailable() && cfg.Subsystems.DNF.Enabled { log.Println(" - Scanning DNF packages...") - updates, err := dnfScanner.Scan() + updates, err := subsystemScan("DNF", dnfCB, cfg.Subsystems.DNF.Timeout, dnfScanner.Scan) if err != nil { errorMsg := fmt.Sprintf("DNF scan failed: %v", err) log.Printf(" %s\n", errorMsg) @@ -587,14 +737,16 @@ func handleScanUpdates(apiClient *client.Client, cfg *config.Config, aptScanner scanResults = append(scanResults, resultMsg) allUpdates = append(allUpdates, updates...) } + } else if !cfg.Subsystems.DNF.Enabled { + scanResults = append(scanResults, "DNF scanner disabled") } else { scanResults = append(scanResults, "DNF scanner not available") } // Scan Docker updates - if dockerScanner != nil && dockerScanner.IsAvailable() { + if dockerScanner != nil && dockerScanner.IsAvailable() && cfg.Subsystems.Docker.Enabled { log.Println(" - Scanning Docker images...") - updates, err := dockerScanner.Scan() + updates, err := subsystemScan("Docker", dockerCB, cfg.Subsystems.Docker.Timeout, dockerScanner.Scan) if err != nil { errorMsg := fmt.Sprintf("Docker scan failed: %v", err) log.Printf(" %s\n", errorMsg) @@ -605,14 +757,16 @@ func handleScanUpdates(apiClient *client.Client, cfg *config.Config, aptScanner scanResults = append(scanResults, resultMsg) allUpdates = append(allUpdates, updates...) } + } else if !cfg.Subsystems.Docker.Enabled { + scanResults = append(scanResults, "Docker scanner disabled") } else { scanResults = append(scanResults, "Docker scanner not available") } // Scan Windows updates - if windowsUpdateScanner.IsAvailable() { + if windowsUpdateScanner.IsAvailable() && cfg.Subsystems.Windows.Enabled { log.Println(" - Scanning Windows updates...") - updates, err := windowsUpdateScanner.Scan() + updates, err := subsystemScan("Windows Update", windowsCB, cfg.Subsystems.Windows.Timeout, windowsUpdateScanner.Scan) if err != nil { errorMsg := fmt.Sprintf("Windows Update scan failed: %v", err) log.Printf(" %s\n", errorMsg) @@ -623,14 +777,16 @@ func handleScanUpdates(apiClient *client.Client, cfg *config.Config, aptScanner scanResults = append(scanResults, resultMsg) allUpdates = append(allUpdates, updates...) } + } else if !cfg.Subsystems.Windows.Enabled { + scanResults = append(scanResults, "Windows Update scanner disabled") } else { scanResults = append(scanResults, "Windows Update scanner not available") } // Scan Winget packages - if wingetScanner.IsAvailable() { + if wingetScanner.IsAvailable() && cfg.Subsystems.Winget.Enabled { log.Println(" - Scanning Winget packages...") - updates, err := wingetScanner.Scan() + updates, err := subsystemScan("Winget", wingetCB, cfg.Subsystems.Winget.Timeout, wingetScanner.Scan) if err != nil { errorMsg := fmt.Sprintf("Winget scan failed: %v", err) log.Printf(" %s\n", errorMsg) @@ -641,6 +797,8 @@ func handleScanUpdates(apiClient *client.Client, cfg *config.Config, aptScanner scanResults = append(scanResults, resultMsg) allUpdates = append(allUpdates, updates...) } + } else if !cfg.Subsystems.Winget.Enabled { + scanResults = append(scanResults, "Winget scanner disabled") } else { scanResults = append(scanResults, "Winget scanner not available") } @@ -678,7 +836,7 @@ func handleScanUpdates(apiClient *client.Client, cfg *config.Config, aptScanner } // Report the scan log - if err := apiClient.ReportLog(cfg.AgentID, logReport); err != nil { + if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil { log.Printf("Failed to report scan log: %v\n", err) // Continue anyway - updates are more important } @@ -871,7 +1029,7 @@ func handleListUpdatesCommand(cfg *config.Config, exportFormat string) error { } // handleInstallUpdates handles install_updates command -func handleInstallUpdates(apiClient *client.Client, cfg *config.Config, commandID string, params map[string]interface{}) error { +func handleInstallUpdates(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, commandID string, params map[string]interface{}) error { log.Println("Installing updates...") // Parse parameters @@ -948,7 +1106,7 @@ func handleInstallUpdates(apiClient *client.Client, cfg *config.Config, commandI DurationSeconds: result.DurationSeconds, } - if reportErr := apiClient.ReportLog(cfg.AgentID, logReport); reportErr != nil { + if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil { log.Printf("Failed to report installation failure: %v\n", reportErr) } @@ -971,7 +1129,7 @@ func handleInstallUpdates(apiClient *client.Client, cfg *config.Config, commandI logReport.Stdout += fmt.Sprintf("\nPackages installed: %v", result.PackagesInstalled) } - if reportErr := apiClient.ReportLog(cfg.AgentID, logReport); reportErr != nil { + if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil { log.Printf("Failed to report installation success: %v\n", reportErr) } @@ -989,7 +1147,7 @@ func handleInstallUpdates(apiClient *client.Client, cfg *config.Config, commandI } // handleDryRunUpdate handles dry_run_update command -func handleDryRunUpdate(apiClient *client.Client, cfg *config.Config, commandID string, params map[string]interface{}) error { +func handleDryRunUpdate(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, commandID string, params map[string]interface{}) error { log.Println("Performing dry run update...") // Parse parameters @@ -1034,7 +1192,7 @@ func handleDryRunUpdate(apiClient *client.Client, cfg *config.Config, commandID DurationSeconds: 0, } - if reportErr := apiClient.ReportLog(cfg.AgentID, logReport); reportErr != nil { + if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil { log.Printf("Failed to report dry run failure: %v\n", reportErr) } @@ -1085,7 +1243,7 @@ func handleDryRunUpdate(apiClient *client.Client, cfg *config.Config, commandID logReport.Stdout += fmt.Sprintf("\nDependencies found: %v", result.Dependencies) } - if reportErr := apiClient.ReportLog(cfg.AgentID, logReport); reportErr != nil { + if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil { log.Printf("Failed to report dry run success: %v\n", reportErr) } @@ -1105,7 +1263,7 @@ func handleDryRunUpdate(apiClient *client.Client, cfg *config.Config, commandID } // handleConfirmDependencies handles confirm_dependencies command -func handleConfirmDependencies(apiClient *client.Client, cfg *config.Config, commandID string, params map[string]interface{}) error { +func handleConfirmDependencies(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, commandID string, params map[string]interface{}) error { log.Println("Installing update with confirmed dependencies...") // Parse parameters @@ -1172,7 +1330,7 @@ func handleConfirmDependencies(apiClient *client.Client, cfg *config.Config, com DurationSeconds: result.DurationSeconds, } - if reportErr := apiClient.ReportLog(cfg.AgentID, logReport); reportErr != nil { + if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil { log.Printf("Failed to report installation failure: %v\n", reportErr) } @@ -1198,7 +1356,7 @@ func handleConfirmDependencies(apiClient *client.Client, cfg *config.Config, com logReport.Stdout += fmt.Sprintf("\nDependencies included: %v", dependencies) } - if reportErr := apiClient.ReportLog(cfg.AgentID, logReport); reportErr != nil { + if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil { log.Printf("Failed to report installation success: %v\n", reportErr) } @@ -1216,7 +1374,7 @@ func handleConfirmDependencies(apiClient *client.Client, cfg *config.Config, com } // handleEnableHeartbeat handles enable_heartbeat command -func handleEnableHeartbeat(apiClient *client.Client, cfg *config.Config, commandID string, params map[string]interface{}) error { +func handleEnableHeartbeat(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, commandID string, params map[string]interface{}) error { // Parse duration parameter (default to 10 minutes) durationMinutes := 10 if duration, ok := params["duration_minutes"]; ok { @@ -1250,7 +1408,7 @@ func handleEnableHeartbeat(apiClient *client.Client, cfg *config.Config, command DurationSeconds: 0, } - if reportErr := apiClient.ReportLog(cfg.AgentID, logReport); reportErr != nil { + if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil { log.Printf("[Heartbeat] Failed to report heartbeat enable: %v", reportErr) } @@ -1291,7 +1449,7 @@ func handleEnableHeartbeat(apiClient *client.Client, cfg *config.Config, command } // handleDisableHeartbeat handles disable_heartbeat command -func handleDisableHeartbeat(apiClient *client.Client, cfg *config.Config, commandID string) error { +func handleDisableHeartbeat(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, commandID string) error { log.Printf("[Heartbeat] Disabling rapid polling") // Update agent config to disable rapid polling @@ -1314,7 +1472,7 @@ func handleDisableHeartbeat(apiClient *client.Client, cfg *config.Config, comman DurationSeconds: 0, } - if reportErr := apiClient.ReportLog(cfg.AgentID, logReport); reportErr != nil { + if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil { log.Printf("[Heartbeat] Failed to report heartbeat disable: %v", reportErr) } @@ -1407,7 +1565,7 @@ func reportSystemInfo(apiClient *client.Client, cfg *config.Config) error { } // handleReboot handles reboot command -func handleReboot(apiClient *client.Client, cfg *config.Config, commandID string, params map[string]interface{}) error { +func handleReboot(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, commandID string, params map[string]interface{}) error { log.Println("[Reboot] Processing reboot request...") // Parse parameters @@ -1449,7 +1607,7 @@ func handleReboot(apiClient *client.Client, cfg *config.Config, commandID string ExitCode: 1, DurationSeconds: 0, } - apiClient.ReportLog(cfg.AgentID, logReport) + reportLogWithAck(apiClient, cfg, ackTracker, logReport) return err } @@ -1469,7 +1627,7 @@ func handleReboot(apiClient *client.Client, cfg *config.Config, commandID string ExitCode: 1, DurationSeconds: 0, } - apiClient.ReportLog(cfg.AgentID, logReport) + reportLogWithAck(apiClient, cfg, ackTracker, logReport) return err } @@ -1487,7 +1645,7 @@ func handleReboot(apiClient *client.Client, cfg *config.Config, commandID string DurationSeconds: 0, } - if reportErr := apiClient.ReportLog(cfg.AgentID, logReport); reportErr != nil { + if reportErr := reportLogWithAck(apiClient, cfg, ackTracker, logReport); reportErr != nil { log.Printf("[Reboot] Failed to report reboot command result: %v", reportErr) } diff --git a/aggregator-agent/internal/acknowledgment/tracker.go b/aggregator-agent/internal/acknowledgment/tracker.go new file mode 100644 index 0000000..29f3e75 --- /dev/null +++ b/aggregator-agent/internal/acknowledgment/tracker.go @@ -0,0 +1,193 @@ +package acknowledgment + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" +) + +// PendingResult represents a command result awaiting acknowledgment +type PendingResult struct { + CommandID string `json:"command_id"` + SentAt time.Time `json:"sent_at"` + RetryCount int `json:"retry_count"` +} + +// Tracker manages pending acknowledgments for command results +type Tracker struct { + pending map[string]*PendingResult + mu sync.RWMutex + filePath string + maxAge time.Duration // Max time to keep pending (default 24h) + maxRetries int // Max retries before giving up (default 10) +} + +// NewTracker creates a new acknowledgment tracker +func NewTracker(statePath string) *Tracker { + return &Tracker{ + pending: make(map[string]*PendingResult), + filePath: filepath.Join(statePath, "pending_acks.json"), + maxAge: 24 * time.Hour, + maxRetries: 10, + } +} + +// Load restores pending acknowledgments from disk +func (t *Tracker) Load() error { + t.mu.Lock() + defer t.mu.Unlock() + + // If file doesn't exist, that's fine (fresh start) + if _, err := os.Stat(t.filePath); os.IsNotExist(err) { + return nil + } + + data, err := os.ReadFile(t.filePath) + if err != nil { + return fmt.Errorf("failed to read pending acks: %w", err) + } + + if len(data) == 0 { + return nil // Empty file + } + + var pending map[string]*PendingResult + if err := json.Unmarshal(data, &pending); err != nil { + return fmt.Errorf("failed to parse pending acks: %w", err) + } + + t.pending = pending + return nil +} + +// Save persists pending acknowledgments to disk +func (t *Tracker) Save() error { + t.mu.RLock() + defer t.mu.RUnlock() + + // Ensure directory exists + dir := filepath.Dir(t.filePath) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create ack directory: %w", err) + } + + data, err := json.MarshalIndent(t.pending, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal pending acks: %w", err) + } + + if err := os.WriteFile(t.filePath, data, 0600); err != nil { + return fmt.Errorf("failed to write pending acks: %w", err) + } + + return nil +} + +// Add marks a command result as pending acknowledgment +func (t *Tracker) Add(commandID string) { + t.mu.Lock() + defer t.mu.Unlock() + + t.pending[commandID] = &PendingResult{ + CommandID: commandID, + SentAt: time.Now(), + RetryCount: 0, + } +} + +// Acknowledge marks command results as acknowledged and removes them +func (t *Tracker) Acknowledge(commandIDs []string) { + t.mu.Lock() + defer t.mu.Unlock() + + for _, id := range commandIDs { + delete(t.pending, id) + } +} + +// GetPending returns list of command IDs awaiting acknowledgment +func (t *Tracker) GetPending() []string { + t.mu.RLock() + defer t.mu.RUnlock() + + ids := make([]string, 0, len(t.pending)) + for id := range t.pending { + ids = append(ids, id) + } + return ids +} + +// IncrementRetry increments retry count for a command +func (t *Tracker) IncrementRetry(commandID string) { + t.mu.Lock() + defer t.mu.Unlock() + + if result, exists := t.pending[commandID]; exists { + result.RetryCount++ + } +} + +// Cleanup removes old or over-retried pending results +func (t *Tracker) Cleanup() int { + t.mu.Lock() + defer t.mu.Unlock() + + now := time.Now() + removed := 0 + + for id, result := range t.pending { + // Remove if too old + if now.Sub(result.SentAt) > t.maxAge { + delete(t.pending, id) + removed++ + continue + } + + // Remove if retried too many times + if result.RetryCount >= t.maxRetries { + delete(t.pending, id) + removed++ + continue + } + } + + return removed +} + +// Stats returns statistics about pending acknowledgments +func (t *Tracker) Stats() Stats { + t.mu.RLock() + defer t.mu.RUnlock() + + stats := Stats{ + Total: len(t.pending), + } + + now := time.Now() + for _, result := range t.pending { + age := now.Sub(result.SentAt) + + if age > 1*time.Hour { + stats.OlderThan1Hour++ + } + if result.RetryCount > 0 { + stats.WithRetries++ + } + if result.RetryCount >= 5 { + stats.HighRetries++ + } + } + + return stats +} + +// Stats holds statistics about pending acknowledgments +type Stats struct { + Total int + OlderThan1Hour int + WithRetries int + HighRetries int +} diff --git a/aggregator-agent/internal/circuitbreaker/circuitbreaker.go b/aggregator-agent/internal/circuitbreaker/circuitbreaker.go new file mode 100644 index 0000000..5ed1d79 --- /dev/null +++ b/aggregator-agent/internal/circuitbreaker/circuitbreaker.go @@ -0,0 +1,233 @@ +package circuitbreaker + +import ( + "fmt" + "sync" + "time" +) + +// State represents the circuit breaker state +type State int + +const ( + StateClosed State = iota // Normal operation + StateOpen // Circuit is open, failing fast + StateHalfOpen // Testing if service recovered +) + +func (s State) String() string { + switch s { + case StateClosed: + return "closed" + case StateOpen: + return "open" + case StateHalfOpen: + return "half-open" + default: + return "unknown" + } +} + +// Config holds circuit breaker configuration +type Config struct { + FailureThreshold int // Number of failures before opening + FailureWindow time.Duration // Time window to track failures + OpenDuration time.Duration // How long circuit stays open + HalfOpenAttempts int // Successful attempts needed to close from half-open +} + +// CircuitBreaker implements the circuit breaker pattern for subsystems +type CircuitBreaker struct { + name string + config Config + + mu sync.RWMutex + state State + failures []time.Time // Timestamps of recent failures + consecutiveSuccess int // Consecutive successes in half-open state + openedAt time.Time // When circuit was opened +} + +// New creates a new circuit breaker +func New(name string, config Config) *CircuitBreaker { + return &CircuitBreaker{ + name: name, + config: config, + state: StateClosed, + failures: make([]time.Time, 0), + } +} + +// Call executes the given function with circuit breaker protection +func (cb *CircuitBreaker) Call(fn func() error) error { + // Check if we can execute + if err := cb.beforeCall(); err != nil { + return err + } + + // Execute the function + err := fn() + + // Record the result + cb.afterCall(err) + + return err +} + +// beforeCall checks if the call should be allowed +func (cb *CircuitBreaker) beforeCall() error { + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case StateClosed: + // Normal operation, allow call + return nil + + case StateOpen: + // Check if enough time has passed to try half-open + if time.Since(cb.openedAt) >= cb.config.OpenDuration { + cb.state = StateHalfOpen + cb.consecutiveSuccess = 0 + return nil + } + // Circuit is still open, fail fast + return fmt.Errorf("circuit breaker [%s] is OPEN (will retry at %s)", + cb.name, cb.openedAt.Add(cb.config.OpenDuration).Format("15:04:05")) + + case StateHalfOpen: + // In half-open state, allow limited attempts + return nil + + default: + return fmt.Errorf("unknown circuit breaker state") + } +} + +// afterCall records the result and updates state +func (cb *CircuitBreaker) afterCall(err error) { + cb.mu.Lock() + defer cb.mu.Unlock() + + now := time.Now() + + if err != nil { + // Record failure + cb.recordFailure(now) + + // If in half-open, go back to open on any failure + if cb.state == StateHalfOpen { + cb.state = StateOpen + cb.openedAt = now + cb.consecutiveSuccess = 0 + return + } + + // Check if we should open the circuit + if cb.shouldOpen(now) { + cb.state = StateOpen + cb.openedAt = now + cb.consecutiveSuccess = 0 + } + } else { + // Success + switch cb.state { + case StateHalfOpen: + // Count consecutive successes + cb.consecutiveSuccess++ + if cb.consecutiveSuccess >= cb.config.HalfOpenAttempts { + // Enough successes, close the circuit + cb.state = StateClosed + cb.failures = make([]time.Time, 0) + cb.consecutiveSuccess = 0 + } + + case StateClosed: + // Clean up old failures on success + cb.cleanupOldFailures(now) + } + } +} + +// recordFailure adds a failure timestamp +func (cb *CircuitBreaker) recordFailure(now time.Time) { + cb.failures = append(cb.failures, now) + cb.cleanupOldFailures(now) +} + +// cleanupOldFailures removes failures outside the window +func (cb *CircuitBreaker) cleanupOldFailures(now time.Time) { + cutoff := now.Add(-cb.config.FailureWindow) + validFailures := make([]time.Time, 0) + + for _, failTime := range cb.failures { + if failTime.After(cutoff) { + validFailures = append(validFailures, failTime) + } + } + + cb.failures = validFailures +} + +// shouldOpen determines if circuit should open based on failures +func (cb *CircuitBreaker) shouldOpen(now time.Time) bool { + cb.cleanupOldFailures(now) + return len(cb.failures) >= cb.config.FailureThreshold +} + +// State returns the current circuit breaker state (thread-safe) +func (cb *CircuitBreaker) State() State { + cb.mu.RLock() + defer cb.mu.RUnlock() + return cb.state +} + +// GetStats returns current circuit breaker statistics +func (cb *CircuitBreaker) GetStats() Stats { + cb.mu.RLock() + defer cb.mu.RUnlock() + + stats := Stats{ + Name: cb.name, + State: cb.state.String(), + RecentFailures: len(cb.failures), + ConsecutiveSuccess: cb.consecutiveSuccess, + } + + if cb.state == StateOpen && !cb.openedAt.IsZero() { + nextAttempt := cb.openedAt.Add(cb.config.OpenDuration) + stats.NextAttempt = &nextAttempt + } + + return stats +} + +// Reset manually resets the circuit breaker to closed state +func (cb *CircuitBreaker) Reset() { + cb.mu.Lock() + defer cb.mu.Unlock() + + cb.state = StateClosed + cb.failures = make([]time.Time, 0) + cb.consecutiveSuccess = 0 + cb.openedAt = time.Time{} +} + +// Stats holds circuit breaker statistics +type Stats struct { + Name string + State string + RecentFailures int + ConsecutiveSuccess int + NextAttempt *time.Time +} + +// String returns a human-readable representation of the stats +func (s Stats) String() string { + if s.NextAttempt != nil { + return fmt.Sprintf("[%s] state=%s, failures=%d, next_attempt=%s", + s.Name, s.State, s.RecentFailures, s.NextAttempt.Format("15:04:05")) + } + return fmt.Sprintf("[%s] state=%s, failures=%d, success=%d", + s.Name, s.State, s.RecentFailures, s.ConsecutiveSuccess) +} diff --git a/aggregator-agent/internal/circuitbreaker/circuitbreaker_test.go b/aggregator-agent/internal/circuitbreaker/circuitbreaker_test.go new file mode 100644 index 0000000..2b9a27f --- /dev/null +++ b/aggregator-agent/internal/circuitbreaker/circuitbreaker_test.go @@ -0,0 +1,138 @@ +package circuitbreaker + +import ( + "errors" + "testing" + "time" +) + +func TestCircuitBreaker_NormalOperation(t *testing.T) { + cb := New("test", Config{ + FailureThreshold: 3, + FailureWindow: 1 * time.Minute, + OpenDuration: 1 * time.Minute, + HalfOpenAttempts: 2, + }) + + // Should allow calls in closed state + err := cb.Call(func() error { return nil }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if cb.State() != StateClosed { + t.Fatalf("expected state closed, got %v", cb.State()) + } +} + +func TestCircuitBreaker_OpensAfterFailures(t *testing.T) { + cb := New("test", Config{ + FailureThreshold: 3, + FailureWindow: 1 * time.Minute, + OpenDuration: 100 * time.Millisecond, + HalfOpenAttempts: 2, + }) + + testErr := errors.New("test error") + + // Record 3 failures + for i := 0; i < 3; i++ { + cb.Call(func() error { return testErr }) + } + + // Should now be open + if cb.State() != StateOpen { + t.Fatalf("expected state open after %d failures, got %v", 3, cb.State()) + } + + // Next call should fail fast + err := cb.Call(func() error { return nil }) + if err == nil { + t.Fatal("expected circuit breaker to reject call, but it succeeded") + } +} + +func TestCircuitBreaker_HalfOpenRecovery(t *testing.T) { + cb := New("test", Config{ + FailureThreshold: 2, + FailureWindow: 1 * time.Minute, + OpenDuration: 50 * time.Millisecond, + HalfOpenAttempts: 2, + }) + + testErr := errors.New("test error") + + // Open the circuit + cb.Call(func() error { return testErr }) + cb.Call(func() error { return testErr }) + + if cb.State() != StateOpen { + t.Fatal("circuit should be open") + } + + // Wait for open duration + time.Sleep(60 * time.Millisecond) + + // Should transition to half-open and allow call + err := cb.Call(func() error { return nil }) + if err != nil { + t.Fatalf("expected call to succeed in half-open state, got %v", err) + } + + if cb.State() != StateHalfOpen { + t.Fatalf("expected half-open state, got %v", cb.State()) + } + + // One more success should close it + cb.Call(func() error { return nil }) + + if cb.State() != StateClosed { + t.Fatalf("expected closed state after %d successes, got %v", 2, cb.State()) + } +} + +func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { + cb := New("test", Config{ + FailureThreshold: 2, + FailureWindow: 1 * time.Minute, + OpenDuration: 50 * time.Millisecond, + HalfOpenAttempts: 2, + }) + + testErr := errors.New("test error") + + // Open the circuit + cb.Call(func() error { return testErr }) + cb.Call(func() error { return testErr }) + + // Wait and attempt in half-open + time.Sleep(60 * time.Millisecond) + cb.Call(func() error { return nil }) // Half-open + + // Fail in half-open - should go back to open + cb.Call(func() error { return testErr }) + + if cb.State() != StateOpen { + t.Fatalf("expected open state after half-open failure, got %v", cb.State()) + } +} + +func TestCircuitBreaker_Stats(t *testing.T) { + cb := New("test-subsystem", Config{ + FailureThreshold: 3, + FailureWindow: 1 * time.Minute, + OpenDuration: 1 * time.Minute, + HalfOpenAttempts: 2, + }) + + stats := cb.GetStats() + if stats.Name != "test-subsystem" { + t.Fatalf("expected name 'test-subsystem', got %s", stats.Name) + } + if stats.State != "closed" { + t.Fatalf("expected state 'closed', got %s", stats.State) + } + if stats.RecentFailures != 0 { + t.Fatalf("expected 0 failures, got %d", stats.RecentFailures) + } +} diff --git a/aggregator-agent/internal/client/client.go b/aggregator-agent/internal/client/client.go index afef686..75565f3 100644 --- a/aggregator-agent/internal/client/client.go +++ b/aggregator-agent/internal/client/client.go @@ -174,8 +174,9 @@ type Command struct { // CommandsResponse contains pending commands type CommandsResponse struct { - Commands []Command `json:"commands"` - RapidPolling *RapidPollingConfig `json:"rapid_polling,omitempty"` + Commands []Command `json:"commands"` + RapidPolling *RapidPollingConfig `json:"rapid_polling,omitempty"` + AcknowledgedIDs []string `json:"acknowledged_ids,omitempty"` // IDs server has received } // RapidPollingConfig contains rapid polling configuration from server @@ -196,11 +197,15 @@ type SystemMetrics struct { Uptime string `json:"uptime,omitempty"` Version string `json:"version,omitempty"` // Agent version Metadata map[string]interface{} `json:"metadata,omitempty"` // Additional metadata + + // Command acknowledgment tracking + PendingAcknowledgments []string `json:"pending_acknowledgments,omitempty"` // Command IDs awaiting ACK } // GetCommands retrieves pending commands from the server // Optionally sends lightweight system metrics in the request -func (c *Client) GetCommands(agentID uuid.UUID, metrics *SystemMetrics) ([]Command, error) { +// Returns the full response including commands and acknowledged IDs +func (c *Client) GetCommands(agentID uuid.UUID, metrics *SystemMetrics) (*CommandsResponse, error) { url := fmt.Sprintf("%s/api/v1/agents/%s/commands", c.baseURL, agentID) var req *http.Request @@ -252,7 +257,7 @@ func (c *Client) GetCommands(agentID uuid.UUID, metrics *SystemMetrics) ([]Comma } } - return result.Commands, nil + return &result, nil } // UpdateReport represents discovered updates diff --git a/aggregator-agent/internal/config/config.go b/aggregator-agent/internal/config/config.go index 3a16461..0c82c2a 100644 --- a/aggregator-agent/internal/config/config.go +++ b/aggregator-agent/internal/config/config.go @@ -80,6 +80,9 @@ type Config struct { Metadata map[string]string `json:"metadata,omitempty"` // Custom metadata DisplayName string `json:"display_name,omitempty"` // Human-readable name Organization string `json:"organization,omitempty"` // Organization/group + + // Subsystem Configuration + Subsystems SubsystemsConfig `json:"subsystems,omitempty"` // Scanner subsystem configs } // Load reads configuration from multiple sources with priority order: @@ -144,8 +147,9 @@ func getDefaultConfig() *Config { MaxBackups: 3, MaxAge: 28, // 28 days }, - Tags: []string{}, - Metadata: make(map[string]string), + Subsystems: GetDefaultSubsystemsConfig(), + Tags: []string{}, + Metadata: make(map[string]string), } } @@ -311,6 +315,11 @@ func mergeConfig(target, source *Config) { if !source.RapidPollingUntil.IsZero() { target.RapidPollingUntil = source.RapidPollingUntil } + + // Merge subsystems config + if source.Subsystems != (SubsystemsConfig{}) { + target.Subsystems = source.Subsystems + } } // validateConfig validates configuration values diff --git a/aggregator-agent/internal/config/subsystems.go b/aggregator-agent/internal/config/subsystems.go new file mode 100644 index 0000000..4ee06f5 --- /dev/null +++ b/aggregator-agent/internal/config/subsystems.go @@ -0,0 +1,95 @@ +package config + +import "time" + +// SubsystemConfig holds configuration for individual subsystems +type SubsystemConfig struct { + // Execution settings + Enabled bool `json:"enabled"` + Timeout time.Duration `json:"timeout"` // Timeout for this subsystem + + // Circuit breaker settings + CircuitBreaker CircuitBreakerConfig `json:"circuit_breaker"` +} + +// CircuitBreakerConfig holds circuit breaker settings for subsystems +type CircuitBreakerConfig struct { + // Enabled controls whether circuit breaker is active + Enabled bool `json:"enabled"` + + // FailureThreshold is the number of consecutive failures before opening the circuit + FailureThreshold int `json:"failure_threshold"` + + // FailureWindow is the time window to track failures (e.g., 3 failures in 10 minutes) + FailureWindow time.Duration `json:"failure_window"` + + // OpenDuration is how long the circuit stays open before attempting recovery + OpenDuration time.Duration `json:"open_duration"` + + // HalfOpenAttempts is the number of test attempts in half-open state before fully closing + HalfOpenAttempts int `json:"half_open_attempts"` +} + +// SubsystemsConfig holds all subsystem configurations +type SubsystemsConfig struct { + APT SubsystemConfig `json:"apt"` + DNF SubsystemConfig `json:"dnf"` + Docker SubsystemConfig `json:"docker"` + Windows SubsystemConfig `json:"windows"` + Winget SubsystemConfig `json:"winget"` + Storage SubsystemConfig `json:"storage"` +} + +// GetDefaultSubsystemsConfig returns default subsystem configurations +func GetDefaultSubsystemsConfig() SubsystemsConfig { + // Default circuit breaker config + defaultCB := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 3, // 3 consecutive failures + FailureWindow: 10 * time.Minute, // within 10 minutes + OpenDuration: 30 * time.Minute, // circuit open for 30 min + HalfOpenAttempts: 2, // 2 successful attempts to close circuit + } + + // Aggressive circuit breaker for Windows Update (known to be slow/problematic) + windowsCB := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 2, // Only 2 failures + FailureWindow: 15 * time.Minute, + OpenDuration: 60 * time.Minute, // Open for 1 hour + HalfOpenAttempts: 3, + } + + return SubsystemsConfig{ + APT: SubsystemConfig{ + Enabled: true, + Timeout: 30 * time.Second, + CircuitBreaker: defaultCB, + }, + DNF: SubsystemConfig{ + Enabled: true, + Timeout: 45 * time.Second, // DNF can be slower + CircuitBreaker: defaultCB, + }, + Docker: SubsystemConfig{ + Enabled: true, + Timeout: 60 * time.Second, // Registry queries can be slow + CircuitBreaker: defaultCB, + }, + Windows: SubsystemConfig{ + Enabled: true, + Timeout: 10 * time.Minute, // Windows Update can be VERY slow + CircuitBreaker: windowsCB, + }, + Winget: SubsystemConfig{ + Enabled: true, + Timeout: 2 * time.Minute, // Winget has multiple retry strategies + CircuitBreaker: defaultCB, + }, + Storage: SubsystemConfig{ + Enabled: true, + Timeout: 10 * time.Second, // Disk info should be fast + CircuitBreaker: defaultCB, + }, + } +} diff --git a/aggregator-agent/test_disk.go b/aggregator-agent/test_disk.go new file mode 100644 index 0000000..ee24b72 --- /dev/null +++ b/aggregator-agent/test_disk.go @@ -0,0 +1,55 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + + "github.com/Fimeg/RedFlag/aggregator-agent/internal/system" +) + +func main() { + // Test lightweight metrics (most common use case) + fmt.Println("=== Enhanced Lightweight Metrics Test ===") + metrics, err := system.GetLightweightMetrics() + if err != nil { + log.Printf("Error getting lightweight metrics: %v", err) + } else { + // Pretty print the JSON + jsonData, _ := json.MarshalIndent(metrics, "", " ") + fmt.Printf("LightweightMetrics:\n%s\n\n", jsonData) + + // Show key findings + fmt.Printf("Root Disk: %.1fGB used / %.1fGB total (%.1f%%)\n", + metrics.DiskUsedGB, metrics.DiskTotalGB, metrics.DiskPercent) + + if metrics.LargestDiskTotalGB > 0 { + fmt.Printf("Largest Disk (%s): %.1fGB used / %.1fGB total (%.1f%%)\n", + metrics.LargestDiskMount, metrics.LargestDiskUsedGB, metrics.LargestDiskTotalGB, metrics.LargestDiskPercent) + } else { + fmt.Printf("No largest disk detected (this might be the issue!)\n") + } + } + + // Test full system info (detailed disk inventory) + fmt.Println("\n=== Enhanced System Info Test ===") + sysInfo, err := system.GetSystemInfo("test-v0.1.5") + if err != nil { + log.Printf("Error getting system info: %v", err) + } else { + fmt.Printf("Found %d disks:\n", len(sysInfo.DiskInfo)) + for i, disk := range sysInfo.DiskInfo { + fmt.Printf(" Disk %d: %s (%s) - %s, %.1fGB used / %.1fGB total (%.1f%%)", + i+1, disk.Mountpoint, disk.Filesystem, disk.DiskType, + float64(disk.Used)/(1024*1024*1024), float64(disk.Total)/(1024*1024*1024), disk.UsedPercent) + + if disk.IsRoot { + fmt.Printf(" [ROOT]") + } + if disk.IsLargest { + fmt.Printf(" [LARGEST]") + } + fmt.Printf("\n") + } + } +} \ No newline at end of file diff --git a/aggregator-server/cmd/server/main.go b/aggregator-server/cmd/server/main.go index 8e181a6..c13924c 100644 --- a/aggregator-server/cmd/server/main.go +++ b/aggregator-server/cmd/server/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "log" @@ -12,6 +13,7 @@ import ( "github.com/Fimeg/RedFlag/aggregator-server/internal/config" "github.com/Fimeg/RedFlag/aggregator-server/internal/database" "github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries" + "github.com/Fimeg/RedFlag/aggregator-server/internal/scheduler" "github.com/Fimeg/RedFlag/aggregator-server/internal/services" "github.com/gin-gonic/gin" ) @@ -285,10 +287,45 @@ func main() { timeoutService.Start() log.Println("Timeout service started") - // Add graceful shutdown for timeout service + // Initialize and start scheduler + schedulerConfig := scheduler.DefaultConfig() + subsystemScheduler := scheduler.NewScheduler(schedulerConfig, agentQueries, commandQueries) + + // Load subsystems into queue + ctx := context.Background() + if err := subsystemScheduler.LoadSubsystems(ctx); err != nil { + log.Printf("Warning: Failed to load subsystems: %v", err) + } else { + log.Println("Subsystems loaded into scheduler") + } + + // Start scheduler + if err := subsystemScheduler.Start(); err != nil { + log.Printf("Warning: Failed to start scheduler: %v", err) + } + + // Add scheduler stats endpoint (after scheduler is initialized) + router.GET("/api/v1/scheduler/stats", middleware.AuthMiddleware(), func(c *gin.Context) { + stats := subsystemScheduler.GetStats() + queueStats := subsystemScheduler.GetQueueStats() + c.JSON(200, gin.H{ + "scheduler": stats, + "queue": queueStats, + }) + }) + + // Add graceful shutdown for services defer func() { + log.Println("Shutting down services...") + + // Stop scheduler first + if err := subsystemScheduler.Stop(); err != nil { + log.Printf("Error stopping scheduler: %v", err) + } + + // Stop timeout service timeoutService.Stop() - log.Println("Timeout service stopped") + log.Println("Services stopped") }() // Start server diff --git a/aggregator-server/internal/api/handlers/agents.go b/aggregator-server/internal/api/handlers/agents.go index 4cd8dc2..1c244b7 100644 --- a/aggregator-server/internal/api/handlers/agents.go +++ b/aggregator-server/internal/api/handlers/agents.go @@ -269,6 +269,21 @@ func (h *AgentHandler) GetCommands(c *gin.Context) { return } + // Process command acknowledgments if agent sent any + var acknowledgedIDs []string + if metrics != nil && len(metrics.PendingAcknowledgments) > 0 { + // Verify which commands from the agent's pending list have been recorded + verified, err := h.commandQueries.VerifyCommandsCompleted(metrics.PendingAcknowledgments) + if err != nil { + log.Printf("Warning: Failed to verify command acknowledgments for agent %s: %v", agentID, err) + } else { + acknowledgedIDs = verified + if len(acknowledgedIDs) > 0 { + log.Printf("Acknowledged %d command results for agent %s", len(acknowledgedIDs), agentID) + } + } + } + // Process heartbeat metadata from agent check-ins if metrics.Metadata != nil { agent, err := h.agentQueries.GetAgentByID(agentID) @@ -437,8 +452,9 @@ func (h *AgentHandler) GetCommands(c *gin.Context) { } response := models.CommandsResponse{ - Commands: commandItems, - RapidPolling: rapidPolling, + Commands: commandItems, + RapidPolling: rapidPolling, + AcknowledgedIDs: acknowledgedIDs, } c.JSON(http.StatusOK, response) diff --git a/aggregator-server/internal/database/queries/commands.go b/aggregator-server/internal/database/queries/commands.go index 8bb43ce..9877281 100644 --- a/aggregator-server/internal/database/queries/commands.go +++ b/aggregator-server/internal/database/queries/commands.go @@ -337,3 +337,61 @@ func (q *CommandQueries) ClearAllFailedCommands(days int) (int64, error) { return result.RowsAffected() } + +// CountPendingCommandsForAgent returns the number of pending commands for a specific agent +// Used by scheduler for backpressure detection +func (q *CommandQueries) CountPendingCommandsForAgent(agentID uuid.UUID) (int, error) { + var count int + query := ` + SELECT COUNT(*) + FROM agent_commands + WHERE agent_id = $1 AND status = 'pending' + ` + err := q.db.Get(&count, query, agentID) + return count, err +} + +// VerifyCommandsCompleted checks which command IDs from the provided list have been completed or failed +// Returns the list of command IDs that have been successfully recorded (completed or failed status) +func (q *CommandQueries) VerifyCommandsCompleted(commandIDs []string) ([]string, error) { + if len(commandIDs) == 0 { + return []string{}, nil + } + + // Convert string IDs to UUIDs + uuidIDs := make([]uuid.UUID, 0, len(commandIDs)) + for _, idStr := range commandIDs { + id, err := uuid.Parse(idStr) + if err != nil { + // Skip invalid UUIDs + continue + } + uuidIDs = append(uuidIDs, id) + } + + if len(uuidIDs) == 0 { + return []string{}, nil + } + + // Query for commands that are completed or failed + query := ` + SELECT id + FROM agent_commands + WHERE id = ANY($1) + AND status IN ('completed', 'failed') + ` + + var completedUUIDs []uuid.UUID + err := q.db.Select(&completedUUIDs, query, uuidIDs) + if err != nil { + return nil, fmt.Errorf("failed to verify command completion: %w", err) + } + + // Convert back to strings + completedIDs := make([]string, len(completedUUIDs)) + for i, id := range completedUUIDs { + completedIDs[i] = id.String() + } + + return completedIDs, nil +} diff --git a/aggregator-server/internal/models/command.go b/aggregator-server/internal/models/command.go index f31ec94..03805ac 100644 --- a/aggregator-server/internal/models/command.go +++ b/aggregator-server/internal/models/command.go @@ -23,8 +23,9 @@ type AgentCommand struct { // CommandsResponse is returned when an agent checks in for commands type CommandsResponse struct { - Commands []CommandItem `json:"commands"` - RapidPolling *RapidPollingConfig `json:"rapid_polling,omitempty"` + Commands []CommandItem `json:"commands"` + RapidPolling *RapidPollingConfig `json:"rapid_polling,omitempty"` + AcknowledgedIDs []string `json:"acknowledged_ids,omitempty"` // IDs server has received } // RapidPollingConfig contains rapid polling configuration for the agent diff --git a/aggregator-server/internal/scheduler/queue.go b/aggregator-server/internal/scheduler/queue.go new file mode 100644 index 0000000..33f5c9f --- /dev/null +++ b/aggregator-server/internal/scheduler/queue.go @@ -0,0 +1,286 @@ +package scheduler + +import ( + "container/heap" + "fmt" + "sync" + "time" + + "github.com/google/uuid" +) + +// SubsystemJob represents a scheduled subsystem scan +type SubsystemJob struct { + AgentID uuid.UUID + AgentHostname string // For logging/debugging + Subsystem string + IntervalMinutes int + NextRunAt time.Time + Enabled bool + index int // Heap index (managed by heap.Interface) +} + +// String returns a human-readable representation of the job +func (j *SubsystemJob) String() string { + return fmt.Sprintf("[%s/%s] next_run=%s interval=%dm", + j.AgentHostname, j.Subsystem, + j.NextRunAt.Format("15:04:05"), j.IntervalMinutes) +} + +// jobHeap implements heap.Interface for SubsystemJob priority queue +// Jobs are ordered by NextRunAt (earliest first) +type jobHeap []*SubsystemJob + +func (h jobHeap) Len() int { return len(h) } + +func (h jobHeap) Less(i, j int) bool { + return h[i].NextRunAt.Before(h[j].NextRunAt) +} + +func (h jobHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].index = i + h[j].index = j +} + +func (h *jobHeap) Push(x interface{}) { + n := len(*h) + job := x.(*SubsystemJob) + job.index = n + *h = append(*h, job) +} + +func (h *jobHeap) Pop() interface{} { + old := *h + n := len(old) + job := old[n-1] + old[n-1] = nil // Avoid memory leak + job.index = -1 // Mark as removed + *h = old[0 : n-1] + return job +} + +// PriorityQueue is a thread-safe priority queue for subsystem jobs +// Jobs are ordered by their NextRunAt timestamp (earliest first) +type PriorityQueue struct { + heap jobHeap + mu sync.RWMutex + + // Index for fast lookups by agent_id + subsystem + index map[string]*SubsystemJob // key: "agent_id:subsystem" +} + +// NewPriorityQueue creates a new empty priority queue +func NewPriorityQueue() *PriorityQueue { + pq := &PriorityQueue{ + heap: make(jobHeap, 0), + index: make(map[string]*SubsystemJob), + } + heap.Init(&pq.heap) + return pq +} + +// Push adds a job to the queue +// If a job with the same agent_id + subsystem already exists, it's updated +func (pq *PriorityQueue) Push(job *SubsystemJob) { + pq.mu.Lock() + defer pq.mu.Unlock() + + key := makeKey(job.AgentID, job.Subsystem) + + // Check if job already exists + if existing, exists := pq.index[key]; exists { + // Update existing job + existing.NextRunAt = job.NextRunAt + existing.IntervalMinutes = job.IntervalMinutes + existing.Enabled = job.Enabled + existing.AgentHostname = job.AgentHostname + heap.Fix(&pq.heap, existing.index) + return + } + + // Add new job + heap.Push(&pq.heap, job) + pq.index[key] = job +} + +// Pop removes and returns the job with the earliest NextRunAt +// Returns nil if queue is empty +func (pq *PriorityQueue) Pop() *SubsystemJob { + pq.mu.Lock() + defer pq.mu.Unlock() + + if pq.heap.Len() == 0 { + return nil + } + + job := heap.Pop(&pq.heap).(*SubsystemJob) + key := makeKey(job.AgentID, job.Subsystem) + delete(pq.index, key) + + return job +} + +// Peek returns the job with the earliest NextRunAt without removing it +// Returns nil if queue is empty +func (pq *PriorityQueue) Peek() *SubsystemJob { + pq.mu.RLock() + defer pq.mu.RUnlock() + + if pq.heap.Len() == 0 { + return nil + } + + return pq.heap[0] +} + +// Remove removes a specific job from the queue +// Returns true if job was found and removed, false otherwise +func (pq *PriorityQueue) Remove(agentID uuid.UUID, subsystem string) bool { + pq.mu.Lock() + defer pq.mu.Unlock() + + key := makeKey(agentID, subsystem) + job, exists := pq.index[key] + if !exists { + return false + } + + heap.Remove(&pq.heap, job.index) + delete(pq.index, key) + + return true +} + +// Get retrieves a specific job without removing it +// Returns nil if not found +func (pq *PriorityQueue) Get(agentID uuid.UUID, subsystem string) *SubsystemJob { + pq.mu.RLock() + defer pq.mu.RUnlock() + + key := makeKey(agentID, subsystem) + return pq.index[key] +} + +// PopBefore returns all jobs with NextRunAt <= before, up to limit +// Jobs are removed from the queue +// If limit <= 0, all matching jobs are returned +func (pq *PriorityQueue) PopBefore(before time.Time, limit int) []*SubsystemJob { + pq.mu.Lock() + defer pq.mu.Unlock() + + var jobs []*SubsystemJob + + for pq.heap.Len() > 0 { + // Peek at next job + next := pq.heap[0] + + // Stop if next job is after our cutoff + if next.NextRunAt.After(before) { + break + } + + // Stop if we've hit the limit + if limit > 0 && len(jobs) >= limit { + break + } + + // Pop and collect the job + job := heap.Pop(&pq.heap).(*SubsystemJob) + key := makeKey(job.AgentID, job.Subsystem) + delete(pq.index, key) + + jobs = append(jobs, job) + } + + return jobs +} + +// PeekBefore returns all jobs with NextRunAt <= before without removing them +// If limit <= 0, all matching jobs are returned +func (pq *PriorityQueue) PeekBefore(before time.Time, limit int) []*SubsystemJob { + pq.mu.RLock() + defer pq.mu.RUnlock() + + var jobs []*SubsystemJob + + for i := 0; i < pq.heap.Len(); i++ { + job := pq.heap[i] + + if job.NextRunAt.After(before) { + // Since heap is sorted by NextRunAt, we can break early + // Note: This is only valid because we peek in order + break + } + + if limit > 0 && len(jobs) >= limit { + break + } + + jobs = append(jobs, job) + } + + return jobs +} + +// Len returns the number of jobs in the queue +func (pq *PriorityQueue) Len() int { + pq.mu.RLock() + defer pq.mu.RUnlock() + return pq.heap.Len() +} + +// Clear removes all jobs from the queue +func (pq *PriorityQueue) Clear() { + pq.mu.Lock() + defer pq.mu.Unlock() + + pq.heap = make(jobHeap, 0) + pq.index = make(map[string]*SubsystemJob) + heap.Init(&pq.heap) +} + +// GetStats returns statistics about the queue +func (pq *PriorityQueue) GetStats() QueueStats { + pq.mu.RLock() + defer pq.mu.RUnlock() + + stats := QueueStats{ + Size: pq.heap.Len(), + } + + if pq.heap.Len() > 0 { + stats.NextRunAt = &pq.heap[0].NextRunAt + stats.OldestJob = pq.heap[0].String() + } + + // Count jobs by subsystem + stats.JobsBySubsystem = make(map[string]int) + for _, job := range pq.heap { + stats.JobsBySubsystem[job.Subsystem]++ + } + + return stats +} + +// QueueStats holds statistics about the priority queue +type QueueStats struct { + Size int + NextRunAt *time.Time + OldestJob string + JobsBySubsystem map[string]int +} + +// String returns a human-readable representation of stats +func (s QueueStats) String() string { + nextRun := "empty" + if s.NextRunAt != nil { + nextRun = s.NextRunAt.Format("15:04:05") + } + return fmt.Sprintf("size=%d next=%s oldest=%s", s.Size, nextRun, s.OldestJob) +} + +// makeKey creates a unique key for agent_id + subsystem +func makeKey(agentID uuid.UUID, subsystem string) string { + return agentID.String() + ":" + subsystem +} diff --git a/aggregator-server/internal/scheduler/queue_test.go b/aggregator-server/internal/scheduler/queue_test.go new file mode 100644 index 0000000..a3e1923 --- /dev/null +++ b/aggregator-server/internal/scheduler/queue_test.go @@ -0,0 +1,539 @@ +package scheduler + +import ( + "sync" + "testing" + "time" + + "github.com/google/uuid" +) + +func TestPriorityQueue_BasicOperations(t *testing.T) { + pq := NewPriorityQueue() + + // Test empty queue + if pq.Len() != 0 { + t.Fatalf("expected empty queue, got len=%d", pq.Len()) + } + + if pq.Peek() != nil { + t.Fatal("Peek on empty queue should return nil") + } + + if pq.Pop() != nil { + t.Fatal("Pop on empty queue should return nil") + } + + // Push a job + agent1 := uuid.New() + job1 := &SubsystemJob{ + AgentID: agent1, + AgentHostname: "agent-01", + Subsystem: "updates", + IntervalMinutes: 15, + NextRunAt: time.Now().Add(10 * time.Minute), + } + pq.Push(job1) + + if pq.Len() != 1 { + t.Fatalf("expected len=1 after push, got %d", pq.Len()) + } + + // Peek should return the job without removing it + peeked := pq.Peek() + if peeked == nil { + t.Fatal("Peek should return job") + } + if peeked.AgentID != agent1 { + t.Fatal("Peek returned wrong job") + } + if pq.Len() != 1 { + t.Fatal("Peek should not remove job") + } + + // Pop should return and remove the job + popped := pq.Pop() + if popped == nil { + t.Fatal("Pop should return job") + } + if popped.AgentID != agent1 { + t.Fatal("Pop returned wrong job") + } + if pq.Len() != 0 { + t.Fatal("Pop should remove job") + } +} + +func TestPriorityQueue_Ordering(t *testing.T) { + pq := NewPriorityQueue() + now := time.Now() + + // Push jobs in random order + jobs := []*SubsystemJob{ + { + AgentID: uuid.New(), + Subsystem: "updates", + NextRunAt: now.Add(30 * time.Minute), // Third + }, + { + AgentID: uuid.New(), + Subsystem: "storage", + NextRunAt: now.Add(5 * time.Minute), // First + }, + { + AgentID: uuid.New(), + Subsystem: "docker", + NextRunAt: now.Add(15 * time.Minute), // Second + }, + } + + for _, job := range jobs { + pq.Push(job) + } + + // Pop should return jobs in NextRunAt order + first := pq.Pop() + if first.Subsystem != "storage" { + t.Fatalf("expected 'storage' first, got '%s'", first.Subsystem) + } + + second := pq.Pop() + if second.Subsystem != "docker" { + t.Fatalf("expected 'docker' second, got '%s'", second.Subsystem) + } + + third := pq.Pop() + if third.Subsystem != "updates" { + t.Fatalf("expected 'updates' third, got '%s'", third.Subsystem) + } +} + +func TestPriorityQueue_UpdateExisting(t *testing.T) { + pq := NewPriorityQueue() + agentID := uuid.New() + now := time.Now() + + // Push initial job + job1 := &SubsystemJob{ + AgentID: agentID, + Subsystem: "updates", + IntervalMinutes: 15, + NextRunAt: now.Add(15 * time.Minute), + } + pq.Push(job1) + + if pq.Len() != 1 { + t.Fatalf("expected len=1, got %d", pq.Len()) + } + + // Push same agent+subsystem with different NextRunAt + job2 := &SubsystemJob{ + AgentID: agentID, + Subsystem: "updates", + IntervalMinutes: 30, + NextRunAt: now.Add(30 * time.Minute), + } + pq.Push(job2) + + // Should still be 1 job (updated, not added) + if pq.Len() != 1 { + t.Fatalf("expected len=1 after update, got %d", pq.Len()) + } + + // Verify the job was updated + job := pq.Pop() + if job.IntervalMinutes != 30 { + t.Fatalf("expected interval=30, got %d", job.IntervalMinutes) + } + if !job.NextRunAt.Equal(now.Add(30 * time.Minute)) { + t.Fatal("NextRunAt was not updated") + } +} + +func TestPriorityQueue_Remove(t *testing.T) { + pq := NewPriorityQueue() + + agent1 := uuid.New() + agent2 := uuid.New() + + pq.Push(&SubsystemJob{ + AgentID: agent1, + Subsystem: "updates", + NextRunAt: time.Now(), + }) + pq.Push(&SubsystemJob{ + AgentID: agent2, + Subsystem: "storage", + NextRunAt: time.Now(), + }) + + if pq.Len() != 2 { + t.Fatalf("expected len=2, got %d", pq.Len()) + } + + // Remove existing job + removed := pq.Remove(agent1, "updates") + if !removed { + t.Fatal("Remove should return true for existing job") + } + if pq.Len() != 1 { + t.Fatalf("expected len=1 after remove, got %d", pq.Len()) + } + + // Remove non-existent job + removed = pq.Remove(agent1, "updates") + if removed { + t.Fatal("Remove should return false for non-existent job") + } + if pq.Len() != 1 { + t.Fatal("Remove of non-existent job should not affect queue") + } +} + +func TestPriorityQueue_Get(t *testing.T) { + pq := NewPriorityQueue() + + agentID := uuid.New() + job := &SubsystemJob{ + AgentID: agentID, + Subsystem: "updates", + NextRunAt: time.Now(), + } + pq.Push(job) + + // Get existing job + retrieved := pq.Get(agentID, "updates") + if retrieved == nil { + t.Fatal("Get should return job") + } + if retrieved.AgentID != agentID { + t.Fatal("Get returned wrong job") + } + if pq.Len() != 1 { + t.Fatal("Get should not remove job") + } + + // Get non-existent job + retrieved = pq.Get(uuid.New(), "storage") + if retrieved != nil { + t.Fatal("Get should return nil for non-existent job") + } +} + +func TestPriorityQueue_PopBefore(t *testing.T) { + pq := NewPriorityQueue() + now := time.Now() + + // Add jobs with different NextRunAt times + for i := 0; i < 5; i++ { + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "updates", + NextRunAt: now.Add(time.Duration(i*10) * time.Minute), + }) + } + + if pq.Len() != 5 { + t.Fatalf("expected len=5, got %d", pq.Len()) + } + + // Pop jobs before now+25min (should get 3 jobs: 0, 10, 20 minutes) + cutoff := now.Add(25 * time.Minute) + jobs := pq.PopBefore(cutoff, 0) // no limit + + if len(jobs) != 3 { + t.Fatalf("expected 3 jobs, got %d", len(jobs)) + } + + // Verify all returned jobs are before cutoff + for _, job := range jobs { + if job.NextRunAt.After(cutoff) { + t.Fatalf("job NextRunAt %v is after cutoff %v", job.NextRunAt, cutoff) + } + } + + // Queue should have 2 jobs left + if pq.Len() != 2 { + t.Fatalf("expected len=2 after PopBefore, got %d", pq.Len()) + } +} + +func TestPriorityQueue_PopBeforeWithLimit(t *testing.T) { + pq := NewPriorityQueue() + now := time.Now() + + // Add 5 jobs all due now + for i := 0; i < 5; i++ { + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "updates", + NextRunAt: now, + }) + } + + // Pop with limit of 3 + jobs := pq.PopBefore(now.Add(1*time.Hour), 3) + + if len(jobs) != 3 { + t.Fatalf("expected 3 jobs (limit), got %d", len(jobs)) + } + + if pq.Len() != 2 { + t.Fatalf("expected 2 jobs remaining, got %d", pq.Len()) + } +} + +func TestPriorityQueue_PeekBefore(t *testing.T) { + pq := NewPriorityQueue() + now := time.Now() + + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "updates", + NextRunAt: now.Add(5 * time.Minute), + }) + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "storage", + NextRunAt: now.Add(15 * time.Minute), + }) + + // Peek before 10 minutes (should see 1 job) + jobs := pq.PeekBefore(now.Add(10*time.Minute), 0) + + if len(jobs) != 1 { + t.Fatalf("expected 1 job, got %d", len(jobs)) + } + + // Queue should still have both jobs + if pq.Len() != 2 { + t.Fatalf("expected len=2 after PeekBefore, got %d", pq.Len()) + } +} + +func TestPriorityQueue_Clear(t *testing.T) { + pq := NewPriorityQueue() + + // Add some jobs + for i := 0; i < 10; i++ { + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "updates", + NextRunAt: time.Now(), + }) + } + + if pq.Len() != 10 { + t.Fatalf("expected len=10, got %d", pq.Len()) + } + + // Clear the queue + pq.Clear() + + if pq.Len() != 0 { + t.Fatalf("expected len=0 after clear, got %d", pq.Len()) + } + + if pq.Peek() != nil { + t.Fatal("Peek should return nil after clear") + } +} + +func TestPriorityQueue_GetStats(t *testing.T) { + pq := NewPriorityQueue() + now := time.Now() + + // Empty queue stats + stats := pq.GetStats() + if stats.Size != 0 { + t.Fatalf("expected size=0, got %d", stats.Size) + } + if stats.NextRunAt != nil { + t.Fatal("empty queue should have nil NextRunAt") + } + + // Add jobs + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + AgentHostname: "agent-01", + Subsystem: "updates", + NextRunAt: now.Add(5 * time.Minute), + IntervalMinutes: 15, + }) + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "storage", + NextRunAt: now.Add(10 * time.Minute), + }) + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "updates", + NextRunAt: now.Add(15 * time.Minute), + }) + + stats = pq.GetStats() + + if stats.Size != 3 { + t.Fatalf("expected size=3, got %d", stats.Size) + } + + if stats.NextRunAt == nil { + t.Fatal("NextRunAt should not be nil") + } + + // Should be the earliest job (5 minutes) + expectedNext := now.Add(5 * time.Minute) + if !stats.NextRunAt.Equal(expectedNext) { + t.Fatalf("expected NextRunAt=%v, got %v", expectedNext, stats.NextRunAt) + } + + // Check subsystem counts + if stats.JobsBySubsystem["updates"] != 2 { + t.Fatalf("expected 2 updates jobs, got %d", stats.JobsBySubsystem["updates"]) + } + if stats.JobsBySubsystem["storage"] != 1 { + t.Fatalf("expected 1 storage job, got %d", stats.JobsBySubsystem["storage"]) + } +} + +func TestPriorityQueue_Concurrency(t *testing.T) { + pq := NewPriorityQueue() + var wg sync.WaitGroup + + // Concurrent pushes + numGoroutines := 100 + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(idx int) { + defer wg.Done() + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "updates", + NextRunAt: time.Now().Add(time.Duration(idx) * time.Second), + }) + }(i) + } + + wg.Wait() + + if pq.Len() != numGoroutines { + t.Fatalf("expected len=%d after concurrent pushes, got %d", numGoroutines, pq.Len()) + } + + // Concurrent pops + wg.Add(numGoroutines) + popped := make(chan *SubsystemJob, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + if job := pq.Pop(); job != nil { + popped <- job + } + }() + } + + wg.Wait() + close(popped) + + // Count popped jobs + count := 0 + for range popped { + count++ + } + + if count != numGoroutines { + t.Fatalf("expected %d popped jobs, got %d", numGoroutines, count) + } + + if pq.Len() != 0 { + t.Fatalf("expected empty queue after concurrent pops, got len=%d", pq.Len()) + } +} + +func TestPriorityQueue_ConcurrentReadWrite(t *testing.T) { + pq := NewPriorityQueue() + done := make(chan bool) + + // Writer goroutine + go func() { + for i := 0; i < 1000; i++ { + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "updates", + NextRunAt: time.Now(), + }) + time.Sleep(1 * time.Microsecond) + } + done <- true + }() + + // Reader goroutine + go func() { + for i := 0; i < 1000; i++ { + pq.Peek() + pq.GetStats() + time.Sleep(1 * time.Microsecond) + } + done <- true + }() + + // Wait for both to complete + <-done + <-done + + // Should not panic and queue should be consistent + if pq.Len() < 0 { + t.Fatal("queue length became negative (race condition)") + } +} + +func BenchmarkPriorityQueue_Push(b *testing.B) { + pq := NewPriorityQueue() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "updates", + NextRunAt: time.Now().Add(time.Duration(i) * time.Second), + }) + } +} + +func BenchmarkPriorityQueue_Pop(b *testing.B) { + pq := NewPriorityQueue() + + // Pre-fill the queue + for i := 0; i < b.N; i++ { + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "updates", + NextRunAt: time.Now().Add(time.Duration(i) * time.Second), + }) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pq.Pop() + } +} + +func BenchmarkPriorityQueue_Peek(b *testing.B) { + pq := NewPriorityQueue() + + // Pre-fill with 10000 jobs + for i := 0; i < 10000; i++ { + pq.Push(&SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "updates", + NextRunAt: time.Now().Add(time.Duration(i) * time.Second), + }) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pq.Peek() + } +} diff --git a/aggregator-server/internal/scheduler/scheduler.go b/aggregator-server/internal/scheduler/scheduler.go new file mode 100644 index 0000000..56bf962 --- /dev/null +++ b/aggregator-server/internal/scheduler/scheduler.go @@ -0,0 +1,406 @@ +package scheduler + +import ( + "context" + "fmt" + "log" + "math/rand" + "sync" + "time" + + "github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries" + "github.com/Fimeg/RedFlag/aggregator-server/internal/models" + "github.com/google/uuid" +) + +// Config holds scheduler configuration +type Config struct { + // CheckInterval is how often to check the queue for due jobs + CheckInterval time.Duration + + // LookaheadWindow is how far ahead to look for jobs + // Jobs due within this window will be batched and jittered + LookaheadWindow time.Duration + + // MaxJitter is the maximum random delay added to job execution + MaxJitter time.Duration + + // NumWorkers is the number of parallel workers for command creation + NumWorkers int + + // BackpressureThreshold is max pending commands per agent before skipping + BackpressureThreshold int + + // RateLimitPerSecond is max commands created per second (0 = unlimited) + RateLimitPerSecond int +} + +// DefaultConfig returns production-ready default configuration +func DefaultConfig() Config { + return Config{ + CheckInterval: 10 * time.Second, + LookaheadWindow: 60 * time.Second, + MaxJitter: 30 * time.Second, + NumWorkers: 10, + BackpressureThreshold: 5, + RateLimitPerSecond: 100, + } +} + +// Scheduler manages subsystem job scheduling with priority queue and worker pool +type Scheduler struct { + config Config + queue *PriorityQueue + + // Database queries + agentQueries *queries.AgentQueries + commandQueries *queries.CommandQueries + + // Worker pool + jobChan chan *SubsystemJob + workers []*worker + + // Rate limiting + rateLimiter chan struct{} + + // Lifecycle management + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + shutdown chan struct{} + + // Metrics + mu sync.RWMutex + stats Stats +} + +// Stats holds scheduler statistics +type Stats struct { + JobsProcessed int64 + JobsSkipped int64 + CommandsCreated int64 + CommandsFailed int64 + BackpressureSkips int64 + LastProcessedAt time.Time + QueueSize int + WorkerPoolUtilized int + AverageProcessingMS int64 +} + +// NewScheduler creates a new scheduler instance +func NewScheduler(config Config, agentQueries *queries.AgentQueries, commandQueries *queries.CommandQueries) *Scheduler { + ctx, cancel := context.WithCancel(context.Background()) + + s := &Scheduler{ + config: config, + queue: NewPriorityQueue(), + agentQueries: agentQueries, + commandQueries: commandQueries, + jobChan: make(chan *SubsystemJob, 1000), // Buffer 1000 jobs + workers: make([]*worker, config.NumWorkers), + shutdown: make(chan struct{}), + ctx: ctx, + cancel: cancel, + } + + // Initialize rate limiter if configured + if config.RateLimitPerSecond > 0 { + s.rateLimiter = make(chan struct{}, config.RateLimitPerSecond) + go s.refillRateLimiter() + } + + // Initialize workers + for i := 0; i < config.NumWorkers; i++ { + s.workers[i] = &worker{ + id: i, + scheduler: s, + } + } + + return s +} + +// LoadSubsystems loads all enabled auto-run subsystems from database into queue +func (s *Scheduler) LoadSubsystems(ctx context.Context) error { + log.Println("[Scheduler] Loading subsystems from database...") + + // Get all agents (pass empty strings to get all agents regardless of status/os) + agents, err := s.agentQueries.ListAgents("", "") + if err != nil { + return fmt.Errorf("failed to get agents: %w", err) + } + + // For now, we'll create default subsystems for each agent + // In full implementation, this would read from agent_subsystems table + subsystems := []string{"updates", "storage", "system", "docker"} + intervals := map[string]int{ + "updates": 15, // 15 minutes + "storage": 15, + "system": 30, + "docker": 15, + } + + loaded := 0 + for _, agent := range agents { + // Skip offline agents (haven't checked in for 10+ minutes) + if time.Since(agent.LastSeen) > 10*time.Minute { + continue + } + + for _, subsystem := range subsystems { + // TODO: Check agent metadata for subsystem enablement + // For now, assume all subsystems are enabled + + job := &SubsystemJob{ + AgentID: agent.ID, + AgentHostname: agent.Hostname, + Subsystem: subsystem, + IntervalMinutes: intervals[subsystem], + NextRunAt: time.Now().Add(time.Duration(intervals[subsystem]) * time.Minute), + Enabled: true, + } + + s.queue.Push(job) + loaded++ + } + } + + log.Printf("[Scheduler] Loaded %d subsystem jobs for %d agents\n", loaded, len(agents)) + return nil +} + +// Start begins the scheduler main loop and workers +func (s *Scheduler) Start() error { + log.Printf("[Scheduler] Starting with %d workers, check interval %v\n", + s.config.NumWorkers, s.config.CheckInterval) + + // Start workers + for _, w := range s.workers { + s.wg.Add(1) + go w.run() + } + + // Start main loop + s.wg.Add(1) + go s.mainLoop() + + log.Println("[Scheduler] Started successfully") + return nil +} + +// Stop gracefully shuts down the scheduler +func (s *Scheduler) Stop() error { + log.Println("[Scheduler] Shutting down...") + + // Signal shutdown + s.cancel() + close(s.shutdown) + + // Close job channel (workers will drain and exit) + close(s.jobChan) + + // Wait for all goroutines with timeout + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + + select { + case <-done: + log.Println("[Scheduler] Shutdown complete") + return nil + case <-time.After(30 * time.Second): + log.Println("[Scheduler] Shutdown timeout - forcing exit") + return fmt.Errorf("shutdown timeout") + } +} + +// mainLoop is the scheduler's main processing loop +func (s *Scheduler) mainLoop() { + defer s.wg.Done() + + ticker := time.NewTicker(s.config.CheckInterval) + defer ticker.Stop() + + log.Printf("[Scheduler] Main loop started (check every %v)\n", s.config.CheckInterval) + + for { + select { + case <-s.shutdown: + log.Println("[Scheduler] Main loop shutting down") + return + + case <-ticker.C: + s.processQueue() + } + } +} + +// processQueue checks for due jobs and dispatches them to workers +func (s *Scheduler) processQueue() { + start := time.Now() + + // Get all jobs due within lookahead window + cutoff := time.Now().Add(s.config.LookaheadWindow) + dueJobs := s.queue.PopBefore(cutoff, 0) // No limit, get all + + if len(dueJobs) == 0 { + // No jobs due, just update stats + s.mu.Lock() + s.stats.QueueSize = s.queue.Len() + s.mu.Unlock() + return + } + + log.Printf("[Scheduler] Processing %d jobs due before %s\n", + len(dueJobs), cutoff.Format("15:04:05")) + + // Add jitter to each job and dispatch to workers + dispatched := 0 + for _, job := range dueJobs { + // Add random jitter (0 to MaxJitter) + jitter := time.Duration(rand.Intn(int(s.config.MaxJitter.Seconds()))) * time.Second + job.NextRunAt = job.NextRunAt.Add(jitter) + + // Dispatch to worker pool (non-blocking) + select { + case s.jobChan <- job: + dispatched++ + default: + // Worker pool full, re-queue job + log.Printf("[Scheduler] Worker pool full, re-queueing %s\n", job.String()) + s.queue.Push(job) + + s.mu.Lock() + s.stats.JobsSkipped++ + s.mu.Unlock() + } + } + + // Update stats + duration := time.Since(start) + s.mu.Lock() + s.stats.JobsProcessed += int64(dispatched) + s.stats.LastProcessedAt = time.Now() + s.stats.QueueSize = s.queue.Len() + s.stats.WorkerPoolUtilized = len(s.jobChan) + s.stats.AverageProcessingMS = duration.Milliseconds() + s.mu.Unlock() + + log.Printf("[Scheduler] Dispatched %d jobs in %v (queue: %d remaining)\n", + dispatched, duration, s.queue.Len()) +} + +// refillRateLimiter continuously refills the rate limiter token bucket +func (s *Scheduler) refillRateLimiter() { + ticker := time.NewTicker(time.Second / time.Duration(s.config.RateLimitPerSecond)) + defer ticker.Stop() + + for { + select { + case <-s.shutdown: + return + case <-ticker.C: + // Try to add token (non-blocking) + select { + case s.rateLimiter <- struct{}{}: + default: + // Bucket full, skip + } + } + } +} + +// GetStats returns current scheduler statistics (thread-safe) +func (s *Scheduler) GetStats() Stats { + s.mu.RLock() + defer s.mu.RUnlock() + return s.stats +} + +// GetQueueStats returns current queue statistics +func (s *Scheduler) GetQueueStats() QueueStats { + return s.queue.GetStats() +} + +// worker processes jobs from the job channel +type worker struct { + id int + scheduler *Scheduler +} + +func (w *worker) run() { + defer w.scheduler.wg.Done() + + log.Printf("[Worker %d] Started\n", w.id) + + for job := range w.scheduler.jobChan { + if err := w.processJob(job); err != nil { + log.Printf("[Worker %d] Failed to process %s: %v\n", w.id, job.String(), err) + + w.scheduler.mu.Lock() + w.scheduler.stats.CommandsFailed++ + w.scheduler.mu.Unlock() + } else { + w.scheduler.mu.Lock() + w.scheduler.stats.CommandsCreated++ + w.scheduler.mu.Unlock() + } + + // Re-queue job for next execution + job.NextRunAt = time.Now().Add(time.Duration(job.IntervalMinutes) * time.Minute) + w.scheduler.queue.Push(job) + } + + log.Printf("[Worker %d] Stopped\n", w.id) +} + +func (w *worker) processJob(job *SubsystemJob) error { + // Apply rate limiting if configured + if w.scheduler.rateLimiter != nil { + select { + case <-w.scheduler.rateLimiter: + // Token acquired + case <-w.scheduler.shutdown: + return fmt.Errorf("shutdown during rate limit wait") + } + } + + // Check backpressure: skip if agent has too many pending commands + pendingCount, err := w.scheduler.commandQueries.CountPendingCommandsForAgent(job.AgentID) + if err != nil { + return fmt.Errorf("failed to check pending commands: %w", err) + } + + if pendingCount >= w.scheduler.config.BackpressureThreshold { + log.Printf("[Worker %d] Backpressure: agent %s has %d pending commands, skipping %s\n", + w.id, job.AgentHostname, pendingCount, job.Subsystem) + + w.scheduler.mu.Lock() + w.scheduler.stats.BackpressureSkips++ + w.scheduler.mu.Unlock() + + return nil // Not an error, just skipped + } + + // Create command + cmd := &models.AgentCommand{ + ID: uuid.New(), + AgentID: job.AgentID, + CommandType: fmt.Sprintf("scan_%s", job.Subsystem), + Params: models.JSONB{}, + Status: models.CommandStatusPending, + Source: models.CommandSourceSystem, + CreatedAt: time.Now(), + } + + if err := w.scheduler.commandQueries.CreateCommand(cmd); err != nil { + return fmt.Errorf("failed to create command: %w", err) + } + + log.Printf("[Worker %d] Created %s command for %s\n", + w.id, job.Subsystem, job.AgentHostname) + + return nil +} diff --git a/aggregator-server/internal/scheduler/scheduler_test.go b/aggregator-server/internal/scheduler/scheduler_test.go new file mode 100644 index 0000000..7bfc149 --- /dev/null +++ b/aggregator-server/internal/scheduler/scheduler_test.go @@ -0,0 +1,323 @@ +package scheduler + +import ( + "testing" + "time" + + "github.com/google/uuid" +) + +func TestScheduler_NewScheduler(t *testing.T) { + config := DefaultConfig() + s := NewScheduler(config, nil, nil) + + if s == nil { + t.Fatal("NewScheduler returned nil") + } + + if s.config.NumWorkers != 10 { + t.Fatalf("expected 10 workers, got %d", s.config.NumWorkers) + } + + if s.queue == nil { + t.Fatal("queue not initialized") + } + + if len(s.workers) != config.NumWorkers { + t.Fatalf("expected %d workers, got %d", config.NumWorkers, len(s.workers)) + } +} + +func TestScheduler_DefaultConfig(t *testing.T) { + config := DefaultConfig() + + if config.CheckInterval != 10*time.Second { + t.Fatalf("expected check interval 10s, got %v", config.CheckInterval) + } + + if config.LookaheadWindow != 60*time.Second { + t.Fatalf("expected lookahead 60s, got %v", config.LookaheadWindow) + } + + if config.MaxJitter != 30*time.Second { + t.Fatalf("expected max jitter 30s, got %v", config.MaxJitter) + } + + if config.NumWorkers != 10 { + t.Fatalf("expected 10 workers, got %d", config.NumWorkers) + } + + if config.BackpressureThreshold != 5 { + t.Fatalf("expected backpressure threshold 5, got %d", config.BackpressureThreshold) + } + + if config.RateLimitPerSecond != 100 { + t.Fatalf("expected rate limit 100/s, got %d", config.RateLimitPerSecond) + } +} + +func TestScheduler_QueueIntegration(t *testing.T) { + config := DefaultConfig() + s := NewScheduler(config, nil, nil) + + // Add jobs to queue + agent1 := uuid.New() + agent2 := uuid.New() + + job1 := &SubsystemJob{ + AgentID: agent1, + AgentHostname: "agent-01", + Subsystem: "updates", + IntervalMinutes: 15, + NextRunAt: time.Now().Add(5 * time.Minute), + } + + job2 := &SubsystemJob{ + AgentID: agent2, + AgentHostname: "agent-02", + Subsystem: "storage", + IntervalMinutes: 15, + NextRunAt: time.Now().Add(10 * time.Minute), + } + + s.queue.Push(job1) + s.queue.Push(job2) + + if s.queue.Len() != 2 { + t.Fatalf("expected queue len 2, got %d", s.queue.Len()) + } + + // Get stats + stats := s.GetQueueStats() + if stats.Size != 2 { + t.Fatalf("expected stats size 2, got %d", stats.Size) + } +} + +func TestScheduler_GetStats(t *testing.T) { + config := DefaultConfig() + s := NewScheduler(config, nil, nil) + + // Initial stats should be zero + stats := s.GetStats() + + if stats.JobsProcessed != 0 { + t.Fatalf("expected 0 jobs processed, got %d", stats.JobsProcessed) + } + + if stats.CommandsCreated != 0 { + t.Fatalf("expected 0 commands created, got %d", stats.CommandsCreated) + } + + if stats.BackpressureSkips != 0 { + t.Fatalf("expected 0 backpressure skips, got %d", stats.BackpressureSkips) + } + + // Manually update stats (simulating processing) + s.mu.Lock() + s.stats.JobsProcessed = 100 + s.stats.CommandsCreated = 95 + s.stats.BackpressureSkips = 5 + s.mu.Unlock() + + stats = s.GetStats() + + if stats.JobsProcessed != 100 { + t.Fatalf("expected 100 jobs processed, got %d", stats.JobsProcessed) + } + + if stats.CommandsCreated != 95 { + t.Fatalf("expected 95 commands created, got %d", stats.CommandsCreated) + } + + if stats.BackpressureSkips != 5 { + t.Fatalf("expected 5 backpressure skips, got %d", stats.BackpressureSkips) + } +} + +func TestScheduler_StartStop(t *testing.T) { + config := Config{ + CheckInterval: 100 * time.Millisecond, // Fast for testing + LookaheadWindow: 60 * time.Second, + MaxJitter: 1 * time.Second, + NumWorkers: 2, + BackpressureThreshold: 5, + RateLimitPerSecond: 0, // Disable rate limiting for test + } + + s := NewScheduler(config, nil, nil) + + // Start scheduler + err := s.Start() + if err != nil { + t.Fatalf("failed to start scheduler: %v", err) + } + + // Let it run for a bit + time.Sleep(500 * time.Millisecond) + + // Stop scheduler + err = s.Stop() + if err != nil { + t.Fatalf("failed to stop scheduler: %v", err) + } + + // Should stop cleanly +} + +func TestScheduler_ProcessQueueEmpty(t *testing.T) { + config := DefaultConfig() + s := NewScheduler(config, nil, nil) + + // Process empty queue should not panic + s.processQueue() + + stats := s.GetStats() + if stats.JobsProcessed != 0 { + t.Fatalf("expected 0 jobs processed on empty queue, got %d", stats.JobsProcessed) + } +} + +func TestScheduler_ProcessQueueWithJobs(t *testing.T) { + config := Config{ + CheckInterval: 1 * time.Second, + LookaheadWindow: 60 * time.Second, + MaxJitter: 5 * time.Second, + NumWorkers: 2, + BackpressureThreshold: 5, + RateLimitPerSecond: 0, // Disable for test + } + + s := NewScheduler(config, nil, nil) + + // Add jobs that are due now + for i := 0; i < 5; i++ { + job := &SubsystemJob{ + AgentID: uuid.New(), + AgentHostname: "test-agent", + Subsystem: "updates", + IntervalMinutes: 15, + NextRunAt: time.Now(), // Due now + } + s.queue.Push(job) + } + + if s.queue.Len() != 5 { + t.Fatalf("expected 5 jobs in queue, got %d", s.queue.Len()) + } + + // Process the queue + s.processQueue() + + // Jobs should be dispatched to job channel + // Note: Without database, workers can't actually process them + // But we can verify they were dispatched + + stats := s.GetStats() + if stats.JobsProcessed == 0 { + t.Fatal("expected some jobs to be processed") + } +} + +func TestScheduler_RateLimiterRefill(t *testing.T) { + config := Config{ + CheckInterval: 1 * time.Second, + LookaheadWindow: 60 * time.Second, + MaxJitter: 1 * time.Second, + NumWorkers: 2, + BackpressureThreshold: 5, + RateLimitPerSecond: 10, // 10 tokens per second + } + + s := NewScheduler(config, nil, nil) + + if s.rateLimiter == nil { + t.Fatal("rate limiter not initialized") + } + + // Start refill goroutine + go s.refillRateLimiter() + + // Wait for some tokens to be added + time.Sleep(200 * time.Millisecond) + + // Should have some tokens available + tokensAvailable := 0 + for i := 0; i < 15; i++ { + select { + case <-s.rateLimiter: + tokensAvailable++ + default: + break + } + } + + if tokensAvailable == 0 { + t.Fatal("expected some tokens to be available after refill") + } + + // Should not exceed buffer size (10) + if tokensAvailable > 10 { + t.Fatalf("token bucket overflowed: got %d tokens, max is 10", tokensAvailable) + } +} + +func TestScheduler_ConcurrentQueueAccess(t *testing.T) { + config := DefaultConfig() + s := NewScheduler(config, nil, nil) + + done := make(chan bool) + + // Concurrent pushes + go func() { + for i := 0; i < 100; i++ { + job := &SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "updates", + IntervalMinutes: 15, + NextRunAt: time.Now(), + } + s.queue.Push(job) + } + done <- true + }() + + // Concurrent stats reads + go func() { + for i := 0; i < 100; i++ { + s.GetStats() + s.GetQueueStats() + } + done <- true + }() + + // Wait for both + <-done + <-done + + // Should not panic and should have queued jobs + if s.queue.Len() <= 0 { + t.Fatal("expected jobs in queue after concurrent pushes") + } +} + +func BenchmarkScheduler_ProcessQueue(b *testing.B) { + config := DefaultConfig() + s := NewScheduler(config, nil, nil) + + // Pre-fill queue with jobs + for i := 0; i < 1000; i++ { + job := &SubsystemJob{ + AgentID: uuid.New(), + Subsystem: "updates", + IntervalMinutes: 15, + NextRunAt: time.Now(), + } + s.queue.Push(job) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.processQueue() + } +} diff --git a/aggregator-web/src/components/AgentStorage.tsx b/aggregator-web/src/components/AgentStorage.tsx index e033063..a9aa931 100644 --- a/aggregator-web/src/components/AgentStorage.tsx +++ b/aggregator-web/src/components/AgentStorage.tsx @@ -12,6 +12,7 @@ import { Info, TrendingUp, Server, + MemoryStick, } from 'lucide-react'; import { formatBytes, formatRelativeTime } from '@/lib/utils'; import { agentApi } from '@/lib/api'; @@ -160,120 +161,85 @@ export function AgentStorage({ agentId }: AgentStorageProps) { - {/* Simple list - no boxes, just clean rows */} -
- {/* Memory */} + {/* Memory & Disk - matching Overview styling */} +
+ {/* Memory - GREEN to differentiate from disks */} {storageMetrics && storageMetrics.memory_total_gb > 0 && ( -
-
- Memory - - {storageMetrics.memory_used_gb.toFixed(1)} / {storageMetrics.memory_total_gb.toFixed(1)} GB - ({storageMetrics.memory_percent.toFixed(0)}%) - +
+
+

+ + Memory +

+

+ {storageMetrics.memory_used_gb.toFixed(1)} GB / {storageMetrics.memory_total_gb.toFixed(1)} GB +

-
+
+

+ {storageMetrics.memory_percent.toFixed(0)}% used +

)} - {/* Root Disk */} - {storageMetrics && storageMetrics.disk_total_gb > 0 && ( -
-
- Root filesystem - - {storageMetrics.disk_used_gb.toFixed(1)} / {storageMetrics.disk_total_gb.toFixed(1)} GB - ({storageMetrics.disk_percent.toFixed(0)}%) - + {/* All Disks from system_info.disk_info - BLUE matching Overview */} + {disks.length > 0 && disks.map((disk, index) => ( +
+
+

+ + Disk ({disk.mountpoint}) +

+

+ {formatBytes(disk.used)} / {formatBytes(disk.total)} +

-
+
+
+

+ {disk.used_percent.toFixed(0)}% used +

+
+ ))} + + {/* Fallback if no disk array but we have metadata */} + {disks.length === 0 && storageMetrics && storageMetrics.disk_total_gb > 0 && ( +
+
+

+ + Disk (/) +

+

+ {storageMetrics.disk_used_gb.toFixed(1)} GB / {storageMetrics.disk_total_gb.toFixed(1)} GB +

+
+
+
-
- )} - - {/* Largest disk if different */} - {storageMetrics && storageMetrics.largest_disk_total_gb > 0 && storageMetrics.largest_disk_mount !== '/' && ( -
-
- {storageMetrics.largest_disk_mount} - - {storageMetrics.largest_disk_used_gb.toFixed(1)} / {storageMetrics.largest_disk_total_gb.toFixed(1)} GB - ({storageMetrics.largest_disk_percent.toFixed(0)}%) - -
-
-
-
+

+ {storageMetrics.disk_percent.toFixed(0)}% used +

)}
- {/* All partitions - minimal table */} - {disks.length > 0 && ( -
-

All partitions

-
- - - - - - - - - - - - - {disks.map((disk, index) => ( - - - - - - - - - ))} - -
MountDeviceTypeUsedTotalUsage
-
- {disk.mountpoint} - {disk.is_root && root} -
-
{disk.device}{disk.disk_type}{formatBytes(disk.used)}{formatBytes(disk.total)} -
- {disk.used_percent.toFixed(0)}% -
-
-
-
-
-
-
- )} - - {/* Last updated - minimal */} - {agentData && ( -
- Last updated {agentData.last_seen ? formatRelativeTime(agentData.last_seen) : 'unknown'} -
- )} + {/* Refresh info */} +
+ Auto-refreshes every 30 seconds • Last updated {agentData?.last_seen ? formatRelativeTime(agentData.last_seen) : 'unknown'} +
); } diff --git a/test_disk_detection.go b/test_disk_detection.go new file mode 100644 index 0000000..4a5d09b --- /dev/null +++ b/test_disk_detection.go @@ -0,0 +1,62 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + "path/filepath" + "runtime" + + "github.com/redflag-aggregator/aggregator-agent/internal/system" +) + +func main() { + // Get the absolute path to this file's directory + _, filename, _, _ := runtime.Caller(0) + dir := filepath.Dir(filename) + + // Change to the project root to find the go.mod file + projectRoot := filepath.Dir(dir) + + // Test lightweight metrics (most common use case) + fmt.Println("=== Enhanced Lightweight Metrics Test ===") + metrics, err := system.GetLightweightMetrics() + if err != nil { + log.Printf("Error getting lightweight metrics: %v", err) + } else { + // Pretty print the JSON + jsonData, _ := json.MarshalIndent(metrics, "", " ") + fmt.Printf("LightweightMetrics:\n%s\n\n", jsonData) + + // Show key findings + fmt.Printf("Root Disk: %.1fGB used / %.1fGB total (%.1f%%)\n", + metrics.DiskUsedGB, metrics.DiskTotalGB, metrics.DiskPercent) + + if metrics.LargestDiskTotalGB > 0 { + fmt.Printf("Largest Disk (%s): %.1fGB used / %.1fGB total (%.1f%%)\n", + metrics.LargestDiskMount, metrics.LargestDiskUsedGB, metrics.LargestDiskTotalGB, metrics.LargestDiskPercent) + } + } + + // Test full system info (detailed disk inventory) + fmt.Println("\n=== Enhanced System Info Test ===") + sysInfo, err := system.GetSystemInfo("test-v0.1.5") + if err != nil { + log.Printf("Error getting system info: %v", err) + } else { + fmt.Printf("Found %d disks:\n", len(sysInfo.DiskInfo)) + for i, disk := range sysInfo.DiskInfo { + fmt.Printf(" Disk %d: %s (%s) - %s, %.1fGB used / %.1fGB total (%.1f%%)", + i+1, disk.Mountpoint, disk.Filesystem, disk.DiskType, + float64(disk.Used)/(1024*1024*1024), float64(disk.Total)/(1024*1024*1024), disk.UsedPercent) + + if disk.IsRoot { + fmt.Printf(" [ROOT]") + } + if disk.IsLargest { + fmt.Printf(" [LARGEST]") + } + fmt.Printf("\n") + } + } +} \ No newline at end of file