v0.1.16: Security overhaul and systematic deployment preparation

Breaking changes for clean alpha releases:
- JWT authentication with user-provided secrets (no more development defaults)
- Registration token system for secure agent enrollment
- Rate limiting with user-adjustable settings
- Enhanced agent configuration with proxy support
- Interactive server setup wizard (--setup flag)
- Heartbeat architecture separation for better UX
- Package status synchronization fixes
- Accurate timestamp tracking for RMM features

Setup process for new installations:
1. docker-compose up -d postgres
2. ./redflag-server --setup
3. ./redflag-server --migrate
4. ./redflag-server
5. Generate tokens via admin UI
6. Deploy agents with registration tokens
This commit is contained in:
Fimeg
2025-10-29 10:38:18 -04:00
parent b3e1b9e52f
commit 03fee29760
50 changed files with 5807 additions and 466 deletions

View File

@@ -1,6 +1,7 @@
package handlers
import (
"fmt"
"log"
"net/http"
"time"
@@ -107,15 +108,16 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
// Try to parse optional system metrics from request body
var metrics struct {
CPUPercent float64 `json:"cpu_percent,omitempty"`
MemoryPercent float64 `json:"memory_percent,omitempty"`
MemoryUsedGB float64 `json:"memory_used_gb,omitempty"`
MemoryTotalGB float64 `json:"memory_total_gb,omitempty"`
DiskUsedGB float64 `json:"disk_used_gb,omitempty"`
DiskTotalGB float64 `json:"disk_total_gb,omitempty"`
DiskPercent float64 `json:"disk_percent,omitempty"`
Uptime string `json:"uptime,omitempty"`
Version string `json:"version,omitempty"`
CPUPercent float64 `json:"cpu_percent,omitempty"`
MemoryPercent float64 `json:"memory_percent,omitempty"`
MemoryUsedGB float64 `json:"memory_used_gb,omitempty"`
MemoryTotalGB float64 `json:"memory_total_gb,omitempty"`
DiskUsedGB float64 `json:"disk_used_gb,omitempty"`
DiskTotalGB float64 `json:"disk_total_gb,omitempty"`
DiskPercent float64 `json:"disk_percent,omitempty"`
Uptime string `json:"uptime,omitempty"`
Version string `json:"version,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// Parse metrics if provided (optional, won't fail if empty)
@@ -130,21 +132,21 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
// Always handle version information if provided
if metrics.Version != "" {
// Get current agent to preserve existing metadata
agent, err := h.agentQueries.GetAgentByID(agentID)
if err == nil && agent.Metadata != nil {
// Update agent's current version
if err := h.agentQueries.UpdateAgentVersion(agentID, metrics.Version); err != nil {
log.Printf("Warning: Failed to update agent version: %v", err)
} else {
// Check if update is available
updateAvailable := utils.IsNewerVersion(h.latestAgentVersion, metrics.Version)
// Update agent's current version in database (primary source of truth)
if err := h.agentQueries.UpdateAgentVersion(agentID, metrics.Version); err != nil {
log.Printf("Warning: Failed to update agent version: %v", err)
} else {
// Check if update is available
updateAvailable := utils.IsNewerVersion(h.latestAgentVersion, metrics.Version)
// Update agent's update availability status
if err := h.agentQueries.UpdateAgentUpdateAvailable(agentID, updateAvailable); err != nil {
log.Printf("Warning: Failed to update agent update availability: %v", err)
}
// Update agent's update availability status
if err := h.agentQueries.UpdateAgentUpdateAvailable(agentID, updateAvailable); err != nil {
log.Printf("Warning: Failed to update agent update availability: %v", err)
}
// Get current agent for logging and metadata update
agent, err := h.agentQueries.GetAgentByID(agentID)
if err == nil {
// Log version check
if updateAvailable {
log.Printf("🔄 Agent %s (%s) version %s has update available: %s",
@@ -154,11 +156,20 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
agent.Hostname, agentID, metrics.Version)
}
// Store version in metadata as well
// Store version in metadata as well (for backwards compatibility)
// Initialize metadata if nil
if agent.Metadata == nil {
agent.Metadata = make(models.JSONB)
}
agent.Metadata["reported_version"] = metrics.Version
agent.Metadata["latest_version"] = h.latestAgentVersion
agent.Metadata["update_available"] = updateAvailable
agent.Metadata["version_checked_at"] = time.Now().Format(time.RFC3339)
// Update agent metadata
if err := h.agentQueries.UpdateAgent(agent); err != nil {
log.Printf("Warning: Failed to update agent metadata: %v", err)
}
}
}
}
@@ -179,6 +190,29 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
agent.Metadata["uptime"] = metrics.Uptime
agent.Metadata["metrics_updated_at"] = time.Now().Format(time.RFC3339)
// Process heartbeat metadata from agent check-ins
if metrics.Metadata != nil {
if rapidPollingEnabled, exists := metrics.Metadata["rapid_polling_enabled"]; exists {
if rapidPollingUntil, exists := metrics.Metadata["rapid_polling_until"]; exists {
// Parse the until timestamp
if untilTime, err := time.Parse(time.RFC3339, rapidPollingUntil.(string)); err == nil {
// Validate if rapid polling is still active (not expired)
isActive := rapidPollingEnabled.(bool) && time.Now().Before(untilTime)
// Store heartbeat status in agent metadata
agent.Metadata["rapid_polling_enabled"] = rapidPollingEnabled
agent.Metadata["rapid_polling_until"] = rapidPollingUntil
agent.Metadata["rapid_polling_active"] = isActive
log.Printf("[Heartbeat] Agent %s heartbeat status: enabled=%v, until=%v, active=%v",
agentID, rapidPollingEnabled, rapidPollingUntil, isActive)
} else {
log.Printf("[Heartbeat] Failed to parse rapid_polling_until timestamp for agent %s: %v", agentID, err)
}
}
}
}
// Update agent with new metadata
if err := h.agentQueries.UpdateAgent(agent); err != nil {
log.Printf("Warning: Failed to update agent metrics: %v", err)
@@ -192,6 +226,37 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
return
}
// Process heartbeat metadata from agent check-ins
if metrics.Metadata != nil {
agent, err := h.agentQueries.GetAgentByID(agentID)
if err == nil && agent.Metadata != nil {
if rapidPollingEnabled, exists := metrics.Metadata["rapid_polling_enabled"]; exists {
if rapidPollingUntil, exists := metrics.Metadata["rapid_polling_until"]; exists {
// Parse the until timestamp
if untilTime, err := time.Parse(time.RFC3339, rapidPollingUntil.(string)); err == nil {
// Validate if rapid polling is still active (not expired)
isActive := rapidPollingEnabled.(bool) && time.Now().Before(untilTime)
// Store heartbeat status in agent metadata
agent.Metadata["rapid_polling_enabled"] = rapidPollingEnabled
agent.Metadata["rapid_polling_until"] = rapidPollingUntil
agent.Metadata["rapid_polling_active"] = isActive
log.Printf("[Heartbeat] Agent %s heartbeat status: enabled=%v, until=%v, active=%v",
agentID, rapidPollingEnabled, rapidPollingUntil, isActive)
// Update agent with new metadata
if err := h.agentQueries.UpdateAgent(agent); err != nil {
log.Printf("[Heartbeat] Warning: Failed to update agent heartbeat metadata: %v", err)
}
} else {
log.Printf("[Heartbeat] Failed to parse rapid_polling_until timestamp for agent %s: %v", agentID, err)
}
}
}
}
}
// Check for version updates for agents that don't send version in metrics
// This ensures agents like Metis that don't report version still get update checks
if metrics.Version == "" {
@@ -239,8 +304,97 @@ func (h *AgentHandler) GetCommands(c *gin.Context) {
h.commandQueries.MarkCommandSent(cmd.ID)
}
// Check if rapid polling should be enabled
var rapidPolling *models.RapidPollingConfig
// Enable rapid polling if there are commands to process
if len(commandItems) > 0 {
rapidPolling = &models.RapidPollingConfig{
Enabled: true,
Until: time.Now().Add(10 * time.Minute).Format(time.RFC3339), // 10 minutes default
}
} else {
// Check if agent has rapid polling already configured in metadata
agent, err := h.agentQueries.GetAgentByID(agentID)
if err == nil && agent.Metadata != nil {
if enabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && enabled {
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
if until, err := time.Parse(time.RFC3339, untilStr); err == nil && time.Now().Before(until) {
rapidPolling = &models.RapidPollingConfig{
Enabled: true,
Until: untilStr,
}
}
}
}
}
}
// Detect stale heartbeat state: Server thinks it's active, but agent didn't report it
// This happens when agent restarts without heartbeat mode
agent, err := h.agentQueries.GetAgentByID(agentID)
if err == nil && agent.Metadata != nil {
// Check if server metadata shows heartbeat active
if serverEnabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && serverEnabled {
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
if until, err := time.Parse(time.RFC3339, untilStr); err == nil && time.Now().Before(until) {
// Server thinks heartbeat is active and not expired
// Check if agent is reporting heartbeat in this check-in
agentReportingHeartbeat := false
if metrics.Metadata != nil {
if agentEnabled, exists := metrics.Metadata["rapid_polling_enabled"]; exists {
agentReportingHeartbeat = agentEnabled.(bool)
}
}
// If agent is NOT reporting heartbeat but server expects it → stale state
if !agentReportingHeartbeat {
log.Printf("[Heartbeat] Stale heartbeat detected for agent %s - server expected active until %s, but agent not reporting heartbeat (likely restarted)",
agentID, until.Format(time.RFC3339))
// Clear stale heartbeat state
agent.Metadata["rapid_polling_enabled"] = false
delete(agent.Metadata, "rapid_polling_until")
if err := h.agentQueries.UpdateAgent(agent); err != nil {
log.Printf("[Heartbeat] Warning: Failed to clear stale heartbeat state: %v", err)
} else {
log.Printf("[Heartbeat] Cleared stale heartbeat state for agent %s", agentID)
// Create audit command to show in history
now := time.Now()
auditCmd := &models.AgentCommand{
ID: uuid.New(),
AgentID: agentID,
CommandType: models.CommandTypeDisableHeartbeat,
Params: models.JSONB{},
Status: models.CommandStatusCompleted,
Result: models.JSONB{
"message": "Heartbeat cleared - agent restarted without active heartbeat mode",
},
CreatedAt: now,
SentAt: &now,
CompletedAt: &now,
}
if err := h.commandQueries.CreateCommand(auditCmd); err != nil {
log.Printf("[Heartbeat] Warning: Failed to create audit command for stale heartbeat: %v", err)
} else {
log.Printf("[Heartbeat] Created audit trail for stale heartbeat cleanup (agent %s)", agentID)
}
}
// Clear rapidPolling response since we just disabled it
rapidPolling = nil
}
}
}
}
}
response := models.CommandsResponse{
Commands: commandItems,
Commands: commandItems,
RapidPolling: rapidPolling,
}
c.JSON(http.StatusOK, response)
@@ -312,6 +466,124 @@ func (h *AgentHandler) TriggerScan(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "scan triggered", "command_id": cmd.ID})
}
// TriggerHeartbeat creates a heartbeat toggle command for an agent
func (h *AgentHandler) TriggerHeartbeat(c *gin.Context) {
idStr := c.Param("id")
agentID, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent ID"})
return
}
var request struct {
Enabled bool `json:"enabled"`
DurationMinutes int `json:"duration_minutes"`
}
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Determine command type based on enabled flag
commandType := models.CommandTypeDisableHeartbeat
if request.Enabled {
commandType = models.CommandTypeEnableHeartbeat
}
// Create heartbeat command with duration parameter
cmd := &models.AgentCommand{
ID: uuid.New(),
AgentID: agentID,
CommandType: commandType,
Params: models.JSONB{
"duration_minutes": request.DurationMinutes,
},
Status: models.CommandStatusPending,
}
if err := h.commandQueries.CreateCommand(cmd); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create heartbeat command"})
return
}
// TODO: Clean up previous heartbeat commands for this agent (only for enable commands)
// if request.Enabled {
// // Mark previous heartbeat commands as 'replaced' to clean up Live Operations view
// if err := h.commandQueries.MarkPreviousHeartbeatCommandsReplaced(agentID, cmd.ID); err != nil {
// log.Printf("Warning: Failed to mark previous heartbeat commands as replaced: %v", err)
// // Don't fail the request, just log the warning
// } else {
// log.Printf("[Heartbeat] Cleaned up previous heartbeat commands for agent %s", agentID)
// }
// }
action := "disabled"
if request.Enabled {
action = "enabled"
}
log.Printf("💓 Heartbeat %s command created for agent %s (duration: %d minutes)",
action, agentID, request.DurationMinutes)
c.JSON(http.StatusOK, gin.H{
"message": fmt.Sprintf("heartbeat %s command sent", action),
"command_id": cmd.ID,
"enabled": request.Enabled,
})
}
// GetHeartbeatStatus returns the current heartbeat status for an agent
func (h *AgentHandler) GetHeartbeatStatus(c *gin.Context) {
idStr := c.Param("id")
agentID, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent ID"})
return
}
// Get agent and their heartbeat metadata
agent, err := h.agentQueries.GetAgentByID(agentID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"})
return
}
// Extract heartbeat information from metadata
response := gin.H{
"enabled": false,
"until": nil,
"active": false,
"duration_minutes": 0,
}
if agent.Metadata != nil {
// Check if heartbeat is enabled in metadata
if enabled, exists := agent.Metadata["rapid_polling_enabled"]; exists {
response["enabled"] = enabled.(bool)
// If enabled, get the until time and check if still active
if enabled.(bool) {
if untilStr, exists := agent.Metadata["rapid_polling_until"]; exists {
response["until"] = untilStr.(string)
// Parse the until timestamp to check if still active
if untilTime, err := time.Parse(time.RFC3339, untilStr.(string)); err == nil {
response["active"] = time.Now().Before(untilTime)
}
}
// Get duration if available
if duration, exists := agent.Metadata["rapid_polling_duration_minutes"]; exists {
response["duration_minutes"] = duration.(float64)
}
}
}
}
c.JSON(http.StatusOK, response)
}
// TriggerUpdate creates an update command for an agent
func (h *AgentHandler) TriggerUpdate(c *gin.Context) {
idStr := c.Param("id")
@@ -541,3 +813,114 @@ func (h *AgentHandler) ReportSystemInfo(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "system info updated successfully"})
}
// EnableRapidPollingMode enables rapid polling for an agent by updating metadata
func (h *AgentHandler) EnableRapidPollingMode(agentID uuid.UUID, durationMinutes int) error {
// Get current agent
agent, err := h.agentQueries.GetAgentByID(agentID)
if err != nil {
return fmt.Errorf("failed to get agent: %w", err)
}
// Calculate new rapid polling end time
newRapidPollingUntil := time.Now().Add(time.Duration(durationMinutes) * time.Minute)
// Update agent metadata with rapid polling settings
if agent.Metadata == nil {
agent.Metadata = models.JSONB{}
}
// Check if rapid polling is already active
if enabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && enabled {
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
if currentUntil, err := time.Parse(time.RFC3339, untilStr); err == nil {
// If current heartbeat expires later than the new duration, keep the longer duration
if currentUntil.After(newRapidPollingUntil) {
log.Printf("💓 Heartbeat already active for agent %s (%s), keeping longer duration (expires: %s)",
agent.Hostname, agentID, currentUntil.Format(time.RFC3339))
return nil
}
// Otherwise extend the heartbeat
log.Printf("💓 Extending heartbeat for agent %s (%s) from %s to %s",
agent.Hostname, agentID,
currentUntil.Format(time.RFC3339),
newRapidPollingUntil.Format(time.RFC3339))
}
}
} else {
log.Printf("💓 Enabling heartbeat mode for agent %s (%s) for %d minutes",
agent.Hostname, agentID, durationMinutes)
}
// Set/update rapid polling settings
agent.Metadata["rapid_polling_enabled"] = true
agent.Metadata["rapid_polling_until"] = newRapidPollingUntil.Format(time.RFC3339)
// Update agent in database
if err := h.agentQueries.UpdateAgent(agent); err != nil {
return fmt.Errorf("failed to update agent with rapid polling: %w", err)
}
return nil
}
// SetRapidPollingMode enables rapid polling mode for an agent
// TODO: Rate limiting should be implemented for rapid polling endpoints to prevent abuse (technical debt)
func (h *AgentHandler) SetRapidPollingMode(c *gin.Context) {
idStr := c.Param("id")
agentID, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent ID"})
return
}
// Check if agent exists
agent, err := h.agentQueries.GetAgentByID(agentID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"})
return
}
var req struct {
DurationMinutes int `json:"duration_minutes" binding:"required,min=1,max=60"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Calculate rapid polling end time
rapidPollingUntil := time.Now().Add(time.Duration(req.DurationMinutes) * time.Minute)
// Update agent metadata with rapid polling settings
if agent.Metadata == nil {
agent.Metadata = models.JSONB{}
}
agent.Metadata["rapid_polling_enabled"] = req.Enabled
agent.Metadata["rapid_polling_until"] = rapidPollingUntil.Format(time.RFC3339)
// Update agent in database
if err := h.agentQueries.UpdateAgent(agent); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to update agent"})
return
}
status := "disabled"
duration := 0
if req.Enabled {
status = "enabled"
duration = req.DurationMinutes
}
log.Printf("🚀 Rapid polling mode %s for agent %s (%s) for %d minutes",
status, agent.Hostname, agentID, duration)
c.JSON(http.StatusOK, gin.H{
"message": fmt.Sprintf("Rapid polling mode %s", status),
"enabled": req.Enabled,
"duration_minutes": req.DurationMinutes,
"rapid_polling_until": rapidPollingUntil,
})
}

