package scheduler import ( "context" "fmt" "log" "math/rand" "sync" "time" "github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries" "github.com/Fimeg/RedFlag/aggregator-server/internal/models" "github.com/google/uuid" ) // Config holds scheduler configuration type Config struct { // CheckInterval is how often to check the queue for due jobs CheckInterval time.Duration // LookaheadWindow is how far ahead to look for jobs // Jobs due within this window will be batched and jittered LookaheadWindow time.Duration // MaxJitter is the maximum random delay added to job execution MaxJitter time.Duration // NumWorkers is the number of parallel workers for command creation NumWorkers int // BackpressureThreshold is max pending commands per agent before skipping BackpressureThreshold int // RateLimitPerSecond is max commands created per second (0 = unlimited) RateLimitPerSecond int } // DefaultConfig returns production-ready default configuration func DefaultConfig() Config { return Config{ CheckInterval: 10 * time.Second, LookaheadWindow: 60 * time.Second, MaxJitter: 30 * time.Second, NumWorkers: 10, BackpressureThreshold: 5, RateLimitPerSecond: 100, } } // Scheduler manages subsystem job scheduling with priority queue and worker pool type Scheduler struct { config Config queue *PriorityQueue // Database queries agentQueries *queries.AgentQueries commandQueries *queries.CommandQueries // Worker pool jobChan chan *SubsystemJob workers []*worker // Rate limiting rateLimiter chan struct{} // Lifecycle management ctx context.Context cancel context.CancelFunc wg sync.WaitGroup shutdown chan struct{} // Metrics mu sync.RWMutex stats Stats } // Stats holds scheduler statistics type Stats struct { JobsProcessed int64 JobsSkipped int64 CommandsCreated int64 CommandsFailed int64 BackpressureSkips int64 LastProcessedAt time.Time QueueSize int WorkerPoolUtilized int AverageProcessingMS int64 } // NewScheduler creates a new scheduler instance func NewScheduler(config Config, agentQueries *queries.AgentQueries, commandQueries *queries.CommandQueries) *Scheduler { ctx, cancel := context.WithCancel(context.Background()) s := &Scheduler{ config: config, queue: NewPriorityQueue(), agentQueries: agentQueries, commandQueries: commandQueries, jobChan: make(chan *SubsystemJob, 1000), // Buffer 1000 jobs workers: make([]*worker, config.NumWorkers), shutdown: make(chan struct{}), ctx: ctx, cancel: cancel, } // Initialize rate limiter if configured if config.RateLimitPerSecond > 0 { s.rateLimiter = make(chan struct{}, config.RateLimitPerSecond) go s.refillRateLimiter() } // Initialize workers for i := 0; i < config.NumWorkers; i++ { s.workers[i] = &worker{ id: i, scheduler: s, } } return s } // LoadSubsystems loads all enabled auto-run subsystems from database into queue func (s *Scheduler) LoadSubsystems(ctx context.Context) error { log.Println("[Scheduler] Loading subsystems from database...") // Get all agents (pass empty strings to get all agents regardless of status/os) agents, err := s.agentQueries.ListAgents("", "") if err != nil { return fmt.Errorf("failed to get agents: %w", err) } // For now, we'll create default subsystems for each agent // In full implementation, this would read from agent_subsystems table subsystems := []string{"updates", "storage", "system", "docker"} intervals := map[string]int{ "updates": 15, // 15 minutes "storage": 15, "system": 30, "docker": 15, } loaded := 0 for _, agent := range agents { // Skip offline agents (haven't checked in for 10+ minutes) if time.Since(agent.LastSeen) > 10*time.Minute { continue } for _, subsystem := range subsystems { // TODO: Check agent metadata for subsystem enablement // For now, assume all subsystems are enabled job := &SubsystemJob{ AgentID: agent.ID, AgentHostname: agent.Hostname, Subsystem: subsystem, IntervalMinutes: intervals[subsystem], NextRunAt: time.Now().Add(time.Duration(intervals[subsystem]) * time.Minute), Enabled: true, } s.queue.Push(job) loaded++ } } log.Printf("[Scheduler] Loaded %d subsystem jobs for %d agents\n", loaded, len(agents)) return nil } // Start begins the scheduler main loop and workers func (s *Scheduler) Start() error { log.Printf("[Scheduler] Starting with %d workers, check interval %v\n", s.config.NumWorkers, s.config.CheckInterval) // Start workers for _, w := range s.workers { s.wg.Add(1) go w.run() } // Start main loop s.wg.Add(1) go s.mainLoop() log.Println("[Scheduler] Started successfully") return nil } // Stop gracefully shuts down the scheduler func (s *Scheduler) Stop() error { log.Println("[Scheduler] Shutting down...") // Signal shutdown s.cancel() close(s.shutdown) // Close job channel (workers will drain and exit) close(s.jobChan) // Wait for all goroutines with timeout done := make(chan struct{}) go func() { s.wg.Wait() close(done) }() select { case <-done: log.Println("[Scheduler] Shutdown complete") return nil case <-time.After(30 * time.Second): log.Println("[Scheduler] Shutdown timeout - forcing exit") return fmt.Errorf("shutdown timeout") } } // mainLoop is the scheduler's main processing loop func (s *Scheduler) mainLoop() { defer s.wg.Done() ticker := time.NewTicker(s.config.CheckInterval) defer ticker.Stop() log.Printf("[Scheduler] Main loop started (check every %v)\n", s.config.CheckInterval) for { select { case <-s.shutdown: log.Println("[Scheduler] Main loop shutting down") return case <-ticker.C: s.processQueue() } } } // processQueue checks for due jobs and dispatches them to workers func (s *Scheduler) processQueue() { start := time.Now() // Get all jobs due within lookahead window cutoff := time.Now().Add(s.config.LookaheadWindow) dueJobs := s.queue.PopBefore(cutoff, 0) // No limit, get all if len(dueJobs) == 0 { // No jobs due, just update stats s.mu.Lock() s.stats.QueueSize = s.queue.Len() s.mu.Unlock() return } log.Printf("[Scheduler] Processing %d jobs due before %s\n", len(dueJobs), cutoff.Format("15:04:05")) // Add jitter to each job and dispatch to workers dispatched := 0 for _, job := range dueJobs { // Add random jitter (0 to MaxJitter) jitter := time.Duration(rand.Intn(int(s.config.MaxJitter.Seconds()))) * time.Second job.NextRunAt = job.NextRunAt.Add(jitter) // Dispatch to worker pool (non-blocking) select { case s.jobChan <- job: dispatched++ default: // Worker pool full, re-queue job log.Printf("[Scheduler] Worker pool full, re-queueing %s\n", job.String()) s.queue.Push(job) s.mu.Lock() s.stats.JobsSkipped++ s.mu.Unlock() } } // Update stats duration := time.Since(start) s.mu.Lock() s.stats.JobsProcessed += int64(dispatched) s.stats.LastProcessedAt = time.Now() s.stats.QueueSize = s.queue.Len() s.stats.WorkerPoolUtilized = len(s.jobChan) s.stats.AverageProcessingMS = duration.Milliseconds() s.mu.Unlock() log.Printf("[Scheduler] Dispatched %d jobs in %v (queue: %d remaining)\n", dispatched, duration, s.queue.Len()) } // refillRateLimiter continuously refills the rate limiter token bucket func (s *Scheduler) refillRateLimiter() { ticker := time.NewTicker(time.Second / time.Duration(s.config.RateLimitPerSecond)) defer ticker.Stop() for { select { case <-s.shutdown: return case <-ticker.C: // Try to add token (non-blocking) select { case s.rateLimiter <- struct{}{}: default: // Bucket full, skip } } } } // GetStats returns current scheduler statistics (thread-safe) func (s *Scheduler) GetStats() Stats { s.mu.RLock() defer s.mu.RUnlock() return s.stats } // GetQueueStats returns current queue statistics func (s *Scheduler) GetQueueStats() QueueStats { return s.queue.GetStats() } // worker processes jobs from the job channel type worker struct { id int scheduler *Scheduler } func (w *worker) run() { defer w.scheduler.wg.Done() log.Printf("[Worker %d] Started\n", w.id) for job := range w.scheduler.jobChan { if err := w.processJob(job); err != nil { log.Printf("[Worker %d] Failed to process %s: %v\n", w.id, job.String(), err) w.scheduler.mu.Lock() w.scheduler.stats.CommandsFailed++ w.scheduler.mu.Unlock() } else { w.scheduler.mu.Lock() w.scheduler.stats.CommandsCreated++ w.scheduler.mu.Unlock() } // Re-queue job for next execution job.NextRunAt = time.Now().Add(time.Duration(job.IntervalMinutes) * time.Minute) w.scheduler.queue.Push(job) } log.Printf("[Worker %d] Stopped\n", w.id) } func (w *worker) processJob(job *SubsystemJob) error { // Apply rate limiting if configured if w.scheduler.rateLimiter != nil { select { case <-w.scheduler.rateLimiter: // Token acquired case <-w.scheduler.shutdown: return fmt.Errorf("shutdown during rate limit wait") } } // Check backpressure: skip if agent has too many pending commands pendingCount, err := w.scheduler.commandQueries.CountPendingCommandsForAgent(job.AgentID) if err != nil { return fmt.Errorf("failed to check pending commands: %w", err) } if pendingCount >= w.scheduler.config.BackpressureThreshold { log.Printf("[Worker %d] Backpressure: agent %s has %d pending commands, skipping %s\n", w.id, job.AgentHostname, pendingCount, job.Subsystem) w.scheduler.mu.Lock() w.scheduler.stats.BackpressureSkips++ w.scheduler.mu.Unlock() return nil // Not an error, just skipped } // Create command cmd := &models.AgentCommand{ ID: uuid.New(), AgentID: job.AgentID, CommandType: fmt.Sprintf("scan_%s", job.Subsystem), Params: models.JSONB{}, Status: models.CommandStatusPending, Source: models.CommandSourceSystem, CreatedAt: time.Now(), } if err := w.scheduler.commandQueries.CreateCommand(cmd); err != nil { return fmt.Errorf("failed to create command: %w", err) } log.Printf("[Worker %d] Created %s command for %s\n", w.id, job.Subsystem, job.AgentHostname) return nil }