v0.1.16: Security overhaul and systematic deployment preparation

Breaking changes for clean alpha releases:
- JWT authentication with user-provided secrets (no more development defaults)
- Registration token system for secure agent enrollment
- Rate limiting with user-adjustable settings
- Enhanced agent configuration with proxy support
- Interactive server setup wizard (--setup flag)
- Heartbeat architecture separation for better UX
- Package status synchronization fixes
- Accurate timestamp tracking for RMM features

Setup process for new installations:
1. docker-compose up -d postgres
2. ./redflag-server --setup
3. ./redflag-server --migrate
4. ./redflag-server
5. Generate tokens via admin UI
6. Deploy agents with registration tokens
This commit is contained in:
Fimeg
2025-10-29 10:38:18 -04:00
parent b3e1b9e52f
commit 03fee29760
50 changed files with 5807 additions and 466 deletions

View File

@@ -1,6 +1,7 @@
package main
import (
"flag"
"fmt"
"log"
"path/filepath"
@@ -16,6 +17,29 @@ import (
)
func main() {
// Parse command line flags
var setup bool
var migrate bool
var version bool
flag.BoolVar(&setup, "setup", false, "Run setup wizard")
flag.BoolVar(&migrate, "migrate", false, "Run database migrations only")
flag.BoolVar(&version, "version", false, "Show version information")
flag.Parse()
// Handle special commands
if version {
fmt.Printf("RedFlag Server v0.1.0-alpha\n")
fmt.Printf("Self-hosted update management platform\n")
return
}
if setup {
if err := config.RunSetupWizard(); err != nil {
log.Fatal("Setup failed:", err)
}
return
}
// Load configuration
cfg, err := config.Load()
if err != nil {
@@ -23,15 +47,29 @@ func main() {
}
// Set JWT secret
middleware.JWTSecret = cfg.JWTSecret
middleware.JWTSecret = cfg.Admin.JWTSecret
// Build database URL from new config structure
databaseURL := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable",
cfg.Database.Username, cfg.Database.Password, cfg.Database.Host, cfg.Database.Port, cfg.Database.Database)
// Connect to database
db, err := database.Connect(cfg.DatabaseURL)
db, err := database.Connect(databaseURL)
if err != nil {
log.Fatal("Failed to connect to database:", err)
}
defer db.Close()
// Handle migrate-only flag
if migrate {
migrationsPath := filepath.Join("internal", "database", "migrations")
if err := db.Migrate(migrationsPath); err != nil {
log.Fatal("Migration failed:", err)
}
fmt.Printf("✅ Database migrations completed\n")
return
}
// Run migrations
migrationsPath := filepath.Join("internal", "database", "migrations")
if err := db.Migrate(migrationsPath); err != nil {
@@ -45,18 +83,24 @@ func main() {
updateQueries := queries.NewUpdateQueries(db.DB)
commandQueries := queries.NewCommandQueries(db.DB)
refreshTokenQueries := queries.NewRefreshTokenQueries(db.DB)
registrationTokenQueries := queries.NewRegistrationTokenQueries(db.DB)
// Initialize services
timezoneService := services.NewTimezoneService(cfg)
timeoutService := services.NewTimeoutService(commandQueries, updateQueries)
// Initialize rate limiter
rateLimiter := middleware.NewRateLimiter()
// Initialize handlers
agentHandler := handlers.NewAgentHandler(agentQueries, commandQueries, refreshTokenQueries, cfg.CheckInInterval, cfg.LatestAgentVersion)
updateHandler := handlers.NewUpdateHandler(updateQueries, agentQueries, commandQueries)
authHandler := handlers.NewAuthHandler(cfg.JWTSecret)
updateHandler := handlers.NewUpdateHandler(updateQueries, agentQueries, commandQueries, agentHandler)
authHandler := handlers.NewAuthHandler(cfg.Admin.JWTSecret)
statsHandler := handlers.NewStatsHandler(agentQueries, updateQueries)
settingsHandler := handlers.NewSettingsHandler(timezoneService)
dockerHandler := handlers.NewDockerHandler(updateQueries, agentQueries, commandQueries)
registrationTokenHandler := handlers.NewRegistrationTokenHandler(registrationTokenQueries, agentQueries, cfg)
rateLimitHandler := handlers.NewRateLimitHandler(rateLimiter)
// Setup router
router := gin.Default()
@@ -72,24 +116,26 @@ func main() {
// API routes
api := router.Group("/api/v1")
{
// Authentication routes
api.POST("/auth/login", authHandler.Login)
// Authentication routes (with rate limiting)
api.POST("/auth/login", rateLimiter.RateLimit("public_access", middleware.KeyByIP), authHandler.Login)
api.POST("/auth/logout", authHandler.Logout)
api.GET("/auth/verify", authHandler.VerifyToken)
// Public routes (no authentication required)
api.POST("/agents/register", agentHandler.RegisterAgent)
api.POST("/agents/renew", agentHandler.RenewToken)
// Public routes (no authentication required, with rate limiting)
api.POST("/agents/register", rateLimiter.RateLimit("agent_registration", middleware.KeyByIP), agentHandler.RegisterAgent)
api.POST("/agents/renew", rateLimiter.RateLimit("public_access", middleware.KeyByIP), agentHandler.RenewToken)
// Protected agent routes
agents := api.Group("/agents")
agents.Use(middleware.AuthMiddleware())
{
agents.GET("/:id/commands", agentHandler.GetCommands)
agents.POST("/:id/updates", updateHandler.ReportUpdates)
agents.POST("/:id/logs", updateHandler.ReportLog)
agents.POST("/:id/dependencies", updateHandler.ReportDependencies)
agents.POST("/:id/system-info", agentHandler.ReportSystemInfo)
agents.POST("/:id/updates", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), updateHandler.ReportUpdates)
agents.POST("/:id/logs", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), updateHandler.ReportLog)
agents.POST("/:id/dependencies", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), updateHandler.ReportDependencies)
agents.POST("/:id/system-info", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), agentHandler.ReportSystemInfo)
agents.POST("/:id/rapid-mode", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), agentHandler.SetRapidPollingMode)
agents.DELETE("/:id", agentHandler.UnregisterAgent)
}
// Dashboard/Web routes (protected by web auth)
@@ -101,7 +147,8 @@ func main() {
dashboard.GET("/agents/:id", agentHandler.GetAgent)
dashboard.POST("/agents/:id/scan", agentHandler.TriggerScan)
dashboard.POST("/agents/:id/update", agentHandler.TriggerUpdate)
dashboard.DELETE("/agents/:id", agentHandler.UnregisterAgent)
dashboard.POST("/agents/:id/heartbeat", agentHandler.TriggerHeartbeat)
dashboard.GET("/agents/:id/heartbeat", agentHandler.GetHeartbeatStatus)
dashboard.GET("/updates", updateHandler.ListUpdates)
dashboard.GET("/updates/:id", updateHandler.GetUpdate)
dashboard.GET("/updates/:id/logs", updateHandler.GetUpdateLogs)
@@ -120,6 +167,7 @@ func main() {
dashboard.GET("/commands/recent", updateHandler.GetRecentCommands)
dashboard.POST("/commands/:id/retry", updateHandler.RetryCommand)
dashboard.POST("/commands/:id/cancel", updateHandler.CancelCommand)
dashboard.DELETE("/commands/failed", updateHandler.ClearFailedCommands)
// Settings routes
dashboard.GET("/settings/timezone", settingsHandler.GetTimezone)
@@ -132,6 +180,25 @@ func main() {
dashboard.POST("/docker/containers/:container_id/images/:image_id/approve", dockerHandler.ApproveUpdate)
dashboard.POST("/docker/containers/:container_id/images/:image_id/reject", dockerHandler.RejectUpdate)
dashboard.POST("/docker/containers/:container_id/images/:image_id/install", dockerHandler.InstallUpdate)
// Admin/Registration Token routes (for agent enrollment management)
admin := dashboard.Group("/admin")
{
admin.POST("/registration-tokens", rateLimiter.RateLimit("admin_token_gen", middleware.KeyByUserID), registrationTokenHandler.GenerateRegistrationToken)
admin.GET("/registration-tokens", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), registrationTokenHandler.ListRegistrationTokens)
admin.GET("/registration-tokens/active", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), registrationTokenHandler.GetActiveRegistrationTokens)
admin.DELETE("/registration-tokens/:token", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), registrationTokenHandler.RevokeRegistrationToken)
admin.POST("/registration-tokens/cleanup", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), registrationTokenHandler.CleanupExpiredTokens)
admin.GET("/registration-tokens/stats", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), registrationTokenHandler.GetTokenStats)
admin.GET("/registration-tokens/validate", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), registrationTokenHandler.ValidateRegistrationToken)
// Rate Limit Management
admin.GET("/rate-limits", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), rateLimitHandler.GetRateLimitSettings)
admin.PUT("/rate-limits", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), rateLimitHandler.UpdateRateLimitSettings)
admin.POST("/rate-limits/reset", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), rateLimitHandler.ResetRateLimitSettings)
admin.GET("/rate-limits/stats", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), rateLimitHandler.GetRateLimitStats)
admin.POST("/rate-limits/cleanup", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), rateLimitHandler.CleanupRateLimitEntries)
}
}
}
@@ -166,8 +233,11 @@ func main() {
}()
// Start server
addr := ":" + cfg.ServerPort
fmt.Printf("\n🚩 RedFlag Aggregator Server starting on %s\n\n", addr)
addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
fmt.Printf("\nRedFlag Aggregator Server starting on %s\n", addr)
fmt.Printf("Admin interface: http://%s:%d/admin\n", cfg.Server.Host, cfg.Server.Port)
fmt.Printf("Dashboard: http://%s:%d\n\n", cfg.Server.Host, cfg.Server.Port)
if err := router.Run(addr); err != nil {
log.Fatal("Failed to start server:", err)
}

View File

