v0.1.16: Security overhaul and systematic deployment preparation
Breaking changes for clean alpha releases: - JWT authentication with user-provided secrets (no more development defaults) - Registration token system for secure agent enrollment - Rate limiting with user-adjustable settings - Enhanced agent configuration with proxy support - Interactive server setup wizard (--setup flag) - Heartbeat architecture separation for better UX - Package status synchronization fixes - Accurate timestamp tracking for RMM features Setup process for new installations: 1. docker-compose up -d postgres 2. ./redflag-server --setup 3. ./redflag-server --migrate 4. ./redflag-server 5. Generate tokens via admin UI 6. Deploy agents with registration tokens
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
@@ -107,15 +108,16 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
|
||||
// Try to parse optional system metrics from request body
|
||||
var metrics struct {
|
||||
CPUPercent float64 `json:"cpu_percent,omitempty"`
|
||||
MemoryPercent float64 `json:"memory_percent,omitempty"`
|
||||
MemoryUsedGB float64 `json:"memory_used_gb,omitempty"`
|
||||
MemoryTotalGB float64 `json:"memory_total_gb,omitempty"`
|
||||
DiskUsedGB float64 `json:"disk_used_gb,omitempty"`
|
||||
DiskTotalGB float64 `json:"disk_total_gb,omitempty"`
|
||||
DiskPercent float64 `json:"disk_percent,omitempty"`
|
||||
Uptime string `json:"uptime,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
CPUPercent float64 `json:"cpu_percent,omitempty"`
|
||||
MemoryPercent float64 `json:"memory_percent,omitempty"`
|
||||
MemoryUsedGB float64 `json:"memory_used_gb,omitempty"`
|
||||
MemoryTotalGB float64 `json:"memory_total_gb,omitempty"`
|
||||
DiskUsedGB float64 `json:"disk_used_gb,omitempty"`
|
||||
DiskTotalGB float64 `json:"disk_total_gb,omitempty"`
|
||||
DiskPercent float64 `json:"disk_percent,omitempty"`
|
||||
Uptime string `json:"uptime,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Parse metrics if provided (optional, won't fail if empty)
|
||||
@@ -130,21 +132,21 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
|
||||
// Always handle version information if provided
|
||||
if metrics.Version != "" {
|
||||
// Get current agent to preserve existing metadata
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err == nil && agent.Metadata != nil {
|
||||
// Update agent's current version
|
||||
if err := h.agentQueries.UpdateAgentVersion(agentID, metrics.Version); err != nil {
|
||||
log.Printf("Warning: Failed to update agent version: %v", err)
|
||||
} else {
|
||||
// Check if update is available
|
||||
updateAvailable := utils.IsNewerVersion(h.latestAgentVersion, metrics.Version)
|
||||
// Update agent's current version in database (primary source of truth)
|
||||
if err := h.agentQueries.UpdateAgentVersion(agentID, metrics.Version); err != nil {
|
||||
log.Printf("Warning: Failed to update agent version: %v", err)
|
||||
} else {
|
||||
// Check if update is available
|
||||
updateAvailable := utils.IsNewerVersion(h.latestAgentVersion, metrics.Version)
|
||||
|
||||
// Update agent's update availability status
|
||||
if err := h.agentQueries.UpdateAgentUpdateAvailable(agentID, updateAvailable); err != nil {
|
||||
log.Printf("Warning: Failed to update agent update availability: %v", err)
|
||||
}
|
||||
// Update agent's update availability status
|
||||
if err := h.agentQueries.UpdateAgentUpdateAvailable(agentID, updateAvailable); err != nil {
|
||||
log.Printf("Warning: Failed to update agent update availability: %v", err)
|
||||
}
|
||||
|
||||
// Get current agent for logging and metadata update
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err == nil {
|
||||
// Log version check
|
||||
if updateAvailable {
|
||||
log.Printf("🔄 Agent %s (%s) version %s has update available: %s",
|
||||
@@ -154,11 +156,20 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
agent.Hostname, agentID, metrics.Version)
|
||||
}
|
||||
|
||||
// Store version in metadata as well
|
||||
// Store version in metadata as well (for backwards compatibility)
|
||||
// Initialize metadata if nil
|
||||
if agent.Metadata == nil {
|
||||
agent.Metadata = make(models.JSONB)
|
||||
}
|
||||
agent.Metadata["reported_version"] = metrics.Version
|
||||
agent.Metadata["latest_version"] = h.latestAgentVersion
|
||||
agent.Metadata["update_available"] = updateAvailable
|
||||
agent.Metadata["version_checked_at"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
// Update agent metadata
|
||||
if err := h.agentQueries.UpdateAgent(agent); err != nil {
|
||||
log.Printf("Warning: Failed to update agent metadata: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -179,6 +190,29 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
agent.Metadata["uptime"] = metrics.Uptime
|
||||
agent.Metadata["metrics_updated_at"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
// Process heartbeat metadata from agent check-ins
|
||||
if metrics.Metadata != nil {
|
||||
if rapidPollingEnabled, exists := metrics.Metadata["rapid_polling_enabled"]; exists {
|
||||
if rapidPollingUntil, exists := metrics.Metadata["rapid_polling_until"]; exists {
|
||||
// Parse the until timestamp
|
||||
if untilTime, err := time.Parse(time.RFC3339, rapidPollingUntil.(string)); err == nil {
|
||||
// Validate if rapid polling is still active (not expired)
|
||||
isActive := rapidPollingEnabled.(bool) && time.Now().Before(untilTime)
|
||||
|
||||
// Store heartbeat status in agent metadata
|
||||
agent.Metadata["rapid_polling_enabled"] = rapidPollingEnabled
|
||||
agent.Metadata["rapid_polling_until"] = rapidPollingUntil
|
||||
agent.Metadata["rapid_polling_active"] = isActive
|
||||
|
||||
log.Printf("[Heartbeat] Agent %s heartbeat status: enabled=%v, until=%v, active=%v",
|
||||
agentID, rapidPollingEnabled, rapidPollingUntil, isActive)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Failed to parse rapid_polling_until timestamp for agent %s: %v", agentID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update agent with new metadata
|
||||
if err := h.agentQueries.UpdateAgent(agent); err != nil {
|
||||
log.Printf("Warning: Failed to update agent metrics: %v", err)
|
||||
@@ -192,6 +226,37 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Process heartbeat metadata from agent check-ins
|
||||
if metrics.Metadata != nil {
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err == nil && agent.Metadata != nil {
|
||||
if rapidPollingEnabled, exists := metrics.Metadata["rapid_polling_enabled"]; exists {
|
||||
if rapidPollingUntil, exists := metrics.Metadata["rapid_polling_until"]; exists {
|
||||
// Parse the until timestamp
|
||||
if untilTime, err := time.Parse(time.RFC3339, rapidPollingUntil.(string)); err == nil {
|
||||
// Validate if rapid polling is still active (not expired)
|
||||
isActive := rapidPollingEnabled.(bool) && time.Now().Before(untilTime)
|
||||
|
||||
// Store heartbeat status in agent metadata
|
||||
agent.Metadata["rapid_polling_enabled"] = rapidPollingEnabled
|
||||
agent.Metadata["rapid_polling_until"] = rapidPollingUntil
|
||||
agent.Metadata["rapid_polling_active"] = isActive
|
||||
|
||||
log.Printf("[Heartbeat] Agent %s heartbeat status: enabled=%v, until=%v, active=%v",
|
||||
agentID, rapidPollingEnabled, rapidPollingUntil, isActive)
|
||||
|
||||
// Update agent with new metadata
|
||||
if err := h.agentQueries.UpdateAgent(agent); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to update agent heartbeat metadata: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Failed to parse rapid_polling_until timestamp for agent %s: %v", agentID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for version updates for agents that don't send version in metrics
|
||||
// This ensures agents like Metis that don't report version still get update checks
|
||||
if metrics.Version == "" {
|
||||
@@ -239,8 +304,97 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
h.commandQueries.MarkCommandSent(cmd.ID)
|
||||
}
|
||||
|
||||
// Check if rapid polling should be enabled
|
||||
var rapidPolling *models.RapidPollingConfig
|
||||
|
||||
// Enable rapid polling if there are commands to process
|
||||
if len(commandItems) > 0 {
|
||||
rapidPolling = &models.RapidPollingConfig{
|
||||
Enabled: true,
|
||||
Until: time.Now().Add(10 * time.Minute).Format(time.RFC3339), // 10 minutes default
|
||||
}
|
||||
} else {
|
||||
// Check if agent has rapid polling already configured in metadata
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err == nil && agent.Metadata != nil {
|
||||
if enabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && enabled {
|
||||
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
|
||||
if until, err := time.Parse(time.RFC3339, untilStr); err == nil && time.Now().Before(until) {
|
||||
rapidPolling = &models.RapidPollingConfig{
|
||||
Enabled: true,
|
||||
Until: untilStr,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Detect stale heartbeat state: Server thinks it's active, but agent didn't report it
|
||||
// This happens when agent restarts without heartbeat mode
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err == nil && agent.Metadata != nil {
|
||||
// Check if server metadata shows heartbeat active
|
||||
if serverEnabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && serverEnabled {
|
||||
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
|
||||
if until, err := time.Parse(time.RFC3339, untilStr); err == nil && time.Now().Before(until) {
|
||||
// Server thinks heartbeat is active and not expired
|
||||
// Check if agent is reporting heartbeat in this check-in
|
||||
agentReportingHeartbeat := false
|
||||
if metrics.Metadata != nil {
|
||||
if agentEnabled, exists := metrics.Metadata["rapid_polling_enabled"]; exists {
|
||||
agentReportingHeartbeat = agentEnabled.(bool)
|
||||
}
|
||||
}
|
||||
|
||||
// If agent is NOT reporting heartbeat but server expects it → stale state
|
||||
if !agentReportingHeartbeat {
|
||||
log.Printf("[Heartbeat] Stale heartbeat detected for agent %s - server expected active until %s, but agent not reporting heartbeat (likely restarted)",
|
||||
agentID, until.Format(time.RFC3339))
|
||||
|
||||
// Clear stale heartbeat state
|
||||
agent.Metadata["rapid_polling_enabled"] = false
|
||||
delete(agent.Metadata, "rapid_polling_until")
|
||||
|
||||
if err := h.agentQueries.UpdateAgent(agent); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to clear stale heartbeat state: %v", err)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Cleared stale heartbeat state for agent %s", agentID)
|
||||
|
||||
// Create audit command to show in history
|
||||
now := time.Now()
|
||||
auditCmd := &models.AgentCommand{
|
||||
ID: uuid.New(),
|
||||
AgentID: agentID,
|
||||
CommandType: models.CommandTypeDisableHeartbeat,
|
||||
Params: models.JSONB{},
|
||||
Status: models.CommandStatusCompleted,
|
||||
Result: models.JSONB{
|
||||
"message": "Heartbeat cleared - agent restarted without active heartbeat mode",
|
||||
},
|
||||
CreatedAt: now,
|
||||
SentAt: &now,
|
||||
CompletedAt: &now,
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(auditCmd); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to create audit command for stale heartbeat: %v", err)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Created audit trail for stale heartbeat cleanup (agent %s)", agentID)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear rapidPolling response since we just disabled it
|
||||
rapidPolling = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response := models.CommandsResponse{
|
||||
Commands: commandItems,
|
||||
Commands: commandItems,
|
||||
RapidPolling: rapidPolling,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
@@ -312,6 +466,124 @@ func (h *AgentHandler) TriggerScan(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "scan triggered", "command_id": cmd.ID})
|
||||
}
|
||||
|
||||
// TriggerHeartbeat creates a heartbeat toggle command for an agent
|
||||
func (h *AgentHandler) TriggerHeartbeat(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
agentID, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var request struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
DurationMinutes int `json:"duration_minutes"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Determine command type based on enabled flag
|
||||
commandType := models.CommandTypeDisableHeartbeat
|
||||
if request.Enabled {
|
||||
commandType = models.CommandTypeEnableHeartbeat
|
||||
}
|
||||
|
||||
// Create heartbeat command with duration parameter
|
||||
cmd := &models.AgentCommand{
|
||||
ID: uuid.New(),
|
||||
AgentID: agentID,
|
||||
CommandType: commandType,
|
||||
Params: models.JSONB{
|
||||
"duration_minutes": request.DurationMinutes,
|
||||
},
|
||||
Status: models.CommandStatusPending,
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(cmd); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create heartbeat command"})
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Clean up previous heartbeat commands for this agent (only for enable commands)
|
||||
// if request.Enabled {
|
||||
// // Mark previous heartbeat commands as 'replaced' to clean up Live Operations view
|
||||
// if err := h.commandQueries.MarkPreviousHeartbeatCommandsReplaced(agentID, cmd.ID); err != nil {
|
||||
// log.Printf("Warning: Failed to mark previous heartbeat commands as replaced: %v", err)
|
||||
// // Don't fail the request, just log the warning
|
||||
// } else {
|
||||
// log.Printf("[Heartbeat] Cleaned up previous heartbeat commands for agent %s", agentID)
|
||||
// }
|
||||
// }
|
||||
|
||||
action := "disabled"
|
||||
if request.Enabled {
|
||||
action = "enabled"
|
||||
}
|
||||
|
||||
log.Printf("💓 Heartbeat %s command created for agent %s (duration: %d minutes)",
|
||||
action, agentID, request.DurationMinutes)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": fmt.Sprintf("heartbeat %s command sent", action),
|
||||
"command_id": cmd.ID,
|
||||
"enabled": request.Enabled,
|
||||
})
|
||||
}
|
||||
|
||||
// GetHeartbeatStatus returns the current heartbeat status for an agent
|
||||
func (h *AgentHandler) GetHeartbeatStatus(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
agentID, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get agent and their heartbeat metadata
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract heartbeat information from metadata
|
||||
response := gin.H{
|
||||
"enabled": false,
|
||||
"until": nil,
|
||||
"active": false,
|
||||
"duration_minutes": 0,
|
||||
}
|
||||
|
||||
if agent.Metadata != nil {
|
||||
// Check if heartbeat is enabled in metadata
|
||||
if enabled, exists := agent.Metadata["rapid_polling_enabled"]; exists {
|
||||
response["enabled"] = enabled.(bool)
|
||||
|
||||
// If enabled, get the until time and check if still active
|
||||
if enabled.(bool) {
|
||||
if untilStr, exists := agent.Metadata["rapid_polling_until"]; exists {
|
||||
response["until"] = untilStr.(string)
|
||||
|
||||
// Parse the until timestamp to check if still active
|
||||
if untilTime, err := time.Parse(time.RFC3339, untilStr.(string)); err == nil {
|
||||
response["active"] = time.Now().Before(untilTime)
|
||||
}
|
||||
}
|
||||
|
||||
// Get duration if available
|
||||
if duration, exists := agent.Metadata["rapid_polling_duration_minutes"]; exists {
|
||||
response["duration_minutes"] = duration.(float64)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// TriggerUpdate creates an update command for an agent
|
||||
func (h *AgentHandler) TriggerUpdate(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
@@ -541,3 +813,114 @@ func (h *AgentHandler) ReportSystemInfo(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "system info updated successfully"})
|
||||
}
|
||||
|
||||
// EnableRapidPollingMode enables rapid polling for an agent by updating metadata
|
||||
func (h *AgentHandler) EnableRapidPollingMode(agentID uuid.UUID, durationMinutes int) error {
|
||||
// Get current agent
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent: %w", err)
|
||||
}
|
||||
|
||||
// Calculate new rapid polling end time
|
||||
newRapidPollingUntil := time.Now().Add(time.Duration(durationMinutes) * time.Minute)
|
||||
|
||||
// Update agent metadata with rapid polling settings
|
||||
if agent.Metadata == nil {
|
||||
agent.Metadata = models.JSONB{}
|
||||
}
|
||||
|
||||
// Check if rapid polling is already active
|
||||
if enabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && enabled {
|
||||
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
|
||||
if currentUntil, err := time.Parse(time.RFC3339, untilStr); err == nil {
|
||||
// If current heartbeat expires later than the new duration, keep the longer duration
|
||||
if currentUntil.After(newRapidPollingUntil) {
|
||||
log.Printf("💓 Heartbeat already active for agent %s (%s), keeping longer duration (expires: %s)",
|
||||
agent.Hostname, agentID, currentUntil.Format(time.RFC3339))
|
||||
return nil
|
||||
}
|
||||
// Otherwise extend the heartbeat
|
||||
log.Printf("💓 Extending heartbeat for agent %s (%s) from %s to %s",
|
||||
agent.Hostname, agentID,
|
||||
currentUntil.Format(time.RFC3339),
|
||||
newRapidPollingUntil.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Printf("💓 Enabling heartbeat mode for agent %s (%s) for %d minutes",
|
||||
agent.Hostname, agentID, durationMinutes)
|
||||
}
|
||||
|
||||
// Set/update rapid polling settings
|
||||
agent.Metadata["rapid_polling_enabled"] = true
|
||||
agent.Metadata["rapid_polling_until"] = newRapidPollingUntil.Format(time.RFC3339)
|
||||
|
||||
// Update agent in database
|
||||
if err := h.agentQueries.UpdateAgent(agent); err != nil {
|
||||
return fmt.Errorf("failed to update agent with rapid polling: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRapidPollingMode enables rapid polling mode for an agent
|
||||
// TODO: Rate limiting should be implemented for rapid polling endpoints to prevent abuse (technical debt)
|
||||
func (h *AgentHandler) SetRapidPollingMode(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
agentID, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if agent exists
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
DurationMinutes int `json:"duration_minutes" binding:"required,min=1,max=60"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate rapid polling end time
|
||||
rapidPollingUntil := time.Now().Add(time.Duration(req.DurationMinutes) * time.Minute)
|
||||
|
||||
// Update agent metadata with rapid polling settings
|
||||
if agent.Metadata == nil {
|
||||
agent.Metadata = models.JSONB{}
|
||||
}
|
||||
agent.Metadata["rapid_polling_enabled"] = req.Enabled
|
||||
agent.Metadata["rapid_polling_until"] = rapidPollingUntil.Format(time.RFC3339)
|
||||
|
||||
// Update agent in database
|
||||
if err := h.agentQueries.UpdateAgent(agent); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to update agent"})
|
||||
return
|
||||
}
|
||||
|
||||
status := "disabled"
|
||||
duration := 0
|
||||
if req.Enabled {
|
||||
status = "enabled"
|
||||
duration = req.DurationMinutes
|
||||
}
|
||||
|
||||
log.Printf("🚀 Rapid polling mode %s for agent %s (%s) for %d minutes",
|
||||
status, agent.Hostname, agentID, duration)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": fmt.Sprintf("Rapid polling mode %s", status),
|
||||
"enabled": req.Enabled,
|
||||
"duration_minutes": req.DurationMinutes,
|
||||
"rapid_polling_until": rapidPollingUntil,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -378,7 +378,7 @@ func (h *DockerHandler) RejectUpdate(c *gin.Context) {
|
||||
}
|
||||
|
||||
// For now, we'll mark as rejected (this would need a proper reject method in queries)
|
||||
if err := h.updateQueries.UpdatePackageStatus(update.AgentID, "docker", update.PackageName, "rejected", nil); err != nil {
|
||||
if err := h.updateQueries.UpdatePackageStatus(update.AgentID, "docker", update.PackageName, "rejected", nil, nil); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to reject Docker update"})
|
||||
return
|
||||
}
|
||||
|
||||
146
aggregator-server/internal/api/handlers/rate_limits.go
Normal file
146
aggregator-server/internal/api/handlers/rate_limits.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/aggregator-project/aggregator-server/internal/api/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type RateLimitHandler struct {
|
||||
rateLimiter *middleware.RateLimiter
|
||||
}
|
||||
|
||||
func NewRateLimitHandler(rateLimiter *middleware.RateLimiter) *RateLimitHandler {
|
||||
return &RateLimitHandler{
|
||||
rateLimiter: rateLimiter,
|
||||
}
|
||||
}
|
||||
|
||||
// GetRateLimitSettings returns current rate limit configuration
|
||||
func (h *RateLimitHandler) GetRateLimitSettings(c *gin.Context) {
|
||||
settings := h.rateLimiter.GetSettings()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"settings": settings,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRateLimitSettings updates rate limit configuration
|
||||
func (h *RateLimitHandler) UpdateRateLimitSettings(c *gin.Context) {
|
||||
var settings middleware.RateLimitSettings
|
||||
if err := c.ShouldBindJSON(&settings); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate settings
|
||||
if err := h.validateRateLimitSettings(settings); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Update rate limiter settings
|
||||
h.rateLimiter.UpdateSettings(settings)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Rate limit settings updated successfully",
|
||||
"settings": settings,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// ResetRateLimitSettings resets to default values
|
||||
func (h *RateLimitHandler) ResetRateLimitSettings(c *gin.Context) {
|
||||
defaultSettings := middleware.DefaultRateLimitSettings()
|
||||
h.rateLimiter.UpdateSettings(defaultSettings)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Rate limit settings reset to defaults",
|
||||
"settings": defaultSettings,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// GetRateLimitStats returns current rate limit statistics
|
||||
func (h *RateLimitHandler) GetRateLimitStats(c *gin.Context) {
|
||||
settings := h.rateLimiter.GetSettings()
|
||||
|
||||
// Calculate total requests and windows
|
||||
stats := gin.H{
|
||||
"total_configured_limits": 6,
|
||||
"enabled_limits": 0,
|
||||
"total_requests_per_minute": 0,
|
||||
"settings": settings,
|
||||
}
|
||||
|
||||
// Count enabled limits and total requests
|
||||
for _, config := range []middleware.RateLimitConfig{
|
||||
settings.AgentRegistration,
|
||||
settings.AgentCheckIn,
|
||||
settings.AgentReports,
|
||||
settings.AdminTokenGen,
|
||||
settings.AdminOperations,
|
||||
settings.PublicAccess,
|
||||
} {
|
||||
if config.Enabled {
|
||||
stats["enabled_limits"] = stats["enabled_limits"].(int) + 1
|
||||
}
|
||||
stats["total_requests_per_minute"] = stats["total_requests_per_minute"].(int) + config.Requests
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// CleanupRateLimitEntries manually triggers cleanup of expired entries
|
||||
func (h *RateLimitHandler) CleanupRateLimitEntries(c *gin.Context) {
|
||||
h.rateLimiter.CleanupExpiredEntries()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Rate limit entries cleanup completed",
|
||||
"timestamp": time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// validateRateLimitSettings validates the provided rate limit settings
|
||||
func (h *RateLimitHandler) validateRateLimitSettings(settings middleware.RateLimitSettings) error {
|
||||
// Validate each configuration
|
||||
validations := []struct {
|
||||
name string
|
||||
config middleware.RateLimitConfig
|
||||
}{
|
||||
{"agent_registration", settings.AgentRegistration},
|
||||
{"agent_checkin", settings.AgentCheckIn},
|
||||
{"agent_reports", settings.AgentReports},
|
||||
{"admin_token_generation", settings.AdminTokenGen},
|
||||
{"admin_operations", settings.AdminOperations},
|
||||
{"public_access", settings.PublicAccess},
|
||||
}
|
||||
|
||||
for _, validation := range validations {
|
||||
if validation.config.Requests <= 0 {
|
||||
return fmt.Errorf("%s: requests must be greater than 0", validation.name)
|
||||
}
|
||||
if validation.config.Window <= 0 {
|
||||
return fmt.Errorf("%s: window must be greater than 0", validation.name)
|
||||
}
|
||||
if validation.config.Window > 24*time.Hour {
|
||||
return fmt.Errorf("%s: window cannot exceed 24 hours", validation.name)
|
||||
}
|
||||
if validation.config.Requests > 1000 {
|
||||
return fmt.Errorf("%s: requests cannot exceed 1000 per window", validation.name)
|
||||
}
|
||||
}
|
||||
|
||||
// Specific validations for different endpoint types
|
||||
if settings.AgentRegistration.Requests > 10 {
|
||||
return fmt.Errorf("agent_registration: requests should not exceed 10 per minute for security")
|
||||
}
|
||||
if settings.PublicAccess.Requests > 50 {
|
||||
return fmt.Errorf("public_access: requests should not exceed 50 per minute for security")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
284
aggregator-server/internal/api/handlers/registration_tokens.go
Normal file
284
aggregator-server/internal/api/handlers/registration_tokens.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/aggregator-project/aggregator-server/internal/config"
|
||||
"github.com/aggregator-project/aggregator-server/internal/database/queries"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type RegistrationTokenHandler struct {
|
||||
tokenQueries *queries.RegistrationTokenQueries
|
||||
agentQueries *queries.AgentQueries
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewRegistrationTokenHandler(tokenQueries *queries.RegistrationTokenQueries, agentQueries *queries.AgentQueries, config *config.Config) *RegistrationTokenHandler {
|
||||
return &RegistrationTokenHandler{
|
||||
tokenQueries: tokenQueries,
|
||||
agentQueries: agentQueries,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRegistrationToken creates a new registration token
|
||||
func (h *RegistrationTokenHandler) GenerateRegistrationToken(c *gin.Context) {
|
||||
var request struct {
|
||||
Label string `json:"label" binding:"required"`
|
||||
ExpiresIn string `json:"expires_in"` // e.g., "24h", "7d", "168h"
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Check agent seat limit (security, not licensing)
|
||||
activeAgents, err := h.agentQueries.GetActiveAgentCount()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check agent count"})
|
||||
return
|
||||
}
|
||||
|
||||
if activeAgents >= h.config.AgentRegistration.MaxSeats {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "Maximum agent seats reached",
|
||||
"limit": h.config.AgentRegistration.MaxSeats,
|
||||
"current": activeAgents,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse expiration duration
|
||||
expiresIn := request.ExpiresIn
|
||||
if expiresIn == "" {
|
||||
expiresIn = h.config.AgentRegistration.TokenExpiry
|
||||
}
|
||||
|
||||
duration, err := time.ParseDuration(expiresIn)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid expiration format. Use formats like '24h', '7d', '168h'"})
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(duration)
|
||||
if duration > 168*time.Hour { // Max 7 days
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Token expiration cannot exceed 7 days"})
|
||||
return
|
||||
}
|
||||
|
||||
// Generate secure token
|
||||
token, err := config.GenerateSecureToken()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Create metadata with default values
|
||||
metadata := request.Metadata
|
||||
if metadata == nil {
|
||||
metadata = make(map[string]interface{})
|
||||
}
|
||||
metadata["server_url"] = c.Request.Host
|
||||
metadata["expires_in"] = expiresIn
|
||||
|
||||
// Store token in database
|
||||
err = h.tokenQueries.CreateRegistrationToken(token, request.Label, expiresAt, metadata)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Build install command
|
||||
serverURL := c.Request.Host
|
||||
if serverURL == "" {
|
||||
serverURL = "localhost:8080" // Fallback for development
|
||||
}
|
||||
installCommand := "curl -sfL https://" + serverURL + "/install | bash -s -- " + token
|
||||
|
||||
response := gin.H{
|
||||
"token": token,
|
||||
"label": request.Label,
|
||||
"expires_at": expiresAt,
|
||||
"install_command": installCommand,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// ListRegistrationTokens returns all registration tokens with pagination
|
||||
func (h *RegistrationTokenHandler) ListRegistrationTokens(c *gin.Context) {
|
||||
// Parse pagination parameters
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "50"))
|
||||
status := c.Query("status")
|
||||
|
||||
// Validate pagination
|
||||
if limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
offset := (page - 1) * limit
|
||||
|
||||
var tokens []queries.RegistrationToken
|
||||
var err error
|
||||
|
||||
if status != "" {
|
||||
// TODO: Add filtered queries by status
|
||||
tokens, err = h.tokenQueries.GetAllRegistrationTokens(limit, offset)
|
||||
} else {
|
||||
tokens, err = h.tokenQueries.GetAllRegistrationTokens(limit, offset)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list tokens"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get token usage stats
|
||||
stats, err := h.tokenQueries.GetTokenUsageStats()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get token stats"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"tokens": tokens,
|
||||
"pagination": gin.H{
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
},
|
||||
"stats": stats,
|
||||
"seat_usage": gin.H{
|
||||
"current": func() int {
|
||||
count, _ := h.agentQueries.GetActiveAgentCount()
|
||||
return count
|
||||
}(),
|
||||
"max": h.config.AgentRegistration.MaxSeats,
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetActiveRegistrationTokens returns only active tokens
|
||||
func (h *RegistrationTokenHandler) GetActiveRegistrationTokens(c *gin.Context) {
|
||||
tokens, err := h.tokenQueries.GetActiveRegistrationTokens()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get active tokens"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"tokens": tokens})
|
||||
}
|
||||
|
||||
// RevokeRegistrationToken revokes a registration token
|
||||
func (h *RegistrationTokenHandler) RevokeRegistrationToken(c *gin.Context) {
|
||||
token := c.Param("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Token is required"})
|
||||
return
|
||||
}
|
||||
|
||||
var request struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
c.ShouldBindJSON(&request) // Reason is optional
|
||||
|
||||
reason := request.Reason
|
||||
if reason == "" {
|
||||
reason = "Revoked via API"
|
||||
}
|
||||
|
||||
err := h.tokenQueries.RevokeRegistrationToken(token, reason)
|
||||
if err != nil {
|
||||
if err.Error() == "token not found or already used/revoked" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Token not found or already used/revoked"})
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to revoke token"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Token revoked successfully"})
|
||||
}
|
||||
|
||||
// ValidateRegistrationToken checks if a token is valid (for testing/debugging)
|
||||
func (h *RegistrationTokenHandler) ValidateRegistrationToken(c *gin.Context) {
|
||||
token := c.Query("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Token query parameter is required"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.tokenQueries.ValidateRegistrationToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"valid": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"valid": true,
|
||||
"token": tokenInfo,
|
||||
})
|
||||
}
|
||||
|
||||
// CleanupExpiredTokens performs cleanup of expired tokens
|
||||
func (h *RegistrationTokenHandler) CleanupExpiredTokens(c *gin.Context) {
|
||||
count, err := h.tokenQueries.CleanupExpiredTokens()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to cleanup expired tokens"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Cleanup completed",
|
||||
"cleaned": count,
|
||||
})
|
||||
}
|
||||
|
||||
// GetTokenStats returns comprehensive token usage statistics
|
||||
func (h *RegistrationTokenHandler) GetTokenStats(c *gin.Context) {
|
||||
// Get token stats
|
||||
tokenStats, err := h.tokenQueries.GetTokenUsageStats()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get token stats"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get agent count
|
||||
activeAgentCount, err := h.agentQueries.GetActiveAgentCount()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get agent count"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"token_stats": tokenStats,
|
||||
"agent_usage": gin.H{
|
||||
"active_agents": activeAgentCount,
|
||||
"max_seats": h.config.AgentRegistration.MaxSeats,
|
||||
"available": h.config.AgentRegistration.MaxSeats - activeAgentCount,
|
||||
},
|
||||
"security_limits": gin.H{
|
||||
"max_tokens_per_request": h.config.AgentRegistration.MaxTokens,
|
||||
"max_token_duration": "7 days",
|
||||
"token_expiry_default": h.config.AgentRegistration.TokenExpiry,
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -16,16 +17,42 @@ type UpdateHandler struct {
|
||||
updateQueries *queries.UpdateQueries
|
||||
agentQueries *queries.AgentQueries
|
||||
commandQueries *queries.CommandQueries
|
||||
agentHandler *AgentHandler
|
||||
}
|
||||
|
||||
func NewUpdateHandler(uq *queries.UpdateQueries, aq *queries.AgentQueries, cq *queries.CommandQueries) *UpdateHandler {
|
||||
func NewUpdateHandler(uq *queries.UpdateQueries, aq *queries.AgentQueries, cq *queries.CommandQueries, ah *AgentHandler) *UpdateHandler {
|
||||
return &UpdateHandler{
|
||||
updateQueries: uq,
|
||||
agentQueries: aq,
|
||||
commandQueries: cq,
|
||||
agentHandler: ah,
|
||||
}
|
||||
}
|
||||
|
||||
// shouldEnableHeartbeat checks if heartbeat is already active for an agent
|
||||
// Returns true if heartbeat should be enabled (i.e., not already active or expired)
|
||||
func (h *UpdateHandler) shouldEnableHeartbeat(agentID uuid.UUID, durationMinutes int) (bool, error) {
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to get agent %s for heartbeat check: %v", agentID, err)
|
||||
return true, nil // Enable heartbeat by default if we can't check
|
||||
}
|
||||
|
||||
// Check if rapid polling is already enabled and not expired
|
||||
if enabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && enabled {
|
||||
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
|
||||
until, err := time.Parse(time.RFC3339, untilStr)
|
||||
if err == nil && until.After(time.Now().Add(5*time.Minute)) {
|
||||
// Heartbeat is already active for sufficient time
|
||||
log.Printf("[Heartbeat] Agent %s already has active heartbeat until %s (skipping)", agentID, untilStr)
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ReportUpdates handles update reports from agents using event sourcing
|
||||
func (h *UpdateHandler) ReportUpdates(c *gin.Context) {
|
||||
agentID := c.MustGet("agent_id").(uuid.UUID)
|
||||
@@ -172,7 +199,7 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
log := &models.UpdateLog{
|
||||
logEntry := &models.UpdateLog{
|
||||
ID: uuid.New(),
|
||||
AgentID: agentID,
|
||||
Action: req.Action,
|
||||
@@ -185,7 +212,7 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Store the log entry
|
||||
if err := h.updateQueries.CreateUpdateLog(log); err != nil {
|
||||
if err := h.updateQueries.CreateUpdateLog(logEntry); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save log"})
|
||||
return
|
||||
}
|
||||
@@ -207,10 +234,34 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Update command status based on log result
|
||||
if req.Result == "success" {
|
||||
if req.Result == "success" || req.Result == "completed" {
|
||||
if err := h.commandQueries.MarkCommandCompleted(commandID, result); err != nil {
|
||||
fmt.Printf("Warning: Failed to mark command %s as completed: %v\n", commandID, err)
|
||||
}
|
||||
|
||||
// NEW: If this was a successful confirm_dependencies command, mark the package as updated
|
||||
command, err := h.commandQueries.GetCommandByID(commandID)
|
||||
if err == nil && command.CommandType == models.CommandTypeConfirmDependencies {
|
||||
// Extract package info from command params
|
||||
if packageName, ok := command.Params["package_name"].(string); ok {
|
||||
if packageType, ok := command.Params["package_type"].(string); ok {
|
||||
// Extract actual completion timestamp from command result for accurate audit trail
|
||||
var completionTime *time.Time
|
||||
if loggedAtStr, ok := command.Result["logged_at"].(string); ok {
|
||||
if parsed, err := time.Parse(time.RFC3339Nano, loggedAtStr); err == nil {
|
||||
completionTime = &parsed
|
||||
}
|
||||
}
|
||||
|
||||
// Update package status to 'updated' with actual completion timestamp
|
||||
if err := h.updateQueries.UpdatePackageStatus(agentID, packageType, packageName, "updated", nil, completionTime); err != nil {
|
||||
log.Printf("Warning: Failed to update package status for %s/%s: %v", packageType, packageName, err)
|
||||
} else {
|
||||
log.Printf("✅ Package %s (%s) marked as updated after successful installation", packageName, packageType)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if req.Result == "failed" || req.Result == "dry_run_failed" {
|
||||
if err := h.commandQueries.MarkCommandFailed(commandID, result); err != nil {
|
||||
fmt.Printf("Warning: Failed to mark command %s as failed: %v\n", commandID, err)
|
||||
@@ -304,7 +355,7 @@ func (h *UpdateHandler) UpdatePackageStatus(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.updateQueries.UpdatePackageStatus(agentID, req.PackageType, req.PackageName, req.Status, req.Metadata); err != nil {
|
||||
if err := h.updateQueries.UpdatePackageStatus(agentID, req.PackageType, req.PackageName, req.Status, req.Metadata, nil); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to update package status"})
|
||||
return
|
||||
}
|
||||
@@ -395,7 +446,29 @@ func (h *UpdateHandler) InstallUpdate(c *gin.Context) {
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Store the command in database
|
||||
// Check if heartbeat should be enabled (avoid duplicates)
|
||||
if shouldEnable, err := h.shouldEnableHeartbeat(update.AgentID, 10); err == nil && shouldEnable {
|
||||
heartbeatCmd := &models.AgentCommand{
|
||||
ID: uuid.New(),
|
||||
AgentID: update.AgentID,
|
||||
CommandType: models.CommandTypeEnableHeartbeat,
|
||||
Params: models.JSONB{
|
||||
"duration_minutes": 10,
|
||||
},
|
||||
Status: models.CommandStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(heartbeatCmd); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to create heartbeat command for agent %s: %v", update.AgentID, err)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Command created for agent %s before dry run", update.AgentID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Skipping heartbeat command for agent %s (already active)", update.AgentID)
|
||||
}
|
||||
|
||||
// Store the dry run command in database
|
||||
if err := h.commandQueries.CreateCommand(command); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create dry run command"})
|
||||
return
|
||||
@@ -478,6 +551,28 @@ func (h *UpdateHandler) ReportDependencies(c *gin.Context) {
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Check if heartbeat should be enabled (avoid duplicates)
|
||||
if shouldEnable, err := h.shouldEnableHeartbeat(agentID, 10); err == nil && shouldEnable {
|
||||
heartbeatCmd := &models.AgentCommand{
|
||||
ID: uuid.New(),
|
||||
AgentID: agentID,
|
||||
CommandType: models.CommandTypeEnableHeartbeat,
|
||||
Params: models.JSONB{
|
||||
"duration_minutes": 10,
|
||||
},
|
||||
Status: models.CommandStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(heartbeatCmd); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to create heartbeat command for agent %s: %v", agentID, err)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Command created for agent %s before installation", agentID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Skipping heartbeat command for agent %s (already active)", agentID)
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(command); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create installation command"})
|
||||
return
|
||||
@@ -536,6 +631,28 @@ func (h *UpdateHandler) ConfirmDependencies(c *gin.Context) {
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Check if heartbeat should be enabled (avoid duplicates)
|
||||
if shouldEnable, err := h.shouldEnableHeartbeat(update.AgentID, 10); err == nil && shouldEnable {
|
||||
heartbeatCmd := &models.AgentCommand{
|
||||
ID: uuid.New(),
|
||||
AgentID: update.AgentID,
|
||||
CommandType: models.CommandTypeEnableHeartbeat,
|
||||
Params: models.JSONB{
|
||||
"duration_minutes": 10,
|
||||
},
|
||||
Status: models.CommandStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(heartbeatCmd); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to create heartbeat command for agent %s: %v", update.AgentID, err)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Command created for agent %s before confirm dependencies", update.AgentID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Skipping heartbeat command for agent %s (already active)", update.AgentID)
|
||||
}
|
||||
|
||||
// Store the command in database
|
||||
if err := h.commandQueries.CreateCommand(command); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create confirmation command"})
|
||||
@@ -684,3 +801,60 @@ func (h *UpdateHandler) GetRecentCommands(c *gin.Context) {
|
||||
"limit": limit,
|
||||
})
|
||||
}
|
||||
|
||||
// ClearFailedCommands manually removes failed/timed_out commands with cheeky warning
|
||||
func (h *UpdateHandler) ClearFailedCommands(c *gin.Context) {
|
||||
// Get query parameters for filtering
|
||||
olderThanDaysStr := c.Query("older_than_days")
|
||||
onlyRetriedStr := c.Query("only_retried")
|
||||
allFailedStr := c.Query("all_failed")
|
||||
|
||||
var count int64
|
||||
var err error
|
||||
|
||||
// Parse parameters
|
||||
olderThanDays := 7 // default
|
||||
if olderThanDaysStr != "" {
|
||||
if days, err := strconv.Atoi(olderThanDaysStr); err == nil && days > 0 {
|
||||
olderThanDays = days
|
||||
}
|
||||
}
|
||||
|
||||
onlyRetried := onlyRetriedStr == "true"
|
||||
allFailed := allFailedStr == "true"
|
||||
|
||||
// Build the appropriate cleanup query based on parameters
|
||||
if allFailed {
|
||||
// Clear ALL failed commands (most aggressive)
|
||||
count, err = h.commandQueries.ClearAllFailedCommands(olderThanDays)
|
||||
} else if onlyRetried {
|
||||
// Clear only failed commands that have been retried
|
||||
count, err = h.commandQueries.ClearRetriedFailedCommands(olderThanDays)
|
||||
} else {
|
||||
// Clear failed commands older than specified days (default behavior)
|
||||
count, err = h.commandQueries.ClearOldFailedCommands(olderThanDays)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "failed to clear failed commands",
|
||||
"details": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Return success with cheeky message
|
||||
message := fmt.Sprintf("Archived %d failed commands", count)
|
||||
if count > 0 {
|
||||
message += ". WARNING: This shouldn't be necessary if the retry logic is working properly - you might want to check what's causing commands to fail in the first place!"
|
||||
message += " (History preserved - commands moved to archived status)"
|
||||
} else {
|
||||
message += ". No failed commands found matching your criteria. SUCCESS!"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": message,
|
||||
"count": count,
|
||||
"cheeky_warning": "Consider this a developer experience enhancement - the system should clean up after itself automatically!",
|
||||
})
|
||||
}
|
||||
|
||||
279
aggregator-server/internal/api/middleware/rate_limiter.go
Normal file
279
aggregator-server/internal/api/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RateLimitConfig holds configuration for rate limiting
|
||||
type RateLimitConfig struct {
|
||||
Requests int `json:"requests"`
|
||||
Window time.Duration `json:"window"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// RateLimitEntry tracks requests for a specific key
|
||||
type RateLimitEntry struct {
|
||||
Requests []time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// RateLimiter implements in-memory rate limiting with user-configurable settings
|
||||
type RateLimiter struct {
|
||||
entries sync.Map // map[string]*RateLimitEntry
|
||||
configs map[string]RateLimitConfig
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// RateLimitSettings holds all user-configurable rate limit settings
|
||||
type RateLimitSettings struct {
|
||||
AgentRegistration RateLimitConfig `json:"agent_registration"`
|
||||
AgentCheckIn RateLimitConfig `json:"agent_checkin"`
|
||||
AgentReports RateLimitConfig `json:"agent_reports"`
|
||||
AdminTokenGen RateLimitConfig `json:"admin_token_generation"`
|
||||
AdminOperations RateLimitConfig `json:"admin_operations"`
|
||||
PublicAccess RateLimitConfig `json:"public_access"`
|
||||
}
|
||||
|
||||
// DefaultRateLimitSettings provides sensible defaults
|
||||
func DefaultRateLimitSettings() RateLimitSettings {
|
||||
return RateLimitSettings{
|
||||
AgentRegistration: RateLimitConfig{
|
||||
Requests: 5,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
AgentCheckIn: RateLimitConfig{
|
||||
Requests: 60,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
AgentReports: RateLimitConfig{
|
||||
Requests: 30,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
AdminTokenGen: RateLimitConfig{
|
||||
Requests: 10,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
AdminOperations: RateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
PublicAccess: RateLimitConfig{
|
||||
Requests: 20,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter with default settings
|
||||
func NewRateLimiter() *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
entries: sync.Map{},
|
||||
}
|
||||
|
||||
// Load default settings
|
||||
defaults := DefaultRateLimitSettings()
|
||||
rl.UpdateSettings(defaults)
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// UpdateSettings updates rate limit configurations
|
||||
func (rl *RateLimiter) UpdateSettings(settings RateLimitSettings) {
|
||||
rl.mutex.Lock()
|
||||
defer rl.mutex.Unlock()
|
||||
|
||||
rl.configs = map[string]RateLimitConfig{
|
||||
"agent_registration": settings.AgentRegistration,
|
||||
"agent_checkin": settings.AgentCheckIn,
|
||||
"agent_reports": settings.AgentReports,
|
||||
"admin_token_gen": settings.AdminTokenGen,
|
||||
"admin_operations": settings.AdminOperations,
|
||||
"public_access": settings.PublicAccess,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSettings returns current rate limit settings
|
||||
func (rl *RateLimiter) GetSettings() RateLimitSettings {
|
||||
rl.mutex.RLock()
|
||||
defer rl.mutex.RUnlock()
|
||||
|
||||
return RateLimitSettings{
|
||||
AgentRegistration: rl.configs["agent_registration"],
|
||||
AgentCheckIn: rl.configs["agent_checkin"],
|
||||
AgentReports: rl.configs["agent_reports"],
|
||||
AdminTokenGen: rl.configs["admin_token_gen"],
|
||||
AdminOperations: rl.configs["admin_operations"],
|
||||
PublicAccess: rl.configs["public_access"],
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimit creates middleware for a specific rate limit type
|
||||
func (rl *RateLimiter) RateLimit(limitType string, keyFunc func(*gin.Context) string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
rl.mutex.RLock()
|
||||
config, exists := rl.configs[limitType]
|
||||
rl.mutex.RUnlock()
|
||||
|
||||
if !exists || !config.Enabled {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
key := keyFunc(c)
|
||||
if key == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Check rate limit
|
||||
allowed, resetTime := rl.checkRateLimit(key, config)
|
||||
if !allowed {
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", config.Requests))
|
||||
c.Header("X-RateLimit-Remaining", "0")
|
||||
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
|
||||
c.Header("Retry-After", fmt.Sprintf("%d", int(resetTime.Sub(time.Now()).Seconds())))
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "Rate limit exceeded",
|
||||
"limit": config.Requests,
|
||||
"window": config.Window.String(),
|
||||
"reset_time": resetTime,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Add rate limit headers
|
||||
remaining := rl.getRemainingRequests(key, config)
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", config.Requests))
|
||||
c.Header("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
||||
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Add(config.Window).Unix()))
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// checkRateLimit checks if the request is allowed
|
||||
func (rl *RateLimiter) checkRateLimit(key string, config RateLimitConfig) (bool, time.Time) {
|
||||
now := time.Now()
|
||||
|
||||
// Get or create entry
|
||||
entryInterface, _ := rl.entries.LoadOrStore(key, &RateLimitEntry{
|
||||
Requests: []time.Time{},
|
||||
})
|
||||
entry := entryInterface.(*RateLimitEntry)
|
||||
|
||||
entry.mutex.Lock()
|
||||
defer entry.mutex.Unlock()
|
||||
|
||||
// Clean old requests outside the window
|
||||
cutoff := now.Add(-config.Window)
|
||||
validRequests := make([]time.Time, 0)
|
||||
for _, reqTime := range entry.Requests {
|
||||
if reqTime.After(cutoff) {
|
||||
validRequests = append(validRequests, reqTime)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if under limit
|
||||
if len(validRequests) >= config.Requests {
|
||||
// Find when the oldest request expires
|
||||
oldestRequest := validRequests[0]
|
||||
resetTime := oldestRequest.Add(config.Window)
|
||||
return false, resetTime
|
||||
}
|
||||
|
||||
// Add current request
|
||||
entry.Requests = append(validRequests, now)
|
||||
|
||||
// Clean up expired entries periodically
|
||||
if len(entry.Requests) == 0 {
|
||||
rl.entries.Delete(key)
|
||||
}
|
||||
|
||||
return true, time.Time{}
|
||||
}
|
||||
|
||||
// getRemainingRequests calculates remaining requests for the key
|
||||
func (rl *RateLimiter) getRemainingRequests(key string, config RateLimitConfig) int {
|
||||
entryInterface, ok := rl.entries.Load(key)
|
||||
if !ok {
|
||||
return config.Requests
|
||||
}
|
||||
|
||||
entry := entryInterface.(*RateLimitEntry)
|
||||
entry.mutex.RLock()
|
||||
defer entry.mutex.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-config.Window)
|
||||
count := 0
|
||||
|
||||
for _, reqTime := range entry.Requests {
|
||||
if reqTime.After(cutoff) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
remaining := config.Requests - count
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
return remaining
|
||||
}
|
||||
|
||||
// CleanupExpiredEntries removes expired entries to prevent memory leaks
|
||||
func (rl *RateLimiter) CleanupExpiredEntries() {
|
||||
rl.entries.Range(func(key, value interface{}) bool {
|
||||
entry := value.(*RateLimitEntry)
|
||||
entry.mutex.Lock()
|
||||
|
||||
now := time.Now()
|
||||
validRequests := make([]time.Time, 0)
|
||||
for _, reqTime := range entry.Requests {
|
||||
if reqTime.After(now.Add(-time.Hour)) { // Keep requests from last hour
|
||||
validRequests = append(validRequests, reqTime)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validRequests) == 0 {
|
||||
rl.entries.Delete(key)
|
||||
} else {
|
||||
entry.Requests = validRequests
|
||||
}
|
||||
|
||||
entry.mutex.Unlock()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Key generation functions
|
||||
func KeyByIP(c *gin.Context) string {
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
func KeyByAgentID(c *gin.Context) string {
|
||||
return c.Param("id")
|
||||
}
|
||||
|
||||
func KeyByUserID(c *gin.Context) string {
|
||||
// This would extract user ID from JWT or session
|
||||
// For now, use IP as fallback
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
func KeyByIPAndPath(c *gin.Context) string {
|
||||
return c.ClientIP() + ":" + c.Request.URL.Path
|
||||
}
|
||||
Reference in New Issue
Block a user