View File

@@ -378,7 +378,7 @@ func (h *DockerHandler) RejectUpdate(c *gin.Context) {
}
// For now, we'll mark as rejected (this would need a proper reject method in queries)
if err := h.updateQueries.UpdatePackageStatus(update.AgentID, "docker", update.PackageName, "rejected", nil); err != nil {
if err := h.updateQueries.UpdatePackageStatus(update.AgentID, "docker", update.PackageName, "rejected", nil, nil); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to reject Docker update"})
return
}

View File

@@ -0,0 +1,146 @@
package handlers
import (
"fmt"
"net/http"
"time"
"github.com/aggregator-project/aggregator-server/internal/api/middleware"
"github.com/gin-gonic/gin"
)
type RateLimitHandler struct {
rateLimiter *middleware.RateLimiter
}
func NewRateLimitHandler(rateLimiter *middleware.RateLimiter) *RateLimitHandler {
return &RateLimitHandler{
rateLimiter: rateLimiter,
}
}
// GetRateLimitSettings returns current rate limit configuration
func (h *RateLimitHandler) GetRateLimitSettings(c *gin.Context) {
settings := h.rateLimiter.GetSettings()
c.JSON(http.StatusOK, gin.H{
"settings": settings,
"updated_at": time.Now(),
})
}
// UpdateRateLimitSettings updates rate limit configuration
func (h *RateLimitHandler) UpdateRateLimitSettings(c *gin.Context) {
var settings middleware.RateLimitSettings
if err := c.ShouldBindJSON(&settings); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format: " + err.Error()})
return
}
// Validate settings
if err := h.validateRateLimitSettings(settings); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Update rate limiter settings
h.rateLimiter.UpdateSettings(settings)
c.JSON(http.StatusOK, gin.H{
"message": "Rate limit settings updated successfully",
"settings": settings,
"updated_at": time.Now(),
})
}
// ResetRateLimitSettings resets to default values
func (h *RateLimitHandler) ResetRateLimitSettings(c *gin.Context) {
defaultSettings := middleware.DefaultRateLimitSettings()
h.rateLimiter.UpdateSettings(defaultSettings)
c.JSON(http.StatusOK, gin.H{
"message": "Rate limit settings reset to defaults",
"settings": defaultSettings,
"updated_at": time.Now(),
})
}
// GetRateLimitStats returns current rate limit statistics
func (h *RateLimitHandler) GetRateLimitStats(c *gin.Context) {
settings := h.rateLimiter.GetSettings()
// Calculate total requests and windows
stats := gin.H{
"total_configured_limits": 6,
"enabled_limits": 0,
"total_requests_per_minute": 0,
"settings": settings,
}
// Count enabled limits and total requests
for _, config := range []middleware.RateLimitConfig{
settings.AgentRegistration,
settings.AgentCheckIn,
settings.AgentReports,
settings.AdminTokenGen,
settings.AdminOperations,
settings.PublicAccess,
} {
if config.Enabled {
stats["enabled_limits"] = stats["enabled_limits"].(int) + 1
}
stats["total_requests_per_minute"] = stats["total_requests_per_minute"].(int) + config.Requests
}
c.JSON(http.StatusOK, stats)
}
// CleanupRateLimitEntries manually triggers cleanup of expired entries
func (h *RateLimitHandler) CleanupRateLimitEntries(c *gin.Context) {
h.rateLimiter.CleanupExpiredEntries()
c.JSON(http.StatusOK, gin.H{
"message": "Rate limit entries cleanup completed",
"timestamp": time.Now(),
})
}
// validateRateLimitSettings validates the provided rate limit settings
func (h *RateLimitHandler) validateRateLimitSettings(settings middleware.RateLimitSettings) error {
// Validate each configuration
validations := []struct {
name string
config middleware.RateLimitConfig
}{
{"agent_registration", settings.AgentRegistration},
{"agent_checkin", settings.AgentCheckIn},
{"agent_reports", settings.AgentReports},
{"admin_token_generation", settings.AdminTokenGen},
{"admin_operations", settings.AdminOperations},
{"public_access", settings.PublicAccess},
}
for _, validation := range validations {
if validation.config.Requests <= 0 {
return fmt.Errorf("%s: requests must be greater than 0", validation.name)
}
if validation.config.Window <= 0 {
return fmt.Errorf("%s: window must be greater than 0", validation.name)
}
if validation.config.Window > 24*time.Hour {
return fmt.Errorf("%s: window cannot exceed 24 hours", validation.name)
}
if validation.config.Requests > 1000 {
return fmt.Errorf("%s: requests cannot exceed 1000 per window", validation.name)
}
}
// Specific validations for different endpoint types
if settings.AgentRegistration.Requests > 10 {
return fmt.Errorf("agent_registration: requests should not exceed 10 per minute for security")
}
if settings.PublicAccess.Requests > 50 {
return fmt.Errorf("public_access: requests should not exceed 50 per minute for security")
}
return nil
}

