Files
Redflag/aggregator-server/internal/api/handlers/agent_updates.go
Fimeg ec3ba88459 feat: machine binding and version enforcement
migration 017 adds machine_id to agents table
middleware validates X-Machine-ID header on authed routes
agent client sends machine ID with requests
MIN_AGENT_VERSION config defaults 0.1.22
version utils added for comparison

blocks config copying attacks via hardware fingerprint
old agents get 426 upgrade required
breaking: <0.1.22 agents rejected
2025-11-02 09:30:04 -05:00

401 lines
12 KiB
Go

package handlers
import (
"fmt"
"log"
"net/http"
"strconv"
"strings"
"time"
"github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries"
"github.com/Fimeg/RedFlag/aggregator-server/internal/models"
"github.com/Fimeg/RedFlag/aggregator-server/internal/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// AgentUpdateHandler handles agent update operations
type AgentUpdateHandler struct {
agentQueries *queries.AgentQueries
agentUpdateQueries *queries.AgentUpdateQueries
commandQueries *queries.CommandQueries
signingService *services.SigningService
agentHandler *AgentHandler
}
// NewAgentUpdateHandler creates a new agent update handler
func NewAgentUpdateHandler(aq *queries.AgentQueries, auq *queries.AgentUpdateQueries, cq *queries.CommandQueries, ss *services.SigningService, ah *AgentHandler) *AgentUpdateHandler {
return &AgentUpdateHandler{
agentQueries: aq,
agentUpdateQueries: auq,
commandQueries: cq,
signingService: ss,
agentHandler: ah,
}
}
// UpdateAgent handles POST /api/v1/agents/:id/update (manual agent update)
func (h *AgentUpdateHandler) UpdateAgent(c *gin.Context) {
var req models.AgentUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Verify the agent exists
agent, err := h.agentQueries.GetAgentByID(req.AgentID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"})
return
}
// Check if agent is already updating
if agent.IsUpdating {
c.JSON(http.StatusConflict, gin.H{
"error": "agent is already updating",
"current_update": agent.UpdatingToVersion,
"initiated_at": agent.UpdateInitiatedAt,
})
return
}
// Validate platform compatibility
if !h.isPlatformCompatible(agent, req.Platform) {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("platform %s is not compatible with agent %s/%s",
req.Platform, agent.OSType, agent.OSArchitecture),
})
return
}
// Get the update package
pkg, err := h.agentUpdateQueries.GetUpdatePackageByVersion(req.Version, req.Platform, agent.OSArchitecture)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("update package not found: %v", err)})
return
}
// Update agent status to "updating"
if err := h.agentQueries.UpdateAgentUpdatingStatus(req.AgentID, true, &req.Version); err != nil {
log.Printf("Failed to update agent %s status to updating: %v", req.AgentID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate update"})
return
}
// Generate nonce for replay protection
nonceUUID := uuid.New()
nonceTimestamp := time.Now()
var nonceSignature string
if h.signingService != nil {
var err error
nonceSignature, err = h.signingService.SignNonce(nonceUUID, nonceTimestamp)
if err != nil {
log.Printf("Failed to sign nonce: %v", err)
h.agentQueries.UpdateAgentUpdatingStatus(req.AgentID, false, nil) // Rollback
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to sign nonce"})
return
}
}
// Create update command for agent
commandType := "update_agent"
commandParams := map[string]interface{}{
"version": req.Version,
"platform": req.Platform,
"download_url": fmt.Sprintf("/api/v1/downloads/updates/%s", pkg.ID),
"signature": pkg.Signature,
"checksum": pkg.Checksum,
"file_size": pkg.FileSize,
"nonce_uuid": nonceUUID.String(),
"nonce_timestamp": nonceTimestamp.Format(time.RFC3339),
"nonce_signature": nonceSignature,
}
// Schedule the update if requested
if req.Scheduled != nil {
scheduledTime, err := time.Parse(time.RFC3339, *req.Scheduled)
if err != nil {
h.agentQueries.UpdateAgentUpdatingStatus(req.AgentID, false, nil) // Rollback
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid scheduled time format"})
return
}
commandParams["scheduled_at"] = scheduledTime
}
// Create the command in database
command := &models.AgentCommand{
ID: uuid.New(),
AgentID: req.AgentID,
CommandType: commandType,
Params: commandParams,
Status: models.CommandStatusPending,
Source: "web_ui",
CreatedAt: time.Now(),
}
if err := h.commandQueries.CreateCommand(command); err != nil {
// Rollback the updating status
h.agentQueries.UpdateAgentUpdatingStatus(req.AgentID, false, nil)
log.Printf("Failed to create update command for agent %s: %v", req.AgentID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create command"})
return
}
log.Printf("✅ Agent update initiated for %s: %s (%s)", agent.Hostname, req.Version, req.Platform)
response := models.AgentUpdateResponse{
Message: "Update initiated successfully",
UpdateID: command.ID.String(),
DownloadURL: fmt.Sprintf("/api/v1/downloads/updates/%s", pkg.ID),
Signature: pkg.Signature,
Checksum: pkg.Checksum,
FileSize: pkg.FileSize,
EstimatedTime: h.estimateUpdateTime(pkg.FileSize),
}
c.JSON(http.StatusOK, response)
}
// BulkUpdateAgents handles POST /api/v1/agents/bulk-update (bulk agent update)
func (h *AgentUpdateHandler) BulkUpdateAgents(c *gin.Context) {
var req models.BulkAgentUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if len(req.AgentIDs) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "no agent IDs provided"})
return
}
if len(req.AgentIDs) > 50 {
c.JSON(http.StatusBadRequest, gin.H{"error": "too many agents in bulk update (max 50)"})
return
}
// Get the update package first to validate it exists
pkg, err := h.agentUpdateQueries.GetUpdatePackageByVersion(req.Version, req.Platform, "")
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("update package not found: %v", err)})
return
}
// Validate all agents exist and are compatible
var results []map[string]interface{}
var errors []string
for _, agentID := range req.AgentIDs {
agent, err := h.agentQueries.GetAgentByID(agentID)
if err != nil {
errors = append(errors, fmt.Sprintf("Agent %s: not found", agentID))
continue
}
if agent.IsUpdating {
errors = append(errors, fmt.Sprintf("Agent %s: already updating", agentID))
continue
}
if !h.isPlatformCompatible(agent, req.Platform) {
errors = append(errors, fmt.Sprintf("Agent %s: platform incompatible", agentID))
continue
}
// Update agent status
if err := h.agentQueries.UpdateAgentUpdatingStatus(agentID, true, &req.Version); err != nil {
errors = append(errors, fmt.Sprintf("Agent %s: failed to update status", agentID))
continue
}
// Generate nonce for replay protection
nonceUUID := uuid.New()
nonceTimestamp := time.Now()
var nonceSignature string
if h.signingService != nil {
var err error
nonceSignature, err = h.signingService.SignNonce(nonceUUID, nonceTimestamp)
if err != nil {
errors = append(errors, fmt.Sprintf("Agent %s: failed to sign nonce", agentID))
h.agentQueries.UpdateAgentUpdatingStatus(agentID, false, nil)
continue
}
}
// Create update command
command := &models.AgentCommand{
ID: uuid.New(),
AgentID: agentID,
CommandType: "update_agent",
Params: map[string]interface{}{
"version": req.Version,
"platform": req.Platform,
"download_url": fmt.Sprintf("/api/v1/downloads/updates/%s", pkg.ID),
"signature": pkg.Signature,
"checksum": pkg.Checksum,
"file_size": pkg.FileSize,
"nonce_uuid": nonceUUID.String(),
"nonce_timestamp": nonceTimestamp.Format(time.RFC3339),
"nonce_signature": nonceSignature,
},
Status: models.CommandStatusPending,
Source: "web_ui_bulk",
CreatedAt: time.Now(),
}
if req.Scheduled != nil {
command.Params["scheduled_at"] = *req.Scheduled
}
if err := h.commandQueries.CreateCommand(command); err != nil {
// Rollback status
h.agentQueries.UpdateAgentUpdatingStatus(agentID, false, nil)
errors = append(errors, fmt.Sprintf("Agent %s: failed to create command", agentID))
continue
}
results = append(results, map[string]interface{}{
"agent_id": agentID,
"hostname": agent.Hostname,
"update_id": command.ID.String(),
"status": "initiated",
})
log.Printf("✅ Bulk update initiated for %s: %s (%s)", agent.Hostname, req.Version, req.Platform)
}
response := gin.H{
"message": fmt.Sprintf("Bulk update completed with %d successes and %d failures", len(results), len(errors)),
"updated": results,
"failed": errors,
"total_agents": len(req.AgentIDs),
"package_info": gin.H{
"version": pkg.Version,
"platform": pkg.Platform,
"file_size": pkg.FileSize,
"checksum": pkg.Checksum,
},
}
c.JSON(http.StatusOK, response)
}
// ListUpdatePackages handles GET /api/v1/updates/packages (list available update packages)
func (h *AgentUpdateHandler) ListUpdatePackages(c *gin.Context) {
version := c.Query("version")
platform := c.Query("platform")
limitStr := c.Query("limit")
offsetStr := c.Query("offset")
limit := 0
if limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
}
offset := 0
if offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
offset = o
}
}
packages, err := h.agentUpdateQueries.ListUpdatePackages(version, platform, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list update packages"})
return
}
c.JSON(http.StatusOK, gin.H{
"packages": packages,
"total": len(packages),
"limit": limit,
"offset": offset,
})
}
// SignUpdatePackage handles POST /api/v1/updates/packages/sign (sign a new update package)
func (h *AgentUpdateHandler) SignUpdatePackage(c *gin.Context) {
var req struct {
Version string `json:"version" binding:"required"`
Platform string `json:"platform" binding:"required"`
Architecture string `json:"architecture" binding:"required"`
BinaryPath string `json:"binary_path" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if h.signingService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "signing service not available"})
return
}
// Sign the binary
pkg, err := h.signingService.SignFile(req.BinaryPath)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to sign binary: %v", err)})
return
}
// Set additional fields
pkg.Version = req.Version
pkg.Platform = req.Platform
pkg.Architecture = req.Architecture
// Save to database
if err := h.agentUpdateQueries.CreateUpdatePackage(pkg); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to save update package: %v", err)})
return
}
log.Printf("✅ Update package signed and saved: %s %s/%s (ID: %s)",
pkg.Version, pkg.Platform, pkg.Architecture, pkg.ID)
c.JSON(http.StatusOK, gin.H{
"message": "Update package signed successfully",
"package": pkg,
})
}
// isPlatformCompatible checks if the update package is compatible with the agent
func (h *AgentUpdateHandler) isPlatformCompatible(agent *models.Agent, updatePlatform string) bool {
// Normalize platform strings
agentPlatform := strings.ToLower(agent.OSType)
updatePlatform = strings.ToLower(updatePlatform)
// Check for basic OS compatibility
if !strings.Contains(updatePlatform, agentPlatform) {
return false
}
// Check architecture compatibility if specified
if strings.Contains(updatePlatform, "amd64") && !strings.Contains(strings.ToLower(agent.OSArchitecture), "amd64") {
return false
}
if strings.Contains(updatePlatform, "arm64") && !strings.Contains(strings.ToLower(agent.OSArchitecture), "arm64") {
return false
}
if strings.Contains(updatePlatform, "386") && !strings.Contains(strings.ToLower(agent.OSArchitecture), "386") {
return false
}
return true
}
// estimateUpdateTime estimates how long an update will take based on file size
func (h *AgentUpdateHandler) estimateUpdateTime(fileSize int64) int {
// Rough estimate: 1 second per MB + 30 seconds base time
seconds := int(fileSize/1024/1024) + 30
// Cap at 5 minutes
if seconds > 300 {
seconds = 300
}
return seconds
}