- Fix recursive call in reportLogWithAck that caused infinite loop - Add machine binding and security API endpoints - Enhance AgentScanners component with security status display - Update scheduler and timeout service reliability - Remove deprecated install.sh script - Add subsystem configuration and logging improvements
441 lines
11 KiB
Go
441 lines
11 KiB
Go
package scheduler
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"math/rand"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries"
|
|
"github.com/Fimeg/RedFlag/aggregator-server/internal/models"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
// Config holds scheduler configuration
|
|
type Config struct {
|
|
// CheckInterval is how often to check the queue for due jobs
|
|
CheckInterval time.Duration
|
|
|
|
// LookaheadWindow is how far ahead to look for jobs
|
|
// Jobs due within this window will be batched and jittered
|
|
LookaheadWindow time.Duration
|
|
|
|
// MaxJitter is the maximum random delay added to job execution
|
|
MaxJitter time.Duration
|
|
|
|
// NumWorkers is the number of parallel workers for command creation
|
|
NumWorkers int
|
|
|
|
// BackpressureThreshold is max pending commands per agent before skipping
|
|
BackpressureThreshold int
|
|
|
|
// RateLimitPerSecond is max commands created per second (0 = unlimited)
|
|
RateLimitPerSecond int
|
|
}
|
|
|
|
// DefaultConfig returns production-ready default configuration
|
|
func DefaultConfig() Config {
|
|
return Config{
|
|
CheckInterval: 10 * time.Second,
|
|
LookaheadWindow: 60 * time.Second,
|
|
MaxJitter: 30 * time.Second,
|
|
NumWorkers: 10,
|
|
BackpressureThreshold: 5,
|
|
RateLimitPerSecond: 100,
|
|
}
|
|
}
|
|
|
|
// Scheduler manages subsystem job scheduling with priority queue and worker pool
|
|
type Scheduler struct {
|
|
config Config
|
|
queue *PriorityQueue
|
|
|
|
// Database queries
|
|
agentQueries *queries.AgentQueries
|
|
commandQueries *queries.CommandQueries
|
|
subsystemQueries *queries.SubsystemQueries
|
|
|
|
// Worker pool
|
|
jobChan chan *SubsystemJob
|
|
workers []*worker
|
|
|
|
// Rate limiting
|
|
rateLimiter chan struct{}
|
|
|
|
// Lifecycle management
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
wg sync.WaitGroup
|
|
shutdown chan struct{}
|
|
|
|
// Metrics
|
|
mu sync.RWMutex
|
|
stats Stats
|
|
}
|
|
|
|
// Stats holds scheduler statistics
|
|
type Stats struct {
|
|
JobsProcessed int64
|
|
JobsSkipped int64
|
|
CommandsCreated int64
|
|
CommandsFailed int64
|
|
BackpressureSkips int64
|
|
LastProcessedAt time.Time
|
|
QueueSize int
|
|
WorkerPoolUtilized int
|
|
AverageProcessingMS int64
|
|
}
|
|
|
|
// NewScheduler creates a new scheduler instance
|
|
func NewScheduler(config Config, agentQueries *queries.AgentQueries, commandQueries *queries.CommandQueries, subsystemQueries *queries.SubsystemQueries) *Scheduler {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
s := &Scheduler{
|
|
config: config,
|
|
queue: NewPriorityQueue(),
|
|
agentQueries: agentQueries,
|
|
commandQueries: commandQueries,
|
|
subsystemQueries: subsystemQueries,
|
|
jobChan: make(chan *SubsystemJob, 1000), // Buffer 1000 jobs
|
|
workers: make([]*worker, config.NumWorkers),
|
|
shutdown: make(chan struct{}),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
|
|
// Initialize rate limiter if configured
|
|
if config.RateLimitPerSecond > 0 {
|
|
s.rateLimiter = make(chan struct{}, config.RateLimitPerSecond)
|
|
go s.refillRateLimiter()
|
|
}
|
|
|
|
// Initialize workers
|
|
for i := 0; i < config.NumWorkers; i++ {
|
|
s.workers[i] = &worker{
|
|
id: i,
|
|
scheduler: s,
|
|
}
|
|
}
|
|
|
|
return s
|
|
}
|
|
|
|
// LoadSubsystems loads all enabled auto-run subsystems from database into queue
|
|
func (s *Scheduler) LoadSubsystems(ctx context.Context) error {
|
|
log.Println("[Scheduler] Loading subsystems from database...")
|
|
|
|
// Get all agents (pass empty strings to get all agents regardless of status/os)
|
|
agents, err := s.agentQueries.ListAgents("", "")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get agents: %w", err)
|
|
}
|
|
|
|
loaded := 0
|
|
for _, agent := range agents {
|
|
// Skip offline agents (haven't checked in for 10+ minutes)
|
|
if time.Since(agent.LastSeen) > 10*time.Minute {
|
|
continue
|
|
}
|
|
|
|
// Get subsystems from database (respect user settings)
|
|
dbSubsystems, err := s.subsystemQueries.GetSubsystems(agent.ID)
|
|
if err != nil {
|
|
log.Printf("[Scheduler] Failed to get subsystems for agent %s: %v", agent.Hostname, err)
|
|
continue
|
|
}
|
|
|
|
// Create jobs only for enabled subsystems with auto_run=true
|
|
for _, dbSub := range dbSubsystems {
|
|
if dbSub.Enabled && dbSub.AutoRun {
|
|
// Use database interval, fallback to default
|
|
intervalMinutes := dbSub.IntervalMinutes
|
|
if intervalMinutes <= 0 {
|
|
intervalMinutes = s.getDefaultInterval(dbSub.Subsystem)
|
|
}
|
|
|
|
var nextRun time.Time
|
|
if dbSub.NextRunAt != nil {
|
|
nextRun = *dbSub.NextRunAt
|
|
} else {
|
|
// If no next run is set, schedule it with default interval
|
|
nextRun = time.Now().Add(time.Duration(intervalMinutes) * time.Minute)
|
|
}
|
|
|
|
job := &SubsystemJob{
|
|
AgentID: agent.ID,
|
|
AgentHostname: agent.Hostname,
|
|
Subsystem: dbSub.Subsystem,
|
|
IntervalMinutes: intervalMinutes,
|
|
NextRunAt: nextRun,
|
|
Enabled: dbSub.Enabled,
|
|
}
|
|
|
|
s.queue.Push(job)
|
|
loaded++
|
|
}
|
|
}
|
|
}
|
|
|
|
log.Printf("[Scheduler] Loaded %d subsystem jobs for %d agents (respecting database settings)\n", loaded, len(agents))
|
|
return nil
|
|
}
|
|
|
|
// getDefaultInterval returns default interval minutes for a subsystem
|
|
// TODO: These intervals need to correlate with agent health scanning settings
|
|
// Each subsystem should be variable based on user-configurable agent health policies
|
|
func (s *Scheduler) getDefaultInterval(subsystem string) int {
|
|
defaults := map[string]int{
|
|
"apt": 30, // 30 minutes
|
|
"dnf": 240, // 4 hours
|
|
"docker": 120, // 2 hours
|
|
"storage": 360, // 6 hours
|
|
"windows": 480, // 8 hours
|
|
"winget": 360, // 6 hours
|
|
"updates": 15, // 15 minutes
|
|
"system": 30, // 30 minutes
|
|
}
|
|
|
|
if interval, exists := defaults[subsystem]; exists {
|
|
return interval
|
|
}
|
|
return 30 // Default fallback
|
|
}
|
|
|
|
// Start begins the scheduler main loop and workers
|
|
func (s *Scheduler) Start() error {
|
|
log.Printf("[Scheduler] Starting with %d workers, check interval %v\n",
|
|
s.config.NumWorkers, s.config.CheckInterval)
|
|
|
|
// Start workers
|
|
for _, w := range s.workers {
|
|
s.wg.Add(1)
|
|
go w.run()
|
|
}
|
|
|
|
// Start main loop
|
|
s.wg.Add(1)
|
|
go s.mainLoop()
|
|
|
|
log.Println("[Scheduler] Started successfully")
|
|
return nil
|
|
}
|
|
|
|
// Stop gracefully shuts down the scheduler
|
|
func (s *Scheduler) Stop() error {
|
|
log.Println("[Scheduler] Shutting down...")
|
|
|
|
// Signal shutdown
|
|
s.cancel()
|
|
close(s.shutdown)
|
|
|
|
// Close job channel (workers will drain and exit)
|
|
close(s.jobChan)
|
|
|
|
// Wait for all goroutines with timeout
|
|
done := make(chan struct{})
|
|
go func() {
|
|
s.wg.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
log.Println("[Scheduler] Shutdown complete")
|
|
return nil
|
|
case <-time.After(30 * time.Second):
|
|
log.Println("[Scheduler] Shutdown timeout - forcing exit")
|
|
return fmt.Errorf("shutdown timeout")
|
|
}
|
|
}
|
|
|
|
// mainLoop is the scheduler's main processing loop
|
|
func (s *Scheduler) mainLoop() {
|
|
defer s.wg.Done()
|
|
|
|
ticker := time.NewTicker(s.config.CheckInterval)
|
|
defer ticker.Stop()
|
|
|
|
log.Printf("[Scheduler] Main loop started (check every %v)\n", s.config.CheckInterval)
|
|
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
log.Println("[Scheduler] Main loop shutting down")
|
|
return
|
|
|
|
case <-ticker.C:
|
|
s.processQueue()
|
|
}
|
|
}
|
|
}
|
|
|
|
// processQueue checks for due jobs and dispatches them to workers
|
|
func (s *Scheduler) processQueue() {
|
|
start := time.Now()
|
|
|
|
// Get all jobs due within lookahead window
|
|
cutoff := time.Now().Add(s.config.LookaheadWindow)
|
|
dueJobs := s.queue.PopBefore(cutoff, 0) // No limit, get all
|
|
|
|
if len(dueJobs) == 0 {
|
|
// No jobs due, just update stats
|
|
s.mu.Lock()
|
|
s.stats.QueueSize = s.queue.Len()
|
|
s.mu.Unlock()
|
|
return
|
|
}
|
|
|
|
log.Printf("[Scheduler] Processing %d jobs due before %s\n",
|
|
len(dueJobs), cutoff.Format("15:04:05"))
|
|
|
|
// Add jitter to each job and dispatch to workers
|
|
dispatched := 0
|
|
for _, job := range dueJobs {
|
|
// Add random jitter (0 to MaxJitter)
|
|
jitter := time.Duration(rand.Intn(int(s.config.MaxJitter.Seconds()))) * time.Second
|
|
job.NextRunAt = job.NextRunAt.Add(jitter)
|
|
|
|
// Dispatch to worker pool (non-blocking)
|
|
select {
|
|
case s.jobChan <- job:
|
|
dispatched++
|
|
default:
|
|
// Worker pool full, re-queue job
|
|
log.Printf("[Scheduler] Worker pool full, re-queueing %s\n", job.String())
|
|
s.queue.Push(job)
|
|
|
|
s.mu.Lock()
|
|
s.stats.JobsSkipped++
|
|
s.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// Update stats
|
|
duration := time.Since(start)
|
|
s.mu.Lock()
|
|
s.stats.JobsProcessed += int64(dispatched)
|
|
s.stats.LastProcessedAt = time.Now()
|
|
s.stats.QueueSize = s.queue.Len()
|
|
s.stats.WorkerPoolUtilized = len(s.jobChan)
|
|
s.stats.AverageProcessingMS = duration.Milliseconds()
|
|
s.mu.Unlock()
|
|
|
|
log.Printf("[Scheduler] Dispatched %d jobs in %v (queue: %d remaining)\n",
|
|
dispatched, duration, s.queue.Len())
|
|
}
|
|
|
|
// refillRateLimiter continuously refills the rate limiter token bucket
|
|
func (s *Scheduler) refillRateLimiter() {
|
|
ticker := time.NewTicker(time.Second / time.Duration(s.config.RateLimitPerSecond))
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
case <-ticker.C:
|
|
// Try to add token (non-blocking)
|
|
select {
|
|
case s.rateLimiter <- struct{}{}:
|
|
default:
|
|
// Bucket full, skip
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetStats returns current scheduler statistics (thread-safe)
|
|
func (s *Scheduler) GetStats() Stats {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.stats
|
|
}
|
|
|
|
// GetQueueStats returns current queue statistics
|
|
func (s *Scheduler) GetQueueStats() QueueStats {
|
|
return s.queue.GetStats()
|
|
}
|
|
|
|
// worker processes jobs from the job channel
|
|
type worker struct {
|
|
id int
|
|
scheduler *Scheduler
|
|
}
|
|
|
|
func (w *worker) run() {
|
|
defer w.scheduler.wg.Done()
|
|
|
|
log.Printf("[Worker %d] Started\n", w.id)
|
|
|
|
for job := range w.scheduler.jobChan {
|
|
if err := w.processJob(job); err != nil {
|
|
log.Printf("[Worker %d] Failed to process %s: %v\n", w.id, job.String(), err)
|
|
|
|
w.scheduler.mu.Lock()
|
|
w.scheduler.stats.CommandsFailed++
|
|
w.scheduler.mu.Unlock()
|
|
} else {
|
|
w.scheduler.mu.Lock()
|
|
w.scheduler.stats.CommandsCreated++
|
|
w.scheduler.mu.Unlock()
|
|
}
|
|
|
|
// Re-queue job for next execution
|
|
job.NextRunAt = time.Now().Add(time.Duration(job.IntervalMinutes) * time.Minute)
|
|
w.scheduler.queue.Push(job)
|
|
}
|
|
|
|
log.Printf("[Worker %d] Stopped\n", w.id)
|
|
}
|
|
|
|
func (w *worker) processJob(job *SubsystemJob) error {
|
|
// Apply rate limiting if configured
|
|
if w.scheduler.rateLimiter != nil {
|
|
select {
|
|
case <-w.scheduler.rateLimiter:
|
|
// Token acquired
|
|
case <-w.scheduler.shutdown:
|
|
return fmt.Errorf("shutdown during rate limit wait")
|
|
}
|
|
}
|
|
|
|
// Check backpressure: skip if agent has too many pending commands
|
|
pendingCount, err := w.scheduler.commandQueries.CountPendingCommandsForAgent(job.AgentID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check pending commands: %w", err)
|
|
}
|
|
|
|
if pendingCount >= w.scheduler.config.BackpressureThreshold {
|
|
log.Printf("[Worker %d] Backpressure: agent %s has %d pending commands, skipping %s\n",
|
|
w.id, job.AgentHostname, pendingCount, job.Subsystem)
|
|
|
|
w.scheduler.mu.Lock()
|
|
w.scheduler.stats.BackpressureSkips++
|
|
w.scheduler.mu.Unlock()
|
|
|
|
return nil // Not an error, just skipped
|
|
}
|
|
|
|
// Create command
|
|
cmd := &models.AgentCommand{
|
|
ID: uuid.New(),
|
|
AgentID: job.AgentID,
|
|
CommandType: fmt.Sprintf("scan_%s", job.Subsystem),
|
|
Params: models.JSONB{},
|
|
Status: models.CommandStatusPending,
|
|
Source: models.CommandSourceSystem,
|
|
CreatedAt: time.Now(),
|
|
}
|
|
|
|
if err := w.scheduler.commandQueries.CreateCommand(cmd); err != nil {
|
|
return fmt.Errorf("failed to create command: %w", err)
|
|
}
|
|
|
|
log.Printf("[Worker %d] Created %s command for %s\n",
|
|
w.id, job.Subsystem, job.AgentHostname)
|
|
|
|
return nil
|
|
}
|