View File

@@ -0,0 +1,284 @@
package handlers
import (
"net/http"
"strconv"
"time"
"github.com/aggregator-project/aggregator-server/internal/config"
"github.com/aggregator-project/aggregator-server/internal/database/queries"
"github.com/gin-gonic/gin"
)
type RegistrationTokenHandler struct {
tokenQueries *queries.RegistrationTokenQueries
agentQueries *queries.AgentQueries
config *config.Config
}
func NewRegistrationTokenHandler(tokenQueries *queries.RegistrationTokenQueries, agentQueries *queries.AgentQueries, config *config.Config) *RegistrationTokenHandler {
return &RegistrationTokenHandler{
tokenQueries: tokenQueries,
agentQueries: agentQueries,
config: config,
}
}
// GenerateRegistrationToken creates a new registration token
func (h *RegistrationTokenHandler) GenerateRegistrationToken(c *gin.Context) {
var request struct {
Label string `json:"label" binding:"required"`
ExpiresIn string `json:"expires_in"` // e.g., "24h", "7d", "168h"
Metadata map[string]interface{} `json:"metadata"`
}
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format: " + err.Error()})
return
}
// Check agent seat limit (security, not licensing)
activeAgents, err := h.agentQueries.GetActiveAgentCount()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check agent count"})
return
}
if activeAgents >= h.config.AgentRegistration.MaxSeats {
c.JSON(http.StatusForbidden, gin.H{
"error": "Maximum agent seats reached",
"limit": h.config.AgentRegistration.MaxSeats,
"current": activeAgents,
})
return
}
// Parse expiration duration
expiresIn := request.ExpiresIn
if expiresIn == "" {
expiresIn = h.config.AgentRegistration.TokenExpiry
}
duration, err := time.ParseDuration(expiresIn)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid expiration format. Use formats like '24h', '7d', '168h'"})
return
}
expiresAt := time.Now().Add(duration)
if duration > 168*time.Hour { // Max 7 days
c.JSON(http.StatusBadRequest, gin.H{"error": "Token expiration cannot exceed 7 days"})
return
}
// Generate secure token
token, err := config.GenerateSecureToken()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
// Create metadata with default values
metadata := request.Metadata
if metadata == nil {
metadata = make(map[string]interface{})
}
metadata["server_url"] = c.Request.Host
metadata["expires_in"] = expiresIn
// Store token in database
err = h.tokenQueries.CreateRegistrationToken(token, request.Label, expiresAt, metadata)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create token"})
return
}
// Build install command
serverURL := c.Request.Host
if serverURL == "" {
serverURL = "localhost:8080" // Fallback for development
}
installCommand := "curl -sfL https://" + serverURL + "/install | bash -s -- " + token
response := gin.H{
"token": token,
"label": request.Label,
"expires_at": expiresAt,
"install_command": installCommand,
"metadata": metadata,
}
c.JSON(http.StatusCreated, response)
}
// ListRegistrationTokens returns all registration tokens with pagination
func (h *RegistrationTokenHandler) ListRegistrationTokens(c *gin.Context) {
// Parse pagination parameters
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "50"))
status := c.Query("status")
// Validate pagination
if limit > 100 {
limit = 100
}
if page < 1 {
page = 1
}
offset := (page - 1) * limit
var tokens []queries.RegistrationToken
var err error
if status != "" {
// TODO: Add filtered queries by status
tokens, err = h.tokenQueries.GetAllRegistrationTokens(limit, offset)
} else {
tokens, err = h.tokenQueries.GetAllRegistrationTokens(limit, offset)
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list tokens"})
return
}
// Get token usage stats
stats, err := h.tokenQueries.GetTokenUsageStats()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get token stats"})
return
}
response := gin.H{
"tokens": tokens,
"pagination": gin.H{
"page": page,
"limit": limit,
"offset": offset,
},
"stats": stats,
"seat_usage": gin.H{
"current": func() int {
count, _ := h.agentQueries.GetActiveAgentCount()
return count
}(),
"max": h.config.AgentRegistration.MaxSeats,
},
}
c.JSON(http.StatusOK, response)
}
// GetActiveRegistrationTokens returns only active tokens
func (h *RegistrationTokenHandler) GetActiveRegistrationTokens(c *gin.Context) {
tokens, err := h.tokenQueries.GetActiveRegistrationTokens()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get active tokens"})
return
}
c.JSON(http.StatusOK, gin.H{"tokens": tokens})
}
// RevokeRegistrationToken revokes a registration token
func (h *RegistrationTokenHandler) RevokeRegistrationToken(c *gin.Context) {
token := c.Param("token")
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Token is required"})
return
}
var request struct {
Reason string `json:"reason"`
}
c.ShouldBindJSON(&request) // Reason is optional
reason := request.Reason
if reason == "" {
reason = "Revoked via API"
}
err := h.tokenQueries.RevokeRegistrationToken(token, reason)
if err != nil {
if err.Error() == "token not found or already used/revoked" {
c.JSON(http.StatusNotFound, gin.H{"error": "Token not found or already used/revoked"})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to revoke token"})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Token revoked successfully"})
}
// ValidateRegistrationToken checks if a token is valid (for testing/debugging)
func (h *RegistrationTokenHandler) ValidateRegistrationToken(c *gin.Context) {
token := c.Query("token")
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Token query parameter is required"})
return
}
tokenInfo, err := h.tokenQueries.ValidateRegistrationToken(token)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"valid": false,
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"valid": true,
"token": tokenInfo,
})
}
// CleanupExpiredTokens performs cleanup of expired tokens
func (h *RegistrationTokenHandler) CleanupExpiredTokens(c *gin.Context) {
count, err := h.tokenQueries.CleanupExpiredTokens()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to cleanup expired tokens"})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "Cleanup completed",
"cleaned": count,
})
}
// GetTokenStats returns comprehensive token usage statistics
func (h *RegistrationTokenHandler) GetTokenStats(c *gin.Context) {
// Get token stats
tokenStats, err := h.tokenQueries.GetTokenUsageStats()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get token stats"})
return
}
// Get agent count
activeAgentCount, err := h.agentQueries.GetActiveAgentCount()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get agent count"})
return
}
response := gin.H{
"token_stats": tokenStats,
"agent_usage": gin.H{
"active_agents": activeAgentCount,
"max_seats": h.config.AgentRegistration.MaxSeats,
"available": h.config.AgentRegistration.MaxSeats - activeAgentCount,
},
"security_limits": gin.H{
"max_tokens_per_request": h.config.AgentRegistration.MaxTokens,
"max_token_duration": "7 days",
"token_expiry_default": h.config.AgentRegistration.TokenExpiry,
},
}
c.JSON(http.StatusOK, response)
}

