steam 2fa for polling and security
This commit is contained in:
@@ -31,6 +31,7 @@ func Init(di *dig.Container, app *fiber.App) {
|
||||
StateHistory: serverIdGroup.Group("/state-history"),
|
||||
Membership: groups.Group("/membership"),
|
||||
System: groups.Group("/system"),
|
||||
Steam2FA: groups.Group("/steam2fa"),
|
||||
}
|
||||
|
||||
accessKeyMiddleware := middleware.NewAccessKeyMiddleware()
|
||||
|
||||
@@ -54,4 +54,9 @@ func InitializeControllers(c *dig.Container) {
|
||||
if err != nil {
|
||||
logging.Panic("unable to initialize membership controller")
|
||||
}
|
||||
|
||||
err = c.Invoke(NewSteam2FAController)
|
||||
if err != nil {
|
||||
logging.Panic("unable to initialize steam 2fa controller")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ func NewMembershipController(service *service.MembershipService, auth *middlewar
|
||||
}
|
||||
|
||||
routeGroups.Auth.Post("/login", mc.Login)
|
||||
routeGroups.Auth.Post("/open-token", mc.GenerateOpenToken)
|
||||
|
||||
usersGroup := routeGroups.Membership
|
||||
usersGroup.Use(mc.auth.Authenticate)
|
||||
@@ -82,6 +83,26 @@ func (c *MembershipController) Login(ctx *fiber.Ctx) error {
|
||||
return ctx.JSON(fiber.Map{"token": token})
|
||||
}
|
||||
|
||||
// GenerateOpenToken generates an open token for a user.
|
||||
// @Summary Generate an open token
|
||||
// @Description Generate an open token for a user
|
||||
// @Tags Authentication
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} object{token=string} "JWT token"
|
||||
// @Failure 400 {object} error_handler.ErrorResponse "Invalid request body"
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Invalid credentials"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Router /auth/open-token [post]
|
||||
func (c *MembershipController) GenerateOpenToken(ctx *fiber.Ctx) error {
|
||||
token, err := c.service.GenerateOpenToken(ctx.UserContext(), ctx.Locals("userId").(string))
|
||||
if err != nil {
|
||||
return c.errorHandler.HandleAuthError(ctx, err)
|
||||
}
|
||||
|
||||
return ctx.JSON(fiber.Map{"token": token})
|
||||
}
|
||||
|
||||
// CreateUser creates a new user.
|
||||
// @Summary Create a new user
|
||||
// @Description Create a new user account with specified role
|
||||
|
||||
139
local/controller/steam_2fa.go
Normal file
139
local/controller/steam_2fa.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/middleware"
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/common"
|
||||
"acc-server-manager/local/utl/error_handler"
|
||||
"acc-server-manager/local/utl/jwt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type Steam2FAController struct {
|
||||
tfaManager *model.Steam2FAManager
|
||||
errorHandler *error_handler.ControllerErrorHandler
|
||||
jwtHandler *jwt.OpenJWTHandler
|
||||
}
|
||||
|
||||
func NewSteam2FAController(tfaManager *model.Steam2FAManager, routeGroups *common.RouteGroups, auth *middleware.AuthMiddleware, jwtHandler *jwt.OpenJWTHandler) *Steam2FAController {
|
||||
controller := &Steam2FAController{
|
||||
tfaManager: tfaManager,
|
||||
errorHandler: error_handler.NewControllerErrorHandler(),
|
||||
jwtHandler: jwtHandler,
|
||||
}
|
||||
|
||||
steam2faRoutes := routeGroups.Steam2FA
|
||||
steam2faRoutes.Use(auth.AuthenticateOpen)
|
||||
|
||||
// Define routes
|
||||
steam2faRoutes.Get("/pending", auth.HasPermission(model.ServerView), controller.GetPendingRequests)
|
||||
steam2faRoutes.Get("/:id", auth.HasPermission(model.ServerView), controller.GetRequest)
|
||||
steam2faRoutes.Post("/:id/complete", auth.HasPermission(model.ServerUpdate), controller.CompleteRequest)
|
||||
steam2faRoutes.Post("/:id/cancel", auth.HasPermission(model.ServerUpdate), controller.CancelRequest)
|
||||
|
||||
return controller
|
||||
}
|
||||
|
||||
// GetPendingRequests gets all pending 2FA requests
|
||||
//
|
||||
// @Summary Get pending 2FA requests
|
||||
// @Description Get all pending Steam 2FA authentication requests
|
||||
// @Tags Steam 2FA
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {array} model.Steam2FARequest
|
||||
// @Failure 500 {object} error_handler.ErrorResponse
|
||||
// @Router /steam2fa/pending [get]
|
||||
func (c *Steam2FAController) GetPendingRequests(ctx *fiber.Ctx) error {
|
||||
requests := c.tfaManager.GetPendingRequests()
|
||||
return ctx.JSON(requests)
|
||||
}
|
||||
|
||||
// GetRequest gets a specific 2FA request by ID
|
||||
//
|
||||
// @Summary Get 2FA request
|
||||
// @Description Get a specific Steam 2FA authentication request by ID
|
||||
// @Tags Steam 2FA
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "2FA Request ID"
|
||||
// @Success 200 {object} model.Steam2FARequest
|
||||
// @Failure 404 {object} error_handler.ErrorResponse
|
||||
// @Failure 500 {object} error_handler.ErrorResponse
|
||||
// @Router /steam2fa/{id} [get]
|
||||
func (c *Steam2FAController) GetRequest(ctx *fiber.Ctx) error {
|
||||
id := ctx.Params("id")
|
||||
if id == "" {
|
||||
return c.errorHandler.HandleError(ctx, fiber.ErrBadRequest, fiber.StatusBadRequest)
|
||||
}
|
||||
|
||||
request, exists := c.tfaManager.GetRequest(id)
|
||||
if !exists {
|
||||
return c.errorHandler.HandleNotFoundError(ctx, "2FA request")
|
||||
}
|
||||
|
||||
return ctx.JSON(request)
|
||||
}
|
||||
|
||||
// CompleteRequest marks a 2FA request as completed
|
||||
//
|
||||
// @Summary Complete 2FA request
|
||||
// @Description Mark a Steam 2FA authentication request as completed
|
||||
// @Tags Steam 2FA
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "2FA Request ID"
|
||||
// @Success 200 {object} model.Steam2FARequest
|
||||
// @Failure 400 {object} error_handler.ErrorResponse
|
||||
// @Failure 404 {object} error_handler.ErrorResponse
|
||||
// @Failure 500 {object} error_handler.ErrorResponse
|
||||
// @Router /steam2fa/{id}/complete [post]
|
||||
func (c *Steam2FAController) CompleteRequest(ctx *fiber.Ctx) error {
|
||||
id := ctx.Params("id")
|
||||
if id == "" {
|
||||
return c.errorHandler.HandleError(ctx, fiber.ErrBadRequest, fiber.StatusBadRequest)
|
||||
}
|
||||
|
||||
if err := c.tfaManager.CompleteRequest(id); err != nil {
|
||||
return c.errorHandler.HandleError(ctx, err, fiber.StatusBadRequest)
|
||||
}
|
||||
|
||||
request, exists := c.tfaManager.GetRequest(id)
|
||||
if !exists {
|
||||
return c.errorHandler.HandleNotFoundError(ctx, "2FA request")
|
||||
}
|
||||
|
||||
return ctx.JSON(request)
|
||||
}
|
||||
|
||||
// CancelRequest cancels a 2FA request
|
||||
//
|
||||
// @Summary Cancel 2FA request
|
||||
// @Description Cancel a Steam 2FA authentication request
|
||||
// @Tags Steam 2FA
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "2FA Request ID"
|
||||
// @Success 200 {object} model.Steam2FARequest
|
||||
// @Failure 400 {object} error_handler.ErrorResponse
|
||||
// @Failure 404 {object} error_handler.ErrorResponse
|
||||
// @Failure 500 {object} error_handler.ErrorResponse
|
||||
// @Router /steam2fa/{id}/cancel [post]
|
||||
func (c *Steam2FAController) CancelRequest(ctx *fiber.Ctx) error {
|
||||
id := ctx.Params("id")
|
||||
if id == "" {
|
||||
return c.errorHandler.HandleError(ctx, fiber.ErrBadRequest, fiber.StatusBadRequest)
|
||||
}
|
||||
|
||||
if err := c.tfaManager.ErrorRequest(id, "cancelled by user"); err != nil {
|
||||
return c.errorHandler.HandleError(ctx, err, fiber.StatusBadRequest)
|
||||
}
|
||||
|
||||
request, exists := c.tfaManager.GetRequest(id)
|
||||
if !exists {
|
||||
return c.errorHandler.HandleNotFoundError(ctx, "2FA request")
|
||||
}
|
||||
|
||||
return ctx.JSON(request)
|
||||
}
|
||||
@@ -30,14 +30,18 @@ type AuthMiddleware struct {
|
||||
membershipService *service.MembershipService
|
||||
cache *cache.InMemoryCache
|
||||
securityMW *security.SecurityMiddleware
|
||||
jwtHandler *jwt.JWTHandler
|
||||
openJWTHandler *jwt.OpenJWTHandler
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates a new AuthMiddleware.
|
||||
func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache) *AuthMiddleware {
|
||||
func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache, jwtHandler *jwt.JWTHandler, openJWTHandler *jwt.OpenJWTHandler) *AuthMiddleware {
|
||||
auth := &AuthMiddleware{
|
||||
membershipService: ms,
|
||||
cache: cache,
|
||||
securityMW: security.NewSecurityMiddleware(),
|
||||
jwtHandler: jwtHandler,
|
||||
openJWTHandler: openJWTHandler,
|
||||
}
|
||||
|
||||
// Set up bidirectional relationship for cache invalidation
|
||||
@@ -46,8 +50,17 @@ func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache
|
||||
return auth
|
||||
}
|
||||
|
||||
// Authenticate is a middleware for JWT authentication with enhanced security.
|
||||
func (m *AuthMiddleware) AuthenticateOpen(ctx *fiber.Ctx) error {
|
||||
return m.AuthenticateWithHandler(m.openJWTHandler.JWTHandler, ctx)
|
||||
}
|
||||
|
||||
// Authenticate is a middleware for JWT authentication with enhanced security.
|
||||
func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
return m.AuthenticateWithHandler(m.jwtHandler, ctx)
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, ctx *fiber.Ctx) error {
|
||||
// Log authentication attempt
|
||||
ip := ctx.IP()
|
||||
userAgent := ctx.Get("User-Agent")
|
||||
@@ -77,7 +90,7 @@ func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
})
|
||||
}
|
||||
|
||||
claims, err := jwt.ValidateToken(token)
|
||||
claims, err := jwtHandler.ValidateToken(token)
|
||||
if err != nil {
|
||||
logging.Error("Authentication failed: invalid token from IP %s, User-Agent: %s, Error: %v", ip, userAgent, err)
|
||||
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/utl/graceful"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -22,35 +23,42 @@ func NewRateLimiter() *RateLimiter {
|
||||
requests: make(map[string][]time.Time),
|
||||
}
|
||||
|
||||
// Clean up old entries every 5 minutes
|
||||
go rl.cleanup()
|
||||
// Use graceful shutdown for cleanup goroutine
|
||||
shutdownManager := graceful.GetManager()
|
||||
shutdownManager.RunGoroutine(func(ctx context.Context) {
|
||||
rl.cleanupWithContext(ctx)
|
||||
})
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// cleanup removes old entries from the rate limiter
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
func (rl *RateLimiter) cleanupWithContext(ctx context.Context) {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.mutex.Lock()
|
||||
now := time.Now()
|
||||
for key, times := range rl.requests {
|
||||
// Remove entries older than 1 hour
|
||||
filtered := make([]time.Time, 0, len(times))
|
||||
for _, t := range times {
|
||||
if now.Sub(t) < time.Hour {
|
||||
filtered = append(filtered, t)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rl.mutex.Lock()
|
||||
now := time.Now()
|
||||
for key, times := range rl.requests {
|
||||
filtered := make([]time.Time, 0, len(times))
|
||||
for _, t := range times {
|
||||
if now.Sub(t) < time.Hour {
|
||||
filtered = append(filtered, t)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(rl.requests, key)
|
||||
} else {
|
||||
rl.requests[key] = filtered
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(rl.requests, key)
|
||||
} else {
|
||||
rl.requests[key] = filtered
|
||||
}
|
||||
rl.mutex.Unlock()
|
||||
}
|
||||
rl.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,13 +197,13 @@ func (sm *SecurityMiddleware) InputSanitization() fiber.Handler {
|
||||
|
||||
// sanitizeInput removes potentially dangerous patterns from input
|
||||
func sanitizeInput(input string) string {
|
||||
// Remove common XSS patterns
|
||||
dangerous := []string{
|
||||
"<script",
|
||||
"</script>",
|
||||
"javascript:",
|
||||
"vbscript:",
|
||||
"data:text/html",
|
||||
"data:application",
|
||||
"onload=",
|
||||
"onerror=",
|
||||
"onclick=",
|
||||
@@ -204,25 +212,46 @@ func sanitizeInput(input string) string {
|
||||
"onblur=",
|
||||
"onchange=",
|
||||
"onsubmit=",
|
||||
"onkeydown=",
|
||||
"onkeyup=",
|
||||
"<iframe",
|
||||
"<object",
|
||||
"<embed",
|
||||
"<link",
|
||||
"<meta",
|
||||
"<style",
|
||||
"<form",
|
||||
"<input",
|
||||
"<button",
|
||||
"<svg",
|
||||
"<math",
|
||||
"expression(",
|
||||
"@import",
|
||||
"url(",
|
||||
"\\x",
|
||||
"\\u",
|
||||
"&#x",
|
||||
"&#",
|
||||
}
|
||||
|
||||
result := strings.ToLower(input)
|
||||
result := input
|
||||
lowerInput := strings.ToLower(input)
|
||||
|
||||
for _, pattern := range dangerous {
|
||||
result = strings.ReplaceAll(result, pattern, "")
|
||||
if strings.Contains(lowerInput, pattern) {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// If the sanitized version is very different, it might be malicious
|
||||
if len(result) < len(input)/2 {
|
||||
if strings.Contains(result, "\x00") {
|
||||
return ""
|
||||
}
|
||||
|
||||
return input
|
||||
if len(strings.TrimSpace(result)) == 0 && len(input) > 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateContentType ensures only expected content types are accepted
|
||||
@@ -349,3 +378,24 @@ func (sm *SecurityMiddleware) TimeoutMiddleware(timeout time.Duration) fiber.Han
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SecurityMiddleware) RequestContextTimeout(timeout time.Duration) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
ctx, cancel := context.WithTimeout(c.UserContext(), timeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- c.Next()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return c.Status(fiber.StatusRequestTimeout).JSON(fiber.Map{
|
||||
"error": "Request timeout",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
168
local/model/steam_2fa.go
Normal file
168
local/model/steam_2fa.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Steam2FAStatus string
|
||||
|
||||
const (
|
||||
Steam2FAStatusIdle Steam2FAStatus = "idle"
|
||||
Steam2FAStatusPending Steam2FAStatus = "pending"
|
||||
Steam2FAStatusComplete Steam2FAStatus = "complete"
|
||||
Steam2FAStatusError Steam2FAStatus = "error"
|
||||
)
|
||||
|
||||
type Steam2FARequest struct {
|
||||
ID string `json:"id"`
|
||||
Status Steam2FAStatus `json:"status"`
|
||||
Message string `json:"message"`
|
||||
RequestTime time.Time `json:"requestTime"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
ErrorMsg string `json:"errorMsg,omitempty"`
|
||||
ServerID *uuid.UUID `json:"serverId,omitempty"`
|
||||
}
|
||||
|
||||
// Steam2FAManager manages 2FA requests and responses
|
||||
type Steam2FAManager struct {
|
||||
mu sync.RWMutex
|
||||
requests map[string]*Steam2FARequest
|
||||
channels map[string]chan bool
|
||||
}
|
||||
|
||||
func NewSteam2FAManager() *Steam2FAManager {
|
||||
return &Steam2FAManager{
|
||||
requests: make(map[string]*Steam2FARequest),
|
||||
channels: make(map[string]chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) CreateRequest(message string, serverID *uuid.UUID) *Steam2FARequest {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
id := uuid.New().String()
|
||||
request := &Steam2FARequest{
|
||||
ID: id,
|
||||
Status: Steam2FAStatusPending,
|
||||
Message: message,
|
||||
RequestTime: time.Now(),
|
||||
ServerID: serverID,
|
||||
}
|
||||
|
||||
m.requests[id] = request
|
||||
m.channels[id] = make(chan bool, 1)
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) GetRequest(id string) (*Steam2FARequest, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
req, exists := m.requests[id]
|
||||
return req, exists
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) GetPendingRequests() []*Steam2FARequest {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var pending []*Steam2FARequest
|
||||
for _, req := range m.requests {
|
||||
if req.Status == Steam2FAStatusPending {
|
||||
pending = append(pending, req)
|
||||
}
|
||||
}
|
||||
return pending
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) CompleteRequest(id string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
req, exists := m.requests[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("request %s not found", id)
|
||||
}
|
||||
|
||||
if req.Status != Steam2FAStatusPending {
|
||||
return fmt.Errorf("request %s is not pending", id)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
req.Status = Steam2FAStatusComplete
|
||||
req.CompletedAt = &now
|
||||
|
||||
// Signal the waiting goroutine
|
||||
if ch, exists := m.channels[id]; exists {
|
||||
select {
|
||||
case ch <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) ErrorRequest(id string, errorMsg string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
req, exists := m.requests[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("request %s not found", id)
|
||||
}
|
||||
|
||||
req.Status = Steam2FAStatusError
|
||||
req.ErrorMsg = errorMsg
|
||||
|
||||
// Signal the waiting goroutine with error
|
||||
if ch, exists := m.channels[id]; exists {
|
||||
select {
|
||||
case ch <- false:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) WaitForCompletion(id string, timeout time.Duration) (bool, error) {
|
||||
m.mu.RLock()
|
||||
ch, exists := m.channels[id]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false, fmt.Errorf("request %s not found", id)
|
||||
}
|
||||
|
||||
select {
|
||||
case success := <-ch:
|
||||
return success, nil
|
||||
case <-time.After(timeout):
|
||||
// Timeout - mark as error
|
||||
m.ErrorRequest(id, "timeout waiting for 2FA confirmation")
|
||||
return false, fmt.Errorf("timeout waiting for 2FA confirmation")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) CleanupOldRequests(maxAge time.Duration) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
for id, req := range m.requests {
|
||||
if req.RequestTime.Before(cutoff) {
|
||||
delete(m.requests, id)
|
||||
if ch, exists := m.channels[id]; exists {
|
||||
close(ch)
|
||||
delete(m.channels, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,12 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/graceful"
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
@@ -17,4 +23,29 @@ func InitializeRepositories(c *dig.Container) {
|
||||
c.Provide(NewLookupRepository)
|
||||
c.Provide(NewSteamCredentialsRepository)
|
||||
c.Provide(NewMembershipRepository)
|
||||
|
||||
// Provide the Steam2FAManager as a singleton
|
||||
if err := c.Provide(func() *model.Steam2FAManager {
|
||||
manager := model.NewSteam2FAManager()
|
||||
|
||||
// Use graceful shutdown manager for cleanup goroutine
|
||||
shutdownManager := graceful.GetManager()
|
||||
shutdownManager.RunGoroutine(func(ctx context.Context) {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
manager.CleanupOldRequests(30 * time.Minute)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return manager
|
||||
}); err != nil {
|
||||
logging.Panic("unable to initialize steam 2fa manager")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,13 +22,17 @@ type CacheInvalidator interface {
|
||||
type MembershipService struct {
|
||||
repo *repository.MembershipRepository
|
||||
cacheInvalidator CacheInvalidator
|
||||
jwtHandler *jwt.JWTHandler
|
||||
openJwtHandler *jwt.OpenJWTHandler
|
||||
}
|
||||
|
||||
// NewMembershipService creates a new MembershipService.
|
||||
func NewMembershipService(repo *repository.MembershipRepository) *MembershipService {
|
||||
func NewMembershipService(repo *repository.MembershipRepository, jwtHandler *jwt.JWTHandler, openJwtHandler *jwt.OpenJWTHandler) *MembershipService {
|
||||
return &MembershipService{
|
||||
repo: repo,
|
||||
cacheInvalidator: nil, // Will be set later via SetCacheInvalidator
|
||||
jwtHandler: jwtHandler,
|
||||
openJwtHandler: openJwtHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,18 +42,37 @@ func (s *MembershipService) SetCacheInvalidator(invalidator CacheInvalidator) {
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns a JWT.
|
||||
func (s *MembershipService) Login(ctx context.Context, username, password string) (string, error) {
|
||||
func (s *MembershipService) HandleLogin(ctx context.Context, username, password string) (*model.User, error) {
|
||||
user, err := s.repo.FindUserByUsername(ctx, username)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid credentials")
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
// Use secure password verification with constant-time comparison
|
||||
if err := user.VerifyPassword(password); err != nil {
|
||||
return "", errors.New("invalid credentials")
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
return jwt.GenerateToken(user)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns a JWT.
|
||||
func (s *MembershipService) Login(ctx context.Context, username, password string) (string, error) {
|
||||
user, err := s.HandleLogin(ctx, username, password)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return s.jwtHandler.GenerateToken(user)
|
||||
}
|
||||
|
||||
func (s *MembershipService) GenerateOpenToken(ctx context.Context, userId string) (string, error) {
|
||||
user, err := s.repo.GetByID(ctx, userId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return s.openJwtHandler.GenerateToken(user)
|
||||
}
|
||||
|
||||
// CreateUser creates a new user.
|
||||
|
||||
@@ -337,7 +337,7 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
|
||||
}
|
||||
|
||||
// Install server using SteamCMD
|
||||
if err := s.steamService.InstallServer(ctx.UserContext(), server.GetServerPath()); err != nil {
|
||||
if err := s.steamService.InstallServer(ctx.UserContext(), server.GetServerPath(), &server.ID); err != nil {
|
||||
return fmt.Errorf("failed to install server: %v", err)
|
||||
}
|
||||
|
||||
@@ -450,7 +450,7 @@ func (s *ServerService) UpdateServer(ctx *fiber.Ctx, server *model.Server) error
|
||||
|
||||
// Update server files if path changed
|
||||
if existingServer.Path != server.Path {
|
||||
if err := s.steamService.InstallServer(ctx.UserContext(), server.Path); err != nil {
|
||||
if err := s.steamService.InstallServer(ctx.UserContext(), server.Path, &server.ID); err != nil {
|
||||
return fmt.Errorf("failed to install server to new location: %v", err)
|
||||
}
|
||||
// Clean up old installation
|
||||
|
||||
@@ -18,12 +18,12 @@ func InitializeServices(c *dig.Container) {
|
||||
|
||||
logging.Debug("Registering services")
|
||||
// Provide services
|
||||
c.Provide(NewSteamService)
|
||||
c.Provide(NewServerService)
|
||||
c.Provide(NewStateHistoryService)
|
||||
c.Provide(NewServiceControlService)
|
||||
c.Provide(NewConfigService)
|
||||
c.Provide(NewLookupService)
|
||||
c.Provide(NewSteamService)
|
||||
c.Provide(NewWindowsService)
|
||||
c.Provide(NewFirewallService)
|
||||
c.Provide(NewMembershipService)
|
||||
|
||||
@@ -6,10 +6,14 @@ import (
|
||||
"acc-server-manager/local/utl/command"
|
||||
"acc-server-manager/local/utl/env"
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"acc-server-manager/local/utl/security"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -17,17 +21,27 @@ const (
|
||||
)
|
||||
|
||||
type SteamService struct {
|
||||
executor *command.CommandExecutor
|
||||
repository *repository.SteamCredentialsRepository
|
||||
executor *command.CommandExecutor
|
||||
interactiveExecutor *command.InteractiveCommandExecutor
|
||||
repository *repository.SteamCredentialsRepository
|
||||
tfaManager *model.Steam2FAManager
|
||||
pathValidator *security.PathValidator
|
||||
downloadVerifier *security.DownloadVerifier
|
||||
}
|
||||
|
||||
func NewSteamService(repository *repository.SteamCredentialsRepository) *SteamService {
|
||||
func NewSteamService(repository *repository.SteamCredentialsRepository, tfaManager *model.Steam2FAManager) *SteamService {
|
||||
baseExecutor := &command.CommandExecutor{
|
||||
ExePath: "powershell",
|
||||
LogOutput: true,
|
||||
}
|
||||
|
||||
return &SteamService{
|
||||
executor: &command.CommandExecutor{
|
||||
ExePath: "powershell",
|
||||
LogOutput: true,
|
||||
},
|
||||
repository: repository,
|
||||
executor: baseExecutor,
|
||||
interactiveExecutor: command.NewInteractiveCommandExecutor(baseExecutor, tfaManager),
|
||||
repository: repository,
|
||||
tfaManager: tfaManager,
|
||||
pathValidator: security.NewPathValidator(),
|
||||
downloadVerifier: security.NewDownloadVerifier(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,7 +56,7 @@ func (s *SteamService) SaveCredentials(ctx context.Context, creds *model.SteamCr
|
||||
return s.repository.Save(ctx, creds)
|
||||
}
|
||||
|
||||
func (s *SteamService) ensureSteamCMD(ctx context.Context) error {
|
||||
func (s *SteamService) ensureSteamCMD(_ context.Context) error {
|
||||
// Get SteamCMD path from environment variable
|
||||
steamCMDPath := env.GetSteamCMDPath()
|
||||
steamCMDDir := filepath.Dir(steamCMDPath)
|
||||
@@ -57,10 +71,13 @@ func (s *SteamService) ensureSteamCMD(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to create SteamCMD directory: %v", err)
|
||||
}
|
||||
|
||||
// Download and install SteamCMD
|
||||
// Download and install SteamCMD securely
|
||||
logging.Info("Downloading SteamCMD...")
|
||||
if err := s.executor.Execute("-Command",
|
||||
"Invoke-WebRequest -Uri 'https://steamcdn-a.akamaihd.net/client/installer/steamcmd.zip' -OutFile 'steamcmd.zip'"); err != nil {
|
||||
steamCMDZip := filepath.Join(steamCMDDir, "steamcmd.zip")
|
||||
if err := s.downloadVerifier.VerifyAndDownload(
|
||||
"https://steamcdn-a.akamaihd.net/client/installer/steamcmd.zip",
|
||||
steamCMDZip,
|
||||
""); err != nil {
|
||||
return fmt.Errorf("failed to download SteamCMD: %v", err)
|
||||
}
|
||||
|
||||
@@ -76,11 +93,16 @@ func (s *SteamService) ensureSteamCMD(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SteamService) InstallServer(ctx context.Context, installPath string) error {
|
||||
func (s *SteamService) InstallServer(ctx context.Context, installPath string, serverID *uuid.UUID) error {
|
||||
if err := s.ensureSteamCMD(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate installation path for security
|
||||
if err := s.pathValidator.ValidateInstallPath(installPath); err != nil {
|
||||
return fmt.Errorf("invalid installation path: %v", err)
|
||||
}
|
||||
|
||||
// Convert to absolute path and ensure proper Windows path format
|
||||
absPath, err := filepath.Abs(installPath)
|
||||
if err != nil {
|
||||
@@ -126,17 +148,15 @@ func (s *SteamService) InstallServer(ctx context.Context, installPath string) er
|
||||
"+quit",
|
||||
)
|
||||
|
||||
// Run SteamCMD
|
||||
// Use interactive executor to handle potential 2FA prompts
|
||||
logging.Info("Installing ACC server to %s...", absPath)
|
||||
if err := s.executor.Execute(args...); err != nil {
|
||||
if err := s.interactiveExecutor.ExecuteInteractive(ctx, serverID, args...); err != nil {
|
||||
return fmt.Errorf("failed to run SteamCMD: %v", err)
|
||||
}
|
||||
|
||||
// Add a delay to allow Steam to properly cleanup
|
||||
logging.Info("Waiting for Steam operations to complete...")
|
||||
if err := s.executor.Execute("-Command", "Start-Sleep -Seconds 5"); err != nil {
|
||||
logging.Warn("Failed to wait after Steam operations: %v", err)
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Verify installation
|
||||
exePath := filepath.Join(absPath, "server", "accServer.exe")
|
||||
@@ -148,8 +168,8 @@ func (s *SteamService) InstallServer(ctx context.Context, installPath string) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SteamService) UpdateServer(ctx context.Context, installPath string) error {
|
||||
return s.InstallServer(ctx, installPath) // Same process as install
|
||||
func (s *SteamService) UpdateServer(ctx context.Context, installPath string, serverID *uuid.UUID) error {
|
||||
return s.InstallServer(ctx, installPath, serverID) // Same process as install
|
||||
}
|
||||
|
||||
func (s *SteamService) UninstallServer(installPath string) error {
|
||||
|
||||
64
local/utl/audit/audit.go
Normal file
64
local/utl/audit/audit.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuditAction string
|
||||
|
||||
const (
|
||||
ActionLogin AuditAction = "LOGIN"
|
||||
ActionLogout AuditAction = "LOGOUT"
|
||||
ActionServerCreate AuditAction = "SERVER_CREATE"
|
||||
ActionServerUpdate AuditAction = "SERVER_UPDATE"
|
||||
ActionServerDelete AuditAction = "SERVER_DELETE"
|
||||
ActionServerStart AuditAction = "SERVER_START"
|
||||
ActionServerStop AuditAction = "SERVER_STOP"
|
||||
ActionUserCreate AuditAction = "USER_CREATE"
|
||||
ActionUserUpdate AuditAction = "USER_UPDATE"
|
||||
ActionUserDelete AuditAction = "USER_DELETE"
|
||||
ActionConfigUpdate AuditAction = "CONFIG_UPDATE"
|
||||
ActionSteamAuth AuditAction = "STEAM_AUTH"
|
||||
ActionPermissionGrant AuditAction = "PERMISSION_GRANT"
|
||||
ActionPermissionRevoke AuditAction = "PERMISSION_REVOKE"
|
||||
)
|
||||
|
||||
type AuditEntry struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Action AuditAction `json:"action"`
|
||||
Resource string `json:"resource"`
|
||||
Details string `json:"details"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
func LogAction(ctx context.Context, userID, username string, action AuditAction, resource, details, ipAddress, userAgent string, success bool) {
|
||||
logging.InfoWithContext("AUDIT", "User %s (%s) performed %s on %s from %s - Success: %t - Details: %s",
|
||||
username, userID, action, resource, ipAddress, success, details)
|
||||
}
|
||||
|
||||
func LogAuthAction(ctx context.Context, username, ipAddress, userAgent string, success bool, details string) {
|
||||
action := ActionLogin
|
||||
if !success {
|
||||
details = "Failed: " + details
|
||||
}
|
||||
|
||||
LogAction(ctx, "", username, action, "authentication", details, ipAddress, userAgent, success)
|
||||
}
|
||||
|
||||
func LogServerAction(ctx context.Context, userID, username string, action AuditAction, serverID, ipAddress, userAgent string, success bool, details string) {
|
||||
LogAction(ctx, userID, username, action, "server:"+serverID, details, ipAddress, userAgent, success)
|
||||
}
|
||||
|
||||
func LogUserManagementAction(ctx context.Context, adminUserID, adminUsername string, action AuditAction, targetUserID, ipAddress, userAgent string, success bool, details string) {
|
||||
LogAction(ctx, adminUserID, adminUsername, action, "user:"+targetUserID, details, ipAddress, userAgent, success)
|
||||
}
|
||||
|
||||
func LogConfigAction(ctx context.Context, userID, username string, configType, ipAddress, userAgent string, success bool, details string) {
|
||||
LogAction(ctx, userID, username, ActionConfigUpdate, "config:"+configType, details, ipAddress, userAgent, success)
|
||||
}
|
||||
179
local/utl/command/interactive_executor.go
Normal file
179
local/utl/command/interactive_executor.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// InteractiveCommandExecutor extends CommandExecutor to handle interactive commands
|
||||
type InteractiveCommandExecutor struct {
|
||||
*CommandExecutor
|
||||
tfaManager *model.Steam2FAManager
|
||||
}
|
||||
|
||||
func NewInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager) *InteractiveCommandExecutor {
|
||||
return &InteractiveCommandExecutor{
|
||||
CommandExecutor: baseExecutor,
|
||||
tfaManager: tfaManager,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteInteractive runs a command that may require 2FA input
|
||||
func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error {
|
||||
cmd := exec.CommandContext(ctx, e.ExePath, args...)
|
||||
|
||||
if e.WorkDir != "" {
|
||||
cmd.Dir = e.WorkDir
|
||||
}
|
||||
|
||||
// Create pipes for stdin, stdout, and stderr
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdin pipe: %v", err)
|
||||
}
|
||||
defer stdin.Close()
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdout pipe: %v", err)
|
||||
}
|
||||
defer stdout.Close()
|
||||
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stderr pipe: %v", err)
|
||||
}
|
||||
defer stderr.Close()
|
||||
|
||||
logging.Info("Executing interactive command: %s %s", e.ExePath, strings.Join(args, " "))
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start command: %v", err)
|
||||
}
|
||||
|
||||
// Create channels for output monitoring
|
||||
outputDone := make(chan error)
|
||||
|
||||
// Monitor stdout and stderr for 2FA prompts
|
||||
go e.monitorOutput(ctx, stdout, stderr, serverID, outputDone)
|
||||
|
||||
// Wait for either the command to finish or output monitoring to complete
|
||||
cmdErr := cmd.Wait()
|
||||
outputErr := <-outputDone
|
||||
|
||||
if outputErr != nil {
|
||||
logging.Warn("Output monitoring error: %v", outputErr)
|
||||
}
|
||||
|
||||
return cmdErr
|
||||
}
|
||||
|
||||
func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) {
|
||||
defer close(done)
|
||||
|
||||
// Create scanners for both outputs
|
||||
stdoutScanner := bufio.NewScanner(stdout)
|
||||
stderrScanner := bufio.NewScanner(stderr)
|
||||
|
||||
outputChan := make(chan string)
|
||||
|
||||
// Read from stdout
|
||||
go func() {
|
||||
for stdoutScanner.Scan() {
|
||||
line := stdoutScanner.Text()
|
||||
if e.LogOutput {
|
||||
logging.Info("STDOUT: %s", line)
|
||||
}
|
||||
outputChan <- line
|
||||
}
|
||||
}()
|
||||
|
||||
// Read from stderr
|
||||
go func() {
|
||||
for stderrScanner.Scan() {
|
||||
line := stderrScanner.Text()
|
||||
if e.LogOutput {
|
||||
logging.Info("STDERR: %s", line)
|
||||
}
|
||||
outputChan <- line
|
||||
}
|
||||
}()
|
||||
|
||||
// Monitor for 2FA prompts
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
done <- ctx.Err()
|
||||
return
|
||||
case line, ok := <-outputChan:
|
||||
if !ok {
|
||||
done <- nil
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this line indicates a 2FA prompt
|
||||
if e.is2FAPrompt(line) {
|
||||
if err := e.handle2FAPrompt(ctx, line, serverID); err != nil {
|
||||
logging.Error("Failed to handle 2FA prompt: %v", err)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *InteractiveCommandExecutor) is2FAPrompt(line string) bool {
|
||||
// Common SteamCMD 2FA prompts
|
||||
twoFAKeywords := []string{
|
||||
"please enter your steam guard code",
|
||||
"steam guard",
|
||||
"two-factor",
|
||||
"authentication code",
|
||||
"please check your steam mobile app",
|
||||
"confirm in application",
|
||||
}
|
||||
|
||||
lowerLine := strings.ToLower(line)
|
||||
for _, keyword := range twoFAKeywords {
|
||||
if strings.Contains(lowerLine, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *InteractiveCommandExecutor) handle2FAPrompt(_ context.Context, promptLine string, serverID *uuid.UUID) error {
|
||||
logging.Info("2FA prompt detected: %s", promptLine)
|
||||
|
||||
// Create a 2FA request
|
||||
request := e.tfaManager.CreateRequest(promptLine, serverID)
|
||||
logging.Info("Created 2FA request with ID: %s", request.ID)
|
||||
|
||||
// Wait for user to complete the 2FA process
|
||||
// Use a reasonable timeout (e.g., 5 minutes)
|
||||
timeout := 5 * time.Minute
|
||||
success, err := e.tfaManager.WaitForCompletion(request.ID, timeout)
|
||||
|
||||
if err != nil {
|
||||
logging.Error("2FA completion failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !success {
|
||||
logging.Error("2FA was not completed successfully")
|
||||
return fmt.Errorf("2FA authentication failed")
|
||||
}
|
||||
|
||||
logging.Info("2FA completed successfully")
|
||||
return nil
|
||||
}
|
||||
@@ -25,6 +25,7 @@ type RouteGroups struct {
|
||||
StateHistory fiber.Router
|
||||
Membership fiber.Router
|
||||
System fiber.Router
|
||||
Steam2FA fiber.Router
|
||||
}
|
||||
|
||||
func CheckError(err error) {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
Version = "0.10.3"
|
||||
Version = "0.10.5"
|
||||
Prefix = "v1"
|
||||
Secret string
|
||||
SecretCode string
|
||||
|
||||
67
local/utl/errors/safe_error.go
Normal file
67
local/utl/errors/safe_error.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
type SafeError struct {
|
||||
Message string
|
||||
Code int
|
||||
Fatal bool
|
||||
}
|
||||
|
||||
func (e *SafeError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func NewSafeError(message string, code int) *SafeError {
|
||||
return &SafeError{
|
||||
Message: message,
|
||||
Code: code,
|
||||
Fatal: false,
|
||||
}
|
||||
}
|
||||
|
||||
func NewFatalError(message string, code int) *SafeError {
|
||||
return &SafeError{
|
||||
Message: message,
|
||||
Code: code,
|
||||
Fatal: true,
|
||||
}
|
||||
}
|
||||
|
||||
func HandleError(err error, context string) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if safeErr, ok := err.(*SafeError); ok {
|
||||
if safeErr.Fatal {
|
||||
logging.Error("Fatal error in %s: %s", context, safeErr.Message)
|
||||
if os.Getenv("ENVIRONMENT") == "production" {
|
||||
logging.Error("Application shutting down due to fatal error")
|
||||
os.Exit(safeErr.Code)
|
||||
} else {
|
||||
logging.Warn("Fatal error occurred but not exiting in non-production environment")
|
||||
}
|
||||
} else {
|
||||
logging.Error("Error in %s: %s", context, safeErr.Message)
|
||||
}
|
||||
} else {
|
||||
logging.Error("Unexpected error in %s: %v", context, err)
|
||||
}
|
||||
}
|
||||
|
||||
func SafeFatal(message string, args ...interface{}) {
|
||||
formattedMessage := fmt.Sprintf(message, args...)
|
||||
err := NewFatalError(formattedMessage, 1)
|
||||
HandleError(err, "application")
|
||||
}
|
||||
|
||||
func SafeLog(message string, args ...interface{}) {
|
||||
formattedMessage := fmt.Sprintf(message, args...)
|
||||
err := NewSafeError(formattedMessage, 0)
|
||||
HandleError(err, "application")
|
||||
}
|
||||
91
local/utl/graceful/shutdown.go
Normal file
91
local/utl/graceful/shutdown.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package graceful
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ShutdownManager struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
handlers []func() error
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
var globalManager *ShutdownManager
|
||||
var once sync.Once
|
||||
|
||||
func GetManager() *ShutdownManager {
|
||||
once.Do(func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
globalManager = &ShutdownManager{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
handlers: make([]func() error, 0),
|
||||
}
|
||||
|
||||
go globalManager.watchSignals()
|
||||
})
|
||||
return globalManager
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) watchSignals() {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
<-sigChan
|
||||
sm.Shutdown(30 * time.Second)
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) AddHandler(handler func() error) {
|
||||
sm.mutex.Lock()
|
||||
defer sm.mutex.Unlock()
|
||||
sm.handlers = append(sm.handlers, handler)
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) Context() context.Context {
|
||||
return sm.ctx
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) AddGoroutine() {
|
||||
sm.wg.Add(1)
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) GoroutineDone() {
|
||||
sm.wg.Done()
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) RunGoroutine(fn func(ctx context.Context)) {
|
||||
sm.wg.Add(1)
|
||||
go func() {
|
||||
defer sm.wg.Done()
|
||||
fn(sm.ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) Shutdown(timeout time.Duration) {
|
||||
sm.cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sm.wg.Wait()
|
||||
|
||||
sm.mutex.Lock()
|
||||
for _, handler := range sm.handlers {
|
||||
handler()
|
||||
}
|
||||
sm.mutex.Unlock()
|
||||
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(timeout):
|
||||
}
|
||||
}
|
||||
@@ -2,57 +2,73 @@ package jwt
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/errors"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"log"
|
||||
"os"
|
||||
goerrors "errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
// SecretKey holds the JWT signing key loaded from environment
|
||||
var SecretKey []byte
|
||||
|
||||
// Claims represents the JWT claims.
|
||||
type Claims struct {
|
||||
UserID string `json:"user_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// init initializes the JWT secret key from environment variable
|
||||
func Init() {
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
log.Fatal("JWT_SECRET environment variable is required and cannot be empty")
|
||||
type JWTHandler struct {
|
||||
SecretKey []byte
|
||||
}
|
||||
|
||||
type OpenJWTHandler struct {
|
||||
*JWTHandler
|
||||
}
|
||||
|
||||
// NewJWTHandler creates a new JWTHandler instance with the provided secret key.
|
||||
func NewOpenJWTHandler(jwtSecret string) *OpenJWTHandler {
|
||||
jwtHandler := NewJWTHandler(jwtSecret)
|
||||
return &OpenJWTHandler{
|
||||
JWTHandler: jwtHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// NewJWTHandler creates a new JWTHandler instance with the provided secret key.
|
||||
func NewJWTHandler(jwtSecret string) *JWTHandler {
|
||||
if jwtSecret == "" {
|
||||
errors.SafeFatal("JWT_SECRET environment variable is required and cannot be empty")
|
||||
}
|
||||
|
||||
var secretKey []byte
|
||||
|
||||
// Decode base64 secret if it looks like base64, otherwise use as-is
|
||||
if decoded, err := base64.StdEncoding.DecodeString(jwtSecret); err == nil && len(decoded) >= 32 {
|
||||
SecretKey = decoded
|
||||
secretKey = decoded
|
||||
} else {
|
||||
SecretKey = []byte(jwtSecret)
|
||||
secretKey = []byte(jwtSecret)
|
||||
}
|
||||
|
||||
// Ensure minimum key length for security
|
||||
if len(SecretKey) < 32 {
|
||||
log.Fatal("JWT_SECRET must be at least 32 bytes long for security")
|
||||
if len(secretKey) < 32 {
|
||||
errors.SafeFatal("JWT_SECRET must be at least 32 bytes long for security")
|
||||
}
|
||||
return &JWTHandler{
|
||||
SecretKey: secretKey,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSecretKey generates a cryptographically secure random key for JWT signing
|
||||
// This is a utility function for generating new secrets, not used in normal operation
|
||||
func GenerateSecretKey() string {
|
||||
func (jh *JWTHandler) GenerateSecretKey() string {
|
||||
key := make([]byte, 64) // 512 bits
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
log.Fatal("Failed to generate random key: ", err)
|
||||
errors.SafeFatal("Failed to generate random key: %v", err)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(key)
|
||||
}
|
||||
|
||||
// GenerateToken generates a new JWT for a given user.
|
||||
func GenerateToken(user *model.User) (string, error) {
|
||||
func (jh *JWTHandler) GenerateToken(user *model.User) (string, error) {
|
||||
expirationTime := time.Now().Add(24 * time.Hour)
|
||||
claims := &Claims{
|
||||
UserID: user.ID.String(),
|
||||
@@ -62,10 +78,10 @@ func GenerateToken(user *model.User) (string, error) {
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(SecretKey)
|
||||
return token.SignedString(jh.SecretKey)
|
||||
}
|
||||
|
||||
func GenerateTokenWithExpiry(user *model.User, expiry time.Time) (string, error) {
|
||||
func (jh *JWTHandler) GenerateTokenWithExpiry(user *model.User, expiry time.Time) (string, error) {
|
||||
expirationTime := expiry
|
||||
claims := &Claims{
|
||||
UserID: user.ID.String(),
|
||||
@@ -75,15 +91,15 @@ func GenerateTokenWithExpiry(user *model.User, expiry time.Time) (string, error)
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(SecretKey)
|
||||
return token.SignedString(jh.SecretKey)
|
||||
}
|
||||
|
||||
// ValidateToken validates a JWT and returns the claims if the token is valid.
|
||||
func ValidateToken(tokenString string) (*Claims, error) {
|
||||
func (jh *JWTHandler) ValidateToken(tokenString string) (*Claims, error) {
|
||||
claims := &Claims{}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return SecretKey, nil
|
||||
return jh.SecretKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -91,7 +107,7 @@ func ValidateToken(tokenString string) (*Claims, error) {
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
return nil, goerrors.New("invalid token")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
|
||||
76
local/utl/security/download_verifier.go
Normal file
76
local/utl/security/download_verifier.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadVerifier struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewDownloadVerifier() *DownloadVerifier {
|
||||
return &DownloadVerifier{
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (dv *DownloadVerifier) VerifyAndDownload(url, outputPath, expectedSHA256 string) error {
|
||||
if url == "" {
|
||||
return fmt.Errorf("URL cannot be empty")
|
||||
}
|
||||
if outputPath == "" {
|
||||
return fmt.Errorf("output path cannot be empty")
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "ACC-Server-Manager/1.0")
|
||||
|
||||
resp, err := dv.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
file, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create output file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hash := sha256.New()
|
||||
writer := io.MultiWriter(file, hash)
|
||||
|
||||
_, err = io.Copy(writer, resp.Body)
|
||||
if err != nil {
|
||||
os.Remove(outputPath)
|
||||
return fmt.Errorf("failed to write file: %v", err)
|
||||
}
|
||||
|
||||
if expectedSHA256 != "" {
|
||||
actualHash := fmt.Sprintf("%x", hash.Sum(nil))
|
||||
if actualHash != expectedSHA256 {
|
||||
os.Remove(outputPath)
|
||||
return fmt.Errorf("file hash mismatch: expected %s, got %s", expectedSHA256, actualHash)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
95
local/utl/security/path_validator.go
Normal file
95
local/utl/security/path_validator.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type PathValidator struct {
|
||||
allowedBasePaths []string
|
||||
blockedPatterns []*regexp.Regexp
|
||||
}
|
||||
|
||||
func NewPathValidator() *PathValidator {
|
||||
blockedPatterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(`\.\.`),
|
||||
regexp.MustCompile(`[<>:"|?*]`),
|
||||
regexp.MustCompile(`^(CON|PRN|AUX|NUL|COM[1-9]|LPT[1-9])$`),
|
||||
regexp.MustCompile(`\x00`),
|
||||
regexp.MustCompile(`^\\\\`),
|
||||
regexp.MustCompile(`^[a-zA-Z]:\\Windows`),
|
||||
regexp.MustCompile(`^[a-zA-Z]:\\Program Files`),
|
||||
}
|
||||
|
||||
return &PathValidator{
|
||||
allowedBasePaths: []string{
|
||||
`C:\ACC-Servers`,
|
||||
`D:\ACC-Servers`,
|
||||
`E:\ACC-Servers`,
|
||||
`C:\SteamCMD`,
|
||||
`D:\SteamCMD`,
|
||||
`E:\SteamCMD`,
|
||||
},
|
||||
blockedPatterns: blockedPatterns,
|
||||
}
|
||||
}
|
||||
|
||||
func (pv *PathValidator) ValidateInstallPath(path string) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("path cannot be empty")
|
||||
}
|
||||
|
||||
cleanPath := filepath.Clean(path)
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid path: %v", err)
|
||||
}
|
||||
|
||||
for _, pattern := range pv.blockedPatterns {
|
||||
if pattern.MatchString(absPath) || pattern.MatchString(strings.ToUpper(filepath.Base(absPath))) {
|
||||
return fmt.Errorf("path contains forbidden patterns")
|
||||
}
|
||||
}
|
||||
|
||||
allowed := false
|
||||
for _, basePath := range pv.allowedBasePaths {
|
||||
if strings.HasPrefix(strings.ToLower(absPath), strings.ToLower(basePath)) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return fmt.Errorf("path must be within allowed directories: %v", pv.allowedBasePaths)
|
||||
}
|
||||
|
||||
if len(absPath) > 260 {
|
||||
return fmt.Errorf("path too long (max 260 characters)")
|
||||
}
|
||||
|
||||
parentDir := filepath.Dir(absPath)
|
||||
if parentInfo, err := os.Stat(parentDir); err == nil {
|
||||
if !parentInfo.IsDir() {
|
||||
return fmt.Errorf("parent path is not a directory")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pv *PathValidator) AddAllowedBasePath(path string) error {
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid base path: %v", err)
|
||||
}
|
||||
|
||||
pv.allowedBasePaths = append(pv.allowedBasePaths, absPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pv *PathValidator) GetAllowedBasePaths() []string {
|
||||
return append([]string(nil), pv.allowedBasePaths...)
|
||||
}
|
||||
@@ -30,6 +30,7 @@ func Start(di *dig.Container) *fiber.App {
|
||||
app.Use(securityMW.SecurityHeaders())
|
||||
app.Use(securityMW.LogSecurityEvents())
|
||||
app.Use(securityMW.TimeoutMiddleware(30 * time.Second))
|
||||
app.Use(securityMW.RequestContextTimeout(60 * time.Second))
|
||||
app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024)) // 10MB
|
||||
app.Use(securityMW.ValidateUserAgent())
|
||||
app.Use(securityMW.ValidateContentType("application/json", "application/x-www-form-urlencoded", "multipart/form-data"))
|
||||
|
||||
Reference in New Issue
Block a user