@@ -9,6 +9,7 @@ require (
github.com/jmoiron/sqlx v1.4.0
github.com/joho/godotenv v1.5.1
github.com/lib/pq v1.10.9
golang.org/x/term v0.33.0
)
require (

View File

@@ -92,6 +92,8 @@ golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=

View File

@@ -1,6 +1,7 @@
package handlers
import (
"fmt"
"log"
"net/http"
"time"
@@ -107,15 +108,16 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
// Try to parse optional system metrics from request body
var metrics struct {
CPUPercent float64 `json:"cpu_percent,omitempty"`
MemoryPercent float64 `json:"memory_percent,omitempty"`
MemoryUsedGB float64 `json:"memory_used_gb,omitempty"`
MemoryTotalGB float64 `json:"memory_total_gb,omitempty"`
DiskUsedGB float64 `json:"disk_used_gb,omitempty"`
DiskTotalGB float64 `json:"disk_total_gb,omitempty"`
DiskPercent float64 `json:"disk_percent,omitempty"`
Uptime string `json:"uptime,omitempty"`
Version string `json:"version,omitempty"`
CPUPercent float64 `json:"cpu_percent,omitempty"`
MemoryPercent float64 `json:"memory_percent,omitempty"`
MemoryUsedGB float64 `json:"memory_used_gb,omitempty"`
MemoryTotalGB float64 `json:"memory_total_gb,omitempty"`
DiskUsedGB float64 `json:"disk_used_gb,omitempty"`
DiskTotalGB float64 `json:"disk_total_gb,omitempty"`
DiskPercent float64 `json:"disk_percent,omitempty"`
Uptime string `json:"uptime,omitempty"`
Version string `json:"version,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// Parse metrics if provided (optional, won't fail if empty)
@@ -130,21 +132,21 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
// Always handle version information if provided
if metrics.Version != "" {
// Get current agent to preserve existing metadata
agent, err := h.agentQueries.GetAgentByID(agentID)
if err == nil && agent.Metadata != nil {
// Update agent's current version
if err := h.agentQueries.UpdateAgentVersion(agentID, metrics.Version); err != nil {
log.Printf("Warning: Failed to update agent version: %v", err)
} else {
// Check if update is available
updateAvailable := utils.IsNewerVersion(h.latestAgentVersion, metrics.Version)
// Update agent's current version in database (primary source of truth)
if err := h.agentQueries.UpdateAgentVersion(agentID, metrics.Version); err != nil {
log.Printf("Warning: Failed to update agent version: %v", err)
} else {
// Check if update is available
updateAvailable := utils.IsNewerVersion(h.latestAgentVersion, metrics.Version)
// Update agent's update availability status
if err := h.agentQueries.UpdateAgentUpdateAvailable(agentID, updateAvailable); err != nil {
log.Printf("Warning: Failed to update agent update availability: %v", err)
}
// Update agent's update availability status
if err := h.agentQueries.UpdateAgentUpdateAvailable(agentID, updateAvailable); err != nil {
log.Printf("Warning: Failed to update agent update availability: %v", err)
}
// Get current agent for logging and metadata update
agent, err := h.agentQueries.GetAgentByID(agentID)
if err == nil {
// Log version check
if updateAvailable {
log.Printf("🔄 Agent %s (%s) version %s has update available: %s",
@@ -154,11 +156,20 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
agent.Hostname, agentID, metrics.Version)
}
// Store version in metadata as well
// Store version in metadata as well (for backwards compatibility)
// Initialize metadata if nil
if agent.Metadata == nil {
agent.Metadata = make(models.JSONB)
}
agent.Metadata["reported_version"] = metrics.Version
agent.Metadata["latest_version"] = h.latestAgentVersion
agent.Metadata["update_available"] = updateAvailable
agent.Metadata["version_checked_at"] = time.Now().Format(time.RFC3339)
// Update agent metadata
if err := h.agentQueries.UpdateAgent(agent); err != nil {
log.Printf("Warning: Failed to update agent metadata: %v", err)
}
}
}
}
@@ -179,6 +190,29 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
agent.Metadata["uptime"] = metrics.Uptime
agent.Metadata["metrics_updated_at"] = time.Now().Format(time.RFC3339)
// Process heartbeat metadata from agent check-ins
if metrics.Metadata != nil {
if rapidPollingEnabled, exists := metrics.Metadata["rapid_polling_enabled"]; exists {
if rapidPollingUntil, exists := metrics.Metadata["rapid_polling_until"]; exists {
// Parse the until timestamp
if untilTime, err := time.Parse(time.RFC3339, rapidPollingUntil.(string)); err == nil {
// Validate if rapid polling is still active (not expired)
isActive := rapidPollingEnabled.(bool) && time.Now().Before(untilTime)
// Store heartbeat status in agent metadata
agent.Metadata["rapid_polling_enabled"] = rapidPollingEnabled
agent.Metadata["rapid_polling_until"] = rapidPollingUntil
agent.Metadata["rapid_polling_active"] = isActive
log.Printf("[Heartbeat] Agent %s heartbeat status: enabled=%v, until=%v, active=%v",
agentID, rapidPollingEnabled, rapidPollingUntil, isActive)
} else {
log.Printf("[Heartbeat] Failed to parse rapid_polling_until timestamp for agent %s: %v", agentID, err)
}
}
}
}
// Update agent with new metadata
if err := h.agentQueries.UpdateAgent(agent); err != nil {
log.Printf("Warning: Failed to update agent metrics: %v", err)
@@ -192,6 +226,37 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
return
}
// Process heartbeat metadata from agent check-ins
if metrics.Metadata != nil {
agent, err := h.agentQueries.GetAgentByID(agentID)
if err == nil && agent.Metadata != nil {
if rapidPollingEnabled, exists := metrics.Metadata["rapid_polling_enabled"]; exists {
if rapidPollingUntil, exists := metrics.Metadata["rapid_polling_until"]; exists {
// Parse the until timestamp
if untilTime, err := time.Parse(time.RFC3339, rapidPollingUntil.(string)); err == nil {
// Validate if rapid polling is still active (not expired)
isActive := rapidPollingEnabled.(bool) && time.Now().Before(untilTime)
// Store heartbeat status in agent metadata
agent.Metadata["rapid_polling_enabled"] = rapidPollingEnabled
agent.Metadata["rapid_polling_until"] = rapidPollingUntil
agent.Metadata["rapid_polling_active"] = isActive
log.Printf("[Heartbeat] Agent %s heartbeat status: enabled=%v, until=%v, active=%v",
agentID, rapidPollingEnabled, rapidPollingUntil, isActive)
// Update agent with new metadata
if err := h.agentQueries.UpdateAgent(agent); err != nil {
log.Printf("[Heartbeat] Warning: Failed to update agent heartbeat metadata: %v", err)
}
} else {
log.Printf("[Heartbeat] Failed to parse rapid_polling_until timestamp for agent %s: %v", agentID, err)
}
}
}
}
}
// Check for version updates for agents that don't send version in metrics
// This ensures agents like Metis that don't report version still get update checks
if metrics.Version == "" {
@@ -239,8 +304,97 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
h.commandQueries.MarkCommandSent(cmd.ID)
}
// Check if rapid polling should be enabled
var rapidPolling *models.RapidPollingConfig
// Enable rapid polling if there are commands to process
if len(commandItems) > 0 {
rapidPolling = &models.RapidPollingConfig{
Enabled: true,
Until: time.Now().Add(10 * time.Minute).Format(time.RFC3339), // 10 minutes default
}
} else {
// Check if agent has rapid polling already configured in metadata
agent, err := h.agentQueries.GetAgentByID(agentID)
if err == nil && agent.Metadata != nil {
if enabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && enabled {
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
if until, err := time.Parse(time.RFC3339, untilStr); err == nil && time.Now().Before(until) {
rapidPolling = &models.RapidPollingConfig{
Enabled: true,
Until: untilStr,
}
}
}
}
}
}
// Detect stale heartbeat state: Server thinks it's active, but agent didn't report it
// This happens when agent restarts without heartbeat mode
agent, err := h.agentQueries.GetAgentByID(agentID)
if err == nil && agent.Metadata != nil {
// Check if server metadata shows heartbeat active
if serverEnabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && serverEnabled {
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
if until, err := time.Parse(time.RFC3339, untilStr); err == nil && time.Now().Before(until) {
// Server thinks heartbeat is active and not expired
// Check if agent is reporting heartbeat in this check-in
agentReportingHeartbeat := false
if metrics.Metadata != nil {
if agentEnabled, exists := metrics.Metadata["rapid_polling_enabled"]; exists {
agentReportingHeartbeat = agentEnabled.(bool)
}
}
// If agent is NOT reporting heartbeat but server expects it → stale state
if !agentReportingHeartbeat {
log.Printf("[Heartbeat] Stale heartbeat detected for agent %s - server expected active until %s, but agent not reporting heartbeat (likely restarted)",
agentID, until.Format(time.RFC3339))
// Clear stale heartbeat state
agent.Metadata["rapid_polling_enabled"] = false
delete(agent.Metadata, "rapid_polling_until")
if err := h.agentQueries.UpdateAgent(agent); err != nil {
log.Printf("[Heartbeat] Warning: Failed to clear stale heartbeat state: %v", err)
} else {
log.Printf("[Heartbeat] Cleared stale heartbeat state for agent %s", agentID)
// Create audit command to show in history
now := time.Now()
auditCmd := &models.AgentCommand{
ID: uuid.New(),
AgentID: agentID,
CommandType: models.CommandTypeDisableHeartbeat,
Params: models.JSONB{},
Status: models.CommandStatusCompleted,
Result: models.JSONB{
"message": "Heartbeat cleared - agent restarted without active heartbeat mode",
},
CreatedAt: now,
SentAt: &now,
CompletedAt: &now,
}
if err := h.commandQueries.CreateCommand(auditCmd); err != nil {
log.Printf("[Heartbeat] Warning: Failed to create audit command for stale heartbeat: %v", err)
} else {
log.Printf("[Heartbeat] Created audit trail for stale heartbeat cleanup (agent %s)", agentID)
}
}
// Clear rapidPolling response since we just disabled it
rapidPolling = nil
}
}
}
}
}
response := models.CommandsResponse{
Commands: commandItems,
Commands: commandItems,
RapidPolling: rapidPolling,
}
c.JSON(http.StatusOK, response)
@@ -312,6 +466,124 @@ func (h *AgentHandler) TriggerScan(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "scan triggered", "command_id": cmd.ID})
}
// TriggerHeartbeat creates a heartbeat toggle command for an agent
func (h *AgentHandler) TriggerHeartbeat(c *gin.Context) {
idStr := c.Param("id")
agentID, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent ID"})
return
}
var request struct {
Enabled bool `json:"enabled"`
DurationMinutes int `json:"duration_minutes"`
}
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Determine command type based on enabled flag
commandType := models.CommandTypeDisableHeartbeat
if request.Enabled {
commandType = models.CommandTypeEnableHeartbeat
}
// Create heartbeat command with duration parameter
cmd := &models.AgentCommand{
ID: uuid.New(),
AgentID: agentID,
CommandType: commandType,
Params: models.JSONB{
"duration_minutes": request.DurationMinutes,
},
Status: models.CommandStatusPending,
}
if err := h.commandQueries.CreateCommand(cmd); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create heartbeat command"})
return
}
// TODO: Clean up previous heartbeat commands for this agent (only for enable commands)
// if request.Enabled {
// // Mark previous heartbeat commands as 'replaced' to clean up Live Operations view
// if err := h.commandQueries.MarkPreviousHeartbeatCommandsReplaced(agentID, cmd.ID); err != nil {
// log.Printf("Warning: Failed to mark previous heartbeat commands as replaced: %v", err)
// // Don't fail the request, just log the warning
// } else {
// log.Printf("[Heartbeat] Cleaned up previous heartbeat commands for agent %s", agentID)
// }
// }
action := "disabled"
if request.Enabled {
action = "enabled"
}
log.Printf("💓 Heartbeat %s command created for agent %s (duration: %d minutes)",
action, agentID, request.DurationMinutes)
c.JSON(http.StatusOK, gin.H{
"message": fmt.Sprintf("heartbeat %s command sent", action),
"command_id": cmd.ID,
"enabled": request.Enabled,
})
}
// GetHeartbeatStatus returns the current heartbeat status for an agent
func (h *AgentHandler) GetHeartbeatStatus(c *gin.Context) {
idStr := c.Param("id")
agentID, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent ID"})
return
}
// Get agent and their heartbeat metadata
agent, err := h.agentQueries.GetAgentByID(agentID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"})
return
}
// Extract heartbeat information from metadata
response := gin.H{
"enabled": false,
"until": nil,
"active": false,
"duration_minutes": 0,
}
if agent.Metadata != nil {
// Check if heartbeat is enabled in metadata
if enabled, exists := agent.Metadata["rapid_polling_enabled"]; exists {
response["enabled"] = enabled.(bool)
// If enabled, get the until time and check if still active
if enabled.(bool) {
if untilStr, exists := agent.Metadata["rapid_polling_until"]; exists {
response["until"] = untilStr.(string)
// Parse the until timestamp to check if still active
if untilTime, err := time.Parse(time.RFC3339, untilStr.(string)); err == nil {
response["active"] = time.Now().Before(untilTime)
}
}
// Get duration if available
if duration, exists := agent.Metadata["rapid_polling_duration_minutes"]; exists {
response["duration_minutes"] = duration.(float64)
}
}
}
}
c.JSON(http.StatusOK, response)
}
// TriggerUpdate creates an update command for an agent
func (h *AgentHandler) TriggerUpdate(c *gin.Context) {
idStr := c.Param("id")
@@ -541,3 +813,114 @@ func (h *AgentHandler) ReportSystemInfo(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "system info updated successfully"})
}
// EnableRapidPollingMode enables rapid polling for an agent by updating metadata
func (h *AgentHandler) EnableRapidPollingMode(agentID uuid.UUID, durationMinutes int) error {
// Get current agent
agent, err := h.agentQueries.GetAgentByID(agentID)
if err != nil {
return fmt.Errorf("failed to get agent: %w", err)
}
// Calculate new rapid polling end time
newRapidPollingUntil := time.Now().Add(time.Duration(durationMinutes) * time.Minute)
// Update agent metadata with rapid polling settings
if agent.Metadata == nil {
agent.Metadata = models.JSONB{}
}
// Check if rapid polling is already active
if enabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && enabled {
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
if currentUntil, err := time.Parse(time.RFC3339, untilStr); err == nil {
// If current heartbeat expires later than the new duration, keep the longer duration
if currentUntil.After(newRapidPollingUntil) {
log.Printf("💓 Heartbeat already active for agent %s (%s), keeping longer duration (expires: %s)",
agent.Hostname, agentID, currentUntil.Format(time.RFC3339))
return nil
}
// Otherwise extend the heartbeat
log.Printf("💓 Extending heartbeat for agent %s (%s) from %s to %s",
agent.Hostname, agentID,
currentUntil.Format(time.RFC3339),
newRapidPollingUntil.Format(time.RFC3339))
}
}
} else {
log.Printf("💓 Enabling heartbeat mode for agent %s (%s) for %d minutes",
agent.Hostname, agentID, durationMinutes)
}
// Set/update rapid polling settings
agent.Metadata["rapid_polling_enabled"] = true
agent.Metadata["rapid_polling_until"] = newRapidPollingUntil.Format(time.RFC3339)
// Update agent in database
if err := h.agentQueries.UpdateAgent(agent); err != nil {
return fmt.Errorf("failed to update agent with rapid polling: %w", err)
}
return nil
}
// SetRapidPollingMode enables rapid polling mode for an agent
// TODO: Rate limiting should be implemented for rapid polling endpoints to prevent abuse (technical debt)
func (h *AgentHandler) SetRapidPollingMode(c *gin.Context) {
idStr := c.Param("id")
agentID, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent ID"})
return
}
// Check if agent exists
agent, err := h.agentQueries.GetAgentByID(agentID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"})
return
}
var req struct {
DurationMinutes int `json:"duration_minutes" binding:"required,min=1,max=60"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Calculate rapid polling end time
rapidPollingUntil := time.Now().Add(time.Duration(req.DurationMinutes) * time.Minute)
// Update agent metadata with rapid polling settings
if agent.Metadata == nil {
agent.Metadata = models.JSONB{}
}
agent.Metadata["rapid_polling_enabled"] = req.Enabled
agent.Metadata["rapid_polling_until"] = rapidPollingUntil.Format(time.RFC3339)
// Update agent in database
if err := h.agentQueries.UpdateAgent(agent); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to update agent"})
return
}
status := "disabled"
duration := 0
if req.Enabled {
status = "enabled"
duration = req.DurationMinutes
}
log.Printf("🚀 Rapid polling mode %s for agent %s (%s) for %d minutes",
status, agent.Hostname, agentID, duration)
c.JSON(http.StatusOK, gin.H{
"message": fmt.Sprintf("Rapid polling mode %s", status),
"enabled": req.Enabled,
"duration_minutes": req.DurationMinutes,
"rapid_polling_until": rapidPollingUntil,
})
}

View File

@@ -378,7 +378,7 @@ func (h *DockerHandler) RejectUpdate(c *gin.Context) {
}
// For now, we'll mark as rejected (this would need a proper reject method in queries)
if err := h.updateQueries.UpdatePackageStatus(update.AgentID, "docker", update.PackageName, "rejected", nil); err != nil {
if err := h.updateQueries.UpdatePackageStatus(update.AgentID, "docker", update.PackageName, "rejected", nil, nil); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to reject Docker update"})
return
}

View File

@@ -0,0 +1,146 @@
package handlers
import (
"fmt"
"net/http"
"time"
"github.com/aggregator-project/aggregator-server/internal/api/middleware"
"github.com/gin-gonic/gin"
)
type RateLimitHandler struct {
rateLimiter *middleware.RateLimiter
}
func NewRateLimitHandler(rateLimiter *middleware.RateLimiter) *RateLimitHandler {
return &RateLimitHandler{
rateLimiter: rateLimiter,
}
}
// GetRateLimitSettings returns current rate limit configuration
func (h *RateLimitHandler) GetRateLimitSettings(c *gin.Context) {
settings := h.rateLimiter.GetSettings()
c.JSON(http.StatusOK, gin.H{
"settings": settings,
"updated_at": time.Now(),
})
}
// UpdateRateLimitSettings updates rate limit configuration
func (h *RateLimitHandler) UpdateRateLimitSettings(c *gin.Context) {
var settings middleware.RateLimitSettings
if err := c.ShouldBindJSON(&settings); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format: " + err.Error()})
return
}
// Validate settings
if err := h.validateRateLimitSettings(settings); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Update rate limiter settings
h.rateLimiter.UpdateSettings(settings)
c.JSON(http.StatusOK, gin.H{
"message": "Rate limit settings updated successfully",
"settings": settings,
"updated_at": time.Now(),
})
}
// ResetRateLimitSettings resets to default values
func (h *RateLimitHandler) ResetRateLimitSettings(c *gin.Context) {
defaultSettings := middleware.DefaultRateLimitSettings()
h.rateLimiter.UpdateSettings(defaultSettings)
c.JSON(http.StatusOK, gin.H{
"message": "Rate limit settings reset to defaults",
"settings": defaultSettings,
"updated_at": time.Now(),
})
}
// GetRateLimitStats returns current rate limit statistics
func (h *RateLimitHandler) GetRateLimitStats(c *gin.Context) {
settings := h.rateLimiter.GetSettings()
// Calculate total requests and windows
stats := gin.H{
"total_configured_limits": 6,
"enabled_limits": 0,
"total_requests_per_minute": 0,
"settings": settings,
}
// Count enabled limits and total requests
for _, config := range []middleware.RateLimitConfig{
settings.AgentRegistration,
settings.AgentCheckIn,
settings.AgentReports,
settings.AdminTokenGen,
settings.AdminOperations,
settings.PublicAccess,
} {
if config.Enabled {
stats["enabled_limits"] = stats["enabled_limits"].(int) + 1
}
stats["total_requests_per_minute"] = stats["total_requests_per_minute"].(int) + config.Requests
}
c.JSON(http.StatusOK, stats)
}
// CleanupRateLimitEntries manually triggers cleanup of expired entries
func (h *RateLimitHandler) CleanupRateLimitEntries(c *gin.Context) {
h.rateLimiter.CleanupExpiredEntries()
c.JSON(http.StatusOK, gin.H{
"message": "Rate limit entries cleanup completed",
"timestamp": time.Now(),
})
}
// validateRateLimitSettings validates the provided rate limit settings
func (h *RateLimitHandler) validateRateLimitSettings(settings middleware.RateLimitSettings) error {
// Validate each configuration
validations := []struct {
name string
config middleware.RateLimitConfig
}{
{"agent_registration", settings.AgentRegistration},
{"agent_checkin", settings.AgentCheckIn},
{"agent_reports", settings.AgentReports},
{"admin_token_generation", settings.AdminTokenGen},
{"admin_operations", settings.AdminOperations},
{"public_access", settings.PublicAccess},
}
for _, validation := range validations {
if validation.config.Requests <= 0 {
return fmt.Errorf("%s: requests must be greater than 0", validation.name)
}
if validation.config.Window <= 0 {
return fmt.Errorf("%s: window must be greater than 0", validation.name)
}
if validation.config.Window > 24*time.Hour {
return fmt.Errorf("%s: window cannot exceed 24 hours", validation.name)
}
if validation.config.Requests > 1000 {
return fmt.Errorf("%s: requests cannot exceed 1000 per window", validation.name)
}
}
// Specific validations for different endpoint types
if settings.AgentRegistration.Requests > 10 {
return fmt.Errorf("agent_registration: requests should not exceed 10 per minute for security")
}
if settings.PublicAccess.Requests > 50 {
return fmt.Errorf("public_access: requests should not exceed 50 per minute for security")
}
return nil
}

