Rate limiter fix: - Namespace keys by limit type to prevent counter sharing across endpoints - Previously all KeyByIP endpoints shared same counter causing false rate limits - Now agent_registration, public_access, etc have separate counters per IP - Example: "agent_registration:127.0.0.1" vs "public_access:127.0.0.1" Session loop fix: - Remove wasInSetupMode from SetupCompletionChecker dependency array - Use local variable instead of state to prevent interval multiplication - Prevents rapid refresh loop during server restart after setup - (turns out useEffect dependency arrays actually matter, who knew) Tested: - First agent registration now succeeds without rate limit (was 429) - Public access requests don't affect agent registration quota - No UI flashing during server restart - Rate limit API endpoints functional (Settings UI needs work)
282 lines
7.1 KiB
Go
282 lines
7.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// RateLimitConfig holds configuration for rate limiting
|
|
type RateLimitConfig struct {
|
|
Requests int `json:"requests"`
|
|
Window time.Duration `json:"window"`
|
|
Enabled bool `json:"enabled"`
|
|
}
|
|
|
|
// RateLimitEntry tracks requests for a specific key
|
|
type RateLimitEntry struct {
|
|
Requests []time.Time
|
|
mutex sync.RWMutex
|
|
}
|
|
|
|
// RateLimiter implements in-memory rate limiting with user-configurable settings
|
|
type RateLimiter struct {
|
|
entries sync.Map // map[string]*RateLimitEntry
|
|
configs map[string]RateLimitConfig
|
|
mutex sync.RWMutex
|
|
}
|
|
|
|
// RateLimitSettings holds all user-configurable rate limit settings
|
|
type RateLimitSettings struct {
|
|
AgentRegistration RateLimitConfig `json:"agent_registration"`
|
|
AgentCheckIn RateLimitConfig `json:"agent_checkin"`
|
|
AgentReports RateLimitConfig `json:"agent_reports"`
|
|
AdminTokenGen RateLimitConfig `json:"admin_token_generation"`
|
|
AdminOperations RateLimitConfig `json:"admin_operations"`
|
|
PublicAccess RateLimitConfig `json:"public_access"`
|
|
}
|
|
|
|
// DefaultRateLimitSettings provides sensible defaults
|
|
func DefaultRateLimitSettings() RateLimitSettings {
|
|
return RateLimitSettings{
|
|
AgentRegistration: RateLimitConfig{
|
|
Requests: 5,
|
|
Window: time.Minute,
|
|
Enabled: true,
|
|
},
|
|
AgentCheckIn: RateLimitConfig{
|
|
Requests: 60,
|
|
Window: time.Minute,
|
|
Enabled: true,
|
|
},
|
|
AgentReports: RateLimitConfig{
|
|
Requests: 30,
|
|
Window: time.Minute,
|
|
Enabled: true,
|
|
},
|
|
AdminTokenGen: RateLimitConfig{
|
|
Requests: 10,
|
|
Window: time.Minute,
|
|
Enabled: true,
|
|
},
|
|
AdminOperations: RateLimitConfig{
|
|
Requests: 100,
|
|
Window: time.Minute,
|
|
Enabled: true,
|
|
},
|
|
PublicAccess: RateLimitConfig{
|
|
Requests: 20,
|
|
Window: time.Minute,
|
|
Enabled: true,
|
|
},
|
|
}
|
|
}
|
|
|
|
// NewRateLimiter creates a new rate limiter with default settings
|
|
func NewRateLimiter() *RateLimiter {
|
|
rl := &RateLimiter{
|
|
entries: sync.Map{},
|
|
}
|
|
|
|
// Load default settings
|
|
defaults := DefaultRateLimitSettings()
|
|
rl.UpdateSettings(defaults)
|
|
|
|
return rl
|
|
}
|
|
|
|
// UpdateSettings updates rate limit configurations
|
|
func (rl *RateLimiter) UpdateSettings(settings RateLimitSettings) {
|
|
rl.mutex.Lock()
|
|
defer rl.mutex.Unlock()
|
|
|
|
rl.configs = map[string]RateLimitConfig{
|
|
"agent_registration": settings.AgentRegistration,
|
|
"agent_checkin": settings.AgentCheckIn,
|
|
"agent_reports": settings.AgentReports,
|
|
"admin_token_gen": settings.AdminTokenGen,
|
|
"admin_operations": settings.AdminOperations,
|
|
"public_access": settings.PublicAccess,
|
|
}
|
|
}
|
|
|
|
// GetSettings returns current rate limit settings
|
|
func (rl *RateLimiter) GetSettings() RateLimitSettings {
|
|
rl.mutex.RLock()
|
|
defer rl.mutex.RUnlock()
|
|
|
|
return RateLimitSettings{
|
|
AgentRegistration: rl.configs["agent_registration"],
|
|
AgentCheckIn: rl.configs["agent_checkin"],
|
|
AgentReports: rl.configs["agent_reports"],
|
|
AdminTokenGen: rl.configs["admin_token_gen"],
|
|
AdminOperations: rl.configs["admin_operations"],
|
|
PublicAccess: rl.configs["public_access"],
|
|
}
|
|
}
|
|
|
|
// RateLimit creates middleware for a specific rate limit type
|
|
func (rl *RateLimiter) RateLimit(limitType string, keyFunc func(*gin.Context) string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
rl.mutex.RLock()
|
|
config, exists := rl.configs[limitType]
|
|
rl.mutex.RUnlock()
|
|
|
|
if !exists || !config.Enabled {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
key := keyFunc(c)
|
|
if key == "" {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
// Namespace the key by limit type to prevent different endpoints from sharing counters
|
|
namespacedKey := limitType + ":" + key
|
|
|
|
// Check rate limit
|
|
allowed, resetTime := rl.checkRateLimit(namespacedKey, config)
|
|
if !allowed {
|
|
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", config.Requests))
|
|
c.Header("X-RateLimit-Remaining", "0")
|
|
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
|
|
c.Header("Retry-After", fmt.Sprintf("%d", int(resetTime.Sub(time.Now()).Seconds())))
|
|
|
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
|
"error": "Rate limit exceeded",
|
|
"limit": config.Requests,
|
|
"window": config.Window.String(),
|
|
"reset_time": resetTime,
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Add rate limit headers
|
|
remaining := rl.getRemainingRequests(namespacedKey, config)
|
|
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", config.Requests))
|
|
c.Header("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
|
c.Header("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Add(config.Window).Unix()))
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// checkRateLimit checks if the request is allowed
|
|
func (rl *RateLimiter) checkRateLimit(key string, config RateLimitConfig) (bool, time.Time) {
|
|
now := time.Now()
|
|
|
|
// Get or create entry
|
|
entryInterface, _ := rl.entries.LoadOrStore(key, &RateLimitEntry{
|
|
Requests: []time.Time{},
|
|
})
|
|
entry := entryInterface.(*RateLimitEntry)
|
|
|
|
entry.mutex.Lock()
|
|
defer entry.mutex.Unlock()
|
|
|
|
// Clean old requests outside the window
|
|
cutoff := now.Add(-config.Window)
|
|
validRequests := make([]time.Time, 0)
|
|
for _, reqTime := range entry.Requests {
|
|
if reqTime.After(cutoff) {
|
|
validRequests = append(validRequests, reqTime)
|
|
}
|
|
}
|
|
|
|
// Check if under limit
|
|
if len(validRequests) >= config.Requests {
|
|
// Find when the oldest request expires
|
|
oldestRequest := validRequests[0]
|
|
resetTime := oldestRequest.Add(config.Window)
|
|
return false, resetTime
|
|
}
|
|
|
|
// Add current request
|
|
entry.Requests = append(validRequests, now)
|
|
|
|
// Clean up expired entries periodically
|
|
if len(entry.Requests) == 0 {
|
|
rl.entries.Delete(key)
|
|
}
|
|
|
|
return true, time.Time{}
|
|
}
|
|
|
|
// getRemainingRequests calculates remaining requests for the key
|
|
func (rl *RateLimiter) getRemainingRequests(key string, config RateLimitConfig) int {
|
|
entryInterface, ok := rl.entries.Load(key)
|
|
if !ok {
|
|
return config.Requests
|
|
}
|
|
|
|
entry := entryInterface.(*RateLimitEntry)
|
|
entry.mutex.RLock()
|
|
defer entry.mutex.RUnlock()
|
|
|
|
now := time.Now()
|
|
cutoff := now.Add(-config.Window)
|
|
count := 0
|
|
|
|
for _, reqTime := range entry.Requests {
|
|
if reqTime.After(cutoff) {
|
|
count++
|
|
}
|
|
}
|
|
|
|
remaining := config.Requests - count
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
|
|
return remaining
|
|
}
|
|
|
|
// CleanupExpiredEntries removes expired entries to prevent memory leaks
|
|
func (rl *RateLimiter) CleanupExpiredEntries() {
|
|
rl.entries.Range(func(key, value interface{}) bool {
|
|
entry := value.(*RateLimitEntry)
|
|
entry.mutex.Lock()
|
|
|
|
now := time.Now()
|
|
validRequests := make([]time.Time, 0)
|
|
for _, reqTime := range entry.Requests {
|
|
if reqTime.After(now.Add(-time.Hour)) { // Keep requests from last hour
|
|
validRequests = append(validRequests, reqTime)
|
|
}
|
|
}
|
|
|
|
if len(validRequests) == 0 {
|
|
rl.entries.Delete(key)
|
|
} else {
|
|
entry.Requests = validRequests
|
|
}
|
|
|
|
entry.mutex.Unlock()
|
|
return true
|
|
})
|
|
}
|
|
|
|
// Key generation functions
|
|
func KeyByIP(c *gin.Context) string {
|
|
return c.ClientIP()
|
|
}
|
|
|
|
func KeyByAgentID(c *gin.Context) string {
|
|
return c.Param("id")
|
|
}
|
|
|
|
func KeyByUserID(c *gin.Context) string {
|
|
// This would extract user ID from JWT or session
|
|
// For now, use IP as fallback
|
|
return c.ClientIP()
|
|
}
|
|
|
|
func KeyByIPAndPath(c *gin.Context) string {
|
|
return c.ClientIP() + ":" + c.Request.URL.Path
|
|
} |