Problem: Version check middleware blocked old agents from checking in to receive update commands, creating a deadlock where agents couldn't upgrade because they were blocked from checking in. Solution: Modified MachineBindingMiddleware to allow old agents checking in for commands to proceed IF they have a pending update_agent command. This allows agents to receive the update command even when below minimum version. Changes: - Added grace period logic in middleware for command endpoints - Check if agent has pending update command before blocking - If update pending, allow check-in and log it - Added HasPendingUpdateCommand() to AgentQueries for checking pending updates - Also added same method to CommandQueries for completeness This prevents the version tracking deadlock while maintaining security for agents without pending updates. NOTE: Need to test that old agents can actually receive and execute update commands when allowed through this path.
451 lines
12 KiB
Go
451 lines
12 KiB
Go
package queries
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/Fimeg/RedFlag/aggregator-server/internal/models"
|
|
"github.com/google/uuid"
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
type AgentQueries struct {
|
|
db *sqlx.DB
|
|
DB *sqlx.DB // Public field for access by config_builder
|
|
}
|
|
|
|
func NewAgentQueries(db *sqlx.DB) *AgentQueries {
|
|
return &AgentQueries{
|
|
db: db,
|
|
DB: db, // Expose for external use
|
|
}
|
|
}
|
|
|
|
// CreateAgent inserts a new agent into the database
|
|
func (q *AgentQueries) CreateAgent(agent *models.Agent) error {
|
|
query := `
|
|
INSERT INTO agents (
|
|
id, hostname, os_type, os_version, os_architecture,
|
|
agent_version, current_version, machine_id, public_key_fingerprint,
|
|
last_seen, status, metadata
|
|
) VALUES (
|
|
:id, :hostname, :os_type, :os_version, :os_architecture,
|
|
:agent_version, :current_version, :machine_id, :public_key_fingerprint,
|
|
:last_seen, :status, :metadata
|
|
)
|
|
`
|
|
_, err := q.db.NamedExec(query, agent)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create agent %s (version %s): %w", agent.Hostname, agent.CurrentVersion, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetAgentByID retrieves an agent by ID
|
|
func (q *AgentQueries) GetAgentByID(id uuid.UUID) (*models.Agent, error) {
|
|
var agent models.Agent
|
|
query := `SELECT * FROM agents WHERE id = $1`
|
|
err := q.db.Get(&agent, query, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &agent, nil
|
|
}
|
|
|
|
// UpdateAgentLastSeen updates the agent's last_seen timestamp
|
|
func (q *AgentQueries) UpdateAgentLastSeen(id uuid.UUID) error {
|
|
query := `UPDATE agents SET last_seen = $1, status = 'online' WHERE id = $2`
|
|
_, err := q.db.Exec(query, time.Now().UTC(), id)
|
|
return err
|
|
}
|
|
|
|
// UpdateAgent updates an agent's full record including metadata
|
|
func (q *AgentQueries) UpdateAgent(agent *models.Agent) error {
|
|
query := `
|
|
UPDATE agents SET
|
|
hostname = :hostname,
|
|
os_type = :os_type,
|
|
os_version = :os_version,
|
|
os_architecture = :os_architecture,
|
|
agent_version = :agent_version,
|
|
last_seen = :last_seen,
|
|
status = :status,
|
|
metadata = :metadata
|
|
WHERE id = :id
|
|
`
|
|
_, err := q.db.NamedExec(query, agent)
|
|
return err
|
|
}
|
|
|
|
// UpdateAgentMetadata updates only the metadata, last_seen, and status fields
|
|
// Used for metrics updates to avoid overwriting version tracking
|
|
func (q *AgentQueries) UpdateAgentMetadata(id uuid.UUID, metadata models.JSONB, status string, lastSeen time.Time) error {
|
|
query := `
|
|
UPDATE agents SET
|
|
last_seen = $1,
|
|
status = $2,
|
|
metadata = $3
|
|
WHERE id = $4
|
|
`
|
|
_, err := q.db.Exec(query, lastSeen, status, metadata, id)
|
|
return err
|
|
}
|
|
|
|
// ListAgents returns all agents with optional filtering
|
|
func (q *AgentQueries) ListAgents(status, osType string) ([]models.Agent, error) {
|
|
var agents []models.Agent
|
|
query := `SELECT * FROM agents WHERE 1=1`
|
|
args := []interface{}{}
|
|
argIdx := 1
|
|
|
|
if status != "" {
|
|
query += ` AND status = $` + string(rune(argIdx+'0'))
|
|
args = append(args, status)
|
|
argIdx++
|
|
}
|
|
if osType != "" {
|
|
query += ` AND os_type = $` + string(rune(argIdx+'0'))
|
|
args = append(args, osType)
|
|
argIdx++
|
|
}
|
|
|
|
query += ` ORDER BY last_seen DESC`
|
|
err := q.db.Select(&agents, query, args...)
|
|
return agents, err
|
|
}
|
|
|
|
// MarkOfflineAgents marks agents as offline if they haven't checked in recently
|
|
func (q *AgentQueries) MarkOfflineAgents(threshold time.Duration) error {
|
|
query := `
|
|
UPDATE agents
|
|
SET status = 'offline'
|
|
WHERE last_seen < $1 AND status = 'online'
|
|
`
|
|
_, err := q.db.Exec(query, time.Now().Add(-threshold))
|
|
return err
|
|
}
|
|
|
|
// GetAgentLastScan gets the last scan time from update events
|
|
func (q *AgentQueries) GetAgentLastScan(id uuid.UUID) (*time.Time, error) {
|
|
var lastScan time.Time
|
|
query := `SELECT MAX(created_at) FROM update_events WHERE agent_id = $1`
|
|
err := q.db.Get(&lastScan, query, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &lastScan, nil
|
|
}
|
|
|
|
// GetAgentWithLastScan gets agent information including last scan time
|
|
func (q *AgentQueries) GetAgentWithLastScan(id uuid.UUID) (*models.AgentWithLastScan, error) {
|
|
var agent models.AgentWithLastScan
|
|
query := `
|
|
SELECT
|
|
a.*,
|
|
(SELECT MAX(created_at) FROM update_events WHERE agent_id = a.id) as last_scan
|
|
FROM agents a
|
|
WHERE a.id = $1`
|
|
err := q.db.Get(&agent, query, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &agent, nil
|
|
}
|
|
|
|
// ListAgentsWithLastScan returns all agents with their last scan times
|
|
func (q *AgentQueries) ListAgentsWithLastScan(status, osType string) ([]models.AgentWithLastScan, error) {
|
|
var agents []models.AgentWithLastScan
|
|
query := `
|
|
SELECT
|
|
a.*,
|
|
(SELECT MAX(created_at) FROM update_events WHERE agent_id = a.id) as last_scan
|
|
FROM agents a
|
|
WHERE 1=1`
|
|
args := []interface{}{}
|
|
argIdx := 1
|
|
|
|
if status != "" {
|
|
query += ` AND a.status = $` + string(rune(argIdx+'0'))
|
|
args = append(args, status)
|
|
argIdx++
|
|
}
|
|
if osType != "" {
|
|
query += ` AND a.os_type = $` + string(rune(argIdx+'0'))
|
|
args = append(args, osType)
|
|
argIdx++
|
|
}
|
|
|
|
query += ` ORDER BY a.last_seen DESC`
|
|
err := q.db.Select(&agents, query, args...)
|
|
return agents, err
|
|
}
|
|
|
|
// UpdateAgentVersion updates the agent's version information and checks for updates
|
|
func (q *AgentQueries) UpdateAgentVersion(id uuid.UUID, currentVersion string) error {
|
|
query := `
|
|
UPDATE agents SET
|
|
current_version = $1,
|
|
last_version_check = $2
|
|
WHERE id = $3
|
|
`
|
|
_, err := q.db.Exec(query, currentVersion, time.Now().UTC(), id)
|
|
return err
|
|
}
|
|
|
|
// UpdateAgentUpdateAvailable sets whether an update is available for an agent
|
|
func (q *AgentQueries) UpdateAgentUpdateAvailable(id uuid.UUID, updateAvailable bool) error {
|
|
query := `
|
|
UPDATE agents SET
|
|
update_available = $1
|
|
WHERE id = $2
|
|
`
|
|
_, err := q.db.Exec(query, updateAvailable, id)
|
|
return err
|
|
}
|
|
|
|
// DeleteAgent removes an agent and all associated data
|
|
func (q *AgentQueries) DeleteAgent(id uuid.UUID) error {
|
|
// Start a transaction for atomic deletion
|
|
tx, err := q.db.Beginx()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Delete the agent (CASCADE will handle related records)
|
|
_, err = tx.Exec("DELETE FROM agents WHERE id = $1", id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Commit the transaction
|
|
return tx.Commit()
|
|
}
|
|
|
|
// GetActiveAgentCount returns the count of active (online) agents
|
|
func (q *AgentQueries) GetActiveAgentCount() (int, error) {
|
|
var count int
|
|
query := `SELECT COUNT(*) FROM agents WHERE status = 'online'`
|
|
err := q.db.Get(&count, query)
|
|
return count, err
|
|
}
|
|
|
|
// GetTotalAgentCount returns the total count of registered agents
|
|
func (q *AgentQueries) GetTotalAgentCount() (int, error) {
|
|
var count int
|
|
query := `SELECT COUNT(*) FROM agents`
|
|
err := q.db.Get(&count, query)
|
|
return count, err
|
|
}
|
|
|
|
// GetAgentCountByVersion returns the count of agents by version (for version compliance)
|
|
func (q *AgentQueries) GetAgentCountByVersion(minVersion string) (int, error) {
|
|
var count int
|
|
query := `SELECT COUNT(*) FROM agents WHERE current_version >= $1`
|
|
err := q.db.Get(&count, query, minVersion)
|
|
return count, err
|
|
}
|
|
|
|
// GetAgentsWithMachineBinding returns count of agents that have machine IDs set
|
|
func (q *AgentQueries) GetAgentsWithMachineBinding() (int, error) {
|
|
var count int
|
|
query := `SELECT COUNT(*) FROM agents WHERE machine_id IS NOT NULL AND machine_id != ''`
|
|
err := q.db.Get(&count, query)
|
|
return count, err
|
|
}
|
|
|
|
// UpdateAgentRebootStatus updates the reboot status for an agent
|
|
func (q *AgentQueries) UpdateAgentRebootStatus(id uuid.UUID, required bool, reason string) error {
|
|
query := `
|
|
UPDATE agents
|
|
SET reboot_required = $1,
|
|
reboot_reason = $2,
|
|
updated_at = $3
|
|
WHERE id = $4
|
|
`
|
|
_, err := q.db.Exec(query, required, reason, time.Now(), id)
|
|
return err
|
|
}
|
|
|
|
// UpdateAgentLastReboot updates the last reboot timestamp for an agent
|
|
func (q *AgentQueries) UpdateAgentLastReboot(id uuid.UUID, rebootTime time.Time) error {
|
|
query := `
|
|
UPDATE agents
|
|
SET last_reboot_at = $1,
|
|
reboot_required = FALSE,
|
|
reboot_reason = '',
|
|
updated_at = $2
|
|
WHERE id = $3
|
|
`
|
|
_, err := q.db.Exec(query, rebootTime, time.Now(), id)
|
|
return err
|
|
}
|
|
|
|
// GetAgentByMachineID retrieves an agent by its machine ID
|
|
func (q *AgentQueries) GetAgentByMachineID(machineID string) (*models.Agent, error) {
|
|
query := `
|
|
SELECT id, hostname, os_type, os_version, os_architecture, agent_version,
|
|
current_version, update_available, last_version_check, machine_id,
|
|
public_key_fingerprint, is_updating, updating_to_version,
|
|
update_initiated_at, last_seen, status, metadata, reboot_required,
|
|
last_reboot_at, reboot_reason, created_at, updated_at
|
|
FROM agents
|
|
WHERE machine_id = $1
|
|
`
|
|
|
|
var agent models.Agent
|
|
err := q.db.Get(&agent, query, machineID)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil // Return nil if not found (not an error)
|
|
}
|
|
return nil, fmt.Errorf("failed to get agent by machine ID: %w", err)
|
|
}
|
|
|
|
return &agent, nil
|
|
}
|
|
|
|
// UpdateAgentUpdatingStatus updates the agent's update status
|
|
func (q *AgentQueries) UpdateAgentUpdatingStatus(id uuid.UUID, isUpdating bool, updatingToVersion *string) error {
|
|
query := `
|
|
UPDATE agents
|
|
SET
|
|
is_updating = $1,
|
|
updating_to_version = $2,
|
|
update_initiated_at = CASE
|
|
WHEN $1 = true THEN $3
|
|
ELSE NULL
|
|
END,
|
|
updated_at = $3
|
|
WHERE id = $4
|
|
`
|
|
|
|
var versionPtr *string
|
|
if updatingToVersion != nil {
|
|
versionPtr = updatingToVersion
|
|
}
|
|
|
|
_, err := q.db.Exec(query, isUpdating, versionPtr, time.Now(), id)
|
|
return err
|
|
}
|
|
|
|
// CompleteAgentUpdate marks an agent update as successful and updates version
|
|
func (q *AgentQueries) CompleteAgentUpdate(agentID string, newVersion string) error {
|
|
query := `
|
|
UPDATE agents
|
|
SET
|
|
current_version = $2,
|
|
is_updating = false,
|
|
updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = $1
|
|
`
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
result, err := q.db.ExecContext(ctx, query, agentID, newVersion)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to complete agent update: %w", err)
|
|
}
|
|
|
|
rows, err := result.RowsAffected()
|
|
if err != nil || rows == 0 {
|
|
return fmt.Errorf("agent not found or version not updated")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CreateSystemEvent creates a new system event entry in the system_events table
|
|
func (q *AgentQueries) CreateSystemEvent(event *models.SystemEvent) error {
|
|
query := `
|
|
INSERT INTO system_events (
|
|
id, agent_id, event_type, event_subtype, severity, component, message, metadata, created_at
|
|
) VALUES (
|
|
:id, :agent_id, :event_type, :event_subtype, :severity, :component, :message, :metadata, :created_at
|
|
)
|
|
`
|
|
_, err := q.db.NamedExec(query, event)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create system event: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetAgentEvents retrieves system events for an agent with optional severity filtering
|
|
func (q *AgentQueries) GetAgentEvents(agentID uuid.UUID, severity string, limit int) ([]models.SystemEvent, error) {
|
|
query := `
|
|
SELECT id, agent_id, event_type, event_subtype, severity, component,
|
|
message, metadata, created_at
|
|
FROM system_events
|
|
WHERE agent_id = $1
|
|
ORDER BY created_at DESC
|
|
LIMIT $2
|
|
`
|
|
args := []interface{}{agentID, limit}
|
|
|
|
if severity != "" {
|
|
query = `
|
|
SELECT id, agent_id, event_type, event_subtype, severity, component,
|
|
message, metadata, created_at
|
|
FROM system_events
|
|
WHERE agent_id = $1 AND severity = ANY(string_to_array($2, ','))
|
|
ORDER BY created_at DESC
|
|
LIMIT $3
|
|
`
|
|
args = []interface{}{agentID, severity, limit}
|
|
}
|
|
|
|
var events []models.SystemEvent
|
|
err := q.db.Select(&events, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch agent events: %w", err)
|
|
}
|
|
|
|
return events, nil
|
|
}
|
|
|
|
// SetAgentUpdating marks an agent as updating with nonce
|
|
func (q *AgentQueries) SetAgentUpdating(agentID string, isUpdating bool, targetVersion string) error {
|
|
query := `
|
|
UPDATE agents
|
|
SET is_updating = $2, updating_to_version = $3, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = $1
|
|
`
|
|
|
|
_, err := q.db.Exec(query, agentID, isUpdating, targetVersion)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to set agent updating state: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// HasPendingUpdateCommand checks if an agent has a pending update_agent command
|
|
// This is used to allow old agents to check in and receive updates even if they're below minimum version
|
|
func (q *AgentQueries) HasPendingUpdateCommand(agentID string) (bool, error) {
|
|
// Check if agent_id is a valid UUID
|
|
agentUUID, err := uuid.Parse(agentID)
|
|
if err != nil {
|
|
return false, fmt.Errorf("invalid agent ID: %w", err)
|
|
}
|
|
|
|
var count int
|
|
query := `
|
|
SELECT COUNT(*)
|
|
FROM agent_commands
|
|
WHERE agent_id = $1
|
|
AND command_type = 'update_agent'
|
|
AND status = 'pending'
|
|
`
|
|
|
|
err = q.db.Get(&count, query, agentUUID)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to check for pending update commands: %w", err)
|
|
}
|
|
|
|
return count > 0, nil
|
|
}
|