feat(security): A-1 Ed25519 key rotation + A-2 replay attack fixes
Complete RedFlag codebase with two major security audit implementations.
== A-1: Ed25519 Key Rotation Support ==
Server:
- SignCommand sets SignedAt timestamp and KeyID on every signature
- signing_keys database table (migration 020) for multi-key rotation
- InitializePrimaryKey registers active key at startup
- /api/v1/public-keys endpoint for rotation-aware agents
- SigningKeyQueries for key lifecycle management
Agent:
- Key-ID-aware verification via CheckKeyRotation
- FetchAndCacheAllActiveKeys for rotation pre-caching
- Cache metadata with TTL and staleness fallback
- SecurityLogger events for key rotation and command signing
== A-2: Replay Attack Fixes (F-1 through F-7) ==
F-5 CRITICAL - RetryCommand now signs via signAndCreateCommand
F-1 HIGH - v3 format: "{agent_id}:{cmd_id}:{type}:{hash}:{ts}"
F-7 HIGH - Migration 026: expires_at column with partial index
F-6 HIGH - GetPendingCommands/GetStuckCommands filter by expires_at
F-2 HIGH - Agent-side executedIDs dedup map with cleanup
F-4 HIGH - commandMaxAge reduced from 24h to 4h
F-3 CRITICAL - Old-format commands rejected after 48h via CreatedAt
Verification fixes: migration idempotency (ETHOS #4), log format
compliance (ETHOS #1), stale comments updated.
All 24 tests passing. Docker --no-cache build verified.
See docs/ for full audit reports and deviation log (DEV-001 to DEV-019).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
71
aggregator-server/internal/api/middleware/auth.go
Normal file
71
aggregator-server/internal/api/middleware/auth.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AgentClaims represents JWT claims for agent authentication
|
||||
type AgentClaims struct {
|
||||
AgentID uuid.UUID `json:"agent_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// JWTSecret is set by the server at initialization
|
||||
var JWTSecret string
|
||||
|
||||
// GenerateAgentToken creates a new JWT token for an agent
|
||||
func GenerateAgentToken(agentID uuid.UUID) (string, error) {
|
||||
claims := AgentClaims{
|
||||
AgentID: agentID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(JWTSecret))
|
||||
}
|
||||
|
||||
// AuthMiddleware validates JWT tokens from agents
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenString == authHeader {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid authorization format"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &AgentClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(JWTSecret), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*AgentClaims); ok {
|
||||
c.Set("agent_id", claims.AgentID)
|
||||
c.Next()
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token claims"})
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
}
|
||||
26
aggregator-server/internal/api/middleware/cors.go
Normal file
26
aggregator-server/internal/api/middleware/cors.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORSMiddleware handles Cross-Origin Resource Sharing
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "http://localhost:3000")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
|
||||
c.Header("Access-Control-Expose-Headers", "Content-Length")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
// Handle preflight requests
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
245
aggregator-server/internal/api/middleware/machine_binding.go
Normal file
245
aggregator-server/internal/api/middleware/machine_binding.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries"
|
||||
"github.com/Fimeg/RedFlag/aggregator-server/internal/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// MachineBindingMiddleware validates machine ID matches database record
|
||||
// This prevents agent impersonation via config file copying to different machines
|
||||
func MachineBindingMiddleware(agentQueries *queries.AgentQueries, minAgentVersion string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Skip if not authenticated (handled by auth middleware)
|
||||
agentIDVal, exists := c.Get("agent_id")
|
||||
if !exists {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
agentID, ok := agentIDVal.(uuid.UUID)
|
||||
if !ok {
|
||||
log.Printf("[MachineBinding] Invalid agent_id type in context")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "invalid agent ID"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Get agent from database
|
||||
agent, err := agentQueries.GetAgentByID(agentID)
|
||||
if err != nil {
|
||||
log.Printf("[MachineBinding] Agent %s not found: %v", agentID, err)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "agent not found"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if agent is reporting an update completion
|
||||
reportedVersion := c.GetHeader("X-Agent-Version")
|
||||
updateNonce := c.GetHeader("X-Update-Nonce")
|
||||
|
||||
if agent.IsUpdating && updateNonce != "" {
|
||||
// Validate the nonce first (proves server authorized this update)
|
||||
if agent.PublicKeyFingerprint == nil {
|
||||
log.Printf("[SECURITY] Agent %s has no public key fingerprint for nonce validation", agentID)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "server public key not configured"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if err := validateUpdateNonceMiddleware(updateNonce, *agent.PublicKeyFingerprint); err != nil {
|
||||
log.Printf("[SECURITY] Invalid update nonce for agent %s: %v", agentID, err)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "invalid update nonce"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check for downgrade attempt (security boundary)
|
||||
if !isVersionUpgrade(reportedVersion, agent.CurrentVersion) {
|
||||
log.Printf("[SECURITY] Downgrade attempt detected: agent %s %s → %s",
|
||||
agentID, agent.CurrentVersion, reportedVersion)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "downgrade not allowed"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Valid upgrade - complete it in database
|
||||
go func() {
|
||||
if err := agentQueries.CompleteAgentUpdate(agentID.String(), reportedVersion); err != nil {
|
||||
log.Printf("[ERROR] Failed to complete agent update: %v", err)
|
||||
} else {
|
||||
log.Printf("[system] Agent %s updated: %s → %s", agentID, agent.CurrentVersion, reportedVersion)
|
||||
}
|
||||
}()
|
||||
|
||||
// Allow this request through
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Check minimum version (hard cutoff for legacy de-support)
|
||||
if agent.CurrentVersion != "" && minAgentVersion != "" {
|
||||
if !utils.IsNewerOrEqualVersion(agent.CurrentVersion, minAgentVersion) {
|
||||
// Allow old agents to check in if they have pending update commands
|
||||
// This prevents deadlock where agent can't check in to receive the update
|
||||
if c.Request.Method == "GET" && strings.HasSuffix(c.Request.URL.Path, "/commands") {
|
||||
// Check if agent has pending update command
|
||||
hasPendingUpdate, err := agentQueries.HasPendingUpdateCommand(agentID.String())
|
||||
if err != nil {
|
||||
log.Printf("[MachineBinding] Error checking pending updates for agent %s: %v", agentID, err)
|
||||
}
|
||||
|
||||
if hasPendingUpdate {
|
||||
log.Printf("[MachineBinding] Allowing old agent %s (%s) to check in for update delivery (v%s < v%s)",
|
||||
agent.Hostname, agentID, agent.CurrentVersion, minAgentVersion)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[MachineBinding] Agent %s version %s below minimum %s - rejecting",
|
||||
agent.Hostname, agent.CurrentVersion, minAgentVersion)
|
||||
c.JSON(http.StatusUpgradeRequired, gin.H{
|
||||
"error": "agent version too old - upgrade required for security",
|
||||
"current_version": agent.CurrentVersion,
|
||||
"minimum_version": minAgentVersion,
|
||||
"upgrade_instructions": "Please upgrade to the latest agent version and re-register",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Extract X-Machine-ID header
|
||||
reportedMachineID := c.GetHeader("X-Machine-ID")
|
||||
if reportedMachineID == "" {
|
||||
log.Printf("[MachineBinding] Agent %s (%s) missing X-Machine-ID header",
|
||||
agent.Hostname, agentID)
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "missing machine ID header - agent version too old or tampered",
|
||||
"hint": "Please upgrade to the latest agent version (v0.1.22+)",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Validate machine ID matches database
|
||||
if agent.MachineID == nil {
|
||||
log.Printf("[MachineBinding] Agent %s (%s) has no machine_id in database - legacy agent",
|
||||
agent.Hostname, agentID)
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "agent not bound to machine - re-registration required",
|
||||
"hint": "This agent was registered before v0.1.22. Please re-register with a new registration token.",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if *agent.MachineID != reportedMachineID {
|
||||
log.Printf("[MachineBinding] ⚠️ SECURITY ALERT: Agent %s (%s) machine ID mismatch! DB=%s, Reported=%s",
|
||||
agent.Hostname, agentID, *agent.MachineID, reportedMachineID)
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "machine ID mismatch - config file copied to different machine",
|
||||
"hint": "Agent configuration is bound to the original machine. Please register this machine with a new registration token.",
|
||||
"security_note": "This prevents agent impersonation attacks",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Machine ID validated - allow request
|
||||
log.Printf("[MachineBinding] ✓ Agent %s (%s) machine ID validated: %s",
|
||||
agent.Hostname, agentID, reportedMachineID[:16]+"...")
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func validateUpdateNonceMiddleware(nonceB64, serverPublicKey string) error {
|
||||
// Decode base64 nonce
|
||||
data, err := base64.StdEncoding.DecodeString(nonceB64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid base64: %w", err)
|
||||
}
|
||||
|
||||
// Parse JSON
|
||||
var nonce struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
TargetVersion string `json:"target_version"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
Signature string `json:"signature"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &nonce); err != nil {
|
||||
return fmt.Errorf("invalid format: %w", err)
|
||||
}
|
||||
|
||||
// Check freshness
|
||||
if time.Now().Unix()-nonce.Timestamp > 600 { // 10 minutes
|
||||
return fmt.Errorf("nonce expired (age: %d seconds)", time.Now().Unix()-nonce.Timestamp)
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
signature, err := base64.StdEncoding.DecodeString(nonce.Signature)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid signature encoding: %w", err)
|
||||
}
|
||||
|
||||
// Parse server's public key
|
||||
pubKeyBytes, err := hex.DecodeString(serverPublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid server public key: %w", err)
|
||||
}
|
||||
|
||||
// Remove signature for verification
|
||||
originalSig := nonce.Signature
|
||||
nonce.Signature = ""
|
||||
verifyData, err := json.Marshal(nonce)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal verify data: %w", err)
|
||||
}
|
||||
|
||||
if !ed25519.Verify(ed25519.PublicKey(pubKeyBytes), verifyData, signature) {
|
||||
return fmt.Errorf("signature verification failed")
|
||||
}
|
||||
|
||||
// Restore signature (not needed but good practice)
|
||||
nonce.Signature = originalSig
|
||||
return nil
|
||||
}
|
||||
|
||||
func isVersionUpgrade(new, current string) bool {
|
||||
// Parse semantic versions
|
||||
newParts := strings.Split(new, ".")
|
||||
curParts := strings.Split(current, ".")
|
||||
|
||||
// Convert to integers for comparison
|
||||
newMajor, _ := strconv.Atoi(newParts[0])
|
||||
newMinor, _ := strconv.Atoi(newParts[1])
|
||||
newPatch, _ := strconv.Atoi(newParts[2])
|
||||
|
||||
curMajor, _ := strconv.Atoi(curParts[0])
|
||||
curMinor, _ := strconv.Atoi(curParts[1])
|
||||
curPatch, _ := strconv.Atoi(curParts[2])
|
||||
|
||||
// Check if new > current (not equal, not less)
|
||||
if newMajor > curMajor {
|
||||
return true
|
||||
}
|
||||
if newMajor == curMajor && newMinor > curMinor {
|
||||
return true
|
||||
}
|
||||
if newMajor == curMajor && newMinor == curMinor && newPatch > curPatch {
|
||||
return true
|
||||
}
|
||||
return false // Equal or downgrade
|
||||
}
|
||||
282
aggregator-server/internal/api/middleware/rate_limiter.go
Normal file
282
aggregator-server/internal/api/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RateLimitConfig holds configuration for rate limiting
|
||||
type RateLimitConfig struct {
|
||||
Requests int `json:"requests"`
|
||||
Window time.Duration `json:"window"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// RateLimitEntry tracks requests for a specific key
|
||||
type RateLimitEntry struct {
|
||||
Requests []time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// RateLimiter implements in-memory rate limiting with user-configurable settings
|
||||
type RateLimiter struct {
|
||||
entries sync.Map // map[string]*RateLimitEntry
|
||||
configs map[string]RateLimitConfig
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// RateLimitSettings holds all user-configurable rate limit settings
|
||||
type RateLimitSettings struct {
|
||||
AgentRegistration RateLimitConfig `json:"agent_registration"`
|
||||
AgentCheckIn RateLimitConfig `json:"agent_checkin"`
|
||||
AgentReports RateLimitConfig `json:"agent_reports"`
|
||||
AdminTokenGen RateLimitConfig `json:"admin_token_generation"`
|
||||
AdminOperations RateLimitConfig `json:"admin_operations"`
|
||||
PublicAccess RateLimitConfig `json:"public_access"`
|
||||
}
|
||||
|
||||
// DefaultRateLimitSettings provides sensible defaults
|
||||
func DefaultRateLimitSettings() RateLimitSettings {
|
||||
return RateLimitSettings{
|
||||
AgentRegistration: RateLimitConfig{
|
||||
Requests: 5,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
AgentCheckIn: RateLimitConfig{
|
||||
Requests: 60,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
AgentReports: RateLimitConfig{
|
||||
Requests: 30,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
AdminTokenGen: RateLimitConfig{
|
||||
Requests: 10,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
AdminOperations: RateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
PublicAccess: RateLimitConfig{
|
||||
Requests: 20,
|
||||
Window: time.Minute,
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter with default settings
|
||||
func NewRateLimiter() *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
entries: sync.Map{},
|
||||
}
|
||||
|
||||
// Load default settings
|
||||
defaults := DefaultRateLimitSettings()
|
||||
rl.UpdateSettings(defaults)
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// UpdateSettings updates rate limit configurations
|
||||
func (rl *RateLimiter) UpdateSettings(settings RateLimitSettings) {
|
||||
rl.mutex.Lock()
|
||||
defer rl.mutex.Unlock()
|
||||
|
||||
rl.configs = map[string]RateLimitConfig{
|
||||
"agent_registration": settings.AgentRegistration,
|
||||
"agent_checkin": settings.AgentCheckIn,
|
||||
"agent_reports": settings.AgentReports,
|
||||
"admin_token_gen": settings.AdminTokenGen,
|
||||
"admin_operations": settings.AdminOperations,
|
||||
"public_access": settings.PublicAccess,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSettings returns current rate limit settings
|
||||
func (rl *RateLimiter) GetSettings() RateLimitSettings {
|
||||
rl.mutex.RLock()
|
||||
defer rl.mutex.RUnlock()
|
||||
|
||||
return RateLimitSettings{
|
||||
AgentRegistration: rl.configs["agent_registration"],
|
||||
AgentCheckIn: rl.configs["agent_checkin"],
|
||||
AgentReports: rl.configs["agent_reports"],
|
||||
AdminTokenGen: rl.configs["admin_token_gen"],
|
||||
AdminOperations: rl.configs["admin_operations"],
|
||||
PublicAccess: rl.configs["public_access"],
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimit creates middleware for a specific rate limit type
|
||||
func (rl *RateLimiter) RateLimit(limitType string, keyFunc func(*gin.Context) string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
rl.mutex.RLock()
|
||||
config, exists := rl.configs[limitType]
|
||||
rl.mutex.RUnlock()
|
||||
|
||||
if !exists || !config.Enabled {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
key := keyFunc(c)
|
||||
if key == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Namespace the key by limit type to prevent different endpoints from sharing counters
|
||||
namespacedKey := limitType + ":" + key
|
||||
|
||||
// Check rate limit
|
||||
allowed, resetTime := rl.checkRateLimit(namespacedKey, config)
|
||||
if !allowed {
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", config.Requests))
|
||||
c.Header("X-RateLimit-Remaining", "0")
|
||||
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
|
||||
c.Header("Retry-After", fmt.Sprintf("%d", int(resetTime.Sub(time.Now()).Seconds())))
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "Rate limit exceeded",
|
||||
"limit": config.Requests,
|
||||
"window": config.Window.String(),
|
||||
"reset_time": resetTime,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Add rate limit headers
|
||||
remaining := rl.getRemainingRequests(namespacedKey, config)
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", config.Requests))
|
||||
c.Header("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
||||
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Add(config.Window).Unix()))
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// checkRateLimit checks if the request is allowed
|
||||
func (rl *RateLimiter) checkRateLimit(key string, config RateLimitConfig) (bool, time.Time) {
|
||||
now := time.Now()
|
||||
|
||||
// Get or create entry
|
||||
entryInterface, _ := rl.entries.LoadOrStore(key, &RateLimitEntry{
|
||||
Requests: []time.Time{},
|
||||
})
|
||||
entry := entryInterface.(*RateLimitEntry)
|
||||
|
||||
entry.mutex.Lock()
|
||||
defer entry.mutex.Unlock()
|
||||
|
||||
// Clean old requests outside the window
|
||||
cutoff := now.Add(-config.Window)
|
||||
validRequests := make([]time.Time, 0)
|
||||
for _, reqTime := range entry.Requests {
|
||||
if reqTime.After(cutoff) {
|
||||
validRequests = append(validRequests, reqTime)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if under limit
|
||||
if len(validRequests) >= config.Requests {
|
||||
// Find when the oldest request expires
|
||||
oldestRequest := validRequests[0]
|
||||
resetTime := oldestRequest.Add(config.Window)
|
||||
return false, resetTime
|
||||
}
|
||||
|
||||
// Add current request
|
||||
entry.Requests = append(validRequests, now)
|
||||
|
||||
// Clean up expired entries periodically
|
||||
if len(entry.Requests) == 0 {
|
||||
rl.entries.Delete(key)
|
||||
}
|
||||
|
||||
return true, time.Time{}
|
||||
}
|
||||
|
||||
// getRemainingRequests calculates remaining requests for the key
|
||||
func (rl *RateLimiter) getRemainingRequests(key string, config RateLimitConfig) int {
|
||||
entryInterface, ok := rl.entries.Load(key)
|
||||
if !ok {
|
||||
return config.Requests
|
||||
}
|
||||
|
||||
entry := entryInterface.(*RateLimitEntry)
|
||||
entry.mutex.RLock()
|
||||
defer entry.mutex.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-config.Window)
|
||||
count := 0
|
||||
|
||||
for _, reqTime := range entry.Requests {
|
||||
if reqTime.After(cutoff) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
remaining := config.Requests - count
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
return remaining
|
||||
}
|
||||
|
||||
// CleanupExpiredEntries removes expired entries to prevent memory leaks
|
||||
func (rl *RateLimiter) CleanupExpiredEntries() {
|
||||
rl.entries.Range(func(key, value interface{}) bool {
|
||||
entry := value.(*RateLimitEntry)
|
||||
entry.mutex.Lock()
|
||||
|
||||
now := time.Now()
|
||||
validRequests := make([]time.Time, 0)
|
||||
for _, reqTime := range entry.Requests {
|
||||
if reqTime.After(now.Add(-time.Hour)) { // Keep requests from last hour
|
||||
validRequests = append(validRequests, reqTime)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validRequests) == 0 {
|
||||
rl.entries.Delete(key)
|
||||
} else {
|
||||
entry.Requests = validRequests
|
||||
}
|
||||
|
||||
entry.mutex.Unlock()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Key generation functions
|
||||
func KeyByIP(c *gin.Context) string {
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
func KeyByAgentID(c *gin.Context) string {
|
||||
return c.Param("id")
|
||||
}
|
||||
|
||||
func KeyByUserID(c *gin.Context) string {
|
||||
// This would extract user ID from JWT or session
|
||||
// For now, use IP as fallback
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
func KeyByIPAndPath(c *gin.Context) string {
|
||||
return c.ClientIP() + ":" + c.Request.URL.Path
|
||||
}
|
||||
Reference in New Issue
Block a user