View File

@@ -0,0 +1,284 @@
package handlers
import (
"net/http"
"strconv"
"time"
"github.com/aggregator-project/aggregator-server/internal/config"
"github.com/aggregator-project/aggregator-server/internal/database/queries"
"github.com/gin-gonic/gin"
)
type RegistrationTokenHandler struct {
tokenQueries *queries.RegistrationTokenQueries
agentQueries *queries.AgentQueries
config *config.Config
}
func NewRegistrationTokenHandler(tokenQueries *queries.RegistrationTokenQueries, agentQueries *queries.AgentQueries, config *config.Config) *RegistrationTokenHandler {
return &RegistrationTokenHandler{
tokenQueries: tokenQueries,
agentQueries: agentQueries,
config: config,
}
}
// GenerateRegistrationToken creates a new registration token
func (h *RegistrationTokenHandler) GenerateRegistrationToken(c *gin.Context) {
var request struct {
Label string `json:"label" binding:"required"`
ExpiresIn string `json:"expires_in"` // e.g., "24h", "7d", "168h"
Metadata map[string]interface{} `json:"metadata"`
}
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format: " + err.Error()})
return
}
// Check agent seat limit (security, not licensing)
activeAgents, err := h.agentQueries.GetActiveAgentCount()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check agent count"})
return
}
if activeAgents >= h.config.AgentRegistration.MaxSeats {
c.JSON(http.StatusForbidden, gin.H{
"error": "Maximum agent seats reached",
"limit": h.config.AgentRegistration.MaxSeats,
"current": activeAgents,
})
return
}
// Parse expiration duration
expiresIn := request.ExpiresIn
if expiresIn == "" {
expiresIn = h.config.AgentRegistration.TokenExpiry
}
duration, err := time.ParseDuration(expiresIn)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid expiration format. Use formats like '24h', '7d', '168h'"})
return
}
expiresAt := time.Now().Add(duration)
if duration > 168*time.Hour { // Max 7 days
c.JSON(http.StatusBadRequest, gin.H{"error": "Token expiration cannot exceed 7 days"})
return
}
// Generate secure token
token, err := config.GenerateSecureToken()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
// Create metadata with default values
metadata := request.Metadata
if metadata == nil {
metadata = make(map[string]interface{})
}
metadata["server_url"] = c.Request.Host
metadata["expires_in"] = expiresIn
// Store token in database
err = h.tokenQueries.CreateRegistrationToken(token, request.Label, expiresAt, metadata)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create token"})
return
}
// Build install command
serverURL := c.Request.Host
if serverURL == "" {
serverURL = "localhost:8080" // Fallback for development
}
installCommand := "curl -sfL https://" + serverURL + "/install | bash -s -- " + token
response := gin.H{
"token": token,
"label": request.Label,
"expires_at": expiresAt,
"install_command": installCommand,
"metadata": metadata,
}
c.JSON(http.StatusCreated, response)
}
// ListRegistrationTokens returns all registration tokens with pagination
func (h *RegistrationTokenHandler) ListRegistrationTokens(c *gin.Context) {
// Parse pagination parameters
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "50"))
status := c.Query("status")
// Validate pagination
if limit > 100 {
limit = 100
}
if page < 1 {
page = 1
}
offset := (page - 1) * limit
var tokens []queries.RegistrationToken
var err error
if status != "" {
// TODO: Add filtered queries by status
tokens, err = h.tokenQueries.GetAllRegistrationTokens(limit, offset)
} else {
tokens, err = h.tokenQueries.GetAllRegistrationTokens(limit, offset)
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list tokens"})
return
}
// Get token usage stats
stats, err := h.tokenQueries.GetTokenUsageStats()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get token stats"})
return
}
response := gin.H{
"tokens": tokens,
"pagination": gin.H{
"page": page,
"limit": limit,
"offset": offset,
},
"stats": stats,
"seat_usage": gin.H{
"current": func() int {
count, _ := h.agentQueries.GetActiveAgentCount()
return count
}(),
"max": h.config.AgentRegistration.MaxSeats,
},
}
c.JSON(http.StatusOK, response)
}
// GetActiveRegistrationTokens returns only active tokens
func (h *RegistrationTokenHandler) GetActiveRegistrationTokens(c *gin.Context) {
tokens, err := h.tokenQueries.GetActiveRegistrationTokens()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get active tokens"})
return
}
c.JSON(http.StatusOK, gin.H{"tokens": tokens})
}
// RevokeRegistrationToken revokes a registration token
func (h *RegistrationTokenHandler) RevokeRegistrationToken(c *gin.Context) {
token := c.Param("token")
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Token is required"})
return
}
var request struct {
Reason string `json:"reason"`
}
c.ShouldBindJSON(&request) // Reason is optional
reason := request.Reason
if reason == "" {
reason = "Revoked via API"
}
err := h.tokenQueries.RevokeRegistrationToken(token, reason)
if err != nil {
if err.Error() == "token not found or already used/revoked" {
c.JSON(http.StatusNotFound, gin.H{"error": "Token not found or already used/revoked"})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to revoke token"})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Token revoked successfully"})
}
// ValidateRegistrationToken checks if a token is valid (for testing/debugging)
func (h *RegistrationTokenHandler) ValidateRegistrationToken(c *gin.Context) {
token := c.Query("token")
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Token query parameter is required"})
return
}
tokenInfo, err := h.tokenQueries.ValidateRegistrationToken(token)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"valid": false,
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"valid": true,
"token": tokenInfo,
})
}
// CleanupExpiredTokens performs cleanup of expired tokens
func (h *RegistrationTokenHandler) CleanupExpiredTokens(c *gin.Context) {
count, err := h.tokenQueries.CleanupExpiredTokens()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to cleanup expired tokens"})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "Cleanup completed",
"cleaned": count,
})
}
// GetTokenStats returns comprehensive token usage statistics
func (h *RegistrationTokenHandler) GetTokenStats(c *gin.Context) {
// Get token stats
tokenStats, err := h.tokenQueries.GetTokenUsageStats()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get token stats"})
return
}
// Get agent count
activeAgentCount, err := h.agentQueries.GetActiveAgentCount()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get agent count"})
return
}
response := gin.H{
"token_stats": tokenStats,
"agent_usage": gin.H{
"active_agents": activeAgentCount,
"max_seats": h.config.AgentRegistration.MaxSeats,
"available": h.config.AgentRegistration.MaxSeats - activeAgentCount,
},
"security_limits": gin.H{
"max_tokens_per_request": h.config.AgentRegistration.MaxTokens,
"max_token_duration": "7 days",
"token_expiry_default": h.config.AgentRegistration.TokenExpiry,
},
}
c.JSON(http.StatusOK, response)
}

