fix: agent acknowledgment recursion and subsystem UI improvements

- Fix recursive call in reportLogWithAck that caused infinite loop
- Add machine binding and security API endpoints
- Enhance AgentScanners component with security status display
- Update scheduler and timeout service reliability
- Remove deprecated install.sh script
- Add subsystem configuration and logging improvements
This commit is contained in:
Fimeg
2025-11-03 21:02:57 -05:00
parent d0f13e5da7
commit 57be3754c6
19 changed files with 665 additions and 409 deletions

View File

@@ -106,7 +106,8 @@ func (h *AgentHandler) RegisterAgent(c *gin.Context) {
// Save to database
if err := h.agentQueries.CreateAgent(agent); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to register agent"})
log.Printf("ERROR: Failed to create agent in database: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to register agent - database error"})
return
}
@@ -163,16 +164,17 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
// Try to parse optional system metrics from request body
var metrics struct {
CPUPercent float64 `json:"cpu_percent,omitempty"`
MemoryPercent float64 `json:"memory_percent,omitempty"`
MemoryUsedGB float64 `json:"memory_used_gb,omitempty"`
MemoryTotalGB float64 `json:"memory_total_gb,omitempty"`
DiskUsedGB float64 `json:"disk_used_gb,omitempty"`
DiskTotalGB float64 `json:"disk_total_gb,omitempty"`
DiskPercent float64 `json:"disk_percent,omitempty"`
Uptime string `json:"uptime,omitempty"`
Version string `json:"version,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
CPUPercent float64 `json:"cpu_percent,omitempty"`
MemoryPercent float64 `json:"memory_percent,omitempty"`
MemoryUsedGB float64 `json:"memory_used_gb,omitempty"`
MemoryTotalGB float64 `json:"memory_total_gb,omitempty"`
DiskUsedGB float64 `json:"disk_used_gb,omitempty"`
DiskTotalGB float64 `json:"disk_total_gb,omitempty"`
DiskPercent float64 `json:"disk_percent,omitempty"`
Uptime string `json:"uptime,omitempty"`
Version string `json:"version,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
PendingAcknowledgments []string `json:"pending_acknowledgments,omitempty"`
}
// Parse metrics if provided (optional, won't fail if empty)
@@ -449,10 +451,27 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
}
}
// Process command acknowledgments from agent
var acknowledgedIDs []string
if len(metrics.PendingAcknowledgments) > 0 {
log.Printf("DEBUG: Processing %d pending acknowledgments for agent %s: %v", len(metrics.PendingAcknowledgments), agentID, metrics.PendingAcknowledgments)
// Verify which commands from agent's pending list have been recorded
verified, err := h.commandQueries.VerifyCommandsCompleted(metrics.PendingAcknowledgments)
if err != nil {
log.Printf("Warning: Failed to verify command acknowledgments for agent %s: %v", agentID, err)
} else {
acknowledgedIDs = verified
log.Printf("DEBUG: Verified %d completed commands out of %d pending for agent %s", len(acknowledgedIDs), len(metrics.PendingAcknowledgments), agentID)
if len(acknowledgedIDs) > 0 {
log.Printf("Acknowledged %d command results for agent %s", len(acknowledgedIDs), agentID)
}
}
}
response := models.CommandsResponse{
Commands: commandItems,
RapidPolling: rapidPolling,
AcknowledgedIDs: []string{}, // No acknowledgments in current implementation
AcknowledgedIDs: acknowledgedIDs,
}
c.JSON(http.StatusOK, response)
@@ -465,7 +484,8 @@ func (h *AgentHandler) ListAgents(c *gin.Context) {
agents, err := h.agentQueries.ListAgentsWithLastScan(status, osType)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list agents"})
log.Printf("ERROR: Failed to list agents: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list agents - database error"})
return
}

View File

