256 lines
7.5 KiB
Go
256 lines
7.5 KiB
Go
package queries
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
type RegistrationTokenQueries struct {
|
|
db *sqlx.DB
|
|
}
|
|
|
|
type RegistrationToken struct {
|
|
ID uuid.UUID `json:"id" db:"id"`
|
|
Token string `json:"token" db:"token"`
|
|
Label *string `json:"label" db:"label"`
|
|
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
|
|
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
|
UsedAt *time.Time `json:"used_at" db:"used_at"`
|
|
UsedByAgentID *uuid.UUID `json:"used_by_agent_id" db:"used_by_agent_id"`
|
|
Revoked bool `json:"revoked" db:"revoked"`
|
|
RevokedAt *time.Time `json:"revoked_at" db:"revoked_at"`
|
|
RevokedReason *string `json:"revoked_reason" db:"revoked_reason"`
|
|
Status string `json:"status" db:"status"`
|
|
CreatedBy string `json:"created_by" db:"created_by"`
|
|
Metadata json.RawMessage `json:"metadata" db:"metadata"`
|
|
MaxSeats int `json:"max_seats" db:"max_seats"`
|
|
SeatsUsed int `json:"seats_used" db:"seats_used"`
|
|
}
|
|
|
|
type TokenRequest struct {
|
|
Label string `json:"label"`
|
|
ExpiresIn string `json:"expires_in"` // e.g., "24h", "7d"
|
|
MaxSeats int `json:"max_seats"` // Number of agents that can use this token (default: 1)
|
|
Metadata map[string]interface{} `json:"metadata"`
|
|
}
|
|
|
|
type TokenResponse struct {
|
|
Token string `json:"token"`
|
|
Label string `json:"label"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
InstallCommand string `json:"install_command"`
|
|
}
|
|
|
|
func NewRegistrationTokenQueries(db *sqlx.DB) *RegistrationTokenQueries {
|
|
return &RegistrationTokenQueries{db: db}
|
|
}
|
|
|
|
// CreateRegistrationToken creates a new registration token with seat tracking
|
|
func (q *RegistrationTokenQueries) CreateRegistrationToken(token, label string, expiresAt time.Time, maxSeats int, metadata map[string]interface{}) error {
|
|
metadataJSON, err := json.Marshal(metadata)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal metadata: %w", err)
|
|
}
|
|
|
|
// Ensure maxSeats is at least 1
|
|
if maxSeats < 1 {
|
|
maxSeats = 1
|
|
}
|
|
|
|
query := `
|
|
INSERT INTO registration_tokens (token, label, expires_at, max_seats, metadata)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
`
|
|
|
|
_, err = q.db.Exec(query, token, label, expiresAt, maxSeats, metadataJSON)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create registration token: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ValidateRegistrationToken checks if a token is valid and has available seats
|
|
func (q *RegistrationTokenQueries) ValidateRegistrationToken(token string) (*RegistrationToken, error) {
|
|
var regToken RegistrationToken
|
|
query := `
|
|
SELECT id, token, label, expires_at, created_at, used_at, used_by_agent_id,
|
|
revoked, revoked_at, revoked_reason, status, created_by, metadata,
|
|
max_seats, seats_used
|
|
FROM registration_tokens
|
|
WHERE token = $1 AND status = 'active' AND expires_at > NOW() AND seats_used < max_seats
|
|
`
|
|
|
|
err := q.db.Get(®Token, query, token)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("invalid, expired, or seats full")
|
|
}
|
|
return nil, fmt.Errorf("failed to validate token: %w", err)
|
|
}
|
|
|
|
return ®Token, nil
|
|
}
|
|
|
|
// MarkTokenUsed marks a token as used by an agent
|
|
// With seat tracking, this increments seats_used and only marks status='used' when all seats are taken
|
|
func (q *RegistrationTokenQueries) MarkTokenUsed(token string, agentID uuid.UUID) error {
|
|
// Call the PostgreSQL function that handles seat tracking logic
|
|
query := `SELECT mark_registration_token_used($1, $2)`
|
|
|
|
var success bool
|
|
err := q.db.QueryRow(query, token, agentID).Scan(&success)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to mark token as used: %w", err)
|
|
}
|
|
|
|
if !success {
|
|
return fmt.Errorf("token not found, already used, expired, or seats full")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetActiveRegistrationTokens returns all active tokens that haven't expired
|
|
func (q *RegistrationTokenQueries) GetActiveRegistrationTokens() ([]RegistrationToken, error) {
|
|
var tokens []RegistrationToken
|
|
query := `
|
|
SELECT id, token, label, expires_at, created_at, used_at, used_by_agent_id,
|
|
revoked, revoked_at, revoked_reason, status, created_by, metadata,
|
|
max_seats, seats_used
|
|
FROM registration_tokens
|
|
WHERE status = 'active' AND expires_at > NOW()
|
|
ORDER BY created_at DESC
|
|
`
|
|
|
|
err := q.db.Select(&tokens, query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get active tokens: %w", err)
|
|
}
|
|
|
|
return tokens, nil
|
|
}
|
|
|
|
// GetAllRegistrationTokens returns all tokens with pagination
|
|
func (q *RegistrationTokenQueries) GetAllRegistrationTokens(limit, offset int) ([]RegistrationToken, error) {
|
|
var tokens []RegistrationToken
|
|
query := `
|
|
SELECT id, token, label, expires_at, created_at, used_at, used_by_agent_id,
|
|
revoked, revoked_at, revoked_reason, status, created_by, metadata,
|
|
max_seats, seats_used
|
|
FROM registration_tokens
|
|
ORDER BY created_at DESC
|
|
LIMIT $1 OFFSET $2
|
|
`
|
|
|
|
err := q.db.Select(&tokens, query, limit, offset)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get all tokens: %w", err)
|
|
}
|
|
|
|
return tokens, nil
|
|
}
|
|
|
|
// RevokeRegistrationToken revokes a token (can revoke tokens in any status)
|
|
func (q *RegistrationTokenQueries) RevokeRegistrationToken(token, reason string) error {
|
|
query := `
|
|
UPDATE registration_tokens
|
|
SET status = 'revoked',
|
|
revoked = true,
|
|
revoked_at = NOW(),
|
|
revoked_reason = $1
|
|
WHERE token = $2
|
|
`
|
|
|
|
result, err := q.db.Exec(query, reason, token)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to revoke token: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return fmt.Errorf("token not found")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteRegistrationToken permanently deletes a token from the database
|
|
func (q *RegistrationTokenQueries) DeleteRegistrationToken(tokenID uuid.UUID) error {
|
|
query := `DELETE FROM registration_tokens WHERE id = $1`
|
|
|
|
result, err := q.db.Exec(query, tokenID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete token: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return fmt.Errorf("token not found")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CleanupExpiredTokens marks expired tokens as expired
|
|
func (q *RegistrationTokenQueries) CleanupExpiredTokens() (int, error) {
|
|
query := `
|
|
UPDATE registration_tokens
|
|
SET status = 'expired',
|
|
used_at = NOW()
|
|
WHERE status = 'active' AND expires_at < NOW() AND used_at IS NULL
|
|
`
|
|
|
|
result, err := q.db.Exec(query)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to cleanup expired tokens: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
return int(rowsAffected), nil
|
|
}
|
|
|
|
// GetTokenUsageStats returns statistics about token usage
|
|
func (q *RegistrationTokenQueries) GetTokenUsageStats() (map[string]int, error) {
|
|
stats := make(map[string]int)
|
|
|
|
query := `
|
|
SELECT status, COUNT(*) as count
|
|
FROM registration_tokens
|
|
GROUP BY status
|
|
`
|
|
|
|
rows, err := q.db.Query(query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get token stats: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var status string
|
|
var count int
|
|
if err := rows.Scan(&status, &count); err != nil {
|
|
return nil, fmt.Errorf("failed to scan token stats row: %w", err)
|
|
}
|
|
stats[status] = count
|
|
}
|
|
|
|
return stats, nil
|
|
} |