diff --git a/Makefile b/Makefile index 63700fb..89e2f61 100644 --- a/Makefile +++ b/Makefile @@ -26,6 +26,9 @@ build-server: ## Build server binary build-agent: ## Build agent binary cd aggregator-agent && go mod tidy && go build -o bin/aggregator-agent cmd/agent/main.go +build-agent-simple: ## Build agent binary with simple script + @./scripts/build-secure-agent.sh + clean: ## Clean build artifacts rm -rf aggregator-server/bin aggregator-agent/bin diff --git a/aggregator-agent/cmd/agent/main.go b/aggregator-agent/cmd/agent/main.go index 50cf0f1..3e4ac00 100644 --- a/aggregator-agent/cmd/agent/main.go +++ b/aggregator-agent/cmd/agent/main.go @@ -17,6 +17,7 @@ import ( "github.com/Fimeg/RedFlag/aggregator-agent/internal/circuitbreaker" "github.com/Fimeg/RedFlag/aggregator-agent/internal/client" "github.com/Fimeg/RedFlag/aggregator-agent/internal/config" + "github.com/Fimeg/RedFlag/aggregator-agent/internal/crypto" "github.com/Fimeg/RedFlag/aggregator-agent/internal/display" "github.com/Fimeg/RedFlag/aggregator-agent/internal/installer" "github.com/Fimeg/RedFlag/aggregator-agent/internal/orchestrator" @@ -348,13 +349,28 @@ func registerAgent(cfg *config.Config, serverURL string) error { } } + // Get machine ID for binding + machineID, err := system.GetMachineID() + if err != nil { + log.Printf("Warning: Failed to get machine ID: %v", err) + machineID = "unknown-" + sysInfo.Hostname + } + + // Get embedded public key fingerprint + publicKeyFingerprint := system.GetPublicKeyFingerprint() + if publicKeyFingerprint == "" { + log.Printf("Warning: No embedded public key fingerprint found") + } + req := client.RegisterRequest{ - Hostname: sysInfo.Hostname, - OSType: sysInfo.OSType, - OSVersion: sysInfo.OSVersion, - OSArchitecture: sysInfo.OSArchitecture, - AgentVersion: sysInfo.AgentVersion, - Metadata: metadata, + Hostname: sysInfo.Hostname, + OSType: sysInfo.OSType, + OSVersion: sysInfo.OSVersion, + OSArchitecture: sysInfo.OSArchitecture, + AgentVersion: sysInfo.AgentVersion, + MachineID: machineID, + PublicKeyFingerprint: publicKeyFingerprint, + Metadata: metadata, } resp, err := apiClient.Register(req) @@ -376,7 +392,27 @@ func registerAgent(cfg *config.Config, serverURL string) error { } // Save configuration - return cfg.Save(getConfigPath()) + if err := cfg.Save(getConfigPath()); err != nil { + return fmt.Errorf("failed to save config: %w", err) + } + + // Fetch and cache server public key for signature verification + log.Println("Fetching server public key for update signature verification...") + if err := fetchAndCachePublicKey(cfg.ServerURL); err != nil { + log.Printf("Warning: Failed to fetch server public key: %v", err) + log.Printf("Agent will not be able to verify update signatures") + // Don't fail registration - key can be fetched later + } else { + log.Println("✓ Server public key cached successfully") + } + + return nil +} + +// fetchAndCachePublicKey fetches the server's Ed25519 public key and caches it locally +func fetchAndCachePublicKey(serverURL string) error { + _, err := crypto.FetchAndCacheServerPublicKey(serverURL) + return err } // renewTokenIfNeeded handles 401 errors by renewing the agent token using refresh token @@ -694,6 +730,12 @@ func runAgent(cfg *config.Config) error { if err := handleReboot(apiClient, cfg, ackTracker, cmd.ID, cmd.Params); err != nil { log.Printf("[Reboot] Error processing reboot command: %v\n", err) } + + case "update_agent": + if err := handleUpdateAgent(apiClient, cfg, ackTracker, cmd.Params, cmd.ID); err != nil { + log.Printf("[Update] Error processing agent update command: %v\n", err) + } + default: log.Printf("Unknown command type: %s - reporting as invalid command\n", cmd.Type) // Report invalid command back to server diff --git a/aggregator-agent/cmd/agent/subsystem_handlers.go b/aggregator-agent/cmd/agent/subsystem_handlers.go index d65c308..d52eb11 100644 --- a/aggregator-agent/cmd/agent/subsystem_handlers.go +++ b/aggregator-agent/cmd/agent/subsystem_handlers.go @@ -2,8 +2,18 @@ package main import ( "context" + "crypto/aes" + "crypto/cipher" + "crypto/ed25519" + "crypto/sha256" + "encoding/hex" "fmt" + "io" "log" + "net/http" + "os" + "os/exec" + "runtime" "time" "github.com/Fimeg/RedFlag/aggregator-agent/internal/acknowledgment" @@ -39,6 +49,10 @@ func handleScanUpdatesV2(apiClient *client.Client, cfg *config.Config, ackTracke Stderr: stderr, ExitCode: exitCode, DurationSeconds: int(duration.Seconds()), + Metadata: map[string]string{ + "subsystem_label": "Package Updates", + "subsystem": "updates", + }, } // Report the scan log @@ -96,6 +110,10 @@ func handleScanStorage(apiClient *client.Client, cfg *config.Config, ackTracker Stderr: stderr, ExitCode: exitCode, DurationSeconds: int(duration.Seconds()), + Metadata: map[string]string{ + "subsystem_label": "Disk Usage", + "subsystem": "storage", + }, } // Report the scan log @@ -150,6 +168,10 @@ func handleScanSystem(apiClient *client.Client, cfg *config.Config, ackTracker * Stderr: stderr, ExitCode: exitCode, DurationSeconds: int(duration.Seconds()), + Metadata: map[string]string{ + "subsystem_label": "System Metrics", + "subsystem": "system", + }, } // Report the scan log @@ -204,6 +226,10 @@ func handleScanDocker(apiClient *client.Client, cfg *config.Config, ackTracker * Stderr: stderr, ExitCode: exitCode, DurationSeconds: int(duration.Seconds()), + Metadata: map[string]string{ + "subsystem_label": "Docker Images", + "subsystem": "docker", + }, } // Report the scan log @@ -230,3 +256,550 @@ func handleScanDocker(apiClient *client.Client, cfg *config.Config, ackTracker * return nil } + +// handleUpdateAgent handles agent update commands with signature verification +func handleUpdateAgent(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, params map[string]interface{}, commandID string) error { + log.Println("Processing agent update command...") + + // Extract parameters + version, ok := params["version"].(string) + if !ok { + return fmt.Errorf("missing version parameter") + } + + platform, ok := params["platform"].(string) + if !ok { + return fmt.Errorf("missing platform parameter") + } + + downloadURL, ok := params["download_url"].(string) + if !ok { + return fmt.Errorf("missing download_url parameter") + } + + signature, ok := params["signature"].(string) + if !ok { + return fmt.Errorf("missing signature parameter") + } + + checksum, ok := params["checksum"].(string) + if !ok { + return fmt.Errorf("missing checksum parameter") + } + + // Extract nonce parameters for replay protection + nonceUUIDStr, ok := params["nonce_uuid"].(string) + if !ok { + return fmt.Errorf("missing nonce_uuid parameter") + } + + nonceTimestampStr, ok := params["nonce_timestamp"].(string) + if !ok { + return fmt.Errorf("missing nonce_timestamp parameter") + } + + nonceSignature, ok := params["nonce_signature"].(string) + if !ok { + return fmt.Errorf("missing nonce_signature parameter") + } + + log.Printf("Updating agent to version %s (%s)", version, platform) + + // Validate nonce for replay protection + log.Printf("[tunturi_ed25519] Validating nonce...") + if err := validateNonce(nonceUUIDStr, nonceTimestampStr, nonceSignature); err != nil { + return fmt.Errorf("[tunturi_ed25519] nonce validation failed: %w", err) + } + log.Printf("[tunturi_ed25519] ✓ Nonce validated") + + // Record start time for duration calculation + updateStartTime := time.Now() + + // Report the update command as started + logReport := client.LogReport{ + CommandID: commandID, + Action: "update_agent", + Result: "started", + Stdout: fmt.Sprintf("Starting agent update to version %s\n", version), + Stderr: "", + ExitCode: 0, + DurationSeconds: 0, + Metadata: map[string]string{ + "subsystem_label": "Agent Update", + "subsystem": "agent", + "target_version": version, + }, + } + + if err := reportLogWithAck(apiClient, cfg, ackTracker, logReport); err != nil { + log.Printf("Failed to report update start log: %v\n", err) + } + + // TODO: Implement actual download, signature verification, and update installation + // This is a placeholder that simulates the update process + // Phase 5: Actual Ed25519-signed update implementation + log.Printf("Starting secure update process for version %s", version) + log.Printf("Download URL: %s", downloadURL) + log.Printf("Signature: %s...", signature[:16]) // Log first 16 chars of signature + log.Printf("Expected checksum: %s", checksum) + + // Step 1: Download the update package + log.Printf("Step 1: Downloading update package...") + tempBinaryPath, err := downloadUpdatePackage(downloadURL) + if err != nil { + return fmt.Errorf("failed to download update package: %w", err) + } + defer os.Remove(tempBinaryPath) // Cleanup on exit + + // Step 2: Verify checksum + log.Printf("Step 2: Verifying checksum...") + actualChecksum, err := computeSHA256(tempBinaryPath) + if err != nil { + return fmt.Errorf("failed to compute checksum: %w", err) + } + + if actualChecksum != checksum { + return fmt.Errorf("checksum mismatch: expected %s, got %s", checksum, actualChecksum) + } + log.Printf("✓ Checksum verified: %s", actualChecksum) + + // Step 3: Verify Ed25519 signature + log.Printf("[tunturi_ed25519] Step 3: Verifying Ed25519 signature...") + if err := verifyBinarySignature(tempBinaryPath, signature); err != nil { + return fmt.Errorf("[tunturi_ed25519] signature verification failed: %w", err) + } + log.Printf("[tunturi_ed25519] ✓ Signature verified") + + // Step 4: Create backup of current binary + log.Printf("Step 4: Creating backup...") + currentBinaryPath, err := getCurrentBinaryPath() + if err != nil { + return fmt.Errorf("failed to determine current binary path: %w", err) + } + + backupPath := currentBinaryPath + ".bak" + var updateSuccess bool = false // Track overall success + + if err := createBackup(currentBinaryPath, backupPath); err != nil { + log.Printf("Warning: Failed to create backup: %v", err) + } else { + // Defer rollback/cleanup logic + defer func() { + if !updateSuccess { + // Rollback on failure + log.Printf("[tunturi_ed25519] Rollback: restoring from backup...") + if restoreErr := restoreFromBackup(backupPath, currentBinaryPath); restoreErr != nil { + log.Printf("[tunturi_ed25519] CRITICAL: Failed to restore backup: %v", restoreErr) + } else { + log.Printf("[tunturi_ed25519] ✓ Successfully rolled back to backup") + } + } else { + // Clean up backup on success + log.Printf("[tunturi_ed25519] ✓ Update successful, cleaning up backup") + os.Remove(backupPath) + } + }() + } + + // Step 5: Atomic installation + log.Printf("Step 5: Installing new binary...") + if err := installNewBinary(tempBinaryPath, currentBinaryPath); err != nil { + return fmt.Errorf("failed to install new binary: %w", err) + } + + // Step 6: Restart agent service + log.Printf("Step 6: Restarting agent service...") + if err := restartAgentService(); err != nil { + return fmt.Errorf("failed to restart agent: %w", err) + } + + // Step 7: Watchdog timer for confirmation + log.Printf("Step 7: Starting watchdog for update confirmation...") + updateSuccess = waitForUpdateConfirmation(apiClient, cfg, ackTracker, version, 5*time.Minute) + success := updateSuccess // Alias for logging below + + finalLogReport := client.LogReport{ + CommandID: commandID, + Action: "update_agent", + Result: map[bool]string{true: "success", false: "failure"}[success], + Stdout: fmt.Sprintf("Agent update to version %s %s\n", version, map[bool]string{true: "completed successfully", false: "failed"}[success]), + Stderr: map[bool]string{true: "", false: "Update verification timeout or restart failure"}[success], + ExitCode: map[bool]int{true: 0, false: 1}[success], + DurationSeconds: int(time.Since(updateStartTime).Seconds()), + Metadata: map[string]string{ + "subsystem_label": "Agent Update", + "subsystem": "agent", + "target_version": version, + "success": map[bool]string{true: "true", false: "false"}[success], + }, + } + + if err := reportLogWithAck(apiClient, cfg, ackTracker, finalLogReport); err != nil { + log.Printf("Failed to report update completion log: %v\n", err) + } + + if success { + log.Printf("✓ Agent successfully updated to version %s", version) + } else { + return fmt.Errorf("agent update verification failed") + } + + return nil +} + +// Helper functions for the update process + +func downloadUpdatePackage(downloadURL string) (string, error) { + // Download to temporary file + tempFile, err := os.CreateTemp("", "redflag-update-*.bin") + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + defer tempFile.Close() + + resp, err := http.Get(downloadURL) + if err != nil { + return "", fmt.Errorf("failed to download: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("download failed with status: %d", resp.StatusCode) + } + + if _, err := tempFile.ReadFrom(resp.Body); err != nil { + return "", fmt.Errorf("failed to write download: %w", err) + } + + return tempFile.Name(), nil +} + +func computeSHA256(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return "", fmt.Errorf("failed to compute hash: %w", err) + } + + return hex.EncodeToString(hash.Sum(nil)), nil +} + +func getCurrentBinaryPath() (string, error) { + execPath, err := os.Executable() + if err != nil { + return "", fmt.Errorf("failed to get executable path: %w", err) + } + return execPath, nil +} + +func createBackup(src, dst string) error { + srcFile, err := os.Open(src) + if err != nil { + return fmt.Errorf("failed to open source: %w", err) + } + defer srcFile.Close() + + dstFile, err := os.Create(dst) + if err != nil { + return fmt.Errorf("failed to create backup: %w", err) + } + defer dstFile.Close() + + if _, err := dstFile.ReadFrom(srcFile); err != nil { + return fmt.Errorf("failed to copy backup: %w", err) + } + + // Ensure backup is executable + if err := os.Chmod(dst, 0755); err != nil { + return fmt.Errorf("failed to set backup permissions: %w", err) + } + + return nil +} + +func restoreFromBackup(backup, target string) error { + // Remove current binary if it exists + if _, err := os.Stat(target); err == nil { + if err := os.Remove(target); err != nil { + return fmt.Errorf("failed to remove current binary: %w", err) + } + } + + // Copy backup to target + return createBackup(backup, target) +} + +func installNewBinary(src, dst string) error { + // Copy new binary to a temporary location first + tempDst := dst + ".new" + + srcFile, err := os.Open(src) + if err != nil { + return fmt.Errorf("failed to open source binary: %w", err) + } + defer srcFile.Close() + + dstFile, err := os.Create(tempDst) + if err != nil { + return fmt.Errorf("failed to create temp binary: %w", err) + } + defer dstFile.Close() + + if _, err := dstFile.ReadFrom(srcFile); err != nil { + return fmt.Errorf("failed to copy binary: %w", err) + } + dstFile.Close() + + // Set executable permissions + if err := os.Chmod(tempDst, 0755); err != nil { + return fmt.Errorf("failed to set binary permissions: %w", err) + } + + // Atomic rename + if err := os.Rename(tempDst, dst); err != nil { + os.Remove(tempDst) // Cleanup temp file + return fmt.Errorf("failed to atomically replace binary: %w", err) + } + + return nil +} + +func restartAgentService() error { + var cmd *exec.Cmd + + switch runtime.GOOS { + case "linux": + // Try systemd first + cmd = exec.Command("systemctl", "restart", "redflag-agent") + if err := cmd.Run(); err == nil { + log.Printf("✓ Systemd service restarted") + return nil + } + // Fallback to service command + cmd = exec.Command("service", "redflag-agent", "restart") + + case "windows": + cmd = exec.Command("sc", "stop", "RedFlagAgent") + cmd.Run() + cmd = exec.Command("sc", "start", "RedFlagAgent") + + default: + return fmt.Errorf("unsupported OS for service restart: %s", runtime.GOOS) + } + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to restart service: %w", err) + } + + log.Printf("✓ Agent service restarted") + return nil +} + +func waitForUpdateConfirmation(apiClient *client.Client, cfg *config.Config, ackTracker *acknowledgment.Tracker, expectedVersion string, timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + pollInterval := 15 * time.Second + + log.Printf("[tunturi_ed25519] Watchdog: waiting for version %s confirmation (timeout: %v)...", expectedVersion, timeout) + + for time.Now().Before(deadline) { + // Poll server for current agent version + agent, err := apiClient.GetAgent(cfg.AgentID.String()) + if err != nil { + log.Printf("[tunturi_ed25519] Watchdog: failed to poll server: %v (retrying...)", err) + time.Sleep(pollInterval) + continue + } + + // Check if the version matches the expected version + if agent != nil && agent.CurrentVersion == expectedVersion { + log.Printf("[tunturi_ed25519] Watchdog: ✓ Version confirmed: %s", expectedVersion) + return true + } + + log.Printf("[tunturi_ed25519] Watchdog: Current version: %s, Expected: %s (polling...)", + agent.CurrentVersion, expectedVersion) + time.Sleep(pollInterval) + } + + log.Printf("[tunturi_ed25519] Watchdog: ✗ Timeout after %v - version not confirmed", timeout) + log.Printf("[tunturi_ed25519] Rollback initiated") + return false +} + +// AES-256-GCM decryption helper functions for encrypted update packages + +// deriveKeyFromNonce derives an AES-256 key from a nonce using SHA-256 +func deriveKeyFromNonce(nonce string) []byte { + hash := sha256.Sum256([]byte(nonce)) + return hash[:] // 32 bytes for AES-256 +} + +// decryptAES256GCM decrypts data using AES-256-GCM with the provided nonce-derived key +func decryptAES256GCM(encryptedData, nonce string) ([]byte, error) { + // Derive key from nonce + key := deriveKeyFromNonce(nonce) + + // Decode hex data + data, err := hex.DecodeString(encryptedData) + if err != nil { + return nil, fmt.Errorf("failed to decode hex data: %w", err) + } + + // Create AES cipher + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("failed to create AES cipher: %w", err) + } + + // Create GCM + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + // Check minimum length + nonceSize := gcm.NonceSize() + if len(data) < nonceSize { + return nil, fmt.Errorf("encrypted data too short") + } + + // Extract nonce and ciphertext + nonceBytes, ciphertext := data[:nonceSize], data[nonceSize:] + + // Decrypt + plaintext, err := gcm.Open(nil, nonceBytes, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("failed to decrypt: %w", err) + } + + return plaintext, nil +} + +// TODO: Integration with system/machine_id.go for key derivation +// This stub should be integrated with the existing machine ID system +// for more sophisticated key management based on hardware fingerprinting +// +// Example integration approach: +// - Use machine_id.go to generate stable hardware fingerprint +// - Combine hardware fingerprint with nonce for key derivation +// - Store derived keys securely in memory only +// - Implement key rotation support for long-running agents + +// verifyBinarySignature verifies the Ed25519 signature of a binary file +func verifyBinarySignature(binaryPath, signatureHex string) error { + // Get the server public key from cache + publicKey, err := getServerPublicKey() + if err != nil { + return fmt.Errorf("failed to get server public key: %w", err) + } + + // Read the binary content + content, err := os.ReadFile(binaryPath) + if err != nil { + return fmt.Errorf("failed to read binary: %w", err) + } + + // Decode signature from hex + signatureBytes, err := hex.DecodeString(signatureHex) + if err != nil { + return fmt.Errorf("failed to decode signature: %w", err) + } + + // Verify signature length + if len(signatureBytes) != ed25519.SignatureSize { + return fmt.Errorf("invalid signature length: expected %d bytes, got %d", ed25519.SignatureSize, len(signatureBytes)) + } + + // Ed25519 verification + valid := ed25519.Verify(ed25519.PublicKey(publicKey), content, signatureBytes) + if !valid { + return fmt.Errorf("signature verification failed: invalid signature") + } + + return nil +} + +// getServerPublicKey retrieves the Ed25519 public key from cache +// The key is fetched from the server at startup and cached locally +func getServerPublicKey() ([]byte, error) { + // Load from cache (fetched during agent startup) + publicKey, err := loadCachedPublicKeyDirect() + if err != nil { + return nil, fmt.Errorf("failed to load server public key: %w (hint: key is fetched at agent startup)", err) + } + + return publicKey, nil +} + +// loadCachedPublicKeyDirect loads the cached public key from the standard location +func loadCachedPublicKeyDirect() ([]byte, error) { + var keyPath string + if runtime.GOOS == "windows" { + keyPath = "C:\\ProgramData\\RedFlag\\server_public_key" + } else { + keyPath = "/etc/aggregator/server_public_key" + } + + data, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("public key not found: %w", err) + } + + if len(data) != 32 { // ed25519.PublicKeySize + return nil, fmt.Errorf("invalid public key size: expected 32 bytes, got %d", len(data)) + } + + return data, nil +} + +// validateNonce validates the nonce for replay protection +func validateNonce(nonceUUIDStr, nonceTimestampStr, nonceSignature string) error { + // Parse timestamp + nonceTimestamp, err := time.Parse(time.RFC3339, nonceTimestampStr) + if err != nil { + return fmt.Errorf("invalid nonce timestamp format: %w", err) + } + + // Check freshness (< 5 minutes) + age := time.Since(nonceTimestamp) + if age > 5*time.Minute { + return fmt.Errorf("nonce expired: age %v > 5 minutes", age) + } + + if age < 0 { + return fmt.Errorf("nonce timestamp in the future: %v", nonceTimestamp) + } + + // Get server public key from cache + publicKey, err := getServerPublicKey() + if err != nil { + return fmt.Errorf("failed to get server public key: %w", err) + } + + // Recreate nonce data (must match server format) + nonceData := fmt.Sprintf("%s:%d", nonceUUIDStr, nonceTimestamp.Unix()) + + // Decode signature + signatureBytes, err := hex.DecodeString(nonceSignature) + if err != nil { + return fmt.Errorf("invalid nonce signature format: %w", err) + } + + if len(signatureBytes) != ed25519.SignatureSize { + return fmt.Errorf("invalid nonce signature length: expected %d bytes, got %d", + ed25519.SignatureSize, len(signatureBytes)) + } + + // Verify Ed25519 signature + valid := ed25519.Verify(ed25519.PublicKey(publicKey), []byte(nonceData), signatureBytes) + if !valid { + return fmt.Errorf("invalid nonce signature") + } + + return nil +} diff --git a/aggregator-agent/go.mod b/aggregator-agent/go.mod index 671578e..f2a5f96 100644 --- a/aggregator-agent/go.mod +++ b/aggregator-agent/go.mod @@ -3,10 +3,12 @@ module github.com/Fimeg/RedFlag/aggregator-agent go 1.23.0 require ( + github.com/denisbrodbeck/machineid v1.0.1 github.com/docker/docker v27.4.1+incompatible github.com/go-ole/go-ole v1.3.0 github.com/google/uuid v1.6.0 github.com/scjalliance/comshim v0.0.0-20250111221056-b2ef9d8d7e0f + golang.org/x/sys v0.35.0 ) require ( @@ -31,7 +33,6 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 // indirect go.opentelemetry.io/otel/metric v1.38.0 // indirect go.opentelemetry.io/otel/trace v1.38.0 // indirect - golang.org/x/sys v0.35.0 // indirect golang.org/x/time v0.5.0 // indirect gotest.tools/v3 v3.5.2 // indirect ) diff --git a/aggregator-agent/go.sum b/aggregator-agent/go.sum index 2840a43..88cf721 100644 --- a/aggregator-agent/go.sum +++ b/aggregator-agent/go.sum @@ -8,6 +8,8 @@ github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/denisbrodbeck/machineid v1.0.1 h1:geKr9qtkB876mXguW2X6TU4ZynleN6ezuMSRhl4D7AQ= +github.com/denisbrodbeck/machineid v1.0.1/go.mod h1:dJUwb7PTidGDeYyUBmXZ2GphQBbjJCrnectwCyxcUSI= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/docker/docker v27.4.1+incompatible h1:ZJvcY7gfwHn1JF48PfbyXg7Jyt9ZCWDW+GGXOIxEwp4= diff --git a/aggregator-agent/internal/client/client.go b/aggregator-agent/internal/client/client.go index 75565f3..d70bf95 100644 --- a/aggregator-agent/internal/client/client.go +++ b/aggregator-agent/internal/client/client.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/Fimeg/RedFlag/aggregator-agent/internal/system" "github.com/google/uuid" ) @@ -21,19 +22,36 @@ type Client struct { http *http.Client RapidPollingEnabled bool RapidPollingUntil time.Time + machineID string // Cached machine ID for security binding } // NewClient creates a new API client func NewClient(baseURL, token string) *Client { + // Get machine ID for security binding (v0.1.22+) + machineID, err := system.GetMachineID() + if err != nil { + // Log warning but don't fail - older servers may not require it + fmt.Printf("Warning: Failed to get machine ID: %v\n", err) + machineID = "" // Will be handled by server validation + } + return &Client{ - baseURL: baseURL, - token: token, + baseURL: baseURL, + token: token, + machineID: machineID, http: &http.Client{ Timeout: 30 * time.Second, }, } } +// addMachineIDHeader adds X-Machine-ID header to authenticated requests (v0.1.22+) +func (c *Client) addMachineIDHeader(req *http.Request) { + if c.machineID != "" { + req.Header.Set("X-Machine-ID", c.machineID) + } +} + // GetToken returns the current JWT token func (c *Client) GetToken() string { return c.token @@ -46,13 +64,15 @@ func (c *Client) SetToken(token string) { // RegisterRequest is the payload for agent registration type RegisterRequest struct { - Hostname string `json:"hostname"` - OSType string `json:"os_type"` - OSVersion string `json:"os_version"` - OSArchitecture string `json:"os_architecture"` - AgentVersion string `json:"agent_version"` - RegistrationToken string `json:"registration_token,omitempty"` // Fallback method - Metadata map[string]string `json:"metadata"` + Hostname string `json:"hostname"` + OSType string `json:"os_type"` + OSVersion string `json:"os_version"` + OSArchitecture string `json:"os_architecture"` + AgentVersion string `json:"agent_version"` + RegistrationToken string `json:"registration_token,omitempty"` // Fallback method + MachineID string `json:"machine_id"` + PublicKeyFingerprint string `json:"public_key_fingerprint"` + Metadata map[string]string `json:"metadata"` } // RegisterResponse is returned after successful registration @@ -230,6 +250,7 @@ func (c *Client) GetCommands(agentID uuid.UUID, metrics *SystemMetrics) (*Comman } req.Header.Set("Authorization", "Bearer "+c.token) + c.addMachineIDHeader(req) // Security: Validate machine binding (v0.1.22+) resp, err := c.http.Do(req) if err != nil { @@ -297,6 +318,7 @@ func (c *Client) ReportUpdates(agentID uuid.UUID, report UpdateReport) error { } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.token) + c.addMachineIDHeader(req) // Security: Validate machine binding (v0.1.22+) resp, err := c.http.Do(req) if err != nil { @@ -314,13 +336,14 @@ func (c *Client) ReportUpdates(agentID uuid.UUID, report UpdateReport) error { // LogReport represents an execution log type LogReport struct { - CommandID string `json:"command_id"` - Action string `json:"action"` - Result string `json:"result"` - Stdout string `json:"stdout"` - Stderr string `json:"stderr"` - ExitCode int `json:"exit_code"` - DurationSeconds int `json:"duration_seconds"` + CommandID string `json:"command_id"` + Action string `json:"action"` + Result string `json:"result"` + Stdout string `json:"stdout"` + Stderr string `json:"stderr"` + ExitCode int `json:"exit_code"` + DurationSeconds int `json:"duration_seconds"` + Metadata map[string]string `json:"metadata,omitempty"` } // ReportLog sends an execution log to the server @@ -338,6 +361,7 @@ func (c *Client) ReportLog(agentID uuid.UUID, report LogReport) error { } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.token) + c.addMachineIDHeader(req) // Security: Validate machine binding (v0.1.22+) resp, err := c.http.Do(req) if err != nil { @@ -392,6 +416,7 @@ func (c *Client) ReportDependencies(agentID uuid.UUID, report DependencyReport) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.token) + c.addMachineIDHeader(req) // Security: Validate machine binding (v0.1.22+) resp, err := c.http.Do(req) if err != nil { @@ -437,6 +462,7 @@ func (c *Client) ReportSystemInfo(agentID uuid.UUID, report SystemInfoReport) er } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.token) + c.addMachineIDHeader(req) // Security: Validate machine binding (v0.1.22+) resp, err := c.http.Do(req) if err != nil { @@ -474,6 +500,49 @@ func DetectSystem() (osType, osVersion, osArch string) { return } +// AgentInfo represents agent information from the server +type AgentInfo struct { + ID string `json:"id"` + Hostname string `json:"hostname"` + CurrentVersion string `json:"current_version"` + OSType string `json:"os_type"` + OSVersion string `json:"os_version"` + OSArchitecture string `json:"os_architecture"` + LastCheckIn string `json:"last_check_in"` +} + +// GetAgent retrieves agent information from the server +func (c *Client) GetAgent(agentID string) (*AgentInfo, error) { + url := fmt.Sprintf("%s/api/v1/agents/%s", c.baseURL, agentID) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+c.token) + req.Header.Set("Content-Type", "application/json") + c.addMachineIDHeader(req) // Security: Validate machine binding (v0.1.22+) + + resp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("server returned status %d: %s", resp.StatusCode, string(body)) + } + + var agent AgentInfo + if err := json.NewDecoder(resp.Body).Decode(&agent); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return &agent, nil +} + // parseOSRelease parses /etc/os-release to get proper distro name func parseOSRelease(data []byte) string { lines := strings.Split(string(data), "\n") diff --git a/aggregator-agent/internal/crypto/pubkey.go b/aggregator-agent/internal/crypto/pubkey.go new file mode 100644 index 0000000..53269e7 --- /dev/null +++ b/aggregator-agent/internal/crypto/pubkey.go @@ -0,0 +1,130 @@ +package crypto + +import ( + "crypto/ed25519" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" +) + +// getPublicKeyPath returns the platform-specific path for storing the server's public key +func getPublicKeyPath() string { + if runtime.GOOS == "windows" { + return "C:\\ProgramData\\RedFlag\\server_public_key" + } + return "/etc/aggregator/server_public_key" +} + +// PublicKeyResponse represents the server's public key response +type PublicKeyResponse struct { + PublicKey string `json:"public_key"` + Fingerprint string `json:"fingerprint"` + Algorithm string `json:"algorithm"` + KeySize int `json:"key_size"` +} + +// FetchAndCacheServerPublicKey fetches the server's Ed25519 public key and caches it locally +// This implements Trust-On-First-Use (TOFU) security model +func FetchAndCacheServerPublicKey(serverURL string) (ed25519.PublicKey, error) { + // Check if we already have a cached key + if cachedKey, err := LoadCachedPublicKey(); err == nil && cachedKey != nil { + return cachedKey, nil + } + + // Fetch from server + resp, err := http.Get(serverURL + "/api/v1/public-key") + if err != nil { + return nil, fmt.Errorf("failed to fetch public key from server: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("server returned status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var keyResp PublicKeyResponse + if err := json.NewDecoder(resp.Body).Decode(&keyResp); err != nil { + return nil, fmt.Errorf("failed to parse public key response: %w", err) + } + + // Validate algorithm + if keyResp.Algorithm != "ed25519" { + return nil, fmt.Errorf("unsupported signature algorithm: %s (expected ed25519)", keyResp.Algorithm) + } + + // Decode hex public key + pubKeyBytes, err := hex.DecodeString(keyResp.PublicKey) + if err != nil { + return nil, fmt.Errorf("invalid public key format: %w", err) + } + + if len(pubKeyBytes) != ed25519.PublicKeySize { + return nil, fmt.Errorf("invalid public key size: expected %d bytes, got %d", + ed25519.PublicKeySize, len(pubKeyBytes)) + } + + publicKey := ed25519.PublicKey(pubKeyBytes) + + // Cache it for future use + if err := cachePublicKey(publicKey); err != nil { + // Log warning but don't fail - we have the key in memory + fmt.Printf("Warning: Failed to cache public key: %v\n", err) + } + + fmt.Printf("✓ Server public key fetched and cached (fingerprint: %s)\n", keyResp.Fingerprint) + + return publicKey, nil +} + +// LoadCachedPublicKey loads the cached public key from disk +func LoadCachedPublicKey() (ed25519.PublicKey, error) { + keyPath := getPublicKeyPath() + + data, err := os.ReadFile(keyPath) + if err != nil { + return nil, err // File doesn't exist or can't be read + } + + if len(data) != ed25519.PublicKeySize { + return nil, fmt.Errorf("cached public key has invalid size: %d bytes", len(data)) + } + + return ed25519.PublicKey(data), nil +} + +// cachePublicKey saves the public key to disk +func cachePublicKey(publicKey ed25519.PublicKey) error { + keyPath := getPublicKeyPath() + + // Ensure directory exists + dir := filepath.Dir(keyPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + // Write public key (read-only for non-root users) + if err := os.WriteFile(keyPath, publicKey, 0644); err != nil { + return fmt.Errorf("failed to write public key: %w", err) + } + + return nil +} + +// GetPublicKey returns the cached public key or fetches it from the server +// This is the main entry point for getting the verification key +func GetPublicKey(serverURL string) (ed25519.PublicKey, error) { + // Try cached key first + if cachedKey, err := LoadCachedPublicKey(); err == nil { + return cachedKey, nil + } + + // Fetch from server if not cached + return FetchAndCacheServerPublicKey(serverURL) +} diff --git a/aggregator-agent/internal/service/windows.go b/aggregator-agent/internal/service/windows.go index b9f4be6..a593106 100644 --- a/aggregator-agent/internal/service/windows.go +++ b/aggregator-agent/internal/service/windows.go @@ -209,6 +209,13 @@ func (s *redflagService) runAgent() { } } + // Check if commands response is valid + if commands == nil { + log.Printf("Check-in successful - no commands received (nil response)") + elog.Info(1, "Check-in successful - no commands received (nil response)") + continue + } + if len(commands.Commands) == 0 { log.Printf("Check-in successful - no new commands") elog.Info(1, "Check-in successful - no new commands") diff --git a/aggregator-agent/internal/system/machine_id.go b/aggregator-agent/internal/system/machine_id.go new file mode 100644 index 0000000..73db35c --- /dev/null +++ b/aggregator-agent/internal/system/machine_id.go @@ -0,0 +1,129 @@ +package system + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "os" + "runtime" + "strings" + + "github.com/denisbrodbeck/machineid" +) + +// GetMachineID generates a unique machine identifier that persists across reboots +func GetMachineID() (string, error) { + // Try machineid library first (cross-platform) + id, err := machineid.ID() + if err == nil && id != "" { + // Hash the machine ID for consistency and privacy + return hashMachineID(id), nil + } + + // Fallback to OS-specific methods + switch runtime.GOOS { + case "linux": + return getLinuxMachineID() + case "windows": + return getWindowsMachineID() + case "darwin": + return getDarwinMachineID() + default: + return generateGenericMachineID() + } +} + +// hashMachineID creates a consistent hash from machine ID +func hashMachineID(id string) string { + hash := sha256.Sum256([]byte(id)) + return hex.EncodeToString(hash[:]) // Return full hash for uniqueness +} + +// getLinuxMachineID tries multiple sources for Linux machine ID +func getLinuxMachineID() (string, error) { + // Try /etc/machine-id first (systemd) + if id, err := os.ReadFile("/etc/machine-id"); err == nil { + idStr := strings.TrimSpace(string(id)) + if idStr != "" { + return hashMachineID(idStr), nil + } + } + + // Try /var/lib/dbus/machine-id + if id, err := os.ReadFile("/var/lib/dbus/machine-id"); err == nil { + idStr := strings.TrimSpace(string(id)) + if idStr != "" { + return hashMachineID(idStr), nil + } + } + + // Try DMI product UUID + if id, err := os.ReadFile("/sys/class/dmi/id/product_uuid"); err == nil { + idStr := strings.TrimSpace(string(id)) + if idStr != "" { + return hashMachineID(idStr), nil + } + } + + // Try /etc/hostname as last resort + if hostname, err := os.ReadFile("/etc/hostname"); err == nil { + hostnameStr := strings.TrimSpace(string(hostname)) + if hostnameStr != "" { + return hashMachineID(hostnameStr + "-linux-fallback"), nil + } + } + + return generateGenericMachineID() +} + +// getWindowsMachineID gets Windows machine ID +func getWindowsMachineID() (string, error) { + // Try machineid library Windows registry keys first + if id, err := machineid.ID(); err == nil && id != "" { + return hashMachineID(id), nil + } + + // Fallback to generating generic ID + return generateGenericMachineID() +} + +// getDarwinMachineID gets macOS machine ID +func getDarwinMachineID() (string, error) { + // Try machineid library platform-specific keys first + if id, err := machineid.ID(); err == nil && id != "" { + return hashMachineID(id), nil + } + + // Fallback to generating generic ID + return generateGenericMachineID() +} + +// generateGenericMachineID creates a fallback machine ID from available system info +func generateGenericMachineID() (string, error) { + // Combine hostname with other available info + hostname, _ := os.Hostname() + if hostname == "" { + hostname = "unknown" + } + + // Create a reasonably unique ID from available system info + idSource := fmt.Sprintf("%s-%s-%s", hostname, runtime.GOOS, runtime.GOARCH) + return hashMachineID(idSource), nil +} + +// GetEmbeddedPublicKey returns the embedded public key fingerprint +// This should be set at build time using ldflags +var EmbeddedPublicKey = "not-set-at-build-time" + +// GetPublicKeyFingerprint returns the fingerprint of the embedded public key +func GetPublicKeyFingerprint() string { + if EmbeddedPublicKey == "not-set-at-build-time" { + return "" + } + + // Return first 8 bytes as fingerprint + if len(EmbeddedPublicKey) >= 16 { + return EmbeddedPublicKey[:16] + } + return EmbeddedPublicKey +} \ No newline at end of file diff --git a/aggregator-agent/test-redflag-agent b/aggregator-agent/test-redflag-agent new file mode 100755 index 0000000..34ab035 Binary files /dev/null and b/aggregator-agent/test-redflag-agent differ diff --git a/aggregator-server/cmd/server/main.go b/aggregator-server/cmd/server/main.go index 7a2e824..5236009 100644 --- a/aggregator-server/cmd/server/main.go +++ b/aggregator-server/cmd/server/main.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "log" + "net/http" "path/filepath" "time" @@ -129,6 +130,7 @@ func main() { registrationTokenQueries := queries.NewRegistrationTokenQueries(db.DB) userQueries := queries.NewUserQueries(db.DB) subsystemQueries := queries.NewSubsystemQueries(db.DB) + agentUpdateQueries := queries.NewAgentUpdateQueries(db.DB) // Ensure admin user exists if err := userQueries.EnsureAdminUser(cfg.Admin.Username, cfg.Admin.Username+"@redflag.local", cfg.Admin.Password); err != nil { @@ -141,6 +143,20 @@ func main() { timezoneService := services.NewTimezoneService(cfg) timeoutService := services.NewTimeoutService(commandQueries, updateQueries) + // Initialize signing service if private key is configured + var signingService *services.SigningService + if cfg.SigningPrivateKey != "" { + var err error + signingService, err = services.NewSigningService(cfg.SigningPrivateKey) + if err != nil { + log.Printf("Warning: Failed to initialize signing service: %v", err) + } else { + log.Printf("✅ Ed25519 signing service initialized") + } + } else { + log.Printf("Warning: No signing private key configured - agent update signing disabled") + } + // Initialize rate limiter rateLimiter := middleware.NewRateLimiter() @@ -156,6 +172,21 @@ func main() { downloadHandler := handlers.NewDownloadHandler(filepath.Join("/app"), cfg) subsystemHandler := handlers.NewSubsystemHandler(subsystemQueries, commandQueries) + // Initialize verification handler + var verificationHandler *handlers.VerificationHandler + if signingService != nil { + verificationHandler = handlers.NewVerificationHandler(agentQueries, signingService) + } + + // Initialize agent update handler + var agentUpdateHandler *handlers.AgentUpdateHandler + if signingService != nil { + agentUpdateHandler = handlers.NewAgentUpdateHandler(agentQueries, agentUpdateQueries, commandQueries, signingService, agentHandler) + } + + // Initialize system handler + systemHandler := handlers.NewSystemHandler(signingService) + // Setup router router := gin.Default() @@ -178,17 +209,23 @@ func main() { api.POST("/auth/logout", authHandler.Logout) api.GET("/auth/verify", authHandler.VerifyToken) + // Public system routes (no authentication required) + api.GET("/public-key", rateLimiter.RateLimit("public_access", middleware.KeyByIP), systemHandler.GetPublicKey) + api.GET("/info", rateLimiter.RateLimit("public_access", middleware.KeyByIP), systemHandler.GetSystemInfo) + // Public routes (no authentication required, with rate limiting) api.POST("/agents/register", rateLimiter.RateLimit("agent_registration", middleware.KeyByIP), agentHandler.RegisterAgent) api.POST("/agents/renew", rateLimiter.RateLimit("public_access", middleware.KeyByIP), agentHandler.RenewToken) // Public download routes (no authentication - agents need these!) api.GET("/downloads/:platform", rateLimiter.RateLimit("public_access", middleware.KeyByIP), downloadHandler.DownloadAgent) + api.GET("/downloads/updates/:package_id", rateLimiter.RateLimit("public_access", middleware.KeyByIP), downloadHandler.DownloadUpdatePackage) api.GET("/install/:platform", rateLimiter.RateLimit("public_access", middleware.KeyByIP), downloadHandler.InstallScript) - // Protected agent routes + // Protected agent routes (with machine binding security) agents := api.Group("/agents") agents.Use(middleware.AuthMiddleware()) + agents.Use(middleware.MachineBindingMiddleware(agentQueries, cfg.MinAgentVersion)) // v0.1.22: Prevent config copying { agents.GET("/:id/commands", agentHandler.GetCommands) agents.POST("/:id/updates", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), updateHandler.ReportUpdates) @@ -196,6 +233,13 @@ func main() { agents.POST("/:id/dependencies", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), updateHandler.ReportDependencies) agents.POST("/:id/system-info", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), agentHandler.ReportSystemInfo) agents.POST("/:id/rapid-mode", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), agentHandler.SetRapidPollingMode) + agents.POST("/:id/verify-signature", rateLimiter.RateLimit("agent_reports", middleware.KeyByAgentID), func(c *gin.Context) { + if verificationHandler == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "signature verification service not available"}) + return + } + verificationHandler.VerifySignature(c) + }) agents.DELETE("/:id", agentHandler.UnregisterAgent) // Subsystem routes @@ -231,6 +275,14 @@ func main() { dashboard.POST("/updates/:id/install", updateHandler.InstallUpdate) dashboard.POST("/updates/:id/confirm-dependencies", updateHandler.ConfirmDependencies) + // Agent update routes + if agentUpdateHandler != nil { + dashboard.POST("/agents/:id/update", agentUpdateHandler.UpdateAgent) + dashboard.POST("/agents/bulk-update", agentUpdateHandler.BulkUpdateAgents) + dashboard.GET("/updates/packages", agentUpdateHandler.ListUpdatePackages) + dashboard.POST("/updates/packages/sign", agentUpdateHandler.SignUpdatePackage) + } + // Log routes dashboard.GET("/logs", updateHandler.GetAllLogs) dashboard.GET("/logs/active", updateHandler.GetActiveOperations) diff --git a/aggregator-server/go.mod b/aggregator-server/go.mod index c835911..f1dde0f 100644 --- a/aggregator-server/go.mod +++ b/aggregator-server/go.mod @@ -7,9 +7,8 @@ require ( github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/jmoiron/sqlx v1.4.0 - github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 - golang.org/x/term v0.33.0 + golang.org/x/crypto v0.40.0 ) require ( @@ -36,7 +35,6 @@ require ( github.com/ugorji/go/codec v1.3.0 // indirect go.uber.org/mock v0.5.0 // indirect golang.org/x/arch v0.20.0 // indirect - golang.org/x/crypto v0.40.0 // indirect golang.org/x/mod v0.25.0 // indirect golang.org/x/net v0.42.0 // indirect golang.org/x/sync v0.16.0 // indirect diff --git a/aggregator-server/go.sum b/aggregator-server/go.sum index bc4792e..0cdc5c4 100644 --- a/aggregator-server/go.sum +++ b/aggregator-server/go.sum @@ -38,8 +38,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= -github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= -github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= @@ -92,8 +90,6 @@ golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= -golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/aggregator-server/internal/api/handlers/agent_updates.go b/aggregator-server/internal/api/handlers/agent_updates.go new file mode 100644 index 0000000..326ed5e --- /dev/null +++ b/aggregator-server/internal/api/handlers/agent_updates.go @@ -0,0 +1,401 @@ +package handlers + +import ( + "fmt" + "log" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries" + "github.com/Fimeg/RedFlag/aggregator-server/internal/models" + "github.com/Fimeg/RedFlag/aggregator-server/internal/services" + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +// AgentUpdateHandler handles agent update operations +type AgentUpdateHandler struct { + agentQueries *queries.AgentQueries + agentUpdateQueries *queries.AgentUpdateQueries + commandQueries *queries.CommandQueries + signingService *services.SigningService + agentHandler *AgentHandler +} + +// NewAgentUpdateHandler creates a new agent update handler +func NewAgentUpdateHandler(aq *queries.AgentQueries, auq *queries.AgentUpdateQueries, cq *queries.CommandQueries, ss *services.SigningService, ah *AgentHandler) *AgentUpdateHandler { + return &AgentUpdateHandler{ + agentQueries: aq, + agentUpdateQueries: auq, + commandQueries: cq, + signingService: ss, + agentHandler: ah, + } +} + +// UpdateAgent handles POST /api/v1/agents/:id/update (manual agent update) +func (h *AgentUpdateHandler) UpdateAgent(c *gin.Context) { + var req models.AgentUpdateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Verify the agent exists + agent, err := h.agentQueries.GetAgentByID(req.AgentID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"}) + return + } + + // Check if agent is already updating + if agent.IsUpdating { + c.JSON(http.StatusConflict, gin.H{ + "error": "agent is already updating", + "current_update": agent.UpdatingToVersion, + "initiated_at": agent.UpdateInitiatedAt, + }) + return + } + + // Validate platform compatibility + if !h.isPlatformCompatible(agent, req.Platform) { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("platform %s is not compatible with agent %s/%s", + req.Platform, agent.OSType, agent.OSArchitecture), + }) + return + } + + // Get the update package + pkg, err := h.agentUpdateQueries.GetUpdatePackageByVersion(req.Version, req.Platform, agent.OSArchitecture) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("update package not found: %v", err)}) + return + } + + // Update agent status to "updating" + if err := h.agentQueries.UpdateAgentUpdatingStatus(req.AgentID, true, &req.Version); err != nil { + log.Printf("Failed to update agent %s status to updating: %v", req.AgentID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate update"}) + return + } + + // Generate nonce for replay protection + nonceUUID := uuid.New() + nonceTimestamp := time.Now() + var nonceSignature string + if h.signingService != nil { + var err error + nonceSignature, err = h.signingService.SignNonce(nonceUUID, nonceTimestamp) + if err != nil { + log.Printf("Failed to sign nonce: %v", err) + h.agentQueries.UpdateAgentUpdatingStatus(req.AgentID, false, nil) // Rollback + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to sign nonce"}) + return + } + } + + // Create update command for agent + commandType := "update_agent" + commandParams := map[string]interface{}{ + "version": req.Version, + "platform": req.Platform, + "download_url": fmt.Sprintf("/api/v1/downloads/updates/%s", pkg.ID), + "signature": pkg.Signature, + "checksum": pkg.Checksum, + "file_size": pkg.FileSize, + "nonce_uuid": nonceUUID.String(), + "nonce_timestamp": nonceTimestamp.Format(time.RFC3339), + "nonce_signature": nonceSignature, + } + + // Schedule the update if requested + if req.Scheduled != nil { + scheduledTime, err := time.Parse(time.RFC3339, *req.Scheduled) + if err != nil { + h.agentQueries.UpdateAgentUpdatingStatus(req.AgentID, false, nil) // Rollback + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid scheduled time format"}) + return + } + commandParams["scheduled_at"] = scheduledTime + } + + // Create the command in database + command := &models.AgentCommand{ + ID: uuid.New(), + AgentID: req.AgentID, + CommandType: commandType, + Params: commandParams, + Status: models.CommandStatusPending, + Source: "web_ui", + CreatedAt: time.Now(), + } + + if err := h.commandQueries.CreateCommand(command); err != nil { + // Rollback the updating status + h.agentQueries.UpdateAgentUpdatingStatus(req.AgentID, false, nil) + log.Printf("Failed to create update command for agent %s: %v", req.AgentID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create command"}) + return + } + + log.Printf("✅ Agent update initiated for %s: %s (%s)", agent.Hostname, req.Version, req.Platform) + + response := models.AgentUpdateResponse{ + Message: "Update initiated successfully", + UpdateID: command.ID.String(), + DownloadURL: fmt.Sprintf("/api/v1/downloads/updates/%s", pkg.ID), + Signature: pkg.Signature, + Checksum: pkg.Checksum, + FileSize: pkg.FileSize, + EstimatedTime: h.estimateUpdateTime(pkg.FileSize), + } + + c.JSON(http.StatusOK, response) +} + +// BulkUpdateAgents handles POST /api/v1/agents/bulk-update (bulk agent update) +func (h *AgentUpdateHandler) BulkUpdateAgents(c *gin.Context) { + var req models.BulkAgentUpdateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if len(req.AgentIDs) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "no agent IDs provided"}) + return + } + + if len(req.AgentIDs) > 50 { + c.JSON(http.StatusBadRequest, gin.H{"error": "too many agents in bulk update (max 50)"}) + return + } + + // Get the update package first to validate it exists + pkg, err := h.agentUpdateQueries.GetUpdatePackageByVersion(req.Version, req.Platform, "") + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("update package not found: %v", err)}) + return + } + + // Validate all agents exist and are compatible + var results []map[string]interface{} + var errors []string + + for _, agentID := range req.AgentIDs { + agent, err := h.agentQueries.GetAgentByID(agentID) + if err != nil { + errors = append(errors, fmt.Sprintf("Agent %s: not found", agentID)) + continue + } + + if agent.IsUpdating { + errors = append(errors, fmt.Sprintf("Agent %s: already updating", agentID)) + continue + } + + if !h.isPlatformCompatible(agent, req.Platform) { + errors = append(errors, fmt.Sprintf("Agent %s: platform incompatible", agentID)) + continue + } + + // Update agent status + if err := h.agentQueries.UpdateAgentUpdatingStatus(agentID, true, &req.Version); err != nil { + errors = append(errors, fmt.Sprintf("Agent %s: failed to update status", agentID)) + continue + } + + // Generate nonce for replay protection + nonceUUID := uuid.New() + nonceTimestamp := time.Now() + var nonceSignature string + if h.signingService != nil { + var err error + nonceSignature, err = h.signingService.SignNonce(nonceUUID, nonceTimestamp) + if err != nil { + errors = append(errors, fmt.Sprintf("Agent %s: failed to sign nonce", agentID)) + h.agentQueries.UpdateAgentUpdatingStatus(agentID, false, nil) + continue + } + } + + // Create update command + command := &models.AgentCommand{ + ID: uuid.New(), + AgentID: agentID, + CommandType: "update_agent", + Params: map[string]interface{}{ + "version": req.Version, + "platform": req.Platform, + "download_url": fmt.Sprintf("/api/v1/downloads/updates/%s", pkg.ID), + "signature": pkg.Signature, + "checksum": pkg.Checksum, + "file_size": pkg.FileSize, + "nonce_uuid": nonceUUID.String(), + "nonce_timestamp": nonceTimestamp.Format(time.RFC3339), + "nonce_signature": nonceSignature, + }, + Status: models.CommandStatusPending, + Source: "web_ui_bulk", + CreatedAt: time.Now(), + } + + if req.Scheduled != nil { + command.Params["scheduled_at"] = *req.Scheduled + } + + if err := h.commandQueries.CreateCommand(command); err != nil { + // Rollback status + h.agentQueries.UpdateAgentUpdatingStatus(agentID, false, nil) + errors = append(errors, fmt.Sprintf("Agent %s: failed to create command", agentID)) + continue + } + + results = append(results, map[string]interface{}{ + "agent_id": agentID, + "hostname": agent.Hostname, + "update_id": command.ID.String(), + "status": "initiated", + }) + + log.Printf("✅ Bulk update initiated for %s: %s (%s)", agent.Hostname, req.Version, req.Platform) + } + + response := gin.H{ + "message": fmt.Sprintf("Bulk update completed with %d successes and %d failures", len(results), len(errors)), + "updated": results, + "failed": errors, + "total_agents": len(req.AgentIDs), + "package_info": gin.H{ + "version": pkg.Version, + "platform": pkg.Platform, + "file_size": pkg.FileSize, + "checksum": pkg.Checksum, + }, + } + + c.JSON(http.StatusOK, response) +} + +// ListUpdatePackages handles GET /api/v1/updates/packages (list available update packages) +func (h *AgentUpdateHandler) ListUpdatePackages(c *gin.Context) { + version := c.Query("version") + platform := c.Query("platform") + limitStr := c.Query("limit") + offsetStr := c.Query("offset") + + limit := 0 + if limitStr != "" { + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { + limit = l + } + } + + offset := 0 + if offsetStr != "" { + if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 { + offset = o + } + } + + packages, err := h.agentUpdateQueries.ListUpdatePackages(version, platform, limit, offset) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list update packages"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "packages": packages, + "total": len(packages), + "limit": limit, + "offset": offset, + }) +} + +// SignUpdatePackage handles POST /api/v1/updates/packages/sign (sign a new update package) +func (h *AgentUpdateHandler) SignUpdatePackage(c *gin.Context) { + var req struct { + Version string `json:"version" binding:"required"` + Platform string `json:"platform" binding:"required"` + Architecture string `json:"architecture" binding:"required"` + BinaryPath string `json:"binary_path" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if h.signingService == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "signing service not available"}) + return + } + + // Sign the binary + pkg, err := h.signingService.SignFile(req.BinaryPath) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to sign binary: %v", err)}) + return + } + + // Set additional fields + pkg.Version = req.Version + pkg.Platform = req.Platform + pkg.Architecture = req.Architecture + + // Save to database + if err := h.agentUpdateQueries.CreateUpdatePackage(pkg); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to save update package: %v", err)}) + return + } + + log.Printf("✅ Update package signed and saved: %s %s/%s (ID: %s)", + pkg.Version, pkg.Platform, pkg.Architecture, pkg.ID) + + c.JSON(http.StatusOK, gin.H{ + "message": "Update package signed successfully", + "package": pkg, + }) +} + +// isPlatformCompatible checks if the update package is compatible with the agent +func (h *AgentUpdateHandler) isPlatformCompatible(agent *models.Agent, updatePlatform string) bool { + // Normalize platform strings + agentPlatform := strings.ToLower(agent.OSType) + updatePlatform = strings.ToLower(updatePlatform) + + // Check for basic OS compatibility + if !strings.Contains(updatePlatform, agentPlatform) { + return false + } + + // Check architecture compatibility if specified + if strings.Contains(updatePlatform, "amd64") && !strings.Contains(strings.ToLower(agent.OSArchitecture), "amd64") { + return false + } + if strings.Contains(updatePlatform, "arm64") && !strings.Contains(strings.ToLower(agent.OSArchitecture), "arm64") { + return false + } + if strings.Contains(updatePlatform, "386") && !strings.Contains(strings.ToLower(agent.OSArchitecture), "386") { + return false + } + + return true +} + +// estimateUpdateTime estimates how long an update will take based on file size +func (h *AgentUpdateHandler) estimateUpdateTime(fileSize int64) int { + // Rough estimate: 1 second per MB + 30 seconds base time + seconds := int(fileSize/1024/1024) + 30 + + // Cap at 5 minutes + if seconds > 300 { + seconds = 300 + } + + return seconds +} \ No newline at end of file diff --git a/aggregator-server/internal/api/handlers/agents.go b/aggregator-server/internal/api/handlers/agents.go index 586f7cd..0edde0e 100644 --- a/aggregator-server/internal/api/handlers/agents.go +++ b/aggregator-server/internal/api/handlers/agents.go @@ -71,18 +71,30 @@ func (h *AgentHandler) RegisterAgent(c *gin.Context) { return } + // Validate machine ID and public key fingerprint if provided + if req.MachineID != "" { + // Check if machine ID is already registered to another agent + existingAgent, err := h.agentQueries.GetAgentByMachineID(req.MachineID) + if err == nil && existingAgent != nil && existingAgent.ID.String() != "" { + c.JSON(http.StatusConflict, gin.H{"error": "machine ID already registered to another agent"}) + return + } + } + // Create new agent agent := &models.Agent{ - ID: uuid.New(), - Hostname: req.Hostname, - OSType: req.OSType, - OSVersion: req.OSVersion, - OSArchitecture: req.OSArchitecture, - AgentVersion: req.AgentVersion, - CurrentVersion: req.AgentVersion, - LastSeen: time.Now(), - Status: "online", - Metadata: models.JSONB{}, + ID: uuid.New(), + Hostname: req.Hostname, + OSType: req.OSType, + OSVersion: req.OSVersion, + OSArchitecture: req.OSArchitecture, + AgentVersion: req.AgentVersion, + CurrentVersion: req.AgentVersion, + MachineID: &req.MachineID, + PublicKeyFingerprint: &req.PublicKeyFingerprint, + LastSeen: time.Now(), + Status: "online", + Metadata: models.JSONB{}, } // Add metadata if provided diff --git a/aggregator-server/internal/api/handlers/downloads.go b/aggregator-server/internal/api/handlers/downloads.go index 9a4bfdb..03f638e 100644 --- a/aggregator-server/internal/api/handlers/downloads.go +++ b/aggregator-server/internal/api/handlers/downloads.go @@ -99,6 +99,25 @@ func (h *DownloadHandler) DownloadAgent(c *gin.Context) { c.File(agentPath) } +// DownloadUpdatePackage serves signed agent update packages +func (h *DownloadHandler) DownloadUpdatePackage(c *gin.Context) { + packageID := c.Param("package_id") + + // Validate package ID format (UUID) + if len(packageID) != 36 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid package ID format"}) + return + } + + // TODO: Implement actual package serving from database/filesystem + // For now, return a placeholder response + c.JSON(http.StatusNotImplemented, gin.H{ + "error": "Update package download not yet implemented", + "package_id": packageID, + "message": "This will serve the signed update package file", + }) +} + // InstallScript serves the installation script func (h *DownloadHandler) InstallScript(c *gin.Context) { platform := c.Param("platform") diff --git a/aggregator-server/internal/api/handlers/system.go b/aggregator-server/internal/api/handlers/system.go new file mode 100644 index 0000000..c046a98 --- /dev/null +++ b/aggregator-server/internal/api/handlers/system.go @@ -0,0 +1,57 @@ +package handlers + +import ( + "net/http" + + "github.com/Fimeg/RedFlag/aggregator-server/internal/services" + "github.com/gin-gonic/gin" +) + +// SystemHandler handles system-level operations +type SystemHandler struct { + signingService *services.SigningService +} + +// NewSystemHandler creates a new system handler +func NewSystemHandler(ss *services.SigningService) *SystemHandler { + return &SystemHandler{ + signingService: ss, + } +} + +// GetPublicKey returns the server's Ed25519 public key for signature verification +// This allows agents to fetch the public key at runtime instead of embedding it at build time +func (h *SystemHandler) GetPublicKey(c *gin.Context) { + if h.signingService == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": "signing service not configured", + "hint": "Set REDFLAG_SIGNING_PRIVATE_KEY environment variable", + }) + return + } + + pubKeyHex := h.signingService.GetPublicKey() + fingerprint := h.signingService.GetPublicKeyFingerprint() + + c.JSON(http.StatusOK, gin.H{ + "public_key": pubKeyHex, + "fingerprint": fingerprint, + "algorithm": "ed25519", + "key_size": 32, + }) +} + +// GetSystemInfo returns general system information +func (h *SystemHandler) GetSystemInfo(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "version": "v0.1.21", + "name": "RedFlag Aggregator", + "description": "Self-hosted update management platform", + "features": []string{ + "agent_management", + "update_tracking", + "command_execution", + "ed25519_signing", + }, + }) +} diff --git a/aggregator-server/internal/api/handlers/verification.go b/aggregator-server/internal/api/handlers/verification.go new file mode 100644 index 0000000..33b0fa8 --- /dev/null +++ b/aggregator-server/internal/api/handlers/verification.go @@ -0,0 +1,137 @@ +package handlers + +import ( + "crypto/ed25519" + "encoding/hex" + "fmt" + "log" + "net/http" + "strings" + + "github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries" + "github.com/Fimeg/RedFlag/aggregator-server/internal/models" + "github.com/Fimeg/RedFlag/aggregator-server/internal/services" + "github.com/gin-gonic/gin" +) + +// VerificationHandler handles signature verification requests +type VerificationHandler struct { + agentQueries *queries.AgentQueries + signingService *services.SigningService +} + +// NewVerificationHandler creates a new verification handler +func NewVerificationHandler(aq *queries.AgentQueries, signingService *services.SigningService) *VerificationHandler { + return &VerificationHandler{ + agentQueries: aq, + signingService: signingService, + } +} + +// VerifySignature handles POST /api/v1/agents/:id/verify-signature +func (h *VerificationHandler) VerifySignature(c *gin.Context) { + var req models.SignatureVerificationRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Validate the agent exists and matches the provided machine ID + agent, err := h.agentQueries.GetAgentByID(req.AgentID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "agent not found"}) + return + } + + // Verify machine ID matches + if agent.MachineID == nil || *agent.MachineID != req.MachineID { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "machine ID mismatch", + "expected": agent.MachineID, + "received": req.MachineID, + }) + return + } + + // Verify public key fingerprint matches + if agent.PublicKeyFingerprint == nil || *agent.PublicKeyFingerprint != req.PublicKey { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "public key fingerprint mismatch", + "expected": agent.PublicKeyFingerprint, + "received": req.PublicKey, + }) + return + } + + // Verify the signature + valid, err := h.verifyAgentSignature(req.BinaryPath, req.Signature) + if err != nil { + log.Printf("Signature verification failed for agent %s: %v", req.AgentID, err) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "signature verification failed", + "details": err.Error(), + }) + return + } + + response := models.SignatureVerificationResponse{ + Valid: valid, + AgentID: req.AgentID.String(), + MachineID: req.MachineID, + Fingerprint: req.PublicKey, + Message: "Signature verification completed", + } + + if !valid { + response.Message = "Invalid signature - binary may be tampered with" + c.JSON(http.StatusUnauthorized, response) + return + } + + c.JSON(http.StatusOK, response) +} + +// verifyAgentSignature verifies the signature of an agent binary +func (h *VerificationHandler) verifyAgentSignature(binaryPath, signatureHex string) (bool, error) { + // Decode the signature + signature, err := hex.DecodeString(signatureHex) + if err != nil { + return false, fmt.Errorf("invalid signature format: %w", err) + } + + if len(signature) != ed25519.SignatureSize { + return false, fmt.Errorf("invalid signature size: expected %d bytes, got %d", ed25519.SignatureSize, len(signature)) + } + + // Read the binary file + content, err := readFileContent(binaryPath) + if err != nil { + return false, fmt.Errorf("failed to read binary: %w", err) + } + + // Verify using the signing service + valid, err := h.signingService.VerifySignature(content, signatureHex) + if err != nil { + return false, fmt.Errorf("verification failed: %w", err) + } + + return valid, nil +} + +// readFileContent reads file content safely +func readFileContent(filePath string) ([]byte, error) { + // Basic path validation to prevent directory traversal + if strings.Contains(filePath, "..") || strings.Contains(filePath, "~") { + return nil, fmt.Errorf("invalid file path") + } + + // Only allow specific file patterns for security + if !strings.HasSuffix(filePath, "/redflag-agent") && !strings.HasSuffix(filePath, "/redflag-agent.exe") { + return nil, fmt.Errorf("invalid file type - only agent binaries are allowed") + } + + // For security, we won't actually read files in this handler + // In a real implementation, this would verify the actual binary on the agent + // For now, we'll simulate the verification process + return []byte("simulated-binary-content"), nil +} \ No newline at end of file diff --git a/aggregator-server/internal/api/middleware/machine_binding.go b/aggregator-server/internal/api/middleware/machine_binding.go new file mode 100644 index 0000000..150bc21 --- /dev/null +++ b/aggregator-server/internal/api/middleware/machine_binding.go @@ -0,0 +1,99 @@ +package middleware + +import ( + "log" + "net/http" + + "github.com/Fimeg/RedFlag/aggregator-server/internal/database/queries" + "github.com/Fimeg/RedFlag/aggregator-server/internal/utils" + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +// MachineBindingMiddleware validates machine ID matches database record +// This prevents agent impersonation via config file copying to different machines +func MachineBindingMiddleware(agentQueries *queries.AgentQueries, minAgentVersion string) gin.HandlerFunc { + return func(c *gin.Context) { + // Skip if not authenticated (handled by auth middleware) + agentIDVal, exists := c.Get("agent_id") + if !exists { + c.Next() + return + } + + agentID, ok := agentIDVal.(uuid.UUID) + if !ok { + log.Printf("[MachineBinding] Invalid agent_id type in context") + c.JSON(http.StatusInternalServerError, gin.H{"error": "invalid agent ID"}) + c.Abort() + return + } + + // Get agent from database + agent, err := agentQueries.GetAgentByID(agentID) + if err != nil { + log.Printf("[MachineBinding] Agent %s not found: %v", agentID, err) + c.JSON(http.StatusUnauthorized, gin.H{"error": "agent not found"}) + c.Abort() + return + } + + // Check minimum version (hard cutoff for legacy de-support) + if agent.CurrentVersion != "" && minAgentVersion != "" { + if !utils.IsNewerOrEqualVersion(agent.CurrentVersion, minAgentVersion) { + log.Printf("[MachineBinding] Agent %s version %s below minimum %s - rejecting", + agent.Hostname, agent.CurrentVersion, minAgentVersion) + c.JSON(http.StatusUpgradeRequired, gin.H{ + "error": "agent version too old - upgrade required for security", + "current_version": agent.CurrentVersion, + "minimum_version": minAgentVersion, + "upgrade_instructions": "Please upgrade to the latest agent version and re-register", + }) + c.Abort() + return + } + } + + // Extract X-Machine-ID header + reportedMachineID := c.GetHeader("X-Machine-ID") + if reportedMachineID == "" { + log.Printf("[MachineBinding] Agent %s (%s) missing X-Machine-ID header", + agent.Hostname, agentID) + c.JSON(http.StatusForbidden, gin.H{ + "error": "missing machine ID header - agent version too old or tampered", + "hint": "Please upgrade to the latest agent version (v0.1.22+)", + }) + c.Abort() + return + } + + // Validate machine ID matches database + if agent.MachineID == nil { + log.Printf("[MachineBinding] Agent %s (%s) has no machine_id in database - legacy agent", + agent.Hostname, agentID) + c.JSON(http.StatusForbidden, gin.H{ + "error": "agent not bound to machine - re-registration required", + "hint": "This agent was registered before v0.1.22. Please re-register with a new registration token.", + }) + c.Abort() + return + } + + if *agent.MachineID != reportedMachineID { + log.Printf("[MachineBinding] ⚠️ SECURITY ALERT: Agent %s (%s) machine ID mismatch! DB=%s, Reported=%s", + agent.Hostname, agentID, *agent.MachineID, reportedMachineID) + c.JSON(http.StatusForbidden, gin.H{ + "error": "machine ID mismatch - config file copied to different machine", + "hint": "Agent configuration is bound to the original machine. Please register this machine with a new registration token.", + "security_note": "This prevents agent impersonation attacks", + }) + c.Abort() + return + } + + // Machine ID validated - allow request + log.Printf("[MachineBinding] ✓ Agent %s (%s) machine ID validated: %s", + agent.Hostname, agentID, reportedMachineID[:16]+"...") + c.Next() + } +} diff --git a/aggregator-server/internal/config/config.go b/aggregator-server/internal/config/config.go index 7e20fa7..cc84572 100644 --- a/aggregator-server/internal/config/config.go +++ b/aggregator-server/internal/config/config.go @@ -37,10 +37,12 @@ type Config struct { MaxTokens int `env:"REDFLAG_MAX_TOKENS" default:"100"` MaxSeats int `env:"REDFLAG_MAX_SEATS" default:"50"` } - CheckInInterval int - OfflineThreshold int - Timezone string + CheckInInterval int + OfflineThreshold int + Timezone string LatestAgentVersion string + MinAgentVersion string `env:"MIN_AGENT_VERSION" default:"0.1.22"` + SigningPrivateKey string `env:"REDFLAG_SIGNING_PRIVATE_KEY"` } // Load reads configuration from environment variables only (immutable configuration) @@ -84,7 +86,8 @@ func Load() (*Config, error) { cfg.CheckInInterval = checkInInterval cfg.OfflineThreshold = offlineThreshold cfg.Timezone = getEnv("TIMEZONE", "UTC") - cfg.LatestAgentVersion = getEnv("LATEST_AGENT_VERSION", "0.1.18") + cfg.LatestAgentVersion = getEnv("LATEST_AGENT_VERSION", "0.1.22") + cfg.MinAgentVersion = getEnv("MIN_AGENT_VERSION", "0.1.22") // Handle missing secrets if cfg.Admin.Password == "" || cfg.Admin.JWTSecret == "" || cfg.Database.Password == "" { diff --git a/aggregator-server/internal/database/migrations/016_agent_update_packages.down.sql b/aggregator-server/internal/database/migrations/016_agent_update_packages.down.sql new file mode 100644 index 0000000..8f2ad63 --- /dev/null +++ b/aggregator-server/internal/database/migrations/016_agent_update_packages.down.sql @@ -0,0 +1,10 @@ +-- Remove agent update packages table +DROP TABLE IF EXISTS agent_update_packages; + +-- Remove new columns from agents table +ALTER TABLE agents +DROP COLUMN IF EXISTS machine_id, +DROP COLUMN IF EXISTS public_key_fingerprint, +DROP COLUMN IF EXISTS is_updating, +DROP COLUMN IF EXISTS updating_to_version, +DROP COLUMN IF EXISTS update_initiated_at; \ No newline at end of file diff --git a/aggregator-server/internal/database/migrations/016_agent_update_packages.up.sql b/aggregator-server/internal/database/migrations/016_agent_update_packages.up.sql new file mode 100644 index 0000000..f06167a --- /dev/null +++ b/aggregator-server/internal/database/migrations/016_agent_update_packages.up.sql @@ -0,0 +1,47 @@ +-- Add machine ID and public key fingerprint fields to agents table +-- This enables Ed25519 binary signing and machine binding + +ALTER TABLE agents +ADD COLUMN machine_id VARCHAR(64) UNIQUE, +ADD COLUMN public_key_fingerprint VARCHAR(16), +ADD COLUMN is_updating BOOLEAN DEFAULT false, +ADD COLUMN updating_to_version VARCHAR(50), +ADD COLUMN update_initiated_at TIMESTAMP; + +-- Create index for machine ID lookups +CREATE INDEX idx_agents_machine_id ON agents(machine_id); +CREATE INDEX idx_agents_public_key_fingerprint ON agents(public_key_fingerprint); + +-- Add comment to document the new fields +COMMENT ON COLUMN agents.machine_id IS 'Unique machine identifier to bind agent binaries to specific hardware'; +COMMENT ON COLUMN agents.public_key_fingerprint IS 'Fingerprint of embedded public key for binary signature verification'; +COMMENT ON COLUMN agents.is_updating IS 'Whether agent is currently updating'; +COMMENT ON COLUMN agents.updating_to_version IS 'Target version for ongoing update'; +COMMENT ON COLUMN agents.update_initiated_at IS 'When the update process started'; + +-- Create table for storing signed update packages +CREATE TABLE agent_update_packages ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + version VARCHAR(50) NOT NULL, + platform VARCHAR(50) NOT NULL, -- linux-amd64, linux-arm64, windows-amd64, etc. + architecture VARCHAR(20) NOT NULL, + binary_path VARCHAR(500) NOT NULL, + signature VARCHAR(128) NOT NULL, -- Ed25519 signature (64 bytes hex encoded) + checksum VARCHAR(64) NOT NULL, -- SHA-256 checksum + file_size BIGINT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100) DEFAULT 'system', + is_active BOOLEAN DEFAULT true +); + +-- Add indexes for update packages +CREATE INDEX idx_agent_update_packages_version ON agent_update_packages(version); +CREATE INDEX idx_agent_update_packages_platform ON agent_update_packages(platform, architecture); +CREATE INDEX idx_agent_update_packages_active ON agent_update_packages(is_active); + +-- Add comments for update packages table +COMMENT ON TABLE agent_update_packages IS 'Stores signed agent binary packages for secure updates'; +COMMENT ON COLUMN agent_update_packages.signature IS 'Ed25519 signature of the binary file'; +COMMENT ON COLUMN agent_update_packages.checksum IS 'SHA-256 checksum of the binary file'; +COMMENT ON COLUMN agent_update_packages.platform IS 'Target platform (OS-architecture)'; +COMMENT ON COLUMN agent_update_packages.is_active IS 'Whether this package is available for updates'; \ No newline at end of file diff --git a/aggregator-server/internal/database/migrations/017_add_machine_id.down.sql b/aggregator-server/internal/database/migrations/017_add_machine_id.down.sql new file mode 100644 index 0000000..5aad745 --- /dev/null +++ b/aggregator-server/internal/database/migrations/017_add_machine_id.down.sql @@ -0,0 +1,4 @@ +-- Rollback machine_id column addition + +DROP INDEX IF EXISTS idx_agents_machine_id; +ALTER TABLE agents DROP COLUMN IF EXISTS machine_id; diff --git a/aggregator-server/internal/database/migrations/017_add_machine_id.up.sql b/aggregator-server/internal/database/migrations/017_add_machine_id.up.sql new file mode 100644 index 0000000..6d4d015 --- /dev/null +++ b/aggregator-server/internal/database/migrations/017_add_machine_id.up.sql @@ -0,0 +1,11 @@ +-- Add machine_id column to agents table for hardware fingerprint binding +-- This prevents config file copying attacks by validating hardware identity + +ALTER TABLE agents +ADD COLUMN machine_id VARCHAR(64); + +-- Create unique index to prevent duplicate machine IDs +CREATE UNIQUE INDEX idx_agents_machine_id ON agents(machine_id) WHERE machine_id IS NOT NULL; + +-- Add comment for documentation +COMMENT ON COLUMN agents.machine_id IS 'SHA-256 hash of hardware fingerprint (prevents agent impersonation via config copying)'; diff --git a/aggregator-server/internal/database/queries/agent_updates.go b/aggregator-server/internal/database/queries/agent_updates.go new file mode 100644 index 0000000..2bc92f4 --- /dev/null +++ b/aggregator-server/internal/database/queries/agent_updates.go @@ -0,0 +1,219 @@ +package queries + +import ( + "database/sql" + "fmt" + "time" + + "github.com/Fimeg/RedFlag/aggregator-server/internal/models" + "github.com/google/uuid" + "github.com/jmoiron/sqlx" +) + +// AgentUpdateQueries handles database operations for agent update packages +type AgentUpdateQueries struct { + db *sqlx.DB +} + +// NewAgentUpdateQueries creates a new AgentUpdateQueries instance +func NewAgentUpdateQueries(db *sqlx.DB) *AgentUpdateQueries { + return &AgentUpdateQueries{db: db} +} + +// CreateUpdatePackage stores a new signed update package +func (q *AgentUpdateQueries) CreateUpdatePackage(pkg *models.AgentUpdatePackage) error { + query := ` + INSERT INTO agent_update_packages ( + id, version, platform, architecture, binary_path, signature, + checksum, file_size, created_by, is_active + ) VALUES ( + :id, :version, :platform, :architecture, :binary_path, :signature, + :checksum, :file_size, :created_by, :is_active + ) RETURNING id, created_at + ` + + rows, err := q.db.NamedQuery(query, pkg) + if err != nil { + return fmt.Errorf("failed to create update package: %w", err) + } + defer rows.Close() + + if rows.Next() { + if err := rows.Scan(&pkg.ID, &pkg.CreatedAt); err != nil { + return fmt.Errorf("failed to scan created package: %w", err) + } + } + + return nil +} + +// GetUpdatePackage retrieves an update package by ID +func (q *AgentUpdateQueries) GetUpdatePackage(id uuid.UUID) (*models.AgentUpdatePackage, error) { + query := ` + SELECT id, version, platform, architecture, binary_path, signature, + checksum, file_size, created_at, created_by, is_active + FROM agent_update_packages + WHERE id = $1 AND is_active = true + ` + + var pkg models.AgentUpdatePackage + err := q.db.Get(&pkg, query, id) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("update package not found") + } + return nil, fmt.Errorf("failed to get update package: %w", err) + } + + return &pkg, nil +} + +// GetUpdatePackageByVersion retrieves the latest update package for a version and platform +func (q *AgentUpdateQueries) GetUpdatePackageByVersion(version, platform, architecture string) (*models.AgentUpdatePackage, error) { + query := ` + SELECT id, version, platform, architecture, binary_path, signature, + checksum, file_size, created_at, created_by, is_active + FROM agent_update_packages + WHERE version = $1 AND platform = $2 AND architecture = $3 AND is_active = true + ORDER BY created_at DESC + LIMIT 1 + ` + + var pkg models.AgentUpdatePackage + err := q.db.Get(&pkg, query, version, platform, architecture) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("no update package found for version %s on %s/%s", version, platform, architecture) + } + return nil, fmt.Errorf("failed to get update package: %w", err) + } + + return &pkg, nil +} + +// ListUpdatePackages retrieves all update packages with optional filtering +func (q *AgentUpdateQueries) ListUpdatePackages(version, platform string, limit, offset int) ([]models.AgentUpdatePackage, error) { + query := ` + SELECT id, version, platform, architecture, binary_path, signature, + checksum, file_size, created_at, created_by, is_active + FROM agent_update_packages + WHERE is_active = true + ` + + args := []interface{}{} + argIndex := 1 + + if version != "" { + query += fmt.Sprintf(" AND version = $%d", argIndex) + args = append(args, version) + argIndex++ + } + + if platform != "" { + query += fmt.Sprintf(" AND platform = $%d", argIndex) + args = append(args, platform) + argIndex++ + } + + query += " ORDER BY created_at DESC" + + if limit > 0 { + query += fmt.Sprintf(" LIMIT $%d", argIndex) + args = append(args, limit) + argIndex++ + + if offset > 0 { + query += fmt.Sprintf(" OFFSET $%d", argIndex) + args = append(args, offset) + } + } + + var packages []models.AgentUpdatePackage + err := q.db.Select(&packages, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to list update packages: %w", err) + } + + return packages, nil +} + +// DeactivateUpdatePackage marks a package as inactive +func (q *AgentUpdateQueries) DeactivateUpdatePackage(id uuid.UUID) error { + query := `UPDATE agent_update_packages SET is_active = false WHERE id = $1` + + result, err := q.db.Exec(query, id) + if err != nil { + return fmt.Errorf("failed to deactivate update package: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rowsAffected == 0 { + return fmt.Errorf("no update package found to deactivate") + } + + return nil +} + +// UpdateAgentMachineInfo updates the machine ID and public key fingerprint for an agent +func (q *AgentUpdateQueries) UpdateAgentMachineInfo(agentID uuid.UUID, machineID, publicKeyFingerprint string) error { + query := ` + UPDATE agents + SET machine_id = $1, public_key_fingerprint = $2, updated_at = $3 + WHERE id = $4 + ` + + _, err := q.db.Exec(query, machineID, publicKeyFingerprint, time.Now().UTC(), agentID) + if err != nil { + return fmt.Errorf("failed to update agent machine info: %w", err) + } + + return nil +} + +// UpdateAgentUpdatingStatus sets the update status for an agent +func (q *AgentUpdateQueries) UpdateAgentUpdatingStatus(agentID uuid.UUID, isUpdating bool, targetVersion *string) error { + query := ` + UPDATE agents + SET is_updating = $1, + updating_to_version = $2, + update_initiated_at = CASE WHEN $1 = true THEN $3 ELSE update_initiated_at END, + updated_at = $3 + WHERE id = $4 + ` + + now := time.Now().UTC() + _, err := q.db.Exec(query, isUpdating, targetVersion, now, agentID) + if err != nil { + return fmt.Errorf("failed to update agent updating status: %w", err) + } + + return nil +} + +// GetAgentByMachineID retrieves an agent by its machine ID +func (q *AgentUpdateQueries) GetAgentByMachineID(machineID string) (*models.Agent, error) { + query := ` + SELECT id, hostname, os_type, os_version, os_architecture, agent_version, + current_version, update_available, last_version_check, machine_id, + public_key_fingerprint, is_updating, updating_to_version, + update_initiated_at, last_seen, status, metadata, reboot_required, + last_reboot_at, reboot_reason, created_at, updated_at + FROM agents + WHERE machine_id = $1 + ` + + var agent models.Agent + err := q.db.Get(&agent, query, machineID) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("agent not found for machine ID") + } + return nil, fmt.Errorf("failed to get agent by machine ID: %w", err) + } + + return &agent, nil +} \ No newline at end of file diff --git a/aggregator-server/internal/database/queries/agents.go b/aggregator-server/internal/database/queries/agents.go index 2f009be..183925e 100644 --- a/aggregator-server/internal/database/queries/agents.go +++ b/aggregator-server/internal/database/queries/agents.go @@ -1,6 +1,8 @@ package queries import ( + "database/sql" + "fmt" "time" "github.com/Fimeg/RedFlag/aggregator-server/internal/models" @@ -245,3 +247,51 @@ func (q *AgentQueries) UpdateAgentLastReboot(id uuid.UUID, rebootTime time.Time) _, err := q.db.Exec(query, rebootTime, time.Now(), id) return err } + +// GetAgentByMachineID retrieves an agent by its machine ID +func (q *AgentQueries) GetAgentByMachineID(machineID string) (*models.Agent, error) { + query := ` + SELECT id, hostname, os_type, os_version, os_architecture, agent_version, + current_version, update_available, last_version_check, machine_id, + public_key_fingerprint, is_updating, updating_to_version, + update_initiated_at, last_seen, status, metadata, reboot_required, + last_reboot_at, reboot_reason, created_at, updated_at + FROM agents + WHERE machine_id = $1 + ` + + var agent models.Agent + err := q.db.Get(&agent, query, machineID) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil // Return nil if not found (not an error) + } + return nil, fmt.Errorf("failed to get agent by machine ID: %w", err) + } + + return &agent, nil +} + +// UpdateAgentUpdatingStatus updates the agent's update status +func (q *AgentQueries) UpdateAgentUpdatingStatus(id uuid.UUID, isUpdating bool, updatingToVersion *string) error { + query := ` + UPDATE agents + SET + is_updating = $1, + updating_to_version = $2, + update_initiated_at = CASE + WHEN $1 = true THEN $3 + ELSE NULL + END, + updated_at = $3 + WHERE id = $4 + ` + + var versionPtr *string + if updatingToVersion != nil { + versionPtr = updatingToVersion + } + + _, err := q.db.Exec(query, isUpdating, versionPtr, time.Now(), id) + return err +} diff --git a/aggregator-server/internal/models/agent.go b/aggregator-server/internal/models/agent.go index 75dd62d..073c0f0 100644 --- a/aggregator-server/internal/models/agent.go +++ b/aggregator-server/internal/models/agent.go @@ -18,15 +18,20 @@ type Agent struct { AgentVersion string `json:"agent_version" db:"agent_version"` // Version at registration CurrentVersion string `json:"current_version" db:"current_version"` // Current running version UpdateAvailable bool `json:"update_available" db:"update_available"` // Whether update is available - LastVersionCheck time.Time `json:"last_version_check" db:"last_version_check"` // Last time version was checked - LastSeen time.Time `json:"last_seen" db:"last_seen"` - Status string `json:"status" db:"status"` - Metadata JSONB `json:"metadata" db:"metadata"` - RebootRequired bool `json:"reboot_required" db:"reboot_required"` - LastRebootAt *time.Time `json:"last_reboot_at,omitempty" db:"last_reboot_at"` - RebootReason *string `json:"reboot_reason,omitempty" db:"reboot_reason"` - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + LastVersionCheck time.Time `json:"last_version_check" db:"last_version_check"` // Last time version was checked + MachineID *string `json:"machine_id,omitempty" db:"machine_id"` // Unique machine identifier + PublicKeyFingerprint *string `json:"public_key_fingerprint,omitempty" db:"public_key_fingerprint"` // Public key fingerprint + IsUpdating bool `json:"is_updating" db:"is_updating"` // Whether agent is currently updating + UpdatingToVersion *string `json:"updating_to_version,omitempty" db:"updating_to_version"` // Target version for ongoing update + UpdateInitiatedAt *time.Time `json:"update_initiated_at,omitempty" db:"update_initiated_at"` // When update process started + LastSeen time.Time `json:"last_seen" db:"last_seen"` + Status string `json:"status" db:"status"` + Metadata JSONB `json:"metadata" db:"metadata"` + RebootRequired bool `json:"reboot_required" db:"reboot_required"` + LastRebootAt *time.Time `json:"last_reboot_at,omitempty" db:"last_reboot_at"` + RebootReason *string `json:"reboot_reason,omitempty" db:"reboot_reason"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` } // AgentWithLastScan extends Agent with last scan information @@ -69,13 +74,15 @@ type AgentSpecs struct { // AgentRegistrationRequest is the payload for agent registration type AgentRegistrationRequest struct { - Hostname string `json:"hostname" binding:"required"` - OSType string `json:"os_type" binding:"required"` - OSVersion string `json:"os_version"` - OSArchitecture string `json:"os_architecture"` - AgentVersion string `json:"agent_version" binding:"required"` - RegistrationToken string `json:"registration_token"` // Optional, for fallback method - Metadata map[string]string `json:"metadata"` + Hostname string `json:"hostname" binding:"required"` + OSType string `json:"os_type" binding:"required"` + OSVersion string `json:"os_version"` + OSArchitecture string `json:"os_architecture"` + AgentVersion string `json:"agent_version" binding:"required"` + RegistrationToken string `json:"registration_token"` // Optional, for fallback method + MachineID string `json:"machine_id"` // Unique machine identifier + PublicKeyFingerprint string `json:"public_key_fingerprint"` // Embedded public key fingerprint + Metadata map[string]string `json:"metadata"` } // AgentRegistrationResponse is returned after successful registration diff --git a/aggregator-server/internal/models/agent_update.go b/aggregator-server/internal/models/agent_update.go new file mode 100644 index 0000000..1575429 --- /dev/null +++ b/aggregator-server/internal/models/agent_update.go @@ -0,0 +1,67 @@ +package models + +import ( + "time" + + "github.com/google/uuid" +) + +// AgentUpdatePackage represents a signed agent binary package +type AgentUpdatePackage struct { + ID uuid.UUID `json:"id" db:"id"` + Version string `json:"version" db:"version"` + Platform string `json:"platform" db:"platform"` + Architecture string `json:"architecture" db:"architecture"` + BinaryPath string `json:"binary_path" db:"binary_path"` + Signature string `json:"signature" db:"signature"` + Checksum string `json:"checksum" db:"checksum"` + FileSize int64 `json:"file_size" db:"file_size"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + CreatedBy string `json:"created_by" db:"created_by"` + IsActive bool `json:"is_active" db:"is_active"` +} + +// AgentUpdateRequest represents a request to update an agent +type AgentUpdateRequest struct { + AgentID uuid.UUID `json:"agent_id" binding:"required"` + Version string `json:"version" binding:"required"` + Platform string `json:"platform" binding:"required"` + Scheduled *string `json:"scheduled_at,omitempty"` +} + +// BulkAgentUpdateRequest represents a bulk update request +type BulkAgentUpdateRequest struct { + AgentIDs []uuid.UUID `json:"agent_ids" binding:"required"` + Version string `json:"version" binding:"required"` + Platform string `json:"platform" binding:"required"` + Scheduled *string `json:"scheduled_at,omitempty"` +} + +// AgentUpdateResponse represents the response for an update request +type AgentUpdateResponse struct { + Message string `json:"message"` + UpdateID string `json:"update_id,omitempty"` + DownloadURL string `json:"download_url,omitempty"` + Signature string `json:"signature,omitempty"` + Checksum string `json:"checksum,omitempty"` + FileSize int64 `json:"file_size,omitempty"` + EstimatedTime int `json:"estimated_time_seconds,omitempty"` +} + +// SignatureVerificationRequest represents a request to verify an agent's binary signature +type SignatureVerificationRequest struct { + AgentID uuid.UUID `json:"agent_id" binding:"required"` + BinaryPath string `json:"binary_path" binding:"required"` + MachineID string `json:"machine_id" binding:"required"` + PublicKey string `json:"public_key" binding:"required"` + Signature string `json:"signature" binding:"required"` +} + +// SignatureVerificationResponse represents the response for signature verification +type SignatureVerificationResponse struct { + Valid bool `json:"valid"` + AgentID string `json:"agent_id"` + MachineID string `json:"machine_id"` + Fingerprint string `json:"fingerprint"` + Message string `json:"message"` +} \ No newline at end of file diff --git a/aggregator-server/internal/services/signing.go b/aggregator-server/internal/services/signing.go new file mode 100644 index 0000000..b101962 --- /dev/null +++ b/aggregator-server/internal/services/signing.go @@ -0,0 +1,239 @@ +package services + +import ( + "crypto/ed25519" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "os" + "runtime" + "time" + + "github.com/Fimeg/RedFlag/aggregator-server/internal/models" + "github.com/google/uuid" +) + +// SigningService handles Ed25519 cryptographic operations +type SigningService struct { + privateKey ed25519.PrivateKey + publicKey ed25519.PublicKey +} + +// NewSigningService creates a new signing service with the provided private key +func NewSigningService(privateKeyHex string) (*SigningService, error) { + // Decode private key from hex + privateKeyBytes, err := hex.DecodeString(privateKeyHex) + if err != nil { + return nil, fmt.Errorf("invalid private key format: %w", err) + } + + if len(privateKeyBytes) != ed25519.PrivateKeySize { + return nil, fmt.Errorf("invalid private key size: expected %d bytes, got %d", ed25519.PrivateKeySize, len(privateKeyBytes)) + } + + // Ed25519 private key format: first 32 bytes are seed, next 32 bytes are public key + privateKey := ed25519.PrivateKey(privateKeyBytes) + publicKey := privateKey.Public().(ed25519.PublicKey) + + return &SigningService{ + privateKey: privateKey, + publicKey: publicKey, + }, nil +} + +// SignFile signs a file and returns the signature and checksum +func (s *SigningService) SignFile(filePath string) (*models.AgentUpdatePackage, error) { + // Read the file + file, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + // Calculate checksum and sign content + content, err := io.ReadAll(file) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + + // Calculate SHA-256 checksum + hash := sha256.Sum256(content) + checksum := hex.EncodeToString(hash[:]) + + // Sign the content + signature := ed25519.Sign(s.privateKey, content) + + // Get file info + fileInfo, err := file.Stat() + if err != nil { + return nil, fmt.Errorf("failed to get file info: %w", err) + } + + // Determine platform and architecture from file path or use runtime defaults + platform, architecture := s.detectPlatformArchitecture(filePath) + + pkg := &models.AgentUpdatePackage{ + BinaryPath: filePath, + Signature: hex.EncodeToString(signature), + Checksum: checksum, + FileSize: fileInfo.Size(), + Platform: platform, + Architecture: architecture, + CreatedBy: "signing-service", + IsActive: true, + } + + return pkg, nil +} + +// VerifySignature verifies a file signature using the embedded public key +func (s *SigningService) VerifySignature(content []byte, signatureHex string) (bool, error) { + // Decode signature + signature, err := hex.DecodeString(signatureHex) + if err != nil { + return false, fmt.Errorf("invalid signature format: %w", err) + } + + if len(signature) != ed25519.SignatureSize { + return false, fmt.Errorf("invalid signature size: expected %d bytes, got %d", ed25519.SignatureSize, len(signature)) + } + + // Verify signature + valid := ed25519.Verify(s.publicKey, content, signature) + return valid, nil +} + +// GetPublicKey returns the public key in hex format +func (s *SigningService) GetPublicKey() string { + return hex.EncodeToString(s.publicKey) +} + +// GetPublicKeyFingerprint returns a short fingerprint of the public key +func (s *SigningService) GetPublicKeyFingerprint() string { + // Use first 8 bytes as fingerprint + return hex.EncodeToString(s.publicKey[:8]) +} + +// VerifyFileIntegrity verifies a file's checksum +func (s *SigningService) VerifyFileIntegrity(filePath, expectedChecksum string) (bool, error) { + file, err := os.Open(filePath) + if err != nil { + return false, fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + content, err := io.ReadAll(file) + if err != nil { + return false, fmt.Errorf("failed to read file: %w", err) + } + + hash := sha256.Sum256(content) + actualChecksum := hex.EncodeToString(hash[:]) + + return actualChecksum == expectedChecksum, nil +} + +// detectPlatformArchitecture attempts to detect platform and architecture from file path +func (s *SigningService) detectPlatformArchitecture(filePath string) (string, string) { + // Default to current runtime + platform := runtime.GOOS + arch := runtime.GOARCH + + // Map architectures + archMap := map[string]string{ + "amd64": "amd64", + "arm64": "arm64", + "386": "386", + } + + // Try to detect from filename patterns + if contains(filePath, "windows") || contains(filePath, ".exe") { + platform = "windows" + } else if contains(filePath, "linux") { + platform = "linux" + } else if contains(filePath, "darwin") || contains(filePath, "macos") { + platform = "darwin" + } + + for archName, archValue := range archMap { + if contains(filePath, archName) { + arch = archValue + break + } + } + + // Normalize architecture names + if arch == "amd64" { + arch = "amd64" + } else if arch == "arm64" { + arch = "arm64" + } + + return platform, arch +} + +// contains is a simple helper for case-insensitive substring checking +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || + (len(s) > len(substr) && + (s[:len(substr)] == substr || + s[len(s)-len(substr):] == substr || + findSubstring(s, substr)))) +} + +// findSubstring is a simple substring finder +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// SignNonce signs a nonce (UUID + timestamp) for replay protection +func (s *SigningService) SignNonce(nonceUUID uuid.UUID, timestamp time.Time) (string, error) { + // Create nonce data: UUID + Unix timestamp as string + nonceData := fmt.Sprintf("%s:%d", nonceUUID.String(), timestamp.Unix()) + + // Sign the nonce data + signature := ed25519.Sign(s.privateKey, []byte(nonceData)) + + // Return hex-encoded signature + return hex.EncodeToString(signature), nil +} + +// VerifyNonce verifies a nonce signature and checks freshness +func (s *SigningService) VerifyNonce(nonceUUID uuid.UUID, timestamp time.Time, signatureHex string, maxAge time.Duration) (bool, error) { + // Check nonce freshness first + if time.Since(timestamp) > maxAge { + return false, fmt.Errorf("nonce is too old: %v > %v", time.Since(timestamp), maxAge) + } + + // Recreate nonce data + nonceData := fmt.Sprintf("%s:%d", nonceUUID.String(), timestamp.Unix()) + + // Verify signature + valid, err := s.VerifySignature([]byte(nonceData), signatureHex) + if err != nil { + return false, fmt.Errorf("failed to verify nonce signature: %w", err) + } + + return valid, nil +} + +// TODO: Key rotation implementation +// This is a stub for future key rotation functionality +// Key rotation should: +// 1. Maintain multiple active key pairs with version numbers +// 2. Support graceful transition periods (e.g., 30 days) +// 3. Store previous keys for signature verification during transition +// 4. Batch migration of existing agent fingerprints +// 5. Provide monitoring for rotation completion +// +// Example implementation approach: +// - Use database to store multiple key versions with activation timestamps +// - Include key version ID in signatures +// - Maintain lookup table of previous keys for verification +// - Background job to monitor rotation progress \ No newline at end of file diff --git a/aggregator-server/internal/utils/version.go b/aggregator-server/internal/utils/version.go index 39491e8..b0c4c8b 100644 --- a/aggregator-server/internal/utils/version.go +++ b/aggregator-server/internal/utils/version.go @@ -33,6 +33,12 @@ func IsNewerVersion(version1, version2 string) bool { return CompareVersions(version1, version2) == 1 } +// IsNewerOrEqualVersion returns true if version1 is newer than or equal to version2 +func IsNewerOrEqualVersion(version1, version2 string) bool { + cmp := CompareVersions(version1, version2) + return cmp == 1 || cmp == 0 +} + // parseVersion parses a version string like "0.1.4" into [0, 1, 4] func parseVersion(version string) [3]int { // Default version if parsing fails diff --git a/aggregator-server/test-server b/aggregator-server/test-server new file mode 100755 index 0000000..d50e74e Binary files /dev/null and b/aggregator-server/test-server differ diff --git a/aggregator-web/src/components/AgentScanners.tsx b/aggregator-web/src/components/AgentScanners.tsx index ae4ab31..c944bba 100644 --- a/aggregator-web/src/components/AgentScanners.tsx +++ b/aggregator-web/src/components/AgentScanners.tsx @@ -1,18 +1,9 @@ -import React, { useState } from 'react'; +import React from 'react'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { - MonitorPlay, RefreshCw, - Settings, Activity, - Clock, - CheckCircle, - XCircle, Play, - Square, - Database, - Shield, - Search, HardDrive, Cpu, Container, @@ -78,7 +69,7 @@ export function AgentScanners({ agentId }: AgentScannersProps) { return await agentApi.disableSubsystem(agentId, subsystem); } }, - onSuccess: (data, variables) => { + onSuccess: (_, variables) => { toast.success(`${subsystemConfig[variables.subsystem]?.name || variables.subsystem} ${variables.enabled ? 'enabled' : 'disabled'}`); queryClient.invalidateQueries({ queryKey: ['subsystems', agentId] }); }, @@ -92,7 +83,7 @@ export function AgentScanners({ agentId }: AgentScannersProps) { mutationFn: async ({ subsystem, intervalMinutes }: { subsystem: string; intervalMinutes: number }) => { return await agentApi.setSubsystemInterval(agentId, subsystem, intervalMinutes); }, - onSuccess: (data, variables) => { + onSuccess: (_, variables) => { toast.success(`Interval updated to ${variables.intervalMinutes} minutes`); queryClient.invalidateQueries({ queryKey: ['subsystems', agentId] }); }, @@ -106,7 +97,7 @@ export function AgentScanners({ agentId }: AgentScannersProps) { mutationFn: async ({ subsystem, autoRun }: { subsystem: string; autoRun: boolean }) => { return await agentApi.setSubsystemAutoRun(agentId, subsystem, autoRun); }, - onSuccess: (data, variables) => { + onSuccess: (_, variables) => { toast.success(`Auto-run ${variables.autoRun ? 'enabled' : 'disabled'}`); queryClient.invalidateQueries({ queryKey: ['subsystems', agentId] }); }, @@ -120,7 +111,7 @@ export function AgentScanners({ agentId }: AgentScannersProps) { mutationFn: async (subsystem: string) => { return await agentApi.triggerSubsystem(agentId, subsystem); }, - onSuccess: (data, subsystem) => { + onSuccess: (_, subsystem) => { toast.success(`${subsystemConfig[subsystem]?.name || subsystem} scan triggered`); queryClient.invalidateQueries({ queryKey: ['subsystems', agentId] }); }, @@ -155,12 +146,7 @@ export function AgentScanners({ agentId }: AgentScannersProps) { { value: 1440, label: '24 hours' }, ]; - const getFrequencyLabel = (frequency: number) => { - if (frequency < 60) return `${frequency}m`; - if (frequency < 1440) return `${frequency / 60}h`; - return `${frequency / 1440}d`; - }; - + const enabledCount = subsystems.filter(s => s.enabled).length; const autoRunCount = subsystems.filter(s => s.auto_run && s.enabled).length; diff --git a/aggregator-web/src/components/AgentStorage.tsx b/aggregator-web/src/components/AgentStorage.tsx index a9aa931..c6a9547 100644 --- a/aggregator-web/src/components/AgentStorage.tsx +++ b/aggregator-web/src/components/AgentStorage.tsx @@ -1,17 +1,8 @@ -import React, { useState } from 'react'; +import { useState } from 'react'; import { useQuery } from '@tanstack/react-query'; import { HardDrive, RefreshCw, - Database, - Search, - Activity, - Monitor, - AlertTriangle, - CheckCircle, - Info, - TrendingUp, - Server, MemoryStick, } from 'lucide-react'; import { formatBytes, formatRelativeTime } from '@/lib/utils'; @@ -116,15 +107,7 @@ export function AgentStorage({ agentId }: AgentStorageProps) { })); }; - const getDiskTypeIcon = (diskType: string) => { - switch (diskType.toLowerCase()) { - case 'nvme': return ; - case 'ssd': return ; - case 'hdd': return ; - default: return ; - } - }; - + if (!agentData) { return (
diff --git a/aggregator-web/src/components/AgentUpdatesEnhanced.tsx b/aggregator-web/src/components/AgentUpdatesEnhanced.tsx index 4bcee6d..ece0cb1 100644 --- a/aggregator-web/src/components/AgentUpdatesEnhanced.tsx +++ b/aggregator-web/src/components/AgentUpdatesEnhanced.tsx @@ -4,6 +4,7 @@ import { Search, Package, Download, + Upload, CheckCircle, RefreshCw, Terminal, @@ -18,6 +19,7 @@ import { updateApi, agentApi } from '@/lib/api'; import toast from 'react-hot-toast'; import { cn } from '@/lib/utils'; import type { UpdatePackage } from '@/types'; +import { AgentUpdatesModal } from './AgentUpdatesModal'; interface AgentUpdatesEnhancedProps { agentId: string; @@ -52,7 +54,7 @@ export function AgentUpdatesEnhanced({ agentId }: AgentUpdatesEnhancedProps) { const [selectedSeverity, setSelectedSeverity] = useState('all'); const [showLogsModal, setShowLogsModal] = useState(false); const [logsData, setLogsData] = useState(null); - const [isLoadingLogs, setIsLoadingLogs] = useState(false); + const [showUpdateModal, setShowUpdateModal] = useState(false); const [expandedUpdates, setExpandedUpdates] = useState>(new Set()); const [selectedUpdates, setSelectedUpdates] = useState([]); @@ -300,6 +302,15 @@ export function AgentUpdatesEnhanced({ agentId }: AgentUpdatesEnhancedProps) { )} )} + + {/* Update Agent Button */} +
{/* Search and Filters */} @@ -531,6 +542,17 @@ export function AgentUpdatesEnhanced({ agentId }: AgentUpdatesEnhancedProps) { )} + + {/* Agent Update Modal */} + setShowUpdateModal(false)} + selectedAgentIds={[agentId]} + onAgentsUpdated={() => { + setShowUpdateModal(false); + queryClient.invalidateQueries({ queryKey: ['agents'] }); + }} + /> ); } diff --git a/aggregator-web/src/components/AgentUpdatesModal.tsx b/aggregator-web/src/components/AgentUpdatesModal.tsx new file mode 100644 index 0000000..2da7863 --- /dev/null +++ b/aggregator-web/src/components/AgentUpdatesModal.tsx @@ -0,0 +1,290 @@ +import { useState } from 'react'; +import { + X, + Download, + CheckCircle, + AlertCircle, + RefreshCw, + Info, + Users, + Package, + Hash, +} from 'lucide-react'; +import { useMutation, useQuery } from '@tanstack/react-query'; +import { agentApi, updateApi } from '@/lib/api'; +import toast from 'react-hot-toast'; +import { cn } from '@/lib/utils'; +import { Agent } from '@/types'; + +interface AgentUpdatesModalProps { + isOpen: boolean; + onClose: () => void; + selectedAgentIds: string[]; + onAgentsUpdated: () => void; +} + +export function AgentUpdatesModal({ + isOpen, + onClose, + selectedAgentIds, + onAgentsUpdated, +}: AgentUpdatesModalProps) { + const [selectedVersion, setSelectedVersion] = useState(''); + const [selectedPlatform, setSelectedPlatform] = useState(''); + const [isUpdating, setIsUpdating] = useState(false); + + // Fetch selected agents details + const { data: agents = [] } = useQuery({ + queryKey: ['agents-details', selectedAgentIds], + queryFn: async (): Promise => { + const promises = selectedAgentIds.map(id => agentApi.getAgent(id)); + const results = await Promise.all(promises); + return results; + }, + enabled: isOpen && selectedAgentIds.length > 0, + }); + + // Fetch available update packages + const { data: packagesResponse, isLoading: packagesLoading } = useQuery({ + queryKey: ['update-packages'], + queryFn: () => updateApi.getPackages(), + enabled: isOpen, + }); + + const packages = packagesResponse?.packages || []; + + // Group packages by version + const versions = [...new Set(packages.map(pkg => pkg.version))].sort((a, b) => b.localeCompare(a)); + const platforms = [...new Set(packages.map(pkg => pkg.platform))].sort(); + + // Filter packages based on selection + const availablePackages = packages.filter( + pkg => (!selectedVersion || pkg.version === selectedVersion) && + (!selectedPlatform || pkg.platform === selectedPlatform) + ); + + // Get unique platform for selected agents (simplified - assumes all agents same platform) + const agentPlatform = agents[0]?.os_type || 'linux'; + const agentArchitecture = agents[0]?.os_architecture || 'amd64'; + + // Update agents mutation + const updateAgentsMutation = useMutation({ + mutationFn: async (packageId: string) => { + const pkg = packages.find(p => p.id === packageId); + if (!pkg) throw new Error('Package not found'); + + const updateData = { + agent_ids: selectedAgentIds, + version: pkg.version, + platform: pkg.platform, + }; + + return agentApi.updateMultipleAgents(updateData); + }, + onSuccess: (data) => { + toast.success(`Update initiated for ${data.updated?.length || 0} agent(s)`); + setIsUpdating(false); + onAgentsUpdated(); + onClose(); + }, + onError: (error: any) => { + toast.error(`Failed to update agents: ${error.message}`); + setIsUpdating(false); + }, + }); + + const handleUpdateAgents = async (packageId: string) => { + setIsUpdating(true); + updateAgentsMutation.mutate(packageId); + }; + + const canUpdate = selectedAgentIds.length > 0 && availablePackages.length > 0 && !isUpdating; + const hasUpdatingAgents = agents.some(agent => agent.is_updating); + + if (!isOpen) return null; + + return ( +
+
+
+ +
+ {/* Header */} +
+
+ +
+

Agent Updates

+

+ Update {selectedAgentIds.length} agent{selectedAgentIds.length !== 1 ? 's' : ''} +

+
+
+ +
+ + {/* Content */} +
+ {/* Selected Agents */} +
+

+ + Selected Agents +

+
+ {agents.map((agent) => ( +
+
+ +
+
{agent.hostname}
+
+ {agent.os_type}/{agent.os_architecture} • Current: {agent.current_version || 'Unknown'} +
+
+
+ {agent.is_updating && ( +
+ + Updating to {agent.updating_to_version} +
+ )} +
+ ))} +
+ {hasUpdatingAgents && ( +
+ + Some agents are currently updating +
+ )} +
+ + {/* Package Selection */} +
+

+ + Update Package Selection +

+ + {/* Filters */} +
+
+ + +
+
+ + +
+
+ + {/* Available Packages */} +
+ {packagesLoading ? ( +
+ Loading packages... +
+ ) : availablePackages.length === 0 ? ( +
+ No packages available for the selected criteria +
+ ) : ( + availablePackages.map((pkg) => ( +
handleUpdateAgents(pkg.id)} + > +
+ +
+
+ Version {pkg.version} +
+
+ {pkg.platform} • {(pkg.file_size / 1024 / 1024).toFixed(1)} MB +
+
+
+
+
+ + {pkg.checksum.slice(0, 8)}... +
+ +
+
+ )) + )} +
+
+ + {/* Platform Compatibility Info */} +
+ + + Detected platform: {agentPlatform}/{agentArchitecture}. + Only compatible packages will be shown. + +
+
+ + {/* Footer */} +
+ +
+
+
+
+ ); +} \ No newline at end of file diff --git a/aggregator-web/src/components/ChatTimeline.tsx b/aggregator-web/src/components/ChatTimeline.tsx index b84db99..04a974a 100644 --- a/aggregator-web/src/components/ChatTimeline.tsx +++ b/aggregator-web/src/components/ChatTimeline.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect } from 'react'; +import React, { useState } from 'react'; import { CheckCircle, XCircle, @@ -7,14 +7,12 @@ import { Search, Terminal, RefreshCw, - Filter, ChevronDown, ChevronRight, User, Clock, Activity, Copy, - Hash, HardDrive, Cpu, Container, @@ -25,7 +23,6 @@ import { useRetryCommand } from '@/hooks/useCommands'; import { cn } from '@/lib/utils'; import toast from 'react-hot-toast'; import { Highlight, themes } from 'prism-react-renderer'; -import { useEffect as useEffectHook } from 'react'; interface HistoryEntry { id: string; @@ -41,6 +38,8 @@ interface HistoryEntry { exit_code?: number; duration_seconds?: number; created_at: string; + metadata?: Record; + params?: Record; hostname?: string; } @@ -76,9 +75,9 @@ const createPackageOperationSummary = (entry: HistoryEntry): string => { // Extract duration if available let durationInfo = ''; - if (entry.logged_at) { + if (entry.created_at) { try { - const loggedTime = new Date(entry.logged_at).toLocaleTimeString('en-US', { + const loggedTime = new Date(entry.created_at).toLocaleTimeString('en-US', { hour: '2-digit', minute: '2-digit' }); @@ -444,9 +443,27 @@ const ChatTimeline: React.FC = ({ agentId, className, isScope } } - // Fallback subject + // Fallback subject - provide better action labels if (!subject) { - subject = entry.package_name || 'system operation'; + // Map action to more readable labels + const actionLabels: Record = { + 'scan updates': 'Package Updates', + 'scan storage': 'Disk Usage', + 'scan system': 'System Metrics', + 'scan docker': 'Docker Images', + 'update agent': 'Agent Update', + 'dry run update': 'Update Dry Run', + 'confirm dependencies': 'Dependency Check', + 'install update': 'Update Installation', + 'collect specs': 'System Specifications', + 'enable heartbeat': 'Heartbeat Enable', + 'disable heartbeat': 'Heartbeat Disable', + 'reboot': 'System Reboot', + 'process command': 'Command Processing' + }; + + // Prioritize metadata subsystem label for better descriptions + subject = entry.metadata?.subsystem_label || entry.package_name || actionLabels[action] || action; } // Build narrative sentence - system thought style @@ -495,6 +512,16 @@ const ChatTimeline: React.FC = ({ agentId, className, isScope } else { sentence = `Docker Image Scanner results`; } + } else if (action === 'update agent') { + if (isInProgress) { + sentence = `Agent Update initiated to version ${subject}`; + } else if (statusType === 'success') { + sentence = `Agent updated to version ${subject}`; + } else if (statusType === 'failed') { + sentence = `Agent update failed for version ${subject}`; + } else { + sentence = `Agent update to version ${subject}`; + } } else if (action === 'dry run update') { if (isInProgress) { sentence = `Dry run initiated for ${subject}`; diff --git a/aggregator-web/src/lib/api.ts b/aggregator-web/src/lib/api.ts index cfc347d..9db9081 100644 --- a/aggregator-web/src/lib/api.ts +++ b/aggregator-web/src/lib/api.ts @@ -2,6 +2,7 @@ import axios, { AxiosResponse } from 'axios'; import { Agent, UpdatePackage, + AgentUpdatePackage, DashboardStats, AgentListResponse, UpdateListResponse, @@ -160,6 +161,27 @@ export const agentApi = { const response = await api.post(`/agents/${agentId}/subsystems/${subsystem}/interval`, { interval_minutes: intervalMinutes }); return response.data; }, + + // Update single agent + updateAgent: async (agentId: string, updateData: { + version: string; + platform: string; + scheduled?: string; + }): Promise<{ message: string; update_id: string; download_url: string; signature: string; checksum: string; file_size: number; estimated_time: number }> => { + const response = await api.post(`/agents/${agentId}/update`, updateData); + return response.data; + }, + + // Update multiple agents (bulk) + updateMultipleAgents: async (updateData: { + agent_ids: string[]; + version: string; + platform: string; + scheduled?: string; + }): Promise<{ message: string; updated: Array<{ agent_id: string; hostname: string; update_id: string; status: string }>; failed: string[]; total_agents: number; package_info: any }> => { + const response = await api.post('/agents/bulk-update', updateData); + return response.data; + }, }; export const updateApi = { @@ -185,6 +207,11 @@ export const updateApi = { await api.post(`/updates/${id}/approve`, { scheduled_at: scheduledAt }); }, + // Approve multiple updates + approveMultiple: async (updateIds: string[]): Promise => { + await api.post('/updates/approve', { update_ids: updateIds }); + }, + // Reject/cancel update rejectUpdate: async (id: string): Promise => { await api.post(`/updates/${id}/reject`); @@ -250,6 +277,28 @@ export const updateApi = { const response = await api.delete(`/commands/failed${params.toString() ? '?' + params.toString() : ''}`); return response.data; }, + + // Get available update packages + getPackages: async (params?: { + version?: string; + platform?: string; + limit?: number; + offset?: number; + }): Promise<{ packages: AgentUpdatePackage[]; total: number; limit: number; offset: number }> => { + const response = await api.get('/updates/packages', { params }); + return response.data; + }, + + // Sign new update package + signPackage: async (packageData: { + version: string; + platform: string; + architecture: string; + binary_path: string; + }): Promise<{ message: string; package: UpdatePackage }> => { + const response = await api.post('/updates/packages/sign', packageData); + return response.data; + }, }; export const statsApi = { diff --git a/aggregator-web/src/pages/Agents.tsx b/aggregator-web/src/pages/Agents.tsx index 899ccce..42d1e47 100644 --- a/aggregator-web/src/pages/Agents.tsx +++ b/aggregator-web/src/pages/Agents.tsx @@ -25,6 +25,7 @@ import { Database, Settings, MonitorPlay, + Upload, } from 'lucide-react'; import { useAgents, useAgent, useScanAgent, useScanMultipleAgents, useUnregisterAgent } from '@/hooks/useAgents'; import { useActiveCommands, useCancelCommand } from '@/hooks/useCommands'; @@ -38,6 +39,7 @@ import { AgentSystemUpdates } from '@/components/AgentUpdates'; import { AgentStorage } from '@/components/AgentStorage'; import { AgentUpdatesEnhanced } from '@/components/AgentUpdatesEnhanced'; import { AgentScanners } from '@/components/AgentScanners'; +import { AgentUpdatesModal } from '@/components/AgentUpdatesModal'; import ChatTimeline from '@/components/ChatTimeline'; const Agents: React.FC = () => { @@ -56,6 +58,7 @@ const Agents: React.FC = () => { const [showDurationDropdown, setShowDurationDropdown] = useState(false); const [heartbeatLoading, setHeartbeatLoading] = useState(false); // Loading state for heartbeat toggle const [heartbeatCommandId, setHeartbeatCommandId] = useState(null); // Track specific heartbeat command + const [showUpdateModal, setShowUpdateModal] = useState(false); // Update modal state const dropdownRef = useRef(null); // Close dropdown when clicking outside @@ -1142,18 +1145,27 @@ const Agents: React.FC = () => { {/* Bulk actions */} {selectedAgents.length > 0 && ( - + <> + + + )}
@@ -1358,6 +1370,20 @@ const Agents: React.FC = () => { > +