security improvements

This commit is contained in:
Fran Jurmanović
2025-06-29 21:59:41 +02:00
parent 7fdda06dba
commit caba5bae70
30 changed files with 3929 additions and 147 deletions

View File

@@ -20,8 +20,6 @@ func Init(di *dig.Container, app *fiber.App) {
// Protected routes
groups := app.Group(configs.Prefix)
serverIdGroup := groups.Group("/server/:id")
routeGroups := &common.RouteGroups{
Api: groups.Group("/api"),
@@ -32,20 +30,12 @@ func Init(di *dig.Container, app *fiber.App) {
StateHistory: serverIdGroup.Group("/state-history"),
}
err := di.Provide(func() *common.RouteGroups {
return routeGroups
})
if err != nil {
logging.Panic("unable to bind routes")
}
err = di.Provide(func() *dig.Container {
return di
})
if err != nil {
logging.Panic("unable to bind dig")
}
controller.InitializeControllers(di)
}

View File

@@ -59,7 +59,7 @@ func (ac *ConfigController) UpdateConfig(c *fiber.Ctx) error {
if err != nil {
return c.Status(400).SendString(err.Error())
}
logging.Info("restart", restart)
logging.Info("restart: %v", restart)
if restart {
_, err := ac.apiService.ApiRestartServer(c)
if err != nil {

View File

@@ -55,6 +55,7 @@ func (c *MembershipController) Login(ctx *fiber.Ctx) error {
return ctx.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request"})
}
logging.Debug("Login request received")
token, err := c.service.Login(ctx.UserContext(), req.Username, req.Password)
if err != nil {
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": err.Error()})

View File

@@ -1,9 +1,12 @@
package middleware
import (
"acc-server-manager/local/middleware/security"
"acc-server-manager/local/service"
"acc-server-manager/local/utl/jwt"
"acc-server-manager/local/utl/logging"
"strings"
"time"
"github.com/gofiber/fiber/v2"
)
@@ -11,49 +14,125 @@ import (
// AuthMiddleware provides authentication and permission middleware.
type AuthMiddleware struct {
membershipService *service.MembershipService
securityMW *security.SecurityMiddleware
}
// NewAuthMiddleware creates a new AuthMiddleware.
func NewAuthMiddleware(ms *service.MembershipService) *AuthMiddleware {
return &AuthMiddleware{
membershipService: ms,
securityMW: security.NewSecurityMiddleware(),
}
}
// Authenticate is a middleware for JWT authentication.
// Authenticate is a middleware for JWT authentication with enhanced security.
func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
// Log authentication attempt
ip := ctx.IP()
userAgent := ctx.Get("User-Agent")
authHeader := ctx.Get("Authorization")
if authHeader == "" {
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Missing or malformed JWT"})
logging.Error("Authentication failed: missing Authorization header from IP %s", ip)
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Missing or malformed JWT",
})
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Missing or malformed JWT"})
logging.Error("Authentication failed: malformed Authorization header from IP %s", ip)
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Missing or malformed JWT",
})
}
claims, err := jwt.ValidateToken(parts[1])
// Validate token length to prevent potential attacks
token := parts[1]
if len(token) < 10 || len(token) > 2048 {
logging.Error("Authentication failed: invalid token length from IP %s", ip)
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Invalid or expired JWT",
})
}
claims, err := jwt.ValidateToken(token)
if err != nil {
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid or expired JWT"})
logging.Error("Authentication failed: invalid token from IP %s, User-Agent: %s, Error: %v", ip, userAgent, err)
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Invalid or expired JWT",
})
}
// Additional security: validate user ID format
if claims.UserID == "" || len(claims.UserID) < 10 {
logging.Error("Authentication failed: invalid user ID in token from IP %s", ip)
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Invalid or expired JWT",
})
}
ctx.Locals("userID", claims.UserID)
ctx.Locals("authTime", time.Now())
logging.Info("User %s authenticated successfully from IP %s", claims.UserID, ip)
return ctx.Next()
}
// HasPermission is a middleware for checking user permissions.
// HasPermission is a middleware for checking user permissions with enhanced logging.
func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler {
return func(ctx *fiber.Ctx) error {
userID, ok := ctx.Locals("userID").(string)
if !ok {
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Unauthorized"})
logging.Error("Permission check failed: no user ID in context from IP %s", ctx.IP())
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Unauthorized",
})
}
// Validate permission parameter
if requiredPermission == "" {
logging.Error("Permission check failed: empty permission requirement")
return ctx.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"error": "Internal server error",
})
}
has, err := m.membershipService.HasPermission(ctx.UserContext(), userID, requiredPermission)
if err != nil || !has {
return ctx.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "Forbidden"})
if err != nil {
logging.Error("Permission check error for user %s, permission %s: %v", userID, requiredPermission, err)
return ctx.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "Forbidden",
})
}
if !has {
logging.Error("Permission denied: user %s lacks permission %s, IP %s", userID, requiredPermission, ctx.IP())
return ctx.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "Forbidden",
})
}
logging.Info("Permission granted: user %s has permission %s", userID, requiredPermission)
return ctx.Next()
}
}
// AuthRateLimit applies rate limiting specifically for authentication endpoints
func (m *AuthMiddleware) AuthRateLimit() fiber.Handler {
return m.securityMW.AuthRateLimit()
}
// RequireHTTPS redirects HTTP requests to HTTPS in production
func (m *AuthMiddleware) RequireHTTPS() fiber.Handler {
return func(ctx *fiber.Ctx) error {
if ctx.Protocol() != "https" && ctx.Get("X-Forwarded-Proto") != "https" {
// Allow HTTP in development/testing
if ctx.Hostname() != "localhost" && ctx.Hostname() != "127.0.0.1" {
httpsURL := "https://" + ctx.Hostname() + ctx.OriginalURL()
return ctx.Redirect(httpsURL, fiber.StatusMovedPermanently)
}
}
return ctx.Next()
}
}