View File

@@ -2,6 +2,7 @@ package handlers
import (
"fmt"
"log"
"net/http"
"strconv"
"time"
@@ -16,16 +17,42 @@ type UpdateHandler struct {
updateQueries *queries.UpdateQueries
agentQueries *queries.AgentQueries
commandQueries *queries.CommandQueries
agentHandler *AgentHandler
}
func NewUpdateHandler(uq *queries.UpdateQueries, aq *queries.AgentQueries, cq *queries.CommandQueries) *UpdateHandler {
func NewUpdateHandler(uq *queries.UpdateQueries, aq *queries.AgentQueries, cq *queries.CommandQueries, ah *AgentHandler) *UpdateHandler {
return &UpdateHandler{
updateQueries: uq,
agentQueries: aq,
commandQueries: cq,
agentHandler: ah,
}
}
// shouldEnableHeartbeat checks if heartbeat is already active for an agent
// Returns true if heartbeat should be enabled (i.e., not already active or expired)
func (h *UpdateHandler) shouldEnableHeartbeat(agentID uuid.UUID, durationMinutes int) (bool, error) {
agent, err := h.agentQueries.GetAgentByID(agentID)
if err != nil {
log.Printf("Warning: Failed to get agent %s for heartbeat check: %v", agentID, err)
return true, nil // Enable heartbeat by default if we can't check
}
// Check if rapid polling is already enabled and not expired
if enabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && enabled {
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
until, err := time.Parse(time.RFC3339, untilStr)
if err == nil && until.After(time.Now().Add(5*time.Minute)) {
// Heartbeat is already active for sufficient time
log.Printf("[Heartbeat] Agent %s already has active heartbeat until %s (skipping)", agentID, untilStr)
return false, nil
}
}
}
return true, nil
}
// ReportUpdates handles update reports from agents using event sourcing
func (h *UpdateHandler) ReportUpdates(c *gin.Context) {
agentID := c.MustGet("agent_id").(uuid.UUID)
@@ -172,7 +199,7 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
return
}
log := &models.UpdateLog{
logEntry := &models.UpdateLog{
ID: uuid.New(),
AgentID: agentID,
Action: req.Action,
@@ -185,7 +212,7 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
}
// Store the log entry
if err := h.updateQueries.CreateUpdateLog(log); err != nil {
if err := h.updateQueries.CreateUpdateLog(logEntry); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save log"})
return
}
@@ -207,10 +234,34 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
}
// Update command status based on log result
if req.Result == "success" {
if req.Result == "success" || req.Result == "completed" {
if err := h.commandQueries.MarkCommandCompleted(commandID, result); err != nil {
fmt.Printf("Warning: Failed to mark command %s as completed: %v\n", commandID, err)
}
// NEW: If this was a successful confirm_dependencies command, mark the package as updated
command, err := h.commandQueries.GetCommandByID(commandID)
if err == nil && command.CommandType == models.CommandTypeConfirmDependencies {
// Extract package info from command params
if packageName, ok := command.Params["package_name"].(string); ok {
if packageType, ok := command.Params["package_type"].(string); ok {
// Extract actual completion timestamp from command result for accurate audit trail
var completionTime *time.Time
if loggedAtStr, ok := command.Result["logged_at"].(string); ok {
if parsed, err := time.Parse(time.RFC3339Nano, loggedAtStr); err == nil {
completionTime = &parsed
}
}
// Update package status to 'updated' with actual completion timestamp
if err := h.updateQueries.UpdatePackageStatus(agentID, packageType, packageName, "updated", nil, completionTime); err != nil {
log.Printf("Warning: Failed to update package status for %s/%s: %v", packageType, packageName, err)
} else {
log.Printf("✅ Package %s (%s) marked as updated after successful installation", packageName, packageType)
}
}
}
}
} else if req.Result == "failed" || req.Result == "dry_run_failed" {
if err := h.commandQueries.MarkCommandFailed(commandID, result); err != nil {
fmt.Printf("Warning: Failed to mark command %s as failed: %v\n", commandID, err)
@@ -304,7 +355,7 @@ func (h *UpdateHandler) UpdatePackageStatus(c *gin.Context) {
return
}
if err := h.updateQueries.UpdatePackageStatus(agentID, req.PackageType, req.PackageName, req.Status, req.Metadata); err != nil {
if err := h.updateQueries.UpdatePackageStatus(agentID, req.PackageType, req.PackageName, req.Status, req.Metadata, nil); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to update package status"})
return
}
@@ -395,7 +446,29 @@ func (h *UpdateHandler) InstallUpdate(c *gin.Context) {
CreatedAt: time.Now(),
}
// Store the command in database
// Check if heartbeat should be enabled (avoid duplicates)
if shouldEnable, err := h.shouldEnableHeartbeat(update.AgentID, 10); err == nil && shouldEnable {
heartbeatCmd := &models.AgentCommand{
ID: uuid.New(),
AgentID: update.AgentID,
CommandType: models.CommandTypeEnableHeartbeat,
Params: models.JSONB{
"duration_minutes": 10,
},
Status: models.CommandStatusPending,
CreatedAt: time.Now(),
}
if err := h.commandQueries.CreateCommand(heartbeatCmd); err != nil {
log.Printf("[Heartbeat] Warning: Failed to create heartbeat command for agent %s: %v", update.AgentID, err)
} else {
log.Printf("[Heartbeat] Command created for agent %s before dry run", update.AgentID)
}
} else {
log.Printf("[Heartbeat] Skipping heartbeat command for agent %s (already active)", update.AgentID)
}
// Store the dry run command in database
if err := h.commandQueries.CreateCommand(command); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create dry run command"})
return
@@ -478,6 +551,28 @@ func (h *UpdateHandler) ReportDependencies(c *gin.Context) {
CreatedAt: time.Now(),
}
// Check if heartbeat should be enabled (avoid duplicates)
if shouldEnable, err := h.shouldEnableHeartbeat(agentID, 10); err == nil && shouldEnable {
heartbeatCmd := &models.AgentCommand{
ID: uuid.New(),
AgentID: agentID,
CommandType: models.CommandTypeEnableHeartbeat,
Params: models.JSONB{
"duration_minutes": 10,
},
Status: models.CommandStatusPending,
CreatedAt: time.Now(),
}
if err := h.commandQueries.CreateCommand(heartbeatCmd); err != nil {
log.Printf("[Heartbeat] Warning: Failed to create heartbeat command for agent %s: %v", agentID, err)
} else {
log.Printf("[Heartbeat] Command created for agent %s before installation", agentID)
}
} else {
log.Printf("[Heartbeat] Skipping heartbeat command for agent %s (already active)", agentID)
}
if err := h.commandQueries.CreateCommand(command); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create installation command"})
return
@@ -536,6 +631,28 @@ func (h *UpdateHandler) ConfirmDependencies(c *gin.Context) {
CreatedAt: time.Now(),
}
// Check if heartbeat should be enabled (avoid duplicates)
if shouldEnable, err := h.shouldEnableHeartbeat(update.AgentID, 10); err == nil && shouldEnable {
heartbeatCmd := &models.AgentCommand{
ID: uuid.New(),
AgentID: update.AgentID,
CommandType: models.CommandTypeEnableHeartbeat,
Params: models.JSONB{
"duration_minutes": 10,
},
Status: models.CommandStatusPending,
CreatedAt: time.Now(),
}
if err := h.commandQueries.CreateCommand(heartbeatCmd); err != nil {
log.Printf("[Heartbeat] Warning: Failed to create heartbeat command for agent %s: %v", update.AgentID, err)
} else {
log.Printf("[Heartbeat] Command created for agent %s before confirm dependencies", update.AgentID)
}
} else {
log.Printf("[Heartbeat] Skipping heartbeat command for agent %s (already active)", update.AgentID)
}
// Store the command in database
if err := h.commandQueries.CreateCommand(command); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create confirmation command"})
@@ -684,3 +801,60 @@ func (h *UpdateHandler) GetRecentCommands(c *gin.Context) {
"limit": limit,
})
}
// ClearFailedCommands manually removes failed/timed_out commands with cheeky warning
func (h *UpdateHandler) ClearFailedCommands(c *gin.Context) {
// Get query parameters for filtering
olderThanDaysStr := c.Query("older_than_days")
onlyRetriedStr := c.Query("only_retried")
allFailedStr := c.Query("all_failed")
var count int64
var err error
// Parse parameters
olderThanDays := 7 // default
if olderThanDaysStr != "" {
if days, err := strconv.Atoi(olderThanDaysStr); err == nil && days > 0 {
olderThanDays = days
}
}
onlyRetried := onlyRetriedStr == "true"
allFailed := allFailedStr == "true"
// Build the appropriate cleanup query based on parameters
if allFailed {
// Clear ALL failed commands (most aggressive)
count, err = h.commandQueries.ClearAllFailedCommands(olderThanDays)
} else if onlyRetried {
// Clear only failed commands that have been retried
count, err = h.commandQueries.ClearRetriedFailedCommands(olderThanDays)
} else {
// Clear failed commands older than specified days (default behavior)
count, err = h.commandQueries.ClearOldFailedCommands(olderThanDays)
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "failed to clear failed commands",
"details": err.Error(),
})
return
}
// Return success with cheeky message
message := fmt.Sprintf("Archived %d failed commands", count)
if count > 0 {
message += ". WARNING: This shouldn't be necessary if the retry logic is working properly - you might want to check what's causing commands to fail in the first place!"
message += " (History preserved - commands moved to archived status)"
} else {
message += ". No failed commands found matching your criteria. SUCCESS!"
}
c.JSON(http.StatusOK, gin.H{
"message": message,
"count": count,
"cheeky_warning": "Consider this a developer experience enhancement - the system should clean up after itself automatically!",
})
}

