Files
Redflag/aggregator-server/internal/scheduler/scheduler.go
jpetree331 f97d4845af feat(security): A-1 Ed25519 key rotation + A-2 replay attack fixes
Complete RedFlag codebase with two major security audit implementations.

== A-1: Ed25519 Key Rotation Support ==

Server:
- SignCommand sets SignedAt timestamp and KeyID on every signature
- signing_keys database table (migration 020) for multi-key rotation
- InitializePrimaryKey registers active key at startup
- /api/v1/public-keys endpoint for rotation-aware agents
- SigningKeyQueries for key lifecycle management

Agent:
- Key-ID-aware verification via CheckKeyRotation
- FetchAndCacheAllActiveKeys for rotation pre-caching
- Cache metadata with TTL and staleness fallback
- SecurityLogger events for key rotation and command signing

== A-2: Replay Attack Fixes (F-1 through F-7) ==

F-5 CRITICAL - RetryCommand now signs via signAndCreateCommand
F-1 HIGH     - v3 format: "{agent_id}:{cmd_id}:{type}:{hash}:{ts}"
F-7 HIGH     - Migration 026: expires_at column with partial index
F-6 HIGH     - GetPendingCommands/GetStuckCommands filter by expires_at
F-2 HIGH     - Agent-side executedIDs dedup map with cleanup
F-4 HIGH     - commandMaxAge reduced from 24h to 4h
F-3 CRITICAL - Old-format commands rejected after 48h via CreatedAt