View File

@@ -0,0 +1,351 @@
package security
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/gofiber/fiber/v2"
)
// RateLimiter stores rate limiting information
type RateLimiter struct {
requests map[string][]time.Time
mutex sync.RWMutex
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter() *RateLimiter {
rl := &RateLimiter{
requests: make(map[string][]time.Time),
}
// Clean up old entries every 5 minutes
go rl.cleanup()
return rl
}
// cleanup removes old entries from the rate limiter
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.mutex.Lock()
now := time.Now()
for key, times := range rl.requests {
// Remove entries older than 1 hour
filtered := make([]time.Time, 0, len(times))
for _, t := range times {
if now.Sub(t) < time.Hour {
filtered = append(filtered, t)
}
}
if len(filtered) == 0 {
delete(rl.requests, key)
} else {
rl.requests[key] = filtered
}
}
rl.mutex.Unlock()
}
}
// SecurityMiddleware provides comprehensive security middleware
type SecurityMiddleware struct {
rateLimiter *RateLimiter
}
// NewSecurityMiddleware creates a new security middleware
func NewSecurityMiddleware() *SecurityMiddleware {
return &SecurityMiddleware{
rateLimiter: NewRateLimiter(),
}
}
// SecurityHeaders adds security headers to responses
func (sm *SecurityMiddleware) SecurityHeaders() fiber.Handler {
return func(c *fiber.Ctx) error {
// Prevent MIME type sniffing
c.Set("X-Content-Type-Options", "nosniff")
// Prevent clickjacking
c.Set("X-Frame-Options", "DENY")
// Enable XSS protection
c.Set("X-XSS-Protection", "1; mode=block")
// Prevent referrer leakage
c.Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Content Security Policy
c.Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none'")
// Permissions Policy
c.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), interest-cohort=()")
return c.Next()
}
}
// RateLimit implements rate limiting for API endpoints
func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) fiber.Handler {
return func(c *fiber.Ctx) error {
ip := c.IP()
key := fmt.Sprintf("rate_limit:%s", ip)
sm.rateLimiter.mutex.Lock()
defer sm.rateLimiter.mutex.Unlock()
now := time.Now()
requests := sm.rateLimiter.requests[key]
// Remove requests older than duration
filtered := make([]time.Time, 0, len(requests))
for _, t := range requests {
if now.Sub(t) < duration {
filtered = append(filtered, t)
}
}
// Check if limit is exceeded
if len(filtered) >= maxRequests {
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
"error": "Rate limit exceeded",
"retry_after": duration.Seconds(),
})
}
// Add current request
filtered = append(filtered, now)
sm.rateLimiter.requests[key] = filtered
return c.Next()
}
}
// AuthRateLimit implements stricter rate limiting for authentication endpoints
func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
return func(c *fiber.Ctx) error {
ip := c.IP()
userAgent := c.Get("User-Agent")
key := fmt.Sprintf("%s:%s", ip, userAgent)
sm.rateLimiter.mutex.Lock()
defer sm.rateLimiter.mutex.Unlock()
now := time.Now()
requests := sm.rateLimiter.requests[key]
// Remove requests older than 15 minutes
filtered := make([]time.Time, 0, len(requests))
for _, t := range requests {
if now.Sub(t) < 15*time.Minute {
filtered = append(filtered, t)
}
}
// Check if limit is exceeded (5 requests per 15 minutes for auth)
if len(filtered) >= 5 {
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
"error": "Too many authentication attempts",
"retry_after": 900, // 15 minutes
})
}
// Add current request
filtered = append(filtered, now)
sm.rateLimiter.requests[key] = filtered
return c.Next()
}
}
// InputSanitization sanitizes user input to prevent XSS and injection attacks
func (sm *SecurityMiddleware) InputSanitization() fiber.Handler {
return func(c *fiber.Ctx) error {
// Sanitize query parameters
c.Request().URI().QueryArgs().VisitAll(func(key, value []byte) {
sanitized := sanitizeInput(string(value))
c.Request().URI().QueryArgs().Set(string(key), sanitized)
})
// Store original body for processing
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
body := c.Body()
if len(body) > 0 {
// Basic sanitization - remove potentially dangerous patterns
sanitized := sanitizeInput(string(body))
c.Request().SetBodyString(sanitized)
}
}
return c.Next()
}
}
// sanitizeInput removes potentially dangerous patterns from input
func sanitizeInput(input string) string {
// Remove common XSS patterns
dangerous := []string{
"<script",
"</script>",
"javascript:",
"vbscript:",
"data:text/html",
"onload=",
"onerror=",
"onclick=",
"onmouseover=",
"onfocus=",
"onblur=",
"onchange=",
"onsubmit=",
"<iframe",
"<object",
"<embed",
"<link",
"<meta",
"<style",
}
result := strings.ToLower(input)
for _, pattern := range dangerous {
result = strings.ReplaceAll(result, pattern, "")
}
// If the sanitized version is very different, it might be malicious
if len(result) < len(input)/2 {
return ""
}
return input
}
// ValidateContentType ensures only expected content types are accepted
func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.Handler {
return func(c *fiber.Ctx) error {
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
contentType := c.Get("Content-Type")
if contentType == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "Content-Type header is required",
})
}
// Check if content type is allowed
allowed := false
for _, allowedType := range allowedTypes {
if strings.Contains(contentType, allowedType) {
allowed = true
break
}
}
if !allowed {
return c.Status(fiber.StatusUnsupportedMediaType).JSON(fiber.Map{
"error": "Unsupported content type",
})
}
}
return c.Next()
}
}
// ValidateUserAgent blocks requests with suspicious or missing user agents
func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
suspiciousAgents := []string{
"sqlmap",
"nikto",
"nmap",
"masscan",
"gobuster",
"dirb",
"dirbuster",
"wpscan",
"curl/7.0", // Very old curl versions
"wget/1.0", // Very old wget versions
}
return func(c *fiber.Ctx) error {
userAgent := strings.ToLower(c.Get("User-Agent"))
// Block empty user agents
if userAgent == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "User-Agent header is required",
})
}
// Block suspicious user agents
for _, suspicious := range suspiciousAgents {
if strings.Contains(userAgent, suspicious) {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "Access denied",
})
}
}
return c.Next()
}
}
// RequestSizeLimit limits the size of incoming requests
func (sm *SecurityMiddleware) RequestSizeLimit(maxSize int) fiber.Handler {
return func(c *fiber.Ctx) error {
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
contentLength := c.Request().Header.ContentLength()
if contentLength > maxSize {
return c.Status(fiber.StatusRequestEntityTooLarge).JSON(fiber.Map{
"error": "Request too large",
"max_size": maxSize,
})
}
}
return c.Next()
}
}
// LogSecurityEvents logs security-related events
func (sm *SecurityMiddleware) LogSecurityEvents() fiber.Handler {
return func(c *fiber.Ctx) error {
start := time.Now()
// Process request
err := c.Next()
// Log suspicious activity
status := c.Response().StatusCode()
if status == 401 || status == 403 || status == 429 {
duration := time.Since(start)
// In a real implementation, you would send this to your logging system
fmt.Printf("[SECURITY] %s %s %s %d %v %s\n",
time.Now().Format(time.RFC3339),
c.IP(),
c.Method(),
status,
duration,
c.Path(),
)
}
return err
}
}
// TimeoutMiddleware adds request timeout
func (sm *SecurityMiddleware) TimeoutMiddleware(timeout time.Duration) fiber.Handler {
return func(c *fiber.Ctx) error {
ctx, cancel := context.WithTimeout(c.UserContext(), timeout)
defer cancel()
c.SetUserContext(ctx)
return c.Next()
}
}