View File

@@ -0,0 +1,279 @@
package middleware
import (
"fmt"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// RateLimitConfig holds configuration for rate limiting
type RateLimitConfig struct {
Requests int `json:"requests"`
Window time.Duration `json:"window"`
Enabled bool `json:"enabled"`
}
// RateLimitEntry tracks requests for a specific key
type RateLimitEntry struct {
Requests []time.Time
mutex sync.RWMutex
}
// RateLimiter implements in-memory rate limiting with user-configurable settings
type RateLimiter struct {
entries sync.Map // map[string]*RateLimitEntry
configs map[string]RateLimitConfig
mutex sync.RWMutex
}
// RateLimitSettings holds all user-configurable rate limit settings
type RateLimitSettings struct {
AgentRegistration RateLimitConfig `json:"agent_registration"`
AgentCheckIn RateLimitConfig `json:"agent_checkin"`
AgentReports RateLimitConfig `json:"agent_reports"`
AdminTokenGen RateLimitConfig `json:"admin_token_generation"`
AdminOperations RateLimitConfig `json:"admin_operations"`
PublicAccess RateLimitConfig `json:"public_access"`
}
// DefaultRateLimitSettings provides sensible defaults
func DefaultRateLimitSettings() RateLimitSettings {
return RateLimitSettings{
AgentRegistration: RateLimitConfig{
Requests: 5,
Window: time.Minute,
Enabled: true,
},
AgentCheckIn: RateLimitConfig{
Requests: 60,
Window: time.Minute,
Enabled: true,
},
AgentReports: RateLimitConfig{
Requests: 30,
Window: time.Minute,
Enabled: true,
},
AdminTokenGen: RateLimitConfig{
Requests: 10,
Window: time.Minute,
Enabled: true,
},
AdminOperations: RateLimitConfig{
Requests: 100,
Window: time.Minute,
Enabled: true,
},
PublicAccess: RateLimitConfig{
Requests: 20,
Window: time.Minute,
Enabled: true,
},
}
}
// NewRateLimiter creates a new rate limiter with default settings
func NewRateLimiter() *RateLimiter {
rl := &RateLimiter{
entries: sync.Map{},
}
// Load default settings
defaults := DefaultRateLimitSettings()
rl.UpdateSettings(defaults)
return rl
}
// UpdateSettings updates rate limit configurations
func (rl *RateLimiter) UpdateSettings(settings RateLimitSettings) {
rl.mutex.Lock()
defer rl.mutex.Unlock()
rl.configs = map[string]RateLimitConfig{
"agent_registration": settings.AgentRegistration,
"agent_checkin": settings.AgentCheckIn,
"agent_reports": settings.AgentReports,
"admin_token_gen": settings.AdminTokenGen,
"admin_operations": settings.AdminOperations,
"public_access": settings.PublicAccess,
}
}
// GetSettings returns current rate limit settings
func (rl *RateLimiter) GetSettings() RateLimitSettings {
rl.mutex.RLock()
defer rl.mutex.RUnlock()
return RateLimitSettings{
AgentRegistration: rl.configs["agent_registration"],
AgentCheckIn: rl.configs["agent_checkin"],
AgentReports: rl.configs["agent_reports"],
AdminTokenGen: rl.configs["admin_token_gen"],
AdminOperations: rl.configs["admin_operations"],
PublicAccess: rl.configs["public_access"],
}
}
// RateLimit creates middleware for a specific rate limit type
func (rl *RateLimiter) RateLimit(limitType string, keyFunc func(*gin.Context) string) gin.HandlerFunc {
return func(c *gin.Context) {
rl.mutex.RLock()
config, exists := rl.configs[limitType]
rl.mutex.RUnlock()
if !exists || !config.Enabled {
c.Next()
return
}
key := keyFunc(c)
if key == "" {
c.Next()
return
}
// Check rate limit
allowed, resetTime := rl.checkRateLimit(key, config)
if !allowed {
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", config.Requests))
c.Header("X-RateLimit-Remaining", "0")
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
c.Header("Retry-After", fmt.Sprintf("%d", int(resetTime.Sub(time.Now()).Seconds())))
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "Rate limit exceeded",
"limit": config.Requests,
"window": config.Window.String(),
"reset_time": resetTime,
})
c.Abort()
return
}
// Add rate limit headers
remaining := rl.getRemainingRequests(key, config)
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", config.Requests))
c.Header("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Add(config.Window).Unix()))
c.Next()
}
}
// checkRateLimit checks if the request is allowed
func (rl *RateLimiter) checkRateLimit(key string, config RateLimitConfig) (bool, time.Time) {
now := time.Now()
// Get or create entry
entryInterface, _ := rl.entries.LoadOrStore(key, &RateLimitEntry{
Requests: []time.Time{},
})
entry := entryInterface.(*RateLimitEntry)
entry.mutex.Lock()
defer entry.mutex.Unlock()
// Clean old requests outside the window
cutoff := now.Add(-config.Window)
validRequests := make([]time.Time, 0)
for _, reqTime := range entry.Requests {
if reqTime.After(cutoff) {
validRequests = append(validRequests, reqTime)
}
}
// Check if under limit
if len(validRequests) >= config.Requests {
// Find when the oldest request expires
oldestRequest := validRequests[0]
resetTime := oldestRequest.Add(config.Window)
return false, resetTime
}
// Add current request
entry.Requests = append(validRequests, now)
// Clean up expired entries periodically
if len(entry.Requests) == 0 {
rl.entries.Delete(key)
}
return true, time.Time{}
}
// getRemainingRequests calculates remaining requests for the key
func (rl *RateLimiter) getRemainingRequests(key string, config RateLimitConfig) int {
entryInterface, ok := rl.entries.Load(key)
if !ok {
return config.Requests
}
entry := entryInterface.(*RateLimitEntry)
entry.mutex.RLock()
defer entry.mutex.RUnlock()
now := time.Now()
cutoff := now.Add(-config.Window)
count := 0
for _, reqTime := range entry.Requests {
if reqTime.After(cutoff) {
count++
}
}
remaining := config.Requests - count
if remaining < 0 {
remaining = 0
}
return remaining
}
// CleanupExpiredEntries removes expired entries to prevent memory leaks
func (rl *RateLimiter) CleanupExpiredEntries() {
rl.entries.Range(func(key, value interface{}) bool {
entry := value.(*RateLimitEntry)
entry.mutex.Lock()
now := time.Now()
validRequests := make([]time.Time, 0)
for _, reqTime := range entry.Requests {
if reqTime.After(now.Add(-time.Hour)) { // Keep requests from last hour
validRequests = append(validRequests, reqTime)
}
}
if len(validRequests) == 0 {
rl.entries.Delete(key)
} else {
entry.Requests = validRequests
}
entry.mutex.Unlock()
return true
})
}
// Key generation functions
func KeyByIP(c *gin.Context) string {
return c.ClientIP()
}
func KeyByAgentID(c *gin.Context) string {
return c.Param("id")
}
func KeyByUserID(c *gin.Context) string {
// This would extract user ID from JWT or session
// For now, use IP as fallback
return c.ClientIP()
}
func KeyByIPAndPath(c *gin.Context) string {
return c.ClientIP() + ":" + c.Request.URL.Path
}

