fix: agent acknowledgment recursion and subsystem UI improvements
- Fix recursive call in reportLogWithAck that caused infinite loop - Add machine binding and security API endpoints - Enhance AgentScanners component with security status display - Update scheduler and timeout service reliability - Remove deprecated install.sh script - Add subsystem configuration and logging improvements
This commit is contained in:
@@ -106,7 +106,8 @@ func (h *AgentHandler) RegisterAgent(c *gin.Context) {
|
||||
|
||||
// Save to database
|
||||
if err := h.agentQueries.CreateAgent(agent); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to register agent"})
|
||||
log.Printf("ERROR: Failed to create agent in database: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to register agent - database error"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -163,16 +164,17 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
|
||||
// Try to parse optional system metrics from request body
|
||||
var metrics struct {
|
||||
CPUPercent float64 `json:"cpu_percent,omitempty"`
|
||||
MemoryPercent float64 `json:"memory_percent,omitempty"`
|
||||
MemoryUsedGB float64 `json:"memory_used_gb,omitempty"`
|
||||
MemoryTotalGB float64 `json:"memory_total_gb,omitempty"`
|
||||
DiskUsedGB float64 `json:"disk_used_gb,omitempty"`
|
||||
DiskTotalGB float64 `json:"disk_total_gb,omitempty"`
|
||||
DiskPercent float64 `json:"disk_percent,omitempty"`
|
||||
Uptime string `json:"uptime,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
CPUPercent float64 `json:"cpu_percent,omitempty"`
|
||||
MemoryPercent float64 `json:"memory_percent,omitempty"`
|
||||
MemoryUsedGB float64 `json:"memory_used_gb,omitempty"`
|
||||
MemoryTotalGB float64 `json:"memory_total_gb,omitempty"`
|
||||
DiskUsedGB float64 `json:"disk_used_gb,omitempty"`
|
||||
DiskTotalGB float64 `json:"disk_total_gb,omitempty"`
|
||||
DiskPercent float64 `json:"disk_percent,omitempty"`
|
||||
Uptime string `json:"uptime,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
PendingAcknowledgments []string `json:"pending_acknowledgments,omitempty"`
|
||||
}
|
||||
|
||||
// Parse metrics if provided (optional, won't fail if empty)
|
||||
@@ -449,10 +451,27 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Process command acknowledgments from agent
|
||||
var acknowledgedIDs []string
|
||||
if len(metrics.PendingAcknowledgments) > 0 {
|
||||
log.Printf("DEBUG: Processing %d pending acknowledgments for agent %s: %v", len(metrics.PendingAcknowledgments), agentID, metrics.PendingAcknowledgments)
|
||||
// Verify which commands from agent's pending list have been recorded
|
||||
verified, err := h.commandQueries.VerifyCommandsCompleted(metrics.PendingAcknowledgments)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to verify command acknowledgments for agent %s: %v", agentID, err)
|
||||
} else {
|
||||
acknowledgedIDs = verified
|
||||
log.Printf("DEBUG: Verified %d completed commands out of %d pending for agent %s", len(acknowledgedIDs), len(metrics.PendingAcknowledgments), agentID)
|
||||
if len(acknowledgedIDs) > 0 {
|
||||
log.Printf("Acknowledged %d command results for agent %s", len(acknowledgedIDs), agentID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response := models.CommandsResponse{
|
||||
Commands: commandItems,
|
||||
RapidPolling: rapidPolling,
|
||||
AcknowledgedIDs: []string{}, // No acknowledgments in current implementation
|
||||
AcknowledgedIDs: acknowledgedIDs,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
@@ -465,7 +484,8 @@ func (h *AgentHandler) ListAgents(c *gin.Context) {
|
||||
|
||||
agents, err := h.agentQueries.ListAgentsWithLastScan(status, osType)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list agents"})
|
||||
log.Printf("ERROR: Failed to list agents: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list agents - database error"})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -31,28 +31,24 @@ func (h *DownloadHandler) getServerURL(c *gin.Context) string {
|
||||
return h.config.Server.PublicURL
|
||||
}
|
||||
|
||||
// Priority 2: Detect from request with TLS/proxy awareness
|
||||
// Priority 2: Construct API server URL from configuration
|
||||
scheme := "http"
|
||||
host := h.config.Server.Host
|
||||
port := h.config.Server.Port
|
||||
|
||||
// Check if TLS is enabled in config
|
||||
// Use HTTPS if TLS is enabled in config
|
||||
if h.config.Server.TLS.Enabled {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
// Check if request came through HTTPS (direct or via proxy)
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
// For default host (0.0.0.0), use localhost for client connections
|
||||
if host == "0.0.0.0" {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
// Check X-Forwarded-Proto for reverse proxy setups
|
||||
if forwardedProto := c.GetHeader("X-Forwarded-Proto"); forwardedProto == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
// Use the Host header exactly as received (includes port if present)
|
||||
host := c.GetHeader("X-Forwarded-Host")
|
||||
if host == "" {
|
||||
host = c.Request.Host
|
||||
// Only include port if it's not the default for the protocol
|
||||
if (scheme == "http" && port != 80) || (scheme == "https" && port != 443) {
|
||||
return fmt.Sprintf("%s://%s:%d", scheme, host, port)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s", scheme, host)
|
||||
@@ -155,6 +151,7 @@ AGENT_BINARY="/usr/local/bin/redflag-agent"
|
||||
SUDOERS_FILE="/etc/sudoers.d/redflag-agent"
|
||||
SERVICE_FILE="/etc/systemd/system/redflag-agent.service"
|
||||
CONFIG_DIR="/etc/aggregator"
|
||||
STATE_DIR="/var/lib/aggregator"
|
||||
|
||||
echo "=== RedFlag Agent Installation ==="
|
||||
echo ""
|
||||
@@ -301,19 +298,24 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Step 4: Create configuration directory
|
||||
# Step 4: Create configuration and state directories
|
||||
echo ""
|
||||
echo "Step 4: Creating configuration directory..."
|
||||
echo "Step 4: Creating configuration and state directories..."
|
||||
mkdir -p "$CONFIG_DIR"
|
||||
chown "$AGENT_USER:$AGENT_USER" "$CONFIG_DIR"
|
||||
chmod 755 "$CONFIG_DIR"
|
||||
echo "✓ Configuration directory created"
|
||||
|
||||
# Set SELinux context for config directory if SELinux is enabled
|
||||
# Create state directory for acknowledgment tracking (v0.1.19+)
|
||||
mkdir -p "$STATE_DIR"
|
||||
chown "$AGENT_USER:$AGENT_USER" "$STATE_DIR"
|
||||
chmod 755 "$STATE_DIR"
|
||||
echo "✓ Configuration and state directories created"
|
||||
|
||||
# Set SELinux context for directories if SELinux is enabled
|
||||
if command -v getenforce >/dev/null 2>&1 && [ "$(getenforce)" != "Disabled" ]; then
|
||||
echo "Setting SELinux context for config directory..."
|
||||
restorecon -Rv "$CONFIG_DIR" 2>/dev/null || true
|
||||
echo "✓ SELinux context set for config directory"
|
||||
echo "Setting SELinux context for directories..."
|
||||
restorecon -Rv "$CONFIG_DIR" "$STATE_DIR" 2>/dev/null || true
|
||||
echo "✓ SELinux context set for directories"
|
||||
fi
|
||||
|
||||
# Step 5: Install systemd service
|
||||
@@ -338,7 +340,7 @@ RestartSec=30
|
||||
# NoNewPrivileges=true - DISABLED: Prevents sudo from working, which agent needs for package management
|
||||
ProtectSystem=strict
|
||||
ProtectHome=true
|
||||
ReadWritePaths=$AGENT_HOME /var/log $CONFIG_DIR
|
||||
ReadWritePaths=$AGENT_HOME /var/log $CONFIG_DIR $STATE_DIR
|
||||
PrivateTmp=true
|
||||
|
||||
# Logging
|
||||
|
||||
@@ -387,9 +387,15 @@ func (h *SetupHandler) ConfigureServer(c *gin.Context) {
|
||||
fmt.Println("Updating PostgreSQL password from bootstrap to user-provided password...")
|
||||
bootstrapPassword := "redflag_bootstrap" // This matches our bootstrap .env
|
||||
if err := updatePostgresPassword(req.DBHost, req.DBPort, req.DBUser, bootstrapPassword, req.DBPassword); err != nil {
|
||||
fmt.Printf("Warning: Failed to update PostgreSQL password: %v\n", err)
|
||||
fmt.Println("Will proceed with configuration anyway...")
|
||||
fmt.Printf("CRITICAL ERROR: Failed to update PostgreSQL password: %v\n", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Failed to update database password. Setup cannot continue.",
|
||||
"details": err.Error(),
|
||||
"help": "Ensure PostgreSQL is accessible and the bootstrap password is correct. Check Docker logs for details.",
|
||||
})
|
||||
return
|
||||
}
|
||||
fmt.Println("PostgreSQL password successfully updated from bootstrap to user-provided password")
|
||||
|
||||
// Step 2: Generate configuration content for manual update
|
||||
fmt.Println("Generating configuration content for manual .env file update...")
|
||||
@@ -414,6 +420,11 @@ func (h *SetupHandler) ConfigureServer(c *gin.Context) {
|
||||
|
||||
// GenerateSigningKeys generates Ed25519 keypair for agent update signing
|
||||
func (h *SetupHandler) GenerateSigningKeys(c *gin.Context) {
|
||||
// Prevent caching of generated keys (security critical)
|
||||
c.Header("Cache-Control", "no-store, no-cache, must-revalidate, private")
|
||||
c.Header("Pragma", "no-cache")
|
||||
c.Header("Expires", "0")
|
||||
|
||||
// Generate Ed25519 keypair
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
@@ -428,6 +439,9 @@ func (h *SetupHandler) GenerateSigningKeys(c *gin.Context) {
|
||||
// Generate fingerprint (first 16 chars)
|
||||
fingerprint := publicKeyHex[:16]
|
||||
|
||||
// Log key generation for security audit trail (only fingerprint, not full key)
|
||||
fmt.Printf("Generated new Ed25519 keypair - Fingerprint: %s\n", fingerprint)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"public_key": publicKeyHex,
|
||||
"private_key": privateKeyHex,
|
||||
|
||||
@@ -13,6 +13,16 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// isValidResult checks if the result value complies with the database constraint
|
||||
func isValidResult(result string) bool {
|
||||
validResults := map[string]bool{
|
||||
"success": true,
|
||||
"failed": true,
|
||||
"partial": true,
|
||||
}
|
||||
return validResults[result]
|
||||
}
|
||||
|
||||
type UpdateHandler struct {
|
||||
updateQueries *queries.UpdateQueries
|
||||
agentQueries *queries.AgentQueries
|
||||
@@ -199,11 +209,22 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate and map result to comply with database constraint
|
||||
validResult := req.Result
|
||||
if !isValidResult(validResult) {
|
||||
// Map invalid results to valid ones (e.g., "timed_out" -> "failed")
|
||||
if validResult == "timed_out" || validResult == "timeout" || validResult == "cancelled" {
|
||||
validResult = "failed"
|
||||
} else {
|
||||
validResult = "failed" // Default to failed for any unknown status
|
||||
}
|
||||
}
|
||||
|
||||
logEntry := &models.UpdateLog{
|
||||
ID: uuid.New(),
|
||||
AgentID: agentID,
|
||||
Action: req.Action,
|
||||
Result: req.Result,
|
||||
Result: validResult,
|
||||
Stdout: req.Stdout,
|
||||
Stderr: req.Stderr,
|
||||
ExitCode: req.ExitCode,
|
||||
@@ -831,8 +852,8 @@ func (h *UpdateHandler) ClearFailedCommands(c *gin.Context) {
|
||||
|
||||
// Build the appropriate cleanup query based on parameters
|
||||
if allFailed {
|
||||
// Clear ALL failed commands (most aggressive)
|
||||
count, err = h.commandQueries.ClearAllFailedCommands(olderThanDays)
|
||||
// Clear ALL failed commands regardless of age (most aggressive)
|
||||
count, err = h.commandQueries.ClearAllFailedCommandsRegardlessOfAge()
|
||||
} else if onlyRetried {
|
||||
// Clear only failed commands that have been retried
|
||||
count, err = h.commandQueries.ClearRetriedFailedCommands(olderThanDays)
|
||||
|
||||
@@ -88,6 +88,14 @@ func Load() (*Config, error) {
|
||||
cfg.Timezone = getEnv("TIMEZONE", "UTC")
|
||||
cfg.LatestAgentVersion = getEnv("LATEST_AGENT_VERSION", "0.1.22")
|
||||
cfg.MinAgentVersion = getEnv("MIN_AGENT_VERSION", "0.1.22")
|
||||
cfg.SigningPrivateKey = getEnv("REDFLAG_SIGNING_PRIVATE_KEY", "")
|
||||
|
||||
// Debug: Log signing key status
|
||||
if cfg.SigningPrivateKey != "" {
|
||||
fmt.Printf("[CONFIG] ✅ Ed25519 signing private key configured (%d characters)\n", len(cfg.SigningPrivateKey))
|
||||
} else {
|
||||
fmt.Printf("[CONFIG] ❌ No Ed25519 signing private key found in REDFLAG_SIGNING_PRIVATE_KEY\n")
|
||||
}
|
||||
|
||||
// Handle missing secrets
|
||||
if cfg.Admin.Password == "" || cfg.Admin.JWTSecret == "" || cfg.Database.Password == "" {
|
||||
|
||||
@@ -23,14 +23,19 @@ func (q *AgentQueries) CreateAgent(agent *models.Agent) error {
|
||||
query := `
|
||||
INSERT INTO agents (
|
||||
id, hostname, os_type, os_version, os_architecture,
|
||||
agent_version, last_seen, status, metadata
|
||||
agent_version, current_version, machine_id, public_key_fingerprint,
|
||||
last_seen, status, metadata
|
||||
) VALUES (
|
||||
:id, :hostname, :os_type, :os_version, :os_architecture,
|
||||
:agent_version, :last_seen, :status, :metadata
|
||||
:agent_version, :current_version, :machine_id, :public_key_fingerprint,
|
||||
:last_seen, :status, :metadata
|
||||
)
|
||||
`
|
||||
_, err := q.db.NamedExec(query, agent)
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create agent %s (version %s): %w", agent.Hostname, agent.CurrentVersion, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAgentByID retrieves an agent by ID
|
||||
|
||||
@@ -2,6 +2,7 @@ package queries
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Fimeg/RedFlag/aggregator-server/internal/models"
|
||||
@@ -31,13 +32,14 @@ func (q *CommandQueries) CreateCommand(cmd *models.AgentCommand) error {
|
||||
}
|
||||
|
||||
// GetPendingCommands retrieves pending commands for an agent
|
||||
// Only returns 'pending' status - 'sent' commands are handled by timeout service
|
||||
func (q *CommandQueries) GetPendingCommands(agentID uuid.UUID) ([]models.AgentCommand, error) {
|
||||
var commands []models.AgentCommand
|
||||
query := `
|
||||
SELECT * FROM agent_commands
|
||||
WHERE agent_id = $1 AND status = 'pending'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 10
|
||||
LIMIT 100
|
||||
`
|
||||
err := q.db.Select(&commands, query, agentID)
|
||||
return commands, err
|
||||
@@ -338,6 +340,23 @@ func (q *CommandQueries) ClearAllFailedCommands(days int) (int64, error) {
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// ClearAllFailedCommandsRegardlessOfAge archives ALL failed/timed_out commands regardless of age
|
||||
// This is used when all_failed=true is passed to truly clear all failed commands
|
||||
func (q *CommandQueries) ClearAllFailedCommandsRegardlessOfAge() (int64, error) {
|
||||
query := `
|
||||
UPDATE agent_commands
|
||||
SET status = 'archived_failed'
|
||||
WHERE status IN ('failed', 'timed_out')
|
||||
`
|
||||
|
||||
result, err := q.db.Exec(query)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to archive all failed commands regardless of age: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// CountPendingCommandsForAgent returns the number of pending commands for a specific agent
|
||||
// Used by scheduler for backpressure detection
|
||||
func (q *CommandQueries) CountPendingCommandsForAgent(agentID uuid.UUID) (int, error) {
|
||||
@@ -373,16 +392,30 @@ func (q *CommandQueries) VerifyCommandsCompleted(commandIDs []string) ([]string,
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// Convert UUIDs back to strings for SQL query
|
||||
uuidStrs := make([]string, len(uuidIDs))
|
||||
for i, id := range uuidIDs {
|
||||
uuidStrs[i] = id.String()
|
||||
}
|
||||
|
||||
// Query for commands that are completed or failed
|
||||
query := `
|
||||
// Use ANY with proper array literal for PostgreSQL
|
||||
placeholders := make([]string, len(uuidStrs))
|
||||
args := make([]interface{}, len(uuidStrs))
|
||||
for i, id := range uuidStrs {
|
||||
placeholders[i] = fmt.Sprintf("$%d", i+1)
|
||||
args[i] = id
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id
|
||||
FROM agent_commands
|
||||
WHERE id = ANY($1)
|
||||
AND status IN ('completed', 'failed')
|
||||
`
|
||||
WHERE id::text = ANY(%s)
|
||||
AND status IN ('completed', 'failed', 'timed_out')
|
||||
`, fmt.Sprintf("ARRAY[%s]", strings.Join(placeholders, ",")))
|
||||
|
||||
var completedUUIDs []uuid.UUID
|
||||
err := q.db.Select(&completedUUIDs, query, uuidIDs)
|
||||
err := q.db.Select(&completedUUIDs, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify command completion: %w", err)
|
||||
}
|
||||
|
||||
@@ -45,6 +45,11 @@ type AgentWithLastScan struct {
|
||||
CurrentVersion string `json:"current_version" db:"current_version"` // Current running version
|
||||
UpdateAvailable bool `json:"update_available" db:"update_available"` // Whether update is available
|
||||
LastVersionCheck time.Time `json:"last_version_check" db:"last_version_check"` // Last time version was checked
|
||||
MachineID *string `json:"machine_id,omitempty" db:"machine_id"` // Unique machine identifier
|
||||
PublicKeyFingerprint *string `json:"public_key_fingerprint,omitempty" db:"public_key_fingerprint"` // Public key fingerprint
|
||||
IsUpdating bool `json:"is_updating" db:"is_updating"` // Whether agent is currently updating
|
||||
UpdatingToVersion *string `json:"updating_to_version,omitempty" db:"updating_to_version"` // Target version for ongoing update
|
||||
UpdateInitiatedAt *time.Time `json:"update_initiated_at,omitempty" db:"update_initiated_at"` // When update process started
|
||||
LastSeen time.Time `json:"last_seen" db:"last_seen"`
|
||||
Status string `json:"status" db:"status"`
|
||||
Metadata JSONB `json:"metadata" db:"metadata"`
|
||||
|
||||
@@ -53,8 +53,9 @@ type Scheduler struct {
|
||||
queue *PriorityQueue
|
||||
|
||||
// Database queries
|
||||
agentQueries *queries.AgentQueries
|
||||
commandQueries *queries.CommandQueries
|
||||
agentQueries *queries.AgentQueries
|
||||
commandQueries *queries.CommandQueries
|
||||
subsystemQueries *queries.SubsystemQueries
|
||||
|
||||
// Worker pool
|
||||
jobChan chan *SubsystemJob
|
||||
@@ -88,19 +89,20 @@ type Stats struct {
|
||||
}
|
||||
|
||||
// NewScheduler creates a new scheduler instance
|
||||
func NewScheduler(config Config, agentQueries *queries.AgentQueries, commandQueries *queries.CommandQueries) *Scheduler {
|
||||
func NewScheduler(config Config, agentQueries *queries.AgentQueries, commandQueries *queries.CommandQueries, subsystemQueries *queries.SubsystemQueries) *Scheduler {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
s := &Scheduler{
|
||||
config: config,
|
||||
queue: NewPriorityQueue(),
|
||||
agentQueries: agentQueries,
|
||||
commandQueries: commandQueries,
|
||||
jobChan: make(chan *SubsystemJob, 1000), // Buffer 1000 jobs
|
||||
workers: make([]*worker, config.NumWorkers),
|
||||
shutdown: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
config: config,
|
||||
queue: NewPriorityQueue(),
|
||||
agentQueries: agentQueries,
|
||||
commandQueries: commandQueries,
|
||||
subsystemQueries: subsystemQueries,
|
||||
jobChan: make(chan *SubsystemJob, 1000), // Buffer 1000 jobs
|
||||
workers: make([]*worker, config.NumWorkers),
|
||||
shutdown: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Initialize rate limiter if configured
|
||||
@@ -130,16 +132,6 @@ func (s *Scheduler) LoadSubsystems(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to get agents: %w", err)
|
||||
}
|
||||
|
||||
// For now, we'll create default subsystems for each agent
|
||||
// In full implementation, this would read from agent_subsystems table
|
||||
subsystems := []string{"updates", "storage", "system", "docker"}
|
||||
intervals := map[string]int{
|
||||
"updates": 15, // 15 minutes
|
||||
"storage": 15,
|
||||
"system": 30,
|
||||
"docker": 15,
|
||||
}
|
||||
|
||||
loaded := 0
|
||||
for _, agent := range agents {
|
||||
// Skip offline agents (haven't checked in for 10+ minutes)
|
||||
@@ -147,28 +139,70 @@ func (s *Scheduler) LoadSubsystems(ctx context.Context) error {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, subsystem := range subsystems {
|
||||
// TODO: Check agent metadata for subsystem enablement
|
||||
// For now, assume all subsystems are enabled
|
||||
// Get subsystems from database (respect user settings)
|
||||
dbSubsystems, err := s.subsystemQueries.GetSubsystems(agent.ID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] Failed to get subsystems for agent %s: %v", agent.Hostname, err)
|
||||
continue
|
||||
}
|
||||
|
||||
job := &SubsystemJob{
|
||||
AgentID: agent.ID,
|
||||
AgentHostname: agent.Hostname,
|
||||
Subsystem: subsystem,
|
||||
IntervalMinutes: intervals[subsystem],
|
||||
NextRunAt: time.Now().Add(time.Duration(intervals[subsystem]) * time.Minute),
|
||||
Enabled: true,
|
||||
// Create jobs only for enabled subsystems with auto_run=true
|
||||
for _, dbSub := range dbSubsystems {
|
||||
if dbSub.Enabled && dbSub.AutoRun {
|
||||
// Use database interval, fallback to default
|
||||
intervalMinutes := dbSub.IntervalMinutes
|
||||
if intervalMinutes <= 0 {
|
||||
intervalMinutes = s.getDefaultInterval(dbSub.Subsystem)
|
||||
}
|
||||
|
||||
var nextRun time.Time
|
||||
if dbSub.NextRunAt != nil {
|
||||
nextRun = *dbSub.NextRunAt
|
||||
} else {
|
||||
// If no next run is set, schedule it with default interval
|
||||
nextRun = time.Now().Add(time.Duration(intervalMinutes) * time.Minute)
|
||||
}
|
||||
|
||||
job := &SubsystemJob{
|
||||
AgentID: agent.ID,
|
||||
AgentHostname: agent.Hostname,
|
||||
Subsystem: dbSub.Subsystem,
|
||||
IntervalMinutes: intervalMinutes,
|
||||
NextRunAt: nextRun,
|
||||
Enabled: dbSub.Enabled,
|
||||
}
|
||||
|
||||
s.queue.Push(job)
|
||||
loaded++
|
||||
}
|
||||
|
||||
s.queue.Push(job)
|
||||
loaded++
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[Scheduler] Loaded %d subsystem jobs for %d agents\n", loaded, len(agents))
|
||||
log.Printf("[Scheduler] Loaded %d subsystem jobs for %d agents (respecting database settings)\n", loaded, len(agents))
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDefaultInterval returns default interval minutes for a subsystem
|
||||
// TODO: These intervals need to correlate with agent health scanning settings
|
||||
// Each subsystem should be variable based on user-configurable agent health policies
|
||||
func (s *Scheduler) getDefaultInterval(subsystem string) int {
|
||||
defaults := map[string]int{
|
||||
"apt": 30, // 30 minutes
|
||||
"dnf": 240, // 4 hours
|
||||
"docker": 120, // 2 hours
|
||||
"storage": 360, // 6 hours
|
||||
"windows": 480, // 8 hours
|
||||
"winget": 360, // 6 hours
|
||||
"updates": 15, // 15 minutes
|
||||
"system": 30, // 30 minutes
|
||||
}
|
||||
|
||||
if interval, exists := defaults[subsystem]; exists {
|
||||
return interval
|
||||
}
|
||||
return 30 // Default fallback
|
||||
}
|
||||
|
||||
// Start begins the scheduler main loop and workers
|
||||
func (s *Scheduler) Start() error {
|
||||
log.Printf("[Scheduler] Starting with %d workers, check interval %v\n",
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
func TestScheduler_NewScheduler(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
s := NewScheduler(config, nil, nil)
|
||||
s := NewScheduler(config, nil, nil, nil)
|
||||
|
||||
if s == nil {
|
||||
t.Fatal("NewScheduler returned nil")
|
||||
@@ -58,7 +58,7 @@ func TestScheduler_DefaultConfig(t *testing.T) {
|
||||
|
||||
func TestScheduler_QueueIntegration(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
s := NewScheduler(config, nil, nil)
|
||||
s := NewScheduler(config, nil, nil, nil)
|
||||
|
||||
// Add jobs to queue
|
||||
agent1 := uuid.New()
|
||||
@@ -96,7 +96,7 @@ func TestScheduler_QueueIntegration(t *testing.T) {
|
||||
|
||||
func TestScheduler_GetStats(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
s := NewScheduler(config, nil, nil)
|
||||
s := NewScheduler(config, nil, nil, nil)
|
||||
|
||||
// Initial stats should be zero
|
||||
stats := s.GetStats()
|
||||
@@ -145,7 +145,7 @@ func TestScheduler_StartStop(t *testing.T) {
|
||||
RateLimitPerSecond: 0, // Disable rate limiting for test
|
||||
}
|
||||
|
||||
s := NewScheduler(config, nil, nil)
|
||||
s := NewScheduler(config, nil, nil, nil)
|
||||
|
||||
// Start scheduler
|
||||
err := s.Start()
|
||||
@@ -167,7 +167,7 @@ func TestScheduler_StartStop(t *testing.T) {
|
||||
|
||||
func TestScheduler_ProcessQueueEmpty(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
s := NewScheduler(config, nil, nil)
|
||||
s := NewScheduler(config, nil, nil, nil)
|
||||
|
||||
// Process empty queue should not panic
|
||||
s.processQueue()
|
||||
@@ -188,7 +188,7 @@ func TestScheduler_ProcessQueueWithJobs(t *testing.T) {
|
||||
RateLimitPerSecond: 0, // Disable for test
|
||||
}
|
||||
|
||||
s := NewScheduler(config, nil, nil)
|
||||
s := NewScheduler(config, nil, nil, nil)
|
||||
|
||||
// Add jobs that are due now
|
||||
for i := 0; i < 5; i++ {
|
||||
@@ -229,7 +229,7 @@ func TestScheduler_RateLimiterRefill(t *testing.T) {
|
||||
RateLimitPerSecond: 10, // 10 tokens per second
|
||||
}
|
||||
|
||||
s := NewScheduler(config, nil, nil)
|
||||
s := NewScheduler(config, nil, nil, nil)
|
||||
|
||||
if s.rateLimiter == nil {
|
||||
t.Fatal("rate limiter not initialized")
|
||||
@@ -264,7 +264,7 @@ func TestScheduler_RateLimiterRefill(t *testing.T) {
|
||||
|
||||
func TestScheduler_ConcurrentQueueAccess(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
s := NewScheduler(config, nil, nil)
|
||||
s := NewScheduler(config, nil, nil, nil)
|
||||
|
||||
done := make(chan bool)
|
||||
|
||||
@@ -303,7 +303,7 @@ func TestScheduler_ConcurrentQueueAccess(t *testing.T) {
|
||||
|
||||
func BenchmarkScheduler_ProcessQueue(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
s := NewScheduler(config, nil, nil)
|
||||
s := NewScheduler(config, nil, nil, nil)
|
||||
|
||||
// Pre-fill queue with jobs
|
||||
for i := 0; i < 1000; i++ {
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
|
||||
// TimeoutService handles timeout management for long-running operations
|
||||
type TimeoutService struct {
|
||||
commandQueries *queries.CommandQueries
|
||||
updateQueries *queries.UpdateQueries
|
||||
ticker *time.Ticker
|
||||
stopChan chan bool
|
||||
timeoutDuration time.Duration
|
||||
commandQueries *queries.CommandQueries
|
||||
updateQueries *queries.UpdateQueries
|
||||
ticker *time.Ticker
|
||||
stopChan chan bool
|
||||
sentTimeout time.Duration // For commands already sent to agents
|
||||
pendingTimeout time.Duration // For commands stuck in queue
|
||||
}
|
||||
|
||||
// NewTimeoutService creates a new timeout service
|
||||
@@ -24,14 +25,16 @@ func NewTimeoutService(cq *queries.CommandQueries, uq *queries.UpdateQueries) *T
|
||||
return &TimeoutService{
|
||||
commandQueries: cq,
|
||||
updateQueries: uq,
|
||||
timeoutDuration: 2 * time.Hour, // 2 hours timeout - allows for system upgrades and large operations
|
||||
sentTimeout: 2 * time.Hour, // 2 hours for commands already sent to agents
|
||||
pendingTimeout: 30 * time.Minute, // 30 minutes for commands stuck in queue
|
||||
// TODO: Make these timeout durations user-adjustable in settings
|
||||
stopChan: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the timeout monitoring service
|
||||
func (ts *TimeoutService) Start() {
|
||||
log.Printf("Starting timeout service with %v timeout duration", ts.timeoutDuration)
|
||||
log.Printf("Starting timeout service with %v sent timeout, %v pending timeout", ts.sentTimeout, ts.pendingTimeout)
|
||||
|
||||
// Create a ticker that runs every 5 minutes
|
||||
ts.ticker = time.NewTicker(5 * time.Minute)
|
||||
@@ -59,25 +62,41 @@ func (ts *TimeoutService) Stop() {
|
||||
func (ts *TimeoutService) checkForTimeouts() {
|
||||
log.Println("Checking for timed out operations...")
|
||||
|
||||
// Get all commands that are in 'sent' status
|
||||
commands, err := ts.commandQueries.GetCommandsByStatus(models.CommandStatusSent)
|
||||
if err != nil {
|
||||
log.Printf("Error getting sent commands: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
timeoutThreshold := time.Now().Add(-ts.timeoutDuration)
|
||||
sentTimeoutThreshold := time.Now().Add(-ts.sentTimeout)
|
||||
pendingTimeoutThreshold := time.Now().Add(-ts.pendingTimeout)
|
||||
timedOutCommands := make([]models.AgentCommand, 0)
|
||||
|
||||
for _, command := range commands {
|
||||
// Check if command has been sent and is older than timeout threshold
|
||||
if command.SentAt != nil && command.SentAt.Before(timeoutThreshold) {
|
||||
timedOutCommands = append(timedOutCommands, command)
|
||||
// Check 'sent' commands (traditional timeout - 2 hours)
|
||||
sentCommands, err := ts.commandQueries.GetCommandsByStatus(models.CommandStatusSent)
|
||||
if err != nil {
|
||||
log.Printf("Error getting sent commands: %v", err)
|
||||
} else {
|
||||
for _, command := range sentCommands {
|
||||
// Check if command has been sent and is older than sent timeout threshold
|
||||
if command.SentAt != nil && command.SentAt.Before(sentTimeoutThreshold) {
|
||||
timedOutCommands = append(timedOutCommands, command)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check 'pending' commands (stuck in queue timeout - 30 minutes)
|
||||
pendingCommands, err := ts.commandQueries.GetCommandsByStatus(models.CommandStatusPending)
|
||||
if err != nil {
|
||||
log.Printf("Error getting pending commands: %v", err)
|
||||
} else {
|
||||
for _, command := range pendingCommands {
|
||||
// Check if command has been pending longer than pending timeout threshold
|
||||
if command.CreatedAt.Before(pendingTimeoutThreshold) {
|
||||
timedOutCommands = append(timedOutCommands, command)
|
||||
log.Printf("Found stuck pending command %s (type: %s, created: %s, age: %v)",
|
||||
command.ID, command.CommandType, command.CreatedAt.Format(time.RFC3339), time.Since(command.CreatedAt))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(timedOutCommands) > 0 {
|
||||
log.Printf("Found %d timed out commands", len(timedOutCommands))
|
||||
log.Printf("Found %d timed out commands (%d sent >2h, %d stuck pending >30m)",
|
||||
len(timedOutCommands), len(sentCommands), len(pendingCommands))
|
||||
|
||||
for _, command := range timedOutCommands {
|
||||
if err := ts.timeoutCommand(&command); err != nil {
|
||||
@@ -91,6 +110,14 @@ func (ts *TimeoutService) checkForTimeouts() {
|
||||
|
||||
// timeoutCommand marks a specific command as timed out and updates related entities
|
||||
func (ts *TimeoutService) timeoutCommand(command *models.AgentCommand) error {
|
||||
// Determine which timeout duration was applied
|
||||
var appliedTimeout time.Duration
|
||||
if command.Status == models.CommandStatusSent {
|
||||
appliedTimeout = ts.sentTimeout
|
||||
} else {
|
||||
appliedTimeout = ts.pendingTimeout
|
||||
}
|
||||
|
||||
log.Printf("Timing out command %s (type: %s, agent: %s)",
|
||||
command.ID, command.CommandType, command.AgentID)
|
||||
|
||||
@@ -103,7 +130,7 @@ func (ts *TimeoutService) timeoutCommand(command *models.AgentCommand) error {
|
||||
result := models.JSONB{
|
||||
"error": "operation timed out",
|
||||
"timeout_at": time.Now(),
|
||||
"duration": ts.timeoutDuration.String(),
|
||||
"duration": appliedTimeout.String(),
|
||||
"command_id": command.ID.String(),
|
||||
}
|
||||
|
||||
@@ -112,7 +139,7 @@ func (ts *TimeoutService) timeoutCommand(command *models.AgentCommand) error {
|
||||
}
|
||||
|
||||
// Update related update package status if applicable
|
||||
if err := ts.updateRelatedPackageStatus(command); err != nil {
|
||||
if err := ts.updateRelatedPackageStatus(command, appliedTimeout); err != nil {
|
||||
log.Printf("Warning: failed to update related package status: %v", err)
|
||||
// Don't return error here as the main timeout operation succeeded
|
||||
}
|
||||
@@ -123,11 +150,11 @@ func (ts *TimeoutService) timeoutCommand(command *models.AgentCommand) error {
|
||||
AgentID: command.AgentID,
|
||||
UpdatePackageID: ts.extractUpdatePackageID(command),
|
||||
Action: command.CommandType,
|
||||
Result: "timed_out",
|
||||
Result: "failed", // Use 'failed' to comply with database constraint
|
||||
Stdout: "",
|
||||
Stderr: fmt.Sprintf("Command %s timed out after %v", command.CommandType, ts.timeoutDuration),
|
||||
Stderr: fmt.Sprintf("Command %s timed out after %v (timeout_id: %s)", command.CommandType, appliedTimeout, command.ID),
|
||||
ExitCode: 124, // Standard timeout exit code
|
||||
DurationSeconds: int(ts.timeoutDuration.Seconds()),
|
||||
DurationSeconds: int(appliedTimeout.Seconds()),
|
||||
ExecutedAt: time.Now(),
|
||||
}
|
||||
|
||||
@@ -141,7 +168,7 @@ func (ts *TimeoutService) timeoutCommand(command *models.AgentCommand) error {
|
||||
}
|
||||
|
||||
// updateRelatedPackageStatus updates the status of related update packages when a command times out
|
||||
func (ts *TimeoutService) updateRelatedPackageStatus(command *models.AgentCommand) error {
|
||||
func (ts *TimeoutService) updateRelatedPackageStatus(command *models.AgentCommand, appliedTimeout time.Duration) error {
|
||||
// Extract update_id from command params if it exists
|
||||
_, ok := command.Params["update_id"].(string)
|
||||
if !ok {
|
||||
@@ -153,7 +180,7 @@ func (ts *TimeoutService) updateRelatedPackageStatus(command *models.AgentComman
|
||||
metadata := models.JSONB{
|
||||
"timeout": true,
|
||||
"timeout_at": time.Now(),
|
||||
"timeout_duration": ts.timeoutDuration.String(),
|
||||
"timeout_duration": appliedTimeout.String(),
|
||||
"command_id": command.ID.String(),
|
||||
"failure_reason": "operation timed out",
|
||||
}
|
||||
@@ -196,7 +223,7 @@ func (ts *TimeoutService) GetTimeoutStatus() (map[string]interface{}, error) {
|
||||
}
|
||||
|
||||
// Count commands approaching timeout (within 5 minutes of timeout)
|
||||
timeoutThreshold := time.Now().Add(-ts.timeoutDuration + 5*time.Minute)
|
||||
timeoutThreshold := time.Now().Add(-ts.sentTimeout + 5*time.Minute)
|
||||
approachingTimeout := 0
|
||||
for _, command := range activeCommands {
|
||||
if command.SentAt != nil && command.SentAt.Before(timeoutThreshold) {
|
||||
@@ -205,16 +232,30 @@ func (ts *TimeoutService) GetTimeoutStatus() (map[string]interface{}, error) {
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_timed_out": len(timedOutCommands),
|
||||
"total_active": len(activeCommands),
|
||||
"approaching_timeout": approachingTimeout,
|
||||
"timeout_duration": ts.timeoutDuration.String(),
|
||||
"last_check": time.Now(),
|
||||
"total_timed_out": len(timedOutCommands),
|
||||
"total_active": len(activeCommands),
|
||||
"approaching_timeout": approachingTimeout,
|
||||
"sent_timeout_duration": ts.sentTimeout.String(),
|
||||
"pending_timeout_duration": ts.pendingTimeout.String(),
|
||||
"last_check": time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetTimeoutDuration allows changing the timeout duration
|
||||
// SetTimeoutDuration allows changing the timeout duration for sent commands
|
||||
// TODO: This should be deprecated in favor of SetSentTimeout and SetPendingTimeout
|
||||
func (ts *TimeoutService) SetTimeoutDuration(duration time.Duration) {
|
||||
ts.timeoutDuration = duration
|
||||
log.Printf("Timeout duration updated to %v", duration)
|
||||
ts.sentTimeout = duration
|
||||
log.Printf("Sent timeout duration updated to %v", duration)
|
||||
}
|
||||
|
||||
// SetSentTimeout allows changing the timeout duration for sent commands
|
||||
func (ts *TimeoutService) SetSentTimeout(duration time.Duration) {
|
||||
ts.sentTimeout = duration
|
||||
log.Printf("Sent timeout duration updated to %v", duration)
|
||||
}
|
||||
|
||||
// SetPendingTimeout allows changing the timeout duration for pending commands
|
||||
func (ts *TimeoutService) SetPendingTimeout(duration time.Duration) {
|
||||
ts.pendingTimeout = duration
|
||||
log.Printf("Pending timeout duration updated to %v", duration)
|
||||
}
|
||||
Reference in New Issue
Block a user