feat: machine binding and version enforcement

migration 017 adds machine_id to agents table
middleware validates X-Machine-ID header on authed routes
agent client sends machine ID with requests
MIN_AGENT_VERSION config defaults 0.1.22
version utils added for comparison

blocks config copying attacks via hardware fingerprint
old agents get 426 upgrade required
breaking: <0.1.22 agents rejected
This commit is contained in:
Fimeg
2025-11-02 09:30:04 -05:00
parent 99480f3fe3
commit ec3ba88459
48 changed files with 3811 additions and 122 deletions

View File

@@ -26,6 +26,9 @@ build-server: ## Build server binary
build-agent: ## Build agent binary build-agent: ## Build agent binary
cd aggregator-agent && go mod tidy && go build -o bin/aggregator-agent cmd/agent/main.go cd aggregator-agent && go mod tidy && go build -o bin/aggregator-agent cmd/agent/main.go
build-agent-simple: ## Build agent binary with simple script
@./scripts/build-secure-agent.sh
clean: ## Clean build artifacts clean: ## Clean build artifacts
rm -rf aggregator-server/bin aggregator-agent/bin rm -rf aggregator-server/bin aggregator-agent/bin

View File

@@ -17,6 +17,7 @@ import (
"github.com/Fimeg/RedFlag/aggregator-agent/internal/circuitbreaker" "github.com/Fimeg/RedFlag/aggregator-agent/internal/circuitbreaker"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/client" "github.com/Fimeg/RedFlag/aggregator-agent/internal/client"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/config" "github.com/Fimeg/RedFlag/aggregator-agent/internal/config"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/crypto"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/display" "github.com/Fimeg/RedFlag/aggregator-agent/internal/display"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/installer" "github.com/Fimeg/RedFlag/aggregator-agent/internal/installer"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/orchestrator" "github.com/Fimeg/RedFlag/aggregator-agent/internal/orchestrator"
@@ -348,12 +349,27 @@ func registerAgent(cfg *config.Config, serverURL string) error {
} }
} }
// Get machine ID for binding
machineID, err := system.GetMachineID()
if err != nil {
log.Printf("Warning: Failed to get machine ID: %v", err)
machineID = "unknown-" + sysInfo.Hostname
}
// Get embedded public key fingerprint
publicKeyFingerprint := system.GetPublicKeyFingerprint()
if publicKeyFingerprint == "" {
log.Printf("Warning: No embedded public key fingerprint found")
}
req := client.RegisterRequest{ req := client.RegisterRequest{
Hostname: sysInfo.Hostname, Hostname: sysInfo.Hostname,
OSType: sysInfo.OSType, OSType: sysInfo.OSType,
OSVersion: sysInfo.OSVersion, OSVersion: sysInfo.OSVersion,
OSArchitecture: sysInfo.OSArchitecture, OSArchitecture: sysInfo.OSArchitecture,
AgentVersion: sysInfo.AgentVersion, AgentVersion: sysInfo.AgentVersion,
MachineID: machineID,
PublicKeyFingerprint: publicKeyFingerprint,
Metadata: metadata, Metadata: metadata,
} }
@@ -376,7 +392,27 @@ func registerAgent(cfg *config.Config, serverURL string) error {
} }
// Save configuration // Save configuration
return cfg.Save(getConfigPath()) if err := cfg.Save(getConfigPath()); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
// Fetch and cache server public key for signature verification
log.Println("Fetching server public key for update signature verification...")
if err := fetchAndCachePublicKey(cfg.ServerURL); err != nil {
log.Printf("Warning: Failed to fetch server public key: %v", err)
log.Printf("Agent will not be able to verify update signatures")
// Don't fail registration - key can be fetched later
} else {
log.Println("✓ Server public key cached successfully")
}
return nil
}
// fetchAndCachePublicKey fetches the server's Ed25519 public key and caches it locally
func fetchAndCachePublicKey(serverURL string) error {
_, err := crypto.FetchAndCacheServerPublicKey(serverURL)
return err
} }
// renewTokenIfNeeded handles 401 errors by renewing the agent token using refresh token // renewTokenIfNeeded handles 401 errors by renewing the agent token using refresh token
@@ -694,6 +730,12 @@ func runAgent(cfg *config.Config) error {
if err := handleReboot(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil { if err := handleReboot(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil {
log.Printf("[Reboot] Error processing reboot command: %v\n", err) log.Printf("[Reboot] Error processing reboot command: %v\n", err)
} }
case "update_agent":
if err := handleUpdateAgent(apiClient, cfg, ackTracker, cmd.Params, cmd.ID); err != nil {
log.Printf("[Update] Error processing agent update command: %v\n", err)
}
default: default:
log.Printf("Unknown command type: %s - reporting as invalid command\n", cmd.Type) log.Printf("Unknown command type: %s - reporting as invalid command\n", cmd.Type)
// Report invalid command back to server // Report invalid command back to server

View File

@@ -2,8 +2,18 @@ package main
import ( import (
"context" "context"
"crypto/aes"
"crypto/cipher"
"crypto/ed25519"
"crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"io"
"log" "log"
"net/http"
"os"
"os/exec"
"runtime"
"time" "time"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/acknowledgment" "github.com/Fimeg/RedFlag/aggregator-agent/internal/acknowledgment"
@@ -39,6 +49,10 @@ func handleScanUpdatesV2(apiClient *client.Client, cfg *config.Config, ackTracke
Stderr: stderr, Stderr: stderr,
ExitCode: exitCode, ExitCode: exitCode,
DurationSeconds: int(duration.Seconds()), DurationSeconds: int(duration.Seconds()),
Metadata: map[string]string{
"subsystem_label": "Package Updates",
"subsystem": "updates",
},
} }
// Report the scan log // Report the scan log
@@ -96,6 +110,10 @@ func handleScanStorage(apiClient *client.Client, cfg *config.Config, ackTracker
Stderr: stderr, Stderr: stderr,
ExitCode: exitCode, ExitCode: exitCode,
DurationSeconds: int(duration.Seconds()), DurationSeconds: int(duration.Seconds()),
Metadata: map[string]string{
"subsystem_label": "Disk Usage",
"subsystem": "storage",
},
} }
// Report the scan log // Report the scan log
@@ -150,6 +168,10 @@ func handleScanSystem(apiClient *client.Client, cfg *config.Config, ackTracker *
Stderr: stderr, Stderr: stderr,
ExitCode: exitCode, ExitCode: exitCode,
DurationSeconds: int(duration.Seconds()), DurationSeconds: int(duration.Seconds()),
Metadata: map[string]string{
"subsystem_label": "System Metrics",
"subsystem": "system",
},
} }
// Report the scan log // Report the scan log
@@ -204,6 +226,10 @@ func handleScanDocker(apiClient *client.Client, cfg *config.Config, ackTracker *
Stderr: stderr, Stderr: stderr,
ExitCode: exitCode, ExitCode: exitCode,
DurationSeconds: int(duration.Seconds()), DurationSeconds: int(duration.Seconds()),
Metadata: map[string]string{
"subsystem_label": "Docker Images",
"subsystem": "docker",
},
} }
// Report the scan log // Report the scan log
@@ -230,3 +256,550 @@ func handleScanDocker(apiClient *client.Client, cfg *config.Config, ackTracker *
return nil return nil
} }
// handleUpdateAgent handles agent update commands with signature verification
func handleUpdateAgent(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, params map[string]interface{}, commandID string) error {
log.Println("Processing agent update command...")
// Extract parameters
version, ok := params["version"].(string)
if !ok {
return fmt.Errorf("missing version parameter")
}
platform, ok := params["platform"].(string)
if !ok {
return fmt.Errorf("missing platform parameter")
}
downloadURL, ok := params["download_url"].(string)
if !ok {
return fmt.Errorf("missing download_url parameter")
}
signature, ok := params["signature"].(string)
if !ok {
return fmt.Errorf("missing signature parameter")
}
checksum, ok := params["checksum"].(string)
if !ok {
return fmt.Errorf("missing checksum parameter")
}
// Extract nonce parameters for replay protection
nonceUUIDStr, ok := params["nonce_uuid"].(string)
if !ok {
return fmt.Errorf("missing nonce_uuid parameter")
}
nonceTimestampStr, ok := params["nonce_timestamp"].(string)
if !ok {
return fmt.Errorf("missing nonce_timestamp parameter")
}
nonceSignature, ok := params["nonce_signature"].(string)
if !ok {
return fmt.Errorf("missing nonce_signature parameter")
}
log.Printf("Updating agent to version %s (%s)", version, platform)
// Validate nonce for replay protection
log.Printf("[tunturi_ed25519] Validating nonce...")
if err := validateNonce(nonceUUIDStr, nonceTimestampStr, nonceSignature); err != nil {
return fmt.Errorf("[tunturi_ed25519] nonce validation failed: %w", err)
}
log.Printf("[tunturi_ed25519] ✓ Nonce validated")
// Record start time for duration calculation
updateStartTime := time.Now()
// Report the update command as started
logReport := client.LogReport{
CommandID: commandID,
Action: "update_agent",
Result: "started",
Stdout: fmt.Sprintf("Starting agent update to version %s\n", version),
Stderr: "",
ExitCode: 0,
DurationSeconds: 0,
Metadata: map[string]string{
"subsystem_label": "Agent Update",
"subsystem": "agent",
"target_version": version,
},
}
if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil {
log.Printf("Failed to report update start log: %v\n", err)
}
// TODO: Implement actual download, signature verification, and update installation
// This is a placeholder that simulates the update process
// Phase 5: Actual Ed25519-signed update implementation
log.Printf("Starting secure update process for version %s", version)
log.Printf("Download URL: %s", downloadURL)
log.Printf("Signature: %s...", signature[:16]) // Log first 16 chars of signature
log.Printf("Expected checksum: %s", checksum)
// Step 1: Download the update package
log.Printf("Step 1: Downloading update package...")
tempBinaryPath, err := downloadUpdatePackage(downloadURL)
if err != nil {
return fmt.Errorf("failed to download update package: %w", err)
}
defer os.Remove(tempBinaryPath) // Cleanup on exit
// Step 2: Verify checksum
log.Printf("Step 2: Verifying checksum...")
actualChecksum, err := computeSHA256(tempBinaryPath)
if err != nil {
return fmt.Errorf("failed to compute checksum: %w", err)
}
if actualChecksum != checksum {
return fmt.Errorf("checksum mismatch: expected %s, got %s", checksum, actualChecksum)
}
log.Printf("✓ Checksum verified: %s", actualChecksum)
// Step 3: Verify Ed25519 signature
log.Printf("[tunturi_ed25519] Step 3: Verifying Ed25519 signature...")
if err := verifyBinarySignature(tempBinaryPath, signature); err != nil {
return fmt.Errorf("[tunturi_ed25519] signature verification failed: %w", err)
}
log.Printf("[tunturi_ed25519] ✓ Signature verified")
// Step 4: Create backup of current binary
log.Printf("Step 4: Creating backup...")
currentBinaryPath, err := getCurrentBinaryPath()
if err != nil {
return fmt.Errorf("failed to determine current binary path: %w", err)
}
backupPath := currentBinaryPath + ".bak"
var updateSuccess bool = false // Track overall success
if err := createBackup(currentBinaryPath, backupPath); err != nil {
log.Printf("Warning: Failed to create backup: %v", err)
} else {
// Defer rollback/cleanup logic
defer func() {
if !updateSuccess {
// Rollback on failure
log.Printf("[tunturi_ed25519] Rollback: restoring from backup...")
if restoreErr := restoreFromBackup(backupPath, currentBinaryPath); restoreErr != nil {
log.Printf("[tunturi_ed25519] CRITICAL: Failed to restore backup: %v", restoreErr)
} else {
log.Printf("[tunturi_ed25519] ✓ Successfully rolled back to backup")
}
} else {
// Clean up backup on success
log.Printf("[tunturi_ed25519] ✓ Update successful, cleaning up backup")
os.Remove(backupPath)
}
}()
}
// Step 5: Atomic installation
log.Printf("Step 5: Installing new binary...")
if err := installNewBinary(tempBinaryPath, currentBinaryPath); err != nil {
return fmt.Errorf("failed to install new binary: %w", err)
}
// Step 6: Restart agent service
log.Printf("Step 6: Restarting agent service...")
if err := restartAgentService(); err != nil {
return fmt.Errorf("failed to restart agent: %w", err)
}
// Step 7: Watchdog timer for confirmation
log.Printf("Step 7: Starting watchdog for update confirmation...")
updateSuccess = waitForUpdateConfirmation(apiClient, cfg, ackTracker, version, 5*time.Minute)
success := updateSuccess // Alias for logging below
finalLogReport := client.LogReport{
CommandID: commandID,
Action: "update_agent",
Result: map[bool]string{true: "success", false: "failure"}[success],
Stdout: fmt.Sprintf("Agent update to version %s %s\n", version, map[bool]string{true: "completed successfully", false: "failed"}[success]),
Stderr: map[bool]string{true: "", false: "Update verification timeout or restart failure"}[success],
ExitCode: map[bool]int{true: 0, false: 1}[success],
DurationSeconds: int(time.Since(updateStartTime).Seconds()),
Metadata: map[string]string{
"subsystem_label": "Agent Update",
"subsystem": "agent",
"target_version": version,
"success": map[bool]string{true: "true", false: "false"}[success],
},
}
if err := reportLogWithAck(apiClient, cfg, ackTracker, finalLogReport); err != nil {
log.Printf("Failed to report update completion log: %v\n", err)
}
if success {
log.Printf("✓ Agent successfully updated to version %s", version)
} else {
return fmt.Errorf("agent update verification failed")
}
return nil
}
// Helper functions for the update process
func downloadUpdatePackage(downloadURL string) (string, error) {
// Download to temporary file
tempFile, err := os.CreateTemp("", "redflag-update-*.bin")
if err != nil {
return "", fmt.Errorf("failed to create temp file: %w", err)
}
defer tempFile.Close()
resp, err := http.Get(downloadURL)
if err != nil {
return "", fmt.Errorf("failed to download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("download failed with status: %d", resp.StatusCode)
}
if _, err := tempFile.ReadFrom(resp.Body); err != nil {
return "", fmt.Errorf("failed to write download: %w", err)
}
return tempFile.Name(), nil
}
func computeSHA256(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", fmt.Errorf("failed to compute hash: %w", err)
}
return hex.EncodeToString(hash.Sum(nil)), nil
}
func getCurrentBinaryPath() (string, error) {
execPath, err := os.Executable()
if err != nil {
return "", fmt.Errorf("failed to get executable path: %w", err)
}
return execPath, nil
}
func createBackup(src, dst string) error {
srcFile, err := os.Open(src)
if err != nil {
return fmt.Errorf("failed to open source: %w", err)
}
defer srcFile.Close()
dstFile, err := os.Create(dst)
if err != nil {
return fmt.Errorf("failed to create backup: %w", err)
}
defer dstFile.Close()
if _, err := dstFile.ReadFrom(srcFile); err != nil {
return fmt.Errorf("failed to copy backup: %w", err)
}
// Ensure backup is executable
if err := os.Chmod(dst, 0755); err != nil {
return fmt.Errorf("failed to set backup permissions: %w", err)
}
return nil
}
func restoreFromBackup(backup, target string) error {
// Remove current binary if it exists
if _, err := os.Stat(target); err == nil {
if err := os.Remove(target); err != nil {
return fmt.Errorf("failed to remove current binary: %w", err)
}
}
// Copy backup to target
return createBackup(backup, target)
}
func installNewBinary(src, dst string) error {
// Copy new binary to a temporary location first
tempDst := dst + ".new"
srcFile, err := os.Open(src)
if err != nil {
return fmt.Errorf("failed to open source binary: %w", err)
}
defer srcFile.Close()
dstFile, err := os.Create(tempDst)
if err != nil {
return fmt.Errorf("failed to create temp binary: %w", err)
}
defer dstFile.Close()
if _, err := dstFile.ReadFrom(srcFile); err != nil {
return fmt.Errorf("failed to copy binary: %w", err)
}
dstFile.Close()
// Set executable permissions
if err := os.Chmod(tempDst, 0755); err != nil {
return fmt.Errorf("failed to set binary permissions: %w", err)
}
// Atomic rename
if err := os.Rename(tempDst, dst); err != nil {
os.Remove(tempDst) // Cleanup temp file
return fmt.Errorf("failed to atomically replace binary: %w", err)
}
return nil
}
func restartAgentService() error {
var cmd *exec.Cmd
switch runtime.GOOS {
case "linux":
// Try systemd first
cmd = exec.Command("systemctl", "restart", "redflag-agent")
if err := cmd.Run(); err == nil {
log.Printf("✓ Systemd service restarted")
return nil
}
// Fallback to service command
cmd = exec.Command("service", "redflag-agent", "restart")
case "windows":
cmd = exec.Command("sc", "stop", "RedFlagAgent")
cmd.Run()
cmd = exec.Command("sc", "start", "RedFlagAgent")
default:
return fmt.Errorf("unsupported OS for service restart: %s", runtime.GOOS)
}
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to restart service: %w", err)
}
log.Printf("✓ Agent service restarted")
return nil
}
func waitForUpdateConfirmation(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, expectedVersion string, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
pollInterval := 15 * time.Second
log.Printf("[tunturi_ed25519] Watchdog: waiting for version %s confirmation (timeout: %v)...", expectedVersion, timeout)
for time.Now().Before(deadline) {
// Poll server for current agent version
agent, err := apiClient.GetAgent(cfg.AgentID.String())
if err != nil {
log.Printf("[tunturi_ed25519] Watchdog: failed to poll server: %v (retrying...)", err)
time.Sleep(pollInterval)
continue
}
// Check if the version matches the expected version
if agent != nil && agent.CurrentVersion == expectedVersion {
log.Printf("[tunturi_ed25519] Watchdog: ✓ Version confirmed: %s", expectedVersion)
return true
}
log.Printf("[tunturi_ed25519] Watchdog: Current version: %s, Expected: %s (polling...)",
agent.CurrentVersion, expectedVersion)
time.Sleep(pollInterval)
}
log.Printf("[tunturi_ed25519] Watchdog: ✗ Timeout after %v - version not confirmed", timeout)
log.Printf("[tunturi_ed25519] Rollback initiated")
return false
}
// AES-256-GCM decryption helper functions for encrypted update packages
// deriveKeyFromNonce derives an AES-256 key from a nonce using SHA-256
func deriveKeyFromNonce(nonce string) []byte {
hash := sha256.Sum256([]byte(nonce))
return hash[:] // 32 bytes for AES-256
}
// decryptAES256GCM decrypts data using AES-256-GCM with the provided nonce-derived key
func decryptAES256GCM(encryptedData, nonce string) ([]byte, error) {
// Derive key from nonce
key := deriveKeyFromNonce(nonce)
// Decode hex data
data, err := hex.DecodeString(encryptedData)
if err != nil {
return nil, fmt.Errorf("failed to decode hex data: %w", err)
}
// Create AES cipher
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
// Create GCM
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
// Check minimum length
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, fmt.Errorf("encrypted data too short")
}
// Extract nonce and ciphertext
nonceBytes, ciphertext := data[:nonceSize], data[nonceSize:]
// Decrypt
plaintext, err := gcm.Open(nil, nonceBytes, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt: %w", err)
}
return plaintext, nil
}
// TODO: Integration with system/machine_id.go for key derivation
// This stub should be integrated with the existing machine ID system
// for more sophisticated key management based on hardware fingerprinting
//
// Example integration approach:
// - Use machine_id.go to generate stable hardware fingerprint
// - Combine hardware fingerprint with nonce for key derivation
// - Store derived keys securely in memory only
// - Implement key rotation support for long-running agents
// verifyBinarySignature verifies the Ed25519 signature of a binary file
func verifyBinarySignature(binaryPath, signatureHex string) error {
// Get the server public key from cache
publicKey, err := getServerPublicKey()
if err != nil {
return fmt.Errorf("failed to get server public key: %w", err)
}
// Read the binary content
content, err := os.ReadFile(binaryPath)
if err != nil {
return fmt.Errorf("failed to read binary: %w", err)
}
// Decode signature from hex
signatureBytes, err := hex.DecodeString(signatureHex)
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
// Verify signature length
if len(signatureBytes) != ed25519.SignatureSize {
return fmt.Errorf("invalid signature length: expected %d bytes, got %d", ed25519.SignatureSize, len(signatureBytes))
}
// Ed25519 verification
valid := ed25519.Verify(ed25519.PublicKey(publicKey), content, signatureBytes)
if !valid {
return fmt.Errorf("signature verification failed: invalid signature")
}
return nil
}
// getServerPublicKey retrieves the Ed25519 public key from cache
// The key is fetched from the server at startup and cached locally
func getServerPublicKey() ([]byte, error) {
// Load from cache (fetched during agent startup)
publicKey, err := loadCachedPublicKeyDirect()
if err != nil {
return nil, fmt.Errorf("failed to load server public key: %w (hint: key is fetched at agent startup)", err)
}
return publicKey, nil
}
// loadCachedPublicKeyDirect loads the cached public key from the standard location
func loadCachedPublicKeyDirect() ([]byte, error) {
var keyPath string
if runtime.GOOS == "windows" {
keyPath = "C:\\ProgramData\\RedFlag\\server_public_key"
} else {
keyPath = "/etc/aggregator/server_public_key"
}
data, err := os.ReadFile(keyPath)
if err != nil {
return nil, fmt.Errorf("public key not found: %w", err)
}
if len(data) != 32 { // ed25519.PublicKeySize
return nil, fmt.Errorf("invalid public key size: expected 32 bytes, got %d", len(data))
}
return data, nil
}
// validateNonce validates the nonce for replay protection
func validateNonce(nonceUUIDStr, nonceTimestampStr, nonceSignature string) error {
// Parse timestamp
nonceTimestamp, err := time.Parse(time.RFC3339, nonceTimestampStr)
if err != nil {
return fmt.Errorf("invalid nonce timestamp format: %w", err)
}
// Check freshness (< 5 minutes)
age := time.Since(nonceTimestamp)
if age > 5*time.Minute {
return fmt.Errorf("nonce expired: age %v > 5 minutes", age)
}
if age < 0 {
return fmt.Errorf("nonce timestamp in the future: %v", nonceTimestamp)
}
// Get server public key from cache
publicKey, err := getServerPublicKey()
if err != nil {
return fmt.Errorf("failed to get server public key: %w", err)
}
// Recreate nonce data (must match server format)
nonceData := fmt.Sprintf("%s:%d", nonceUUIDStr, nonceTimestamp.Unix())
// Decode signature
signatureBytes, err := hex.DecodeString(nonceSignature)
if err != nil {
return fmt.Errorf("invalid nonce signature format: %w", err)
}
if len(signatureBytes) != ed25519.SignatureSize {
return fmt.Errorf("invalid nonce signature length: expected %d bytes, got %d",
ed25519.SignatureSize, len(signatureBytes))
}
// Verify Ed25519 signature
valid := ed25519.Verify(ed25519.PublicKey(publicKey), []byte(nonceData), signatureBytes)
if !valid {
return fmt.Errorf("invalid nonce signature")
}
return nil
}