View File

@@ -0,0 +1,238 @@
package migrations
import (
"acc-server-manager/local/utl/logging"
"acc-server-manager/local/utl/password"
"errors"
"fmt"
"gorm.io/gorm"
)
// Migration001UpgradePasswordSecurity migrates existing user passwords from encrypted to hashed format
type Migration001UpgradePasswordSecurity struct {
DB *gorm.DB
}
// NewMigration001UpgradePasswordSecurity creates a new password security migration
func NewMigration001UpgradePasswordSecurity(db *gorm.DB) *Migration001UpgradePasswordSecurity {
return &Migration001UpgradePasswordSecurity{DB: db}
}
// Up executes the migration
func (m *Migration001UpgradePasswordSecurity) Up() error {
logging.Info("Starting password security upgrade migration...")
// Check if migration has already been applied
var migrationRecord MigrationRecord
err := m.DB.Where("migration_name = ?", "001_upgrade_password_security").First(&migrationRecord).Error
if err == nil {
logging.Info("Password security migration already applied, skipping")
return nil
}
// Create migration tracking table if it doesn't exist
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
return fmt.Errorf("failed to create migration tracking table: %v", err)
}
// Start transaction
tx := m.DB.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to start transaction: %v", tx.Error)
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// Add a backup column for old passwords (temporary)
if err := tx.Exec("ALTER TABLE users ADD COLUMN password_backup TEXT").Error; err != nil {
// Column might already exist, ignore if it's a duplicate column error
if !isDuplicateColumnError(err) {
tx.Rollback()
return fmt.Errorf("failed to add backup column: %v", err)
}
}
// Get all users with encrypted passwords
var users []UserForMigration
if err := tx.Find(&users).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to fetch users: %v", err)
}
logging.Info("Found %d users to migrate", len(users))
migratedCount := 0
failedCount := 0
for _, user := range users {
if err := m.migrateUserPassword(tx, &user); err != nil {
logging.Error("Failed to migrate user %s (ID: %s): %v", user.Username, user.ID, err)
failedCount++
// Continue with other users rather than failing completely
continue
}
migratedCount++
}
// Remove backup column after successful migration
if err := tx.Exec("ALTER TABLE users DROP COLUMN password_backup").Error; err != nil {
logging.Error("Failed to remove backup column (non-critical): %v", err)
// Don't fail the migration for this
}
// Record successful migration
migrationRecord = MigrationRecord{
MigrationName: "001_upgrade_password_security",
AppliedAt: "datetime('now')",
Success: true,
Notes: fmt.Sprintf("Migrated %d users, %d failed", migratedCount, failedCount),
}
if err := tx.Create(&migrationRecord).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to record migration: %v", err)
}
// Commit transaction
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit migration: %v", err)
}
logging.Info("Password security migration completed successfully. Migrated: %d, Failed: %d", migratedCount, failedCount)
if failedCount > 0 {
logging.Error("Some users failed to migrate. They will need to reset their passwords.")
}
return nil
}
// migrateUserPassword migrates a single user's password
func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, user *UserForMigration) error {
// Skip if password is already hashed (bcrypt hashes start with $2a$, $2b$, or $2y$)
if isAlreadyHashed(user.Password) {
logging.Debug("User %s already has hashed password, skipping", user.Username)
return nil
}
// Backup original password
if err := tx.Model(user).Update("password_backup", user.Password).Error; err != nil {
return fmt.Errorf("failed to backup password: %v", err)
}
// Try to decrypt the old password
var plainPassword string
// First, try to decrypt using the old encryption method
decrypted, err := decryptOldPassword(user.Password)
if err != nil {
// If decryption fails, the password might already be plain text or corrupted
logging.Error("Failed to decrypt password for user %s, treating as plain text: %v", user.Username, err)
// Use original password as-is (might be plain text from development)
plainPassword = user.Password
// Validate it's not obviously encrypted data
if len(plainPassword) > 100 || containsBinaryData(plainPassword) {
return fmt.Errorf("password appears to be corrupted encrypted data")
}
} else {
plainPassword = decrypted
}
// Validate plain password
if plainPassword == "" {
return errors.New("decrypted password is empty")
}
if len(plainPassword) < 1 {
return errors.New("password too short after decryption")
}
// Hash the plain password using bcrypt
hashedPassword, err := password.HashPassword(plainPassword)
if err != nil {
return fmt.Errorf("failed to hash password: %v", err)
}
// Update with hashed password
if err := tx.Model(user).Update("password", hashedPassword).Error; err != nil {
return fmt.Errorf("failed to update password: %v", err)
}
logging.Debug("Successfully migrated password for user %s", user.Username)
return nil
}
// UserForMigration represents a user record for migration purposes
type UserForMigration struct {
ID string `gorm:"column:id"`
Username string `gorm:"column:username"`
Password string `gorm:"column:password"`
}
// TableName specifies the table name for GORM
func (UserForMigration) TableName() string {
return "users"
}
// MigrationRecord tracks applied migrations
type MigrationRecord struct {
ID uint `gorm:"primaryKey"`
MigrationName string `gorm:"unique;not null"`
AppliedAt string `gorm:"not null"`
Success bool `gorm:"not null"`
Notes string
}
// TableName specifies the table name for GORM
func (MigrationRecord) TableName() string {
return "migration_records"
}
// isAlreadyHashed checks if a password is already bcrypt hashed
func isAlreadyHashed(password string) bool {
return len(password) >= 60 && (password[:4] == "$2a$" || password[:4] == "$2b$" || password[:4] == "$2y$")
}
// containsBinaryData checks if a string contains binary data
func containsBinaryData(s string) bool {
for _, b := range []byte(s) {
if b < 32 && b != 9 && b != 10 && b != 13 { // Allow tab, newline, carriage return
return true
}
}
return false
}
// isDuplicateColumnError checks if an error is due to duplicate column
func isDuplicateColumnError(err error) bool {
errStr := err.Error()
return fmt.Sprintf("%v", errStr) == "duplicate column name: password_backup" ||
fmt.Sprintf("%v", errStr) == "SQLITE_ERROR: duplicate column name: password_backup"
}
// decryptOldPassword attempts to decrypt using the old encryption method
// This is a simplified version of the old DecryptPassword function
func decryptOldPassword(encryptedPassword string) (string, error) {
// This would use the old decryption logic
// For now, we'll return an error to force treating as plain text
// In a real scenario, you'd implement the old decryption here
return "", errors.New("old decryption not implemented - treating as plain text")
}
// Down reverses the migration (if needed)
func (m *Migration001UpgradePasswordSecurity) Down() error {
logging.Error("Password security migration rollback is not supported for security reasons")
return errors.New("password security migration rollback is not supported")
}
// RunMigration is a convenience function to run the migration
func RunPasswordSecurityMigration(db *gorm.DB) error {
migration := NewMigration001UpgradePasswordSecurity(db)
return migration.Up()
}