View File

@@ -2,6 +2,7 @@ package handlers
import (
"fmt"
"log"
"net/http"
"strconv"
"time"
@@ -16,16 +17,42 @@ type UpdateHandler struct {
updateQueries *queries.UpdateQueries
agentQueries *queries.AgentQueries
commandQueries *queries.CommandQueries
agentHandler *AgentHandler
}
func NewUpdateHandler(uq *queries.UpdateQueries, aq *queries.AgentQueries, cq *queries.CommandQueries) *UpdateHandler {
func NewUpdateHandler(uq *queries.UpdateQueries, aq *queries.AgentQueries, cq *queries.CommandQueries, ah *AgentHandler) *UpdateHandler {
return &UpdateHandler{
updateQueries: uq,
agentQueries: aq,
commandQueries: cq,
agentHandler: ah,
}
}
// shouldEnableHeartbeat checks if heartbeat is already active for an agent
// Returns true if heartbeat should be enabled (i.e., not already active or expired)
func (h *UpdateHandler) shouldEnableHeartbeat(agentID uuid.UUID, durationMinutes int) (bool, error) {
agent, err := h.agentQueries.GetAgentByID(agentID)
if err != nil {
log.Printf("Warning: Failed to get agent %s for heartbeat check: %v", agentID, err)
return true, nil // Enable heartbeat by default if we can't check
}
// Check if rapid polling is already enabled and not expired
if enabled, ok := agent.Metadata["rapid_polling_enabled"].(bool); ok && enabled {
if untilStr, ok := agent.Metadata["rapid_polling_until"].(string); ok {
until, err := time.Parse(time.RFC3339, untilStr)
if err == nil && until.After(time.Now().Add(5*time.Minute)) {
// Heartbeat is already active for sufficient time
log.Printf("[Heartbeat] Agent %s already has active heartbeat until %s (skipping)", agentID, untilStr)
return false, nil
}
}
}
return true, nil
}
// ReportUpdates handles update reports from agents using event sourcing
func (h *UpdateHandler) ReportUpdates(c *gin.Context) {
agentID := c.MustGet("agent_id").(uuid.UUID)
@@ -172,7 +199,7 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
return
}
log := &models.UpdateLog{
logEntry := &models.UpdateLog{
ID: uuid.New(),
AgentID: agentID,
Action: req.Action,
@@ -185,7 +212,7 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
}
// Store the log entry
if err := h.updateQueries.CreateUpdateLog(log); err != nil {
if err := h.updateQueries.CreateUpdateLog(logEntry); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save log"})
return
}
@@ -207,10 +234,34 @@ func (h *UpdateHandler) ReportLog(c *gin.Context) {
}
// Update command status based on log result
if req.Result == "success" {
if req.Result == "success" || req.Result == "completed" {
if err := h.commandQueries.MarkCommandCompleted(commandID, result); err != nil {
fmt.Printf("Warning: Failed to mark command %s as completed: %v\n", commandID, err)
}
// NEW: If this was a successful confirm_dependencies command, mark the package as updated
command, err := h.commandQueries.GetCommandByID(commandID)
if err == nil && command.CommandType == models.CommandTypeConfirmDependencies {
// Extract package info from command params
if packageName, ok := command.Params["package_name"].(string); ok {
if packageType, ok := command.Params["package_type"].(string); ok {
// Extract actual completion timestamp from command result for accurate audit trail
var completionTime *time.Time
if loggedAtStr, ok := command.Result["logged_at"].(string); ok {
if parsed, err := time.Parse(time.RFC3339Nano, loggedAtStr); err == nil {
completionTime = &parsed
}
}
// Update package status to 'updated' with actual completion timestamp
if err := h.updateQueries.UpdatePackageStatus(agentID, packageType, packageName, "updated", nil, completionTime); err != nil {
log.Printf("Warning: Failed to update package status for %s/%s: %v", packageType, packageName, err)
} else {
log.Printf("✅ Package %s (%s) marked as updated after successful installation", packageName, packageType)
}
}
}
}
} else if req.Result == "failed" || req.Result == "dry_run_failed" {
if err := h.commandQueries.MarkCommandFailed(commandID, result); err != nil {
fmt.Printf("Warning: Failed to mark command %s as failed: %v\n", commandID, err)
@@ -304,7 +355,7 @@ func (h *UpdateHandler) UpdatePackageStatus(c *gin.Context) {
return
}
if err := h.updateQueries.UpdatePackageStatus(agentID, req.PackageType, req.PackageName, req.Status, req.Metadata); err != nil {
if err := h.updateQueries.UpdatePackageStatus(agentID, req.PackageType, req.PackageName, req.Status, req.Metadata, nil); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to update package status"})
return
}
@@ -395,7 +446,29 @@ func (h *UpdateHandler) InstallUpdate(c *gin.Context) {
CreatedAt: time.Now(),
}
// Store the command in database
// Check if heartbeat should be enabled (avoid duplicates)
if shouldEnable, err := h.shouldEnableHeartbeat(update.AgentID, 10); err == nil && shouldEnable {
heartbeatCmd := &models.AgentCommand{
ID: uuid.New(),
AgentID: update.AgentID,
CommandType: models.CommandTypeEnableHeartbeat,
Params: models.JSONB{
"duration_minutes": 10,
},
Status: models.CommandStatusPending,
CreatedAt: time.Now(),
}
if err := h.commandQueries.CreateCommand(heartbeatCmd); err != nil {
log.Printf("[Heartbeat] Warning: Failed to create heartbeat command for agent %s: %v", update.AgentID, err)
} else {
log.Printf("[Heartbeat] Command created for agent %s before dry run", update.AgentID)
}
} else {
log.Printf("[Heartbeat] Skipping heartbeat command for agent %s (already active)", update.AgentID)
}
// Store the dry run command in database
if err := h.commandQueries.CreateCommand(command); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create dry run command"})
return
@@ -478,6 +551,28 @@ func (h *UpdateHandler) ReportDependencies(c *gin.Context) {
CreatedAt: time.Now(),
}
// Check if heartbeat should be enabled (avoid duplicates)
if shouldEnable, err := h.shouldEnableHeartbeat(agentID, 10); err == nil && shouldEnable {
heartbeatCmd := &models.AgentCommand{
ID: uuid.New(),
AgentID: agentID,
CommandType: models.CommandTypeEnableHeartbeat,
Params: models.JSONB{
"duration_minutes": 10,
},
Status: models.CommandStatusPending,
CreatedAt: time.Now(),
}
if err := h.commandQueries.CreateCommand(heartbeatCmd); err != nil {
log.Printf("[Heartbeat] Warning: Failed to create heartbeat command for agent %s: %v", agentID, err)
} else {
log.Printf("[Heartbeat] Command created for agent %s before installation", agentID)
}
} else {
log.Printf("[Heartbeat] Skipping heartbeat command for agent %s (already active)", agentID)
}
if err := h.commandQueries.CreateCommand(command); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create installation command"})
return
@@ -536,6 +631,28 @@ func (h *UpdateHandler) ConfirmDependencies(c *gin.Context) {
CreatedAt: time.Now(),
}
// Check if heartbeat should be enabled (avoid duplicates)
if shouldEnable, err := h.shouldEnableHeartbeat(update.AgentID, 10); err == nil && shouldEnable {
heartbeatCmd := &models.AgentCommand{
ID: uuid.New(),
AgentID: update.AgentID,
CommandType: models.CommandTypeEnableHeartbeat,
Params: models.JSONB{
"duration_minutes": 10,
},
Status: models.CommandStatusPending,
CreatedAt: time.Now(),
}
if err := h.commandQueries.CreateCommand(heartbeatCmd); err != nil {
log.Printf("[Heartbeat] Warning: Failed to create heartbeat command for agent %s: %v", update.AgentID, err)
} else {
log.Printf("[Heartbeat] Command created for agent %s before confirm dependencies", update.AgentID)
}
} else {
log.Printf("[Heartbeat] Skipping heartbeat command for agent %s (already active)", update.AgentID)
}
// Store the command in database
if err := h.commandQueries.CreateCommand(command); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create confirmation command"})
@@ -684,3 +801,60 @@ func (h *UpdateHandler) GetRecentCommands(c *gin.Context) {
"limit": limit,
})
}
// ClearFailedCommands manually removes failed/timed_out commands with cheeky warning
func (h *UpdateHandler) ClearFailedCommands(c *gin.Context) {
// Get query parameters for filtering
olderThanDaysStr := c.Query("older_than_days")
onlyRetriedStr := c.Query("only_retried")
allFailedStr := c.Query("all_failed")
var count int64
var err error
// Parse parameters
olderThanDays := 7 // default
if olderThanDaysStr != "" {
if days, err := strconv.Atoi(olderThanDaysStr); err == nil && days > 0 {
olderThanDays = days
}
}
onlyRetried := onlyRetriedStr == "true"
allFailed := allFailedStr == "true"
// Build the appropriate cleanup query based on parameters
if allFailed {
// Clear ALL failed commands (most aggressive)
count, err = h.commandQueries.ClearAllFailedCommands(olderThanDays)
} else if onlyRetried {
// Clear only failed commands that have been retried
count, err = h.commandQueries.ClearRetriedFailedCommands(olderThanDays)
} else {
// Clear failed commands older than specified days (default behavior)
count, err = h.commandQueries.ClearOldFailedCommands(olderThanDays)
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "failed to clear failed commands",
"details": err.Error(),
})
return
}
// Return success with cheeky message
message := fmt.Sprintf("Archived %d failed commands", count)
if count > 0 {
message += ". WARNING: This shouldn't be necessary if the retry logic is working properly - you might want to check what's causing commands to fail in the first place!"
message += " (History preserved - commands moved to archived status)"
} else {
message += ". No failed commands found matching your criteria. SUCCESS!"
}
c.JSON(http.StatusOK, gin.H{
"message": message,
"count": count,
"cheeky_warning": "Consider this a developer experience enhancement - the system should clean up after itself automatically!",
})
}

View File

@@ -0,0 +1,279 @@
package middleware
import (
"fmt"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// RateLimitConfig holds configuration for rate limiting
type RateLimitConfig struct {
Requests int `json:"requests"`
Window time.Duration `json:"window"`
Enabled bool `json:"enabled"`
}
// RateLimitEntry tracks requests for a specific key
type RateLimitEntry struct {
Requests []time.Time
mutex sync.RWMutex
}
// RateLimiter implements in-memory rate limiting with user-configurable settings
type RateLimiter struct {
entries sync.Map // map[string]*RateLimitEntry
configs map[string]RateLimitConfig
mutex sync.RWMutex
}
// RateLimitSettings holds all user-configurable rate limit settings
type RateLimitSettings struct {
AgentRegistration RateLimitConfig `json:"agent_registration"`
AgentCheckIn RateLimitConfig `json:"agent_checkin"`
AgentReports RateLimitConfig `json:"agent_reports"`
AdminTokenGen RateLimitConfig `json:"admin_token_generation"`
AdminOperations RateLimitConfig `json:"admin_operations"`
PublicAccess RateLimitConfig `json:"public_access"`
}
// DefaultRateLimitSettings provides sensible defaults
func DefaultRateLimitSettings() RateLimitSettings {
return RateLimitSettings{
AgentRegistration: RateLimitConfig{
Requests: 5,
Window: time.Minute,
Enabled: true,
},
AgentCheckIn: RateLimitConfig{
Requests: 60,
Window: time.Minute,
Enabled: true,
},
AgentReports: RateLimitConfig{
Requests: 30,
Window: time.Minute,
Enabled: true,
},
AdminTokenGen: RateLimitConfig{
Requests: 10,
Window: time.Minute,
Enabled: true,
},
AdminOperations: RateLimitConfig{
Requests: 100,
Window: time.Minute,
Enabled: true,
},
PublicAccess: RateLimitConfig{
Requests: 20,
Window: time.Minute,
Enabled: true,
},
}
}
// NewRateLimiter creates a new rate limiter with default settings
func NewRateLimiter() *RateLimiter {
rl := &RateLimiter{
entries: sync.Map{},
}
// Load default settings
defaults := DefaultRateLimitSettings()
rl.UpdateSettings(defaults)
return rl
}
// UpdateSettings updates rate limit configurations
func (rl *RateLimiter) UpdateSettings(settings RateLimitSettings) {
rl.mutex.Lock()
defer rl.mutex.Unlock()
rl.configs = map[string]RateLimitConfig{
"agent_registration": settings.AgentRegistration,
"agent_checkin": settings.AgentCheckIn,
"agent_reports": settings.AgentReports,
"admin_token_gen": settings.AdminTokenGen,
"admin_operations": settings.AdminOperations,
"public_access": settings.PublicAccess,
}
}
// GetSettings returns current rate limit settings
func (rl *RateLimiter) GetSettings() RateLimitSettings {
rl.mutex.RLock()
defer rl.mutex.RUnlock()
return RateLimitSettings{
AgentRegistration: rl.configs["agent_registration"],
AgentCheckIn: rl.configs["agent_checkin"],
AgentReports: rl.configs["agent_reports"],
AdminTokenGen: rl.configs["admin_token_gen"],
AdminOperations: rl.configs["admin_operations"],
PublicAccess: rl.configs["public_access"],
}
}
// RateLimit creates middleware for a specific rate limit type
func (rl *RateLimiter) RateLimit(limitType string, keyFunc func(*gin.Context) string) gin.HandlerFunc {
return func(c *gin.Context) {
rl.mutex.RLock()
config, exists := rl.configs[limitType]
rl.mutex.RUnlock()
if !exists || !config.Enabled {
c.Next()
return
}
key := keyFunc(c)
if key == "" {
c.Next()
return
}
// Check rate limit
allowed, resetTime := rl.checkRateLimit(key, config)
if !allowed {
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", config.Requests))
c.Header("X-RateLimit-Remaining", "0")
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
c.Header("Retry-After", fmt.Sprintf("%d", int(resetTime.Sub(time.Now()).Seconds())))
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "Rate limit exceeded",
"limit": config.Requests,
"window": config.Window.String(),
"reset_time": resetTime,
})
c.Abort()
return
}
// Add rate limit headers
remaining := rl.getRemainingRequests(key, config)
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", config.Requests))
c.Header("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Add(config.Window).Unix()))
c.Next()
}
}
// checkRateLimit checks if the request is allowed
func (rl *RateLimiter) checkRateLimit(key string, config RateLimitConfig) (bool, time.Time) {
now := time.Now()
// Get or create entry
entryInterface, _ := rl.entries.LoadOrStore(key, &RateLimitEntry{
Requests: []time.Time{},
})
entry := entryInterface.(*RateLimitEntry)
entry.mutex.Lock()
defer entry.mutex.Unlock()
// Clean old requests outside the window
cutoff := now.Add(-config.Window)
validRequests := make([]time.Time, 0)
for _, reqTime := range entry.Requests {
if reqTime.After(cutoff) {
validRequests = append(validRequests, reqTime)
}
}
// Check if under limit
if len(validRequests) >= config.Requests {
// Find when the oldest request expires
oldestRequest := validRequests[0]
resetTime := oldestRequest.Add(config.Window)
return false, resetTime
}
// Add current request
entry.Requests = append(validRequests, now)
// Clean up expired entries periodically
if len(entry.Requests) == 0 {
rl.entries.Delete(key)
}
return true, time.Time{}
}
// getRemainingRequests calculates remaining requests for the key
func (rl *RateLimiter) getRemainingRequests(key string, config RateLimitConfig) int {
entryInterface, ok := rl.entries.Load(key)
if !ok {
return config.Requests
}
entry := entryInterface.(*RateLimitEntry)
entry.mutex.RLock()
defer entry.mutex.RUnlock()
now := time.Now()
cutoff := now.Add(-config.Window)
count := 0
for _, reqTime := range entry.Requests {
if reqTime.After(cutoff) {
count++
}
}
remaining := config.Requests - count
if remaining < 0 {
remaining = 0
}
return remaining
}
// CleanupExpiredEntries removes expired entries to prevent memory leaks
func (rl *RateLimiter) CleanupExpiredEntries() {
rl.entries.Range(func(key, value interface{}) bool {
entry := value.(*RateLimitEntry)
entry.mutex.Lock()
now := time.Now()
validRequests := make([]time.Time, 0)
for _, reqTime := range entry.Requests {
if reqTime.After(now.Add(-time.Hour)) { // Keep requests from last hour
validRequests = append(validRequests, reqTime)
}
}
if len(validRequests) == 0 {
rl.entries.Delete(key)
} else {
entry.Requests = validRequests
}
entry.mutex.Unlock()
return true
})
}
// Key generation functions
func KeyByIP(c *gin.Context) string {
return c.ClientIP()
}
func KeyByAgentID(c *gin.Context) string {
return c.Param("id")
}
func KeyByUserID(c *gin.Context) string {
// This would extract user ID from JWT or session
// For now, use IP as fallback
return c.ClientIP()
}
func KeyByIPAndPath(c *gin.Context) string {
return c.ClientIP() + ":" + c.Request.URL.Path
}