v0.1.16: Security overhaul and systematic deployment preparation
Breaking changes for clean alpha releases: - JWT authentication with user-provided secrets (no more development defaults) - Registration token system for secure agent enrollment - Rate limiting with user-adjustable settings - Enhanced agent configuration with proxy support - Interactive server setup wizard (--setup flag) - Heartbeat architecture separation for better UX - Package status synchronization fixes - Accurate timestamp tracking for RMM features Setup process for new installations: 1. docker-compose up -d postgres 2. ./redflag-server --setup 3. ./redflag-server --migrate 4. ./redflag-server 5. Generate tokens via admin UI 6. Deploy agents with registration tokens
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"path/filepath"
|
||||
@@ -16,6 +17,29 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Parse command line flags
|
||||
var setup bool
|
||||
var migrate bool
|
||||
var version bool
|
||||
flag.BoolVar(&setup, "setup", false, "Run setup wizard")
|
||||
flag.BoolVar(&migrate, "migrate", false, "Run database migrations only")
|
||||
flag.BoolVar(&version, "version", false, "Show version information")
|
||||
flag.Parse()
|
||||
|
||||
// Handle special commands
|
||||
if version {
|
||||
fmt.Printf("RedFlag Server v0.1.0-alpha\n")
|
||||
fmt.Printf("Self-hosted update management platform\n")
|
||||
return
|
||||
}
|
||||
|
||||
if setup {
|
||||
if err := config.RunSetupWizard(); err != nil {
|
||||
log.Fatal("Setup failed:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Load configuration
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
@@ -23,15 +47,29 @@ func main() {
|
||||
}
|
||||
|
||||
// Set JWT secret
|
||||
middleware.JWTSecret = cfg.JWTSecret
|
||||
middleware.JWTSecret = cfg.Admin.JWTSecret
|
||||
|
||||
// Build database URL from new config structure
|
||||
databaseURL := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable",
|
||||
cfg.Database.Username, cfg.Database.Password, cfg.Database.Host, cfg.Database.Port, cfg.Database.Database)
|
||||
|
||||
// Connect to database
|
||||
db, err := database.Connect(cfg.DatabaseURL)
|
||||
db, err := database.Connect(databaseURL)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to connect to database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Handle migrate-only flag
|
||||
if migrate {
|
||||
migrationsPath := filepath.Join("internal", "database", "migrations")
|
||||
if err := db.Migrate(migrationsPath); err != nil {
|
||||
log.Fatal("Migration failed:", err)
|
||||
}
|
||||
fmt.Printf("✅ Database migrations completed\n")
|
||||
return
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
migrationsPath := filepath.Join("internal", "database", "migrations")
|
||||
if err := db.Migrate(migrationsPath); err != nil {
|
||||
@@ -45,18 +83,24 @@ func main() {
|
||||
updateQueries := queries.NewUpdateQueries(db.DB)
|
||||
commandQueries := queries.NewCommandQueries(db.DB)
|
||||
refreshTokenQueries := queries.NewRefreshTokenQueries(db.DB)
|
||||
registrationTokenQueries := queries.NewRegistrationTokenQueries(db.DB)
|
||||
|
||||
// Initialize services
|
||||
timezoneService := services.NewTimezoneService(cfg)
|
||||
timeoutService := services.NewTimeoutService(commandQueries, updateQueries)
|
||||
|
||||
// Initialize rate limiter
|
||||
rateLimiter := middleware.NewRateLimiter()
|
||||
|
||||
// Initialize handlers
|
||||
agentHandler := handlers.NewAgentHandler(agentQueries, commandQueries, refreshTokenQueries, cfg.CheckInInterval, cfg.LatestAgentVersion)
|
||||
updateHandler := handlers.NewUpdateHandler(updateQueries, agentQueries, commandQueries)
|
||||
authHandler := handlers.NewAuthHandler(cfg.JWTSecret)
|
||||
updateHandler := handlers.NewUpdateHandler(updateQueries, agentQueries, commandQueries, agentHandler)
|
||||
authHandler := handlers.NewAuthHandler(cfg.Admin.JWTSecret)
|
||||
statsHandler := handlers.NewStatsHandler(agentQueries, updateQueries)
|
||||
settingsHandler := handlers.NewSettingsHandler(timezoneService)
|
||||
dockerHandler := handlers.NewDockerHandler(updateQueries, agentQueries, commandQueries)
|
||||
registrationTokenHandler := handlers.NewRegistrationTokenHandler(registrationTokenQueries, agentQueries, cfg)
|
||||
rateLimitHandler := handlers.NewRateLimitHandler(rateLimiter)
|
||||
|
||||
// Setup router
|
||||
router := gin.Default()
|
||||
@@ -72,24 +116,26 @@ func main() {
|
||||
// API routes
|
||||
api := router.Group("/api/v1")
|
||||
{
|
||||
// Authentication routes
|
||||
api.POST("/auth/login", authHandler.Login)
|
||||
// Authentication routes (with rate limiting)
|
||||
api.POST("/auth/login", rateLimiter.RateLimit("public_access", middleware.KeyByIP), authHandler.Login)
|
||||
api.POST("/auth/logout", authHandler.Logout)
|
||||
api.GET("/auth/verify", authHandler.VerifyToken)
|
||||
|
||||
// Public routes (no authentication required)
|
||||
api.POST("/agents/register", agentHandler.RegisterAgent)
|
||||
api.POST("/agents/renew", agentHandler.RenewToken)
|
||||
// Public routes (no authentication required, with rate limiting)
|
||||
api.POST("/agents/register", rateLimiter.RateLimit("agent_registration", middleware.KeyByIP), agentHandler.RegisterAgent)
|
||||
api.POST("/agents/renew", rateLimiter.RateLimit("public_access", middleware.KeyByIP), agentHandler.RenewToken)
|
||||
|
||||
// Protected agent routes
|
||||
agents := api.Group("/agents")
|
||||
agents.Use(middleware.AuthMiddleware())
|
||||
{
|
||||
agents.GET("/:id/commands", agentHandler.GetCommands)
|
||||
agents.POST("/:id/updates", updateHandler.ReportUpdates)
|
||||
agents.POST("/:id/logs", updateHandler.ReportLog)
|
||||
agents.POST("/:id/dependencies", updateHandler.ReportDependencies)
|
||||
agents.POST("/:id/system-info", agentHandler.ReportSystemInfo)
|
||||
agents.POST("/:id/updates", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), updateHandler.ReportUpdates)
|
||||
agents.POST("/:id/logs", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), updateHandler.ReportLog)
|
||||
agents.POST("/:id/dependencies", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), updateHandler.ReportDependencies)
|
||||
agents.POST("/:id/system-info", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), agentHandler.ReportSystemInfo)
|
||||
agents.POST("/:id/rapid-mode", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), agentHandler.SetRapidPollingMode)
|
||||
agents.DELETE("/:id", agentHandler.UnregisterAgent)
|
||||
}
|
||||
|
||||
// Dashboard/Web routes (protected by web auth)
|
||||
@@ -101,7 +147,8 @@ func main() {
|
||||
dashboard.GET("/agents/:id", agentHandler.GetAgent)
|
||||
dashboard.POST("/agents/:id/scan", agentHandler.TriggerScan)
|
||||
dashboard.POST("/agents/:id/update", agentHandler.TriggerUpdate)
|
||||
dashboard.DELETE("/agents/:id", agentHandler.UnregisterAgent)
|
||||
dashboard.POST("/agents/:id/heartbeat", agentHandler.TriggerHeartbeat)
|
||||
dashboard.GET("/agents/:id/heartbeat", agentHandler.GetHeartbeatStatus)
|
||||
dashboard.GET("/updates", updateHandler.ListUpdates)
|
||||
dashboard.GET("/updates/:id", updateHandler.GetUpdate)
|
||||
dashboard.GET("/updates/:id/logs", updateHandler.GetUpdateLogs)
|
||||
@@ -120,6 +167,7 @@ func main() {
|
||||
dashboard.GET("/commands/recent", updateHandler.GetRecentCommands)
|
||||
dashboard.POST("/commands/:id/retry", updateHandler.RetryCommand)
|
||||
dashboard.POST("/commands/:id/cancel", updateHandler.CancelCommand)
|
||||
dashboard.DELETE("/commands/failed", updateHandler.ClearFailedCommands)
|
||||
|
||||
// Settings routes
|
||||
dashboard.GET("/settings/timezone", settingsHandler.GetTimezone)
|
||||
@@ -132,6 +180,25 @@ func main() {
|
||||
dashboard.POST("/docker/containers/:container_id/images/:image_id/approve", dockerHandler.ApproveUpdate)
|
||||
dashboard.POST("/docker/containers/:container_id/images/:image_id/reject", dockerHandler.RejectUpdate)
|
||||
dashboard.POST("/docker/containers/:container_id/images/:image_id/install", dockerHandler.InstallUpdate)
|
||||
|
||||
// Admin/Registration Token routes (for agent enrollment management)
|
||||
admin := dashboard.Group("/admin")
|
||||
{
|
||||
admin.POST("/registration-tokens", rateLimiter.RateLimit("admin_token_gen", middleware.KeyByUserID), registrationTokenHandler.GenerateRegistrationToken)
|
||||
admin.GET("/registration-tokens", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), registrationTokenHandler.ListRegistrationTokens)
|
||||
admin.GET("/registration-tokens/active", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), registrationTokenHandler.GetActiveRegistrationTokens)
|
||||
admin.DELETE("/registration-tokens/:token", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), registrationTokenHandler.RevokeRegistrationToken)
|
||||
admin.POST("/registration-tokens/cleanup", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), registrationTokenHandler.CleanupExpiredTokens)
|
||||
admin.GET("/registration-tokens/stats", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), registrationTokenHandler.GetTokenStats)
|
||||
admin.GET("/registration-tokens/validate", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), registrationTokenHandler.ValidateRegistrationToken)
|
||||
|
||||
// Rate Limit Management
|
||||
admin.GET("/rate-limits", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), rateLimitHandler.GetRateLimitSettings)
|
||||
admin.PUT("/rate-limits", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), rateLimitHandler.UpdateRateLimitSettings)
|
||||
admin.POST("/rate-limits/reset", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), rateLimitHandler.ResetRateLimitSettings)
|
||||
admin.GET("/rate-limits/stats", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), rateLimitHandler.GetRateLimitStats)
|
||||
admin.POST("/rate-limits/cleanup", rateLimiter.RateLimit("admin_operations", middleware.KeyByUserID), rateLimitHandler.CleanupRateLimitEntries)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,8 +233,11 @@ func main() {
|
||||
}()
|
||||
|
||||
// Start server
|
||||
addr := ":" + cfg.ServerPort
|
||||
fmt.Printf("\n🚩 RedFlag Aggregator Server starting on %s\n\n", addr)
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
|
||||
fmt.Printf("\nRedFlag Aggregator Server starting on %s\n", addr)
|
||||
fmt.Printf("Admin interface: http://%s:%d/admin\n", cfg.Server.Host, cfg.Server.Port)
|
||||
fmt.Printf("Dashboard: http://%s:%d\n\n", cfg.Server.Host, cfg.Server.Port)
|
||||
|
||||
if err := router.Run(addr); err != nil {
|
||||
log.Fatal("Failed to start server:", err)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ require (
|
||||
github.com/jmoiron/sqlx v1.4.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/lib/pq v1.10.9
|
||||
golang.org/x/term v0.33.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
@@ -92,6 +92,8 @@ golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
|
||||
golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
|
||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
@@ -107,15 +108,16 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
|
||||
// Try to parse optional system metrics from request body
|
||||
var metrics struct {
|
||||
CPUPercent float64 `json:"cpu_percent,omitempty"`
|
||||
MemoryPercent float64 `json:"memory_percent,omitempty"`
|
||||
MemoryUsedGB float64 `json:"memory_used_gb,omitempty"`
|
||||
MemoryTotalGB float64 `json:"memory_total_gb,omitempty"`
|
||||
DiskUsedGB float64 `json:"disk_used_gb,omitempty"`
|
||||
DiskTotalGB float64 `json:"disk_total_gb,omitempty"`
|
||||
DiskPercent float64 `json:"disk_percent,omitempty"`
|
||||
Uptime string `json:"uptime,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
CPUPercent float64 `json:"cpu_percent,omitempty"`
|
||||
MemoryPercent float64 `json:"memory_percent,omitempty"`
|
||||
MemoryUsedGB float64 `json:"memory_used_gb,omitempty"`
|
||||
MemoryTotalGB float64 `json:"memory_total_gb,omitempty"`
|
||||
DiskUsedGB float64 `json:"disk_used_gb,omitempty"`
|
||||
DiskTotalGB float64 `json:"disk_total_gb,omitempty"`
|
||||
DiskPercent float64 `json:"disk_percent,omitempty"`
|
||||
Uptime string `json:"uptime,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Parse metrics if provided (optional, won't fail if empty)
|
||||
@@ -130,21 +132,21 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
|
||||
// Always handle version information if provided
|
||||
if metrics.Version != "" {
|
||||
// Get current agent to preserve existing metadata
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err == nil && agent.Metadata != nil {
|
||||
// Update agent's current version
|
||||
if err := h.agentQueries.UpdateAgentVersion(agentID, metrics.Version); err != nil {
|
||||
log.Printf("Warning: Failed to update agent version: %v", err)
|
||||
} else {
|
||||
// Check if update is available
|
||||
updateAvailable := utils.IsNewerVersion(h.latestAgentVersion, metrics.Version)
|
||||
// Update agent's current version in database (primary source of truth)
|
||||
if err := h.agentQueries.UpdateAgentVersion(agentID, metrics.Version); err != nil {
|
||||
log.Printf("Warning: Failed to update agent version: %v", err)
|
||||
} else {
|
||||
// Check if update is available
|
||||
updateAvailable := utils.IsNewerVersion(h.latestAgentVersion, metrics.Version)
|
||||
|
||||
// Update agent's update availability status
|
||||
if err := h.agentQueries.UpdateAgentUpdateAvailable(agentID, updateAvailable); err != nil {
|
||||
log.Printf("Warning: Failed to update agent update availability: %v", err)
|
||||
}
|
||||
// Update agent's update availability status
|
||||
if err := h.agentQueries.UpdateAgentUpdateAvailable(agentID, updateAvailable); err != nil {
|
||||
log.Printf("Warning: Failed to update agent update availability: %v", err)
|
||||
}
|
||||
|
||||
// Get current agent for logging and metadata update
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err == nil {
|
||||
// Log version check
|
||||
if updateAvailable {
|
||||
log.Printf("🔄 Agent %s (%s) version %s has update available: %s",
|
||||
@@ -154,11 +156,20 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
agent.Hostname, agentID, metrics.Version)
|
||||
}
|
||||
|
||||
// Store version in metadata as well
|
||||
// Store version in metadata as well (for backwards compatibility)
|
||||
// Initialize metadata if nil
|
||||
if agent.Metadata == nil {
|
||||
agent.Metadata = make(models.JSONB)
|
||||
}
|
||||
agent.Metadata["reported_version"] = metrics.Version
|
||||
agent.Metadata["latest_version"] = h.latestAgentVersion
|
||||
agent.Metadata["update_available"] = updateAvailable
|
||||
agent.Metadata["version_checked_at"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
// Update agent metadata
|
||||
if err := h.agentQueries.UpdateAgent(agent); err != nil {
|
||||
log.Printf("Warning: Failed to update agent metadata: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -179,6 +190,29 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
agent.Metadata["uptime"] = metrics.Uptime
|
||||
agent.Metadata["metrics_updated_at"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
// Process heartbeat metadata from agent check-ins
|
||||
if metrics.Metadata != nil {
|
||||
if rapidPollingEnabled, exists := metrics.Metadata["rapid_polling_enabled"]; exists {
|
||||
if rapidPollingUntil, exists := metrics.Metadata["rapid_polling_until"]; exists {
|
||||
// Parse the until timestamp
|
||||
if untilTime, err := time.Parse(time.RFC3339, rapidPollingUntil.(string)); err == nil {
|
||||
// Validate if rapid polling is still active (not expired)
|
||||
isActive := rapidPollingEnabled.(bool) && time.Now().Before(untilTime)
|
||||
|
||||
// Store heartbeat status in agent metadata
|
||||
agent.Metadata["rapid_polling_enabled"] = rapidPollingEnabled
|
||||
agent.Metadata["rapid_polling_until"] = rapidPollingUntil
|
||||
agent.Metadata["rapid_polling_active"] = isActive
|
||||
|
||||
log.Printf("[Heartbeat] Agent %s heartbeat status: enabled=%v, until=%v, active=%v",
|
||||
agentID, rapidPollingEnabled, rapidPollingUntil, isActive)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Failed to parse rapid_polling_until timestamp for agent %s: %v", agentID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update agent with new metadata
|
||||
if err := h.agentQueries.UpdateAgent(agent); err != nil {
|
||||
log.Printf("Warning: Failed to update agent metrics: %v", err)
|
||||
@@ -192,6 +226,37 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Process heartbeat metadata from agent check-ins
|
||||
if metrics.Metadata != nil {
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err == nil && agent.Metadata != nil {
|
||||
if rapidPollingEnabled, exists := metrics.Metadata["rapid_polling_enabled"]; exists {
|
||||
if rapidPollingUntil, exists := metrics.Metadata["rapid_polling_until"]; exists {
|
||||
// Parse the until timestamp
|
||||
if untilTime, err := time.Parse(time.RFC3339, rapidPollingUntil.(string)); err == nil {
|
||||
// Validate if rapid polling is still active (not expired)
|
||||
isActive := rapidPollingEnabled.(bool) && time.Now().Before(untilTime)
|
||||
|
||||
// Store heartbeat status in agent metadata
|
||||
agent.Metadata["rapid_polling_enabled"] = rapidPollingEnabled
|
||||
agent.Metadata["rapid_polling_until"] = rapidPollingUntil
|
||||
agent.Metadata["rapid_polling_active"] = isActive
|
||||
|
||||
log.Printf("[Heartbeat] Agent %s heartbeat status: enabled=%v, until=%v, active=%v",
|
||||
agentID, rapidPollingEnabled, rapidPollingUntil, isActive)
|
||||
|
||||
// Update agent with new metadata
|
||||
if err := h.agentQueries.UpdateAgent(agent); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to update agent heartbeat metadata: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Failed to parse rapid_polling_until timestamp for agent %s: %v", agentID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for version updates for agents that don't send version in metrics
|
||||
// This ensures agents like Metis that don't report version still get update checks
|
||||
if metrics.Version == "" {
|
||||
@@ -239,8 +304,97 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
|
||||
h.commandQueries.MarkCommandSent(cmd.ID)
|
||||
}
|
||||
|
||||
// Check if rapid polling should be enabled
|
||||
var rapidPolling *models.RapidPollingConfig
|
||||
|
||||
// Enable rapid polling if there are commands to process
|
||||
if len(commandItems) > 0 {
|
||||
rapidPolling = &models.RapidPollingConfig{
|
||||
Enabled: true,
|
||||
Until: time.Now().Add(10 * time.Minute).Format(time.RFC3339), // 10 minutes default
|
||||
}
|
||||
} else {
|
||||
// Check if agent has rapid polling already configured in metadata
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err == nil && agent.Metadata != nil {
|
||||
if enabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && enabled {
|
||||
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
|
||||
if until, err := time.Parse(time.RFC3339, untilStr); err == nil && time.Now().Before(until) {
|
||||
rapidPolling = &models.RapidPollingConfig{
|
||||
Enabled: true,
|
||||
Until: untilStr,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Detect stale heartbeat state: Server thinks it's active, but agent didn't report it
|
||||
// This happens when agent restarts without heartbeat mode
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err == nil && agent.Metadata != nil {
|
||||
// Check if server metadata shows heartbeat active
|
||||
if serverEnabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && serverEnabled {
|
||||
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
|
||||
if until, err := time.Parse(time.RFC3339, untilStr); err == nil && time.Now().Before(until) {
|
||||
// Server thinks heartbeat is active and not expired
|
||||
// Check if agent is reporting heartbeat in this check-in
|
||||
agentReportingHeartbeat := false
|
||||
if metrics.Metadata != nil {
|
||||
if agentEnabled, exists := metrics.Metadata["rapid_polling_enabled"]; exists {
|
||||
agentReportingHeartbeat = agentEnabled.(bool)
|
||||
}
|
||||
}
|
||||
|
||||
// If agent is NOT reporting heartbeat but server expects it → stale state
|
||||
if !agentReportingHeartbeat {
|
||||
log.Printf("[Heartbeat] Stale heartbeat detected for agent %s - server expected active until %s, but agent not reporting heartbeat (likely restarted)",
|
||||
agentID, until.Format(time.RFC3339))
|
||||
|
||||
// Clear stale heartbeat state
|
||||
agent.Metadata["rapid_polling_enabled"] = false
|
||||
delete(agent.Metadata, "rapid_polling_until")
|
||||
|
||||
if err := h.agentQueries.UpdateAgent(agent); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to clear stale heartbeat state: %v", err)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Cleared stale heartbeat state for agent %s", agentID)
|
||||
|
||||
// Create audit command to show in history
|
||||
now := time.Now()
|
||||
auditCmd := &models.AgentCommand{
|
||||
ID: uuid.New(),
|
||||
AgentID: agentID,
|
||||
CommandType: models.CommandTypeDisableHeartbeat,
|
||||
Params: models.JSONB{},
|
||||
Status: models.CommandStatusCompleted,
|
||||
Result: models.JSONB{
|
||||
"message": "Heartbeat cleared - agent restarted without active heartbeat mode",
|
||||
},
|
||||
CreatedAt: now,
|
||||
SentAt: &now,
|
||||
CompletedAt: &now,
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(auditCmd); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to create audit command for stale heartbeat: %v", err)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Created audit trail for stale heartbeat cleanup (agent %s)", agentID)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear rapidPolling response since we just disabled it
|
||||
rapidPolling = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response := models.CommandsResponse{
|
||||
Commands: commandItems,
|
||||
Commands: commandItems,
|
||||
RapidPolling: rapidPolling,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
@@ -312,6 +466,124 @@ func (h *AgentHandler) TriggerScan(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "scan triggered", "command_id": cmd.ID})
|
||||
}
|
||||
|
||||
// TriggerHeartbeat creates a heartbeat toggle command for an agent
|
||||
func (h *AgentHandler) TriggerHeartbeat(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
agentID, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var request struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
DurationMinutes int `json:"duration_minutes"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Determine command type based on enabled flag
|
||||
commandType := models.CommandTypeDisableHeartbeat
|
||||
if request.Enabled {
|
||||
commandType = models.CommandTypeEnableHeartbeat
|
||||
}
|
||||
|
||||
// Create heartbeat command with duration parameter
|
||||
cmd := &models.AgentCommand{
|
||||
ID: uuid.New(),
|
||||
AgentID: agentID,
|
||||
CommandType: commandType,
|
||||
Params: models.JSONB{
|
||||
"duration_minutes": request.DurationMinutes,
|
||||
},
|
||||
Status: models.CommandStatusPending,
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(cmd); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create heartbeat command"})
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Clean up previous heartbeat commands for this agent (only for enable commands)
|
||||
// if request.Enabled {
|
||||
// // Mark previous heartbeat commands as 'replaced' to clean up Live Operations view
|
||||
// if err := h.commandQueries.MarkPreviousHeartbeatCommandsReplaced(agentID, cmd.ID); err != nil {
|
||||
// log.Printf("Warning: Failed to mark previous heartbeat commands as replaced: %v", err)
|
||||
// // Don't fail the request, just log the warning
|
||||
// } else {
|
||||
// log.Printf("[Heartbeat] Cleaned up previous heartbeat commands for agent %s", agentID)
|
||||
// }
|
||||
// }
|
||||
|
||||
action := "disabled"
|
||||
if request.Enabled {
|
||||
action = "enabled"
|
||||
}
|
||||
|
||||
log.Printf("💓 Heartbeat %s command created for agent %s (duration: %d minutes)",
|
||||
action, agentID, request.DurationMinutes)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": fmt.Sprintf("heartbeat %s command sent", action),
|
||||
"command_id": cmd.ID,
|
||||
"enabled": request.Enabled,
|
||||
})
|
||||
}
|
||||
|
||||
// GetHeartbeatStatus returns the current heartbeat status for an agent
|
||||
func (h *AgentHandler) GetHeartbeatStatus(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
agentID, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get agent and their heartbeat metadata
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract heartbeat information from metadata
|
||||
response := gin.H{
|
||||
"enabled": false,
|
||||
"until": nil,
|
||||
"active": false,
|
||||
"duration_minutes": 0,
|
||||
}
|
||||
|
||||
if agent.Metadata != nil {
|
||||
// Check if heartbeat is enabled in metadata
|
||||
if enabled, exists := agent.Metadata["rapid_polling_enabled"]; exists {
|
||||
response["enabled"] = enabled.(bool)
|
||||
|
||||
// If enabled, get the until time and check if still active
|
||||
if enabled.(bool) {
|
||||
if untilStr, exists := agent.Metadata["rapid_polling_until"]; exists {
|
||||
response["until"] = untilStr.(string)
|
||||
|
||||
// Parse the until timestamp to check if still active
|
||||
if untilTime, err := time.Parse(time.RFC3339, untilStr.(string)); err == nil {
|
||||
response["active"] = time.Now().Before(untilTime)
|
||||
}
|
||||
}
|
||||
|
||||
// Get duration if available
|
||||
if duration, exists := agent.Metadata["rapid_polling_duration_minutes"]; exists {
|
||||
response["duration_minutes"] = duration.(float64)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// TriggerUpdate creates an update command for an agent
|
||||
func (h *AgentHandler) TriggerUpdate(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
@@ -541,3 +813,114 @@ func (h *AgentHandler) ReportSystemInfo(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "system info updated successfully"})
|
||||
}
|
||||
|
||||
// EnableRapidPollingMode enables rapid polling for an agent by updating metadata
|
||||
func (h *AgentHandler) EnableRapidPollingMode(agentID uuid.UUID, durationMinutes int) error {
|
||||
// Get current agent
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get agent: %w", err)
|
||||
}
|
||||
|
||||
// Calculate new rapid polling end time
|
||||
newRapidPollingUntil := time.Now().Add(time.Duration(durationMinutes) * time.Minute)
|
||||
|
||||
// Update agent metadata with rapid polling settings
|
||||
if agent.Metadata == nil {
|
||||
agent.Metadata = models.JSONB{}
|
||||
}
|
||||
|
||||
// Check if rapid polling is already active
|
||||
if enabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && enabled {
|
||||
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
|
||||
if currentUntil, err := time.Parse(time.RFC3339, untilStr); err == nil {
|
||||
// If current heartbeat expires later than the new duration, keep the longer duration
|
||||
if currentUntil.After(newRapidPollingUntil) {
|
||||
log.Printf("💓 Heartbeat already active for agent %s (%s), keeping longer duration (expires: %s)",
|
||||
agent.Hostname, agentID, currentUntil.Format(time.RFC3339))
|
||||
return nil
|
||||
}
|
||||
// Otherwise extend the heartbeat
|
||||
log.Printf("💓 Extending heartbeat for agent %s (%s) from %s to %s",
|
||||
agent.Hostname, agentID,
|
||||
currentUntil.Format(time.RFC3339),
|
||||
newRapidPollingUntil.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Printf("💓 Enabling heartbeat mode for agent %s (%s) for %d minutes",
|
||||
agent.Hostname, agentID, durationMinutes)
|
||||
}
|
||||
|
||||
// Set/update rapid polling settings
|
||||
agent.Metadata["rapid_polling_enabled"] = true
|
||||
agent.Metadata["rapid_polling_until"] = newRapidPollingUntil.Format(time.RFC3339)
|
||||
|
||||
// Update agent in database
|
||||
if err := h.agentQueries.UpdateAgent(agent); err != nil {
|
||||
return fmt.Errorf("failed to update agent with rapid polling: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRapidPollingMode enables rapid polling mode for an agent
|
||||
// TODO: Rate limiting should be implemented for rapid polling endpoints to prevent abuse (technical debt)
|
||||
func (h *AgentHandler) SetRapidPollingMode(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
agentID, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if agent exists
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
DurationMinutes int `json:"duration_minutes" binding:"required,min=1,max=60"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate rapid polling end time
|
||||
rapidPollingUntil := time.Now().Add(time.Duration(req.DurationMinutes) * time.Minute)
|
||||
|
||||
// Update agent metadata with rapid polling settings
|
||||
if agent.Metadata == nil {
|
||||
agent.Metadata = models.JSONB{}
|
||||
}
|
||||
agent.Metadata["rapid_polling_enabled"] = req.Enabled
|
||||
agent.Metadata["rapid_polling_until"] = rapidPollingUntil.Format(time.RFC3339)
|
||||
|
||||
// Update agent in database
|
||||
if err := h.agentQueries.UpdateAgent(agent); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to update agent"})
|
||||
return
|
||||
}
|
||||
|
||||
status := "disabled"
|
||||
duration := 0
|
||||
if req.Enabled {
|
||||
status = "enabled"
|
||||
duration = req.DurationMinutes
|
||||
}
|
||||
|
||||
log.Printf("🚀 Rapid polling mode %s for agent %s (%s) for %d minutes",
|
||||
status, agent.Hostname, agentID, duration)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": fmt.Sprintf("Rapid polling mode %s", status),
|
||||
"enabled": req.Enabled,
|
||||
"duration_minutes": req.DurationMinutes,
|
||||
"rapid_polling_until": rapidPollingUntil,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -378,7 +378,7 @@ func (h *DockerHandler) RejectUpdate(c *gin.Context) {
|
||||
}
|
||||
|
||||
// For now, we'll mark as rejected (this would need a proper reject method in queries)
|
||||
if err := h.updateQueries.UpdatePackageStatus(update.AgentID, "docker", update.PackageName, "rejected", nil); err != nil {
|
||||
if err := h.updateQueries.UpdatePackageStatus(update.AgentID, "docker", update.PackageName, "rejected", nil, nil); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to reject Docker update"})
|
||||
return
|
||||
}
|
||||
|
||||
146
aggregator-server/internal/api/handlers/rate_limits.go
Normal file
146
aggregator-server/internal/api/handlers/rate_limits.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/aggregator-project/aggregator-server/internal/api/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type RateLimitHandler struct {
|
||||
rateLimiter *middleware.RateLimiter
|
||||
}
|
||||
|
||||
func NewRateLimitHandler(rateLimiter *middleware.RateLimiter) *RateLimitHandler {
|
||||
return &RateLimitHandler{
|
||||
rateLimiter: rateLimiter,
|
||||
}
|
||||
}
|
||||
|
||||
// GetRateLimitSettings returns current rate limit configuration
|
||||
func (h *RateLimitHandler) GetRateLimitSettings(c *gin.Context) {
|
||||
settings := h.rateLimiter.GetSettings()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"settings": settings,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRateLimitSettings updates rate limit configuration
|
||||
func (h *RateLimitHandler) UpdateRateLimitSettings(c *gin.Context) {
|
||||
var settings middleware.RateLimitSettings
|
||||
if err := c.ShouldBindJSON(&settings); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate settings
|
||||
if err := h.validateRateLimitSettings(settings); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Update rate limiter settings
|
||||
h.rateLimiter.UpdateSettings(settings)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Rate limit settings updated successfully",
|
||||
"settings": settings,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// ResetRateLimitSettings resets to default values
|
||||
func (h *RateLimitHandler) ResetRateLimitSettings(c *gin.Context) {
|
||||
defaultSettings := middleware.DefaultRateLimitSettings()
|
||||
h.rateLimiter.UpdateSettings(defaultSettings)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Rate limit settings reset to defaults",
|
||||
"settings": defaultSettings,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// GetRateLimitStats returns current rate limit statistics
|
||||
func (h *RateLimitHandler) GetRateLimitStats(c *gin.Context) {
|
||||
settings := h.rateLimiter.GetSettings()
|
||||
|
||||
// Calculate total requests and windows
|
||||
stats := gin.H{
|
||||
"total_configured_limits": 6,
|
||||
"enabled_limits": 0,
|
||||
"total_requests_per_minute": 0,
|
||||
"settings": settings,
|
||||
}
|
||||
|
||||
// Count enabled limits and total requests
|
||||
for _, config := range []middleware.RateLimitConfig{
|
||||
settings.AgentRegistration,
|
||||
settings.AgentCheckIn,
|
||||
settings.AgentReports,
|
||||
settings.AdminTokenGen,
|
||||
settings.AdminOperations,
|
||||
settings.PublicAccess,
|
||||
} {
|
||||
if config.Enabled {
|
||||
stats["enabled_limits"] = stats["enabled_limits"].(int) + 1
|
||||
}
|
||||
stats["total_requests_per_minute"] = stats["total_requests_per_minute"].(int) + config.Requests
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// CleanupRateLimitEntries manually triggers cleanup of expired entries
|
||||
func (h *RateLimitHandler) CleanupRateLimitEntries(c *gin.Context) {
|
||||
h.rateLimiter.CleanupExpiredEntries()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Rate limit entries cleanup completed",
|
||||
"timestamp": time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// validateRateLimitSettings validates the provided rate limit settings
|
||||
func (h *RateLimitHandler) validateRateLimitSettings(settings middleware.RateLimitSettings) error {
|
||||
// Validate each configuration
|
||||
validations := []struct {
|
||||
name string
|
||||
config middleware.RateLimitConfig
|
||||
}{
|
||||
{"agent_registration", settings.AgentRegistration},
|
||||
{"agent_checkin", settings.AgentCheckIn},
|
||||
{"agent_reports", settings.AgentReports},
|
||||
{"admin_token_generation", settings.AdminTokenGen},
|
||||
{"admin_operations", settings.AdminOperations},
|
||||
{"public_access", settings.PublicAccess},
|
||||
}
|
||||
|
||||
for _, validation := range validations {
|
||||
if validation.config.Requests <= 0 {
|
||||
return fmt.Errorf("%s: requests must be greater than 0", validation.name)
|
||||
}
|
||||
if validation.config.Window <= 0 {
|
||||
return fmt.Errorf("%s: window must be greater than 0", validation.name)
|
||||
}
|
||||
if validation.config.Window > 24*time.Hour {
|
||||
return fmt.Errorf("%s: window cannot exceed 24 hours", validation.name)
|
||||
}
|
||||
if validation.config.Requests > 1000 {
|
||||
return fmt.Errorf("%s: requests cannot exceed 1000 per window", validation.name)
|
||||
}
|
||||
}
|
||||
|
||||
// Specific validations for different endpoint types
|
||||
if settings.AgentRegistration.Requests > 10 {
|
||||
return fmt.Errorf("agent_registration: requests should not exceed 10 per minute for security")
|
||||
}
|
||||
if settings.PublicAccess.Requests > 50 {
|
||||
return fmt.Errorf("public_access: requests should not exceed 50 per minute for security")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
284
aggregator-server/internal/api/handlers/registration_tokens.go
Normal file
284
aggregator-server/internal/api/handlers/registration_tokens.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/aggregator-project/aggregator-server/internal/config"
|
||||
"github.com/aggregator-project/aggregator-server/internal/database/queries"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type RegistrationTokenHandler struct {
|
||||
tokenQueries *queries.RegistrationTokenQueries
|
||||
agentQueries *queries.AgentQueries
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewRegistrationTokenHandler(tokenQueries *queries.RegistrationTokenQueries, agentQueries *queries.AgentQueries, config *config.Config) *RegistrationTokenHandler {
|
||||
return &RegistrationTokenHandler{
|
||||
tokenQueries: tokenQueries,
|
||||
agentQueries: agentQueries,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRegistrationToken creates a new registration token
|
||||
func (h *RegistrationTokenHandler) GenerateRegistrationToken(c *gin.Context) {
|
||||
var request struct {
|
||||
Label string `json:"label" binding:"required"`
|
||||
ExpiresIn string `json:"expires_in"` // e.g., "24h", "7d", "168h"
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Check agent seat limit (security, not licensing)
|
||||
activeAgents, err := h.agentQueries.GetActiveAgentCount()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check agent count"})
|
||||
return
|
||||
}
|
||||
|
||||
if activeAgents >= h.config.AgentRegistration.MaxSeats {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "Maximum agent seats reached",
|
||||
"limit": h.config.AgentRegistration.MaxSeats,
|
||||
"current": activeAgents,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse expiration duration
|
||||
expiresIn := request.ExpiresIn
|
||||
if expiresIn == "" {
|
||||
expiresIn = h.config.AgentRegistration.TokenExpiry
|
||||
}
|
||||
|
||||
duration, err := time.ParseDuration(expiresIn)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid expiration format. Use formats like '24h', '7d', '168h'"})
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(duration)
|
||||
if duration > 168*time.Hour { // Max 7 days
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Token expiration cannot exceed 7 days"})
|
||||
return
|
||||
}
|
||||
|
||||
// Generate secure token
|
||||
token, err := config.GenerateSecureToken()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Create metadata with default values
|
||||
metadata := request.Metadata
|
||||
if metadata == nil {
|
||||
metadata = make(map[string]interface{})
|
||||
}
|
||||
metadata["server_url"] = c.Request.Host
|
||||
metadata["expires_in"] = expiresIn
|
||||
|
||||
// Store token in database
|
||||
err = h.tokenQueries.CreateRegistrationToken(token, request.Label, expiresAt, metadata)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Build install command
|
||||
serverURL := c.Request.Host
|
||||
if serverURL == "" {
|
||||
serverURL = "localhost:8080" // Fallback for development
|
||||
}
|
||||
installCommand := "curl -sfL https://" + serverURL + "/install | bash -s -- " + token
|
||||
|
||||
response := gin.H{
|
||||
"token": token,
|
||||
"label": request.Label,
|
||||
"expires_at": expiresAt,
|
||||
"install_command": installCommand,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// ListRegistrationTokens returns all registration tokens with pagination
|
||||
func (h *RegistrationTokenHandler) ListRegistrationTokens(c *gin.Context) {
|
||||
// Parse pagination parameters
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "50"))
|
||||
status := c.Query("status")
|
||||
|
||||
// Validate pagination
|
||||
if limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
offset := (page - 1) * limit
|
||||
|
||||
var tokens []queries.RegistrationToken
|
||||
var err error
|
||||
|
||||
if status != "" {
|
||||
// TODO: Add filtered queries by status
|
||||
tokens, err = h.tokenQueries.GetAllRegistrationTokens(limit, offset)
|
||||
} else {
|
||||
tokens, err = h.tokenQueries.GetAllRegistrationTokens(limit, offset)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list tokens"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get token usage stats
|
||||
stats, err := h.tokenQueries.GetTokenUsageStats()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get token stats"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"tokens": tokens,
|
||||
"pagination": gin.H{
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
},
|
||||
"stats": stats,
|
||||
"seat_usage": gin.H{
|
||||
"current": func() int {
|
||||
count, _ := h.agentQueries.GetActiveAgentCount()
|
||||
return count
|
||||
}(),
|
||||
"max": h.config.AgentRegistration.MaxSeats,
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetActiveRegistrationTokens returns only active tokens
|
||||
func (h *RegistrationTokenHandler) GetActiveRegistrationTokens(c *gin.Context) {
|
||||
tokens, err := h.tokenQueries.GetActiveRegistrationTokens()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get active tokens"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"tokens": tokens})
|
||||
}
|
||||
|
||||
// RevokeRegistrationToken revokes a registration token
|
||||
func (h *RegistrationTokenHandler) RevokeRegistrationToken(c *gin.Context) {
|
||||
token := c.Param("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Token is required"})
|
||||
return
|
||||
}
|
||||
|
||||
var request struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
c.ShouldBindJSON(&request) // Reason is optional
|
||||
|
||||
reason := request.Reason
|
||||
if reason == "" {
|
||||
reason = "Revoked via API"
|
||||
}
|
||||
|
||||
err := h.tokenQueries.RevokeRegistrationToken(token, reason)
|
||||
if err != nil {
|
||||
if err.Error() == "token not found or already used/revoked" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Token not found or already used/revoked"})
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to revoke token"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Token revoked successfully"})
|
||||
}
|
||||
|
||||
// ValidateRegistrationToken checks if a token is valid (for testing/debugging)
|
||||
func (h *RegistrationTokenHandler) ValidateRegistrationToken(c *gin.Context) {
|
||||
token := c.Query("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Token query parameter is required"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.tokenQueries.ValidateRegistrationToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"valid": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"valid": true,
|
||||
"token": tokenInfo,
|
||||
})
|
||||
}
|
||||
|
||||
// CleanupExpiredTokens performs cleanup of expired tokens
|
||||
func (h *RegistrationTokenHandler) CleanupExpiredTokens(c *gin.Context) {
|
||||
count, err := h.tokenQueries.CleanupExpiredTokens()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to cleanup expired tokens"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Cleanup completed",
|
||||
"cleaned": count,
|
||||
})
|
||||
}
|
||||
|
||||
// GetTokenStats returns comprehensive token usage statistics
|
||||
func (h *RegistrationTokenHandler) GetTokenStats(c *gin.Context) {
|
||||
// Get token stats
|
||||
tokenStats, err := h.tokenQueries.GetTokenUsageStats()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get token stats"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get agent count
|
||||
activeAgentCount, err := h.agentQueries.GetActiveAgentCount()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get agent count"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"token_stats": tokenStats,
|
||||
"agent_usage": gin.H{
|
||||
"active_agents": activeAgentCount,
|
||||
"max_seats": h.config.AgentRegistration.MaxSeats,
|
||||
"available": h.config.AgentRegistration.MaxSeats - activeAgentCount,
|
||||
},
|
||||
"security_limits": gin.H{
|
||||
"max_tokens_per_request": h.config.AgentRegistration.MaxTokens,
|
||||
"max_token_duration": "7 days",
|
||||
"token_expiry_default": h.config.AgentRegistration.TokenExpiry,
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -16,16 +17,42 @@ type UpdateHandler struct {
|
||||
updateQueries *queries.UpdateQueries
|
||||
agentQueries *queries.AgentQueries
|
||||
commandQueries *queries.CommandQueries
|
||||
agentHandler *AgentHandler
|
||||
}
|
||||
|
||||
func NewUpdateHandler(uq *queries.UpdateQueries, aq *queries.AgentQueries, cq *queries.CommandQueries) *UpdateHandler {
|
||||
func NewUpdateHandler(uq *queries.UpdateQueries, aq *queries.AgentQueries, cq *queries.CommandQueries, ah *AgentHandler) *UpdateHandler {
|
||||
return &UpdateHandler{
|
||||
updateQueries: uq,
|
||||
agentQueries: aq,
|
||||
commandQueries: cq,
|
||||
agentHandler: ah,
|
||||
}
|
||||
}
|
||||
|
||||
// shouldEnableHeartbeat checks if heartbeat is already active for an agent
|
||||
// Returns true if heartbeat should be enabled (i.e., not already active or expired)
|
||||
func (h *UpdateHandler) shouldEnableHeartbeat(agentID uuid.UUID, durationMinutes int) (bool, error) {
|
||||
agent, err := h.agentQueries.GetAgentByID(agentID)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to get agent %s for heartbeat check: %v", agentID, err)
|
||||
return true, nil // Enable heartbeat by default if we can't check
|
||||
}
|
||||
|
||||
// Check if rapid polling is already enabled and not expired
|
||||
if enabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && enabled {
|
||||
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
|
||||
until, err := time.Parse(time.RFC3339, untilStr)
|
||||
if err == nil && until.After(time.Now().Add(5*time.Minute)) {
|
||||
// Heartbeat is already active for sufficient time
|
||||
log.Printf("[Heartbeat] Agent %s already has active heartbeat until %s (skipping)", agentID, untilStr)
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ReportUpdates handles update reports from agents using event sourcing
|
||||
func (h *UpdateHandler) ReportUpdates(c *gin.Context) {
|
||||
agentID := c.MustGet("agent_id").(uuid.UUID)
|
||||
@@ -172,7 +199,7 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
log := &models.UpdateLog{
|
||||
logEntry := &models.UpdateLog{
|
||||
ID: uuid.New(),
|
||||
AgentID: agentID,
|
||||
Action: req.Action,
|
||||
@@ -185,7 +212,7 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Store the log entry
|
||||
if err := h.updateQueries.CreateUpdateLog(log); err != nil {
|
||||
if err := h.updateQueries.CreateUpdateLog(logEntry); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save log"})
|
||||
return
|
||||
}
|
||||
@@ -207,10 +234,34 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Update command status based on log result
|
||||
if req.Result == "success" {
|
||||
if req.Result == "success" || req.Result == "completed" {
|
||||
if err := h.commandQueries.MarkCommandCompleted(commandID, result); err != nil {
|
||||
fmt.Printf("Warning: Failed to mark command %s as completed: %v\n", commandID, err)
|
||||
}
|
||||
|
||||
// NEW: If this was a successful confirm_dependencies command, mark the package as updated
|
||||
command, err := h.commandQueries.GetCommandByID(commandID)
|
||||
if err == nil && command.CommandType == models.CommandTypeConfirmDependencies {
|
||||
// Extract package info from command params
|
||||
if packageName, ok := command.Params["package_name"].(string); ok {
|
||||
if packageType, ok := command.Params["package_type"].(string); ok {
|
||||
// Extract actual completion timestamp from command result for accurate audit trail
|
||||
var completionTime *time.Time
|
||||
if loggedAtStr, ok := command.Result["logged_at"].(string); ok {
|
||||
if parsed, err := time.Parse(time.RFC3339Nano, loggedAtStr); err == nil {
|
||||
completionTime = &parsed
|
||||
}
|
||||
}
|
||||
|
||||
// Update package status to 'updated' with actual completion timestamp
|
||||
if err := h.updateQueries.UpdatePackageStatus(agentID, packageType, packageName, "updated", nil, completionTime); err != nil {
|
||||
log.Printf("Warning: Failed to update package status for %s/%s: %v", packageType, packageName, err)
|
||||
} else {
|
||||
log.Printf("✅ Package %s (%s) marked as updated after successful installation", packageName, packageType)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if req.Result == "failed" || req.Result == "dry_run_failed" {
|
||||
if err := h.commandQueries.MarkCommandFailed(commandID, result); err != nil {
|
||||
fmt.Printf("Warning: Failed to mark command %s as failed: %v\n", commandID, err)
|
||||
@@ -304,7 +355,7 @@ func (h *UpdateHandler) UpdatePackageStatus(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.updateQueries.UpdatePackageStatus(agentID, req.PackageType, req.PackageName, req.Status, req.Metadata); err != nil {
|
||||
if err := h.updateQueries.UpdatePackageStatus(agentID, req.PackageType, req.PackageName, req.Status, req.Metadata, nil); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to update package status"})
|
||||
return
|
||||
}
|
||||
@@ -395,7 +446,29 @@ func (h *UpdateHandler) InstallUpdate(c *gin.Context) {
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Store the command in database
|
||||
// Check if heartbeat should be enabled (avoid duplicates)
|
||||
if shouldEnable, err := h.shouldEnableHeartbeat(update.AgentID, 10); err == nil && shouldEnable {
|
||||
heartbeatCmd := &models.AgentCommand{
|
||||
ID: uuid.New(),
|
||||
AgentID: update.AgentID,
|
||||
CommandType: models.CommandTypeEnableHeartbeat,
|
||||
Params: models.JSONB{
|
||||
"duration_minutes": 10,
|
||||
},
|
||||
Status: models.CommandStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(heartbeatCmd); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to create heartbeat command for agent %s: %v", update.AgentID, err)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Command created for agent %s before dry run", update.AgentID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Skipping heartbeat command for agent %s (already active)", update.AgentID)
|
||||
}
|
||||
|
||||
// Store the dry run command in database
|
||||
if err := h.commandQueries.CreateCommand(command); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create dry run command"})
|
||||
return
|
||||
@@ -478,6 +551,28 @@ func (h *UpdateHandler) ReportDependencies(c *gin.Context) {
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Check if heartbeat should be enabled (avoid duplicates)
|
||||
if shouldEnable, err := h.shouldEnableHeartbeat(agentID, 10); err == nil && shouldEnable {
|
||||
heartbeatCmd := &models.AgentCommand{
|
||||
ID: uuid.New(),
|
||||
AgentID: agentID,
|
||||
CommandType: models.CommandTypeEnableHeartbeat,
|
||||
Params: models.JSONB{
|
||||
"duration_minutes": 10,
|
||||
},
|
||||
Status: models.CommandStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(heartbeatCmd); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to create heartbeat command for agent %s: %v", agentID, err)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Command created for agent %s before installation", agentID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Skipping heartbeat command for agent %s (already active)", agentID)
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(command); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create installation command"})
|
||||
return
|
||||
@@ -536,6 +631,28 @@ func (h *UpdateHandler) ConfirmDependencies(c *gin.Context) {
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Check if heartbeat should be enabled (avoid duplicates)
|
||||
if shouldEnable, err := h.shouldEnableHeartbeat(update.AgentID, 10); err == nil && shouldEnable {
|
||||
heartbeatCmd := &models.AgentCommand{
|
||||
ID: uuid.New(),
|
||||
AgentID: update.AgentID,
|
||||
CommandType: models.CommandTypeEnableHeartbeat,
|
||||
Params: models.JSONB{
|
||||
"duration_minutes": 10,
|
||||
},
|
||||
Status: models.CommandStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := h.commandQueries.CreateCommand(heartbeatCmd); err != nil {
|
||||
log.Printf("[Heartbeat] Warning: Failed to create heartbeat command for agent %s: %v", update.AgentID, err)
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Command created for agent %s before confirm dependencies", update.AgentID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Heartbeat] Skipping heartbeat command for agent %s (already active)", update.AgentID)
|
||||
}
|
||||
|
||||
// Store the command in database
|
||||
if err := h.commandQueries.CreateCommand(command); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create confirmation command"})
|
||||
@@ -684,3 +801,60 @@ func (h *UpdateHandler) GetRecentCommands(c *gin.Context) {
|
||||
"limit": limit,
|
||||
})
|
||||
}
|
||||
|
||||
// ClearFailedCommands manually removes failed/timed_out commands with cheeky warning
|
||||
func (h *UpdateHandler) ClearFailedCommands(c *gin.Context) {
|
||||
// Get query parameters for filtering
|
||||
olderThanDaysStr := c.Query("older_than_days")
|
||||
onlyRetriedStr := c.Query("only_retried")
|
||||
allFailedStr := c.Query("all_failed")
|
||||
|
||||
var count int64
|
||||
var err error
|
||||
|
||||
// Parse parameters
|
||||
olderThanDays := 7 // default
|
||||
if olderThanDaysStr != "" {
|
||||
if days, err := strconv.Atoi(olderThanDaysStr); err == nil && days > 0 {
|
||||
olderThanDays = days
|
||||
}
|
||||
}
|
||||
|
||||
onlyRetried := onlyRetriedStr == "true"
|
||||
allFailed := allFailedStr == "true"
|
||||
|
||||
// Build the appropriate cleanup query based on parameters
|
||||
if allFailed {
|
||||
// Clear ALL failed commands (most aggressive)
|
||||
count, err = h.commandQueries.ClearAllFailedCommands(olderThanDays)
|
||||
} else if onlyRetried {
|
||||
// Clear only failed commands that have been retried
|
||||
count, err = h.commandQueries.ClearRetriedFailedCommands(olderThanDays)
|
||||
} else {
|
||||
// Clear failed commands older than specified days (default behavior)
|
||||
count, err = h.commandQueries.ClearOldFailedCommands(olderThanDays)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "failed to clear failed commands",
|
||||
"details": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Return success with cheeky message
|
||||
message := fmt.Sprintf("Archived %d failed commands", count)
|
||||
if count > 0 {
|
||||
message += ". WARNING: This shouldn't be necessary if the retry logic is working properly - you might want to check what's causing commands to fail in the first place!"
|
||||
message += " (History preserved - commands moved to archived status)"
|
||||
} else {
|
||||
message += ". No failed commands found matching your criteria. SUCCESS!"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": message,
|
||||
"count": count,
|
||||
"cheeky_warning": "Consider this a developer experience enhancement - the system should clean up after itself automatically!",
|
||||
})
|
||||
}
|
||||
|
||||
279
aggregator-server/internal/api/middleware/rate_limiter.go
Normal file
279
aggregator-server/internal/api/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RateLimitConfig holds configuration for rate limiting
|
||||
type RateLimitConfig struct {
|
||||
Requests int `json:"requests"`
|
||||
Window time.Duration `json:"window"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// RateLimitEntry tracks requests for a specific key
|
||||
type RateLimitEntry struct {
|
||||
Requests []time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// RateLimiter implements in-memory rate limiting with user-configurable settings
|
||||
type RateLimiter struct {
|
||||
entries sync.Map // map[string]*RateLimitEntry
|
||||
configs map[string]RateLimitConfig
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// RateLimitSettings holds all user-configurable rate limit settings
|
||||
type RateLimitSettings struct {
|
||||
AgentRegistration RateLimitConfig `json:"agent_registration"`
|
||||
AgentCheckIn RateLimitConfig `json:"agent_checkin"`
|
||||
AgentReports RateLimitConfig `json:"agent_reports"`
|
||||
AdminTokenGen RateLimitConfig `json:"admin_token_generation"`
|
||||
AdminOperations RateLimitConfig `json:"admin_operations"`
|
||||
PublicAccess RateLimitConfig `json:"public_access"`
|
||||
}
|
||||
|
||||
// DefaultRateLimitSettings provides sensible defaults
|
||||
func DefaultRateLimitSettings() RateLimitSettings {
|
||||
return RateLimitSettings{
|
||||
AgentRegistration: RateLimitConfig{
|
||||
Requests: 5,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
AgentCheckIn: RateLimitConfig{
|
||||
Requests: 60,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
AgentReports: RateLimitConfig{
|
||||
Requests: 30,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
AdminTokenGen: RateLimitConfig{
|
||||
Requests: 10,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
AdminOperations: RateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
PublicAccess: RateLimitConfig{
|
||||
Requests: 20,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter with default settings
|
||||
func NewRateLimiter() *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
entries: sync.Map{},
|
||||
}
|
||||
|
||||
// Load default settings
|
||||
defaults := DefaultRateLimitSettings()
|
||||
rl.UpdateSettings(defaults)
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// UpdateSettings updates rate limit configurations
|
||||
func (rl *RateLimiter) UpdateSettings(settings RateLimitSettings) {
|
||||
rl.mutex.Lock()
|
||||
defer rl.mutex.Unlock()
|
||||
|
||||
rl.configs = map[string]RateLimitConfig{
|
||||
"agent_registration": settings.AgentRegistration,
|
||||
"agent_checkin": settings.AgentCheckIn,
|
||||
"agent_reports": settings.AgentReports,
|
||||
"admin_token_gen": settings.AdminTokenGen,
|
||||
"admin_operations": settings.AdminOperations,
|
||||
"public_access": settings.PublicAccess,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSettings returns current rate limit settings
|
||||
func (rl *RateLimiter) GetSettings() RateLimitSettings {
|
||||
rl.mutex.RLock()
|
||||
defer rl.mutex.RUnlock()
|
||||
|
||||
return RateLimitSettings{
|
||||
AgentRegistration: rl.configs["agent_registration"],
|
||||
AgentCheckIn: rl.configs["agent_checkin"],
|
||||
AgentReports: rl.configs["agent_reports"],
|
||||
AdminTokenGen: rl.configs["admin_token_gen"],
|
||||
AdminOperations: rl.configs["admin_operations"],
|
||||
PublicAccess: rl.configs["public_access"],
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimit creates middleware for a specific rate limit type
|
||||
func (rl *RateLimiter) RateLimit(limitType string, keyFunc func(*gin.Context) string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
rl.mutex.RLock()
|
||||
config, exists := rl.configs[limitType]
|
||||
rl.mutex.RUnlock()
|
||||
|
||||
if !exists || !config.Enabled {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
key := keyFunc(c)
|
||||
if key == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Check rate limit
|
||||
allowed, resetTime := rl.checkRateLimit(key, config)
|
||||
if !allowed {
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", config.Requests))
|
||||
c.Header("X-RateLimit-Remaining", "0")
|
||||
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
|
||||
c.Header("Retry-After", fmt.Sprintf("%d", int(resetTime.Sub(time.Now()).Seconds())))
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "Rate limit exceeded",
|
||||
"limit": config.Requests,
|
||||
"window": config.Window.String(),
|
||||
"reset_time": resetTime,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Add rate limit headers
|
||||
remaining := rl.getRemainingRequests(key, config)
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", config.Requests))
|
||||
c.Header("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
||||
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Add(config.Window).Unix()))
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// checkRateLimit checks if the request is allowed
|
||||
func (rl *RateLimiter) checkRateLimit(key string, config RateLimitConfig) (bool, time.Time) {
|
||||
now := time.Now()
|
||||
|
||||
// Get or create entry
|
||||
entryInterface, _ := rl.entries.LoadOrStore(key, &RateLimitEntry{
|
||||
Requests: []time.Time{},
|
||||
})
|
||||
entry := entryInterface.(*RateLimitEntry)
|
||||
|
||||
entry.mutex.Lock()
|
||||
defer entry.mutex.Unlock()
|
||||
|
||||
// Clean old requests outside the window
|
||||
cutoff := now.Add(-config.Window)
|
||||
validRequests := make([]time.Time, 0)
|
||||
for _, reqTime := range entry.Requests {
|
||||
if reqTime.After(cutoff) {
|
||||
validRequests = append(validRequests, reqTime)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if under limit
|
||||
if len(validRequests) >= config.Requests {
|
||||
// Find when the oldest request expires
|
||||
oldestRequest := validRequests[0]
|
||||
resetTime := oldestRequest.Add(config.Window)
|
||||
return false, resetTime
|
||||
}
|
||||
|
||||
// Add current request
|
||||
entry.Requests = append(validRequests, now)
|
||||
|
||||
// Clean up expired entries periodically
|
||||
if len(entry.Requests) == 0 {
|
||||
rl.entries.Delete(key)
|
||||
}
|
||||
|
||||
return true, time.Time{}
|
||||
}
|
||||
|
||||
// getRemainingRequests calculates remaining requests for the key
|
||||
func (rl *RateLimiter) getRemainingRequests(key string, config RateLimitConfig) int {
|
||||
entryInterface, ok := rl.entries.Load(key)
|
||||
if !ok {
|
||||
return config.Requests
|
||||
}
|
||||
|
||||
entry := entryInterface.(*RateLimitEntry)
|
||||
entry.mutex.RLock()
|
||||
defer entry.mutex.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-config.Window)
|
||||
count := 0
|
||||
|
||||
for _, reqTime := range entry.Requests {
|
||||
if reqTime.After(cutoff) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
remaining := config.Requests - count
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
return remaining
|
||||
}
|
||||
|
||||
// CleanupExpiredEntries removes expired entries to prevent memory leaks
|
||||
func (rl *RateLimiter) CleanupExpiredEntries() {
|
||||
rl.entries.Range(func(key, value interface{}) bool {
|
||||
entry := value.(*RateLimitEntry)
|
||||
entry.mutex.Lock()
|
||||
|
||||
now := time.Now()
|
||||
validRequests := make([]time.Time, 0)
|
||||
for _, reqTime := range entry.Requests {
|
||||
if reqTime.After(now.Add(-time.Hour)) { // Keep requests from last hour
|
||||
validRequests = append(validRequests, reqTime)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validRequests) == 0 {
|
||||
rl.entries.Delete(key)
|
||||
} else {
|
||||
entry.Requests = validRequests
|
||||
}
|
||||
|
||||
entry.mutex.Unlock()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Key generation functions
|
||||
func KeyByIP(c *gin.Context) string {
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
func KeyByAgentID(c *gin.Context) string {
|
||||
return c.Param("id")
|
||||
}
|
||||
|
||||
func KeyByUserID(c *gin.Context) string {
|
||||
// This would extract user ID from JWT or session
|
||||
// For now, use IP as fallback
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
func KeyByIPAndPath(c *gin.Context) string {
|
||||
return c.ClientIP() + ":" + c.Request.URL.Path
|
||||
}
|
||||
@@ -1,18 +1,47 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Config holds the application configuration
|
||||
type Config struct {
|
||||
ServerPort string
|
||||
DatabaseURL string
|
||||
JWTSecret string
|
||||
Server struct {
|
||||
Host string `env:"REDFLAG_SERVER_HOST" default:"0.0.0.0"`
|
||||
Port int `env:"REDFLAG_SERVER_PORT" default:"8080"`
|
||||
TLS struct {
|
||||
Enabled bool `env:"REDFLAG_TLS_ENABLED" default:"false"`
|
||||
CertFile string `env:"REDFLAG_TLS_CERT_FILE"`
|
||||
KeyFile string `env:"REDFLAG_TLS_KEY_FILE"`
|
||||
}
|
||||
}
|
||||
Database struct {
|
||||
Host string `env:"REDFLAG_DB_HOST" default:"localhost"`
|
||||
Port int `env:"REDFLAG_DB_PORT" default:"5432"`
|
||||
Database string `env:"REDFLAG_DB_NAME" default:"redflag"`
|
||||
Username string `env:"REDFLAG_DB_USER" default:"redflag"`
|
||||
Password string `env:"REDFLAG_DB_PASSWORD"`
|
||||
}
|
||||
Admin struct {
|
||||
Username string `env:"REDFLAG_ADMIN_USER" default:"admin"`
|
||||
Password string `env:"REDFLAG_ADMIN_PASSWORD"`
|
||||
JWTSecret string `env:"REDFLAG_JWT_SECRET"`
|
||||
}
|
||||
AgentRegistration struct {
|
||||
TokenExpiry string `env:"REDFLAG_TOKEN_EXPIRY" default:"24h"`
|
||||
MaxTokens int `env:"REDFLAG_MAX_TOKENS" default:"100"`
|
||||
MaxSeats int `env:"REDFLAG_MAX_SEATS" default:"50"`
|
||||
}
|
||||
CheckInInterval int
|
||||
OfflineThreshold int
|
||||
Timezone string
|
||||
@@ -24,30 +53,195 @@ func Load() (*Config, error) {
|
||||
// Load .env file if it exists (for development)
|
||||
_ = godotenv.Load()
|
||||
|
||||
cfg := &Config{}
|
||||
|
||||
// Parse server configuration
|
||||
cfg.Server.Host = getEnv("REDFLAG_SERVER_HOST", "0.0.0.0")
|
||||
serverPort, _ := strconv.Atoi(getEnv("REDFLAG_SERVER_PORT", "8080"))
|
||||
cfg.Server.Port = serverPort
|
||||
cfg.Server.TLS.Enabled = getEnv("REDFLAG_TLS_ENABLED", "false") == "true"
|
||||
cfg.Server.TLS.CertFile = getEnv("REDFLAG_TLS_CERT_FILE", "")
|
||||
cfg.Server.TLS.KeyFile = getEnv("REDFLAG_TLS_KEY_FILE", "")
|
||||
|
||||
// Parse database configuration
|
||||
cfg.Database.Host = getEnv("REDFLAG_DB_HOST", "localhost")
|
||||
dbPort, _ := strconv.Atoi(getEnv("REDFLAG_DB_PORT", "5432"))
|
||||
cfg.Database.Port = dbPort
|
||||
cfg.Database.Database = getEnv("REDFLAG_DB_NAME", "redflag")
|
||||
cfg.Database.Username = getEnv("REDFLAG_DB_USER", "redflag")
|
||||
cfg.Database.Password = getEnv("REDFLAG_DB_PASSWORD", "")
|
||||
|
||||
// Parse admin configuration
|
||||
cfg.Admin.Username = getEnv("REDFLAG_ADMIN_USER", "admin")
|
||||
cfg.Admin.Password = getEnv("REDFLAG_ADMIN_PASSWORD", "")
|
||||
cfg.Admin.JWTSecret = getEnv("REDFLAG_JWT_SECRET", "")
|
||||
|
||||
// Parse agent registration configuration
|
||||
cfg.AgentRegistration.TokenExpiry = getEnv("REDFLAG_TOKEN_EXPIRY", "24h")
|
||||
maxTokens, _ := strconv.Atoi(getEnv("REDFLAG_MAX_TOKENS", "100"))
|
||||
cfg.AgentRegistration.MaxTokens = maxTokens
|
||||
maxSeats, _ := strconv.Atoi(getEnv("REDFLAG_MAX_SEATS", "50"))
|
||||
cfg.AgentRegistration.MaxSeats = maxSeats
|
||||
|
||||
// Parse legacy configuration for backwards compatibility
|
||||
checkInInterval, _ := strconv.Atoi(getEnv("CHECK_IN_INTERVAL", "300"))
|
||||
offlineThreshold, _ := strconv.Atoi(getEnv("OFFLINE_THRESHOLD", "600"))
|
||||
cfg.CheckInInterval = checkInInterval
|
||||
cfg.OfflineThreshold = offlineThreshold
|
||||
cfg.Timezone = getEnv("TIMEZONE", "UTC")
|
||||
cfg.LatestAgentVersion = getEnv("LATEST_AGENT_VERSION", "0.1.16")
|
||||
|
||||
cfg := &Config{
|
||||
ServerPort: getEnv("SERVER_PORT", "8080"),
|
||||
DatabaseURL: getEnv("DATABASE_URL", "postgres://aggregator:aggregator@localhost:5432/aggregator?sslmode=disable"),
|
||||
JWTSecret: getEnv("JWT_SECRET", "test-secret-for-development-only"),
|
||||
CheckInInterval: checkInInterval,
|
||||
OfflineThreshold: offlineThreshold,
|
||||
Timezone: getEnv("TIMEZONE", "UTC"),
|
||||
LatestAgentVersion: getEnv("LATEST_AGENT_VERSION", "0.1.4"),
|
||||
// Handle missing secrets
|
||||
if cfg.Admin.Password == "" || cfg.Admin.JWTSecret == "" || cfg.Database.Password == "" {
|
||||
fmt.Printf("[WARNING] Missing required configuration (admin password, JWT secret, or database password)\n")
|
||||
fmt.Printf("[INFO] Run: ./redflag-server --setup to configure\n")
|
||||
return nil, fmt.Errorf("missing required configuration")
|
||||
}
|
||||
|
||||
// Debug: Log what JWT secret we're using (remove in production)
|
||||
if cfg.JWTSecret == "test-secret-for-development-only" {
|
||||
fmt.Printf("🔓 Using development JWT secret\n")
|
||||
// Validate JWT secret is not the development default
|
||||
if cfg.Admin.JWTSecret == "test-secret-for-development-only" {
|
||||
fmt.Printf("[SECURITY WARNING] Using development JWT secret\n")
|
||||
fmt.Printf("[INFO] Run: ./redflag-server --setup to configure production secrets\n")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// RunSetupWizard guides user through initial configuration
|
||||
func RunSetupWizard() error {
|
||||
fmt.Printf("RedFlag Server Setup Wizard\n")
|
||||
fmt.Printf("===========================\n\n")
|
||||
|
||||
// Admin credentials
|
||||
fmt.Printf("Admin Account Setup\n")
|
||||
fmt.Printf("--------------------\n")
|
||||
username := promptForInput("Admin username", "admin")
|
||||
password := promptForPassword("Admin password")
|
||||
|
||||
// Database configuration
|
||||
fmt.Printf("\nDatabase Configuration\n")
|
||||
fmt.Printf("----------------------\n")
|
||||
dbHost := promptForInput("Database host", "localhost")
|
||||
dbPort, _ := strconv.Atoi(promptForInput("Database port", "5432"))
|
||||
dbName := promptForInput("Database name", "redflag")
|
||||
dbUser := promptForInput("Database user", "redflag")
|
||||
dbPassword := promptForPassword("Database password")
|
||||
|
||||
// Server configuration
|
||||
fmt.Printf("\nServer Configuration\n")
|
||||
fmt.Printf("--------------------\n")
|
||||
serverHost := promptForInput("Server bind address", "0.0.0.0")
|
||||
serverPort, _ := strconv.Atoi(promptForInput("Server port", "8080"))
|
||||
|
||||
// Agent limits
|
||||
fmt.Printf("\nAgent Registration\n")
|
||||
fmt.Printf("------------------\n")
|
||||
maxSeats, _ := strconv.Atoi(promptForInput("Maximum agent seats (security limit)", "50"))
|
||||
|
||||
// Generate JWT secret from admin password
|
||||
jwtSecret := deriveJWTSecret(username, password)
|
||||
|
||||
// Create .env file
|
||||
envContent := fmt.Sprintf(`# RedFlag Server Configuration
|
||||
# Generated on %s
|
||||
|
||||
# Server Configuration
|
||||
REDFLAG_SERVER_HOST=%s
|
||||
REDFLAG_SERVER_PORT=%d
|
||||
REDFLAG_TLS_ENABLED=false
|
||||
# REDFLAG_TLS_CERT_FILE=
|
||||
# REDFLAG_TLS_KEY_FILE=
|
||||
|
||||
# Database Configuration
|
||||
REDFLAG_DB_HOST=%s
|
||||
REDFLAG_DB_PORT=%d
|
||||
REDFLAG_DB_NAME=%s
|
||||
REDFLAG_DB_USER=%s
|
||||
REDFLAG_DB_PASSWORD=%s
|
||||
|
||||
# Admin Configuration
|
||||
REDFLAG_ADMIN_USER=%s
|
||||
REDFLAG_ADMIN_PASSWORD=%s
|
||||
REDFLAG_JWT_SECRET=%s
|
||||
|
||||
# Agent Registration
|
||||
REDFLAG_TOKEN_EXPIRY=24h
|
||||
REDFLAG_MAX_TOKENS=100
|
||||
REDFLAG_MAX_SEATS=%d
|
||||
|
||||
# Legacy Configuration (for backwards compatibility)
|
||||
SERVER_PORT=%d
|
||||
DATABASE_URL=postgres://%s:%s@%s:%d/%s?sslmode=disable
|
||||
JWT_SECRET=%s
|
||||
CHECK_IN_INTERVAL=300
|
||||
OFFLINE_THRESHOLD=600
|
||||
TIMEZONE=UTC
|
||||
LATEST_AGENT_VERSION=0.1.8
|
||||
`, time.Now().Format("2006-01-02 15:04:05"), serverHost, serverPort,
|
||||
dbHost, dbPort, dbName, dbUser, dbPassword,
|
||||
username, password, jwtSecret, maxSeats,
|
||||
serverPort, dbUser, dbPassword, dbHost, dbPort, dbName, jwtSecret)
|
||||
|
||||
// Write .env file
|
||||
if err := os.WriteFile(".env", []byte(envContent), 0600); err != nil {
|
||||
return fmt.Errorf("failed to write .env file: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("\n[OK] Configuration saved to .env file\n")
|
||||
fmt.Printf("[SECURITY] File permissions set to 0600 (owner read/write only)\n")
|
||||
fmt.Printf("\nNext steps:\n")
|
||||
fmt.Printf(" 1. Start database: %s:%d\n", dbHost, dbPort)
|
||||
fmt.Printf(" 2. Create database: CREATE DATABASE %s;\n", dbName)
|
||||
fmt.Printf(" 3. Run migrations: ./redflag-server --migrate\n")
|
||||
fmt.Printf(" 4. Start server: ./redflag-server\n")
|
||||
fmt.Printf("\nServer will be available at: http://%s:%d\n", serverHost, serverPort)
|
||||
fmt.Printf("Admin interface: http://%s:%d/admin\n", serverHost, serverPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func promptForInput(prompt, defaultValue string) string {
|
||||
fmt.Printf("%s [%s]: ", prompt, defaultValue)
|
||||
var input string
|
||||
fmt.Scanln(&input)
|
||||
if strings.TrimSpace(input) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return strings.TrimSpace(input)
|
||||
}
|
||||
|
||||
func promptForPassword(prompt string) string {
|
||||
fmt.Printf("%s: ", prompt)
|
||||
password, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
// Fallback to non-hidden input
|
||||
var input string
|
||||
fmt.Scanln(&input)
|
||||
return strings.TrimSpace(input)
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
return strings.TrimSpace(string(password))
|
||||
}
|
||||
|
||||
func deriveJWTSecret(username, password string) string {
|
||||
// Derive JWT secret from admin credentials
|
||||
// This ensures JWT secret changes if admin password changes
|
||||
hash := sha256.Sum256([]byte(username + password + "redflag-jwt-2024"))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// GenerateSecureToken generates a cryptographically secure random token
|
||||
func GenerateSecureToken() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate secure token: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
-- Add retry tracking to agent_commands table
|
||||
-- This allows us to track command retry chains and display retry indicators in the UI
|
||||
|
||||
-- Add retried_from_id column to link retries to their original commands
|
||||
ALTER TABLE agent_commands
|
||||
ADD COLUMN retried_from_id UUID REFERENCES agent_commands(id) ON DELETE SET NULL;
|
||||
|
||||
-- Add index for efficient retry chain lookups
|
||||
CREATE INDEX idx_commands_retried_from ON agent_commands(retried_from_id) WHERE retried_from_id IS NOT NULL;
|
||||
@@ -0,0 +1,9 @@
|
||||
-- Add 'archived_failed' status to agent_commands status constraint
|
||||
-- This allows archiving failed/timed_out commands to clean up the active list
|
||||
|
||||
-- Drop the existing constraint
|
||||
ALTER TABLE agent_commands DROP CONSTRAINT IF EXISTS agent_commands_status_check;
|
||||
|
||||
-- Add the new constraint with 'archived_failed' included
|
||||
ALTER TABLE agent_commands ADD CONSTRAINT agent_commands_status_check
|
||||
CHECK (status IN ('pending', 'sent', 'running', 'completed', 'failed', 'timed_out', 'cancelled', 'archived_failed'));
|
||||
@@ -0,0 +1,85 @@
|
||||
-- Registration tokens for secure agent enrollment
|
||||
-- Tokens are one-time use and have configurable expiration
|
||||
|
||||
CREATE TABLE registration_tokens (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
token VARCHAR(64) UNIQUE NOT NULL, -- One-time use token
|
||||
label VARCHAR(255), -- Optional label for token identification
|
||||
expires_at TIMESTAMP NOT NULL, -- Token expiration time
|
||||
created_at TIMESTAMP DEFAULT NOW(), -- When token was created
|
||||
used_at TIMESTAMP NULL, -- When token was used (NULL if unused)
|
||||
used_by_agent_id UUID NULL, -- Which agent used this token (foreign key)
|
||||
revoked BOOLEAN DEFAULT FALSE, -- Manual revocation
|
||||
revoked_at TIMESTAMP NULL, -- When token was revoked
|
||||
revoked_reason VARCHAR(255) NULL, -- Reason for revocation
|
||||
|
||||
-- Token status tracking
|
||||
status VARCHAR(20) DEFAULT 'active'
|
||||
CHECK (status IN ('active', 'used', 'expired', 'revoked')),
|
||||
|
||||
-- Additional metadata
|
||||
created_by VARCHAR(100) DEFAULT 'setup_wizard', -- Who created the token
|
||||
metadata JSONB DEFAULT '{}'::jsonb -- Additional token metadata
|
||||
);
|
||||
|
||||
-- Indexes for performance
|
||||
CREATE INDEX idx_registration_tokens_token ON registration_tokens(token);
|
||||
CREATE INDEX idx_registration_tokens_expires_at ON registration_tokens(expires_at);
|
||||
CREATE INDEX idx_registration_tokens_status ON registration_tokens(status);
|
||||
CREATE INDEX idx_registration_tokens_used_by_agent ON registration_tokens(used_by_agent_id) WHERE used_by_agent_id IS NOT NULL;
|
||||
|
||||
-- Foreign key constraint for used_by_agent_id
|
||||
ALTER TABLE registration_tokens
|
||||
ADD CONSTRAINT fk_registration_tokens_agent
|
||||
FOREIGN KEY (used_by_agent_id) REFERENCES agents(id) ON DELETE SET NULL;
|
||||
|
||||
-- Function to clean up expired tokens (called by periodic cleanup job)
|
||||
CREATE OR REPLACE FUNCTION cleanup_expired_registration_tokens()
|
||||
RETURNS INTEGER AS $$
|
||||
DECLARE
|
||||
deleted_count INTEGER;
|
||||
BEGIN
|
||||
UPDATE registration_tokens
|
||||
SET status = 'expired',
|
||||
used_at = NOW()
|
||||
WHERE status = 'active'
|
||||
AND expires_at < NOW()
|
||||
AND used_at IS NULL;
|
||||
|
||||
GET DIAGNOSTICS deleted_count = ROW_COUNT;
|
||||
RETURN deleted_count;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Function to check if a token is valid
|
||||
CREATE OR REPLACE FUNCTION is_registration_token_valid(token_input VARCHAR)
|
||||
RETURNS BOOLEAN AS $$
|
||||
DECLARE
|
||||
token_valid BOOLEAN;
|
||||
BEGIN
|
||||
SELECT (status = 'active' AND expires_at > NOW()) INTO token_valid
|
||||
FROM registration_tokens
|
||||
WHERE token = token_input;
|
||||
|
||||
RETURN COALESCE(token_valid, FALSE);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Function to mark token as used
|
||||
CREATE OR REPLACE function mark_registration_token_used(token_input VARCHAR, agent_id UUID)
|
||||
RETURNS BOOLEAN AS $$
|
||||
DECLARE
|
||||
updated BOOLEAN;
|
||||
BEGIN
|
||||
UPDATE registration_tokens
|
||||
SET status = 'used',
|
||||
used_at = NOW(),
|
||||
used_by_agent_id = agent_id
|
||||
WHERE token = token_input
|
||||
AND status = 'active'
|
||||
AND expires_at > NOW();
|
||||
|
||||
GET DIAGNOSTICS updated = ROW_COUNT;
|
||||
RETURN updated > 0;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
@@ -196,3 +196,11 @@ func (q *AgentQueries) DeleteAgent(id uuid.UUID) error {
|
||||
// Commit the transaction
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetActiveAgentCount returns the count of active (online) agents
|
||||
func (q *AgentQueries) GetActiveAgentCount() (int, error) {
|
||||
var count int
|
||||
query := `SELECT COUNT(*) FROM agents WHERE status = 'online'`
|
||||
err := q.db.Get(&count, query)
|
||||
return count, err
|
||||
}
|
||||
|
||||
@@ -21,9 +21,9 @@ func NewCommandQueries(db *sqlx.DB) *CommandQueries {
|
||||
func (q *CommandQueries) CreateCommand(cmd *models.AgentCommand) error {
|
||||
query := `
|
||||
INSERT INTO agent_commands (
|
||||
id, agent_id, command_type, params, status
|
||||
id, agent_id, command_type, params, status, retried_from_id
|
||||
) VALUES (
|
||||
:id, :agent_id, :command_type, :params, :status
|
||||
:id, :agent_id, :command_type, :params, :status, :retried_from_id
|
||||
)
|
||||
`
|
||||
_, err := q.db.NamedExec(query, cmd)
|
||||
@@ -152,14 +152,15 @@ func (q *CommandQueries) RetryCommand(originalID uuid.UUID) (*models.AgentComman
|
||||
return nil, fmt.Errorf("command must be failed, timed_out, or cancelled to retry")
|
||||
}
|
||||
|
||||
// Create new command with same parameters
|
||||
// Create new command with same parameters, linking it to the original
|
||||
newCommand := &models.AgentCommand{
|
||||
ID: uuid.New(),
|
||||
AgentID: original.AgentID,
|
||||
CommandType: original.CommandType,
|
||||
Params: original.Params,
|
||||
Status: models.CommandStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
ID: uuid.New(),
|
||||
AgentID: original.AgentID,
|
||||
CommandType: original.CommandType,
|
||||
Params: original.Params,
|
||||
Status: models.CommandStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
RetriedFromID: &originalID,
|
||||
}
|
||||
|
||||
// Store the new command
|
||||
@@ -180,20 +181,44 @@ func (q *CommandQueries) GetActiveCommands() ([]models.ActiveCommandInfo, error)
|
||||
c.id,
|
||||
c.agent_id,
|
||||
c.command_type,
|
||||
c.params,
|
||||
c.status,
|
||||
c.created_at,
|
||||
c.sent_at,
|
||||
c.result,
|
||||
c.retried_from_id,
|
||||
a.hostname as agent_hostname,
|
||||
COALESCE(ups.package_name, 'N/A') as package_name,
|
||||
COALESCE(ups.package_type, 'N/A') as package_type
|
||||
COALESCE(ups.package_type, 'N/A') as package_type,
|
||||
(c.retried_from_id IS NOT NULL) as is_retry,
|
||||
EXISTS(SELECT 1 FROM agent_commands WHERE retried_from_id = c.id) as has_been_retried,
|
||||
COALESCE((
|
||||
WITH RECURSIVE retry_chain AS (
|
||||
SELECT id, retried_from_id, 1 as depth
|
||||
FROM agent_commands
|
||||
WHERE id = c.id
|
||||
UNION ALL
|
||||
SELECT ac.id, ac.retried_from_id, rc.depth + 1
|
||||
FROM agent_commands ac
|
||||
JOIN retry_chain rc ON ac.id = rc.retried_from_id
|
||||
)
|
||||
SELECT MAX(depth) FROM retry_chain
|
||||
), 1) - 1 as retry_count
|
||||
FROM agent_commands c
|
||||
LEFT JOIN agents a ON c.agent_id = a.id
|
||||
LEFT JOIN current_package_state ups ON (
|
||||
c.params->>'update_id' = ups.id::text OR
|
||||
(c.params->>'package_name' = ups.package_name AND c.params->>'package_type' = ups.package_type)
|
||||
)
|
||||
WHERE c.status NOT IN ('completed', 'cancelled')
|
||||
WHERE c.status NOT IN ('completed', 'cancelled', 'archived_failed')
|
||||
AND NOT (
|
||||
c.status IN ('failed', 'timed_out')
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM agent_commands retry
|
||||
WHERE retry.retried_from_id = c.id
|
||||
AND retry.status = 'completed'
|
||||
)
|
||||
)
|
||||
ORDER BY c.created_at DESC
|
||||
`
|
||||
|
||||
@@ -223,9 +248,24 @@ func (q *CommandQueries) GetRecentCommands(limit int) ([]models.ActiveCommandInf
|
||||
c.sent_at,
|
||||
c.completed_at,
|
||||
c.result,
|
||||
c.retried_from_id,
|
||||
a.hostname as agent_hostname,
|
||||
COALESCE(ups.package_name, 'N/A') as package_name,
|
||||
COALESCE(ups.package_type, 'N/A') as package_type
|
||||
COALESCE(ups.package_type, 'N/A') as package_type,
|
||||
(c.retried_from_id IS NOT NULL) as is_retry,
|
||||
EXISTS(SELECT 1 FROM agent_commands WHERE retried_from_id = c.id) as has_been_retried,
|
||||
COALESCE((
|
||||
WITH RECURSIVE retry_chain AS (
|
||||
SELECT id, retried_from_id, 1 as depth
|
||||
FROM agent_commands
|
||||
WHERE id = c.id
|
||||
UNION ALL
|
||||
SELECT ac.id, ac.retried_from_id, rc.depth + 1
|
||||
FROM agent_commands ac
|
||||
JOIN retry_chain rc ON ac.id = rc.retried_from_id
|
||||
)
|
||||
SELECT MAX(depth) FROM retry_chain
|
||||
), 1) - 1 as retry_count
|
||||
FROM agent_commands c
|
||||
LEFT JOIN agents a ON c.agent_id = a.id
|
||||
LEFT JOIN current_package_state ups ON (
|
||||
@@ -243,3 +283,55 @@ func (q *CommandQueries) GetRecentCommands(limit int) ([]models.ActiveCommandInf
|
||||
|
||||
return commands, nil
|
||||
}
|
||||
|
||||
// ClearOldFailedCommands archives failed commands older than specified days by changing status to 'archived_failed'
|
||||
func (q *CommandQueries) ClearOldFailedCommands(days int) (int64, error) {
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE agent_commands
|
||||
SET status = 'archived_failed'
|
||||
WHERE status IN ('failed', 'timed_out')
|
||||
AND created_at < NOW() - INTERVAL '%d days'
|
||||
`, days)
|
||||
|
||||
result, err := q.db.Exec(query)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to archive old failed commands: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// ClearRetriedFailedCommands archives failed commands that have been retried and are older than specified days
|
||||
func (q *CommandQueries) ClearRetriedFailedCommands(days int) (int64, error) {
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE agent_commands
|
||||
SET status = 'archived_failed'
|
||||
WHERE status IN ('failed', 'timed_out')
|
||||
AND EXISTS (SELECT 1 FROM agent_commands WHERE retried_from_id = agent_commands.id)
|
||||
AND created_at < NOW() - INTERVAL '%d days'
|
||||
`, days)
|
||||
|
||||
result, err := q.db.Exec(query)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to archive retried failed commands: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// ClearAllFailedCommands archives all failed commands older than specified days (most aggressive)
|
||||
func (q *CommandQueries) ClearAllFailedCommands(days int) (int64, error) {
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE agent_commands
|
||||
SET status = 'archived_failed'
|
||||
WHERE status IN ('failed', 'timed_out')
|
||||
AND created_at < NOW() - INTERVAL '%d days'
|
||||
`, days)
|
||||
|
||||
result, err := q.db.Exec(query)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to archive all failed commands: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
package queries
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
type RegistrationTokenQueries struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
type RegistrationToken struct {
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
Token string `json:"token" db:"token"`
|
||||
Label *string `json:"label" db:"label"`
|
||||
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UsedAt *time.Time `json:"used_at" db:"used_at"`
|
||||
UsedByAgentID *uuid.UUID `json:"used_by_agent_id" db:"used_by_agent_id"`
|
||||
Revoked bool `json:"revoked" db:"revoked"`
|
||||
RevokedAt *time.Time `json:"revoked_at" db:"revoked_at"`
|
||||
RevokedReason *string `json:"revoked_reason" db:"revoked_reason"`
|
||||
Status string `json:"status" db:"status"`
|
||||
CreatedBy string `json:"created_by" db:"created_by"`
|
||||
Metadata map[string]interface{} `json:"metadata" db:"metadata"`
|
||||
}
|
||||
|
||||
type TokenRequest struct {
|
||||
Label string `json:"label"`
|
||||
ExpiresIn string `json:"expires_in"` // e.g., "24h", "7d"
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Label string `json:"label"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
InstallCommand string `json:"install_command"`
|
||||
}
|
||||
|
||||
func NewRegistrationTokenQueries(db *sqlx.DB) *RegistrationTokenQueries {
|
||||
return &RegistrationTokenQueries{db: db}
|
||||
}
|
||||
|
||||
// CreateRegistrationToken creates a new one-time use registration token
|
||||
func (q *RegistrationTokenQueries) CreateRegistrationToken(token, label string, expiresAt time.Time, metadata map[string]interface{}) error {
|
||||
metadataJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO registration_tokens (token, label, expires_at, metadata)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`
|
||||
|
||||
_, err = q.db.Exec(query, token, label, expiresAt, metadataJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create registration token: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateRegistrationToken checks if a token is valid and unused
|
||||
func (q *RegistrationTokenQueries) ValidateRegistrationToken(token string) (*RegistrationToken, error) {
|
||||
var regToken RegistrationToken
|
||||
query := `
|
||||
SELECT id, token, label, expires_at, created_at, used_at, used_by_agent_id,
|
||||
revoked, revoked_at, revoked_reason, status, created_by, metadata
|
||||
FROM registration_tokens
|
||||
WHERE token = $1 AND status = 'active' AND expires_at > NOW()
|
||||
`
|
||||
|
||||
err := q.db.Get(®Token, query, token)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("invalid or expired token")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to validate token: %w", err)
|
||||
}
|
||||
|
||||
return ®Token, nil
|
||||
}
|
||||
|
||||
// MarkTokenUsed marks a token as used by an agent
|
||||
func (q *RegistrationTokenQueries) MarkTokenUsed(token string, agentID uuid.UUID) error {
|
||||
query := `
|
||||
UPDATE registration_tokens
|
||||
SET status = 'used',
|
||||
used_at = NOW(),
|
||||
used_by_agent_id = $1
|
||||
WHERE token = $2 AND status = 'active' AND expires_at > NOW()
|
||||
`
|
||||
|
||||
result, err := q.db.Exec(query, agentID, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark token as used: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("token not found or already used")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveRegistrationTokens returns all active tokens
|
||||
func (q *RegistrationTokenQueries) GetActiveRegistrationTokens() ([]RegistrationToken, error) {
|
||||
var tokens []RegistrationToken
|
||||
query := `
|
||||
SELECT id, token, label, expires_at, created_at, used_at, used_by_agent_id,
|
||||
revoked, revoked_at, revoked_reason, status, created_by, metadata
|
||||
FROM registration_tokens
|
||||
WHERE status = 'active'
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
err := q.db.Select(&tokens, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active tokens: %w", err)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// GetAllRegistrationTokens returns all tokens with pagination
|
||||
func (q *RegistrationTokenQueries) GetAllRegistrationTokens(limit, offset int) ([]RegistrationToken, error) {
|
||||
var tokens []RegistrationToken
|
||||
query := `
|
||||
SELECT id, token, label, expires_at, created_at, used_at, used_by_agent_id,
|
||||
revoked, revoked_at, revoked_reason, status, created_by, metadata
|
||||
FROM registration_tokens
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
`
|
||||
|
||||
err := q.db.Select(&tokens, query, limit, offset)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get all tokens: %w", err)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// RevokeRegistrationToken revokes a token
|
||||
func (q *RegistrationTokenQueries) RevokeRegistrationToken(token, reason string) error {
|
||||
query := `
|
||||
UPDATE registration_tokens
|
||||
SET status = 'revoked',
|
||||
revoked = true,
|
||||
revoked_at = NOW(),
|
||||
revoked_reason = $1
|
||||
WHERE token = $2 AND status = 'active'
|
||||
`
|
||||
|
||||
result, err := q.db.Exec(query, reason, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke token: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("token not found or already used/revoked")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupExpiredTokens marks expired tokens as expired
|
||||
func (q *RegistrationTokenQueries) CleanupExpiredTokens() (int, error) {
|
||||
query := `
|
||||
UPDATE registration_tokens
|
||||
SET status = 'expired',
|
||||
used_at = NOW()
|
||||
WHERE status = 'active' AND expires_at < NOW() AND used_at IS NULL
|
||||
`
|
||||
|
||||
result, err := q.db.Exec(query)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to cleanup expired tokens: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
// GetTokenUsageStats returns statistics about token usage
|
||||
func (q *RegistrationTokenQueries) GetTokenUsageStats() (map[string]int, error) {
|
||||
stats := make(map[string]int)
|
||||
|
||||
query := `
|
||||
SELECT status, COUNT(*) as count
|
||||
FROM registration_tokens
|
||||
GROUP BY status
|
||||
`
|
||||
|
||||
rows, err := q.db.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token stats: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var status string
|
||||
var count int
|
||||
if err := rows.Scan(&status, &count); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan token stats row: %w", err)
|
||||
}
|
||||
stats[status] = count
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
@@ -527,7 +527,8 @@ func (q *UpdateQueries) GetPackageHistory(agentID uuid.UUID, packageType, packag
|
||||
}
|
||||
|
||||
// UpdatePackageStatus updates the status of a package and records history
|
||||
func (q *UpdateQueries) UpdatePackageStatus(agentID uuid.UUID, packageType, packageName, status string, metadata models.JSONB) error {
|
||||
// completedAt is optional - if nil, uses time.Now(). Pass actual completion time for accurate audit trails.
|
||||
func (q *UpdateQueries) UpdatePackageStatus(agentID uuid.UUID, packageType, packageName, status string, metadata models.JSONB, completedAt *time.Time) error {
|
||||
tx, err := q.db.Beginx()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
@@ -542,13 +543,19 @@ func (q *UpdateQueries) UpdatePackageStatus(agentID uuid.UUID, packageType, pack
|
||||
return fmt.Errorf("failed to get current state: %w", err)
|
||||
}
|
||||
|
||||
// Use provided timestamp or fall back to server time
|
||||
timestamp := time.Now()
|
||||
if completedAt != nil {
|
||||
timestamp = *completedAt
|
||||
}
|
||||
|
||||
// Update status
|
||||
updateQuery := `
|
||||
UPDATE current_package_state
|
||||
SET status = $1, last_updated_at = $2
|
||||
WHERE agent_id = $3 AND package_type = $4 AND package_name = $5
|
||||
`
|
||||
_, err = tx.Exec(updateQuery, status, time.Now(), agentID, packageType, packageName)
|
||||
_, err = tx.Exec(updateQuery, status, timestamp, agentID, packageType, packageName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update package status: %w", err)
|
||||
}
|
||||
@@ -564,7 +571,7 @@ func (q *UpdateQueries) UpdatePackageStatus(agentID uuid.UUID, packageType, pack
|
||||
_, err = tx.Exec(historyQuery,
|
||||
agentID, packageType, packageName, currentState.CurrentVersion,
|
||||
currentState.AvailableVersion, currentState.Severity,
|
||||
currentState.RepositorySource, metadata, time.Now(), status)
|
||||
currentState.RepositorySource, metadata, timestamp, status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record version history: %w", err)
|
||||
}
|
||||
|
||||
@@ -8,20 +8,28 @@ import (
|
||||
|
||||
// AgentCommand represents a command to be executed by an agent
|
||||
type AgentCommand struct {
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
AgentID uuid.UUID `json:"agent_id" db:"agent_id"`
|
||||
CommandType string `json:"command_type" db:"command_type"`
|
||||
Params JSONB `json:"params" db:"params"`
|
||||
Status string `json:"status" db:"status"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
SentAt *time.Time `json:"sent_at,omitempty" db:"sent_at"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
|
||||
Result JSONB `json:"result,omitempty" db:"result"`
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
AgentID uuid.UUID `json:"agent_id" db:"agent_id"`
|
||||
CommandType string `json:"command_type" db:"command_type"`
|
||||
Params JSONB `json:"params" db:"params"`
|
||||
Status string `json:"status" db:"status"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
SentAt *time.Time `json:"sent_at,omitempty" db:"sent_at"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
|
||||
Result JSONB `json:"result,omitempty" db:"result"`
|
||||
RetriedFromID *uuid.UUID `json:"retried_from_id,omitempty" db:"retried_from_id"`
|
||||
}
|
||||
|
||||
// CommandsResponse is returned when an agent checks in for commands
|
||||
type CommandsResponse struct {
|
||||
Commands []CommandItem `json:"commands"`
|
||||
Commands []CommandItem `json:"commands"`
|
||||
RapidPolling *RapidPollingConfig `json:"rapid_polling,omitempty"`
|
||||
}
|
||||
|
||||
// RapidPollingConfig contains rapid polling configuration for the agent
|
||||
type RapidPollingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Until string `json:"until"` // ISO 8601 timestamp
|
||||
}
|
||||
|
||||
// CommandItem represents a command in the response
|
||||
@@ -40,6 +48,8 @@ const (
|
||||
CommandTypeConfirmDependencies = "confirm_dependencies"
|
||||
CommandTypeRollback = "rollback_update"
|
||||
CommandTypeUpdateAgent = "update_agent"
|
||||
CommandTypeEnableHeartbeat = "enable_heartbeat"
|
||||
CommandTypeDisableHeartbeat = "disable_heartbeat"
|
||||
)
|
||||
|
||||
// Command statuses
|
||||
@@ -55,15 +65,20 @@ const (
|
||||
|
||||
// ActiveCommandInfo represents information about an active command for UI display
|
||||
type ActiveCommandInfo struct {
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
AgentID uuid.UUID `json:"agent_id" db:"agent_id"`
|
||||
CommandType string `json:"command_type" db:"command_type"`
|
||||
Status string `json:"status" db:"status"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
SentAt *time.Time `json:"sent_at,omitempty" db:"sent_at"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
|
||||
Result JSONB `json:"result,omitempty" db:"result"`
|
||||
AgentHostname string `json:"agent_hostname" db:"agent_hostname"`
|
||||
PackageName string `json:"package_name" db:"package_name"`
|
||||
PackageType string `json:"package_type" db:"package_type"`
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
AgentID uuid.UUID `json:"agent_id" db:"agent_id"`
|
||||
CommandType string `json:"command_type" db:"command_type"`
|
||||
Params JSONB `json:"params" db:"params"`
|
||||
Status string `json:"status" db:"status"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
SentAt *time.Time `json:"sent_at,omitempty" db:"sent_at"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
|
||||
Result JSONB `json:"result,omitempty" db:"result"`
|
||||
AgentHostname string `json:"agent_hostname" db:"agent_hostname"`
|
||||
PackageName string `json:"package_name" db:"package_name"`
|
||||
PackageType string `json:"package_type" db:"package_type"`
|
||||
RetriedFromID *uuid.UUID `json:"retried_from_id,omitempty" db:"retried_from_id"`
|
||||
IsRetry bool `json:"is_retry" db:"is_retry"`
|
||||
HasBeenRetried bool `json:"has_been_retried" db:"has_been_retried"`
|
||||
RetryCount int `json:"retry_count" db:"retry_count"`
|
||||
}
|
||||
|
||||
@@ -162,7 +162,8 @@ func (ts *TimeoutService) updateRelatedPackageStatus(command *models.AgentComman
|
||||
command.Params["package_type"].(string),
|
||||
command.Params["package_name"].(string),
|
||||
"failed",
|
||||
metadata)
|
||||
metadata,
|
||||
nil) // nil = use time.Now() for timeout operations
|
||||
}
|
||||
|
||||
// extractUpdatePackageID extracts the update package ID from command params
|
||||
|
||||
Reference in New Issue
Block a user