Verification fixes: migration idempotency (ETHOS #4), log format
compliance (ETHOS #1), stale comments updated.

All 24 tests passing. Docker --no-cache build verified.
See docs/ for full audit reports and deviation log (DEV-001 to DEV-019).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-28 21:25:47 -04:00

441 lines
11 KiB
Go

package scheduler
import (
"context"
"fmt"
"log"
"math/rand"
"sync"
"time"
"github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries"
"github.com/Fimeg/RedFlag/aggregator-server/internal/models"
"github.com/google/uuid"
)
// Config holds scheduler configuration
type Config struct {
// CheckInterval is how often to check the queue for due jobs
CheckInterval time.Duration
// LookaheadWindow is how far ahead to look for jobs
// Jobs due within this window will be batched and jittered
LookaheadWindow time.Duration
// MaxJitter is the maximum random delay added to job execution
MaxJitter time.Duration
// NumWorkers is the number of parallel workers for command creation
NumWorkers int
// BackpressureThreshold is max pending commands per agent before skipping
BackpressureThreshold int
// RateLimitPerSecond is max commands created per second (0 = unlimited)
RateLimitPerSecond int
}
// DefaultConfig returns production-ready default configuration
func DefaultConfig() Config {
return Config{
CheckInterval: 10 * time.Second,
LookaheadWindow: 60 * time.Second,
MaxJitter: 30 * time.Second,
NumWorkers: 10,
BackpressureThreshold: 5,
RateLimitPerSecond: 100,
}
}
// Scheduler manages subsystem job scheduling with priority queue and worker pool
type Scheduler struct {
config Config
queue *PriorityQueue
// Database queries
agentQueries *queries.AgentQueries
commandQueries *queries.CommandQueries
subsystemQueries *queries.SubsystemQueries
// Worker pool
jobChan chan *SubsystemJob
workers []*worker
// Rate limiting
rateLimiter chan struct{}
// Lifecycle management
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
shutdown chan struct{}
// Metrics
mu sync.RWMutex
stats Stats
}
// Stats holds scheduler statistics
type Stats struct {
JobsProcessed int64
JobsSkipped int64
CommandsCreated int64
CommandsFailed int64
BackpressureSkips int64
LastProcessedAt time.Time
QueueSize int
WorkerPoolUtilized int
AverageProcessingMS int64
}
// NewScheduler creates a new scheduler instance
func NewScheduler(config Config, agentQueries *queries.AgentQueries, commandQueries *queries.CommandQueries, subsystemQueries *queries.SubsystemQueries) *Scheduler {
ctx, cancel := context.WithCancel(context.Background())
s := &Scheduler{
config: config,
queue: NewPriorityQueue(),
agentQueries: agentQueries,
commandQueries: commandQueries,
subsystemQueries: subsystemQueries,
jobChan: make(chan *SubsystemJob, 1000), // Buffer 1000 jobs
workers: make([]*worker, config.NumWorkers),
shutdown: make(chan struct{}),
ctx: ctx,
cancel: cancel,
}
// Initialize rate limiter if configured
if config.RateLimitPerSecond > 0 {
s.rateLimiter = make(chan struct{}, config.RateLimitPerSecond)
go s.refillRateLimiter()
}
// Initialize workers
for i := 0; i < config.NumWorkers; i++ {
s.workers[i] = &worker{
id: i,
scheduler: s,
}
}
return s
}
// LoadSubsystems loads all enabled auto-run subsystems from database into queue
func (s *Scheduler) LoadSubsystems(ctx context.Context) error {
log.Println("[Scheduler] Loading subsystems from database...")
// Get all agents (pass empty strings to get all agents regardless of status/os)
agents, err := s.agentQueries.ListAgents("", "")
if err != nil {
return fmt.Errorf("failed to get agents: %w", err)
}
loaded := 0
for _, agent := range agents {
// Skip offline agents (haven't checked in for 10+ minutes)
if time.Since(agent.LastSeen) > 10*time.Minute {
continue
}
// Get subsystems from database (respect user settings)
dbSubsystems, err := s.subsystemQueries.GetSubsystems(agent.ID)
if err != nil {
log.Printf("[Scheduler] Failed to get subsystems for agent %s: %v", agent.Hostname, err)
continue
}
// Create jobs only for enabled subsystems with auto_run=true
for _, dbSub := range dbSubsystems {
if dbSub.Enabled && dbSub.AutoRun {
// Use database interval, fallback to default
intervalMinutes := dbSub.IntervalMinutes
if intervalMinutes <= 0 {
intervalMinutes = s.getDefaultInterval(dbSub.Subsystem)
}
var nextRun time.Time
if dbSub.NextRunAt != nil {
nextRun = *dbSub.NextRunAt
} else {
// If no next run is set, schedule it with default interval
nextRun = time.Now().Add(time.Duration(intervalMinutes) * time.Minute)
}
job := &SubsystemJob{
AgentID: agent.ID,
AgentHostname: agent.Hostname,
Subsystem: dbSub.Subsystem,
IntervalMinutes: intervalMinutes,
NextRunAt: nextRun,
Enabled: dbSub.Enabled,
}
s.queue.Push(job)
loaded++
}
}
}
log.Printf("[Scheduler] Loaded %d subsystem jobs for %d agents (respecting database settings)\n", loaded, len(agents))
return nil
}
// getDefaultInterval returns default interval minutes for a subsystem
// TODO: These intervals need to correlate with agent health scanning settings
// Each subsystem should be variable based on user-configurable agent health policies
func (s *Scheduler) getDefaultInterval(subsystem string) int {
defaults := map[string]int{
"apt": 30, // 30 minutes
"dnf": 240, // 4 hours
"docker": 120, // 2 hours
"storage": 360, // 6 hours
"windows": 480, // 8 hours
"winget": 360, // 6 hours
"updates": 15, // 15 minutes
"system": 30, // 30 minutes
}
if interval, exists := defaults[subsystem]; exists {
return interval
}
return 30 // Default fallback
}
// Start begins the scheduler main loop and workers
func (s *Scheduler) Start() error {
log.Printf("[Scheduler] Starting with %d workers, check interval %v\n",
s.config.NumWorkers, s.config.CheckInterval)
// Start workers
for _, w := range s.workers {
s.wg.Add(1)
go w.run()
}
// Start main loop
s.wg.Add(1)
go s.mainLoop()
log.Println("[Scheduler] Started successfully")
return nil
}
// Stop gracefully shuts down the scheduler
func (s *Scheduler) Stop() error {
log.Println("[Scheduler] Shutting down...")
// Signal shutdown
s.cancel()
close(s.shutdown)
// Close job channel (workers will drain and exit)
close(s.jobChan)
// Wait for all goroutines with timeout
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
log.Println("[Scheduler] Shutdown complete")
return nil
case <-time.After(30 * time.Second):
log.Println("[Scheduler] Shutdown timeout - forcing exit")
return fmt.Errorf("shutdown timeout")
}
}
// mainLoop is the scheduler's main processing loop
func (s *Scheduler) mainLoop() {
defer s.wg.Done()
ticker := time.NewTicker(s.config.CheckInterval)
defer ticker.Stop()
log.Printf("[Scheduler] Main loop started (check every %v)\n", s.config.CheckInterval)
for {
select {
case <-s.shutdown:
log.Println("[Scheduler] Main loop shutting down")
return
case <-ticker.C:
s.processQueue()
}
}
}
// processQueue checks for due jobs and dispatches them to workers
func (s *Scheduler) processQueue() {
start := time.Now()
// Get all jobs due within lookahead window
cutoff := time.Now().Add(s.config.LookaheadWindow)
dueJobs := s.queue.PopBefore(cutoff, 0) // No limit, get all
if len(dueJobs) == 0 {
// No jobs due, just update stats
s.mu.Lock()
s.stats.QueueSize = s.queue.Len()
s.mu.Unlock()
return
}
log.Printf("[Scheduler] Processing %d jobs due before %s\n",
len(dueJobs), cutoff.Format("15:04:05"))
// Add jitter to each job and dispatch to workers
dispatched := 0
for _, job := range dueJobs {
// Add random jitter (0 to MaxJitter)
jitter := time.Duration(rand.Intn(int(s.config.MaxJitter.Seconds()))) * time.Second
job.NextRunAt = job.NextRunAt.Add(jitter)
// Dispatch to worker pool (non-blocking)
select {
case s.jobChan <- job:
dispatched++
default:
// Worker pool full, re-queue job
log.Printf("[Scheduler] Worker pool full, re-queueing %s\n", job.String())
s.queue.Push(job)
s.mu.Lock()
s.stats.JobsSkipped++
s.mu.Unlock()
}
}
// Update stats
duration := time.Since(start)
s.mu.Lock()
s.stats.JobsProcessed += int64(dispatched)
s.stats.LastProcessedAt = time.Now()
s.stats.QueueSize = s.queue.Len()
s.stats.WorkerPoolUtilized = len(s.jobChan)
s.stats.AverageProcessingMS = duration.Milliseconds()
s.mu.Unlock()
log.Printf("[Scheduler] Dispatched %d jobs in %v (queue: %d remaining)\n",
dispatched, duration, s.queue.Len())
}
// refillRateLimiter continuously refills the rate limiter token bucket
func (s *Scheduler) refillRateLimiter() {
ticker := time.NewTicker(time.Second / time.Duration(s.config.RateLimitPerSecond))
defer ticker.Stop()
for {
select {
case <-s.shutdown:
return
case <-ticker.C:
// Try to add token (non-blocking)
select {
case s.rateLimiter <- struct{}{}:
default:
// Bucket full, skip
}
}
}
}
// GetStats returns current scheduler statistics (thread-safe)
func (s *Scheduler) GetStats() Stats {
s.mu.RLock()
defer s.mu.RUnlock()
return s.stats
}
// GetQueueStats returns current queue statistics
func (s *Scheduler) GetQueueStats() QueueStats {
return s.queue.GetStats()
}
// worker processes jobs from the job channel
type worker struct {
id int
scheduler *Scheduler
}
func (w *worker) run() {
defer w.scheduler.wg.Done()
log.Printf("[Worker %d] Started\n", w.id)
for job := range w.scheduler.jobChan {
if err := w.processJob(job); err != nil {
log.Printf("[Worker %d] Failed to process %s: %v\n", w.id, job.String(), err)
w.scheduler.mu.Lock()
w.scheduler.stats.CommandsFailed++
w.scheduler.mu.Unlock()
} else {
w.scheduler.mu.Lock()
w.scheduler.stats.CommandsCreated++
w.scheduler.mu.Unlock()
}
// Re-queue job for next execution
job.NextRunAt = time.Now().Add(time.Duration(job.IntervalMinutes) * time.Minute)
w.scheduler.queue.Push(job)
}
log.Printf("[Worker %d] Stopped\n", w.id)
}
func (w *worker) processJob(job *SubsystemJob) error {
// Apply rate limiting if configured
if w.scheduler.rateLimiter != nil {
select {
case <-w.scheduler.rateLimiter:
// Token acquired
case <-w.scheduler.shutdown:
return fmt.Errorf("shutdown during rate limit wait")
}
}
// Check backpressure: skip if agent has too many pending commands
pendingCount, err := w.scheduler.commandQueries.CountPendingCommandsForAgent(job.AgentID)
if err != nil {
return fmt.Errorf("failed to check pending commands: %w", err)
}
if pendingCount >= w.scheduler.config.BackpressureThreshold {
log.Printf("[Worker %d] Backpressure: agent %s has %d pending commands, skipping %s\n",
w.id, job.AgentHostname, pendingCount, job.Subsystem)
w.scheduler.mu.Lock()
w.scheduler.stats.BackpressureSkips++
w.scheduler.mu.Unlock()
return nil // Not an error, just skipped
}
// Create command
cmd := &models.AgentCommand{
ID: uuid.New(),
AgentID: job.AgentID,
CommandType: fmt.Sprintf("scan_%s", job.Subsystem),
Params: models.JSONB{},
Status: models.CommandStatusPending,
Source: models.CommandSourceSystem,
CreatedAt: time.Now(),
}
if err := w.scheduler.commandQueries.CreateCommand(cmd); err != nil {
return fmt.Errorf("failed to create command: %w", err)
}
log.Printf("[Worker %d] Created %s command for %s\n",
w.id, job.Subsystem, job.AgentHostname)
return nil
}