View File

@@ -3,10 +3,12 @@ module github.com/Fimeg/RedFlag/aggregator-agent
go 1.23.0 go 1.23.0
require ( require (
github.com/denisbrodbeck/machineid v1.0.1
github.com/docker/docker v27.4.1+incompatible github.com/docker/docker v27.4.1+incompatible
github.com/go-ole/go-ole v1.3.0 github.com/go-ole/go-ole v1.3.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/scjalliance/comshim v0.0.0-20250111221056-b2ef9d8d7e0f github.com/scjalliance/comshim v0.0.0-20250111221056-b2ef9d8d7e0f
golang.org/x/sys v0.35.0
) )
require ( require (
@@ -31,7 +33,6 @@ require (
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect go.opentelemetry.io/otel/trace v1.38.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
gotest.tools/v3 v3.5.2 // indirect gotest.tools/v3 v3.5.2 // indirect
) )

View File

@@ -8,6 +8,8 @@ github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/denisbrodbeck/machineid v1.0.1 h1:geKr9qtkB876mXguW2X6TU4ZynleN6ezuMSRhl4D7AQ=
github.com/denisbrodbeck/machineid v1.0.1/go.mod h1:dJUwb7PTidGDeYyUBmXZ2GphQBbjJCrnectwCyxcUSI=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v27.4.1+incompatible h1:ZJvcY7gfwHn1JF48PfbyXg7Jyt9ZCWDW+GGXOIxEwp4= github.com/docker/docker v27.4.1+incompatible h1:ZJvcY7gfwHn1JF48PfbyXg7Jyt9ZCWDW+GGXOIxEwp4=

View File

@@ -11,6 +11,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/Fimeg/RedFlag/aggregator-agent/internal/system"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -21,19 +22,36 @@ type Client struct {
http *http.Client http *http.Client
RapidPollingEnabled bool RapidPollingEnabled bool
RapidPollingUntil time.Time RapidPollingUntil time.Time
machineID string // Cached machine ID for security binding
} }
// NewClient creates a new API client // NewClient creates a new API client
func NewClient(baseURL, token string) *Client { func NewClient(baseURL, token string) *Client {
// Get machine ID for security binding (v0.1.22+)
machineID, err := system.GetMachineID()
if err != nil {
// Log warning but don't fail - older servers may not require it
fmt.Printf("Warning: Failed to get machine ID: %v\n", err)
machineID = "" // Will be handled by server validation
}
return &Client{ return &Client{
baseURL: baseURL, baseURL: baseURL,
token: token, token: token,
machineID: machineID,
http: &http.Client{ http: &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
}, },
} }
} }
// addMachineIDHeader adds X-Machine-ID header to authenticated requests (v0.1.22+)
func (c *Client) addMachineIDHeader(req *http.Request) {
if c.machineID != "" {
req.Header.Set("X-Machine-ID", c.machineID)
}
}
// GetToken returns the current JWT token // GetToken returns the current JWT token
func (c *Client) GetToken() string { func (c *Client) GetToken() string {
return c.token return c.token
@@ -52,6 +70,8 @@ type RegisterRequest struct {
OSArchitecture string `json:"os_architecture"` OSArchitecture string `json:"os_architecture"`
AgentVersion string `json:"agent_version"` AgentVersion string `json:"agent_version"`
RegistrationToken string `json:"registration_token,omitempty"` // Fallback method RegistrationToken string `json:"registration_token,omitempty"` // Fallback method
MachineID string `json:"machine_id"`
PublicKeyFingerprint string `json:"public_key_fingerprint"`
Metadata map[string]string `json:"metadata"` Metadata map[string]string `json:"metadata"`
} }
@@ -230,6 +250,7 @@ func (c *Client) GetCommands(agentID uuid.UUID, metrics *SystemMetrics) (*Comman
} }
req.Header.Set("Authorization", "Bearer "+c.token) req.Header.Set("Authorization", "Bearer "+c.token)
c.addMachineIDHeader(req) // Security: Validate machine binding (v0.1.22+)
resp, err := c.http.Do(req) resp, err := c.http.Do(req)
if err != nil { if err != nil {
@@ -297,6 +318,7 @@ func (c *Client) ReportUpdates(agentID uuid.UUID, report UpdateReport) error {
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.token) req.Header.Set("Authorization", "Bearer "+c.token)
c.addMachineIDHeader(req) // Security: Validate machine binding (v0.1.22+)
resp, err := c.http.Do(req) resp, err := c.http.Do(req)
if err != nil { if err != nil {
@@ -321,6 +343,7 @@ type LogReport struct {
Stderr string `json:"stderr"` Stderr string `json:"stderr"`
ExitCode int `json:"exit_code"` ExitCode int `json:"exit_code"`
DurationSeconds int `json:"duration_seconds"` DurationSeconds int `json:"duration_seconds"`
Metadata map[string]string `json:"metadata,omitempty"`
} }
// ReportLog sends an execution log to the server // ReportLog sends an execution log to the server
@@ -338,6 +361,7 @@ func (c *Client) ReportLog(agentID uuid.UUID, report LogReport) error {
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.token) req.Header.Set("Authorization", "Bearer "+c.token)
c.addMachineIDHeader(req) // Security: Validate machine binding (v0.1.22+)
resp, err := c.http.Do(req) resp, err := c.http.Do(req)
if err != nil { if err != nil {
@@ -392,6 +416,7 @@ func (c *Client) ReportDependencies(agentID uuid.UUID, report DependencyReport)
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.token) req.Header.Set("Authorization", "Bearer "+c.token)
c.addMachineIDHeader(req) // Security: Validate machine binding (v0.1.22+)
resp, err := c.http.Do(req) resp, err := c.http.Do(req)
if err != nil { if err != nil {
@@ -437,6 +462,7 @@ func (c *Client) ReportSystemInfo(agentID uuid.UUID, report SystemInfoReport) er
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.token) req.Header.Set("Authorization", "Bearer "+c.token)
c.addMachineIDHeader(req) // Security: Validate machine binding (v0.1.22+)
resp, err := c.http.Do(req) resp, err := c.http.Do(req)
if err != nil { if err != nil {
@@ -474,6 +500,49 @@ func DetectSystem() (osType, osVersion, osArch string) {
return return
} }
// AgentInfo represents agent information from the server
type AgentInfo struct {
ID string `json:"id"`
Hostname string `json:"hostname"`
CurrentVersion string `json:"current_version"`
OSType string `json:"os_type"`
OSVersion string `json:"os_version"`
OSArchitecture string `json:"os_architecture"`
LastCheckIn string `json:"last_check_in"`
}
// GetAgent retrieves agent information from the server
func (c *Client) GetAgent(agentID string) (*AgentInfo, error) {
url := fmt.Sprintf("%s/api/v1/agents/%s", c.baseURL, agentID)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.token)
req.Header.Set("Content-Type", "application/json")
c.addMachineIDHeader(req) // Security: Validate machine binding (v0.1.22+)
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("server returned status %d: %s", resp.StatusCode, string(body))
}
var agent AgentInfo
if err := json.NewDecoder(resp.Body).Decode(&agent); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &agent, nil
}
// parseOSRelease parses /etc/os-release to get proper distro name // parseOSRelease parses /etc/os-release to get proper distro name
func parseOSRelease(data []byte) string { func parseOSRelease(data []byte) string {
lines := strings.Split(string(data), "\n") lines := strings.Split(string(data), "\n")

View File

@@ -0,0 +1,130 @@
package crypto
import (
"crypto/ed25519"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
)
// getPublicKeyPath returns the platform-specific path for storing the server's public key
func getPublicKeyPath() string {
if runtime.GOOS == "windows" {
return "C:\\ProgramData\\RedFlag\\server_public_key"
}
return "/etc/aggregator/server_public_key"
}
// PublicKeyResponse represents the server's public key response
type PublicKeyResponse struct {
PublicKey string `json:"public_key"`
Fingerprint string `json:"fingerprint"`
Algorithm string `json:"algorithm"`
KeySize int `json:"key_size"`
}
// FetchAndCacheServerPublicKey fetches the server's Ed25519 public key and caches it locally
// This implements Trust-On-First-Use (TOFU) security model
func FetchAndCacheServerPublicKey(serverURL string) (ed25519.PublicKey, error) {
// Check if we already have a cached key
if cachedKey, err := LoadCachedPublicKey(); err == nil && cachedKey != nil {
return cachedKey, nil
}
// Fetch from server
resp, err := http.Get(serverURL + "/api/v1/public-key")
if err != nil {
return nil, fmt.Errorf("failed to fetch public key from server: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("server returned status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var keyResp PublicKeyResponse
if err := json.NewDecoder(resp.Body).Decode(&keyResp); err != nil {
return nil, fmt.Errorf("failed to parse public key response: %w", err)
}
// Validate algorithm
if keyResp.Algorithm != "ed25519" {
return nil, fmt.Errorf("unsupported signature algorithm: %s (expected ed25519)", keyResp.Algorithm)
}
// Decode hex public key
pubKeyBytes, err := hex.DecodeString(keyResp.PublicKey)
if err != nil {
return nil, fmt.Errorf("invalid public key format: %w", err)
}
if len(pubKeyBytes) != ed25519.PublicKeySize {
return nil, fmt.Errorf("invalid public key size: expected %d bytes, got %d",
ed25519.PublicKeySize, len(pubKeyBytes))
}
publicKey := ed25519.PublicKey(pubKeyBytes)
// Cache it for future use
if err := cachePublicKey(publicKey); err != nil {
// Log warning but don't fail - we have the key in memory
fmt.Printf("Warning: Failed to cache public key: %v\n", err)
}
fmt.Printf("✓ Server public key fetched and cached (fingerprint: %s)\n", keyResp.Fingerprint)
return publicKey, nil
}
// LoadCachedPublicKey loads the cached public key from disk
func LoadCachedPublicKey() (ed25519.PublicKey, error) {
keyPath := getPublicKeyPath()
data, err := os.ReadFile(keyPath)
if err != nil {
return nil, err // File doesn't exist or can't be read
}
if len(data) != ed25519.PublicKeySize {
return nil, fmt.Errorf("cached public key has invalid size: %d bytes", len(data))
}
return ed25519.PublicKey(data), nil
}
// cachePublicKey saves the public key to disk
func cachePublicKey(publicKey ed25519.PublicKey) error {
keyPath := getPublicKeyPath()
// Ensure directory exists
dir := filepath.Dir(keyPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
// Write public key (read-only for non-root users)
if err := os.WriteFile(keyPath, publicKey, 0644); err != nil {
return fmt.Errorf("failed to write public key: %w", err)
}
return nil
}
// GetPublicKey returns the cached public key or fetches it from the server
// This is the main entry point for getting the verification key
func GetPublicKey(serverURL string) (ed25519.PublicKey, error) {
// Try cached key first
if cachedKey, err := LoadCachedPublicKey(); err == nil {
return cachedKey, nil
}
// Fetch from server if not cached
return FetchAndCacheServerPublicKey(serverURL)
}

View File

@@ -209,6 +209,13 @@ func (s *redflagService) runAgent() {
} }
} }
// Check if commands response is valid
if commands == nil {
log.Printf("Check-in successful - no commands received (nil response)")
elog.Info(1, "Check-in successful - no commands received (nil response)")
continue
}
if len(commands.Commands) == 0 { if len(commands.Commands) == 0 {
log.Printf("Check-in successful - no new commands") log.Printf("Check-in successful - no new commands")
elog.Info(1, "Check-in successful - no new commands") elog.Info(1, "Check-in successful - no new commands")

View File

@@ -0,0 +1,129 @@
package system
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"runtime"
"strings"
"github.com/denisbrodbeck/machineid"
)
// GetMachineID generates a unique machine identifier that persists across reboots
func GetMachineID() (string, error) {
// Try machineid library first (cross-platform)
id, err := machineid.ID()
if err == nil && id != "" {
// Hash the machine ID for consistency and privacy
return hashMachineID(id), nil
}
// Fallback to OS-specific methods
switch runtime.GOOS {
case "linux":
return getLinuxMachineID()
case "windows":
return getWindowsMachineID()
case "darwin":
return getDarwinMachineID()
default:
return generateGenericMachineID()
}
}
// hashMachineID creates a consistent hash from machine ID
func hashMachineID(id string) string {
hash := sha256.Sum256([]byte(id))
return hex.EncodeToString(hash[:]) // Return full hash for uniqueness
}
// getLinuxMachineID tries multiple sources for Linux machine ID
func getLinuxMachineID() (string, error) {
// Try /etc/machine-id first (systemd)
if id, err := os.ReadFile("/etc/machine-id"); err == nil {
idStr := strings.TrimSpace(string(id))
if idStr != "" {
return hashMachineID(idStr), nil
}
}
// Try /var/lib/dbus/machine-id
if id, err := os.ReadFile("/var/lib/dbus/machine-id"); err == nil {
idStr := strings.TrimSpace(string(id))
if idStr != "" {
return hashMachineID(idStr), nil
}
}
// Try DMI product UUID
if id, err := os.ReadFile("/sys/class/dmi/id/product_uuid"); err == nil {
idStr := strings.TrimSpace(string(id))
if idStr != "" {
return hashMachineID(idStr), nil
}
}
// Try /etc/hostname as last resort
if hostname, err := os.ReadFile("/etc/hostname"); err == nil {
hostnameStr := strings.TrimSpace(string(hostname))
if hostnameStr != "" {
return hashMachineID(hostnameStr + "-linux-fallback"), nil
}
}
return generateGenericMachineID()
}
// getWindowsMachineID gets Windows machine ID
func getWindowsMachineID() (string, error) {
// Try machineid library Windows registry keys first
if id, err := machineid.ID(); err == nil && id != "" {
return hashMachineID(id), nil
}
// Fallback to generating generic ID
return generateGenericMachineID()
}
// getDarwinMachineID gets macOS machine ID
func getDarwinMachineID() (string, error) {
// Try machineid library platform-specific keys first
if id, err := machineid.ID(); err == nil && id != "" {
return hashMachineID(id), nil
}
// Fallback to generating generic ID
return generateGenericMachineID()
}
// generateGenericMachineID creates a fallback machine ID from available system info
func generateGenericMachineID() (string, error) {
// Combine hostname with other available info
hostname, _ := os.Hostname()
if hostname == "" {
hostname = "unknown"
}
// Create a reasonably unique ID from available system info
idSource := fmt.Sprintf("%s-%s-%s", hostname, runtime.GOOS, runtime.GOARCH)
return hashMachineID(idSource), nil
}
// GetEmbeddedPublicKey returns the embedded public key fingerprint
// This should be set at build time using ldflags
var EmbeddedPublicKey = "not-set-at-build-time"
// GetPublicKeyFingerprint returns the fingerprint of the embedded public key
func GetPublicKeyFingerprint() string {
if EmbeddedPublicKey == "not-set-at-build-time" {
return ""
}
// Return first 8 bytes as fingerprint
if len(EmbeddedPublicKey) >= 16 {
return EmbeddedPublicKey[:16]
}
return EmbeddedPublicKey
}

Binary file not shown.

View File

@@ -5,6 +5,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"log" "log"
"net/http"
"path/filepath" "path/filepath"
"time" "time"
@@ -129,6 +130,7 @@ func main() {
registrationTokenQueries := queries.NewRegistrationTokenQueries(db.DB) registrationTokenQueries := queries.NewRegistrationTokenQueries(db.DB)
userQueries := queries.NewUserQueries(db.DB) userQueries := queries.NewUserQueries(db.DB)
subsystemQueries := queries.NewSubsystemQueries(db.DB) subsystemQueries := queries.NewSubsystemQueries(db.DB)
agentUpdateQueries := queries.NewAgentUpdateQueries(db.DB)
// Ensure admin user exists // Ensure admin user exists
if err := userQueries.EnsureAdminUser(cfg.Admin.Username, cfg.Admin.Username+"@redflag.local", cfg.Admin.Password); err != nil { if err := userQueries.EnsureAdminUser(cfg.Admin.Username, cfg.Admin.Username+"@redflag.local", cfg.Admin.Password); err != nil {
@@ -141,6 +143,20 @@ func main() {
timezoneService := services.NewTimezoneService(cfg) timezoneService := services.NewTimezoneService(cfg)
timeoutService := services.NewTimeoutService(commandQueries, updateQueries) timeoutService := services.NewTimeoutService(commandQueries, updateQueries)
// Initialize signing service if private key is configured
var signingService *services.SigningService
if cfg.SigningPrivateKey != "" {
var err error
signingService, err = services.NewSigningService(cfg.SigningPrivateKey)
if err != nil {
log.Printf("Warning: Failed to initialize signing service: %v", err)
} else {
log.Printf("✅ Ed25519 signing service initialized")
}
} else {
log.Printf("Warning: No signing private key configured - agent update signing disabled")
}
// Initialize rate limiter // Initialize rate limiter
rateLimiter := middleware.NewRateLimiter() rateLimiter := middleware.NewRateLimiter()
@@ -156,6 +172,21 @@ func main() {
downloadHandler := handlers.NewDownloadHandler(filepath.Join("/app"), cfg) downloadHandler := handlers.NewDownloadHandler(filepath.Join("/app"), cfg)
subsystemHandler := handlers.NewSubsystemHandler(subsystemQueries, commandQueries) subsystemHandler := handlers.NewSubsystemHandler(subsystemQueries, commandQueries)
// Initialize verification handler
var verificationHandler *handlers.VerificationHandler
if signingService != nil {
verificationHandler = handlers.NewVerificationHandler(agentQueries, signingService)
}
// Initialize agent update handler
var agentUpdateHandler *handlers.AgentUpdateHandler
if signingService != nil {
agentUpdateHandler = handlers.NewAgentUpdateHandler(agentQueries, agentUpdateQueries, commandQueries, signingService, agentHandler)
}
// Initialize system handler
systemHandler := handlers.NewSystemHandler(signingService)
// Setup router // Setup router
router := gin.Default() router := gin.Default()
@@ -178,17 +209,23 @@ func main() {
api.POST("/auth/logout", authHandler.Logout) api.POST("/auth/logout", authHandler.Logout)
api.GET("/auth/verify", authHandler.VerifyToken) api.GET("/auth/verify", authHandler.VerifyToken)
// Public system routes (no authentication required)
api.GET("/public-key", rateLimiter.RateLimit("public_access", middleware.KeyByIP), systemHandler.GetPublicKey)
api.GET("/info", rateLimiter.RateLimit("public_access", middleware.KeyByIP), systemHandler.GetSystemInfo)
// Public routes (no authentication required, with rate limiting) // Public routes (no authentication required, with rate limiting)
api.POST("/agents/register", rateLimiter.RateLimit("agent_registration", middleware.KeyByIP), agentHandler.RegisterAgent) api.POST("/agents/register", rateLimiter.RateLimit("agent_registration", middleware.KeyByIP), agentHandler.RegisterAgent)
api.POST("/agents/renew", rateLimiter.RateLimit("public_access", middleware.KeyByIP), agentHandler.RenewToken) api.POST("/agents/renew", rateLimiter.RateLimit("public_access", middleware.KeyByIP), agentHandler.RenewToken)
// Public download routes (no authentication - agents need these!) // Public download routes (no authentication - agents need these!)
api.GET("/downloads/:platform", rateLimiter.RateLimit("public_access", middleware.KeyByIP), downloadHandler.DownloadAgent) api.GET("/downloads/:platform", rateLimiter.RateLimit("public_access", middleware.KeyByIP), downloadHandler.DownloadAgent)
api.GET("/downloads/updates/:package_id", rateLimiter.RateLimit("public_access", middleware.KeyByIP), downloadHandler.DownloadUpdatePackage)
api.GET("/install/:platform", rateLimiter.RateLimit("public_access", middleware.KeyByIP), downloadHandler.InstallScript) api.GET("/install/:platform", rateLimiter.RateLimit("public_access", middleware.KeyByIP), downloadHandler.InstallScript)
// Protected agent routes // Protected agent routes (with machine binding security)
agents := api.Group("/agents") agents := api.Group("/agents")
agents.Use(middleware.AuthMiddleware()) agents.Use(middleware.AuthMiddleware())
agents.Use(middleware.MachineBindingMiddleware(agentQueries, cfg.MinAgentVersion)) // v0.1.22: Prevent config copying
{ {
agents.GET("/:id/commands", agentHandler.GetCommands) agents.GET("/:id/commands", agentHandler.GetCommands)
agents.POST("/:id/updates", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), updateHandler.ReportUpdates) agents.POST("/:id/updates", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), updateHandler.ReportUpdates)
@@ -196,6 +233,13 @@ func main() {
agents.POST("/:id/dependencies", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), updateHandler.ReportDependencies) agents.POST("/:id/dependencies", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), updateHandler.ReportDependencies)
agents.POST("/:id/system-info", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), agentHandler.ReportSystemInfo) agents.POST("/:id/system-info", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), agentHandler.ReportSystemInfo)
agents.POST("/:id/rapid-mode", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), agentHandler.SetRapidPollingMode) agents.POST("/:id/rapid-mode", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), agentHandler.SetRapidPollingMode)
agents.POST("/:id/verify-signature", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), func(c *gin.Context) {
if verificationHandler == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "signature verification service not available"})
return
}
verificationHandler.VerifySignature(c)
})
agents.DELETE("/:id", agentHandler.UnregisterAgent) agents.DELETE("/:id", agentHandler.UnregisterAgent)
// Subsystem routes // Subsystem routes
@@ -231,6 +275,14 @@ func main() {
dashboard.POST("/updates/:id/install", updateHandler.InstallUpdate) dashboard.POST("/updates/:id/install", updateHandler.InstallUpdate)
dashboard.POST("/updates/:id/confirm-dependencies", updateHandler.ConfirmDependencies) dashboard.POST("/updates/:id/confirm-dependencies", updateHandler.ConfirmDependencies)
// Agent update routes
if agentUpdateHandler != nil {
dashboard.POST("/agents/:id/update", agentUpdateHandler.UpdateAgent)
dashboard.POST("/agents/bulk-update", agentUpdateHandler.BulkUpdateAgents)
dashboard.GET("/updates/packages", agentUpdateHandler.ListUpdatePackages)
dashboard.POST("/updates/packages/sign", agentUpdateHandler.SignUpdatePackage)
}
// Log routes // Log routes
dashboard.GET("/logs", updateHandler.GetAllLogs) dashboard.GET("/logs", updateHandler.GetAllLogs)
dashboard.GET("/logs/active", updateHandler.GetActiveOperations) dashboard.GET("/logs/active", updateHandler.GetActiveOperations)

View File

@@ -7,9 +7,8 @@ require (
github.com/golang-jwt/jwt/v5 v5.3.0 github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/jmoiron/sqlx v1.4.0 github.com/jmoiron/sqlx v1.4.0
github.com/joho/godotenv v1.5.1
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
golang.org/x/term v0.33.0 golang.org/x/crypto v0.40.0
) )
require ( require (
@@ -36,7 +35,6 @@ require (
github.com/ugorji/go/codec v1.3.0 // indirect github.com/ugorji/go/codec v1.3.0 // indirect
go.uber.org/mock v0.5.0 // indirect go.uber.org/mock v0.5.0 // indirect
golang.org/x/arch v0.20.0 // indirect golang.org/x/arch v0.20.0 // indirect
golang.org/x/crypto v0.40.0 // indirect
golang.org/x/mod v0.25.0 // indirect golang.org/x/mod v0.25.0 // indirect
golang.org/x/net v0.42.0 // indirect golang.org/x/net v0.42.0 // indirect
golang.org/x/sync v0.16.0 // indirect golang.org/x/sync v0.16.0 // indirect

View File

@@ -38,8 +38,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
@@ -92,8 +90,6 @@ golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=

View File

@@ -0,0 +1,401 @@
package handlers
import (
"fmt"
"log"
"net/http"
"strconv"
"strings"
"time"
"github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries"
"github.com/Fimeg/RedFlag/aggregator-server/internal/models"
"github.com/Fimeg/RedFlag/aggregator-server/internal/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// AgentUpdateHandler handles agent update operations
type AgentUpdateHandler struct {
agentQueries *queries.AgentQueries
agentUpdateQueries *queries.AgentUpdateQueries
commandQueries *queries.CommandQueries
signingService *services.SigningService
agentHandler *AgentHandler
}
// NewAgentUpdateHandler creates a new agent update handler
func NewAgentUpdateHandler(aq *queries.AgentQueries, auq *queries.AgentUpdateQueries, cq *queries.CommandQueries, ss *services.SigningService, ah *AgentHandler) *AgentUpdateHandler {
return &AgentUpdateHandler{
agentQueries: aq,
agentUpdateQueries: auq,
commandQueries: cq,
signingService: ss,
agentHandler: ah,
}
}
// UpdateAgent handles POST /api/v1/agents/:id/update (manual agent update)
func (h *AgentUpdateHandler) UpdateAgent(c *gin.Context) {
var req models.AgentUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Verify the agent exists
agent, err := h.agentQueries.GetAgentByID(req.AgentID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"})
return
}
// Check if agent is already updating
if agent.IsUpdating {
c.JSON(http.StatusConflict, gin.H{
"error": "agent is already updating",
"current_update": agent.UpdatingToVersion,
"initiated_at": agent.UpdateInitiatedAt,
})
return
}
// Validate platform compatibility
if !h.isPlatformCompatible(agent, req.Platform) {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("platform %s is not compatible with agent %s/%s",
req.Platform, agent.OSType, agent.OSArchitecture),
})
return
}
// Get the update package
pkg, err := h.agentUpdateQueries.GetUpdatePackageByVersion(req.Version, req.Platform, agent.OSArchitecture)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("update package not found: %v", err)})
return
}
// Update agent status to "updating"
if err := h.agentQueries.UpdateAgentUpdatingStatus(req.AgentID, true, &req.Version); err != nil {
log.Printf("Failed to update agent %s status to updating: %v", req.AgentID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate update"})
return
}
// Generate nonce for replay protection
nonceUUID := uuid.New()
nonceTimestamp := time.Now()
var nonceSignature string
if h.signingService != nil {
var err error
nonceSignature, err = h.signingService.SignNonce(nonceUUID, nonceTimestamp)
if err != nil {
log.Printf("Failed to sign nonce: %v", err)
h.agentQueries.UpdateAgentUpdatingStatus(req.AgentID, false, nil) // Rollback
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to sign nonce"})
return
}
}
// Create update command for agent
commandType := "update_agent"
commandParams := map[string]interface{}{
"version": req.Version,
"platform": req.Platform,
"download_url": fmt.Sprintf("/api/v1/downloads/updates/%s", pkg.ID),
"signature": pkg.Signature,
"checksum": pkg.Checksum,
"file_size": pkg.FileSize,
"nonce_uuid": nonceUUID.String(),
"nonce_timestamp": nonceTimestamp.Format(time.RFC3339),
"nonce_signature": nonceSignature,
}
// Schedule the update if requested
if req.Scheduled != nil {
scheduledTime, err := time.Parse(time.RFC3339, *req.Scheduled)
if err != nil {
h.agentQueries.UpdateAgentUpdatingStatus(req.AgentID, false, nil) // Rollback
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid scheduled time format"})
return
}
commandParams["scheduled_at"] = scheduledTime
}
// Create the command in database
command := &models.AgentCommand{
ID: uuid.New(),
AgentID: req.AgentID,
CommandType: commandType,
Params: commandParams,
Status: models.CommandStatusPending,
Source: "web_ui",
CreatedAt: time.Now(),
}
if err := h.commandQueries.CreateCommand(command); err != nil {
// Rollback the updating status
h.agentQueries.UpdateAgentUpdatingStatus(req.AgentID, false, nil)
log.Printf("Failed to create update command for agent %s: %v", req.AgentID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create command"})
return
}
log.Printf("✅ Agent update initiated for %s: %s (%s)", agent.Hostname, req.Version, req.Platform)
response := models.AgentUpdateResponse{
Message: "Update initiated successfully",
UpdateID: command.ID.String(),
DownloadURL: fmt.Sprintf("/api/v1/downloads/updates/%s", pkg.ID),
Signature: pkg.Signature,
Checksum: pkg.Checksum,
FileSize: pkg.FileSize,
EstimatedTime: h.estimateUpdateTime(pkg.FileSize),
}
c.JSON(http.StatusOK, response)
}
// BulkUpdateAgents handles POST /api/v1/agents/bulk-update (bulk agent update)
func (h *AgentUpdateHandler) BulkUpdateAgents(c *gin.Context) {
var req models.BulkAgentUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if len(req.AgentIDs) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "no agent IDs provided"})
return
}
if len(req.AgentIDs) > 50 {
c.JSON(http.StatusBadRequest, gin.H{"error": "too many agents in bulk update (max 50)"})
return
}
// Get the update package first to validate it exists
pkg, err := h.agentUpdateQueries.GetUpdatePackageByVersion(req.Version, req.Platform, "")
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("update package not found: %v", err)})
return
}
// Validate all agents exist and are compatible
var results []map[string]interface{}
var errors []string
for _, agentID := range req.AgentIDs {
agent, err := h.agentQueries.GetAgentByID(agentID)
if err != nil {
errors = append(errors, fmt.Sprintf("Agent %s: not found", agentID))
continue
}
if agent.IsUpdating {
errors = append(errors, fmt.Sprintf("Agent %s: already updating", agentID))
continue
}
if !h.isPlatformCompatible(agent, req.Platform) {
errors = append(errors, fmt.Sprintf("Agent %s: platform incompatible", agentID))
continue
}
// Update agent status
if err := h.agentQueries.UpdateAgentUpdatingStatus(agentID, true, &req.Version); err != nil {
errors = append(errors, fmt.Sprintf("Agent %s: failed to update status", agentID))
continue
}
// Generate nonce for replay protection
nonceUUID := uuid.New()
nonceTimestamp := time.Now()
var nonceSignature string
if h.signingService != nil {
var err error
nonceSignature, err = h.signingService.SignNonce(nonceUUID, nonceTimestamp)
if err != nil {
errors = append(errors, fmt.Sprintf("Agent %s: failed to sign nonce", agentID))
h.agentQueries.UpdateAgentUpdatingStatus(agentID, false, nil)
continue
}
}
// Create update command
command := &models.AgentCommand{
ID: uuid.New(),
AgentID: agentID,
CommandType: "update_agent",
Params: map[string]interface{}{
"version": req.Version,
"platform": req.Platform,
"download_url": fmt.Sprintf("/api/v1/downloads/updates/%s", pkg.ID),
"signature": pkg.Signature,
"checksum": pkg.Checksum,
"file_size": pkg.FileSize,
"nonce_uuid": nonceUUID.String(),
"nonce_timestamp": nonceTimestamp.Format(time.RFC3339),
"nonce_signature": nonceSignature,
},
Status: models.CommandStatusPending,
Source: "web_ui_bulk",
CreatedAt: time.Now(),
}
if req.Scheduled != nil {
command.Params["scheduled_at"] = *req.Scheduled
}
if err := h.commandQueries.CreateCommand(command); err != nil {
// Rollback status
h.agentQueries.UpdateAgentUpdatingStatus(agentID, false, nil)
errors = append(errors, fmt.Sprintf("Agent %s: failed to create command", agentID))
continue
}
results = append(results, map[string]interface{}{
"agent_id": agentID,
"hostname": agent.Hostname,
"update_id": command.ID.String(),
"status": "initiated",
})
log.Printf("✅ Bulk update initiated for %s: %s (%s)", agent.Hostname, req.Version, req.Platform)
}
response := gin.H{
"message": fmt.Sprintf("Bulk update completed with %d successes and %d failures", len(results), len(errors)),
"updated": results,
"failed": errors,
"total_agents": len(req.AgentIDs),
"package_info": gin.H{
"version": pkg.Version,
"platform": pkg.Platform,
"file_size": pkg.FileSize,
"checksum": pkg.Checksum,
},
}
c.JSON(http.StatusOK, response)
}
// ListUpdatePackages handles GET /api/v1/updates/packages (list available update packages)
func (h *AgentUpdateHandler) ListUpdatePackages(c *gin.Context) {
version := c.Query("version")
platform := c.Query("platform")
limitStr := c.Query("limit")
offsetStr := c.Query("offset")
limit := 0
if limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
}
offset := 0
if offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
offset = o
}
}
packages, err := h.agentUpdateQueries.ListUpdatePackages(version, platform, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list update packages"})
return
}
c.JSON(http.StatusOK, gin.H{
"packages": packages,
"total": len(packages),
"limit": limit,
"offset": offset,
})
}
// SignUpdatePackage handles POST /api/v1/updates/packages/sign (sign a new update package)
func (h *AgentUpdateHandler) SignUpdatePackage(c *gin.Context) {
var req struct {
Version string `json:"version" binding:"required"`
Platform string `json:"platform" binding:"required"`
Architecture string `json:"architecture" binding:"required"`
BinaryPath string `json:"binary_path" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if h.signingService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "signing service not available"})
return
}
// Sign the binary
pkg, err := h.signingService.SignFile(req.BinaryPath)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to sign binary: %v", err)})
return
}
// Set additional fields
pkg.Version = req.Version
pkg.Platform = req.Platform
pkg.Architecture = req.Architecture
// Save to database
if err := h.agentUpdateQueries.CreateUpdatePackage(pkg); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to save update package: %v", err)})
return
}
log.Printf("✅ Update package signed and saved: %s %s/%s (ID: %s)",
pkg.Version, pkg.Platform, pkg.Architecture, pkg.ID)
c.JSON(http.StatusOK, gin.H{
"message": "Update package signed successfully",
"package": pkg,
})
}
// isPlatformCompatible checks if the update package is compatible with the agent
func (h *AgentUpdateHandler) isPlatformCompatible(agent *models.Agent, updatePlatform string) bool {
// Normalize platform strings
agentPlatform := strings.ToLower(agent.OSType)
updatePlatform = strings.ToLower(updatePlatform)
// Check for basic OS compatibility
if !strings.Contains(updatePlatform, agentPlatform) {
return false
}
// Check architecture compatibility if specified
if strings.Contains(updatePlatform, "amd64") && !strings.Contains(strings.ToLower(agent.OSArchitecture), "amd64") {
return false
}
if strings.Contains(updatePlatform, "arm64") && !strings.Contains(strings.ToLower(agent.OSArchitecture), "arm64") {
return false
}
if strings.Contains(updatePlatform, "386") && !strings.Contains(strings.ToLower(agent.OSArchitecture), "386") {
return false
}
return true
}
// estimateUpdateTime estimates how long an update will take based on file size
func (h *AgentUpdateHandler) estimateUpdateTime(fileSize int64) int {
// Rough estimate: 1 second per MB + 30 seconds base time
seconds := int(fileSize/1024/1024) + 30
// Cap at 5 minutes
if seconds > 300 {
seconds = 300
}
return seconds
}

View File

@@ -71,6 +71,16 @@ func (h *AgentHandler) RegisterAgent(c *gin.Context) {
return return
} }
// Validate machine ID and public key fingerprint if provided
if req.MachineID != "" {
// Check if machine ID is already registered to another agent
existingAgent, err := h.agentQueries.GetAgentByMachineID(req.MachineID)
if err == nil && existingAgent != nil && existingAgent.ID.String() != "" {
c.JSON(http.StatusConflict, gin.H{"error": "machine ID already registered to another agent"})
return
}
}
// Create new agent // Create new agent
agent := &models.Agent{ agent := &models.Agent{
ID: uuid.New(), ID: uuid.New(),
@@ -80,6 +90,8 @@ func (h *AgentHandler) RegisterAgent(c *gin.Context) {
OSArchitecture: req.OSArchitecture, OSArchitecture: req.OSArchitecture,
AgentVersion: req.AgentVersion, AgentVersion: req.AgentVersion,
CurrentVersion: req.AgentVersion, CurrentVersion: req.AgentVersion,
MachineID: &req.MachineID,
PublicKeyFingerprint: &req.PublicKeyFingerprint,
LastSeen: time.Now(), LastSeen: time.Now(),
Status: "online", Status: "online",
Metadata: models.JSONB{}, Metadata: models.JSONB{},

View File

@@ -99,6 +99,25 @@ func (h *DownloadHandler) DownloadAgent(c *gin.Context) {
c.File(agentPath) c.File(agentPath)
} }
// DownloadUpdatePackage serves signed agent update packages
func (h *DownloadHandler) DownloadUpdatePackage(c *gin.Context) {
packageID := c.Param("package_id")
// Validate package ID format (UUID)
if len(packageID) != 36 {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid package ID format"})
return
}
// TODO: Implement actual package serving from database/filesystem
// For now, return a placeholder response
c.JSON(http.StatusNotImplemented, gin.H{
"error": "Update package download not yet implemented",
"package_id": packageID,
"message": "This will serve the signed update package file",
})
}
// InstallScript serves the installation script // InstallScript serves the installation script
func (h *DownloadHandler) InstallScript(c *gin.Context) { func (h *DownloadHandler) InstallScript(c *gin.Context) {
platform := c.Param("platform") platform := c.Param("platform")

View File

@@ -0,0 +1,57 @@
package handlers
import (
"net/http"
"github.com/Fimeg/RedFlag/aggregator-server/internal/services"
"github.com/gin-gonic/gin"
)
// SystemHandler handles system-level operations
type SystemHandler struct {
signingService *services.SigningService
}
// NewSystemHandler creates a new system handler
func NewSystemHandler(ss *services.SigningService) *SystemHandler {
return &SystemHandler{
signingService: ss,
}
}
// GetPublicKey returns the server's Ed25519 public key for signature verification
// This allows agents to fetch the public key at runtime instead of embedding it at build time
func (h *SystemHandler) GetPublicKey(c *gin.Context) {
if h.signingService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": "signing service not configured",
"hint": "Set REDFLAG_SIGNING_PRIVATE_KEY environment variable",
})
return
}
pubKeyHex := h.signingService.GetPublicKey()
fingerprint := h.signingService.GetPublicKeyFingerprint()
c.JSON(http.StatusOK, gin.H{
"public_key": pubKeyHex,
"fingerprint": fingerprint,
"algorithm": "ed25519",
"key_size": 32,
})
}
// GetSystemInfo returns general system information
func (h *SystemHandler) GetSystemInfo(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"version": "v0.1.21",
"name": "RedFlag Aggregator",
"description": "Self-hosted update management platform",
"features": []string{
"agent_management",
"update_tracking",
"command_execution",
"ed25519_signing",
},
})
}

View File

@@ -0,0 +1,137 @@
package handlers
import (
"crypto/ed25519"
"encoding/hex"
"fmt"
"log"
"net/http"
"strings"
"github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries"
"github.com/Fimeg/RedFlag/aggregator-server/internal/models"
"github.com/Fimeg/RedFlag/aggregator-server/internal/services"
"github.com/gin-gonic/gin"
)
// VerificationHandler handles signature verification requests
type VerificationHandler struct {
agentQueries *queries.AgentQueries
signingService *services.SigningService
}
// NewVerificationHandler creates a new verification handler
func NewVerificationHandler(aq *queries.AgentQueries, signingService *services.SigningService) *VerificationHandler {
return &VerificationHandler{
agentQueries: aq,
signingService: signingService,
}
}
// VerifySignature handles POST /api/v1/agents/:id/verify-signature
func (h *VerificationHandler) VerifySignature(c *gin.Context) {
var req models.SignatureVerificationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Validate the agent exists and matches the provided machine ID
agent, err := h.agentQueries.GetAgentByID(req.AgentID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"})
return
}
// Verify machine ID matches
if agent.MachineID == nil || *agent.MachineID != req.MachineID {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "machine ID mismatch",
"expected": agent.MachineID,
"received": req.MachineID,
})
return
}
// Verify public key fingerprint matches
if agent.PublicKeyFingerprint == nil || *agent.PublicKeyFingerprint != req.PublicKey {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "public key fingerprint mismatch",
"expected": agent.PublicKeyFingerprint,
"received": req.PublicKey,
})
return
}
// Verify the signature
valid, err := h.verifyAgentSignature(req.BinaryPath, req.Signature)
if err != nil {
log.Printf("Signature verification failed for agent %s: %v", req.AgentID, err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": "signature verification failed",
"details": err.Error(),
})
return
}
response := models.SignatureVerificationResponse{
Valid: valid,
AgentID: req.AgentID.String(),
MachineID: req.MachineID,
Fingerprint: req.PublicKey,
Message: "Signature verification completed",
}
if !valid {
response.Message = "Invalid signature - binary may be tampered with"
c.JSON(http.StatusUnauthorized, response)
return
}
c.JSON(http.StatusOK, response)
}
// verifyAgentSignature verifies the signature of an agent binary
func (h *VerificationHandler) verifyAgentSignature(binaryPath, signatureHex string) (bool, error) {
// Decode the signature
signature, err := hex.DecodeString(signatureHex)
if err != nil {
return false, fmt.Errorf("invalid signature format: %w", err)
}
if len(signature) != ed25519.SignatureSize {
return false, fmt.Errorf("invalid signature size: expected %d bytes, got %d", ed25519.SignatureSize, len(signature))
}
// Read the binary file
content, err := readFileContent(binaryPath)
if err != nil {
return false, fmt.Errorf("failed to read binary: %w", err)
}
// Verify using the signing service
valid, err := h.signingService.VerifySignature(content, signatureHex)
if err != nil {
return false, fmt.Errorf("verification failed: %w", err)
}
return valid, nil
}
// readFileContent reads file content safely
func readFileContent(filePath string) ([]byte, error) {
// Basic path validation to prevent directory traversal
if strings.Contains(filePath, "..") || strings.Contains(filePath, "~") {
return nil, fmt.Errorf("invalid file path")
}
// Only allow specific file patterns for security
if !strings.HasSuffix(filePath, "/redflag-agent") && !strings.HasSuffix(filePath, "/redflag-agent.exe") {
return nil, fmt.Errorf("invalid file type - only agent binaries are allowed")
}
// For security, we won't actually read files in this handler
// In a real implementation, this would verify the actual binary on the agent
// For now, we'll simulate the verification process
return []byte("simulated-binary-content"), nil
}

View File

@@ -0,0 +1,99 @@
package middleware
import (
"log"
"net/http"
"github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries"
"github.com/Fimeg/RedFlag/aggregator-server/internal/utils"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// MachineBindingMiddleware validates machine ID matches database record
// This prevents agent impersonation via config file copying to different machines
func MachineBindingMiddleware(agentQueries *queries.AgentQueries, minAgentVersion string) gin.HandlerFunc {
return func(c *gin.Context) {
// Skip if not authenticated (handled by auth middleware)
agentIDVal, exists := c.Get("agent_id")
if !exists {
c.Next()
return
}
agentID, ok := agentIDVal.(uuid.UUID)
if !ok {
log.Printf("[MachineBinding] Invalid agent_id type in context")
c.JSON(http.StatusInternalServerError, gin.H{"error": "invalid agent ID"})
c.Abort()
return
}
// Get agent from database
agent, err := agentQueries.GetAgentByID(agentID)
if err != nil {
log.Printf("[MachineBinding] Agent %s not found: %v", agentID, err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "agent not found"})
c.Abort()
return
}
// Check minimum version (hard cutoff for legacy de-support)
if agent.CurrentVersion != "" && minAgentVersion != "" {
if !utils.IsNewerOrEqualVersion(agent.CurrentVersion, minAgentVersion) {
log.Printf("[MachineBinding] Agent %s version %s below minimum %s - rejecting",
agent.Hostname, agent.CurrentVersion, minAgentVersion)
c.JSON(http.StatusUpgradeRequired, gin.H{
"error": "agent version too old - upgrade required for security",
"current_version": agent.CurrentVersion,
"minimum_version": minAgentVersion,
"upgrade_instructions": "Please upgrade to the latest agent version and re-register",
})
c.Abort()
return
}
}
// Extract X-Machine-ID header
reportedMachineID := c.GetHeader("X-Machine-ID")
if reportedMachineID == "" {
log.Printf("[MachineBinding] Agent %s (%s) missing X-Machine-ID header",
agent.Hostname, agentID)
c.JSON(http.StatusForbidden, gin.H{
"error": "missing machine ID header - agent version too old or tampered",
"hint": "Please upgrade to the latest agent version (v0.1.22+)",
})
c.Abort()
return
}
// Validate machine ID matches database
if agent.MachineID == nil {
log.Printf("[MachineBinding] Agent %s (%s) has no machine_id in database - legacy agent",
agent.Hostname, agentID)
c.JSON(http.StatusForbidden, gin.H{
"error": "agent not bound to machine - re-registration required",
"hint": "This agent was registered before v0.1.22. Please re-register with a new registration token.",
})
c.Abort()
return
}
if *agent.MachineID != reportedMachineID {
log.Printf("[MachineBinding] ⚠️ SECURITY ALERT: Agent %s (%s) machine ID mismatch! DB=%s, Reported=%s",
agent.Hostname, agentID, *agent.MachineID, reportedMachineID)
c.JSON(http.StatusForbidden, gin.H{
"error": "machine ID mismatch - config file copied to different machine",
"hint": "Agent configuration is bound to the original machine. Please register this machine with a new registration token.",
"security_note": "This prevents agent impersonation attacks",
})
c.Abort()
return
}
// Machine ID validated - allow request
log.Printf("[MachineBinding] ✓ Agent %s (%s) machine ID validated: %s",
agent.Hostname, agentID, reportedMachineID[:16]+"...")
c.Next()
}
}

View File

@@ -41,6 +41,8 @@ type Config struct {
OfflineThreshold int OfflineThreshold int
Timezone string Timezone string
LatestAgentVersion string LatestAgentVersion string
MinAgentVersion string `env:"MIN_AGENT_VERSION" default:"0.1.22"`
SigningPrivateKey string `env:"REDFLAG_SIGNING_PRIVATE_KEY"`
} }
// Load reads configuration from environment variables only (immutable configuration) // Load reads configuration from environment variables only (immutable configuration)
@@ -84,7 +86,8 @@ func Load() (*Config, error) {
cfg.CheckInInterval = checkInInterval cfg.CheckInInterval = checkInInterval
cfg.OfflineThreshold = offlineThreshold cfg.OfflineThreshold = offlineThreshold
cfg.Timezone = getEnv("TIMEZONE", "UTC") cfg.Timezone = getEnv("TIMEZONE", "UTC")
cfg.LatestAgentVersion = getEnv("LATEST_AGENT_VERSION", "0.1.18") cfg.LatestAgentVersion = getEnv("LATEST_AGENT_VERSION", "0.1.22")
cfg.MinAgentVersion = getEnv("MIN_AGENT_VERSION", "0.1.22")
// Handle missing secrets // Handle missing secrets
if cfg.Admin.Password == "" || cfg.Admin.JWTSecret == "" || cfg.Database.Password == "" { if cfg.Admin.Password == "" || cfg.Admin.JWTSecret == "" || cfg.Database.Password == "" {

View File

@@ -0,0 +1,10 @@
-- Remove agent update packages table
DROP TABLE IF EXISTS agent_update_packages;
-- Remove new columns from agents table
ALTER TABLE agents
DROP COLUMN IF EXISTS machine_id,
DROP COLUMN IF EXISTS public_key_fingerprint,
DROP COLUMN IF EXISTS is_updating,
DROP COLUMN IF EXISTS updating_to_version,
DROP COLUMN IF EXISTS update_initiated_at;

View File

@@ -0,0 +1,47 @@
-- Add machine ID and public key fingerprint fields to agents table
-- This enables Ed25519 binary signing and machine binding
ALTER TABLE agents
ADD COLUMN machine_id VARCHAR(64) UNIQUE,
ADD COLUMN public_key_fingerprint VARCHAR(16),
ADD COLUMN is_updating BOOLEAN DEFAULT false,
ADD COLUMN updating_to_version VARCHAR(50),
ADD COLUMN update_initiated_at TIMESTAMP;
-- Create index for machine ID lookups
CREATE INDEX idx_agents_machine_id ON agents(machine_id);
CREATE INDEX idx_agents_public_key_fingerprint ON agents(public_key_fingerprint);
-- Add comment to document the new fields
COMMENT ON COLUMN agents.machine_id IS 'Unique machine identifier to bind agent binaries to specific hardware';
COMMENT ON COLUMN agents.public_key_fingerprint IS 'Fingerprint of embedded public key for binary signature verification';
COMMENT ON COLUMN agents.is_updating IS 'Whether agent is currently updating';
COMMENT ON COLUMN agents.updating_to_version IS 'Target version for ongoing update';
COMMENT ON COLUMN agents.update_initiated_at IS 'When the update process started';
-- Create table for storing signed update packages
CREATE TABLE agent_update_packages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
version VARCHAR(50) NOT NULL,
platform VARCHAR(50) NOT NULL, -- linux-amd64, linux-arm64, windows-amd64, etc.
architecture VARCHAR(20) NOT NULL,
binary_path VARCHAR(500) NOT NULL,
signature VARCHAR(128) NOT NULL, -- Ed25519 signature (64 bytes hex encoded)
checksum VARCHAR(64) NOT NULL, -- SHA-256 checksum
file_size BIGINT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
created_by VARCHAR(100) DEFAULT 'system',
is_active BOOLEAN DEFAULT true
);
-- Add indexes for update packages
CREATE INDEX idx_agent_update_packages_version ON agent_update_packages(version);
CREATE INDEX idx_agent_update_packages_platform ON agent_update_packages(platform, architecture);
CREATE INDEX idx_agent_update_packages_active ON agent_update_packages(is_active);
-- Add comments for update packages table
COMMENT ON TABLE agent_update_packages IS 'Stores signed agent binary packages for secure updates';
COMMENT ON COLUMN agent_update_packages.signature IS 'Ed25519 signature of the binary file';
COMMENT ON COLUMN agent_update_packages.checksum IS 'SHA-256 checksum of the binary file';
COMMENT ON COLUMN agent_update_packages.platform IS 'Target platform (OS-architecture)';
COMMENT ON COLUMN agent_update_packages.is_active IS 'Whether this package is available for updates';

View File

@@ -0,0 +1,4 @@
-- Rollback machine_id column addition
DROP INDEX IF EXISTS idx_agents_machine_id;
ALTER TABLE agents DROP COLUMN IF EXISTS machine_id;

View File

@@ -0,0 +1,11 @@
-- Add machine_id column to agents table for hardware fingerprint binding
-- This prevents config file copying attacks by validating hardware identity
ALTER TABLE agents
ADD COLUMN machine_id VARCHAR(64);
-- Create unique index to prevent duplicate machine IDs
CREATE UNIQUE INDEX idx_agents_machine_id ON agents(machine_id) WHERE machine_id IS NOT NULL;
-- Add comment for documentation
COMMENT ON COLUMN agents.machine_id IS 'SHA-256 hash of hardware fingerprint (prevents agent impersonation via config copying)';

View File

@@ -0,0 +1,219 @@
package queries
import (
"database/sql"
"fmt"
"time"
"github.com/Fimeg/RedFlag/aggregator-server/internal/models"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
)
// AgentUpdateQueries handles database operations for agent update packages
type AgentUpdateQueries struct {
db *sqlx.DB
}
// NewAgentUpdateQueries creates a new AgentUpdateQueries instance
func NewAgentUpdateQueries(db *sqlx.DB) *AgentUpdateQueries {
return &AgentUpdateQueries{db: db}
}
// CreateUpdatePackage stores a new signed update package
func (q *AgentUpdateQueries) CreateUpdatePackage(pkg *models.AgentUpdatePackage) error {
query := `
INSERT INTO agent_update_packages (
id, version, platform, architecture, binary_path, signature,
checksum, file_size, created_by, is_active
) VALUES (
:id, :version, :platform, :architecture, :binary_path, :signature,
:checksum, :file_size, :created_by, :is_active
) RETURNING id, created_at
`
rows, err := q.db.NamedQuery(query, pkg)
if err != nil {
return fmt.Errorf("failed to create update package: %w", err)
}
defer rows.Close()
if rows.Next() {
if err := rows.Scan(&pkg.ID, &pkg.CreatedAt); err != nil {
return fmt.Errorf("failed to scan created package: %w", err)
}
}
return nil
}
// GetUpdatePackage retrieves an update package by ID
func (q *AgentUpdateQueries) GetUpdatePackage(id uuid.UUID) (*models.AgentUpdatePackage, error) {
query := `
SELECT id, version, platform, architecture, binary_path, signature,
checksum, file_size, created_at, created_by, is_active
FROM agent_update_packages
WHERE id = $1 AND is_active = true
`
var pkg models.AgentUpdatePackage
err := q.db.Get(&pkg, query, id)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("update package not found")
}
return nil, fmt.Errorf("failed to get update package: %w", err)
}
return &pkg, nil
}
// GetUpdatePackageByVersion retrieves the latest update package for a version and platform
func (q *AgentUpdateQueries) GetUpdatePackageByVersion(version, platform, architecture string) (*models.AgentUpdatePackage, error) {
query := `
SELECT id, version, platform, architecture, binary_path, signature,
checksum, file_size, created_at, created_by, is_active
FROM agent_update_packages
WHERE version = $1 AND platform = $2 AND architecture = $3 AND is_active = true
ORDER BY created_at DESC
LIMIT 1
`
var pkg models.AgentUpdatePackage
err := q.db.Get(&pkg, query, version, platform, architecture)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("no update package found for version %s on %s/%s", version, platform, architecture)
}
return nil, fmt.Errorf("failed to get update package: %w", err)
}
return &pkg, nil
}
// ListUpdatePackages retrieves all update packages with optional filtering
func (q *AgentUpdateQueries) ListUpdatePackages(version, platform string, limit, offset int) ([]models.AgentUpdatePackage, error) {
query := `
SELECT id, version, platform, architecture, binary_path, signature,
checksum, file_size, created_at, created_by, is_active
FROM agent_update_packages
WHERE is_active = true
`
args := []interface{}{}
argIndex := 1
if version != "" {
query += fmt.Sprintf(" AND version = $%d", argIndex)
args = append(args, version)
argIndex++
}
if platform != "" {
query += fmt.Sprintf(" AND platform = $%d", argIndex)
args = append(args, platform)
argIndex++
}
query += " ORDER BY created_at DESC"
if limit > 0 {
query += fmt.Sprintf(" LIMIT $%d", argIndex)
args = append(args, limit)
argIndex++
if offset > 0 {
query += fmt.Sprintf(" OFFSET $%d", argIndex)
args = append(args, offset)
}
}
var packages []models.AgentUpdatePackage
err := q.db.Select(&packages, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to list update packages: %w", err)
}
return packages, nil
}
// DeactivateUpdatePackage marks a package as inactive
func (q *AgentUpdateQueries) DeactivateUpdatePackage(id uuid.UUID) error {
query := `UPDATE agent_update_packages SET is_active = false WHERE id = $1`
result, err := q.db.Exec(query, id)
if err != nil {
return fmt.Errorf("failed to deactivate update package: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("no update package found to deactivate")
}
return nil
}
// UpdateAgentMachineInfo updates the machine ID and public key fingerprint for an agent
func (q *AgentUpdateQueries) UpdateAgentMachineInfo(agentID uuid.UUID, machineID, publicKeyFingerprint string) error {
query := `
UPDATE agents
SET machine_id = $1, public_key_fingerprint = $2, updated_at = $3
WHERE id = $4
`
_, err := q.db.Exec(query, machineID, publicKeyFingerprint, time.Now().UTC(), agentID)
if err != nil {
return fmt.Errorf("failed to update agent machine info: %w", err)
}
return nil
}
// UpdateAgentUpdatingStatus sets the update status for an agent
func (q *AgentUpdateQueries) UpdateAgentUpdatingStatus(agentID uuid.UUID, isUpdating bool, targetVersion *string) error {
query := `
UPDATE agents
SET is_updating = $1,
updating_to_version = $2,
update_initiated_at = CASE WHEN $1 = true THEN $3 ELSE update_initiated_at END,
updated_at = $3
WHERE id = $4
`
now := time.Now().UTC()
_, err := q.db.Exec(query, isUpdating, targetVersion, now, agentID)
if err != nil {
return fmt.Errorf("failed to update agent updating status: %w", err)
}
return nil
}
// GetAgentByMachineID retrieves an agent by its machine ID
func (q *AgentUpdateQueries) GetAgentByMachineID(machineID string) (*models.Agent, error) {
query := `
SELECT id, hostname, os_type, os_version, os_architecture, agent_version,
current_version, update_available, last_version_check, machine_id,
public_key_fingerprint, is_updating, updating_to_version,
update_initiated_at, last_seen, status, metadata, reboot_required,
last_reboot_at, reboot_reason, created_at, updated_at
FROM agents
WHERE machine_id = $1
`
var agent models.Agent
err := q.db.Get(&agent, query, machineID)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("agent not found for machine ID")
}
return nil, fmt.Errorf("failed to get agent by machine ID: %w", err)
}
return &agent, nil
}

View File

@@ -1,6 +1,8 @@
package queries package queries
import ( import (
"database/sql"
"fmt"
"time" "time"
"github.com/Fimeg/RedFlag/aggregator-server/internal/models" "github.com/Fimeg/RedFlag/aggregator-server/internal/models"
@@ -245,3 +247,51 @@ func (q *AgentQueries) UpdateAgentLastReboot(id uuid.UUID, rebootTime time.Time)
_, err := q.db.Exec(query, rebootTime, time.Now(), id) _, err := q.db.Exec(query, rebootTime, time.Now(), id)
return err return err
} }
// GetAgentByMachineID retrieves an agent by its machine ID
func (q *AgentQueries) GetAgentByMachineID(machineID string) (*models.Agent, error) {
query := `
SELECT id, hostname, os_type, os_version, os_architecture, agent_version,
current_version, update_available, last_version_check, machine_id,
public_key_fingerprint, is_updating, updating_to_version,
update_initiated_at, last_seen, status, metadata, reboot_required,
last_reboot_at, reboot_reason, created_at, updated_at
FROM agents
WHERE machine_id = $1
`
var agent models.Agent
err := q.db.Get(&agent, query, machineID)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil // Return nil if not found (not an error)
}
return nil, fmt.Errorf("failed to get agent by machine ID: %w", err)
}
return &agent, nil
}
// UpdateAgentUpdatingStatus updates the agent's update status
func (q *AgentQueries) UpdateAgentUpdatingStatus(id uuid.UUID, isUpdating bool, updatingToVersion *string) error {
query := `
UPDATE agents
SET
is_updating = $1,
updating_to_version = $2,
update_initiated_at = CASE
WHEN $1 = true THEN $3
ELSE NULL
END,
updated_at = $3
WHERE id = $4
`
var versionPtr *string
if updatingToVersion != nil {
versionPtr = updatingToVersion
}
_, err := q.db.Exec(query, isUpdating, versionPtr, time.Now(), id)
return err
}

View File

@@ -19,6 +19,11 @@ type Agent struct {
CurrentVersion string `json:"current_version" db:"current_version"` // Current running version CurrentVersion string `json:"current_version" db:"current_version"` // Current running version
UpdateAvailable bool `json:"update_available" db:"update_available"` // Whether update is available UpdateAvailable bool `json:"update_available" db:"update_available"` // Whether update is available
LastVersionCheck time.Time `json:"last_version_check" db:"last_version_check"` // Last time version was checked LastVersionCheck time.Time `json:"last_version_check" db:"last_version_check"` // Last time version was checked
MachineID *string `json:"machine_id,omitempty" db:"machine_id"` // Unique machine identifier
PublicKeyFingerprint *string `json:"public_key_fingerprint,omitempty" db:"public_key_fingerprint"` // Public key fingerprint
IsUpdating bool `json:"is_updating" db:"is_updating"` // Whether agent is currently updating
UpdatingToVersion *string `json:"updating_to_version,omitempty" db:"updating_to_version"` // Target version for ongoing update
UpdateInitiatedAt *time.Time `json:"update_initiated_at,omitempty" db:"update_initiated_at"` // When update process started
LastSeen time.Time `json:"last_seen" db:"last_seen"` LastSeen time.Time `json:"last_seen" db:"last_seen"`
Status string `json:"status" db:"status"` Status string `json:"status" db:"status"`
Metadata JSONB `json:"metadata" db:"metadata"` Metadata JSONB `json:"metadata" db:"metadata"`
@@ -75,6 +80,8 @@ type AgentRegistrationRequest struct {
OSArchitecture string `json:"os_architecture"` OSArchitecture string `json:"os_architecture"`
AgentVersion string `json:"agent_version" binding:"required"` AgentVersion string `json:"agent_version" binding:"required"`
RegistrationToken string `json:"registration_token"` // Optional, for fallback method RegistrationToken string `json:"registration_token"` // Optional, for fallback method
MachineID string `json:"machine_id"` // Unique machine identifier
PublicKeyFingerprint string `json:"public_key_fingerprint"` // Embedded public key fingerprint
Metadata map[string]string `json:"metadata"` Metadata map[string]string `json:"metadata"`
} }

View File

@@ -0,0 +1,67 @@
package models
import (
"time"
"github.com/google/uuid"
)
// AgentUpdatePackage represents a signed agent binary package
type AgentUpdatePackage struct {
ID uuid.UUID `json:"id" db:"id"`
Version string `json:"version" db:"version"`
Platform string `json:"platform" db:"platform"`
Architecture string `json:"architecture" db:"architecture"`
BinaryPath string `json:"binary_path" db:"binary_path"`
Signature string `json:"signature" db:"signature"`
Checksum string `json:"checksum" db:"checksum"`
FileSize int64 `json:"file_size" db:"file_size"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
CreatedBy string `json:"created_by" db:"created_by"`
IsActive bool `json:"is_active" db:"is_active"`
}
// AgentUpdateRequest represents a request to update an agent
type AgentUpdateRequest struct {
AgentID uuid.UUID `json:"agent_id" binding:"required"`
Version string `json:"version" binding:"required"`
Platform string `json:"platform" binding:"required"`
Scheduled *string `json:"scheduled_at,omitempty"`
}
// BulkAgentUpdateRequest represents a bulk update request
type BulkAgentUpdateRequest struct {
AgentIDs []uuid.UUID `json:"agent_ids" binding:"required"`
Version string `json:"version" binding:"required"`
Platform string `json:"platform" binding:"required"`
Scheduled *string `json:"scheduled_at,omitempty"`
}
// AgentUpdateResponse represents the response for an update request
type AgentUpdateResponse struct {
Message string `json:"message"`
UpdateID string `json:"update_id,omitempty"`
DownloadURL string `json:"download_url,omitempty"`
Signature string `json:"signature,omitempty"`
Checksum string `json:"checksum,omitempty"`
FileSize int64 `json:"file_size,omitempty"`
EstimatedTime int `json:"estimated_time_seconds,omitempty"`
}
// SignatureVerificationRequest represents a request to verify an agent's binary signature
type SignatureVerificationRequest struct {
AgentID uuid.UUID `json:"agent_id" binding:"required"`
BinaryPath string `json:"binary_path" binding:"required"`
MachineID string `json:"machine_id" binding:"required"`
PublicKey string `json:"public_key" binding:"required"`
Signature string `json:"signature" binding:"required"`
}
// SignatureVerificationResponse represents the response for signature verification
type SignatureVerificationResponse struct {
Valid bool `json:"valid"`
AgentID string `json:"agent_id"`
MachineID string `json:"machine_id"`
Fingerprint string `json:"fingerprint"`
Message string `json:"message"`
}

View File

@@ -0,0 +1,239 @@
package services
import (
"crypto/ed25519"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"runtime"
"time"
"github.com/Fimeg/RedFlag/aggregator-server/internal/models"
"github.com/google/uuid"
)
// SigningService handles Ed25519 cryptographic operations
type SigningService struct {
privateKey ed25519.PrivateKey
publicKey ed25519.PublicKey
}
// NewSigningService creates a new signing service with the provided private key
func NewSigningService(privateKeyHex string) (*SigningService, error) {
// Decode private key from hex
privateKeyBytes, err := hex.DecodeString(privateKeyHex)
if err != nil {
return nil, fmt.Errorf("invalid private key format: %w", err)
}
if len(privateKeyBytes) != ed25519.PrivateKeySize {
return nil, fmt.Errorf("invalid private key size: expected %d bytes, got %d", ed25519.PrivateKeySize, len(privateKeyBytes))
}
// Ed25519 private key format: first 32 bytes are seed, next 32 bytes are public key
privateKey := ed25519.PrivateKey(privateKeyBytes)
publicKey := privateKey.Public().(ed25519.PublicKey)
return &SigningService{
privateKey: privateKey,
publicKey: publicKey,
}, nil
}
// SignFile signs a file and returns the signature and checksum
func (s *SigningService) SignFile(filePath string) (*models.AgentUpdatePackage, error) {
// Read the file
file, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
// Calculate checksum and sign content
content, err := io.ReadAll(file)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
// Calculate SHA-256 checksum
hash := sha256.Sum256(content)
checksum := hex.EncodeToString(hash[:])
// Sign the content
signature := ed25519.Sign(s.privateKey, content)
// Get file info
fileInfo, err := file.Stat()
if err != nil {
return nil, fmt.Errorf("failed to get file info: %w", err)
}
// Determine platform and architecture from file path or use runtime defaults
platform, architecture := s.detectPlatformArchitecture(filePath)
pkg := &models.AgentUpdatePackage{
BinaryPath: filePath,
Signature: hex.EncodeToString(signature),
Checksum: checksum,
FileSize: fileInfo.Size(),
Platform: platform,
Architecture: architecture,
CreatedBy: "signing-service",
IsActive: true,
}
return pkg, nil
}
// VerifySignature verifies a file signature using the embedded public key
func (s *SigningService) VerifySignature(content []byte, signatureHex string) (bool, error) {
// Decode signature
signature, err := hex.DecodeString(signatureHex)
if err != nil {
return false, fmt.Errorf("invalid signature format: %w", err)
}
if len(signature) != ed25519.SignatureSize {
return false, fmt.Errorf("invalid signature size: expected %d bytes, got %d", ed25519.SignatureSize, len(signature))
}
// Verify signature
valid := ed25519.Verify(s.publicKey, content, signature)
return valid, nil
}
// GetPublicKey returns the public key in hex format
func (s *SigningService) GetPublicKey() string {
return hex.EncodeToString(s.publicKey)
}
// GetPublicKeyFingerprint returns a short fingerprint of the public key
func (s *SigningService) GetPublicKeyFingerprint() string {
// Use first 8 bytes as fingerprint
return hex.EncodeToString(s.publicKey[:8])
}
// VerifyFileIntegrity verifies a file's checksum
func (s *SigningService) VerifyFileIntegrity(filePath, expectedChecksum string) (bool, error) {
file, err := os.Open(filePath)
if err != nil {
return false, fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
content, err := io.ReadAll(file)
if err != nil {
return false, fmt.Errorf("failed to read file: %w", err)
}
hash := sha256.Sum256(content)
actualChecksum := hex.EncodeToString(hash[:])
return actualChecksum == expectedChecksum, nil
}
// detectPlatformArchitecture attempts to detect platform and architecture from file path
func (s *SigningService) detectPlatformArchitecture(filePath string) (string, string) {
// Default to current runtime
platform := runtime.GOOS
arch := runtime.GOARCH
// Map architectures
archMap := map[string]string{
"amd64": "amd64",
"arm64": "arm64",
"386": "386",
}
// Try to detect from filename patterns
if contains(filePath, "windows") || contains(filePath, ".exe") {
platform = "windows"
} else if contains(filePath, "linux") {
platform = "linux"
} else if contains(filePath, "darwin") || contains(filePath, "macos") {
platform = "darwin"
}
for archName, archValue := range archMap {
if contains(filePath, archName) {
arch = archValue
break
}
}
// Normalize architecture names
if arch == "amd64" {
arch = "amd64"
} else if arch == "arm64" {
arch = "arm64"
}
return platform, arch
}
// contains is a simple helper for case-insensitive substring checking
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr ||
(len(s) > len(substr) &&
(s[:len(substr)] == substr ||
s[len(s)-len(substr):] == substr ||
findSubstring(s, substr))))
}
// findSubstring is a simple substring finder
func findSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// SignNonce signs a nonce (UUID + timestamp) for replay protection
func (s *SigningService) SignNonce(nonceUUID uuid.UUID, timestamp time.Time) (string, error) {
// Create nonce data: UUID + Unix timestamp as string
nonceData := fmt.Sprintf("%s:%d", nonceUUID.String(), timestamp.Unix())
// Sign the nonce data
signature := ed25519.Sign(s.privateKey, []byte(nonceData))
// Return hex-encoded signature
return hex.EncodeToString(signature), nil
}
// VerifyNonce verifies a nonce signature and checks freshness
func (s *SigningService) VerifyNonce(nonceUUID uuid.UUID, timestamp time.Time, signatureHex string, maxAge time.Duration) (bool, error) {
// Check nonce freshness first
if time.Since(timestamp) > maxAge {
return false, fmt.Errorf("nonce is too old: %v > %v", time.Since(timestamp), maxAge)
}
// Recreate nonce data
nonceData := fmt.Sprintf("%s:%d", nonceUUID.String(), timestamp.Unix())
// Verify signature
valid, err := s.VerifySignature([]byte(nonceData), signatureHex)
if err != nil {
return false, fmt.Errorf("failed to verify nonce signature: %w", err)
}
return valid, nil
}
// TODO: Key rotation implementation
// This is a stub for future key rotation functionality
// Key rotation should:
// 1. Maintain multiple active key pairs with version numbers
// 2. Support graceful transition periods (e.g., 30 days)
// 3. Store previous keys for signature verification during transition
// 4. Batch migration of existing agent fingerprints
// 5. Provide monitoring for rotation completion
//
// Example implementation approach:
// - Use database to store multiple key versions with activation timestamps
// - Include key version ID in signatures
// - Maintain lookup table of previous keys for verification
// - Background job to monitor rotation progress