View File

@@ -1,18 +1,47 @@
package config
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/joho/godotenv"
"golang.org/x/term"
)
// Config holds the application configuration
type Config struct {
ServerPort string
DatabaseURL string
JWTSecret string
Server struct {
Host string `env:"REDFLAG_SERVER_HOST" default:"0.0.0.0"`
Port int `env:"REDFLAG_SERVER_PORT" default:"8080"`
TLS struct {
Enabled bool `env:"REDFLAG_TLS_ENABLED" default:"false"`
CertFile string `env:"REDFLAG_TLS_CERT_FILE"`
KeyFile string `env:"REDFLAG_TLS_KEY_FILE"`
}
}
Database struct {
Host string `env:"REDFLAG_DB_HOST" default:"localhost"`
Port int `env:"REDFLAG_DB_PORT" default:"5432"`
Database string `env:"REDFLAG_DB_NAME" default:"redflag"`
Username string `env:"REDFLAG_DB_USER" default:"redflag"`
Password string `env:"REDFLAG_DB_PASSWORD"`
}
Admin struct {
Username string `env:"REDFLAG_ADMIN_USER" default:"admin"`
Password string `env:"REDFLAG_ADMIN_PASSWORD"`
JWTSecret string `env:"REDFLAG_JWT_SECRET"`
}
AgentRegistration struct {
TokenExpiry string `env:"REDFLAG_TOKEN_EXPIRY" default:"24h"`
MaxTokens int `env:"REDFLAG_MAX_TOKENS" default:"100"`
MaxSeats int `env:"REDFLAG_MAX_SEATS" default:"50"`
}
CheckInInterval int
OfflineThreshold int
Timezone string
@@ -24,30 +53,195 @@ func Load() (*Config, error) {
// Load .env file if it exists (for development)
_ = godotenv.Load()
cfg := &Config{}
// Parse server configuration
cfg.Server.Host = getEnv("REDFLAG_SERVER_HOST", "0.0.0.0")
serverPort, _ := strconv.Atoi(getEnv("REDFLAG_SERVER_PORT", "8080"))
cfg.Server.Port = serverPort
cfg.Server.TLS.Enabled = getEnv("REDFLAG_TLS_ENABLED", "false") == "true"
cfg.Server.TLS.CertFile = getEnv("REDFLAG_TLS_CERT_FILE", "")
cfg.Server.TLS.KeyFile = getEnv("REDFLAG_TLS_KEY_FILE", "")
// Parse database configuration
cfg.Database.Host = getEnv("REDFLAG_DB_HOST", "localhost")
dbPort, _ := strconv.Atoi(getEnv("REDFLAG_DB_PORT", "5432"))
cfg.Database.Port = dbPort
cfg.Database.Database = getEnv("REDFLAG_DB_NAME", "redflag")
cfg.Database.Username = getEnv("REDFLAG_DB_USER", "redflag")
cfg.Database.Password = getEnv("REDFLAG_DB_PASSWORD", "")
// Parse admin configuration
cfg.Admin.Username = getEnv("REDFLAG_ADMIN_USER", "admin")
cfg.Admin.Password = getEnv("REDFLAG_ADMIN_PASSWORD", "")
cfg.Admin.JWTSecret = getEnv("REDFLAG_JWT_SECRET", "")
// Parse agent registration configuration
cfg.AgentRegistration.TokenExpiry = getEnv("REDFLAG_TOKEN_EXPIRY", "24h")
maxTokens, _ := strconv.Atoi(getEnv("REDFLAG_MAX_TOKENS", "100"))
cfg.AgentRegistration.MaxTokens = maxTokens
maxSeats, _ := strconv.Atoi(getEnv("REDFLAG_MAX_SEATS", "50"))
cfg.AgentRegistration.MaxSeats = maxSeats
// Parse legacy configuration for backwards compatibility
checkInInterval, _ := strconv.Atoi(getEnv("CHECK_IN_INTERVAL", "300"))
offlineThreshold, _ := strconv.Atoi(getEnv("OFFLINE_THRESHOLD", "600"))
cfg.CheckInInterval = checkInInterval
cfg.OfflineThreshold = offlineThreshold
cfg.Timezone = getEnv("TIMEZONE", "UTC")
cfg.LatestAgentVersion = getEnv("LATEST_AGENT_VERSION", "0.1.16")
cfg := &Config{
ServerPort: getEnv("SERVER_PORT", "8080"),
DatabaseURL: getEnv("DATABASE_URL", "postgres://aggregator:aggregator@localhost:5432/aggregator?sslmode=disable"),
JWTSecret: getEnv("JWT_SECRET", "test-secret-for-development-only"),
CheckInInterval: checkInInterval,
OfflineThreshold: offlineThreshold,
Timezone: getEnv("TIMEZONE", "UTC"),
LatestAgentVersion: getEnv("LATEST_AGENT_VERSION", "0.1.4"),
// Handle missing secrets
if cfg.Admin.Password == "" || cfg.Admin.JWTSecret == "" || cfg.Database.Password == "" {
fmt.Printf("[WARNING] Missing required configuration (admin password, JWT secret, or database password)\n")
fmt.Printf("[INFO] Run: ./redflag-server --setup to configure\n")
return nil, fmt.Errorf("missing required configuration")
}
// Debug: Log what JWT secret we're using (remove in production)
if cfg.JWTSecret == "test-secret-for-development-only" {
fmt.Printf("🔓 Using development JWT secret\n")
// Validate JWT secret is not the development default
if cfg.Admin.JWTSecret == "test-secret-for-development-only" {
fmt.Printf("[SECURITY WARNING] Using development JWT secret\n")
fmt.Printf("[INFO] Run: ./redflag-server --setup to configure production secrets\n")
}
return cfg, nil
}
// RunSetupWizard guides user through initial configuration
func RunSetupWizard() error {
fmt.Printf("RedFlag Server Setup Wizard\n")
fmt.Printf("===========================\n\n")
// Admin credentials
fmt.Printf("Admin Account Setup\n")
fmt.Printf("--------------------\n")
username := promptForInput("Admin username", "admin")
password := promptForPassword("Admin password")
// Database configuration
fmt.Printf("\nDatabase Configuration\n")
fmt.Printf("----------------------\n")
dbHost := promptForInput("Database host", "localhost")
dbPort, _ := strconv.Atoi(promptForInput("Database port", "5432"))
dbName := promptForInput("Database name", "redflag")
dbUser := promptForInput("Database user", "redflag")
dbPassword := promptForPassword("Database password")
// Server configuration
fmt.Printf("\nServer Configuration\n")
fmt.Printf("--------------------\n")
serverHost := promptForInput("Server bind address", "0.0.0.0")
serverPort, _ := strconv.Atoi(promptForInput("Server port", "8080"))
// Agent limits
fmt.Printf("\nAgent Registration\n")
fmt.Printf("------------------\n")
maxSeats, _ := strconv.Atoi(promptForInput("Maximum agent seats (security limit)", "50"))
// Generate JWT secret from admin password
jwtSecret := deriveJWTSecret(username, password)
// Create .env file
envContent := fmt.Sprintf(`# RedFlag Server Configuration
# Generated on %s
# Server Configuration
REDFLAG_SERVER_HOST=%s
REDFLAG_SERVER_PORT=%d
REDFLAG_TLS_ENABLED=false
# REDFLAG_TLS_CERT_FILE=
# REDFLAG_TLS_KEY_FILE=
# Database Configuration
REDFLAG_DB_HOST=%s
REDFLAG_DB_PORT=%d
REDFLAG_DB_NAME=%s
REDFLAG_DB_USER=%s
REDFLAG_DB_PASSWORD=%s
# Admin Configuration
REDFLAG_ADMIN_USER=%s
REDFLAG_ADMIN_PASSWORD=%s
REDFLAG_JWT_SECRET=%s
# Agent Registration
REDFLAG_TOKEN_EXPIRY=24h
REDFLAG_MAX_TOKENS=100
REDFLAG_MAX_SEATS=%d
# Legacy Configuration (for backwards compatibility)
SERVER_PORT=%d
DATABASE_URL=postgres://%s:%s@%s:%d/%s?sslmode=disable
JWT_SECRET=%s
CHECK_IN_INTERVAL=300
OFFLINE_THRESHOLD=600
TIMEZONE=UTC
LATEST_AGENT_VERSION=0.1.8
`, time.Now().Format("2006-01-02 15:04:05"), serverHost, serverPort,
dbHost, dbPort, dbName, dbUser, dbPassword,
username, password, jwtSecret, maxSeats,
serverPort, dbUser, dbPassword, dbHost, dbPort, dbName, jwtSecret)
// Write .env file
if err := os.WriteFile(".env", []byte(envContent), 0600); err != nil {
return fmt.Errorf("failed to write .env file: %w", err)
}
fmt.Printf("\n[OK] Configuration saved to .env file\n")
fmt.Printf("[SECURITY] File permissions set to 0600 (owner read/write only)\n")
fmt.Printf("\nNext steps:\n")
fmt.Printf(" 1. Start database: %s:%d\n", dbHost, dbPort)
fmt.Printf(" 2. Create database: CREATE DATABASE %s;\n", dbName)
fmt.Printf(" 3. Run migrations: ./redflag-server --migrate\n")
fmt.Printf(" 4. Start server: ./redflag-server\n")
fmt.Printf("\nServer will be available at: http://%s:%d\n", serverHost, serverPort)
fmt.Printf("Admin interface: http://%s:%d/admin\n", serverHost, serverPort)
return nil
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func promptForInput(prompt, defaultValue string) string {
fmt.Printf("%s [%s]: ", prompt, defaultValue)
var input string
fmt.Scanln(&input)
if strings.TrimSpace(input) == "" {
return defaultValue
}
return strings.TrimSpace(input)
}
func promptForPassword(prompt string) string {
fmt.Printf("%s: ", prompt)
password, err := term.ReadPassword(int(os.Stdin.Fd()))
if err != nil {
// Fallback to non-hidden input
var input string
fmt.Scanln(&input)
return strings.TrimSpace(input)
}
fmt.Printf("\n")
return strings.TrimSpace(string(password))
}
func deriveJWTSecret(username, password string) string {
// Derive JWT secret from admin credentials
// This ensures JWT secret changes if admin password changes
hash := sha256.Sum256([]byte(username + password + "redflag-jwt-2024"))
return hex.EncodeToString(hash[:])
}
// GenerateSecureToken generates a cryptographically secure random token
func GenerateSecureToken() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate secure token: %w", err)
}
return hex.EncodeToString(bytes), nil
}

View File

@@ -0,0 +1,9 @@
-- Add retry tracking to agent_commands table
-- This allows us to track command retry chains and display retry indicators in the UI
-- Add retried_from_id column to link retries to their original commands
ALTER TABLE agent_commands
ADD COLUMN retried_from_id UUID REFERENCES agent_commands(id) ON DELETE SET NULL;
-- Add index for efficient retry chain lookups
CREATE INDEX idx_commands_retried_from ON agent_commands(retried_from_id) WHERE retried_from_id IS NOT NULL;

View File

@@ -0,0 +1,9 @@
-- Add 'archived_failed' status to agent_commands status constraint
-- This allows archiving failed/timed_out commands to clean up the active list
-- Drop the existing constraint
ALTER TABLE agent_commands DROP CONSTRAINT IF EXISTS agent_commands_status_check;
-- Add the new constraint with 'archived_failed' included
ALTER TABLE agent_commands ADD CONSTRAINT agent_commands_status_check
CHECK (status IN ('pending', 'sent', 'running', 'completed', 'failed', 'timed_out', 'cancelled', 'archived_failed'));

View File

@@ -0,0 +1,85 @@
-- Registration tokens for secure agent enrollment
-- Tokens are one-time use and have configurable expiration
CREATE TABLE registration_tokens (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
token VARCHAR(64) UNIQUE NOT NULL, -- One-time use token
label VARCHAR(255), -- Optional label for token identification
expires_at TIMESTAMP NOT NULL, -- Token expiration time
created_at TIMESTAMP DEFAULT NOW(), -- When token was created
used_at TIMESTAMP NULL, -- When token was used (NULL if unused)
used_by_agent_id UUID NULL, -- Which agent used this token (foreign key)
revoked BOOLEAN DEFAULT FALSE, -- Manual revocation
revoked_at TIMESTAMP NULL, -- When token was revoked
revoked_reason VARCHAR(255) NULL, -- Reason for revocation
-- Token status tracking
status VARCHAR(20) DEFAULT 'active'
CHECK (status IN ('active', 'used', 'expired', 'revoked')),
-- Additional metadata
created_by VARCHAR(100) DEFAULT 'setup_wizard', -- Who created the token
metadata JSONB DEFAULT '{}'::jsonb -- Additional token metadata
);
-- Indexes for performance
CREATE INDEX idx_registration_tokens_token ON registration_tokens(token);
CREATE INDEX idx_registration_tokens_expires_at ON registration_tokens(expires_at);
CREATE INDEX idx_registration_tokens_status ON registration_tokens(status);
CREATE INDEX idx_registration_tokens_used_by_agent ON registration_tokens(used_by_agent_id) WHERE used_by_agent_id IS NOT NULL;
-- Foreign key constraint for used_by_agent_id
ALTER TABLE registration_tokens
ADD CONSTRAINT fk_registration_tokens_agent
FOREIGN KEY (used_by_agent_id) REFERENCES agents(id) ON DELETE SET NULL;
-- Function to clean up expired tokens (called by periodic cleanup job)
CREATE OR REPLACE FUNCTION cleanup_expired_registration_tokens()
RETURNS INTEGER AS $$
DECLARE
deleted_count INTEGER;
BEGIN
UPDATE registration_tokens
SET status = 'expired',
used_at = NOW()
WHERE status = 'active'
AND expires_at < NOW()
AND used_at IS NULL;
GET DIAGNOSTICS deleted_count = ROW_COUNT;
RETURN deleted_count;
END;
$$ LANGUAGE plpgsql;
-- Function to check if a token is valid
CREATE OR REPLACE FUNCTION is_registration_token_valid(token_input VARCHAR)
RETURNS BOOLEAN AS $$
DECLARE
token_valid BOOLEAN;
BEGIN
SELECT (status = 'active' AND expires_at > NOW()) INTO token_valid
FROM registration_tokens
WHERE token = token_input;
RETURN COALESCE(token_valid, FALSE);
END;
$$ LANGUAGE plpgsql;
-- Function to mark token as used
CREATE OR REPLACE function mark_registration_token_used(token_input VARCHAR, agent_id UUID)
RETURNS BOOLEAN AS $$
DECLARE
updated BOOLEAN;
BEGIN
UPDATE registration_tokens
SET status = 'used',
used_at = NOW(),
used_by_agent_id = agent_id
WHERE token = token_input
AND status = 'active'
AND expires_at > NOW();
GET DIAGNOSTICS updated = ROW_COUNT;
RETURN updated > 0;
END;
$$ LANGUAGE plpgsql;

View File

@@ -196,3 +196,11 @@ func (q *AgentQueries) DeleteAgent(id uuid.UUID) error {
// Commit the transaction
return tx.Commit()
}
// GetActiveAgentCount returns the count of active (online) agents
func (q *AgentQueries) GetActiveAgentCount() (int, error) {
var count int
query := `SELECT COUNT(*) FROM agents WHERE status = 'online'`
err := q.db.Get(&count, query)
return count, err
}

View File

@@ -21,9 +21,9 @@ func NewCommandQueries(db *sqlx.DB) *CommandQueries {
func (q *CommandQueries) CreateCommand(cmd *models.AgentCommand) error {
query := `
INSERT INTO agent_commands (
id, agent_id, command_type, params, status
id, agent_id, command_type, params, status, retried_from_id
) VALUES (
:id, :agent_id, :command_type, :params, :status
:id, :agent_id, :command_type, :params, :status, :retried_from_id
)
`
_, err := q.db.NamedExec(query, cmd)
@@ -152,14 +152,15 @@ func (q *CommandQueries) RetryCommand(originalID uuid.UUID) (*models.AgentComman
return nil, fmt.Errorf("command must be failed, timed_out, or cancelled to retry")
}
// Create new command with same parameters
// Create new command with same parameters, linking it to the original
newCommand := &models.AgentCommand{
ID: uuid.New(),
AgentID: original.AgentID,
CommandType: original.CommandType,
Params: original.Params,
Status: models.CommandStatusPending,
CreatedAt: time.Now(),
ID: uuid.New(),
AgentID: original.AgentID,
CommandType: original.CommandType,
Params: original.Params,
Status: models.CommandStatusPending,
CreatedAt: time.Now(),
RetriedFromID: &originalID,
}
// Store the new command
@@ -180,20 +181,44 @@ func (q *CommandQueries) GetActiveCommands() ([]models.ActiveCommandInfo, error)
c.id,
c.agent_id,
c.command_type,
c.params,
c.status,
c.created_at,
c.sent_at,
c.result,
c.retried_from_id,
a.hostname as agent_hostname,
COALESCE(ups.package_name, 'N/A') as package_name,
COALESCE(ups.package_type, 'N/A') as package_type
COALESCE(ups.package_type, 'N/A') as package_type,
(c.retried_from_id IS NOT NULL) as is_retry,
EXISTS(SELECT 1 FROM agent_commands WHERE retried_from_id = c.id) as has_been_retried,
COALESCE((
WITH RECURSIVE retry_chain AS (
SELECT id, retried_from_id, 1 as depth
FROM agent_commands
WHERE id = c.id
UNION ALL
SELECT ac.id, ac.retried_from_id, rc.depth + 1
FROM agent_commands ac
JOIN retry_chain rc ON ac.id = rc.retried_from_id
)
SELECT MAX(depth) FROM retry_chain
), 1) - 1 as retry_count
FROM agent_commands c
LEFT JOIN agents a ON c.agent_id = a.id
LEFT JOIN current_package_state ups ON (
c.params->>'update_id' = ups.id::text OR
(c.params->>'package_name' = ups.package_name AND c.params->>'package_type' = ups.package_type)
)
WHERE c.status NOT IN ('completed', 'cancelled')
WHERE c.status NOT IN ('completed', 'cancelled', 'archived_failed')
AND NOT (
c.status IN ('failed', 'timed_out')
AND EXISTS (
SELECT 1 FROM agent_commands retry
WHERE retry.retried_from_id = c.id
AND retry.status = 'completed'
)
)
ORDER BY c.created_at DESC
`
@@ -223,9 +248,24 @@ func (q *CommandQueries) GetRecentCommands(limit int) ([]models.ActiveCommandInf
c.sent_at,
c.completed_at,
c.result,
c.retried_from_id,
a.hostname as agent_hostname,
COALESCE(ups.package_name, 'N/A') as package_name,
COALESCE(ups.package_type, 'N/A') as package_type
COALESCE(ups.package_type, 'N/A') as package_type,
(c.retried_from_id IS NOT NULL) as is_retry,
EXISTS(SELECT 1 FROM agent_commands WHERE retried_from_id = c.id) as has_been_retried,
COALESCE((
WITH RECURSIVE retry_chain AS (
SELECT id, retried_from_id, 1 as depth
FROM agent_commands
WHERE id = c.id
UNION ALL
SELECT ac.id, ac.retried_from_id, rc.depth + 1
FROM agent_commands ac
JOIN retry_chain rc ON ac.id = rc.retried_from_id
)
SELECT MAX(depth) FROM retry_chain
), 1) - 1 as retry_count
FROM agent_commands c
LEFT JOIN agents a ON c.agent_id = a.id
LEFT JOIN current_package_state ups ON (
@@ -243,3 +283,55 @@ func (q *CommandQueries) GetRecentCommands(limit int) ([]models.ActiveCommandInf
return commands, nil
}
// ClearOldFailedCommands archives failed commands older than specified days by changing status to 'archived_failed'
func (q *CommandQueries) ClearOldFailedCommands(days int) (int64, error) {
query := fmt.Sprintf(`
UPDATE agent_commands
SET status = 'archived_failed'
WHERE status IN ('failed', 'timed_out')
AND created_at < NOW() - INTERVAL '%d days'
`, days)
result, err := q.db.Exec(query)
if err != nil {
return 0, fmt.Errorf("failed to archive old failed commands: %w", err)
}
return result.RowsAffected()
}
// ClearRetriedFailedCommands archives failed commands that have been retried and are older than specified days
func (q *CommandQueries) ClearRetriedFailedCommands(days int) (int64, error) {
query := fmt.Sprintf(`
UPDATE agent_commands
SET status = 'archived_failed'
WHERE status IN ('failed', 'timed_out')
AND EXISTS (SELECT 1 FROM agent_commands WHERE retried_from_id = agent_commands.id)
AND created_at < NOW() - INTERVAL '%d days'
`, days)
result, err := q.db.Exec(query)
if err != nil {
return 0, fmt.Errorf("failed to archive retried failed commands: %w", err)
}
return result.RowsAffected()
}
// ClearAllFailedCommands archives all failed commands older than specified days (most aggressive)
func (q *CommandQueries) ClearAllFailedCommands(days int) (int64, error) {
query := fmt.Sprintf(`
UPDATE agent_commands
SET status = 'archived_failed'
WHERE status IN ('failed', 'timed_out')
AND created_at < NOW() - INTERVAL '%d days'
`, days)
result, err := q.db.Exec(query)
if err != nil {
return 0, fmt.Errorf("failed to archive all failed commands: %w", err)
}
return result.RowsAffected()
}

View File

@@ -0,0 +1,232 @@
package queries
import (
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
)
type RegistrationTokenQueries struct {
db *sqlx.DB
}
type RegistrationToken struct {
ID uuid.UUID `json:"id" db:"id"`
Token string `json:"token" db:"token"`
Label *string `json:"label" db:"label"`
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UsedAt *time.Time `json:"used_at" db:"used_at"`
UsedByAgentID *uuid.UUID `json:"used_by_agent_id" db:"used_by_agent_id"`
Revoked bool `json:"revoked" db:"revoked"`
RevokedAt *time.Time `json:"revoked_at" db:"revoked_at"`
RevokedReason *string `json:"revoked_reason" db:"revoked_reason"`
Status string `json:"status" db:"status"`
CreatedBy string `json:"created_by" db:"created_by"`
Metadata map[string]interface{} `json:"metadata" db:"metadata"`
}
type TokenRequest struct {
Label string `json:"label"`
ExpiresIn string `json:"expires_in"` // e.g., "24h", "7d"
Metadata map[string]interface{} `json:"metadata"`
}
type TokenResponse struct {
Token string `json:"token"`
Label string `json:"label"`
ExpiresAt time.Time `json:"expires_at"`
InstallCommand string `json:"install_command"`
}
func NewRegistrationTokenQueries(db *sqlx.DB) *RegistrationTokenQueries {
return &RegistrationTokenQueries{db: db}
}
// CreateRegistrationToken creates a new one-time use registration token
func (q *RegistrationTokenQueries) CreateRegistrationToken(token, label string, expiresAt time.Time, metadata map[string]interface{}) error {
metadataJSON, err := json.Marshal(metadata)
if err != nil {
return fmt.Errorf("failed to marshal metadata: %w", err)
}
query := `
INSERT INTO registration_tokens (token, label, expires_at, metadata)
VALUES ($1, $2, $3, $4)
`
_, err = q.db.Exec(query, token, label, expiresAt, metadataJSON)
if err != nil {
return fmt.Errorf("failed to create registration token: %w", err)
}
return nil
}
// ValidateRegistrationToken checks if a token is valid and unused
func (q *RegistrationTokenQueries) ValidateRegistrationToken(token string) (*RegistrationToken, error) {
var regToken RegistrationToken
query := `
SELECT id, token, label, expires_at, created_at, used_at, used_by_agent_id,
revoked, revoked_at, revoked_reason, status, created_by, metadata
FROM registration_tokens
WHERE token = $1 AND status = 'active' AND expires_at > NOW()
`
err := q.db.Get(&regToken, query, token)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("invalid or expired token")
}
return nil, fmt.Errorf("failed to validate token: %w", err)
}
return &regToken, nil
}
// MarkTokenUsed marks a token as used by an agent
func (q *RegistrationTokenQueries) MarkTokenUsed(token string, agentID uuid.UUID) error {
query := `
UPDATE registration_tokens
SET status = 'used',
used_at = NOW(),
used_by_agent_id = $1
WHERE token = $2 AND status = 'active' AND expires_at > NOW()
`
result, err := q.db.Exec(query, agentID, token)
if err != nil {
return fmt.Errorf("failed to mark token as used: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("token not found or already used")
}
return nil
}
// GetActiveRegistrationTokens returns all active tokens
func (q *RegistrationTokenQueries) GetActiveRegistrationTokens() ([]RegistrationToken, error) {
var tokens []RegistrationToken
query := `
SELECT id, token, label, expires_at, created_at, used_at, used_by_agent_id,
revoked, revoked_at, revoked_reason, status, created_by, metadata
FROM registration_tokens
WHERE status = 'active'
ORDER BY created_at DESC
`
err := q.db.Select(&tokens, query)
if err != nil {
return nil, fmt.Errorf("failed to get active tokens: %w", err)
}
return tokens, nil
}
// GetAllRegistrationTokens returns all tokens with pagination
func (q *RegistrationTokenQueries) GetAllRegistrationTokens(limit, offset int) ([]RegistrationToken, error) {
var tokens []RegistrationToken
query := `
SELECT id, token, label, expires_at, created_at, used_at, used_by_agent_id,
revoked, revoked_at, revoked_reason, status, created_by, metadata
FROM registration_tokens
ORDER BY created_at DESC
LIMIT $1 OFFSET $2
`
err := q.db.Select(&tokens, query, limit, offset)
if err != nil {
return nil, fmt.Errorf("failed to get all tokens: %w", err)
}
return tokens, nil
}
// RevokeRegistrationToken revokes a token
func (q *RegistrationTokenQueries) RevokeRegistrationToken(token, reason string) error {
query := `
UPDATE registration_tokens
SET status = 'revoked',
revoked = true,
revoked_at = NOW(),
revoked_reason = $1
WHERE token = $2 AND status = 'active'
`
result, err := q.db.Exec(query, reason, token)
if err != nil {
return fmt.Errorf("failed to revoke token: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("token not found or already used/revoked")
}
return nil
}
// CleanupExpiredTokens marks expired tokens as expired
func (q *RegistrationTokenQueries) CleanupExpiredTokens() (int, error) {
query := `
UPDATE registration_tokens
SET status = 'expired',
used_at = NOW()
WHERE status = 'active' AND expires_at < NOW() AND used_at IS NULL
`
result, err := q.db.Exec(query)
if err != nil {
return 0, fmt.Errorf("failed to cleanup expired tokens: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return int(rowsAffected), nil
}
// GetTokenUsageStats returns statistics about token usage
func (q *RegistrationTokenQueries) GetTokenUsageStats() (map[string]int, error) {
stats := make(map[string]int)
query := `
SELECT status, COUNT(*) as count
FROM registration_tokens
GROUP BY status
`
rows, err := q.db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to get token stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var status string
var count int
if err := rows.Scan(&status, &count); err != nil {
return nil, fmt.Errorf("failed to scan token stats row: %w", err)
}
stats[status] = count
}
return stats, nil
}

View File

@@ -527,7 +527,8 @@ func (q *UpdateQueries) GetPackageHistory(agentID uuid.UUID, packageType, packag
}
// UpdatePackageStatus updates the status of a package and records history
func (q *UpdateQueries) UpdatePackageStatus(agentID uuid.UUID, packageType, packageName, status string, metadata models.JSONB) error {
// completedAt is optional - if nil, uses time.Now(). Pass actual completion time for accurate audit trails.
func (q *UpdateQueries) UpdatePackageStatus(agentID uuid.UUID, packageType, packageName, status string, metadata models.JSONB, completedAt *time.Time) error {
tx, err := q.db.Beginx()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
@@ -542,13 +543,19 @@ func (q *UpdateQueries) UpdatePackageStatus(agentID uuid.UUID, packageType, pack
return fmt.Errorf("failed to get current state: %w", err)
}
// Use provided timestamp or fall back to server time
timestamp := time.Now()
if completedAt != nil {
timestamp = *completedAt
}
// Update status
updateQuery := `
UPDATE current_package_state
SET status = $1, last_updated_at = $2
WHERE agent_id = $3 AND package_type = $4 AND package_name = $5
`
_, err = tx.Exec(updateQuery, status, time.Now(), agentID, packageType, packageName)
_, err = tx.Exec(updateQuery, status, timestamp, agentID, packageType, packageName)
if err != nil {
return fmt.Errorf("failed to update package status: %w", err)
}
@@ -564,7 +571,7 @@ func (q *UpdateQueries) UpdatePackageStatus(agentID uuid.UUID, packageType, pack
_, err = tx.Exec(historyQuery,
agentID, packageType, packageName, currentState.CurrentVersion,
currentState.AvailableVersion, currentState.Severity,
currentState.RepositorySource, metadata, time.Now(), status)
currentState.RepositorySource, metadata, timestamp, status)
if err != nil {
return fmt.Errorf("failed to record version history: %w", err)
}

View File

@@ -8,20 +8,28 @@ import (
// AgentCommand represents a command to be executed by an agent
type AgentCommand struct {
ID uuid.UUID `json:"id" db:"id"`
AgentID uuid.UUID `json:"agent_id" db:"agent_id"`
CommandType string `json:"command_type" db:"command_type"`
Params JSONB `json:"params" db:"params"`
Status string `json:"status" db:"status"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
SentAt *time.Time `json:"sent_at,omitempty" db:"sent_at"`
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
Result JSONB `json:"result,omitempty" db:"result"`
ID uuid.UUID `json:"id" db:"id"`
AgentID uuid.UUID `json:"agent_id" db:"agent_id"`
CommandType string `json:"command_type" db:"command_type"`
Params JSONB `json:"params" db:"params"`
Status string `json:"status" db:"status"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
SentAt *time.Time `json:"sent_at,omitempty" db:"sent_at"`
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
Result JSONB `json:"result,omitempty" db:"result"`
RetriedFromID *uuid.UUID `json:"retried_from_id,omitempty" db:"retried_from_id"`
}
// CommandsResponse is returned when an agent checks in for commands
type CommandsResponse struct {
Commands []CommandItem `json:"commands"`
Commands []CommandItem `json:"commands"`
RapidPolling *RapidPollingConfig `json:"rapid_polling,omitempty"`
}
// RapidPollingConfig contains rapid polling configuration for the agent
type RapidPollingConfig struct {
Enabled bool `json:"enabled"`
Until string `json:"until"` // ISO 8601 timestamp
}
// CommandItem represents a command in the response
@@ -40,6 +48,8 @@ const (
CommandTypeConfirmDependencies = "confirm_dependencies"
CommandTypeRollback = "rollback_update"
CommandTypeUpdateAgent = "update_agent"
CommandTypeEnableHeartbeat = "enable_heartbeat"
CommandTypeDisableHeartbeat = "disable_heartbeat"
)
// Command statuses
@@ -55,15 +65,20 @@ const (
// ActiveCommandInfo represents information about an active command for UI display
type ActiveCommandInfo struct {
ID uuid.UUID `json:"id" db:"id"`
AgentID uuid.UUID `json:"agent_id" db:"agent_id"`
CommandType string `json:"command_type" db:"command_type"`
Status string `json:"status" db:"status"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
SentAt *time.Time `json:"sent_at,omitempty" db:"sent_at"`
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
Result JSONB `json:"result,omitempty" db:"result"`
AgentHostname string `json:"agent_hostname" db:"agent_hostname"`
PackageName string `json:"package_name" db:"package_name"`
PackageType string `json:"package_type" db:"package_type"`
ID uuid.UUID `json:"id" db:"id"`
AgentID uuid.UUID `json:"agent_id" db:"agent_id"`
CommandType string `json:"command_type" db:"command_type"`
Params JSONB `json:"params" db:"params"`
Status string `json:"status" db:"status"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
SentAt *time.Time `json:"sent_at,omitempty" db:"sent_at"`
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
Result JSONB `json:"result,omitempty" db:"result"`
AgentHostname string `json:"agent_hostname" db:"agent_hostname"`
PackageName string `json:"package_name" db:"package_name"`
PackageType string `json:"package_type" db:"package_type"`
RetriedFromID *uuid.UUID `json:"retried_from_id,omitempty" db:"retried_from_id"`
IsRetry bool `json:"is_retry" db:"is_retry"`
HasBeenRetried bool `json:"has_been_retried" db:"has_been_retried"`
RetryCount int `json:"retry_count" db:"retry_count"`
}

View File

@@ -162,7 +162,8 @@ func (ts *TimeoutService) updateRelatedPackageStatus(command *models.AgentComman
command.Params["package_type"].(string),
command.Params["package_name"].(string),
"failed",
metadata)
metadata,
nil) // nil = use time.Now() for timeout operations
}
// extractUpdatePackageID extracts the update package ID from command params