View File

@@ -8,6 +8,8 @@ import (
"encoding/base64"
"errors"
"io"
"regexp"
"strings"
"time"
"gorm.io/gorm"
@@ -74,25 +76,67 @@ func (s *SteamCredentials) AfterFind(tx *gorm.DB) error {
return nil
}
// Validate checks if the credentials are valid
// Validate checks if the credentials are valid with enhanced security checks
func (s *SteamCredentials) Validate() error {
if s.Username == "" {
return errors.New("username is required")
}
// Enhanced username validation
if len(s.Username) < 3 || len(s.Username) > 64 {
return errors.New("username must be between 3 and 64 characters")
}
// Check for valid characters in username (alphanumeric, underscore, hyphen)
if matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, s.Username); !matched {
return errors.New("username contains invalid characters")
}
if s.Password == "" {
return errors.New("password is required")
}
// Basic password validation
if len(s.Password) < 6 {
return errors.New("password must be at least 6 characters long")
}
if len(s.Password) > 128 {
return errors.New("password is too long")
}
// Check for obvious weak passwords
weakPasswords := []string{"password", "123456", "steam", "admin", "user"}
lowerPass := strings.ToLower(s.Password)
for _, weak := range weakPasswords {
if lowerPass == weak {
return errors.New("password is too weak")
}
}
return nil
}
// GetEncryptionKey returns the encryption key from config.
// The key is loaded from the ENCRYPTION_KEY environment variable.
func GetEncryptionKey() []byte {
return []byte(configs.EncryptionKey)
key := []byte(configs.EncryptionKey)
if len(key) != 32 {
panic("encryption key must be exactly 32 bytes for AES-256")
}
return key
}
// EncryptPassword encrypts a password using AES-256
// EncryptPassword encrypts a password using AES-256-GCM with enhanced security
func EncryptPassword(password string) (string, error) {
if password == "" {
return "", errors.New("password cannot be empty")
}
if len(password) > 1024 {
return "", errors.New("password too long")
}
key := GetEncryptionKey()
block, err := aes.NewCipher(key)
if err != nil {
@@ -105,21 +149,30 @@ func EncryptPassword(password string) (string, error) {
return "", err
}
// Create a nonce
// Create a cryptographically secure nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
// Encrypt the password
// Encrypt the password with authenticated encryption
ciphertext := gcm.Seal(nonce, nonce, []byte(password), nil)
// Return base64 encoded encrypted password
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptPassword decrypts an encrypted password
// DecryptPassword decrypts an encrypted password with enhanced validation
func DecryptPassword(encryptedPassword string) (string, error) {
if encryptedPassword == "" {
return "", errors.New("encrypted password cannot be empty")
}
// Validate base64 format
if len(encryptedPassword) < 24 { // Minimum reasonable length
return "", errors.New("invalid encrypted password format")
}
key := GetEncryptionKey()
block, err := aes.NewCipher(key)
if err != nil {
@@ -135,7 +188,7 @@ func DecryptPassword(encryptedPassword string) (string, error) {
// Decode base64 encoded password
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
if err != nil {
return "", err
return "", errors.New("invalid base64 encoding")
}
nonceSize := gcm.NonceSize()
@@ -146,8 +199,14 @@ func DecryptPassword(encryptedPassword string) (string, error) {
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", err
return "", errors.New("decryption failed - invalid ciphertext or key")
}
return string(plaintext), nil
}
// Validate decrypted content
decrypted := string(plaintext)
if len(decrypted) == 0 || len(decrypted) > 1024 {
return "", errors.New("invalid decrypted password")
}
return decrypted, nil
}

View File

@@ -1,6 +1,7 @@
package model
import (
"acc-server-manager/local/utl/password"
"errors"
"github.com/google/uuid"
@@ -11,54 +12,57 @@ import (
type User struct {
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
Username string `json:"username" gorm:"unique_index;not null"`
Password string `json:"password" gorm:"not null"`
Password string `json:"-" gorm:"not null"` // Never expose password in JSON
RoleID uuid.UUID `json:"role_id" gorm:"type:uuid"`
Role Role `json:"role"`
}
// BeforeCreate is a GORM hook that runs before creating new credentials
// BeforeCreate is a GORM hook that runs before creating new users
func (s *User) BeforeCreate(tx *gorm.DB) error {
s.ID = uuid.New()
// Encrypt password before saving
encrypted, err := EncryptPassword(s.Password)
// Validate password strength
if err := password.ValidatePasswordStrength(s.Password); err != nil {
return err
}
// Hash password before saving
hashed, err := password.HashPassword(s.Password)
if err != nil {
return err
}
s.Password = encrypted
s.Password = hashed
return nil
}
// BeforeUpdate is a GORM hook that runs before updating credentials
// BeforeUpdate is a GORM hook that runs before updating users
func (s *User) BeforeUpdate(tx *gorm.DB) error {
// Only encrypt if password field is being updated
// Only hash if password field is being updated
if tx.Statement.Changed("Password") {
encrypted, err := EncryptPassword(s.Password)
// Validate password strength
if err := password.ValidatePasswordStrength(s.Password); err != nil {
return err
}
hashed, err := password.HashPassword(s.Password)
if err != nil {
return err
}
s.Password = encrypted
s.Password = hashed
}
return nil
}
// AfterFind is a GORM hook that runs after fetching credentials
// AfterFind is a GORM hook that runs after fetching users
func (s *User) AfterFind(tx *gorm.DB) error {
// Decrypt password after fetching
if s.Password != "" {
decrypted, err := DecryptPassword(s.Password)
if err != nil {
return err
}
s.Password = decrypted
}
// Password remains hashed - never decrypt
// This hook is kept for potential future use
return nil
}
// Validate checks if the credentials are valid
// Validate checks if the user data is valid
func (s *User) Validate() error {
if s.Username == "" {
return errors.New("username is required")
@@ -67,4 +71,9 @@ func (s *User) Validate() error {
return errors.New("password is required")
}
return nil
}
}
// VerifyPassword verifies a plain text password against the stored hash
func (s *User) VerifyPassword(plainPassword string) error {
return password.VerifyPassword(s.Password, plainPassword)
}

View File

@@ -41,7 +41,6 @@ func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context,
return &user, nil
}
// CreateUser creates a new user.
func (r *MembershipRepository) CreateUser(ctx context.Context, user *model.User) error {
db := r.db.WithContext(ctx)

View File

@@ -17,73 +17,9 @@ func NewServerRepository(db *gorm.DB) *ServerRepository {
BaseRepository: NewBaseRepository[model.Server, model.ServerFilter](db, model.Server{}),
}
// Run migrations
if err := repo.migrateServerTable(); err != nil {
panic(err)
}
return repo
}
// migrateServerTable ensures all required columns exist with proper defaults
func (r *ServerRepository) migrateServerTable() error {
// Create a temporary table with all required columns
if err := r.db.Exec(`
CREATE TABLE IF NOT EXISTS servers_new (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
ip TEXT NOT NULL,
port INTEGER NOT NULL DEFAULT 9600,
path TEXT NOT NULL,
service_name TEXT NOT NULL,
date_created DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
from_steam_cmd BOOLEAN NOT NULL DEFAULT 1
)
`).Error; err != nil {
return err
}
// Copy data from old table, setting defaults for new columns
if err := r.db.Exec(`
INSERT INTO servers_new (
id,
name,
ip,
port,
path,
service_name,
date_created,
from_steam_cmd
)
SELECT
id,
COALESCE(name, 'Server ' || id) as name,
COALESCE(ip, '127.0.0.1') as ip,
COALESCE(port, 9600) as port,
path,
COALESCE(service_name, 'ACC-Server-' || id) as service_name,
COALESCE(date_created, CURRENT_TIMESTAMP) as date_created,
COALESCE(from_steam_cmd, 1) as from_steam_cmd
FROM servers
`).Error; err != nil {
// If the old table doesn't exist, this is a fresh install
if err := r.db.Exec(`DROP TABLE IF EXISTS servers_new`).Error; err != nil {
return err
}
return nil
}
// Replace old table with new one
if err := r.db.Exec(`DROP TABLE IF EXISTS servers`).Error; err != nil {
return err
}
if err := r.db.Exec(`ALTER TABLE servers_new RENAME TO servers`).Error; err != nil {
return err
}
return nil
}
// GetFirstByServiceName
// Gets first row from Server table.
//
@@ -100,4 +36,4 @@ func (r *ServerRepository) GetFirstByServiceName(ctx context.Context, serviceNam
return nil, err
}
return result, nil
}
}

View File

@@ -93,7 +93,7 @@ func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, fil
rawQuery := `
SELECT
DATETIME(MIN(date_created)) as timestamp,
AVG(player_count) as count
ROUND(AVG(player_count)) as count
FROM state_histories
WHERE server_id = ? AND date_created BETWEEN ? AND ?
GROUP BY strftime('%Y-%m-%d %H', date_created)

View File

@@ -4,6 +4,7 @@ import (
"acc-server-manager/local/model"
"acc-server-manager/local/repository"
"acc-server-manager/local/utl/jwt"
"acc-server-manager/local/utl/logging"
"context"
"errors"
"os"
@@ -28,7 +29,8 @@ func (s *MembershipService) Login(ctx context.Context, username, password string
return "", errors.New("invalid credentials")
}
if user.Password != password {
// Use secure password verification with constant-time comparison
if err := user.VerifyPassword(password); err != nil {
return "", errors.New("invalid credentials")
}
@@ -40,6 +42,7 @@ func (s *MembershipService) CreateUser(ctx context.Context, username, password,
role, err := s.repo.FindRoleByName(ctx, roleName)
if err != nil {
logging.Error("Failed to find role by name: %v", err)
return nil, errors.New("role not found")
}
@@ -50,8 +53,10 @@ func (s *MembershipService) CreateUser(ctx context.Context, username, password,
}
if err := s.repo.CreateUser(ctx, user); err != nil {
logging.Error("Failed to create user: %v", err)
return nil, err
}
logging.Debug("User created successfully")
return user, nil
}
@@ -90,6 +95,7 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
}
if req.Password != nil && *req.Password != "" {
// Password will be automatically hashed in BeforeUpdate hook
user.Password = *req.Password
}
@@ -162,6 +168,7 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
// Create a default admin user if one doesn't exist
_, err = s.repo.FindUserByUsername(ctx, "admin")
if err != nil {
logging.Debug("Creating default admin user")
_, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin") // Default password, should be changed
if err != nil {
return err

View File

@@ -3,6 +3,8 @@ package configs
import (
"log"
"os"
"github.com/joho/godotenv"
)
var (
@@ -14,12 +16,14 @@ var (
)
func init() {
Secret = getEnv("APP_SECRET", "default-secret-for-dev-use-only")
SecretCode = getEnv("APP_SECRET_CODE", "another-secret-for-dev-use-only")
EncryptionKey = getEnv("ENCRYPTION_KEY", "a-secure-32-byte-long-key-!!!!!!") // Fallback MUST be 32 bytes for AES-256
godotenv.Load()
// Fail fast if critical environment variables are missing
Secret = getEnvRequired("APP_SECRET")
SecretCode = getEnvRequired("APP_SECRET_CODE")
EncryptionKey = getEnvRequired("ENCRYPTION_KEY")
if len(EncryptionKey) != 32 {
log.Fatal("ENCRYPTION_KEY must be 32 bytes long")
log.Fatal("ENCRYPTION_KEY must be exactly 32 bytes long for AES-256")
}
}
@@ -31,3 +35,13 @@ func getEnv(key, fallback string) string {
log.Printf("Environment variable %s not set, using fallback.", key)
return fallback
}
// getEnvRequired retrieves an environment variable and fails if it's not set.
// This should be used for critical configuration that must not have defaults.
func getEnvRequired(key string) string {
if value, exists := os.LookupEnv(key); exists && value != "" {
return value
}
log.Fatalf("Required environment variable %s is not set or is empty", key)
return "" // This line will never be reached due to log.Fatalf
}

View File

@@ -44,9 +44,9 @@ func Migrate(db *gorm.DB) {
&model.StateHistory{},
&model.SteamCredentials{},
&model.SystemConfig{},
&model.User{},
&model.Role{},
&model.Permission{},
&model.Role{},
&model.User{},
)
if err != nil {
@@ -55,6 +55,10 @@ func Migrate(db *gorm.DB) {
db.FirstOrCreate(&model.ApiModel{Api: "Works"})
// Run security migrations - temporarily disabled until migration is fixed
// TODO: Implement proper migration system
logging.Info("Database migration system needs to be implemented")
Seed(db)
}
@@ -80,8 +84,6 @@ func Seed(db *gorm.DB) error {
return nil
}
func seedTracks(db *gorm.DB) error {
tracks := []model.Track{
{Name: "monza", UniquePitBoxes: 29, PrivateServerSlots: 60},

View File

@@ -2,16 +2,18 @@ package jwt
import (
"acc-server-manager/local/model"
"crypto/rand"
"encoding/base64"
"errors"
"log"
"os"
"time"
"github.com/golang-jwt/jwt/v4"
)
// SecretKey is the secret key for signing the JWT.
// It is recommended to use a long, complex string for this.
// In a production environment, this should be loaded from a secure configuration source.
var SecretKey = []byte("your-secret-key")
// SecretKey holds the JWT signing key loaded from environment
var SecretKey []byte
// Claims represents the JWT claims.
type Claims struct {
@@ -19,6 +21,36 @@ type Claims struct {
jwt.RegisteredClaims
}
// init initializes the JWT secret key from environment variable
func init() {
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
log.Fatal("JWT_SECRET environment variable is required and cannot be empty")
}
// Decode base64 secret if it looks like base64, otherwise use as-is
if decoded, err := base64.StdEncoding.DecodeString(jwtSecret); err == nil && len(decoded) >= 32 {
SecretKey = decoded
} else {
SecretKey = []byte(jwtSecret)
}
// Ensure minimum key length for security
if len(SecretKey) < 32 {
log.Fatal("JWT_SECRET must be at least 32 bytes long for security")
}
}
// GenerateSecretKey generates a cryptographically secure random key for JWT signing
// This is a utility function for generating new secrets, not used in normal operation
func GenerateSecretKey() string {
key := make([]byte, 64) // 512 bits
if _, err := rand.Read(key); err != nil {
log.Fatal("Failed to generate random key: ", err)
}
return base64.StdEncoding.EncodeToString(key)
}
// GenerateToken generates a new JWT for a given user.
func GenerateToken(user *model.User) (string, error) {
expirationTime := time.Now().Add(24 * time.Hour)

View File

@@ -0,0 +1,82 @@
package password
import (
"errors"
"os"
"golang.org/x/crypto/bcrypt"
)
const (
// MinPasswordLength defines the minimum password length
MinPasswordLength = 8
// BcryptCost defines the cost factor for bcrypt hashing
BcryptCost = 12
)
// HashPassword hashes a plain text password using bcrypt
func HashPassword(password string) (string, error) {
if len(password) < MinPasswordLength {
return "", errors.New("password must be at least 8 characters long")
}
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), BcryptCost)
if err != nil {
return "", err
}
return string(hashedBytes), nil
}
// VerifyPassword verifies a plain text password against a hashed password
func VerifyPassword(hashedPassword, password string) error {
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
}
// ValidatePasswordStrength validates password complexity requirements
func ValidatePasswordStrength(password string) error {
if len(password) < MinPasswordLength {
return errors.New("password must be at least 8 characters long")
}
if os.Getenv("ENFORCE_PASSWORD_STRENGTH") == "true" {
if len(password) < MinPasswordLength {
return errors.New("password must be at least 8 characters long")
}
hasUpper := false
hasLower := false
hasDigit := false
hasSpecial := false
for _, char := range password {
switch {
case char >= 'A' && char <= 'Z':
hasUpper = true
case char >= 'a' && char <= 'z':
hasLower = true
case char >= '0' && char <= '9':
hasDigit = true
case char >= '!' && char <= '/' || char >= ':' && char <= '@' || char >= '[' && char <= '`' || char >= '{' && char <= '~':
hasSpecial = true
}
}
if !hasUpper {
return errors.New("password must contain at least one uppercase letter")
}
if !hasLower {
return errors.New("password must contain at least one lowercase letter")
}
if !hasDigit {
return errors.New("password must contain at least one digit")
}
if !hasSpecial {
return errors.New("password must contain at least one special character")
}
return nil
}
return nil
}