@@ -31,28 +31,24 @@ func (h *DownloadHandler) getServerURL(c *gin.Context) string {
return h.config.Server.PublicURL
}
// Priority 2: Detect from request with TLS/proxy awareness
// Priority 2: Construct API server URL from configuration
scheme := "http"
host := h.config.Server.Host
port := h.config.Server.Port
// Check if TLS is enabled in config
// Use HTTPS if TLS is enabled in config
if h.config.Server.TLS.Enabled {
scheme = "https"
}
// Check if request came through HTTPS (direct or via proxy)
if c.Request.TLS != nil {
scheme = "https"
// For default host (0.0.0.0), use localhost for client connections
if host == "0.0.0.0" {
host = "localhost"
}
// Check X-Forwarded-Proto for reverse proxy setups
if forwardedProto := c.GetHeader("X-Forwarded-Proto"); forwardedProto == "https" {
scheme = "https"
}
// Use the Host header exactly as received (includes port if present)
host := c.GetHeader("X-Forwarded-Host")
if host == "" {
host = c.Request.Host
// Only include port if it's not the default for the protocol
if (scheme == "http" && port != 80) || (scheme == "https" && port != 443) {
return fmt.Sprintf("%s://%s:%d", scheme, host, port)
}
return fmt.Sprintf("%s://%s", scheme, host)
@@ -155,6 +151,7 @@ AGENT_BINARY="/usr/local/bin/redflag-agent"
SUDOERS_FILE="/etc/sudoers.d/redflag-agent"
SERVICE_FILE="/etc/systemd/system/redflag-agent.service"
CONFIG_DIR="/etc/aggregator"
STATE_DIR="/var/lib/aggregator"
echo "=== RedFlag Agent Installation ==="
echo ""
@@ -301,19 +298,24 @@ else
exit 1
fi
# Step 4: Create configuration directory
# Step 4: Create configuration and state directories
echo ""
echo "Step 4: Creating configuration directory..."
echo "Step 4: Creating configuration and state directories..."
mkdir -p "$CONFIG_DIR"
chown "$AGENT_USER:$AGENT_USER" "$CONFIG_DIR"
chmod 755 "$CONFIG_DIR"
echo "✓ Configuration directory created"
# Set SELinux context for config directory if SELinux is enabled
# Create state directory for acknowledgment tracking (v0.1.19+)
mkdir -p "$STATE_DIR"
chown "$AGENT_USER:$AGENT_USER" "$STATE_DIR"
chmod 755 "$STATE_DIR"
echo "✓ Configuration and state directories created"
# Set SELinux context for directories if SELinux is enabled
if command -v getenforce >/dev/null 2>&1 && [ "$(getenforce)" != "Disabled" ]; then
echo "Setting SELinux context for config directory..."
restorecon -Rv "$CONFIG_DIR" 2>/dev/null || true
echo "✓ SELinux context set for config directory"
echo "Setting SELinux context for directories..."
restorecon -Rv "$CONFIG_DIR" "$STATE_DIR" 2>/dev/null || true
echo "✓ SELinux context set for directories"
fi
# Step 5: Install systemd service
@@ -338,7 +340,7 @@ RestartSec=30
# NoNewPrivileges=true - DISABLED: Prevents sudo from working, which agent needs for package management
ProtectSystem=strict
ProtectHome=true
ReadWritePaths=$AGENT_HOME /var/log $CONFIG_DIR
ReadWritePaths=$AGENT_HOME /var/log $CONFIG_DIR $STATE_DIR
PrivateTmp=true
# Logging

View File

@@ -387,9 +387,15 @@ func (h *SetupHandler) ConfigureServer(c *gin.Context) {
fmt.Println("Updating PostgreSQL password from bootstrap to user-provided password...")
bootstrapPassword := "redflag_bootstrap" // This matches our bootstrap .env
if err := updatePostgresPassword(req.DBHost, req.DBPort, req.DBUser, bootstrapPassword, req.DBPassword); err != nil {
fmt.Printf("Warning: Failed to update PostgreSQL password: %v\n", err)
fmt.Println("Will proceed with configuration anyway...")
fmt.Printf("CRITICAL ERROR: Failed to update PostgreSQL password: %v\n", err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to update database password. Setup cannot continue.",
"details": err.Error(),
"help": "Ensure PostgreSQL is accessible and the bootstrap password is correct. Check Docker logs for details.",
})
return
}
fmt.Println("PostgreSQL password successfully updated from bootstrap to user-provided password")
// Step 2: Generate configuration content for manual update
fmt.Println("Generating configuration content for manual .env file update...")
@@ -414,6 +420,11 @@ func (h *SetupHandler) ConfigureServer(c *gin.Context) {
// GenerateSigningKeys generates Ed25519 keypair for agent update signing
func (h *SetupHandler) GenerateSigningKeys(c *gin.Context) {
// Prevent caching of generated keys (security critical)
c.Header("Cache-Control", "no-store, no-cache, must-revalidate, private")
c.Header("Pragma", "no-cache")
c.Header("Expires", "0")
// Generate Ed25519 keypair
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
@@ -428,6 +439,9 @@ func (h *SetupHandler) GenerateSigningKeys(c *gin.Context) {
// Generate fingerprint (first 16 chars)
fingerprint := publicKeyHex[:16]
// Log key generation for security audit trail (only fingerprint, not full key)
fmt.Printf("Generated new Ed25519 keypair - Fingerprint: %s\n", fingerprint)
c.JSON(http.StatusOK, gin.H{
"public_key": publicKeyHex,
"private_key": privateKeyHex,

View File

@@ -13,6 +13,16 @@ import (
"github.com/google/uuid"
)
// isValidResult checks if the result value complies with the database constraint
func isValidResult(result string) bool {
validResults := map[string]bool{
"success": true,
"failed": true,
"partial": true,
}
return validResults[result]
}
type UpdateHandler struct {
updateQueries *queries.UpdateQueries
agentQueries *queries.AgentQueries
@@ -199,11 +209,22 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
return
}
// Validate and map result to comply with database constraint
validResult := req.Result
if !isValidResult(validResult) {
// Map invalid results to valid ones (e.g., "timed_out" -> "failed")
if validResult == "timed_out" || validResult == "timeout" || validResult == "cancelled" {
validResult = "failed"
} else {
validResult = "failed" // Default to failed for any unknown status
}
}
logEntry := &models.UpdateLog{
ID: uuid.New(),
AgentID: agentID,
Action: req.Action,
Result: req.Result,
Result: validResult,
Stdout: req.Stdout,
Stderr: req.Stderr,
ExitCode: req.ExitCode,
@@ -831,8 +852,8 @@ func (h *UpdateHandler) ClearFailedCommands(c *gin.Context) {
// Build the appropriate cleanup query based on parameters
if allFailed {
// Clear ALL failed commands (most aggressive)
count, err = h.commandQueries.ClearAllFailedCommands(olderThanDays)
// Clear ALL failed commands regardless of age (most aggressive)
count, err = h.commandQueries.ClearAllFailedCommandsRegardlessOfAge()
} else if onlyRetried {
// Clear only failed commands that have been retried
count, err = h.commandQueries.ClearRetriedFailedCommands(olderThanDays)

View File

@@ -88,6 +88,14 @@ func Load() (*Config, error) {
cfg.Timezone = getEnv("TIMEZONE", "UTC")
cfg.LatestAgentVersion = getEnv("LATEST_AGENT_VERSION", "0.1.22")
cfg.MinAgentVersion = getEnv("MIN_AGENT_VERSION", "0.1.22")
cfg.SigningPrivateKey = getEnv("REDFLAG_SIGNING_PRIVATE_KEY", "")
// Debug: Log signing key status
if cfg.SigningPrivateKey != "" {
fmt.Printf("[CONFIG] ✅ Ed25519 signing private key configured (%d characters)\n", len(cfg.SigningPrivateKey))
} else {
fmt.Printf("[CONFIG] ❌ No Ed25519 signing private key found in REDFLAG_SIGNING_PRIVATE_KEY\n")
}
// Handle missing secrets
if cfg.Admin.Password == "" || cfg.Admin.JWTSecret == "" || cfg.Database.Password == "" {

View File

@@ -23,14 +23,19 @@ func (q *AgentQueries) CreateAgent(agent *models.Agent) error {
query := `
INSERT INTO agents (
id, hostname, os_type, os_version, os_architecture,
agent_version, last_seen, status, metadata
agent_version, current_version, machine_id, public_key_fingerprint,
last_seen, status, metadata
) VALUES (
:id, :hostname, :os_type, :os_version, :os_architecture,
:agent_version, :last_seen, :status, :metadata
:agent_version, :current_version, :machine_id, :public_key_fingerprint,
:last_seen, :status, :metadata
)
`
_, err := q.db.NamedExec(query, agent)
return err
if err != nil {
return fmt.Errorf("failed to create agent %s (version %s): %w", agent.Hostname, agent.CurrentVersion, err)
}
return nil
}
// GetAgentByID retrieves an agent by ID

View File

@@ -2,6 +2,7 @@ package queries
import (
"fmt"
"strings"
"time"
"github.com/Fimeg/RedFlag/aggregator-server/internal/models"
@@ -31,13 +32,14 @@ func (q *CommandQueries) CreateCommand(cmd *models.AgentCommand) error {
}
// GetPendingCommands retrieves pending commands for an agent
// Only returns 'pending' status - 'sent' commands are handled by timeout service
func (q *CommandQueries) GetPendingCommands(agentID uuid.UUID) ([]models.AgentCommand, error) {
var commands []models.AgentCommand
query := `
SELECT * FROM agent_commands
WHERE agent_id = $1 AND status = 'pending'
ORDER BY created_at ASC
LIMIT 10
LIMIT 100
`
err := q.db.Select(&commands, query, agentID)
return commands, err
@@ -338,6 +340,23 @@ func (q *CommandQueries) ClearAllFailedCommands(days int) (int64, error) {
return result.RowsAffected()
}
// ClearAllFailedCommandsRegardlessOfAge archives ALL failed/timed_out commands regardless of age
// This is used when all_failed=true is passed to truly clear all failed commands
func (q *CommandQueries) ClearAllFailedCommandsRegardlessOfAge() (int64, error) {
query := `
UPDATE agent_commands
SET status = 'archived_failed'
WHERE status IN ('failed', 'timed_out')
`
result, err := q.db.Exec(query)
if err != nil {
return 0, fmt.Errorf("failed to archive all failed commands regardless of age: %w", err)
}
return result.RowsAffected()
}
// CountPendingCommandsForAgent returns the number of pending commands for a specific agent
// Used by scheduler for backpressure detection
func (q *CommandQueries) CountPendingCommandsForAgent(agentID uuid.UUID) (int, error) {
@@ -373,16 +392,30 @@ func (q *CommandQueries) VerifyCommandsCompleted(commandIDs []string) ([]string,
return []string{}, nil
}
// Convert UUIDs back to strings for SQL query
uuidStrs := make([]string, len(uuidIDs))
for i, id := range uuidIDs {
uuidStrs[i] = id.String()
}
// Query for commands that are completed or failed
query := `
// Use ANY with proper array literal for PostgreSQL
placeholders := make([]string, len(uuidStrs))
args := make([]interface{}, len(uuidStrs))
for i, id := range uuidStrs {
placeholders[i] = fmt.Sprintf("$%d", i+1)
args[i] = id
}
query := fmt.Sprintf(`
SELECT id
FROM agent_commands
WHERE id = ANY($1)
AND status IN ('completed', 'failed')
`
WHERE id::text = ANY(%s)
AND status IN ('completed', 'failed', 'timed_out')
`, fmt.Sprintf("ARRAY[%s]", strings.Join(placeholders, ",")))
var completedUUIDs []uuid.UUID
err := q.db.Select(&completedUUIDs, query, uuidIDs)
err := q.db.Select(&completedUUIDs, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to verify command completion: %w", err)
}

View File

@@ -45,6 +45,11 @@ type AgentWithLastScan struct {
CurrentVersion string `json:"current_version" db:"current_version"` // Current running version
UpdateAvailable bool `json:"update_available" db:"update_available"` // Whether update is available
LastVersionCheck time.Time `json:"last_version_check" db:"last_version_check"` // Last time version was checked
MachineID *string `json:"machine_id,omitempty" db:"machine_id"` // Unique machine identifier
PublicKeyFingerprint *string `json:"public_key_fingerprint,omitempty" db:"public_key_fingerprint"` // Public key fingerprint
IsUpdating bool `json:"is_updating" db:"is_updating"` // Whether agent is currently updating
UpdatingToVersion *string `json:"updating_to_version,omitempty" db:"updating_to_version"` // Target version for ongoing update
UpdateInitiatedAt *time.Time `json:"update_initiated_at,omitempty" db:"update_initiated_at"` // When update process started
LastSeen time.Time `json:"last_seen" db:"last_seen"`
Status string `json:"status" db:"status"`
Metadata JSONB `json:"metadata" db:"metadata"`

View File

@@ -53,8 +53,9 @@ type Scheduler struct {
queue *PriorityQueue
// Database queries
agentQueries *queries.AgentQueries
commandQueries *queries.CommandQueries
agentQueries *queries.AgentQueries
commandQueries *queries.CommandQueries
subsystemQueries *queries.SubsystemQueries
// Worker pool
jobChan chan *SubsystemJob
@@ -88,19 +89,20 @@ type Stats struct {
}
// NewScheduler creates a new scheduler instance
func NewScheduler(config Config, agentQueries *queries.AgentQueries, commandQueries *queries.CommandQueries) *Scheduler {
func NewScheduler(config Config, agentQueries *queries.AgentQueries, commandQueries *queries.CommandQueries, subsystemQueries *queries.SubsystemQueries) *Scheduler {
ctx, cancel := context.WithCancel(context.Background())
s := &Scheduler{
config: config,
queue: NewPriorityQueue(),
agentQueries: agentQueries,
commandQueries: commandQueries,
jobChan: make(chan *SubsystemJob, 1000), // Buffer 1000 jobs
workers: make([]*worker, config.NumWorkers),
shutdown: make(chan struct{}),
ctx: ctx,
cancel: cancel,
config: config,
queue: NewPriorityQueue(),
agentQueries: agentQueries,
commandQueries: commandQueries,
subsystemQueries: subsystemQueries,
jobChan: make(chan *SubsystemJob, 1000), // Buffer 1000 jobs
workers: make([]*worker, config.NumWorkers),
shutdown: make(chan struct{}),
ctx: ctx,
cancel: cancel,
}
// Initialize rate limiter if configured
@@ -130,16 +132,6 @@ func (s *Scheduler) LoadSubsystems(ctx context.Context) error {
return fmt.Errorf("failed to get agents: %w", err)
}
// For now, we'll create default subsystems for each agent
// In full implementation, this would read from agent_subsystems table
subsystems := []string{"updates", "storage", "system", "docker"}
intervals := map[string]int{
"updates": 15, // 15 minutes
"storage": 15,
"system": 30,
"docker": 15,
}
loaded := 0
for _, agent := range agents {
// Skip offline agents (haven't checked in for 10+ minutes)
@@ -147,28 +139,70 @@ func (s *Scheduler) LoadSubsystems(ctx context.Context) error {
continue
}
for _, subsystem := range subsystems {
// TODO: Check agent metadata for subsystem enablement
// For now, assume all subsystems are enabled
// Get subsystems from database (respect user settings)
dbSubsystems, err := s.subsystemQueries.GetSubsystems(agent.ID)
if err != nil {
log.Printf("[Scheduler] Failed to get subsystems for agent %s: %v", agent.Hostname, err)
continue
}
job := &SubsystemJob{
AgentID: agent.ID,
AgentHostname: agent.Hostname,
Subsystem: subsystem,
IntervalMinutes: intervals[subsystem],
NextRunAt: time.Now().Add(time.Duration(intervals[subsystem]) * time.Minute),
Enabled: true,
// Create jobs only for enabled subsystems with auto_run=true
for _, dbSub := range dbSubsystems {
if dbSub.Enabled && dbSub.AutoRun {
// Use database interval, fallback to default
intervalMinutes := dbSub.IntervalMinutes
if intervalMinutes <= 0 {
intervalMinutes = s.getDefaultInterval(dbSub.Subsystem)
}
var nextRun time.Time
if dbSub.NextRunAt != nil {
nextRun = *dbSub.NextRunAt
} else {
// If no next run is set, schedule it with default interval
nextRun = time.Now().Add(time.Duration(intervalMinutes) * time.Minute)
}
job := &SubsystemJob{
AgentID: agent.ID,
AgentHostname: agent.Hostname,
Subsystem: dbSub.Subsystem,
IntervalMinutes: intervalMinutes,
NextRunAt: nextRun,
Enabled: dbSub.Enabled,
}
s.queue.Push(job)
loaded++
}
s.queue.Push(job)
loaded++
}
}
log.Printf("[Scheduler] Loaded %d subsystem jobs for %d agents\n", loaded, len(agents))
log.Printf("[Scheduler] Loaded %d subsystem jobs for %d agents (respecting database settings)\n", loaded, len(agents))
return nil
}
// getDefaultInterval returns default interval minutes for a subsystem
// TODO: These intervals need to correlate with agent health scanning settings
// Each subsystem should be variable based on user-configurable agent health policies
func (s *Scheduler) getDefaultInterval(subsystem string) int {
defaults := map[string]int{
"apt": 30, // 30 minutes
"dnf": 240, // 4 hours
"docker": 120, // 2 hours
"storage": 360, // 6 hours
"windows": 480, // 8 hours
"winget": 360, // 6 hours
"updates": 15, // 15 minutes
"system": 30, // 30 minutes
}
if interval, exists := defaults[subsystem]; exists {
return interval
}
return 30 // Default fallback
}
// Start begins the scheduler main loop and workers
func (s *Scheduler) Start() error {
log.Printf("[Scheduler] Starting with %d workers, check interval %v\n",

View File

@@ -9,7 +9,7 @@ import (
func TestScheduler_NewScheduler(t *testing.T) {
config := DefaultConfig()
s := NewScheduler(config, nil, nil)
s := NewScheduler(config, nil, nil, nil)
if s == nil {
t.Fatal("NewScheduler returned nil")
@@ -58,7 +58,7 @@ func TestScheduler_DefaultConfig(t *testing.T) {
func TestScheduler_QueueIntegration(t *testing.T) {
config := DefaultConfig()
s := NewScheduler(config, nil, nil)
s := NewScheduler(config, nil, nil, nil)
// Add jobs to queue
agent1 := uuid.New()
@@ -96,7 +96,7 @@ func TestScheduler_QueueIntegration(t *testing.T) {
func TestScheduler_GetStats(t *testing.T) {
config := DefaultConfig()
s := NewScheduler(config, nil, nil)
s := NewScheduler(config, nil, nil, nil)
// Initial stats should be zero
stats := s.GetStats()
@@ -145,7 +145,7 @@ func TestScheduler_StartStop(t *testing.T) {
RateLimitPerSecond: 0, // Disable rate limiting for test
}
s := NewScheduler(config, nil, nil)
s := NewScheduler(config, nil, nil, nil)
// Start scheduler
err := s.Start()
@@ -167,7 +167,7 @@ func TestScheduler_StartStop(t *testing.T) {
func TestScheduler_ProcessQueueEmpty(t *testing.T) {
config := DefaultConfig()
s := NewScheduler(config, nil, nil)
s := NewScheduler(config, nil, nil, nil)
// Process empty queue should not panic
s.processQueue()
@@ -188,7 +188,7 @@ func TestScheduler_ProcessQueueWithJobs(t *testing.T) {
RateLimitPerSecond: 0, // Disable for test
}
s := NewScheduler(config, nil, nil)
s := NewScheduler(config, nil, nil, nil)
// Add jobs that are due now
for i := 0; i < 5; i++ {
@@ -229,7 +229,7 @@ func TestScheduler_RateLimiterRefill(t *testing.T) {
RateLimitPerSecond: 10, // 10 tokens per second
}
s := NewScheduler(config, nil, nil)
s := NewScheduler(config, nil, nil, nil)
if s.rateLimiter == nil {
t.Fatal("rate limiter not initialized")
@@ -264,7 +264,7 @@ func TestScheduler_RateLimiterRefill(t *testing.T) {
func TestScheduler_ConcurrentQueueAccess(t *testing.T) {
config := DefaultConfig()
s := NewScheduler(config, nil, nil)
s := NewScheduler(config, nil, nil, nil)
done := make(chan bool)
@@ -303,7 +303,7 @@ func TestScheduler_ConcurrentQueueAccess(t *testing.T) {
func BenchmarkScheduler_ProcessQueue(b *testing.B) {
config := DefaultConfig()
s := NewScheduler(config, nil, nil)
s := NewScheduler(config, nil, nil, nil)
// Pre-fill queue with jobs
for i := 0; i < 1000; i++ {

View File

@@ -12,11 +12,12 @@ import (
// TimeoutService handles timeout management for long-running operations
type TimeoutService struct {
commandQueries *queries.CommandQueries
updateQueries *queries.UpdateQueries
ticker *time.Ticker
stopChan chan bool
timeoutDuration time.Duration
commandQueries *queries.CommandQueries
updateQueries *queries.UpdateQueries
ticker *time.Ticker
stopChan chan bool
sentTimeout time.Duration // For commands already sent to agents
pendingTimeout time.Duration // For commands stuck in queue
}
// NewTimeoutService creates a new timeout service
@@ -24,14 +25,16 @@ func NewTimeoutService(cq *queries.CommandQueries, uq *queries.UpdateQueries) *T
return &TimeoutService{
commandQueries: cq,
updateQueries: uq,
timeoutDuration: 2 * time.Hour, // 2 hours timeout - allows for system upgrades and large operations
sentTimeout: 2 * time.Hour, // 2 hours for commands already sent to agents
pendingTimeout: 30 * time.Minute, // 30 minutes for commands stuck in queue
// TODO: Make these timeout durations user-adjustable in settings
stopChan: make(chan bool),
}
}
// Start begins the timeout monitoring service
func (ts *TimeoutService) Start() {
log.Printf("Starting timeout service with %v timeout duration", ts.timeoutDuration)
log.Printf("Starting timeout service with %v sent timeout, %v pending timeout", ts.sentTimeout, ts.pendingTimeout)
// Create a ticker that runs every 5 minutes
ts.ticker = time.NewTicker(5 * time.Minute)
@@ -59,25 +62,41 @@ func (ts *TimeoutService) Stop() {
func (ts *TimeoutService) checkForTimeouts() {
log.Println("Checking for timed out operations...")
// Get all commands that are in 'sent' status
commands, err := ts.commandQueries.GetCommandsByStatus(models.CommandStatusSent)
if err != nil {
log.Printf("Error getting sent commands: %v", err)
return
}
timeoutThreshold := time.Now().Add(-ts.timeoutDuration)
sentTimeoutThreshold := time.Now().Add(-ts.sentTimeout)
pendingTimeoutThreshold := time.Now().Add(-ts.pendingTimeout)
timedOutCommands := make([]models.AgentCommand, 0)
for _, command := range commands {
// Check if command has been sent and is older than timeout threshold
if command.SentAt != nil && command.SentAt.Before(timeoutThreshold) {
timedOutCommands = append(timedOutCommands, command)
// Check 'sent' commands (traditional timeout - 2 hours)
sentCommands, err := ts.commandQueries.GetCommandsByStatus(models.CommandStatusSent)
if err != nil {
log.Printf("Error getting sent commands: %v", err)
} else {
for _, command := range sentCommands {
// Check if command has been sent and is older than sent timeout threshold
if command.SentAt != nil && command.SentAt.Before(sentTimeoutThreshold) {
timedOutCommands = append(timedOutCommands, command)
}
}
}
// Check 'pending' commands (stuck in queue timeout - 30 minutes)
pendingCommands, err := ts.commandQueries.GetCommandsByStatus(models.CommandStatusPending)
if err != nil {
log.Printf("Error getting pending commands: %v", err)
} else {
for _, command := range pendingCommands {
// Check if command has been pending longer than pending timeout threshold
if command.CreatedAt.Before(pendingTimeoutThreshold) {
timedOutCommands = append(timedOutCommands, command)
log.Printf("Found stuck pending command %s (type: %s, created: %s, age: %v)",
command.ID, command.CommandType, command.CreatedAt.Format(time.RFC3339), time.Since(command.CreatedAt))
}
}
}
if len(timedOutCommands) > 0 {
log.Printf("Found %d timed out commands", len(timedOutCommands))
log.Printf("Found %d timed out commands (%d sent >2h, %d stuck pending >30m)",
len(timedOutCommands), len(sentCommands), len(pendingCommands))
for _, command := range timedOutCommands {
if err := ts.timeoutCommand(&command); err != nil {
@@ -91,6 +110,14 @@ func (ts *TimeoutService) checkForTimeouts() {
// timeoutCommand marks a specific command as timed out and updates related entities
func (ts *TimeoutService) timeoutCommand(command *models.AgentCommand) error {
// Determine which timeout duration was applied
var appliedTimeout time.Duration
if command.Status == models.CommandStatusSent {
appliedTimeout = ts.sentTimeout
} else {
appliedTimeout = ts.pendingTimeout
}
log.Printf("Timing out command %s (type: %s, agent: %s)",
command.ID, command.CommandType, command.AgentID)
@@ -103,7 +130,7 @@ func (ts *TimeoutService) timeoutCommand(command *models.AgentCommand) error {
result := models.JSONB{
"error": "operation timed out",
"timeout_at": time.Now(),
"duration": ts.timeoutDuration.String(),
"duration": appliedTimeout.String(),
"command_id": command.ID.String(),
}
@@ -112,7 +139,7 @@ func (ts *TimeoutService) timeoutCommand(command *models.AgentCommand) error {
}
// Update related update package status if applicable
if err := ts.updateRelatedPackageStatus(command); err != nil {
if err := ts.updateRelatedPackageStatus(command, appliedTimeout); err != nil {
log.Printf("Warning: failed to update related package status: %v", err)
// Don't return error here as the main timeout operation succeeded
}
@@ -123,11 +150,11 @@ func (ts *TimeoutService) timeoutCommand(command *models.AgentCommand) error {
AgentID: command.AgentID,
UpdatePackageID: ts.extractUpdatePackageID(command),
Action: command.CommandType,
Result: "timed_out",
Result: "failed", // Use 'failed' to comply with database constraint
Stdout: "",
Stderr: fmt.Sprintf("Command %s timed out after %v", command.CommandType, ts.timeoutDuration),
Stderr: fmt.Sprintf("Command %s timed out after %v (timeout_id: %s)", command.CommandType, appliedTimeout, command.ID),
ExitCode: 124, // Standard timeout exit code
DurationSeconds: int(ts.timeoutDuration.Seconds()),
DurationSeconds: int(appliedTimeout.Seconds()),
ExecutedAt: time.Now(),
}
@@ -141,7 +168,7 @@ func (ts *TimeoutService) timeoutCommand(command *models.AgentCommand) error {
}
// updateRelatedPackageStatus updates the status of related update packages when a command times out
func (ts *TimeoutService) updateRelatedPackageStatus(command *models.AgentCommand) error {
func (ts *TimeoutService) updateRelatedPackageStatus(command *models.AgentCommand, appliedTimeout time.Duration) error {
// Extract update_id from command params if it exists
_, ok := command.Params["update_id"].(string)
if !ok {
@@ -153,7 +180,7 @@ func (ts *TimeoutService) updateRelatedPackageStatus(command *models.AgentComman
metadata := models.JSONB{
"timeout": true,
"timeout_at": time.Now(),
"timeout_duration": ts.timeoutDuration.String(),
"timeout_duration": appliedTimeout.String(),
"command_id": command.ID.String(),
"failure_reason": "operation timed out",
}
@@ -196,7 +223,7 @@ func (ts *TimeoutService) GetTimeoutStatus() (map[string]interface{}, error) {
}
// Count commands approaching timeout (within 5 minutes of timeout)
timeoutThreshold := time.Now().Add(-ts.timeoutDuration + 5*time.Minute)
timeoutThreshold := time.Now().Add(-ts.sentTimeout + 5*time.Minute)
approachingTimeout := 0
for _, command := range activeCommands {
if command.SentAt != nil && command.SentAt.Before(timeoutThreshold) {
@@ -205,16 +232,30 @@ func (ts *TimeoutService) GetTimeoutStatus() (map[string]interface{}, error) {
}
return map[string]interface{}{
"total_timed_out": len(timedOutCommands),
"total_active": len(activeCommands),
"approaching_timeout": approachingTimeout,
"timeout_duration": ts.timeoutDuration.String(),
"last_check": time.Now(),
"total_timed_out": len(timedOutCommands),
"total_active": len(activeCommands),
"approaching_timeout": approachingTimeout,
"sent_timeout_duration": ts.sentTimeout.String(),
"pending_timeout_duration": ts.pendingTimeout.String(),
"last_check": time.Now(),
}, nil
}
// SetTimeoutDuration allows changing the timeout duration
// SetTimeoutDuration allows changing the timeout duration for sent commands
// TODO: This should be deprecated in favor of SetSentTimeout and SetPendingTimeout
func (ts *TimeoutService) SetTimeoutDuration(duration time.Duration) {
ts.timeoutDuration = duration
log.Printf("Timeout duration updated to %v", duration)
ts.sentTimeout = duration
log.Printf("Sent timeout duration updated to %v", duration)
}
// SetSentTimeout allows changing the timeout duration for sent commands
func (ts *TimeoutService) SetSentTimeout(duration time.Duration) {
ts.sentTimeout = duration
log.Printf("Sent timeout duration updated to %v", duration)
}
// SetPendingTimeout allows changing the timeout duration for pending commands
func (ts *TimeoutService) SetPendingTimeout(duration time.Duration) {
ts.pendingTimeout = duration
log.Printf("Pending timeout duration updated to %v", duration)
}