WIP: Save current state - security subsystems, migrations, logging
This commit is contained in:
@@ -8,7 +8,10 @@ import (
|
||||
|
||||
"github.com/Fimeg/RedFlag/aggregator-server/internal/api/middleware"
|
||||
"github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries"
|
||||
"github.com/Fimeg/RedFlag/aggregator-server/internal/logging"
|
||||
"github.com/Fimeg/RedFlag/aggregator-server/internal/models"
|
||||
"github.com/Fimeg/RedFlag/aggregator-server/internal/scheduler"
|
||||
"github.com/Fimeg/RedFlag/aggregator-server/internal/services"
|
||||
"github.com/Fimeg/RedFlag/aggregator-server/internal/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -20,22 +23,59 @@ type AgentHandler struct {
|
||||
refreshTokenQueries *queries.RefreshTokenQueries
|
||||
registrationTokenQueries *queries.RegistrationTokenQueries
|
||||
subsystemQueries *queries.SubsystemQueries
|
||||
scheduler *scheduler.Scheduler
|
||||
signingService *services.SigningService
|
||||
securityLogger *logging.SecurityLogger
|
||||
checkInInterval int
|
||||
latestAgentVersion string
|
||||
}
|
||||
|
||||
func NewAgentHandler(aq *queries.AgentQueries, cq *queries.CommandQueries, rtq *queries.RefreshTokenQueries, regTokenQueries *queries.RegistrationTokenQueries, sq *queries.SubsystemQueries, checkInInterval int, latestAgentVersion string) *AgentHandler {
|
||||
func NewAgentHandler(aq *queries.AgentQueries, cq *queries.CommandQueries, rtq *queries.RefreshTokenQueries, regTokenQueries *queries.RegistrationTokenQueries, sq *queries.SubsystemQueries, scheduler *scheduler.Scheduler, signingService *services.SigningService, securityLogger *logging.SecurityLogger, checkInInterval int, latestAgentVersion string) *AgentHandler {
|
||||
return &AgentHandler{
|
||||
agentQueries: aq,
|
||||
commandQueries: cq,
|
||||
refreshTokenQueries: rtq,
|
||||
registrationTokenQueries: regTokenQueries,
|
||||
subsystemQueries: sq,
|
||||
scheduler: scheduler,
|
||||
signingService: signingService,
|
||||
securityLogger: securityLogger,
|
||||
checkInInterval: checkInInterval,
|
||||
latestAgentVersion: latestAgentVersion,
|
||||
}
|
||||
}
|
||||
|
||||
// signAndCreateCommand signs a command if signing service is enabled, then stores it in the database
|
||||
func (h *AgentHandler) signAndCreateCommand(cmd *models.AgentCommand) error {
|
||||
// Sign the command before storing
|
||||
if h.signingService != nil && h.signingService.IsEnabled() {
|
||||
signature, err := h.signingService.SignCommand(cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign command: %w", err)
|
||||
}
|
||||
cmd.Signature = signature
|
||||
|
||||
// Log successful signing
|
||||
if h.securityLogger != nil {
|
||||
h.securityLogger.LogCommandSigned(cmd)
|
||||
}
|
||||
} else {
|
||||
// Log warning if signing disabled
|
||||
log.Printf("[WARNING] Command signing disabled, storing unsigned command")
|
||||
if h.securityLogger != nil {
|
||||
h.securityLogger.LogPrivateKeyNotConfigured()
|
||||
}
|
||||
}
|
||||
|
||||
// Store in database
|
||||
err := h.commandQueries.CreateCommand(cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create command: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterAgent handles agent registration
|
||||
func (h *AgentHandler) RegisterAgent(c *gin.Context) {
|
||||
var req models.AgentRegistrationRequest
|
||||
@@ -185,6 +225,47 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
log.Printf("DEBUG: Failed to parse metrics JSON: %v", err)
|
||||
}
|
||||
|
||||
// Process buffered events from agent if present
|
||||
if metrics.Metadata != nil {
|
||||
if bufferedEvents, exists := metrics.Metadata["buffered_events"]; exists {
|
||||
if events, ok := bufferedEvents.([]interface{}); ok && len(events) > 0 {
|
||||
stored := 0
|
||||
for _, e := range events {
|
||||
if eventMap, ok := e.(map[string]interface{}); ok {
|
||||
// Extract event fields with type safety
|
||||
eventType := getStringFromMap(eventMap, "event_type")
|
||||
eventSubtype := getStringFromMap(eventMap, "event_subtype")
|
||||
severity := getStringFromMap(eventMap, "severity")
|
||||
component := getStringFromMap(eventMap, "component")
|
||||
message := getStringFromMap(eventMap, "message")
|
||||
|
||||
if eventType != "" && eventSubtype != "" && severity != "" {
|
||||
event := &models.SystemEvent{
|
||||
AgentID: &agentID,
|
||||
EventType: eventType,
|
||||
EventSubtype: eventSubtype,
|
||||
Severity: severity,
|
||||
Component: component,
|
||||
Message: message,
|
||||
Metadata: eventMap["metadata"].(map[string]interface{}),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := h.agentQueries.CreateSystemEvent(event); err != nil {
|
||||
log.Printf("Warning: Failed to store buffered event: %v", err)
|
||||
} else {
|
||||
stored++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if stored > 0 {
|
||||
log.Printf("Stored %d buffered events from agent %s", stored, agentID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Debug logging to see what we received
|
||||
log.Printf("DEBUG: Received metrics - Version: '%s', CPU: %.2f, Memory: %.2f",
|
||||
metrics.Version, metrics.CPUPercent, metrics.MemoryPercent)
|
||||
@@ -355,9 +436,10 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
commandItems := make([]models.CommandItem, 0, len(commands))
|
||||
for _, cmd := range commands {
|
||||
commandItems = append(commandItems, models.CommandItem{
|
||||
ID: cmd.ID.String(),
|
||||
Type: cmd.CommandType,
|
||||
Params: cmd.Params,
|
||||
ID: cmd.ID.String(),
|
||||
Type: cmd.CommandType,
|
||||
Params: cmd.Params,
|
||||
Signature: cmd.Signature,
|
||||
})
|
||||
|
||||
// Mark as sent
|
||||
@@ -438,7 +520,7 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
CompletedAt: &now,
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(auditCmd); err != nil {
|
||||
if err := h.signAndCreateCommand(auditCmd); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to create audit command for stale heartbeat: %v", err)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Created audit trail for stale heartbeat cleanup (agent %s)", agentID)
|
||||
@@ -456,6 +538,19 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
// Process command acknowledgments from agent
|
||||
var acknowledgedIDs []string
|
||||
if len(metrics.PendingAcknowledgments) > 0 {
|
||||
// Debug: Check what commands exist for this agent
|
||||
agentCommands, err := h.commandQueries.GetCommandsByAgentID(agentID)
|
||||
if err != nil {
|
||||
log.Printf("DEBUG: Failed to get commands for agent %s: %v", agentID, err)
|
||||
} else {
|
||||
log.Printf("DEBUG: Agent %s has %d total commands in database", agentID, len(agentCommands))
|
||||
for _, cmd := range agentCommands {
|
||||
if cmd.Status == "completed" || cmd.Status == "failed" || cmd.Status == "timed_out" {
|
||||
log.Printf("DEBUG: Completed command found - ID: %s, Status: %s, Type: %s", cmd.ID, cmd.Status, cmd.CommandType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("DEBUG: Processing %d pending acknowledgments for agent %s: %v", len(metrics.PendingAcknowledgments), agentID, metrics.PendingAcknowledgments)
|
||||
// Verify which commands from agent's pending list have been recorded
|
||||
verified, err := h.commandQueries.VerifyCommandsCompleted(metrics.PendingAcknowledgments)
|
||||
@@ -470,6 +565,19 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Hybrid Heartbeat: Check for scheduled subsystem jobs during heartbeat mode
|
||||
// This ensures that even in heartbeat mode, scheduled scans can be triggered
|
||||
if h.scheduler != nil {
|
||||
// Only check for scheduled jobs if agent is in heartbeat mode (rapid polling enabled)
|
||||
isHeartbeatMode := rapidPolling != nil && rapidPolling.Enabled
|
||||
if isHeartbeatMode {
|
||||
if err := h.checkAndCreateScheduledCommands(agentID); err != nil {
|
||||
// Log error but don't fail the request - this is enhancement, not core functionality
|
||||
log.Printf("[Heartbeat] Failed to check scheduled commands for agent %s: %v", agentID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response := models.CommandsResponse{
|
||||
Commands: commandItems,
|
||||
RapidPolling: rapidPolling,
|
||||
@@ -479,6 +587,94 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// checkAndCreateScheduledCommands checks if any subsystem jobs are due for the agent
|
||||
// and creates commands for them using the scheduler (following Option A approach)
|
||||
func (h *AgentHandler) checkAndCreateScheduledCommands(agentID uuid.UUID) error {
|
||||
// Get current subsystems for this agent from database
|
||||
subsystems, err := h.subsystemQueries.GetSubsystems(agentID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get subsystems: %w", err)
|
||||
}
|
||||
|
||||
// Check each enabled subsystem with auto_run=true
|
||||
now := time.Now()
|
||||
jobsCreated := 0
|
||||
|
||||
for _, subsystem := range subsystems {
|
||||
if !subsystem.Enabled || !subsystem.AutoRun {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this subsystem job is due
|
||||
var isDue bool
|
||||
if subsystem.NextRunAt == nil {
|
||||
// No next run time set, it's due
|
||||
isDue = true
|
||||
} else {
|
||||
// Check if next run time has passed
|
||||
isDue = subsystem.NextRunAt.Before(now) || subsystem.NextRunAt.Equal(now)
|
||||
}
|
||||
|
||||
if isDue {
|
||||
// Create the command using scheduler logic (reusing existing safeguards)
|
||||
if err := h.createSubsystemCommand(agentID, subsystem); err != nil {
|
||||
log.Printf("[Heartbeat] Failed to create command for %s subsystem: %v", subsystem.Subsystem, err)
|
||||
continue
|
||||
}
|
||||
jobsCreated++
|
||||
|
||||
// Update next run time in database ONLY after successful command creation
|
||||
if err := h.updateNextRunTime(agentID, subsystem); err != nil {
|
||||
log.Printf("[Heartbeat] Failed to update next run time for %s subsystem: %v", subsystem.Subsystem, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if jobsCreated > 0 {
|
||||
log.Printf("[Heartbeat] Created %d scheduled commands for agent %s", jobsCreated, agentID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSubsystemCommand creates a subsystem scan command using scheduler's logic
|
||||
func (h *AgentHandler) createSubsystemCommand(agentID uuid.UUID, subsystem models.AgentSubsystem) error {
|
||||
// Check backpressure: skip if agent has too many pending commands
|
||||
pendingCount, err := h.commandQueries.CountPendingCommandsForAgent(agentID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check pending commands: %w", err)
|
||||
}
|
||||
|
||||
// Backpressure threshold (same as scheduler)
|
||||
const backpressureThreshold = 10
|
||||
if pendingCount >= backpressureThreshold {
|
||||
return fmt.Errorf("agent has %d pending commands (threshold: %d), skipping", pendingCount, backpressureThreshold)
|
||||
}
|
||||
|
||||
// Create the command using same format as scheduler
|
||||
cmd := &models.AgentCommand{
|
||||
ID: uuid.New(),
|
||||
AgentID: agentID,
|
||||
CommandType: fmt.Sprintf("scan_%s", subsystem.Subsystem),
|
||||
Params: models.JSONB{},
|
||||
Status: models.CommandStatusPending,
|
||||
Source: models.CommandSourceSystem,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := h.signAndCreateCommand(cmd); err != nil {
|
||||
return fmt.Errorf("failed to create command: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateNextRunTime updates the last_run_at and next_run_at for a subsystem after creating a command
|
||||
func (h *AgentHandler) updateNextRunTime(agentID uuid.UUID, subsystem models.AgentSubsystem) error {
|
||||
// Use the existing UpdateLastRun method which handles next_run_at calculation
|
||||
return h.subsystemQueries.UpdateLastRun(agentID, subsystem.Subsystem)
|
||||
}
|
||||
|
||||
// ListAgents returns all agents with last scan information
|
||||
func (h *AgentHandler) ListAgents(c *gin.Context) {
|
||||
status := c.Query("status")
|
||||
@@ -546,7 +742,7 @@ func (h *AgentHandler) TriggerScan(c *gin.Context) {
|
||||
Source: models.CommandSourceManual,
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(cmd); err != nil {
|
||||
if err := h.signAndCreateCommand(cmd); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create command"})
|
||||
return
|
||||
}
|
||||
@@ -591,7 +787,7 @@ func (h *AgentHandler) TriggerHeartbeat(c *gin.Context) {
|
||||
Source: models.CommandSourceManual,
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(cmd); err != nil {
|
||||
if err := h.signAndCreateCommand(cmd); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create heartbeat command"})
|
||||
return
|
||||
}
|
||||
@@ -786,7 +982,7 @@ func (h *AgentHandler) TriggerUpdate(c *gin.Context) {
|
||||
Source: models.CommandSourceManual,
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(cmd); err != nil {
|
||||
if err := h.signAndCreateCommand(cmd); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create update command"})
|
||||
return
|
||||
}
|
||||
@@ -827,6 +1023,15 @@ func (h *AgentHandler) RenewToken(c *gin.Context) {
|
||||
log.Printf("Warning: Failed to update last_seen for agent %s: %v", req.AgentID, err)
|
||||
}
|
||||
|
||||
// Update agent version if provided (for upgrade tracking)
|
||||
if req.AgentVersion != "" {
|
||||
if err := h.agentQueries.UpdateAgentVersion(req.AgentID, req.AgentVersion); err != nil {
|
||||
log.Printf("Warning: Failed to update agent version during token renewal for agent %s: %v", req.AgentID, err)
|
||||
} else {
|
||||
log.Printf("Agent %s version updated to %s during token renewal", req.AgentID, req.AgentVersion)
|
||||
}
|
||||
}
|
||||
|
||||
// Update refresh token expiration (sliding window - reset to 90 days from now)
|
||||
// This ensures active agents never need to re-register
|
||||
newExpiry := time.Now().Add(90 * 24 * time.Hour)
|
||||
@@ -1123,7 +1328,7 @@ func (h *AgentHandler) TriggerReboot(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Save command to database
|
||||
if err := h.commandQueries.CreateCommand(cmd); err != nil {
|
||||
if err := h.signAndCreateCommand(cmd); err != nil {
|
||||
log.Printf("Failed to create reboot command: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create reboot command"})
|
||||
return
|
||||
@@ -1179,3 +1384,13 @@ func (h *AgentHandler) GetAgentConfig(c *gin.Context) {
|
||||
"version": time.Now().Unix(), // Simple version timestamp
|
||||
})
|
||||
}
|
||||
|
||||
// getStringFromMap safely extracts a string value from a map
|
||||
func getStringFromMap(m map[string]interface{}, key string) string {
|
||||
if val, exists := m[key]; exists {
|
||||
if str, ok := val.(string); ok {
|
||||
return str
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user