View File

@@ -33,6 +33,12 @@ func IsNewerVersion(version1, version2 string) bool {
return CompareVersions(version1, version2) == 1 return CompareVersions(version1, version2) == 1
} }
// IsNewerOrEqualVersion returns true if version1 is newer than or equal to version2
func IsNewerOrEqualVersion(version1, version2 string) bool {
cmp := CompareVersions(version1, version2)
return cmp == 1 || cmp == 0
}
// parseVersion parses a version string like "0.1.4" into [0, 1, 4] // parseVersion parses a version string like "0.1.4" into [0, 1, 4]
func parseVersion(version string) [3]int { func parseVersion(version string) [3]int {
// Default version if parsing fails // Default version if parsing fails

BIN
aggregator-server/test-server Executable file

Binary file not shown.

View File

@@ -1,18 +1,9 @@
import React, { useState } from 'react'; import React from 'react';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import { import {
MonitorPlay,
RefreshCw, RefreshCw,
Settings,
Activity, Activity,
Clock,
CheckCircle,
XCircle,
Play, Play,
Square,
Database,
Shield,
Search,
HardDrive, HardDrive,
Cpu, Cpu,
Container, Container,
@@ -78,7 +69,7 @@ export function AgentScanners({ agentId }: AgentScannersProps) {
return await agentApi.disableSubsystem(agentId, subsystem); return await agentApi.disableSubsystem(agentId, subsystem);
} }
}, },
onSuccess: (data, variables) => { onSuccess: (_, variables) => {
toast.success(`${subsystemConfig[variables.subsystem]?.name || variables.subsystem} ${variables.enabled ? 'enabled' : 'disabled'}`); toast.success(`${subsystemConfig[variables.subsystem]?.name || variables.subsystem} ${variables.enabled ? 'enabled' : 'disabled'}`);
queryClient.invalidateQueries({ queryKey: ['subsystems', agentId] }); queryClient.invalidateQueries({ queryKey: ['subsystems', agentId] });
}, },
@@ -92,7 +83,7 @@ export function AgentScanners({ agentId }: AgentScannersProps) {
mutationFn: async ({ subsystem, intervalMinutes }: { subsystem: string; intervalMinutes: number }) => { mutationFn: async ({ subsystem, intervalMinutes }: { subsystem: string; intervalMinutes: number }) => {
return await agentApi.setSubsystemInterval(agentId, subsystem, intervalMinutes); return await agentApi.setSubsystemInterval(agentId, subsystem, intervalMinutes);
}, },
onSuccess: (data, variables) => { onSuccess: (_, variables) => {
toast.success(`Interval updated to ${variables.intervalMinutes} minutes`); toast.success(`Interval updated to ${variables.intervalMinutes} minutes`);
queryClient.invalidateQueries({ queryKey: ['subsystems', agentId] }); queryClient.invalidateQueries({ queryKey: ['subsystems', agentId] });
}, },
@@ -106,7 +97,7 @@ export function AgentScanners({ agentId }: AgentScannersProps) {
mutationFn: async ({ subsystem, autoRun }: { subsystem: string; autoRun: boolean }) => { mutationFn: async ({ subsystem, autoRun }: { subsystem: string; autoRun: boolean }) => {
return await agentApi.setSubsystemAutoRun(agentId, subsystem, autoRun); return await agentApi.setSubsystemAutoRun(agentId, subsystem, autoRun);
}, },
onSuccess: (data, variables) => { onSuccess: (_, variables) => {
toast.success(`Auto-run ${variables.autoRun ? 'enabled' : 'disabled'}`); toast.success(`Auto-run ${variables.autoRun ? 'enabled' : 'disabled'}`);
queryClient.invalidateQueries({ queryKey: ['subsystems', agentId] }); queryClient.invalidateQueries({ queryKey: ['subsystems', agentId] });
}, },
@@ -120,7 +111,7 @@ export function AgentScanners({ agentId }: AgentScannersProps) {
mutationFn: async (subsystem: string) => { mutationFn: async (subsystem: string) => {
return await agentApi.triggerSubsystem(agentId, subsystem); return await agentApi.triggerSubsystem(agentId, subsystem);
}, },
onSuccess: (data, subsystem) => { onSuccess: (_, subsystem) => {
toast.success(`${subsystemConfig[subsystem]?.name || subsystem} scan triggered`); toast.success(`${subsystemConfig[subsystem]?.name || subsystem} scan triggered`);
queryClient.invalidateQueries({ queryKey: ['subsystems', agentId] }); queryClient.invalidateQueries({ queryKey: ['subsystems', agentId] });
}, },
@@ -155,11 +146,6 @@ export function AgentScanners({ agentId }: AgentScannersProps) {
{ value: 1440, label: '24 hours' }, { value: 1440, label: '24 hours' },
]; ];
const getFrequencyLabel = (frequency: number) => {
if (frequency < 60) return `${frequency}m`;
if (frequency < 1440) return `${frequency / 60}h`;
return `${frequency / 1440}d`;
};
const enabledCount = subsystems.filter(s => s.enabled).length; const enabledCount = subsystems.filter(s => s.enabled).length;
const autoRunCount = subsystems.filter(s => s.auto_run && s.enabled).length; const autoRunCount = subsystems.filter(s => s.auto_run && s.enabled).length;

View File

@@ -1,17 +1,8 @@
import React, { useState } from 'react'; import { useState } from 'react';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
import { import {
HardDrive, HardDrive,
RefreshCw, RefreshCw,
Database,
Search,
Activity,
Monitor,
AlertTriangle,
CheckCircle,
Info,
TrendingUp,
Server,
MemoryStick, MemoryStick,
} from 'lucide-react'; } from 'lucide-react';
import { formatBytes, formatRelativeTime } from '@/lib/utils'; import { formatBytes, formatRelativeTime } from '@/lib/utils';
@@ -116,14 +107,6 @@ export function AgentStorage({ agentId }: AgentStorageProps) {
})); }));
}; };
const getDiskTypeIcon = (diskType: string) => {
switch (diskType.toLowerCase()) {
case 'nvme': return <Database className="h-4 w-4 text-purple-500" />;
case 'ssd': return <Server className="h-4 w-4 text-blue-500" />;
case 'hdd': return <HardDrive className="h-4 w-4 text-gray-500" />;
default: return <Monitor className="h-4 w-4 text-gray-400" />;
}
};
if (!agentData) { if (!agentData) {
return ( return (

View File

@@ -4,6 +4,7 @@ import {
Search, Search,
Package, Package,
Download, Download,
Upload,
CheckCircle, CheckCircle,
RefreshCw, RefreshCw,
Terminal, Terminal,
@@ -18,6 +19,7 @@ import { updateApi, agentApi } from '@/lib/api';
import toast from 'react-hot-toast'; import toast from 'react-hot-toast';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import type { UpdatePackage } from '@/types'; import type { UpdatePackage } from '@/types';
import { AgentUpdatesModal } from './AgentUpdatesModal';
interface AgentUpdatesEnhancedProps { interface AgentUpdatesEnhancedProps {
agentId: string; agentId: string;
@@ -52,7 +54,7 @@ export function AgentUpdatesEnhanced({ agentId }: AgentUpdatesEnhancedProps) {
const [selectedSeverity, setSelectedSeverity] = useState('all'); const [selectedSeverity, setSelectedSeverity] = useState('all');
const [showLogsModal, setShowLogsModal] = useState(false); const [showLogsModal, setShowLogsModal] = useState(false);
const [logsData, setLogsData] = useState<LogResponse | null>(null); const [logsData, setLogsData] = useState<LogResponse | null>(null);
const [isLoadingLogs, setIsLoadingLogs] = useState(false); const [showUpdateModal, setShowUpdateModal] = useState(false);
const [expandedUpdates, setExpandedUpdates] = useState<Set<string>>(new Set()); const [expandedUpdates, setExpandedUpdates] = useState<Set<string>>(new Set());
const [selectedUpdates, setSelectedUpdates] = useState<string[]>([]); const [selectedUpdates, setSelectedUpdates] = useState<string[]>([]);
@@ -300,6 +302,15 @@ export function AgentUpdatesEnhanced({ agentId }: AgentUpdatesEnhancedProps) {
)} )}
</button> </button>
)} )}
{/* Update Agent Button */}
<button
onClick={() => setShowUpdateModal(true)}
className="text-sm text-primary-600 hover:text-primary-800 flex items-center space-x-1 border border-primary-300 px-2 py-1 rounded"
>
<Upload className="h-4 w-4" />
<span>Update Agent</span>
</button>
</div> </div>
{/* Search and Filters */} {/* Search and Filters */}
@@ -531,6 +542,17 @@ export function AgentUpdatesEnhanced({ agentId }: AgentUpdatesEnhancedProps) {
</div> </div>
</div> </div>
)} )}
{/* Agent Update Modal */}
<AgentUpdatesModal
isOpen={showUpdateModal}
onClose={() => setShowUpdateModal(false)}
selectedAgentIds={[agentId]}
onAgentsUpdated={() => {
setShowUpdateModal(false);
queryClient.invalidateQueries({ queryKey: ['agents'] });
}}
/>
</div> </div>
); );
} }