View File

@@ -2,8 +2,10 @@ package server
import (
"acc-server-manager/local/api"
"acc-server-manager/local/middleware/security"
"acc-server-manager/local/utl/logging"
"os"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
@@ -15,8 +17,25 @@ import (
func Start(di *dig.Container) *fiber.App {
app := fiber.New(fiber.Config{
EnablePrintRoutes: true,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
BodyLimit: 10 * 1024 * 1024, // 10MB
})
// Initialize security middleware
securityMW := security.NewSecurityMiddleware()
// Add security middleware stack
app.Use(securityMW.SecurityHeaders())
app.Use(securityMW.LogSecurityEvents())
app.Use(securityMW.TimeoutMiddleware(30 * time.Second))
app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024)) // 10MB
app.Use(securityMW.ValidateUserAgent())
app.Use(securityMW.ValidateContentType("application/json", "application/x-www-form-urlencoded", "multipart/form-data"))
app.Use(securityMW.InputSanitization())
app.Use(securityMW.RateLimit(100, 1*time.Minute)) // 100 requests per minute global
app.Use(helmet.New())
allowedOrigin := os.Getenv("CORS_ALLOWED_ORIGIN")
@@ -25,8 +44,11 @@ func Start(di *dig.Container) *fiber.App {
}
app.Use(cors.New(cors.Config{
AllowOrigins: allowedOrigin,
AllowHeaders: "Origin, Content-Type, Accept",
AllowOrigins: allowedOrigin,
AllowHeaders: "Origin, Content-Type, Accept, Authorization",
AllowMethods: "GET, POST, PUT, DELETE, OPTIONS",
AllowCredentials: true,
MaxAge: 86400, // 24 hours
}))
app.Get("/swagger/*", swagger.HandlerDefault)

View File

@@ -154,6 +154,9 @@ func (instance *AccServerInstance) UpdateState(callback func(state *model.Server
}
func (instance *AccServerInstance) UpdatePlayerCount(count int) {
if (count < 0) {
return
}
instance.UpdateState(func (state *model.ServerState, changes *[]StateChange) {
if (count == state.PlayerCount) {
return