View File

@@ -0,0 +1,290 @@
import { useState } from 'react';
import {
X,
Download,
CheckCircle,
AlertCircle,
RefreshCw,
Info,
Users,
Package,
Hash,
} from 'lucide-react';
import { useMutation, useQuery } from '@tanstack/react-query';
import { agentApi, updateApi } from '@/lib/api';
import toast from 'react-hot-toast';
import { cn } from '@/lib/utils';
import { Agent } from '@/types';
interface AgentUpdatesModalProps {
isOpen: boolean;
onClose: () => void;
selectedAgentIds: string[];
onAgentsUpdated: () => void;
}
export function AgentUpdatesModal({
isOpen,
onClose,
selectedAgentIds,
onAgentsUpdated,
}: AgentUpdatesModalProps) {
const [selectedVersion, setSelectedVersion] = useState('');
const [selectedPlatform, setSelectedPlatform] = useState('');
const [isUpdating, setIsUpdating] = useState(false);
// Fetch selected agents details
const { data: agents = [] } = useQuery<Agent[]>({
queryKey: ['agents-details', selectedAgentIds],
queryFn: async (): Promise<Agent[]> => {
const promises = selectedAgentIds.map(id => agentApi.getAgent(id));
const results = await Promise.all(promises);
return results;
},
enabled: isOpen && selectedAgentIds.length > 0,
});
// Fetch available update packages
const { data: packagesResponse, isLoading: packagesLoading } = useQuery({
queryKey: ['update-packages'],
queryFn: () => updateApi.getPackages(),
enabled: isOpen,
});
const packages = packagesResponse?.packages || [];
// Group packages by version
const versions = [...new Set(packages.map(pkg => pkg.version))].sort((a, b) => b.localeCompare(a));
const platforms = [...new Set(packages.map(pkg => pkg.platform))].sort();
// Filter packages based on selection
const availablePackages = packages.filter(
pkg => (!selectedVersion || pkg.version === selectedVersion) &&
(!selectedPlatform || pkg.platform === selectedPlatform)
);
// Get unique platform for selected agents (simplified - assumes all agents same platform)
const agentPlatform = agents[0]?.os_type || 'linux';
const agentArchitecture = agents[0]?.os_architecture || 'amd64';
// Update agents mutation
const updateAgentsMutation = useMutation({
mutationFn: async (packageId: string) => {
const pkg = packages.find(p => p.id === packageId);
if (!pkg) throw new Error('Package not found');
const updateData = {
agent_ids: selectedAgentIds,
version: pkg.version,
platform: pkg.platform,
};
return agentApi.updateMultipleAgents(updateData);
},
onSuccess: (data) => {
toast.success(`Update initiated for ${data.updated?.length || 0} agent(s)`);
setIsUpdating(false);
onAgentsUpdated();
onClose();
},
onError: (error: any) => {
toast.error(`Failed to update agents: ${error.message}`);
setIsUpdating(false);
},
});
const handleUpdateAgents = async (packageId: string) => {
setIsUpdating(true);
updateAgentsMutation.mutate(packageId);
};
const canUpdate = selectedAgentIds.length > 0 && availablePackages.length > 0 && !isUpdating;
const hasUpdatingAgents = agents.some(agent => agent.is_updating);
if (!isOpen) return null;
return (
<div className="fixed inset-0 z-50 overflow-y-auto">
<div className="flex min-h-full items-center justify-center p-4 text-center">
<div className="fixed inset-0 bg-gray-500 bg-opacity-75 transition-opacity" onClick={onClose} />
<div className="relative w-full max-w-4xl transform overflow-hidden rounded-lg bg-white p-6 text-left shadow-xl transition-all">
{/* Header */}
<div className="flex items-center justify-between pb-4 border-b">
<div className="flex items-center space-x-3">
<Download className="h-6 w-6 text-primary-600" />
<div>
<h3 className="text-lg font-semibold text-gray-900">Agent Updates</h3>
<p className="text-sm text-gray-500">
Update {selectedAgentIds.length} agent{selectedAgentIds.length !== 1 ? 's' : ''}
</p>
</div>
</div>
<button
onClick={onClose}
className="rounded-md p-2 text-gray-400 hover:text-gray-500"
>
<X className="h-5 w-5" />
</button>
</div>
{/* Content */}
<div className="py-4">
{/* Selected Agents */}
<div className="mb-6">
<h4 className="text-sm font-medium text-gray-900 mb-3 flex items-center">
<Users className="h-4 w-4 mr-2" />
Selected Agents
</h4>
<div className="max-h-32 overflow-y-auto space-y-2">
{agents.map((agent) => (
<div key={agent.id} className="flex items-center justify-between p-2 bg-gray-50 rounded">
<div className="flex items-center space-x-3">
<CheckCircle className={cn(
"h-4 w-4",
agent.status === 'online' ? "text-green-500" : "text-gray-400"
)} />
<div>
<div className="text-sm font-medium text-gray-900">{agent.hostname}</div>
<div className="text-xs text-gray-500">
{agent.os_type}/{agent.os_architecture} Current: {agent.current_version || 'Unknown'}
</div>
</div>
</div>
{agent.is_updating && (
<div className="flex items-center text-amber-600 text-xs">
<RefreshCw className="h-3 w-3 mr-1 animate-spin" />
Updating to {agent.updating_to_version}
</div>
)}
</div>
))}
</div>
{hasUpdatingAgents && (
<div className="mt-2 text-xs text-amber-600 flex items-center">
<AlertCircle className="h-3 w-3 mr-1" />
Some agents are currently updating
</div>
)}
</div>
{/* Package Selection */}
<div className="mb-6">
<h4 className="text-sm font-medium text-gray-900 mb-3 flex items-center">
<Package className="h-4 w-4 mr-2" />
Update Package Selection
</h4>
{/* Filters */}
<div className="grid grid-cols-2 gap-4 mb-4">
<div>
<label className="block text-xs font-medium text-gray-700 mb-1">Version</label>
<select
value={selectedVersion}
onChange={(e) => setSelectedVersion(e.target.value)}
className="w-full rounded-md border-gray-300 shadow-sm text-sm"
>
<option value="">All Versions</option>
{versions.map(version => (
<option key={version} value={version}>{version}</option>
))}
</select>
</div>
<div>
<label className="block text-xs font-medium text-gray-700 mb-1">Platform</label>
<select
value={selectedPlatform}
onChange={(e) => setSelectedPlatform(e.target.value)}
className="w-full rounded-md border-gray-300 shadow-sm text-sm"
>
<option value="">All Platforms</option>
{platforms.map(platform => (
<option key={platform} value={platform}>{platform}</option>
))}
</select>
</div>
</div>
{/* Available Packages */}
<div className="space-y-2 max-h-48 overflow-y-auto">
{packagesLoading ? (
<div className="text-center py-4 text-sm text-gray-500">
Loading packages...
</div>
) : availablePackages.length === 0 ? (
<div className="text-center py-4 text-sm text-gray-500">
No packages available for the selected criteria
</div>
) : (
availablePackages.map((pkg) => (
<div
key={pkg.id}
className={cn(
"flex items-center justify-between p-3 border rounded-lg cursor-pointer transition-colors",
"hover:bg-gray-50 border-gray-200"
)}
onClick={() => handleUpdateAgents(pkg.id)}
>
<div className="flex items-center space-x-3">
<Package className="h-4 w-4 text-primary-600" />
<div>
<div className="text-sm font-medium text-gray-900">
Version {pkg.version}
</div>
<div className="text-xs text-gray-500">
{pkg.platform} {(pkg.file_size / 1024 / 1024).toFixed(1)} MB
</div>
</div>
</div>
<div className="flex items-center space-x-2">
<div className="text-xs text-gray-400">
<Hash className="h-3 w-3 inline mr-1" />
{pkg.checksum.slice(0, 8)}...
</div>
<button
disabled={!canUpdate}
className={cn(
"px-3 py-1 text-xs rounded-md font-medium transition-colors",
canUpdate
? "bg-primary-600 text-white hover:bg-primary-700"
: "bg-gray-100 text-gray-400 cursor-not-allowed"
)}
>
{isUpdating ? (
<RefreshCw className="h-3 w-3 animate-spin" />
) : (
'Update'
)}
</button>
</div>
</div>
))
)}
</div>
</div>
{/* Platform Compatibility Info */}
<div className="text-xs text-gray-500 flex items-start">
<Info className="h-3 w-3 mr-1 mt-0.5 flex-shrink-0" />
<span>
Detected platform: <strong>{agentPlatform}/{agentArchitecture}</strong>.
Only compatible packages will be shown.
</span>
</div>
</div>
{/* Footer */}
<div className="flex justify-end space-x-3 pt-4 border-t">
<button
onClick={onClose}
disabled={isUpdating}
className="px-4 py-2 text-sm font-medium text-gray-700 bg-white border border-gray-300 rounded-md hover:bg-gray-50 disabled:opacity-50"
>
Cancel
</button>
</div>
</div>
</div>
</div>
);
}

View File

@@ -1,4 +1,4 @@
import React, { useState, useEffect } from 'react'; import React, { useState } from 'react';
import { import {
CheckCircle, CheckCircle,
XCircle, XCircle,
@@ -7,14 +7,12 @@ import {
Search, Search,
Terminal, Terminal,
RefreshCw, RefreshCw,
Filter,
ChevronDown, ChevronDown,
ChevronRight, ChevronRight,
User, User,
Clock, Clock,
Activity, Activity,
Copy, Copy,
Hash,
HardDrive, HardDrive,
Cpu, Cpu,
Container, Container,
@@ -25,7 +23,6 @@ import { useRetryCommand } from '@/hooks/useCommands';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import toast from 'react-hot-toast'; import toast from 'react-hot-toast';
import { Highlight, themes } from 'prism-react-renderer'; import { Highlight, themes } from 'prism-react-renderer';
import { useEffect as useEffectHook } from 'react';
interface HistoryEntry { interface HistoryEntry {
id: string; id: string;
@@ -41,6 +38,8 @@ interface HistoryEntry {
exit_code?: number; exit_code?: number;
duration_seconds?: number; duration_seconds?: number;
created_at: string; created_at: string;
metadata?: Record<string, string>;
params?: Record<string, any>;
hostname?: string; hostname?: string;
} }
@@ -76,9 +75,9 @@ const createPackageOperationSummary = (entry: HistoryEntry): string => {
// Extract duration if available // Extract duration if available
let durationInfo = ''; let durationInfo = '';
if (entry.logged_at) { if (entry.created_at) {
try { try {
const loggedTime = new Date(entry.logged_at).toLocaleTimeString('en-US', { const loggedTime = new Date(entry.created_at).toLocaleTimeString('en-US', {
hour: '2-digit', hour: '2-digit',
minute: '2-digit' minute: '2-digit'
}); });
@@ -444,9 +443,27 @@ const ChatTimeline: React.FC<ChatTimelineProps> = ({ agentId, className, isScope
} }
} }
// Fallback subject // Fallback subject - provide better action labels
if (!subject) { if (!subject) {
subject = entry.package_name || 'system operation'; // Map action to more readable labels
const actionLabels: Record<string, string> = {
'scan updates': 'Package Updates',
'scan storage': 'Disk Usage',
'scan system': 'System Metrics',
'scan docker': 'Docker Images',
'update agent': 'Agent Update',
'dry run update': 'Update Dry Run',
'confirm dependencies': 'Dependency Check',
'install update': 'Update Installation',
'collect specs': 'System Specifications',
'enable heartbeat': 'Heartbeat Enable',
'disable heartbeat': 'Heartbeat Disable',
'reboot': 'System Reboot',
'process command': 'Command Processing'
};
// Prioritize metadata subsystem label for better descriptions
subject = entry.metadata?.subsystem_label || entry.package_name || actionLabels[action] || action;
} }
// Build narrative sentence - system thought style // Build narrative sentence - system thought style
@@ -495,6 +512,16 @@ const ChatTimeline: React.FC<ChatTimelineProps> = ({ agentId, className, isScope
} else { } else {
sentence = `Docker Image Scanner results`; sentence = `Docker Image Scanner results`;
} }
} else if (action === 'update agent') {
if (isInProgress) {
sentence = `Agent Update initiated to version ${subject}`;
} else if (statusType === 'success') {
sentence = `Agent updated to version ${subject}`;
} else if (statusType === 'failed') {
sentence = `Agent update failed for version ${subject}`;
} else {
sentence = `Agent update to version ${subject}`;
}
} else if (action === 'dry run update') { } else if (action === 'dry run update') {
if (isInProgress) { if (isInProgress) {
sentence = `Dry run initiated for ${subject}`; sentence = `Dry run initiated for ${subject}`;

View File

@@ -2,6 +2,7 @@ import axios, { AxiosResponse } from 'axios';
import { import {
Agent, Agent,
UpdatePackage, UpdatePackage,
AgentUpdatePackage,
DashboardStats, DashboardStats,
AgentListResponse, AgentListResponse,
UpdateListResponse, UpdateListResponse,
@@ -160,6 +161,27 @@ export const agentApi = {
const response = await api.post(`/agents/${agentId}/subsystems/${subsystem}/interval`, { interval_minutes: intervalMinutes }); const response = await api.post(`/agents/${agentId}/subsystems/${subsystem}/interval`, { interval_minutes: intervalMinutes });
return response.data; return response.data;
}, },
// Update single agent
updateAgent: async (agentId: string, updateData: {
version: string;
platform: string;
scheduled?: string;
}): Promise<{ message: string; update_id: string; download_url: string; signature: string; checksum: string; file_size: number; estimated_time: number }> => {
const response = await api.post(`/agents/${agentId}/update`, updateData);
return response.data;
},
// Update multiple agents (bulk)
updateMultipleAgents: async (updateData: {
agent_ids: string[];
version: string;
platform: string;
scheduled?: string;
}): Promise<{ message: string; updated: Array<{ agent_id: string; hostname: string; update_id: string; status: string }>; failed: string[]; total_agents: number; package_info: any }> => {
const response = await api.post('/agents/bulk-update', updateData);
return response.data;
},
}; };
export const updateApi = { export const updateApi = {
@@ -185,6 +207,11 @@ export const updateApi = {
await api.post(`/updates/${id}/approve`, { scheduled_at: scheduledAt }); await api.post(`/updates/${id}/approve`, { scheduled_at: scheduledAt });
}, },
// Approve multiple updates
approveMultiple: async (updateIds: string[]): Promise<void> => {
await api.post('/updates/approve', { update_ids: updateIds });
},
// Reject/cancel update // Reject/cancel update
rejectUpdate: async (id: string): Promise<void> => { rejectUpdate: async (id: string): Promise<void> => {
await api.post(`/updates/${id}/reject`); await api.post(`/updates/${id}/reject`);
@@ -250,6 +277,28 @@ export const updateApi = {
const response = await api.delete(`/commands/failed${params.toString() ? '?' + params.toString() : ''}`); const response = await api.delete(`/commands/failed${params.toString() ? '?' + params.toString() : ''}`);
return response.data; return response.data;
}, },
// Get available update packages
getPackages: async (params?: {
version?: string;
platform?: string;
limit?: number;
offset?: number;
}): Promise<{ packages: AgentUpdatePackage[]; total: number; limit: number; offset: number }> => {
const response = await api.get('/updates/packages', { params });
return response.data;
},
// Sign new update package
signPackage: async (packageData: {
version: string;
platform: string;
architecture: string;
binary_path: string;
}): Promise<{ message: string; package: UpdatePackage }> => {
const response = await api.post('/updates/packages/sign', packageData);
return response.data;
},
}; };
export const statsApi = { export const statsApi = {

View File

@@ -25,6 +25,7 @@ import {
Database, Database,
Settings, Settings,
MonitorPlay, MonitorPlay,
Upload,
} from 'lucide-react'; } from 'lucide-react';
import { useAgents, useAgent, useScanAgent, useScanMultipleAgents, useUnregisterAgent } from '@/hooks/useAgents'; import { useAgents, useAgent, useScanAgent, useScanMultipleAgents, useUnregisterAgent } from '@/hooks/useAgents';
import { useActiveCommands, useCancelCommand } from '@/hooks/useCommands'; import { useActiveCommands, useCancelCommand } from '@/hooks/useCommands';
@@ -38,6 +39,7 @@ import { AgentSystemUpdates } from '@/components/AgentUpdates';
import { AgentStorage } from '@/components/AgentStorage'; import { AgentStorage } from '@/components/AgentStorage';
import { AgentUpdatesEnhanced } from '@/components/AgentUpdatesEnhanced'; import { AgentUpdatesEnhanced } from '@/components/AgentUpdatesEnhanced';
import { AgentScanners } from '@/components/AgentScanners'; import { AgentScanners } from '@/components/AgentScanners';
import { AgentUpdatesModal } from '@/components/AgentUpdatesModal';
import ChatTimeline from '@/components/ChatTimeline'; import ChatTimeline from '@/components/ChatTimeline';
const Agents: React.FC = () => { const Agents: React.FC = () => {
@@ -56,6 +58,7 @@ const Agents: React.FC = () => {
const [showDurationDropdown, setShowDurationDropdown] = useState(false); const [showDurationDropdown, setShowDurationDropdown] = useState(false);
const [heartbeatLoading, setHeartbeatLoading] = useState(false); // Loading state for heartbeat toggle const [heartbeatLoading, setHeartbeatLoading] = useState(false); // Loading state for heartbeat toggle
const [heartbeatCommandId, setHeartbeatCommandId] = useState<string | null>(null); // Track specific heartbeat command const [heartbeatCommandId, setHeartbeatCommandId] = useState<string | null>(null); // Track specific heartbeat command
const [showUpdateModal, setShowUpdateModal] = useState(false); // Update modal state
const dropdownRef = useRef<HTMLDivElement>(null); const dropdownRef = useRef<HTMLDivElement>(null);
// Close dropdown when clicking outside // Close dropdown when clicking outside
@@ -1142,6 +1145,7 @@ const Agents: React.FC = () => {
{/* Bulk actions */} {/* Bulk actions */}
{selectedAgents.length > 0 && ( {selectedAgents.length > 0 && (
<>
<button <button
onClick={handleScanSelected} onClick={handleScanSelected}
disabled={scanMultipleMutation.isPending} disabled={scanMultipleMutation.isPending}
@@ -1154,6 +1158,14 @@ const Agents: React.FC = () => {
)} )}
Scan Selected ({selectedAgents.length}) Scan Selected ({selectedAgents.length})
</button> </button>
<button
onClick={() => setShowUpdateModal(true)}
className="btn btn-secondary"
>
<Upload className="h-4 w-4 mr-2" />
Update Selected ({selectedAgents.length})
</button>
</>
)} )}
</div> </div>
@@ -1358,6 +1370,20 @@ const Agents: React.FC = () => {
> >
<RefreshCw className="h-4 w-4" /> <RefreshCw className="h-4 w-4" />
</button> </button>
<button
onClick={() => {
setSelectedAgents([agent.id]);
setShowUpdateModal(true);
}}
disabled={agent.is_updating}
className={cn(
"text-gray-400 hover:text-primary-600",
agent.is_updating && "text-amber-600 animate-pulse"
)}
title={agent.is_updating ? "Agent is updating..." : "Update agent"}
>
<Upload className="h-4 w-4" />
</button>
<button <button
onClick={() => handleRemoveAgent(agent.id, agent.hostname)} onClick={() => handleRemoveAgent(agent.id, agent.hostname)}
disabled={unregisterAgentMutation.isPending} disabled={unregisterAgentMutation.isPending}
@@ -1382,6 +1408,18 @@ const Agents: React.FC = () => {
</div> </div>
</div> </div>
)} )}
{/* Agent Updates Modal */}
<AgentUpdatesModal
isOpen={showUpdateModal}
onClose={() => setShowUpdateModal(false)}
selectedAgentIds={selectedAgents}
onAgentsUpdated={() => {
// Refresh agents data after update
queryClient.invalidateQueries({ queryKey: ['agents'] });
setSelectedAgents([]);
}}
/>
</div> </div>
); );
}; };

View File

@@ -26,6 +26,9 @@ export interface Agent {
last_reboot_at?: string | null; last_reboot_at?: string | null;
reboot_reason?: string; reboot_reason?: string;
metadata?: Record<string, any>; metadata?: Record<string, any>;
system_info?: Record<string, any>;
is_updating?: boolean;
updating_to_version?: string;
// Note: ip_address not available from API yet // Note: ip_address not available from API yet
} }
@@ -57,9 +60,21 @@ export interface UpdatePackage {
approved_at: string | null; approved_at: string | null;
scheduled_at: string | null; scheduled_at: string | null;
installed_at: string | null; installed_at: string | null;
created_at: string;
recent_command_id?: string;
metadata: Record<string, any>; metadata: Record<string, any>;
} }
// Agent Update Package (for agent binary updates)
export interface AgentUpdatePackage {
id: string;
version: string;
platform: string;
file_size: number;
checksum: string;
created_at: string;
}
// Update specific types // Update specific types
export interface DockerUpdateInfo { export interface DockerUpdateInfo {
local_digest: string; local_digest: string;

70
cmd/tools/keygen/main.go Normal file
View File

@@ -0,0 +1,70 @@
package main
import (
"crypto/ed25519"
"encoding/base64"
"encoding/hex"
"flag"
"fmt"
"os"
)
func main() {
publicB64 := flag.Bool("public-b64", false, "Extract and output public key in base64")
publicHex := flag.Bool("public-hex", false, "Extract and output public key in hex")
help := flag.Bool("help", false, "Show help message")
flag.Parse()
if *help {
fmt.Println("RedFlag Ed25519 Key Tool")
fmt.Println("Usage:")
fmt.Println(" go run ./cmd/tools/keygen -public-b64 Extract public key in base64")
fmt.Println(" go run ./cmd/tools/keygen -public-hex Extract public key in hex")
fmt.Println("")
fmt.Println("Requires REDFLAG_SIGNING_PRIVATE_KEY environment variable (64-byte hex)")
os.Exit(0)
}
// Read private key from environment
privateKeyHex := os.Getenv("REDFLAG_SIGNING_PRIVATE_KEY")
if privateKeyHex == "" {
fmt.Fprintln(os.Stderr, "Error: REDFLAG_SIGNING_PRIVATE_KEY environment variable not set")
os.Exit(1)
}
// Decode hex private key
privateKeyBytes, err := hex.DecodeString(privateKeyHex)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: Invalid private key hex format: %v\n", err)
os.Exit(1)
}
if len(privateKeyBytes) != ed25519.PrivateKeySize {
fmt.Fprintf(os.Stderr, "Error: Invalid private key size: expected %d bytes, got %d\n",
ed25519.PrivateKeySize, len(privateKeyBytes))
os.Exit(1)
}
// Extract public key from private key
privateKey := ed25519.PrivateKey(privateKeyBytes)
publicKey := privateKey.Public().(ed25519.PublicKey)
// Output in requested format
if *publicB64 {
fmt.Println(base64.StdEncoding.EncodeToString(publicKey))
} else if *publicHex {
fmt.Println(hex.EncodeToString(publicKey))
} else {
// Default: show both formats
fmt.Println("Public Key (hex):")
fmt.Println(hex.EncodeToString(publicKey))
fmt.Println("")
fmt.Println("Public Key (base64):")
fmt.Println(base64.StdEncoding.EncodeToString(publicKey))
fmt.Println("")
fmt.Println("For embedding in agent binary, use:")
fmt.Printf("go build -ldflags \"-X main.ServerPublicKeyHex=%s\" -o redflag-agent cmd/agent/main.go\n",
hex.EncodeToString(publicKey))
}
}

24
discord/.env.example Normal file
View File

@@ -0,0 +1,24 @@
# Discord Configuration Template
# Copy this file to .env and fill in your actual values
# Discord Bot Configuration
DISCORD_BOT_TOKEN=your_bot_token_here
DISCORD_SERVER_ID=your_server_id_here
DISCORD_APPLICATION_ID=your_app_id_here
DISCORD_PUBLIC_KEY=your_public_key_here
# Server Management
SERVER_NAME=RedFlag Security
ADMIN_ROLE_ID=your_admin_role_id_here
# Channel IDs (to be filled after creation)
GENERAL_CHANNEL_ID=
ANNOUNCEMENTS_CHANNEL_ID=
SECURITY_ALERTS_CHANNEL_ID=
DEV_CHAT_CHANNEL_ID=
BUG_REPORTS_CHANNEL_ID=
# Category IDs (to be filled after creation)
COMMUNITY_CATEGORY_ID=
DEVELOPMENT_CATEGORY_ID=
SECURITY_CATEGORY_ID=

31
discord/.gitignore vendored Normal file
View File

@@ -0,0 +1,31 @@
# Environment files
.env
.env.local
.env.*.local
# Python
__pycache__/
*.pyc
*.pyo
*.pyd
.env
venv/
.venv/
env/
venv.bak/
venv/
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# Logs
*.log
logs/
# OS
.DS_Store
Thumbs.db

426
discord/discord_manager.py Executable file
View File

@@ -0,0 +1,426 @@
#!/usr/bin/env python3
"""
RedFlag Discord Management Bot
Interactive Discord server management with secure configuration
"""
import discord
import asyncio
import logging
import sys
from typing import Optional, Dict, List
from discord.ext import commands
from discord import app_commands
from env_manager import discord_env
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class DiscordManager:
"""Interactive Discord server manager"""
def __init__(self):
self.bot_token = discord_env.get_required('DISCORD_BOT_TOKEN')
self.server_id = int(discord_env.get_required('DISCORD_SERVER_ID'))
self.application_id = discord_env.get_required('DISCORD_APPLICATION_ID')
self.public_key = discord_env.get_required('DISCORD_PUBLIC_KEY')
# Bot setup with required intents
intents = discord.Intents.default()
intents.guilds = True
intents.message_content = True
# Initialize bot
self.bot = commands.Bot(
command_prefix='!',
intents=intents,
help_command=None # We'll create custom help
)
self.setup_events()
self.setup_commands()
def setup_events(self):
"""Setup bot event handlers"""
@self.bot.event
async def on_ready():
logger.info(f'✅ Bot logged in as {self.bot.user}')
logger.info(f'Serving server: {self.bot.user.name} (ID: {self.bot.user.id})')
# Sync commands
await self.bot.tree.sync()
logger.info('✅ Commands synced')
# Get server info
guild = self.bot.get_guild(self.server_id)
if guild:
logger.info(f'✅ Connected to server: {guild.name}')
await self.print_server_status(guild)
else:
logger.error('❌ Could not find server!')
@self.bot.event
async def on_command_error(ctx, error):
"""Handle command errors"""
logger.error(f'Command error in {ctx.command}: {error}')
if isinstance(error, commands.MissingRequiredArgument):
await ctx.send(f'❌ Missing required argument: `{error.param}`')
elif isinstance(error, commands.BadArgument):
await ctx.send(f'❌ Invalid argument: {error}')
else:
await ctx.send(f'❌ An error occurred: {error}')
def setup_commands(self):
"""Setup slash commands"""
# Server management commands
@self.bot.tree.command(name="status", description="Show current Discord server status")
async def cmd_status(interaction: discord.Interaction):
await self.cmd_status(interaction)
@self.bot.tree.command(name="create-channels", description="Create standard RedFlag channels")
async def cmd_create_channels(interaction: discord.Interaction):
await self.cmd_create_channels(interaction)
@self.bot.tree.command(name="send-message", description="Send a message to a channel")
@app_commands.describe(channel="Channel to send message to", message="Message to send")
async def cmd_send_message(interaction: discord.Interaction, channel: str, message: str):
await self.cmd_send_message(interaction, channel, message)
@self.bot.tree.command(name="list-channels", description="List all channels and categories")
async def cmd_list_channels(interaction: discord.Interaction):
await self.cmd_list_channels(interaction)
@self.bot.tree.command(name="create-category", description="Create a channel category")
@app_commands.describe(name="Category name")
async def cmd_create_category(interaction: discord.Interaction, name: str):
await self.cmd_create_category(interaction, name)
@self.bot.tree.command(name="create-test-channel", description="Create one simple test channel")
async def cmd_create_test_channel(interaction: discord.Interaction):
await self.cmd_create_test_channel(interaction)
@self.bot.tree.command(name="help", description="Show available commands")
async def cmd_help(interaction: discord.Interaction):
await self.cmd_help(interaction)
async def cmd_status(self, interaction: discord.Interaction):
"""Show server status"""
guild = self.bot.get_guild(self.server_id)
if not guild:
await interaction.response.send_message("❌ Could not find server!", ephemeral=True)
return
embed = discord.Embed(
title="📊 RedFlag Discord Server Status",
color=discord.Color.blue(),
description=f"Server: **{guild.name}**"
)
embed.add_field(name="👥 Members", value=str(guild.member_count), inline=True)
embed.add_field(name="💬 Channels", value=str(len(guild.channels)), inline=True)
embed.add_field(name="🎭 Roles", value=str(len(guild.roles)), inline=True)
embed.add_field(name="📅 Created", value=guild.created_at.strftime("%Y-%m-%d"), inline=True)
embed.add_field(name="👑 Owner", value=f"<@{guild.owner_id}>", inline=True)
embed.add_field(name="🚀 Boost Level", value=str(guild.premium_tier), inline=True)
await interaction.response.send_message(embed=embed)
async def cmd_create_channels(self, interaction: discord.Interaction):
"""Create standard RedFlag channels"""
guild = self.bot.get_guild(self.server_id)
if not guild:
await interaction.response.send_message("❌ Could not find server!", ephemeral=True)
return
await interaction.response.defer(ephemeral=True)
results = []
# Create categories first
try:
# Community category
community_cat = await guild.create_category_channel("🌍 Community")
discord_env.update_category_ids("community", str(community_cat.id))
results.append("✅ Community category")
# Development category
dev_cat = await guild.create_category_channel("💻 Development")
discord_env.update_category_ids("development", str(dev_cat.id))
results.append("✅ Development category")
# Security category
security_cat = await guild.create_category_channel("🔒 Security")
discord_env.update_category_ids("security", str(security_cat.id))
results.append("✅ Security category")
except Exception as e:
logger.error(f"Error creating categories: {e}")
results.append(f"❌ Categories: {e}")
# Create channels (with small delays to avoid rate limits)
await asyncio.sleep(1)
# Community channels
try:
general = await guild.create_text_channel(
"general",
category=discord.Object(id=int(discord_env.get('COMMUNITY_CATEGORY_ID', 0))),
reason="Community general discussion"
)
discord_env.update_channel_ids("general", str(general.id))
results.append("✅ #general")
announcements = await guild.create_text_channel(
"announcements",
category=discord.Object(id=int(discord_env.get('COMMUNITY_CATEGORY_ID', 0))),
reason="Project announcements"
)
discord_env.update_channel_ids("announcements", str(announcements.id))
results.append("✅ #announcements")
except Exception as e:
logger.error(f"Error creating community channels: {e}")
results.append(f"❌ Community channels: {e}")
await asyncio.sleep(1)
# Security channels
try:
security_alerts = await guild.create_text_channel(
"security-alerts",
category=discord.Object(id=int(discord_env.get('SECURITY_CATEGORY_ID', 0))),
reason="Security alerts and notifications"
)
discord_env.update_channel_ids("security-alerts", str(security_alerts.id))
results.append("✅ #security-alerts")
except Exception as e:
logger.error(f"Error creating security channels: {e}")
results.append(f"❌ Security channels: {e}")
await asyncio.sleep(1)
# Development channels
try:
dev_chat = await guild.create_text_channel(
"dev-chat",
category=discord.Object(id=int(discord_env.get('DEVELOPMENT_CATEGORY_ID', 0))),
reason="Development discussions"
)
discord_env.update_channel_ids("dev-chat", str(dev_chat.id))
results.append("✅ #dev-chat")
bug_reports = await guild.create_text_channel(
"bug-reports",
category=discord.Object(id=int(discord_env.get('DEVELOPMENT_CATEGORY_ID', 0))),
reason="Bug reports and issues"
)
discord_env.update_channel_ids("bug-reports", str(bug_reports.id))
results.append("✅ #bug-reports")
except Exception as e:
logger.error(f"Error creating development channels: {e}")
results.append(f"❌ Development channels: {e}")
# Send results
embed = discord.Embed(
title="🔧 Channel Creation Results",
color=discord.Color.green() if "" not in str(results) else discord.Color.red(),
description="\n".join(results)
)
await interaction.followup.send(embed=embed, ephemeral=True)
async def cmd_send_message(self, interaction: discord.Interaction, channel: str, message: str):
"""Send a message to a specific channel"""
guild = self.bot.get_guild(self.server_id)
if not guild:
await interaction.response.send_message("❌ Could not find server!", ephemeral=True)
return
# Find channel by name
target_channel = discord.utils.get(guild.text_channels, name=channel.lower())
if not target_channel:
await interaction.response.send_message(f"❌ Channel '{channel}' not found!", ephemeral=True)
return
try:
await target_channel.send(message)
await interaction.response.send_message(
f"✅ Message sent to #{channel}!", ephemeral=True
)
except Exception as e:
logger.error(f"Error sending message: {e}")
await interaction.response.send_message(
f"❌ Failed to send message: {e}", ephemeral=True
)
async def cmd_list_channels(self, interaction: discord.Interaction):
"""List all channels and categories"""
guild = self.bot.get_guild(self.server_id)
if not guild:
await interaction.response.send_message("❌ Could not find server!", ephemeral=True)
return
embed = discord.Embed(
title="📋 Server Channels",
color=discord.Color.blue()
)
# List categories
categories = [c for c in guild.categories if c.name]
if categories:
category_text = "\n".join([f"**{c.name}** (ID: {c.id})" for c in categories])
embed.add_field(name="📂 Categories", value=category_text or "None", inline=False)
# List text channels
text_channels = [c for c in guild.text_channels if c.category]
if text_channels:
channel_text = "\n".join([f"#{c.name} (ID: {c.id})" for c in text_channels])
embed.add_field(name="💬 Text Channels", value=channel_text or "None", inline=False)
# List voice channels
voice_channels = [c for c in guild.voice_channels if c.category]
if voice_channels:
voice_text = "\n".join([f"🎤 {c.name} (ID: {c.id})" for c in voice_channels])
embed.add_field(name="🎤 Voice Channels", value=voice_text or "None", inline=False)
await interaction.response.send_message(embed=embed, ephemeral=True)
async def cmd_create_category(self, interaction: discord.Interaction, name: str):
"""Create a new category"""
guild = self.bot.get_guild(self.server_id)
if not guild:
await interaction.response.send_message("❌ Could not find server!", ephemeral=True)
return
try:
category = await guild.create_category_channel(name)
await interaction.response.send_message(
f"✅ Created category: **{name}** (ID: {category.id})",
ephemeral=True
)
except Exception as e:
logger.error(f"Error creating category: {e}")
await interaction.response.send_message(
f"❌ Failed to create category: {e}",
ephemeral=True
)
async def cmd_create_test_channel(self, interaction: discord.Interaction):
"""Create one simple test channel"""
guild = self.bot.get_guild(self.server_id)
if not guild:
await interaction.response.send_message("❌ Could not find server!", ephemeral=True)
return
try:
# Create a simple text channel
test_channel = await guild.create_text_channel(
"test-channel",
reason="Testing bot channel creation"
)
await interaction.response.send_message(
f"✅ Created test channel: **#{test_channel.name}**",
ephemeral=True
)
except Exception as e:
logger.error(f"Error creating test channel: {e}")
await interaction.response.send_message(
f"❌ Failed to create test channel: {e}",
ephemeral=True
)
async def cmd_help(self, interaction: discord.Interaction):
"""Show help information"""
embed = discord.Embed(
title="🤖 RedFlag Discord Bot Help",
description="Interactive Discord server management commands",
color=discord.Color.blue()
)
commands_info = [
("`/status`", "📊 Show server status"),
("`/create-channels`", "🔧 Create standard channels"),
("`/list-channels`", "📋 List all channels"),
("`/send-message`", "💬 Send message to channel"),
("`/create-category`", "📂 Create new category"),
("`/create-test-channel`", "🧪 Create one test channel"),
("`/help`", "❓ Show this help"),
]
for cmd, desc in commands_info:
embed.add_field(name=cmd, value=desc, inline=False)
embed.add_field(
name="🛡️ Security Features",
value="✅ Secure configuration management\n✅ Token protection\n✅ Rate limiting",
inline=False
)
embed.add_field(
name="🔄 Live Updates",
value="Configuration changes are saved instantly to .env file",
inline=False
)
embed.set_footer(text="RedFlag Security Discord Management v1.0")
await interaction.response.send_message(embed=embed, ephemeral=True)
async def print_server_status(self, guild):
"""Print detailed server status to console"""
print(f"\n{'='*60}")
print(f"🚀 REDFLAG DISCORD SERVER STATUS")
print(f"{'='*60}")
print(f"Server Name: {guild.name}")
print(f"Server ID: {guild.id}")
print(f"Members: {guild.member_count}")
print(f"Channels: {len(guild.channels)}")
print(f"Roles: {len(guild.roles)}")
print(f"Owner: <@{guild.owner_id}>")
print(f"Created: {guild.created_at}")
print(f"Boost Level: {guild.premium_tier}")
print(f"{'='*60}\n")
async def run(self):
"""Start the bot"""
try:
# Connect to Discord
await self.bot.start(self.bot_token)
except discord.errors.LoginFailure:
logger.error("❌ Invalid Discord bot token!")
logger.error("Please check your DISCORD_BOT_TOKEN in .env file")
sys.exit(1)
except Exception as e:
logger.error(f"❌ Failed to start bot: {e}")
sys.exit(1)
def main():
"""Main entry point"""
print("🚀 Starting RedFlag Discord Management Bot...")
# Check configuration
if not discord_env.is_configured():
print("❌ Discord not configured!")
print("Please:")
print("1. Copy .env.example to .env")
print("2. Fill in your Discord bot token and server ID")
print("3. Run this script again")
sys.exit(1)
# Create and run bot
bot = DiscordManager()
asyncio.run(bot.run())
if __name__ == "__main__":
main()

153
discord/env_manager.py Executable file
View File

@@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""
Secure Discord Environment Manager
Handles loading Discord configuration from .env without exposing secrets
"""
import os
import logging
from typing import Optional, Dict, Any
from dotenv import load_dotenv
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class DiscordEnvManager:
"""Secure environment manager for Discord configuration"""
def __init__(self, env_file: str = ".env"):
self.env_file = env_file
self._config = {}
self._load_config()
def _load_config(self):
"""Load configuration from .env file"""
try:
load_dotenv(self.env_file)
self._config = {
'DISCORD_BOT_TOKEN': os.getenv('DISCORD_BOT_TOKEN'),
'DISCORD_SERVER_ID': os.getenv('DISCORD_SERVER_ID'),
'DISCORD_APPLICATION_ID': os.getenv('DISCORD_APPLICATION_ID'),
'DISCORD_PUBLIC_KEY': os.getenv('DISCORD_PUBLIC_KEY'),
'SERVER_NAME': os.getenv('SERVER_NAME', 'RedFlag Security'),
'ADMIN_ROLE_ID': os.getenv('ADMIN_ROLE_ID'),
'GENERAL_CHANNEL_ID': os.getenv('GENERAL_CHANNEL_ID'),
'ANNOUNCEMENTS_CHANNEL_ID': os.getenv('ANNOUNCEMENTS_CHANNEL_ID'),
'SECURITY_ALERTS_CHANNEL_ID': os.getenv('SECURITY_ALERTS_CHANNEL_ID'),
'DEV_CHAT_CHANNEL_ID': os.getenv('DEV_CHAT_CHANNEL_ID'),
'BUG_REPORTS_CHANNEL_ID': os.getenv('BUG_REPORTS_CHANNEL_ID'),
'COMMUNITY_CATEGORY_ID': os.getenv('COMMUNITY_CATEGORY_ID'),
'DEVELOPMENT_CATEGORY_ID': os.getenv('DEVELOPMENT_CATEGORY_ID'),
'SECURITY_CATEGORY_ID': os.getenv('SECURITY_CATEGORY_ID'),
}
# Validate required fields
required_fields = ['DISCORD_BOT_TOKEN', 'DISCORD_SERVER_ID']
missing = [field for field in required_fields if not self._config.get(field)]
if missing:
logger.error(f"Missing required environment variables: {missing}")
raise ValueError(f"Missing required fields: {missing}")
logger.info("✅ Discord configuration loaded successfully")
except Exception as e:
logger.error(f"Failed to load Discord configuration: {e}")
raise
def get(self, key: str, default: Any = None) -> Optional[str]:
"""Get configuration value"""
return self._config.get(key, default)
def get_required(self, key: str) -> str:
"""Get required configuration value"""
value = self._config.get(key)
if not value:
raise ValueError(f"Required environment variable {key} is not set")
return value
def update_channel_ids(self, channel_name: str, channel_id: str):
"""Update channel ID in config"""
channel_key = f"{channel_name.upper()}_CHANNEL_ID"
self._config[channel_key] = channel_id
# Also update in file
self._update_env_file(channel_key, channel_id)
def update_category_ids(self, category_name: str, category_id: str):
"""Update category ID in config"""
category_key = f"{category_name.upper()}_CATEGORY_ID"
self._config[category_key] = category_id
# Also update in file
self._update_env_file(category_key, category_id)
def _update_env_file(self, key: str, value: str):
"""Update .env file with new value"""
try:
env_path = os.path.join(os.path.dirname(__file__), self.env_file)
# Read current file
if os.path.exists(env_path):
with open(env_path, 'r') as f:
lines = f.readlines()
else:
lines = []
# Update or add the line
updated = False
for i, line in enumerate(lines):
if line.startswith(f"{key}="):
lines[i] = f"{key}={value}\n"
updated = True
break
if not updated:
lines.append(f"{key}={value}\n")
# Write back to file
with open(env_path, 'w') as f:
f.writelines(lines)
logger.info(f"✅ Updated {key} in {self.env_file}")
except Exception as e:
logger.error(f"Failed to update {key} in {self.env_file}: {e}")
def is_configured(self) -> bool:
"""Check if the Discord bot is properly configured"""
return (
self.get('DISCORD_BOT_TOKEN') and
self.get('DISCORD_SERVER_ID') and
self.get('DISCORD_APPLICATION_ID')
)
def mask_sensitive_info(self, text: str) -> str:
"""Mask sensitive information in logs"""
sensitive_words = ['TOKEN', 'KEY']
masked_text = text
for word in sensitive_words:
if f"{word}_ID" not in masked_text: # Don't mask channel IDs
# Find and mask the value
import re
pattern = rf'{word}=\w+'
replacement = f'{word}=***MASKED***'
masked_text = re.sub(pattern, replacement, masked_text)
return masked_text
# Global instance for easy access
discord_env = DiscordEnvManager()
# Convenience functions
def get_discord_config():
"""Get Discord configuration"""
return discord_env
def is_discord_ready():
"""Check if Discord is ready for use"""
return discord_env.is_configured()

4
discord/requirements.txt Normal file
View File

@@ -0,0 +1,4 @@
discord.py>=2.4.0
python-dotenv>=1.0.0
aiohttp>=3.8.0
asyncio-mqtt>=0.16.0

124
discord/setup.py Normal file
View File

@@ -0,0 +1,124 @@
#!/usr/bin/env python3
"""
RedFlag Discord Setup Assistant
Helps configure Discord bot for server management
"""
import os
import sys
from dotenv import load_dotenv
def setup_discord():
"""Interactive Discord setup"""
print("🚀 RedFlag Discord Bot Setup Assistant")
print("=" * 50)
# Check if .env exists
env_file = ".env"
if not os.path.exists(env_file):
print(f"📝 Creating {env_file} from template...")
if os.path.exists(".env.example"):
import shutil
shutil.copy(".env.example", env_file)
print(f"✅ Created {env_file} from .env.example")
else:
# Create basic .env file
with open(env_file, 'w') as f:
f.write("# Discord Bot Configuration\n")
f.write("DISCORD_BOT_TOKEN=your_bot_token_here\n")
f.write("DISCORD_SERVER_ID=your_server_id_here\n")
f.write("DISCORD_APPLICATION_ID=your_app_id_here\n")
f.write("DISCORD_PUBLIC_KEY=your_public_key_here\n")
f.write("\n# Server Settings\n")
f.write("SERVER_NAME=RedFlag Security\n")
f.write("ADMIN_ROLE_ID=\n")
print(f"✅ Created basic {env_file}")
# Load environment
load_dotenv(env_file)
print("\n📋 Discord Configuration Checklist:")
print("1. ✅ Discord Developer Portal: https://discord.com/developers/applications")
print("2. ✅ Create Application: Click 'New Application'")
print("3. ✅ Create Bot: Go to 'Bot''Add Bot'")
print("4. ✅ Enable Privileged Intents:")
print(" - ✅ Server Members Intent")
print(" - ✅ Server Management Intent")
print(" - ✅ Message Content Intent")
print("5. ✅ OAuth2 URL Generator:")
print(" - ✅ Scope: bot")
print(" - ✅ Scope: applications.commands")
print(" - ✅ Permissions: Administrator (or specific)")
print("6. ✅ Invite Bot to Server")
print("7. ✅ Copy Values Below:")
print("\n🔑 Required Discord Information:")
print("From your Discord Developer Portal, copy these values:")
print("-" * 50)
# Get user input (with masking)
def get_sensitive_input(prompt, key):
value = input(f"{prompt}: ").strip()
if value:
# Update .env file
update_env_file(key, value)
# Show masked version
masked_value = value[:8] + "..." + value[-4:] if len(value) > 12 else value
print(f"{key}: {masked_value}")
return value
def update_env_file(key, value):
"""Update .env file with value"""
env_path = os.path.join(os.path.dirname(__file__), env_file)
# Read current file
with open(env_path, 'r') as f:
lines = f.readlines()
# Update or add the line
updated = False
for i, line in enumerate(lines):
if line.startswith(f"{key}="):
lines[i] = f"{key}={value}\n"
updated = True
break
if not updated:
lines.append(f"{key}={value}\n")
# Write back to file
with open(env_path, 'w') as f:
f.writelines(lines)
# Get required values
bot_token = get_sensitive_input("Discord Bot Token", "DISCORD_BOT_TOKEN")
server_id = get_sensitive_input("Discord Server ID", "DISCORD_SERVER_ID")
app_id = get_sensitive_input("Discord Application ID", "DISCORD_APPLICATION_ID")
public_key = get_sensitive_input("Discord Public Key", "DISCORD_PUBLIC_KEY")
print("-" * 50)
print("🎉 Configuration Complete!")
print("\n📝 Next Steps:")
print("1. Run the Discord bot:")
print(" cd /home/memory/Desktop/Projects/RedFlag/discord")
print(" python discord_manager.py")
print("\n2. Available Commands (slash commands):")
print(" • /status - Show server status")
print(" • /create-channels - Create standard channels")
print(" • /list-channels - List all channels")
print(" • /send-message - Send message to channel")
print(" • /create-category - Create new category")
print(" • /help - Show all commands")
print("\n🔒 Security Note:")
print("• Your bot token is stored locally in .env")
print("• Never share the .env file")
print("• The bot only has Administrator permissions you grant it")
print("• All actions are logged locally")
def main():
"""Main setup function"""
setup_discord()
if __name__ == "__main__":
main()

23
scripts/build-secure-agent.sh Executable file
View File

@@ -0,0 +1,23 @@
#!/bin/bash
# RedFlag Agent Build Script
# Builds agent binary (public key fetched from server at runtime)
set -e
echo "🔨 RedFlag Agent Build"
echo "====================="
# Build agent
echo "Building agent..."
cd aggregator-agent
go build \
-o redflag-agent \
./cmd/agent
cd ..
echo "✅ Agent build complete!"
echo " Binary: aggregator-agent/redflag-agent"
echo ""
echo " Note: Agent will fetch the server's public key automatically at startup"

View File

@@ -0,0 +1,34 @@
//go:build ignore
// +build ignore
package main
import (
"crypto/ed25519"
"crypto/rand"
"encoding/hex"
"fmt"
"os"
)
func main() {
// Generate Ed25519 keypair
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
fmt.Printf("Failed to generate keypair: %v\n", err)
os.Exit(1)
}
// Output keys in hex format
fmt.Printf("Ed25519 Keypair Generated:\n\n")
fmt.Printf("Private Key (keep secret, add to server env):\n")
fmt.Printf("REDFLAG_SIGNING_PRIVATE_KEY=%s\n", hex.EncodeToString(privateKey))
fmt.Printf("\nPublic Key (embed in agent binaries):\n")
fmt.Printf("REDFLAG_PUBLIC_KEY=%s\n", hex.EncodeToString(publicKey))
fmt.Printf("\nPublic Key Fingerprint (for database):\n")
fmt.Printf("Fingerprint: %s\n", hex.EncodeToString(publicKey[:8])) // First 8 bytes as fingerprint
fmt.Printf("\nAdd the private key to your server environment and embed the public key in agent builds.\n")
}