From 4004d83411ae42ab80bf7242d97265325555d71a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=20Jurmanovi=C4=87?= Date: Thu, 18 Sep 2025 22:24:51 +0200 Subject: [PATCH] add step list for server creation --- local/controller/server.go | 36 -- local/controller/websocket.go | 32 +- local/middleware/access_key.go | 4 - local/middleware/auth.go | 29 -- local/middleware/logging/request_logging.go | 12 - local/middleware/security/security.go | 48 +- .../001_upgrade_password_security.go | 41 +- local/migrations/002_migrate_to_uuid.go | 21 - .../003_update_state_history_sessions.go | 18 - local/model/cache.go | 38 +- local/model/config.go | 6 +- local/model/filter.go | 18 +- local/model/lookup.go | 5 - local/model/model.go | 2 - local/model/permission.go | 4 +- local/model/permissions.go | 2 - local/model/role.go | 4 +- local/model/server.go | 30 +- local/model/service_control.go | 10 - local/model/state_history.go | 14 +- local/model/state_history_stats.go | 2 +- local/model/steam_credentials.go | 29 +- local/model/user.go | 12 - local/model/websocket.go | 9 - local/repository/base.go | 13 - local/repository/config.go | 4 +- local/repository/membership.go | 19 - local/repository/repository.go | 14 +- local/repository/server.go | 4 - local/repository/state_history.go | 17 +- local/service/config.go | 264 ++++------ local/service/firewall_service.go | 4 +- local/service/membership.go | 38 +- local/service/server.go | 479 +++++------------- local/service/service.go | 7 +- local/service/service_control.go | 19 +- local/service/service_manager.go | 11 +- local/service/state_history.go | 6 - local/service/steam_service.go | 298 +++++------ local/service/websocket.go | 39 +- local/service/windows_service.go | 21 - local/utl/cache/cache.go | 12 - local/utl/command/callback_executor.go | 245 +++++++++ local/utl/command/callbacks.go | 19 + local/utl/command/executor.go | 21 +- local/utl/command/interactive_executor.go | 300 +---------- local/utl/common/common.go | 15 +- local/utl/configs/configs.go | 6 +- local/utl/db/db.go | 5 - local/utl/env/env.go | 7 - .../error_handler/controller_error_handler.go | 45 +- local/utl/jwt/jwt.go | 9 - local/utl/logging/base.go | 21 - local/utl/logging/debug.go | 23 - local/utl/logging/error.go | 14 - local/utl/logging/info.go | 20 - local/utl/logging/logger.go | 31 -- local/utl/logging/warn.go | 14 - local/utl/network/port.go | 9 +- local/utl/password/password.go | 7 +- local/utl/server/server.go | 20 +- local/utl/tracking/log_tailer.go | 24 +- local/utl/tracking/position_tracker.go | 10 +- local/utl/tracking/tracking.go | 4 +- local/utl/websocket/websocket.go | 169 ++++++ swagger/docs.go | 2 - tests/auth_helper.go | 15 - tests/mocks/auth_middleware_mock.go | 11 +- tests/mocks/repository_mock.go | 13 - tests/mocks/state_history_mock.go | 39 +- tests/test_helper.go | 51 +- tests/testdata/state_history_data.go | 9 - .../unit/controller/controller_simple_test.go | 70 +-- tests/unit/controller/helper.go | 13 +- .../state_history_controller_test.go | 96 +--- .../state_history_repository_test.go | 71 +-- tests/unit/service/auth_simple_test.go | 46 +- tests/unit/service/cache_service_test.go | 120 +---- tests/unit/service/config_service_test.go | 109 ---- .../service/state_history_service_test.go | 106 +--- 80 files changed, 950 insertions(+), 2554 deletions(-) create mode 100644 local/utl/command/callback_executor.go create mode 100644 local/utl/command/callbacks.go create mode 100644 local/utl/websocket/websocket.go diff --git a/local/controller/server.go b/local/controller/server.go index 1101971..344b4cb 100644 --- a/local/controller/server.go +++ b/local/controller/server.go @@ -29,7 +29,6 @@ func NewServerController(ss *service.ServerService, routeGroups *common.RouteGro serverRoutes.Get("/", auth.HasPermission(model.ServerView), ac.GetAll) serverRoutes.Get("/:id", auth.HasPermission(model.ServerView), ac.GetById) serverRoutes.Post("/", auth.HasPermission(model.ServerCreate), ac.CreateServer) - serverRoutes.Put("/:id", auth.HasPermission(model.ServerUpdate), ac.UpdateServer) serverRoutes.Delete("/:id", auth.HasPermission(model.ServerDelete), ac.DeleteServer) apiServerRoutes := routeGroups.Api.Group("/server") @@ -152,41 +151,6 @@ func (ac *ServerController) CreateServer(c *fiber.Ctx) error { return c.JSON(server) } -// UpdateServer updates an existing server -// @Summary Update an ACC server -// @Description Update configuration for an existing ACC server -// @Tags Server -// @Accept json -// @Produce json -// @Param id path string true "Server ID (UUID format)" -// @Param server body model.Server true "Updated server configuration" -// @Success 200 {object} object "Updated server details" -// @Failure 400 {object} error_handler.ErrorResponse "Invalid server data or ID" -// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized" -// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions" -// @Failure 404 {object} error_handler.ErrorResponse "Server not found" -// @Failure 500 {object} error_handler.ErrorResponse "Internal server error" -// @Security BearerAuth -// @Router /server/{id} [put] -func (ac *ServerController) UpdateServer(c *fiber.Ctx) error { - serverIDStr := c.Params("id") - serverID, err := uuid.Parse(serverIDStr) - if err != nil { - return ac.errorHandler.HandleUUIDError(c, "server ID") - } - - server := new(model.Server) - if err := c.BodyParser(server); err != nil { - return ac.errorHandler.HandleParsingError(c, err) - } - server.ID = serverID - - if err := ac.service.UpdateServer(c, server); err != nil { - return ac.errorHandler.HandleServiceError(c, err) - } - return c.JSON(server) -} - // DeleteServer deletes an existing server // @Summary Delete an ACC server // @Description Delete an existing ACC server diff --git a/local/controller/websocket.go b/local/controller/websocket.go index 8ae5c47..1e10b51 100644 --- a/local/controller/websocket.go +++ b/local/controller/websocket.go @@ -17,7 +17,6 @@ type WebSocketController struct { jwtHandler *jwt.OpenJWTHandler } -// NewWebSocketController initializes WebSocketController func NewWebSocketController( wsService *service.WebSocketService, jwtHandler *jwt.OpenJWTHandler, @@ -29,7 +28,6 @@ func NewWebSocketController( jwtHandler: jwtHandler, } - // WebSocket routes wsRoutes := routeGroups.WebSocket wsRoutes.Use("/", wsc.upgradeWebSocket) wsRoutes.Get("/", websocket.New(wsc.handleWebSocket)) @@ -37,11 +35,8 @@ func NewWebSocketController( return wsc } -// upgradeWebSocket middleware to upgrade HTTP to WebSocket and validate authentication func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error { - // Check if it's a WebSocket upgrade request if websocket.IsWebSocketUpgrade(c) { - // Validate JWT token from query parameter or header token := c.Query("token") if token == "" { token = c.Get("Authorization") @@ -54,21 +49,18 @@ func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusUnauthorized, "Missing authentication token") } - // Validate the token claims, err := wsc.jwtHandler.ValidateToken(token) if err != nil { return fiber.NewError(fiber.StatusUnauthorized, "Invalid authentication token") } - // Parse UserID string to UUID userID, err := uuid.Parse(claims.UserID) if err != nil { return fiber.NewError(fiber.StatusUnauthorized, "Invalid user ID in token") } - // Store user info in context for use in WebSocket handler c.Locals("userID", userID) - c.Locals("username", claims.UserID) // Use UserID as username for now + c.Locals("username", claims.UserID) return c.Next() } @@ -76,12 +68,9 @@ func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusUpgradeRequired, "WebSocket upgrade required") } -// handleWebSocket handles WebSocket connections func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) { - // Generate a unique connection ID connID := uuid.New().String() - // Get user info from locals (set by middleware) userID, ok := c.Locals("userID").(uuid.UUID) if !ok { logging.Error("Failed to get user ID from WebSocket connection") @@ -92,16 +81,13 @@ func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) { username, _ := c.Locals("username").(string) logging.Info("WebSocket connection established for user: %s (ID: %s)", username, userID.String()) - // Add the connection to the service wsc.webSocketService.AddConnection(connID, c, &userID) - // Handle connection cleanup defer func() { wsc.webSocketService.RemoveConnection(connID) logging.Info("WebSocket connection closed for user: %s", username) }() - // Handle incoming messages from the client for { messageType, message, err := c.ReadMessage() if err != nil { @@ -111,14 +97,12 @@ func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) { break } - // Handle different message types switch messageType { case websocket.TextMessage: wsc.handleTextMessage(connID, userID, message) case websocket.BinaryMessage: logging.Debug("Received binary message from user %s (not supported)", username) case websocket.PingMessage: - // Respond with pong if err := c.WriteMessage(websocket.PongMessage, nil); err != nil { logging.Error("Failed to send pong to user %s: %v", username, err) break @@ -127,21 +111,11 @@ func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) { } } -// handleTextMessage processes text messages from the client func (wsc *WebSocketController) handleTextMessage(connID string, userID uuid.UUID, message []byte) { logging.Debug("Received WebSocket message from user %s: %s", userID.String(), string(message)) - // Parse the message to handle different types of client requests - // For now, we'll just log it. In the future, you might want to handle: - // - Subscription to specific server creation processes - // - Client heartbeat/keepalive - // - Request for status updates - - // Example: If the message contains a server ID, associate this connection with that server - // This is a simple implementation - you might want to use proper JSON parsing messageStr := string(message) if len(messageStr) > 10 && messageStr[:9] == "server_id" { - // Extract server ID from message like "server_id:uuid" if serverIDStr := messageStr[10:]; len(serverIDStr) > 0 { if serverID, err := uuid.Parse(serverIDStr); err == nil { wsc.webSocketService.SetServerID(connID, serverID) @@ -151,18 +125,14 @@ func (wsc *WebSocketController) handleTextMessage(connID string, userID uuid.UUI } } -// GetWebSocketUpgrade returns the WebSocket upgrade handler for use in other controllers func (wsc *WebSocketController) GetWebSocketUpgrade() fiber.Handler { return wsc.upgradeWebSocket } -// GetWebSocketHandler returns the WebSocket connection handler for use in other controllers func (wsc *WebSocketController) GetWebSocketHandler() func(*websocket.Conn) { return wsc.handleWebSocket } -// BroadcastServerCreationProgress is a helper method for other services to broadcast progress func (wsc *WebSocketController) BroadcastServerCreationProgress(serverID uuid.UUID, step string, status string, message string) { - // This can be used by the ServerService during server creation logging.Info("Broadcasting server creation progress: %s - %s: %s", serverID.String(), step, status) } diff --git a/local/middleware/access_key.go b/local/middleware/access_key.go index 3129639..7256b44 100644 --- a/local/middleware/access_key.go +++ b/local/middleware/access_key.go @@ -10,12 +10,10 @@ import ( "github.com/google/uuid" ) -// AccessKeyMiddleware provides authentication and permission middleware. type AccessKeyMiddleware struct { userInfo CachedUserInfo } -// NewAccessKeyMiddleware creates a new AccessKeyMiddleware. func NewAccessKeyMiddleware() *AccessKeyMiddleware { auth := &AccessKeyMiddleware{ userInfo: CachedUserInfo{UserID: uuid.New().String(), Username: "access_key", RoleName: "Admin", Permissions: map[string]bool{ @@ -25,9 +23,7 @@ func NewAccessKeyMiddleware() *AccessKeyMiddleware { return auth } -// Authenticate is a middleware for JWT authentication with enhanced security. func (m *AccessKeyMiddleware) Authenticate(ctx *fiber.Ctx) error { - // Log authentication attempt ip := ctx.IP() userAgent := ctx.Get("User-Agent") diff --git a/local/middleware/auth.go b/local/middleware/auth.go index db8fd73..467d966 100644 --- a/local/middleware/auth.go +++ b/local/middleware/auth.go @@ -16,7 +16,6 @@ import ( "github.com/google/uuid" ) -// CachedUserInfo holds cached user authentication and permission data type CachedUserInfo struct { UserID string Username string @@ -25,7 +24,6 @@ type CachedUserInfo struct { CachedAt time.Time } -// AuthMiddleware provides authentication and permission middleware. type AuthMiddleware struct { membershipService *service.MembershipService cache *cache.InMemoryCache @@ -34,7 +32,6 @@ type AuthMiddleware struct { openJWTHandler *jwt.OpenJWTHandler } -// NewAuthMiddleware creates a new AuthMiddleware. func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache, jwtHandler *jwt.JWTHandler, openJWTHandler *jwt.OpenJWTHandler) *AuthMiddleware { auth := &AuthMiddleware{ membershipService: ms, @@ -44,24 +41,20 @@ func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache openJWTHandler: openJWTHandler, } - // Set up bidirectional relationship for cache invalidation ms.SetCacheInvalidator(auth) return auth } -// Authenticate is a middleware for JWT authentication with enhanced security. func (m *AuthMiddleware) AuthenticateOpen(ctx *fiber.Ctx) error { return m.AuthenticateWithHandler(m.openJWTHandler.JWTHandler, true, ctx) } -// Authenticate is a middleware for JWT authentication with enhanced security. func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error { return m.AuthenticateWithHandler(m.jwtHandler, false, ctx) } func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isOpenToken bool, ctx *fiber.Ctx) error { - // Log authentication attempt ip := ctx.IP() userAgent := ctx.Get("User-Agent") @@ -89,7 +82,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO }) } - // Validate token length to prevent potential attacks token := parts[1] if len(token) < 10 || len(token) > 2048 { logging.Error("Authentication failed: invalid token length from IP %s", ip) @@ -113,7 +105,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO }) } - // Additional security: validate user ID format if claims.UserID == "" || len(claims.UserID) < 10 { logging.Error("Authentication failed: invalid user ID in token from IP %s", ip) return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ @@ -127,7 +118,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO ctx.Locals("userInfo", userInfo) ctx.Locals("authTime", time.Now()) } else { - // Preload and cache user info to avoid database queries on permission checks userInfo, err := m.getCachedUserInfo(ctx.UserContext(), claims.UserID) if err != nil { logging.Error("Authentication failed: unable to load user info for %s from IP %s: %v", claims.UserID, ip, err) @@ -145,7 +135,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO return ctx.Next() } -// HasPermission is a middleware for checking user permissions with enhanced logging. func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler { return func(ctx *fiber.Ctx) error { userID, ok := ctx.Locals("userID").(string) @@ -160,7 +149,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler return ctx.Next() } - // Validate permission parameter if requiredPermission == "" { logging.Error("Permission check failed: empty permission requirement") return ctx.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ @@ -168,7 +156,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler }) } - // Use cached user info from authentication step - no database queries needed userInfo, ok := ctx.Locals("userInfo").(*CachedUserInfo) if !ok { logging.Error("Permission check failed: no cached user info in context from IP %s", ctx.IP()) @@ -177,7 +164,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler }) } - // Check if user has permission using cached data has := m.hasPermissionFromCache(userInfo, requiredPermission) if !has { @@ -192,16 +178,13 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler } } -// AuthRateLimit applies rate limiting specifically for authentication endpoints func (m *AuthMiddleware) AuthRateLimit() fiber.Handler { return m.securityMW.AuthRateLimit() } -// RequireHTTPS redirects HTTP requests to HTTPS in production func (m *AuthMiddleware) RequireHTTPS() fiber.Handler { return func(ctx *fiber.Ctx) error { if ctx.Protocol() != "https" && ctx.Get("X-Forwarded-Proto") != "https" { - // Allow HTTP in development/testing if ctx.Hostname() != "localhost" && ctx.Hostname() != "127.0.0.1" { httpsURL := "https://" + ctx.Hostname() + ctx.OriginalURL() return ctx.Redirect(httpsURL, fiber.StatusMovedPermanently) @@ -211,11 +194,9 @@ func (m *AuthMiddleware) RequireHTTPS() fiber.Handler { } } -// getCachedUserInfo retrieves and caches complete user information including permissions func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (*CachedUserInfo, error) { cacheKey := fmt.Sprintf("userinfo:%s", userID) - // Try cache first if cached, found := m.cache.Get(cacheKey); found { if userInfo, ok := cached.(*CachedUserInfo); ok { logging.DebugWithContext("AUTH_CACHE", "User info for %s found in cache", userID) @@ -223,13 +204,11 @@ func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) ( } } - // Cache miss - load from database user, err := m.membershipService.GetUserWithPermissions(ctx, userID) if err != nil { return nil, err } - // Build permission map for fast lookups permissions := make(map[string]bool) for _, p := range user.Role.Permissions { permissions[p.Name] = true @@ -243,34 +222,26 @@ func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) ( CachedAt: time.Now(), } - // Cache for 15 minutes m.cache.Set(cacheKey, userInfo, 15*time.Minute) logging.DebugWithContext("AUTH_CACHE", "User info for %s cached with %d permissions", userID, len(permissions)) return userInfo, nil } -// hasPermissionFromCache checks permissions using cached user info (no database queries) func (m *AuthMiddleware) hasPermissionFromCache(userInfo *CachedUserInfo, permission string) bool { - // Super Admin and Admin have all permissions if userInfo.RoleName == "Super Admin" || userInfo.RoleName == "Admin" { return true } - // Check specific permission in cached map return userInfo.Permissions[permission] } -// InvalidateUserPermissions removes cached user info for a user func (m *AuthMiddleware) InvalidateUserPermissions(userID string) { cacheKey := fmt.Sprintf("userinfo:%s", userID) m.cache.Delete(cacheKey) logging.InfoWithContext("AUTH_CACHE", "User info cache invalidated for user %s", userID) } -// InvalidateAllUserPermissions clears all cached user info (useful for role/permission changes) func (m *AuthMiddleware) InvalidateAllUserPermissions() { - // This would need to be implemented based on your cache interface - // For now, just log that invalidation was requested logging.InfoWithContext("AUTH_CACHE", "All user info caches invalidation requested") } diff --git a/local/middleware/logging/request_logging.go b/local/middleware/logging/request_logging.go index 7c31e3b..8d92c4b 100644 --- a/local/middleware/logging/request_logging.go +++ b/local/middleware/logging/request_logging.go @@ -7,25 +7,20 @@ import ( "github.com/gofiber/fiber/v2" ) -// RequestLoggingMiddleware logs HTTP requests and responses type RequestLoggingMiddleware struct { infoLogger *logging.InfoLogger } -// NewRequestLoggingMiddleware creates a new request logging middleware func NewRequestLoggingMiddleware() *RequestLoggingMiddleware { return &RequestLoggingMiddleware{ infoLogger: logging.GetInfoLogger(), } } -// Handler returns the middleware handler function func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler { return func(c *fiber.Ctx) error { - // Record start time start := time.Now() - // Log incoming request userAgent := c.Get("User-Agent") if userAgent == "" { userAgent = "Unknown" @@ -33,17 +28,13 @@ func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler { rlm.infoLogger.LogRequest(c.Method(), c.OriginalURL(), userAgent) - // Continue to next handler err := c.Next() - // Calculate duration duration := time.Since(start) - // Log response statusCode := c.Response().StatusCode() rlm.infoLogger.LogResponse(c.Method(), c.OriginalURL(), statusCode, duration.String()) - // Log error if present if err != nil { logging.ErrorWithContext("REQUEST_MIDDLEWARE", "Request failed: %v", err) } @@ -52,10 +43,8 @@ func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler { } } -// Global request logging middleware instance var globalRequestLoggingMiddleware *RequestLoggingMiddleware -// GetRequestLoggingMiddleware returns the global request logging middleware func GetRequestLoggingMiddleware() *RequestLoggingMiddleware { if globalRequestLoggingMiddleware == nil { globalRequestLoggingMiddleware = NewRequestLoggingMiddleware() @@ -63,7 +52,6 @@ func GetRequestLoggingMiddleware() *RequestLoggingMiddleware { return globalRequestLoggingMiddleware } -// Handler returns the global request logging middleware handler func Handler() fiber.Handler { return GetRequestLoggingMiddleware().Handler() } diff --git a/local/middleware/security/security.go b/local/middleware/security/security.go index 00899c4..78a9182 100644 --- a/local/middleware/security/security.go +++ b/local/middleware/security/security.go @@ -11,19 +11,16 @@ import ( "github.com/gofiber/fiber/v2" ) -// RateLimiter stores rate limiting information type RateLimiter struct { requests map[string][]time.Time mutex sync.RWMutex } -// NewRateLimiter creates a new rate limiter func NewRateLimiter() *RateLimiter { rl := &RateLimiter{ requests: make(map[string][]time.Time), } - // Use graceful shutdown for cleanup goroutine shutdownManager := graceful.GetManager() shutdownManager.RunGoroutine(func(ctx context.Context) { rl.cleanupWithContext(ctx) @@ -32,7 +29,6 @@ func NewRateLimiter() *RateLimiter { return rl } -// cleanup removes old entries from the rate limiter func (rl *RateLimiter) cleanupWithContext(ctx context.Context) { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() @@ -62,44 +58,29 @@ func (rl *RateLimiter) cleanupWithContext(ctx context.Context) { } } -// SecurityMiddleware provides comprehensive security middleware type SecurityMiddleware struct { rateLimiter *RateLimiter } -// NewSecurityMiddleware creates a new security middleware func NewSecurityMiddleware() *SecurityMiddleware { return &SecurityMiddleware{ rateLimiter: NewRateLimiter(), } } -// SecurityHeaders adds security headers to responses func (sm *SecurityMiddleware) SecurityHeaders() fiber.Handler { return func(c *fiber.Ctx) error { - // Prevent MIME type sniffing c.Set("X-Content-Type-Options", "nosniff") - - // Prevent clickjacking c.Set("X-Frame-Options", "DENY") - - // Enable XSS protection c.Set("X-XSS-Protection", "1; mode=block") - - // Prevent referrer leakage c.Set("Referrer-Policy", "strict-origin-when-cross-origin") - - // Content Security Policy c.Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none'") - - // Permissions Policy c.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), interest-cohort=()") return c.Next() } } -// RateLimit implements rate limiting for API endpoints func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) fiber.Handler { return func(c *fiber.Ctx) error { ip := c.IP() @@ -111,7 +92,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) now := time.Now() requests := sm.rateLimiter.requests[key] - // Remove requests older than duration filtered := make([]time.Time, 0, len(requests)) for _, t := range requests { if now.Sub(t) < duration { @@ -119,7 +99,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) } } - // Check if limit is exceeded if len(filtered) >= maxRequests { return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{ "error": "Rate limit exceeded", @@ -127,7 +106,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) }) } - // Add current request filtered = append(filtered, now) sm.rateLimiter.requests[key] = filtered @@ -135,7 +113,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) } } -// AuthRateLimit implements stricter rate limiting for authentication endpoints func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler { return func(c *fiber.Ctx) error { ip := c.IP() @@ -148,7 +125,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler { now := time.Now() requests := sm.rateLimiter.requests[key] - // Remove requests older than 15 minutes filtered := make([]time.Time, 0, len(requests)) for _, t := range requests { if now.Sub(t) < 15*time.Minute { @@ -156,7 +132,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler { } } - // Check if limit is exceeded (5 requests per 15 minutes for auth) if len(filtered) >= 5 { return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{ "error": "Too many authentication attempts", @@ -164,7 +139,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler { }) } - // Add current request filtered = append(filtered, now) sm.rateLimiter.requests[key] = filtered @@ -172,20 +146,16 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler { } } -// InputSanitization sanitizes user input to prevent XSS and injection attacks func (sm *SecurityMiddleware) InputSanitization() fiber.Handler { return func(c *fiber.Ctx) error { - // Sanitize query parameters c.Request().URI().QueryArgs().VisitAll(func(key, value []byte) { sanitized := sanitizeInput(string(value)) c.Request().URI().QueryArgs().Set(string(key), sanitized) }) - // Store original body for processing if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" { body := c.Body() if len(body) > 0 { - // Basic sanitization - remove potentially dangerous patterns sanitized := sanitizeInput(string(body)) c.Request().SetBodyString(sanitized) } @@ -195,7 +165,6 @@ func (sm *SecurityMiddleware) InputSanitization() fiber.Handler { } } -// sanitizeInput removes potentially dangerous patterns from input func sanitizeInput(input string) string { dangerous := []string{ " 100 || containsBinaryData(plainPassword) { return fmt.Errorf("password appears to be corrupted encrypted data") } @@ -144,7 +122,6 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u plainPassword = decrypted } - // Validate plain password if plainPassword == "" { return errors.New("decrypted password is empty") } @@ -153,13 +130,11 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u return errors.New("password too short after decryption") } - // Hash the plain password using bcrypt hashedPassword, err := password.HashPassword(plainPassword) if err != nil { return fmt.Errorf("failed to hash password: %v", err) } - // Update with hashed password if err := tx.Model(user).Update("password", hashedPassword).Error; err != nil { return fmt.Errorf("failed to update password: %v", err) } @@ -168,19 +143,16 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u return nil } -// UserForMigration represents a user record for migration purposes type UserForMigration struct { ID string `gorm:"column:id"` Username string `gorm:"column:username"` Password string `gorm:"column:password"` } -// TableName specifies the table name for GORM func (UserForMigration) TableName() string { return "users" } -// MigrationRecord tracks applied migrations type MigrationRecord struct { ID uint `gorm:"primaryKey"` MigrationName string `gorm:"unique;not null"` @@ -189,49 +161,38 @@ type MigrationRecord struct { Notes string } -// TableName specifies the table name for GORM func (MigrationRecord) TableName() string { return "migration_records" } -// isAlreadyHashed checks if a password is already bcrypt hashed func isAlreadyHashed(password string) bool { return len(password) >= 60 && (password[:4] == "$2a$" || password[:4] == "$2b$" || password[:4] == "$2y$") } -// containsBinaryData checks if a string contains binary data func containsBinaryData(s string) bool { for _, b := range []byte(s) { - if b < 32 && b != 9 && b != 10 && b != 13 { // Allow tab, newline, carriage return + if b < 32 && b != 9 && b != 10 && b != 13 { return true } } return false } -// isDuplicateColumnError checks if an error is due to duplicate column func isDuplicateColumnError(err error) bool { errStr := err.Error() return fmt.Sprintf("%v", errStr) == "duplicate column name: password_backup" || fmt.Sprintf("%v", errStr) == "SQLITE_ERROR: duplicate column name: password_backup" } -// decryptOldPassword attempts to decrypt using the old encryption method -// This is a simplified version of the old DecryptPassword function func decryptOldPassword(encryptedPassword string) (string, error) { - // This would use the old decryption logic - // For now, we'll return an error to force treating as plain text - // In a real scenario, you'd implement the old decryption here return "", errors.New("old decryption not implemented - treating as plain text") } -// Down reverses the migration (if needed) func (m *Migration001UpgradePasswordSecurity) Down() error { logging.Error("Password security migration rollback is not supported for security reasons") return errors.New("password security migration rollback is not supported") } -// RunMigration is a convenience function to run the migration func RunPasswordSecurityMigration(db *gorm.DB) error { migration := NewMigration001UpgradePasswordSecurity(db) return migration.Up() diff --git a/local/migrations/002_migrate_to_uuid.go b/local/migrations/002_migrate_to_uuid.go index 7cf915d..2387dbd 100644 --- a/local/migrations/002_migrate_to_uuid.go +++ b/local/migrations/002_migrate_to_uuid.go @@ -9,21 +9,17 @@ import ( "gorm.io/gorm" ) -// Migration002MigrateToUUID migrates tables from integer IDs to UUIDs type Migration002MigrateToUUID struct { DB *gorm.DB } -// NewMigration002MigrateToUUID creates a new UUID migration func NewMigration002MigrateToUUID(db *gorm.DB) *Migration002MigrateToUUID { return &Migration002MigrateToUUID{DB: db} } -// Up executes the migration func (m *Migration002MigrateToUUID) Up() error { logging.Info("Checking UUID migration...") - // Check if migration is needed by looking at the servers table structure if !m.needsMigration() { logging.Info("UUID migration not needed - tables already use UUID primary keys") return nil @@ -31,7 +27,6 @@ func (m *Migration002MigrateToUUID) Up() error { logging.Info("Starting UUID migration...") - // Check if migration has already been applied var migrationRecord MigrationRecord err := m.DB.Where("migration_name = ?", "002_migrate_to_uuid").First(&migrationRecord).Error if err == nil { @@ -39,12 +34,10 @@ func (m *Migration002MigrateToUUID) Up() error { return nil } - // Create migration tracking table if it doesn't exist if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil { return fmt.Errorf("failed to create migration tracking table: %v", err) } - // Execute the UUID migration using the existing migration function logging.Info("Executing UUID migration...") if err := runUUIDMigrationSQL(m.DB); err != nil { return fmt.Errorf("failed to execute UUID migration: %v", err) @@ -54,9 +47,7 @@ func (m *Migration002MigrateToUUID) Up() error { return nil } -// needsMigration checks if the UUID migration is needed by examining table structure func (m *Migration002MigrateToUUID) needsMigration() bool { - // Check if servers table exists and has integer primary key var result struct { Type string `gorm:"column:type"` } @@ -67,29 +58,22 @@ func (m *Migration002MigrateToUUID) needsMigration() bool { `).Scan(&result).Error if err != nil || result.Type == "" { - // Table doesn't exist or no primary key found - assume no migration needed return false } - // If the primary key is INTEGER, we need migration - // If it's TEXT (UUID), migration already done return result.Type == "INTEGER" || result.Type == "integer" } -// Down reverses the migration (not implemented for safety) func (m *Migration002MigrateToUUID) Down() error { logging.Error("UUID migration rollback is not supported for data safety reasons") return fmt.Errorf("UUID migration rollback is not supported") } -// runUUIDMigrationSQL executes the UUID migration using the SQL file func runUUIDMigrationSQL(db *gorm.DB) error { - // Disable foreign key constraints during migration if err := db.Exec("PRAGMA foreign_keys=OFF").Error; err != nil { return fmt.Errorf("failed to disable foreign keys: %v", err) } - // Start transaction tx := db.Begin() if tx.Error != nil { return fmt.Errorf("failed to start transaction: %v", tx.Error) @@ -101,25 +85,21 @@ func runUUIDMigrationSQL(db *gorm.DB) error { } }() - // Read the migration SQL from file sqlPath := filepath.Join("scripts", "migrations", "002_migrate_servers_to_uuid.sql") migrationSQL, err := ioutil.ReadFile(sqlPath) if err != nil { return fmt.Errorf("failed to read migration SQL file: %v", err) } - // Execute the migration if err := tx.Exec(string(migrationSQL)).Error; err != nil { tx.Rollback() return fmt.Errorf("failed to execute migration: %v", err) } - // Commit transaction if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit migration: %v", err) } - // Re-enable foreign key constraints if err := db.Exec("PRAGMA foreign_keys=ON").Error; err != nil { return fmt.Errorf("failed to re-enable foreign keys: %v", err) } @@ -127,7 +107,6 @@ func runUUIDMigrationSQL(db *gorm.DB) error { return nil } -// RunUUIDMigration is a convenience function to run the migration func RunUUIDMigration(db *gorm.DB) error { migration := NewMigration002MigrateToUUID(db) return migration.Up() diff --git a/local/migrations/003_update_state_history_sessions.go b/local/migrations/003_update_state_history_sessions.go index 0092c22..3fd53bc 100644 --- a/local/migrations/003_update_state_history_sessions.go +++ b/local/migrations/003_update_state_history_sessions.go @@ -7,21 +7,17 @@ import ( "gorm.io/gorm" ) -// UpdateStateHistorySessions migrates tables from integer IDs to UUIDs type UpdateStateHistorySessions struct { DB *gorm.DB } -// NewUpdateStateHistorySessions creates a new UUID migration func NewUpdateStateHistorySessions(db *gorm.DB) *UpdateStateHistorySessions { return &UpdateStateHistorySessions{DB: db} } -// Up executes the migration func (m *UpdateStateHistorySessions) Up() error { logging.Info("Checking UUID migration...") - // Check if migration is needed by looking at the servers table structure if !m.needsMigration() { logging.Info("UUID migration not needed - tables already use UUID primary keys") return nil @@ -29,7 +25,6 @@ func (m *UpdateStateHistorySessions) Up() error { logging.Info("Starting UUID migration...") - // Check if migration has already been applied var migrationRecord MigrationRecord err := m.DB.Where("migration_name = ?", "002_migrate_to_uuid").First(&migrationRecord).Error if err == nil { @@ -37,12 +32,10 @@ func (m *UpdateStateHistorySessions) Up() error { return nil } - // Create migration tracking table if it doesn't exist if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil { return fmt.Errorf("failed to create migration tracking table: %v", err) } - // Execute the UUID migration using the existing migration function logging.Info("Executing UUID migration...") if err := runUUIDMigrationSQL(m.DB); err != nil { return fmt.Errorf("failed to execute UUID migration: %v", err) @@ -52,9 +45,7 @@ func (m *UpdateStateHistorySessions) Up() error { return nil } -// needsMigration checks if the UUID migration is needed by examining table structure func (m *UpdateStateHistorySessions) needsMigration() bool { - // Check if servers table exists and has integer primary key var result struct { Exists bool `gorm:"column:exists"` } @@ -65,26 +56,21 @@ func (m *UpdateStateHistorySessions) needsMigration() bool { `).Scan(&result).Error if err != nil || !result.Exists { - // Table doesn't exist or no primary key found - assume no migration needed return false } return result.Exists } -// Down reverses the migration (not implemented for safety) func (m *UpdateStateHistorySessions) Down() error { logging.Error("UUID migration rollback is not supported for data safety reasons") return fmt.Errorf("UUID migration rollback is not supported") } -// runUpdateStateHistorySessionsMigration executes the UUID migration using the SQL file func runUpdateStateHistorySessionsMigration(db *gorm.DB) error { - // Disable foreign key constraints during migration if err := db.Exec("PRAGMA foreign_keys=OFF").Error; err != nil { return fmt.Errorf("failed to disable foreign keys: %v", err) } - // Start transaction tx := db.Begin() if tx.Error != nil { return fmt.Errorf("failed to start transaction: %v", tx.Error) @@ -98,18 +84,15 @@ func runUpdateStateHistorySessionsMigration(db *gorm.DB) error { migrationSQL := "UPDATE state_history SET session = upper(substr(session, 1, 1));" - // Execute the migration if err := tx.Exec(string(migrationSQL)).Error; err != nil { tx.Rollback() return fmt.Errorf("failed to execute migration: %v", err) } - // Commit transaction if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit migration: %v", err) } - // Re-enable foreign key constraints if err := db.Exec("PRAGMA foreign_keys=ON").Error; err != nil { return fmt.Errorf("failed to re-enable foreign keys: %v", err) } @@ -117,7 +100,6 @@ func runUpdateStateHistorySessionsMigration(db *gorm.DB) error { return nil } -// RunUpdateStateHistorySessionsMigration is a convenience function to run the migration func RunUpdateStateHistorySessionsMigration(db *gorm.DB) error { migration := NewUpdateStateHistorySessions(db) return migration.Up() diff --git a/local/model/cache.go b/local/model/cache.go index 5b1a1eb..989d57e 100644 --- a/local/model/cache.go +++ b/local/model/cache.go @@ -6,20 +6,17 @@ import ( "time" ) -// StatusCache represents a cached server status with expiration type StatusCache struct { Status ServiceStatus UpdatedAt time.Time } -// CacheConfig holds configuration for cache behavior type CacheConfig struct { - ExpirationTime time.Duration // How long before a cache entry expires - ThrottleTime time.Duration // Minimum time between status checks - DefaultStatus ServiceStatus // Default status to return when throttled + ExpirationTime time.Duration + ThrottleTime time.Duration + DefaultStatus ServiceStatus } -// ServerStatusCache manages cached server statuses type ServerStatusCache struct { sync.RWMutex cache map[string]*StatusCache @@ -27,7 +24,6 @@ type ServerStatusCache struct { lastChecked map[string]time.Time } -// NewServerStatusCache creates a new server status cache func NewServerStatusCache(config CacheConfig) *ServerStatusCache { return &ServerStatusCache{ cache: make(map[string]*StatusCache), @@ -36,12 +32,10 @@ func NewServerStatusCache(config CacheConfig) *ServerStatusCache { } } -// GetStatus retrieves the cached status or indicates if a fresh check is needed func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool) { c.RLock() defer c.RUnlock() - // Check if we're being throttled if lastCheck, exists := c.lastChecked[serviceName]; exists { if time.Since(lastCheck) < c.config.ThrottleTime { if cached, ok := c.cache[serviceName]; ok { @@ -51,7 +45,6 @@ func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool) } } - // Check if we have a valid cached entry if cached, ok := c.cache[serviceName]; ok { if time.Since(cached.UpdatedAt) < c.config.ExpirationTime { return cached.Status, false @@ -61,7 +54,6 @@ func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool) return StatusUnknown, true } -// UpdateStatus updates the cache with a new status func (c *ServerStatusCache) UpdateStatus(serviceName string, status ServiceStatus) { c.Lock() defer c.Unlock() @@ -73,7 +65,6 @@ func (c *ServerStatusCache) UpdateStatus(serviceName string, status ServiceStatu c.lastChecked[serviceName] = time.Now() } -// InvalidateStatus removes a specific service from the cache func (c *ServerStatusCache) InvalidateStatus(serviceName string) { c.Lock() defer c.Unlock() @@ -82,7 +73,6 @@ func (c *ServerStatusCache) InvalidateStatus(serviceName string) { delete(c.lastChecked, serviceName) } -// Clear removes all entries from the cache func (c *ServerStatusCache) Clear() { c.Lock() defer c.Unlock() @@ -91,13 +81,11 @@ func (c *ServerStatusCache) Clear() { c.lastChecked = make(map[string]time.Time) } -// LookupCache provides a generic cache for lookup data type LookupCache struct { sync.RWMutex data map[string]interface{} } -// NewLookupCache creates a new lookup cache func NewLookupCache() *LookupCache { logging.Debug("Initializing new LookupCache") return &LookupCache{ @@ -105,7 +93,6 @@ func NewLookupCache() *LookupCache { } } -// Get retrieves a cached value by key func (c *LookupCache) Get(key string) (interface{}, bool) { c.RLock() defer c.RUnlock() @@ -119,7 +106,6 @@ func (c *LookupCache) Get(key string) (interface{}, bool) { return value, exists } -// Set stores a value in the cache func (c *LookupCache) Set(key string, value interface{}) { c.Lock() defer c.Unlock() @@ -128,7 +114,6 @@ func (c *LookupCache) Set(key string, value interface{}) { logging.Debug("Cache SET for key: %s", key) } -// Clear removes all entries from the cache func (c *LookupCache) Clear() { c.Lock() defer c.Unlock() @@ -137,13 +122,11 @@ func (c *LookupCache) Clear() { logging.Debug("Cache CLEARED") } -// ConfigEntry represents a cached configuration entry with its update time type ConfigEntry[T any] struct { Data T UpdatedAt time.Time } -// getConfigFromCache is a generic helper function to retrieve cached configs func getConfigFromCache[T any](cache map[string]*ConfigEntry[T], serverID string, expirationTime time.Duration) (*T, bool) { if entry, ok := cache[serverID]; ok { if time.Since(entry.UpdatedAt) < expirationTime { @@ -157,7 +140,6 @@ func getConfigFromCache[T any](cache map[string]*ConfigEntry[T], serverID string return nil, false } -// updateConfigInCache is a generic helper function to update cached configs func updateConfigInCache[T any](cache map[string]*ConfigEntry[T], serverID string, data T) { cache[serverID] = &ConfigEntry[T]{ Data: data, @@ -166,7 +148,6 @@ func updateConfigInCache[T any](cache map[string]*ConfigEntry[T], serverID strin logging.Debug("Config cache SET for server ID: %s", serverID) } -// ServerConfigCache manages cached server configurations type ServerConfigCache struct { sync.RWMutex configuration map[string]*ConfigEntry[Configuration] @@ -177,7 +158,6 @@ type ServerConfigCache struct { config CacheConfig } -// NewServerConfigCache creates a new server configuration cache func NewServerConfigCache(config CacheConfig) *ServerConfigCache { logging.Debug("Initializing new ServerConfigCache with expiration time: %v, throttle time: %v", config.ExpirationTime, config.ThrottleTime) return &ServerConfigCache{ @@ -190,7 +170,6 @@ func NewServerConfigCache(config CacheConfig) *ServerConfigCache { } } -// GetConfiguration retrieves a cached configuration func (c *ServerConfigCache) GetConfiguration(serverID string) (*Configuration, bool) { c.RLock() defer c.RUnlock() @@ -198,7 +177,6 @@ func (c *ServerConfigCache) GetConfiguration(serverID string) (*Configuration, b return getConfigFromCache(c.configuration, serverID, c.config.ExpirationTime) } -// GetAssistRules retrieves cached assist rules func (c *ServerConfigCache) GetAssistRules(serverID string) (*AssistRules, bool) { c.RLock() defer c.RUnlock() @@ -206,7 +184,6 @@ func (c *ServerConfigCache) GetAssistRules(serverID string) (*AssistRules, bool) return getConfigFromCache(c.assistRules, serverID, c.config.ExpirationTime) } -// GetEvent retrieves cached event configuration func (c *ServerConfigCache) GetEvent(serverID string) (*EventConfig, bool) { c.RLock() defer c.RUnlock() @@ -214,7 +191,6 @@ func (c *ServerConfigCache) GetEvent(serverID string) (*EventConfig, bool) { return getConfigFromCache(c.event, serverID, c.config.ExpirationTime) } -// GetEventRules retrieves cached event rules func (c *ServerConfigCache) GetEventRules(serverID string) (*EventRules, bool) { c.RLock() defer c.RUnlock() @@ -222,7 +198,6 @@ func (c *ServerConfigCache) GetEventRules(serverID string) (*EventRules, bool) { return getConfigFromCache(c.eventRules, serverID, c.config.ExpirationTime) } -// GetSettings retrieves cached server settings func (c *ServerConfigCache) GetSettings(serverID string) (*ServerSettings, bool) { c.RLock() defer c.RUnlock() @@ -230,7 +205,6 @@ func (c *ServerConfigCache) GetSettings(serverID string) (*ServerSettings, bool) return getConfigFromCache(c.settings, serverID, c.config.ExpirationTime) } -// UpdateConfiguration updates the configuration cache func (c *ServerConfigCache) UpdateConfiguration(serverID string, config Configuration) { c.Lock() defer c.Unlock() @@ -238,7 +212,6 @@ func (c *ServerConfigCache) UpdateConfiguration(serverID string, config Configur updateConfigInCache(c.configuration, serverID, config) } -// UpdateAssistRules updates the assist rules cache func (c *ServerConfigCache) UpdateAssistRules(serverID string, rules AssistRules) { c.Lock() defer c.Unlock() @@ -246,7 +219,6 @@ func (c *ServerConfigCache) UpdateAssistRules(serverID string, rules AssistRules updateConfigInCache(c.assistRules, serverID, rules) } -// UpdateEvent updates the event configuration cache func (c *ServerConfigCache) UpdateEvent(serverID string, event EventConfig) { c.Lock() defer c.Unlock() @@ -254,7 +226,6 @@ func (c *ServerConfigCache) UpdateEvent(serverID string, event EventConfig) { updateConfigInCache(c.event, serverID, event) } -// UpdateEventRules updates the event rules cache func (c *ServerConfigCache) UpdateEventRules(serverID string, rules EventRules) { c.Lock() defer c.Unlock() @@ -262,7 +233,6 @@ func (c *ServerConfigCache) UpdateEventRules(serverID string, rules EventRules) updateConfigInCache(c.eventRules, serverID, rules) } -// UpdateSettings updates the server settings cache func (c *ServerConfigCache) UpdateSettings(serverID string, settings ServerSettings) { c.Lock() defer c.Unlock() @@ -270,7 +240,6 @@ func (c *ServerConfigCache) UpdateSettings(serverID string, settings ServerSetti updateConfigInCache(c.settings, serverID, settings) } -// InvalidateServerCache removes all cached configurations for a specific server func (c *ServerConfigCache) InvalidateServerCache(serverID string) { c.Lock() defer c.Unlock() @@ -283,7 +252,6 @@ func (c *ServerConfigCache) InvalidateServerCache(serverID string) { delete(c.settings, serverID) } -// Clear removes all entries from the cache func (c *ServerConfigCache) Clear() { c.Lock() defer c.Unlock() diff --git a/local/model/config.go b/local/model/config.go index 3833e78..8ea2306 100644 --- a/local/model/config.go +++ b/local/model/config.go @@ -13,17 +13,15 @@ import ( type IntString int type IntBool int -// Config tracks configuration modifications type Config struct { ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"` ServerID uuid.UUID `json:"serverId" gorm:"not null;type:uuid"` - ConfigFile string `json:"configFile" gorm:"not null"` // e.g. "settings.json" + ConfigFile string `json:"configFile" gorm:"not null"` OldConfig string `json:"oldConfig" gorm:"type:text"` NewConfig string `json:"newConfig" gorm:"type:text"` ChangedAt time.Time `json:"changedAt" gorm:"default:CURRENT_TIMESTAMP"` } -// BeforeCreate is a GORM hook that runs before creating new config entries func (c *Config) BeforeCreate(tx *gorm.DB) error { if c.ID == uuid.Nil { c.ID = uuid.New() @@ -121,8 +119,6 @@ type Configuration struct { ConfigVersion IntString `json:"configVersion"` } -// Known configuration keys - func (i *IntBool) UnmarshalJSON(b []byte) error { var str int if err := json.Unmarshal(b, &str); err == nil && str <= 1 { diff --git a/local/model/filter.go b/local/model/filter.go index 13ef18c..b9df064 100644 --- a/local/model/filter.go +++ b/local/model/filter.go @@ -7,7 +7,6 @@ import ( "gorm.io/gorm" ) -// BaseFilter contains common filter fields that can be embedded in other filters type BaseFilter struct { Page int `query:"page"` PageSize int `query:"page_size"` @@ -15,18 +14,15 @@ type BaseFilter struct { SortDesc bool `query:"sort_desc"` } -// DateRangeFilter adds date range filtering capabilities type DateRangeFilter struct { StartDate time.Time `query:"start_date" time_format:"2006-01-02T15:04:05Z07:00"` EndDate time.Time `query:"end_date" time_format:"2006-01-02T15:04:05Z07:00"` } -// ServerBasedFilter adds server ID filtering capability type ServerBasedFilter struct { ServerID string `param:"id"` } -// ConfigFilter defines filtering options for Config queries type ConfigFilter struct { BaseFilter ServerBasedFilter @@ -34,13 +30,11 @@ type ConfigFilter struct { ChangedAt time.Time `query:"changed_at" time_format:"2006-01-02T15:04:05Z07:00"` } -// ApiFilter defines filtering options for Api queries type ServiceControlFilter struct { BaseFilter ServiceControl string `query:"serviceControl"` } -// MembershipFilter defines filtering options for User queries type MembershipFilter struct { BaseFilter Username string `query:"username"` @@ -48,36 +42,32 @@ type MembershipFilter struct { RoleID string `query:"role_id"` } -// Pagination returns the offset and limit for database queries func (f *BaseFilter) Pagination() (offset, limit int) { if f.Page < 1 { f.Page = 1 } if f.PageSize < 1 { - f.PageSize = 10 // Default page size + f.PageSize = 10 } offset = (f.Page - 1) * f.PageSize limit = f.PageSize return } -// GetSorting returns the sort field and direction for database queries func (f *BaseFilter) GetSorting() (field string, desc bool) { if f.SortBy == "" { - return "id", false // Default sorting + return "id", false } return f.SortBy, f.SortDesc } -// IsDateRangeValid checks if both dates are set and start date is before end date func (f *DateRangeFilter) IsDateRangeValid() bool { if f.StartDate.IsZero() || f.EndDate.IsZero() { - return true // If either date is not set, consider it valid + return true } return f.StartDate.Before(f.EndDate) } -// ApplyFilter applies the membership filter to a GORM query func (f *MembershipFilter) ApplyFilter(query *gorm.DB) *gorm.DB { if f.Username != "" { query = query.Where("username LIKE ?", "%"+f.Username+"%") @@ -93,12 +83,10 @@ func (f *MembershipFilter) ApplyFilter(query *gorm.DB) *gorm.DB { return query } -// Pagination returns the offset and limit for database queries func (f *MembershipFilter) Pagination() (offset, limit int) { return f.BaseFilter.Pagination() } -// GetSorting returns the sort field and direction for database queries func (f *MembershipFilter) GetSorting() (field string, desc bool) { return f.BaseFilter.GetSorting() } diff --git a/local/model/lookup.go b/local/model/lookup.go index 4f303cc..929a4f1 100644 --- a/local/model/lookup.go +++ b/local/model/lookup.go @@ -1,31 +1,26 @@ package model -// Track represents a track and its capacity type Track struct { Name string `json:"track" gorm:"primaryKey;size:50"` UniquePitBoxes int `json:"unique_pit_boxes"` PrivateServerSlots int `json:"private_server_slots"` } -// CarModel represents a car model mapping type CarModel struct { Value int `json:"value" gorm:"primaryKey"` CarModel string `json:"car_model"` } -// DriverCategory represents driver skill categories type DriverCategory struct { Value int `json:"value" gorm:"primaryKey"` Category string `json:"category"` } -// CupCategory represents championship cup categories type CupCategory struct { Value int `json:"value" gorm:"primaryKey"` Category string `json:"category"` } -// SessionType represents session types type SessionType struct { Value int `json:"value" gorm:"primaryKey"` SessionType string `json:"session_type"` diff --git a/local/model/model.go b/local/model/model.go index 3aadae4..8b13fd4 100644 --- a/local/model/model.go +++ b/local/model/model.go @@ -32,8 +32,6 @@ type BaseModel struct { DateUpdated time.Time `json:"dateUpdated"` } -// Init -// Initializes base model with DateCreated, DateUpdated, and Id values. func (cm *BaseModel) Init() { date := time.Now() cm.Id = uuid.NewString() diff --git a/local/model/permission.go b/local/model/permission.go index 3e3022d..88240f9 100644 --- a/local/model/permission.go +++ b/local/model/permission.go @@ -5,15 +5,13 @@ import ( "gorm.io/gorm" ) -// Permission represents an action that can be performed in the system. type Permission struct { ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"` Name string `json:"name" gorm:"unique_index;not null"` } -// BeforeCreate is a GORM hook that runs before creating new credentials func (s *Permission) BeforeCreate(tx *gorm.DB) error { s.ID = uuid.New() return nil -} \ No newline at end of file +} diff --git a/local/model/permissions.go b/local/model/permissions.go index b96eff0..d57c097 100644 --- a/local/model/permissions.go +++ b/local/model/permissions.go @@ -1,6 +1,5 @@ package model -// Permission constants const ( ServerView = "server.view" ServerCreate = "server.create" @@ -27,7 +26,6 @@ const ( MembershipEdit = "membership.edit" ) -// AllPermissions returns a slice of all permission strings. func AllPermissions() []string { return []string{ ServerView, diff --git a/local/model/role.go b/local/model/role.go index 9d48842..836393c 100644 --- a/local/model/role.go +++ b/local/model/role.go @@ -5,16 +5,14 @@ import ( "gorm.io/gorm" ) -// Role represents a user role in the system. type Role struct { ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"` Name string `json:"name" gorm:"unique_index;not null"` Permissions []Permission `json:"permissions" gorm:"many2many:role_permissions;"` } -// BeforeCreate is a GORM hook that runs before creating new credentials func (s *Role) BeforeCreate(tx *gorm.DB) error { s.ID = uuid.New() return nil -} \ No newline at end of file +} diff --git a/local/model/server.go b/local/model/server.go index 95a9e8c..bab75d6 100644 --- a/local/model/server.go +++ b/local/model/server.go @@ -16,7 +16,6 @@ const ( ServiceNamePrefix = "ACC-Server" ) -// Server represents an ACC server instance type ServerAPI struct { Name string `json:"name"` Status ServiceStatus `json:"status"` @@ -35,28 +34,27 @@ func (s *Server) ToServerAPI() *ServerAPI { } } -// Server represents an ACC server instance type Server struct { ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"` Name string `gorm:"not null" json:"name"` Status ServiceStatus `json:"status" gorm:"-"` IP string `gorm:"not null" json:"-"` Port int `gorm:"not null" json:"-"` - Path string `gorm:"not null" json:"path"` // e.g. "/acc/servers/server1/" - ServiceName string `gorm:"not null" json:"serviceName"` // Windows service name + Path string `gorm:"not null" json:"path"` + ServiceName string `gorm:"not null" json:"serviceName"` State *ServerState `gorm:"-" json:"state"` DateCreated time.Time `json:"dateCreated"` FromSteamCMD bool `gorm:"not null; default:true" json:"-"` } type PlayerState struct { - CarID int // Car ID in broadcast packets - DriverName string // Optional: pulled from registration packet + CarID int + DriverName string TeamName string CarModel string CurrentLap int - LastLapTime int // in milliseconds - BestLapTime int // in milliseconds + LastLapTime int + BestLapTime int Position int ConnectedAt time.Time DisconnectedAt *time.Time @@ -67,8 +65,6 @@ type State struct { Session string `json:"session"` SessionStart time.Time `json:"sessionStart"` PlayerCount int `json:"playerCount"` - // Players map[int]*PlayerState - // etc. } type ServerState struct { @@ -79,11 +75,8 @@ type ServerState struct { Track string `json:"track"` MaxConnections int `json:"maxConnections"` SessionDurationMinutes int `json:"sessionDurationMinutes"` - // Players map[int]*PlayerState - // etc. } -// ServerFilter defines filtering options for Server queries type ServerFilter struct { BaseFilter ServerBasedFilter @@ -92,9 +85,7 @@ type ServerFilter struct { Status string `query:"status"` } -// ApplyFilter implements the Filterable interface func (f *ServerFilter) ApplyFilter(query *gorm.DB) *gorm.DB { - // Apply server filter if f.ServerID != "" { if serverUUID, err := uuid.Parse(f.ServerID); err == nil { query = query.Where("id = ?", serverUUID) @@ -110,16 +101,13 @@ func (s *Server) GenerateUUID() { } } -// BeforeCreate is a GORM hook that runs before creating a new server func (s *Server) BeforeCreate(tx *gorm.DB) error { if s.Name == "" { return errors.New("server name is required") } - // Generate UUID if not set s.GenerateUUID() - // Generate service name and config path if not set if s.ServiceName == "" { s.ServiceName = s.GenerateServiceName() } @@ -127,7 +115,6 @@ func (s *Server) BeforeCreate(tx *gorm.DB) error { s.Path = s.GenerateServerPath(BaseServerPath) } - // Set creation date if not set if s.DateCreated.IsZero() { s.DateCreated = time.Now().UTC() } @@ -135,19 +122,14 @@ func (s *Server) BeforeCreate(tx *gorm.DB) error { return nil } -// GenerateServiceName creates a unique service name based on the server name func (s *Server) GenerateServiceName() string { - // If ID is set, use it if s.ID != uuid.Nil { return fmt.Sprintf("%s-%s", ServiceNamePrefix, s.ID.String()[:8]) } - // Otherwise use a timestamp-based unique identifier return fmt.Sprintf("%s-%d", ServiceNamePrefix, time.Now().UnixNano()) } -// GenerateServerPath creates the config path based on the service name func (s *Server) GenerateServerPath(steamCMDPath string) string { - // Ensure service name is set if s.ServiceName == "" { s.ServiceName = s.GenerateServiceName() } diff --git a/local/model/service_control.go b/local/model/service_control.go index c447294..835db6a 100644 --- a/local/model/service_control.go +++ b/local/model/service_control.go @@ -17,7 +17,6 @@ const ( StatusRunning ) -// String converts the ServiceStatus to its string representation func (s ServiceStatus) String() string { switch s { case StatusRunning: @@ -35,7 +34,6 @@ func (s ServiceStatus) String() string { } } -// ParseServiceStatus converts a string to ServiceStatus func ParseServiceStatus(s string) ServiceStatus { switch s { case "SERVICE_RUNNING": @@ -53,31 +51,24 @@ func ParseServiceStatus(s string) ServiceStatus { } } -// MarshalJSON implements json.Marshaler interface func (s ServiceStatus) MarshalJSON() ([]byte, error) { - // Return the numeric value instead of string return []byte(strconv.Itoa(int(s))), nil } -// UnmarshalJSON implements json.Unmarshaler interface func (s *ServiceStatus) UnmarshalJSON(data []byte) error { - // Try to parse as number first if i, err := strconv.Atoi(string(data)); err == nil { *s = ServiceStatus(i) return nil } - // Fallback to string parsing for backward compatibility str := string(data) if len(str) >= 2 { - // Remove quotes if present str = str[1 : len(str)-1] } *s = ParseServiceStatus(str) return nil } -// Scan implements the sql.Scanner interface func (s *ServiceStatus) Scan(value interface{}) error { if value == nil { *s = StatusUnknown @@ -99,7 +90,6 @@ func (s *ServiceStatus) Scan(value interface{}) error { } } -// Value implements the driver.Valuer interface func (s ServiceStatus) Value() (driver.Value, error) { return s.String(), nil } diff --git a/local/model/state_history.go b/local/model/state_history.go index 1579e32..c47bb17 100644 --- a/local/model/state_history.go +++ b/local/model/state_history.go @@ -10,27 +10,22 @@ import ( "gorm.io/gorm" ) -// StateHistoryFilter combines common filter capabilities type StateHistoryFilter struct { - ServerBasedFilter // Adds server ID from path parameter - DateRangeFilter // Adds date range filtering + ServerBasedFilter + DateRangeFilter - // Additional fields specific to state history Session TrackSession `query:"session"` MinPlayers *int `query:"min_players"` MaxPlayers *int `query:"max_players"` } -// ApplyFilter implements the Filterable interface func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB { - // Apply server filter if f.ServerID != "" { if serverUUID, err := uuid.Parse(f.ServerID); err == nil { query = query.Where("server_id = ?", serverUUID) } } - // Apply date range filter if set timeZero := time.Time{} if f.StartDate != timeZero { query = query.Where("date_created >= ?", f.StartDate) @@ -39,12 +34,10 @@ func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB { query = query.Where("date_created <= ?", f.EndDate) } - // Apply session filter if set if f.Session != "" { query = query.Where("session = ?", f.Session) } - // Apply player count filters if set if f.MinPlayers != nil { query = query.Where("player_count >= ?", *f.MinPlayers) } @@ -114,10 +107,9 @@ type StateHistory struct { DateCreated time.Time `json:"dateCreated"` SessionStart time.Time `json:"sessionStart"` SessionDurationMinutes int `json:"sessionDurationMinutes"` - SessionID uuid.UUID `json:"sessionId" gorm:"not null;type:uuid"` // Unique identifier for each session/event + SessionID uuid.UUID `json:"sessionId" gorm:"not null;type:uuid"` } -// BeforeCreate is a GORM hook that runs before creating new state history entries func (sh *StateHistory) BeforeCreate(tx *gorm.DB) error { if sh.ID == uuid.Nil { sh.ID = uuid.New() diff --git a/local/model/state_history_stats.go b/local/model/state_history_stats.go index 0a8e853..15c4fe3 100644 --- a/local/model/state_history_stats.go +++ b/local/model/state_history_stats.go @@ -21,7 +21,7 @@ type StateHistoryStats struct { AveragePlayers float64 `json:"averagePlayers"` PeakPlayers int `json:"peakPlayers"` TotalSessions int `json:"totalSessions"` - TotalPlaytime int `json:"totalPlaytime" gorm:"-"` // in minutes + TotalPlaytime int `json:"totalPlaytime" gorm:"-"` PlayerCountOverTime []PlayerCountPoint `json:"playerCountOverTime" gorm:"-"` SessionTypes []SessionCount `json:"sessionTypes" gorm:"-"` DailyActivity []DailyActivity `json:"dailyActivity" gorm:"-"` diff --git a/local/model/steam_credentials.go b/local/model/steam_credentials.go index 263a5d6..075d0e6 100644 --- a/local/model/steam_credentials.go +++ b/local/model/steam_credentials.go @@ -16,21 +16,18 @@ import ( "gorm.io/gorm" ) -// SteamCredentials represents stored Steam login credentials type SteamCredentials struct { ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"` Username string `gorm:"not null" json:"username"` - Password string `gorm:"not null" json:"-"` // Encrypted, not exposed in JSON + Password string `gorm:"not null" json:"-"` DateCreated time.Time `json:"dateCreated"` LastUpdated time.Time `json:"lastUpdated"` } -// TableName specifies the table name for GORM func (SteamCredentials) TableName() string { return "steam_credentials" } -// BeforeCreate is a GORM hook that runs before creating new credentials func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error { if s.ID == uuid.Nil { s.ID = uuid.New() @@ -42,7 +39,6 @@ func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error { } s.LastUpdated = now - // Encrypt password before saving encrypted, err := EncryptPassword(s.Password) if err != nil { return err @@ -52,11 +48,9 @@ func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error { return nil } -// BeforeUpdate is a GORM hook that runs before updating credentials func (s *SteamCredentials) BeforeUpdate(tx *gorm.DB) error { s.LastUpdated = time.Now().UTC() - // Only encrypt if password field is being updated if tx.Statement.Changed("Password") { encrypted, err := EncryptPassword(s.Password) if err != nil { @@ -68,9 +62,7 @@ func (s *SteamCredentials) BeforeUpdate(tx *gorm.DB) error { return nil } -// AfterFind is a GORM hook that runs after fetching credentials func (s *SteamCredentials) AfterFind(tx *gorm.DB) error { - // Decrypt password after fetching if s.Password != "" { decrypted, err := DecryptPassword(s.Password) if err != nil { @@ -81,18 +73,15 @@ func (s *SteamCredentials) AfterFind(tx *gorm.DB) error { return nil } -// Validate checks if the credentials are valid with enhanced security checks func (s *SteamCredentials) Validate() error { if s.Username == "" { return errors.New("username is required") } - // Enhanced username validation if len(s.Username) < 3 || len(s.Username) > 64 { return errors.New("username must be between 3 and 64 characters") } - // Check for valid characters in username (alphanumeric, underscore, hyphen) if matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, s.Username); !matched { return errors.New("username contains invalid characters") } @@ -101,7 +90,6 @@ func (s *SteamCredentials) Validate() error { return errors.New("password is required") } - // Basic password validation if len(s.Password) < 6 { return errors.New("password must be at least 6 characters long") } @@ -110,7 +98,6 @@ func (s *SteamCredentials) Validate() error { return errors.New("password is too long") } - // Check for obvious weak passwords weakPasswords := []string{"password", "123456", "steam", "admin", "user"} lowerPass := strings.ToLower(s.Password) for _, weak := range weakPasswords { @@ -122,8 +109,6 @@ func (s *SteamCredentials) Validate() error { return nil } -// GetEncryptionKey returns the encryption key from config. -// The key is loaded from the ENCRYPTION_KEY environment variable. func GetEncryptionKey() []byte { key := []byte(configs.EncryptionKey) if len(key) != 32 { @@ -132,7 +117,6 @@ func GetEncryptionKey() []byte { return key } -// EncryptPassword encrypts a password using AES-256-GCM with enhanced security func EncryptPassword(password string) (string, error) { if password == "" { return "", errors.New("password cannot be empty") @@ -148,33 +132,27 @@ func EncryptPassword(password string) (string, error) { return "", err } - // Create a new GCM cipher gcm, err := cipher.NewGCM(block) if err != nil { return "", err } - // Create a cryptographically secure nonce nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return "", err } - // Encrypt the password with authenticated encryption ciphertext := gcm.Seal(nonce, nonce, []byte(password), nil) - // Return base64 encoded encrypted password return base64.StdEncoding.EncodeToString(ciphertext), nil } -// DecryptPassword decrypts an encrypted password with enhanced validation func DecryptPassword(encryptedPassword string) (string, error) { if encryptedPassword == "" { return "", errors.New("encrypted password cannot be empty") } - // Validate base64 format - if len(encryptedPassword) < 24 { // Minimum reasonable length + if len(encryptedPassword) < 24 { return "", errors.New("invalid encrypted password format") } @@ -184,13 +162,11 @@ func DecryptPassword(encryptedPassword string) (string, error) { return "", err } - // Create a new GCM cipher gcm, err := cipher.NewGCM(block) if err != nil { return "", err } - // Decode base64 encoded password ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword) if err != nil { return "", errors.New("invalid base64 encoding") @@ -207,7 +183,6 @@ func DecryptPassword(encryptedPassword string) (string, error) { return "", errors.New("decryption failed - invalid ciphertext or key") } - // Validate decrypted content decrypted := string(plaintext) if len(decrypted) == 0 || len(decrypted) > 1024 { return "", errors.New("invalid decrypted password") diff --git a/local/model/user.go b/local/model/user.go index 7cc8adb..959e340 100644 --- a/local/model/user.go +++ b/local/model/user.go @@ -8,7 +8,6 @@ import ( "gorm.io/gorm" ) -// User represents a user account in the system. type User struct { ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"` Username string `json:"username" gorm:"unique_index;not null"` @@ -17,16 +16,13 @@ type User struct { Role Role `json:"role"` } -// BeforeCreate is a GORM hook that runs before creating new users func (s *User) BeforeCreate(tx *gorm.DB) error { s.ID = uuid.New() - // Validate password strength if err := password.ValidatePasswordStrength(s.Password); err != nil { return err } - // Hash password before saving hashed, err := password.HashPassword(s.Password) if err != nil { return err @@ -36,11 +32,8 @@ func (s *User) BeforeCreate(tx *gorm.DB) error { return nil } -// BeforeUpdate is a GORM hook that runs before updating users func (s *User) BeforeUpdate(tx *gorm.DB) error { - // Only hash if password field is being updated if tx.Statement.Changed("Password") { - // Validate password strength if err := password.ValidatePasswordStrength(s.Password); err != nil { return err } @@ -55,14 +48,10 @@ func (s *User) BeforeUpdate(tx *gorm.DB) error { return nil } -// AfterFind is a GORM hook that runs after fetching users func (s *User) AfterFind(tx *gorm.DB) error { - // Password remains hashed - never decrypt - // This hook is kept for potential future use return nil } -// Validate checks if the user data is valid func (s *User) Validate() error { if s.Username == "" { return errors.New("username is required") @@ -73,7 +62,6 @@ func (s *User) Validate() error { return nil } -// VerifyPassword verifies a plain text password against the stored hash func (s *User) VerifyPassword(plainPassword string) error { return password.VerifyPassword(s.Password, plainPassword) } diff --git a/local/model/websocket.go b/local/model/websocket.go index 3fd6b20..f894a9e 100644 --- a/local/model/websocket.go +++ b/local/model/websocket.go @@ -4,7 +4,6 @@ import ( "github.com/google/uuid" ) -// ServerCreationStep represents the steps in server creation process type ServerCreationStep string const ( @@ -18,7 +17,6 @@ const ( StepCompleted ServerCreationStep = "completed" ) -// StepStatus represents the status of a step type StepStatus string const ( @@ -28,7 +26,6 @@ const ( StatusFailed StepStatus = "failed" ) -// WebSocketMessageType represents different types of WebSocket messages type WebSocketMessageType string const ( @@ -38,7 +35,6 @@ const ( MessageTypeComplete WebSocketMessageType = "complete" ) -// WebSocketMessage is the base structure for all WebSocket messages type WebSocketMessage struct { Type WebSocketMessageType `json:"type"` ServerID *uuid.UUID `json:"server_id,omitempty"` @@ -46,7 +42,6 @@ type WebSocketMessage struct { Data interface{} `json:"data"` } -// StepMessage represents a step update message type StepMessage struct { Step ServerCreationStep `json:"step"` Status StepStatus `json:"status"` @@ -54,26 +49,22 @@ type StepMessage struct { Error string `json:"error,omitempty"` } -// SteamOutputMessage represents SteamCMD output type SteamOutputMessage struct { Output string `json:"output"` IsError bool `json:"is_error"` } -// ErrorMessage represents an error message type ErrorMessage struct { Error string `json:"error"` Details string `json:"details,omitempty"` } -// CompleteMessage represents completion message type CompleteMessage struct { ServerID uuid.UUID `json:"server_id"` Success bool `json:"success"` Message string `json:"message"` } -// GetStepDescription returns a human-readable description for each step func GetStepDescription(step ServerCreationStep) string { descriptions := map[ServerCreationStep]string{ StepValidation: "Validating server configuration", diff --git a/local/repository/base.go b/local/repository/base.go index efad80b..cae5ac4 100644 --- a/local/repository/base.go +++ b/local/repository/base.go @@ -7,13 +7,11 @@ import ( "gorm.io/gorm" ) -// BaseRepository provides generic CRUD operations for any model type BaseRepository[T any, F any] struct { db *gorm.DB modelType T } -// NewBaseRepository creates a new base repository for the given model type func NewBaseRepository[T any, F any](db *gorm.DB, model T) *BaseRepository[T, F] { return &BaseRepository[T, F]{ db: db, @@ -21,23 +19,19 @@ func NewBaseRepository[T any, F any](db *gorm.DB, model T) *BaseRepository[T, F] } } -// GetAll retrieves all records based on the filter func (r *BaseRepository[T, F]) GetAll(ctx context.Context, filter *F) (*[]T, error) { result := new([]T) query := r.db.WithContext(ctx).Model(&r.modelType) - // Apply filter conditions if filter implements Filterable if filterable, ok := any(filter).(Filterable); ok { query = filterable.ApplyFilter(query) } - // Apply pagination if filter implements Pageable if pageable, ok := any(filter).(Pageable); ok { offset, limit := pageable.Pagination() query = query.Offset(offset).Limit(limit) } - // Apply sorting if filter implements Sortable if sortable, ok := any(filter).(Sortable); ok { field, desc := sortable.GetSorting() if desc { @@ -54,7 +48,6 @@ func (r *BaseRepository[T, F]) GetAll(ctx context.Context, filter *F) (*[]T, err return result, nil } -// GetByID retrieves a single record by ID func (r *BaseRepository[T, F]) GetByID(ctx context.Context, id interface{}) (*T, error) { result := new(T) if err := r.db.WithContext(ctx).Where("id = ?", id).First(result).Error; err != nil { @@ -66,7 +59,6 @@ func (r *BaseRepository[T, F]) GetByID(ctx context.Context, id interface{}) (*T, return result, nil } -// Insert creates a new record func (r *BaseRepository[T, F]) Insert(ctx context.Context, model *T) error { if err := r.db.WithContext(ctx).Create(model).Error; err != nil { return fmt.Errorf("error creating record: %w", err) @@ -74,7 +66,6 @@ func (r *BaseRepository[T, F]) Insert(ctx context.Context, model *T) error { return nil } -// Update modifies an existing record func (r *BaseRepository[T, F]) Update(ctx context.Context, model *T) error { if err := r.db.WithContext(ctx).Save(model).Error; err != nil { return fmt.Errorf("error updating record: %w", err) @@ -82,7 +73,6 @@ func (r *BaseRepository[T, F]) Update(ctx context.Context, model *T) error { return nil } -// Delete removes a record by ID func (r *BaseRepository[T, F]) Delete(ctx context.Context, id interface{}) error { if err := r.db.WithContext(ctx).Delete(new(T), id).Error; err != nil { return fmt.Errorf("error deleting record: %w", err) @@ -90,7 +80,6 @@ func (r *BaseRepository[T, F]) Delete(ctx context.Context, id interface{}) error return nil } -// Count returns the total number of records matching the filter func (r *BaseRepository[T, F]) Count(ctx context.Context, filter *F) (int64, error) { var count int64 query := r.db.WithContext(ctx).Model(&r.modelType) @@ -106,8 +95,6 @@ func (r *BaseRepository[T, F]) Count(ctx context.Context, filter *F) (int64, err return count, nil } -// Interfaces for filter capabilities - type Filterable interface { ApplyFilter(*gorm.DB) *gorm.DB } diff --git a/local/repository/config.go b/local/repository/config.go index 1e4b1e7..e303365 100644 --- a/local/repository/config.go +++ b/local/repository/config.go @@ -17,13 +17,11 @@ func NewConfigRepository(db *gorm.DB) *ConfigRepository { } } -// UpdateConfig updates or creates a Config record func (r *ConfigRepository) UpdateConfig(ctx context.Context, config *model.Config) *model.Config { if err := r.Update(ctx, config); err != nil { - // If update fails, try to insert if err := r.Insert(ctx, config); err != nil { return nil } } return config -} \ No newline at end of file +} diff --git a/local/repository/membership.go b/local/repository/membership.go index ae9d000..f931ff5 100644 --- a/local/repository/membership.go +++ b/local/repository/membership.go @@ -8,20 +8,16 @@ import ( "gorm.io/gorm" ) -// MembershipRepository handles database operations for users, roles, and permissions. type MembershipRepository struct { *BaseRepository[model.User, model.MembershipFilter] } -// NewMembershipRepository creates a new MembershipRepository. func NewMembershipRepository(db *gorm.DB) *MembershipRepository { return &MembershipRepository{ BaseRepository: NewBaseRepository[model.User, model.MembershipFilter](db, model.User{}), } } -// FindUserByUsername finds a user by their username. -// It preloads the user's role and the role's permissions. func (r *MembershipRepository) FindUserByUsername(ctx context.Context, username string) (*model.User, error) { var user model.User db := r.db.WithContext(ctx) @@ -32,7 +28,6 @@ func (r *MembershipRepository) FindUserByUsername(ctx context.Context, username return &user, nil } -// FindUserByIDWithPermissions finds a user by their ID and preloads Role and Permissions. func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context, userID string) (*model.User, error) { var user model.User db := r.db.WithContext(ctx) @@ -43,13 +38,11 @@ func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context, return &user, nil } -// CreateUser creates a new user. func (r *MembershipRepository) CreateUser(ctx context.Context, user *model.User) error { db := r.db.WithContext(ctx) return db.Create(user).Error } -// FindRoleByName finds a role by its name. func (r *MembershipRepository) FindRoleByName(ctx context.Context, name string) (*model.Role, error) { var role model.Role db := r.db.WithContext(ctx) @@ -60,13 +53,11 @@ func (r *MembershipRepository) FindRoleByName(ctx context.Context, name string) return &role, nil } -// CreateRole creates a new role. func (r *MembershipRepository) CreateRole(ctx context.Context, role *model.Role) error { db := r.db.WithContext(ctx) return db.Create(role).Error } -// FindPermissionByName finds a permission by its name. func (r *MembershipRepository) FindPermissionByName(ctx context.Context, name string) (*model.Permission, error) { var permission model.Permission db := r.db.WithContext(ctx) @@ -77,19 +68,16 @@ func (r *MembershipRepository) FindPermissionByName(ctx context.Context, name st return &permission, nil } -// CreatePermission creates a new permission. func (r *MembershipRepository) CreatePermission(ctx context.Context, permission *model.Permission) error { db := r.db.WithContext(ctx) return db.Create(permission).Error } -// AssignPermissionsToRole assigns a set of permissions to a role. func (r *MembershipRepository) AssignPermissionsToRole(ctx context.Context, role *model.Role, permissions []model.Permission) error { db := r.db.WithContext(ctx) return db.Model(role).Association("Permissions").Replace(permissions) } -// GetUserPermissions retrieves all permissions for a given user ID. func (r *MembershipRepository) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) { var user model.User db := r.db.WithContext(ctx) @@ -106,7 +94,6 @@ func (r *MembershipRepository) GetUserPermissions(ctx context.Context, userID uu return permissions, nil } -// ListUsers retrieves all users. func (r *MembershipRepository) ListUsers(ctx context.Context) ([]*model.User, error) { var users []*model.User db := r.db.WithContext(ctx) @@ -114,13 +101,11 @@ func (r *MembershipRepository) ListUsers(ctx context.Context) ([]*model.User, er return users, err } -// DeleteUser deletes a user. func (r *MembershipRepository) DeleteUser(ctx context.Context, userID uuid.UUID) error { db := r.db.WithContext(ctx) return db.Delete(&model.User{}, "id = ?", userID).Error } -// FindUserByID finds a user by their ID. func (r *MembershipRepository) FindUserByID(ctx context.Context, userID uuid.UUID) (*model.User, error) { var user model.User db := r.db.WithContext(ctx) @@ -131,13 +116,11 @@ func (r *MembershipRepository) FindUserByID(ctx context.Context, userID uuid.UUI return &user, nil } -// UpdateUser updates a user's details in the database. func (r *MembershipRepository) UpdateUser(ctx context.Context, user *model.User) error { db := r.db.WithContext(ctx) return db.Save(user).Error } -// FindRoleByID finds a role by its ID. func (r *MembershipRepository) FindRoleByID(ctx context.Context, roleID uuid.UUID) (*model.Role, error) { var role model.Role db := r.db.WithContext(ctx) @@ -148,12 +131,10 @@ func (r *MembershipRepository) FindRoleByID(ctx context.Context, roleID uuid.UUI return &role, nil } -// ListUsersWithFilter retrieves users based on the membership filter. func (r *MembershipRepository) ListUsersWithFilter(ctx context.Context, filter *model.MembershipFilter) (*[]model.User, error) { return r.BaseRepository.GetAll(ctx, filter) } -// ListRoles retrieves all roles. func (r *MembershipRepository) ListRoles(ctx context.Context) ([]*model.Role, error) { var roles []*model.Role db := r.db.WithContext(ctx) diff --git a/local/repository/repository.go b/local/repository/repository.go index 664a968..8600daf 100644 --- a/local/repository/repository.go +++ b/local/repository/repository.go @@ -10,11 +10,7 @@ import ( "go.uber.org/dig" ) -// InitializeRepositories -// Initializes Dependency Injection modules for repositories -// -// Args: -// *dig.Container: Dig Container +// *dig.Container: Dig Container func InitializeRepositories(c *dig.Container) { c.Provide(NewServiceControlRepository) c.Provide(NewStateHistoryRepository) @@ -24,16 +20,14 @@ func InitializeRepositories(c *dig.Container) { c.Provide(NewSteamCredentialsRepository) c.Provide(NewMembershipRepository) - // Provide the Steam2FAManager as a singleton if err := c.Provide(func() *model.Steam2FAManager { manager := model.NewSteam2FAManager() - - // Use graceful shutdown manager for cleanup goroutine + shutdownManager := graceful.GetManager() shutdownManager.RunGoroutine(func(ctx context.Context) { ticker := time.NewTicker(15 * time.Minute) defer ticker.Stop() - + for { select { case <-ctx.Done(): @@ -43,7 +37,7 @@ func InitializeRepositories(c *dig.Container) { } } }) - + return manager }); err != nil { logging.Panic("unable to initialize steam 2fa manager") diff --git a/local/repository/server.go b/local/repository/server.go index c82566f..ffd318c 100644 --- a/local/repository/server.go +++ b/local/repository/server.go @@ -20,10 +20,6 @@ func NewServerRepository(db *gorm.DB) *ServerRepository { return repo } -// GetFirstByServiceName -// Gets first row from Server table. -// -// Args: // context.Context: Application context // Returns: // model.ServerModel: Server object from database. diff --git a/local/repository/state_history.go b/local/repository/state_history.go index de81458..e537182 100644 --- a/local/repository/state_history.go +++ b/local/repository/state_history.go @@ -18,17 +18,14 @@ func NewStateHistoryRepository(db *gorm.DB) *StateHistoryRepository { } } -// GetAll retrieves all state history records with the given filter func (r *StateHistoryRepository) GetAll(ctx context.Context, filter *model.StateHistoryFilter) (*[]model.StateHistory, error) { return r.BaseRepository.GetAll(ctx, filter) } -// Insert creates a new state history record func (r *StateHistoryRepository) Insert(ctx context.Context, model *model.StateHistory) error { return r.BaseRepository.Insert(ctx, model) } -// GetLastSessionID gets the last session ID for a server func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID uuid.UUID) (uuid.UUID, error) { var lastSession model.StateHistory result := r.BaseRepository.db.WithContext(ctx). @@ -38,7 +35,7 @@ func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { - return uuid.Nil, nil // Return nil UUID if no sessions found + return uuid.Nil, nil } return uuid.Nil, result.Error } @@ -46,10 +43,8 @@ func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID return lastSession.SessionID, nil } -// GetSummaryStats calculates peak players, total sessions, and average players. func (r *StateHistoryRepository) GetSummaryStats(ctx context.Context, filter *model.StateHistoryFilter) (model.StateHistoryStats, error) { var stats model.StateHistoryStats - // Parse ServerID to UUID for query serverUUID, err := uuid.Parse(filter.ServerID) if err != nil { return model.StateHistoryStats{}, err @@ -73,12 +68,10 @@ func (r *StateHistoryRepository) GetSummaryStats(ctx context.Context, filter *mo return stats, nil } -// GetTotalPlaytime calculates the total playtime in minutes. func (r *StateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *model.StateHistoryFilter) (int, error) { var totalPlaytime struct { TotalMinutes float64 } - // Parse ServerID to UUID for query serverUUID, err := uuid.Parse(filter.ServerID) if err != nil { return 0, err @@ -100,10 +93,8 @@ func (r *StateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *m return int(totalPlaytime.TotalMinutes), nil } -// GetPlayerCountOverTime gets downsampled player count data. func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, filter *model.StateHistoryFilter) ([]model.PlayerCountPoint, error) { var points []model.PlayerCountPoint - // Parse ServerID to UUID for query serverUUID, err := uuid.Parse(filter.ServerID) if err != nil { return points, err @@ -122,10 +113,8 @@ func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, fil return points, err } -// GetSessionTypes counts sessions by type. func (r *StateHistoryRepository) GetSessionTypes(ctx context.Context, filter *model.StateHistoryFilter) ([]model.SessionCount, error) { var sessionTypes []model.SessionCount - // Parse ServerID to UUID for query serverUUID, err := uuid.Parse(filter.ServerID) if err != nil { return sessionTypes, err @@ -145,10 +134,8 @@ func (r *StateHistoryRepository) GetSessionTypes(ctx context.Context, filter *mo return sessionTypes, err } -// GetDailyActivity counts sessions per day. func (r *StateHistoryRepository) GetDailyActivity(ctx context.Context, filter *model.StateHistoryFilter) ([]model.DailyActivity, error) { var dailyActivity []model.DailyActivity - // Parse ServerID to UUID for query serverUUID, err := uuid.Parse(filter.ServerID) if err != nil { return dailyActivity, err @@ -167,10 +154,8 @@ func (r *StateHistoryRepository) GetDailyActivity(ctx context.Context, filter *m return dailyActivity, err } -// GetRecentSessions retrieves the 10 most recent sessions. func (r *StateHistoryRepository) GetRecentSessions(ctx context.Context, filter *model.StateHistoryFilter) ([]model.RecentSession, error) { var recentSessions []model.RecentSession - // Parse ServerID to UUID for query serverUUID, err := uuid.Parse(filter.ServerID) if err != nil { return recentSessions, err diff --git a/local/service/config.go b/local/service/config.go index 0161c15..09363f1 100644 --- a/local/service/config.go +++ b/local/service/config.go @@ -77,8 +77,8 @@ func NewConfigService(repository *repository.ConfigRepository, serverRepository repository: repository, serverRepository: serverRepository, configCache: model.NewServerConfigCache(model.CacheConfig{ - ExpirationTime: 5 * time.Minute, // Cache configs for 5 minutes - ThrottleTime: 1 * time.Second, // Prevent rapid re-reads + ExpirationTime: 5 * time.Minute, + ThrottleTime: 1 * time.Second, DefaultStatus: model.StatusUnknown, }), } @@ -88,10 +88,6 @@ func (as *ConfigService) SetServerService(serverService *ServerService) { as.serverService = serverService } -// UpdateConfig -// Updates physical config file and caches it in database. -// -// Args: // context.Context: Application context // Returns: // string: Application version @@ -103,7 +99,62 @@ func (as *ConfigService) UpdateConfig(ctx *fiber.Ctx, body *map[string]interface return as.updateConfigInternal(ctx.UserContext(), serverID, configFile, body, override) } -// updateConfigInternal handles the actual config update logic without Fiber dependencies +func (as *ConfigService) updateConfigFiles(ctx context.Context, server *model.Server, configFile string, body *map[string]interface{}, override bool) ([]byte, []byte, error) { + if server == nil { + logging.Error("Server not found") + return nil, nil, fmt.Errorf("server not found") + } + + configPath := filepath.Join(server.GetConfigPath(), configFile) + oldData, err := os.ReadFile(configPath) + if err != nil { + if os.IsNotExist(err) { + dir := filepath.Dir(configPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, nil, err + } + if err := os.WriteFile(configPath, []byte("{}"), 0644); err != nil { + return nil, nil, err + } + oldData = []byte("{}") + } else { + return nil, nil, err + } + } + + oldDataUTF8, err := DecodeUTF16LEBOM(oldData) + if err != nil { + return nil, nil, err + } + + newData, err := json.Marshal(&body) + if err != nil { + return nil, nil, err + } + + if !override { + newData, err = jsons.Merge(oldDataUTF8, newData) + if err != nil { + return nil, nil, err + } + } + newData, err = common.IndentJson(newData) + if err != nil { + return nil, nil, err + } + + newDataUTF16, err := EncodeUTF16LEBOM(newData) + if err != nil { + return nil, nil, err + } + + if err := os.WriteFile(configPath, newDataUTF16, 0644); err != nil { + return nil, nil, err + } + + return oldDataUTF8, newData, nil +} + func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID string, configFile string, body *map[string]interface{}, override bool) (*model.Config, error) { serverUUID, err := uuid.Parse(serverID) if err != nil { @@ -117,63 +168,14 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri return nil, fmt.Errorf("server not found") } - // Read existing config - configPath := filepath.Join(server.GetConfigPath(), configFile) - oldData, err := os.ReadFile(configPath) - if err != nil { - if os.IsNotExist(err) { - // Create directory if it doesn't exist - dir := filepath.Dir(configPath) - if err := os.MkdirAll(dir, 0755); err != nil { - return nil, err - } - // Create empty JSON file - if err := os.WriteFile(configPath, []byte("{}"), 0644); err != nil { - return nil, err - } - oldData = []byte("{}") - } else { - return nil, err - } - } - - oldDataUTF8, err := DecodeUTF16LEBOM(oldData) + oldDataUTF8, newData, err := as.updateConfigFiles(ctx, server, configFile, body, override) if err != nil { return nil, err } - // Write new config - newData, err := json.Marshal(&body) - if err != nil { - return nil, err - } - - if !override { - newData, err = jsons.Merge(oldDataUTF8, newData) - if err != nil { - return nil, err - } - } - newData, err = common.IndentJson(newData) - if err != nil { - return nil, err - } - - newDataUTF16, err := EncodeUTF16LEBOM(newData) - if err != nil { - return nil, err - } - - if err := os.WriteFile(configPath, newDataUTF16, 0644); err != nil { - return nil, err - } - - // Invalidate all configs for this server since configs can be interdependent as.configCache.InvalidateServerCache(serverID) as.serverService.StartAccServerRuntime(server) - - // Log change return as.repository.UpdateConfig(ctx, &model.Config{ ServerID: serverUUID, ConfigFile: configFile, @@ -183,10 +185,6 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri }), nil } -// GetConfig -// Gets physical config file and caches it in database. -// -// Args: // context.Context: Application context // Returns: // string: Application version @@ -197,44 +195,47 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) { logging.Debug("Getting config for server ID: %s, file: %s", serverIDStr, configFile) server, err := as.serverRepository.GetByID(ctx.UserContext(), serverIDStr) + if err != nil { logging.Error("Server not found") return nil, fiber.NewError(404, "Server not found") } + return as.getConfigFile(server, configFile) +} - // Try to get from cache based on config file type +func (as *ConfigService) getConfigFile(server *model.Server, configFile string) (interface{}, error) { + serverIDStr := server.ID.String() switch configFile { case ConfigurationJson: if cached, ok := as.configCache.GetConfiguration(serverIDStr); ok { logging.Debug("Returning cached configuration for server ID: %s", serverIDStr) - return cached, nil + return *cached, nil } case AssistRulesJson: if cached, ok := as.configCache.GetAssistRules(serverIDStr); ok { logging.Debug("Returning cached assist rules for server ID: %s", serverIDStr) - return cached, nil + return *cached, nil } case EventJson: if cached, ok := as.configCache.GetEvent(serverIDStr); ok { logging.Debug("Returning cached event config for server ID: %s", serverIDStr) - return cached, nil + return *cached, nil } case EventRulesJson: if cached, ok := as.configCache.GetEventRules(serverIDStr); ok { logging.Debug("Returning cached event rules for server ID: %s", serverIDStr) - return cached, nil + return *cached, nil } case SettingsJson: if cached, ok := as.configCache.GetSettings(serverIDStr); ok { logging.Debug("Returning cached settings for server ID: %s", serverIDStr) - return cached, nil + return *cached, nil } } logging.Debug("Cache miss for server ID: %s, file: %s - loading from disk", serverIDStr, configFile) - // Not in cache, load from disk - configPath := filepath.Join(server.GetConfigPath(), configFile) + configPath := server.GetConfigPath() decoder := DecodeFileName(configFile) if decoder == nil { return nil, errors.New("invalid config file") @@ -244,43 +245,39 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) { if err != nil { if os.IsNotExist(err) { logging.Debug("Config file not found, creating default for server ID: %s, file: %s", serverIDStr, configFile) - // Return empty config if file doesn't exist switch configFile { case ConfigurationJson: - return &model.Configuration{}, nil + return model.Configuration{}, nil case AssistRulesJson: - return &model.AssistRules{}, nil + return model.AssistRules{}, nil case EventJson: - return &model.EventConfig{}, nil + return model.EventConfig{}, nil case EventRulesJson: - return &model.EventRules{}, nil + return model.EventRules{}, nil case SettingsJson: - return &model.ServerSettings{}, nil + return model.ServerSettings{}, nil } } return nil, err } - // Cache the loaded config switch configFile { case ConfigurationJson: - as.configCache.UpdateConfiguration(serverIDStr, *config.(*model.Configuration)) + as.configCache.UpdateConfiguration(serverIDStr, config.(model.Configuration)) case AssistRulesJson: - as.configCache.UpdateAssistRules(serverIDStr, *config.(*model.AssistRules)) + as.configCache.UpdateAssistRules(serverIDStr, config.(model.AssistRules)) case EventJson: - as.configCache.UpdateEvent(serverIDStr, *config.(*model.EventConfig)) + as.configCache.UpdateEvent(serverIDStr, config.(model.EventConfig)) case EventRulesJson: - as.configCache.UpdateEventRules(serverIDStr, *config.(*model.EventRules)) + as.configCache.UpdateEventRules(serverIDStr, config.(model.EventRules)) case SettingsJson: - as.configCache.UpdateSettings(serverIDStr, *config.(*model.ServerSettings)) + as.configCache.UpdateSettings(serverIDStr, config.(model.ServerSettings)) } logging.Debug("Successfully loaded and cached config for server ID: %s, file: %s", serverIDStr, configFile) return config, nil } -// GetConfigs -// Gets all configurations for a server, using cache when possible. func (as *ConfigService) GetConfigs(ctx *fiber.Ctx) (*model.Configurations, error) { serverID := ctx.Params("id") @@ -296,82 +293,33 @@ func (as *ConfigService) GetConfigs(ctx *fiber.Ctx) (*model.Configurations, erro func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configurations, error) { serverIDStr := server.ID.String() logging.Info("Loading configs for server ID: %s at path: %s", serverIDStr, server.GetConfigPath()) - configs := &model.Configurations{} - // Load configuration - if cached, ok := as.configCache.GetConfiguration(serverIDStr); ok { - logging.Debug("Using cached configuration for server %s", serverIDStr) - configs.Configuration = *cached - } else { - logging.Debug("Loading configuration from disk for server %s", serverIDStr) - config, err := mustDecode[model.Configuration](ConfigurationJson, server.GetConfigPath()) - if err != nil { - logging.Error("Failed to load configuration for server %s: %v", serverIDStr, err) - return nil, fmt.Errorf("failed to load configuration: %v", err) - } - configs.Configuration = config - as.configCache.UpdateConfiguration(serverIDStr, config) + settingsConf, err := as.getConfigFile(server, SettingsJson) + if err != nil { + return nil, err } - - // Load assist rules - if cached, ok := as.configCache.GetAssistRules(serverIDStr); ok { - logging.Debug("Using cached assist rules for server %s", serverIDStr) - configs.AssistRules = *cached - } else { - logging.Debug("Loading assist rules from disk for server %s", serverIDStr) - rules, err := mustDecode[model.AssistRules](AssistRulesJson, server.GetConfigPath()) - if err != nil { - logging.Error("Failed to load assist rules for server %s: %v", serverIDStr, err) - return nil, fmt.Errorf("failed to load assist rules: %v", err) - } - configs.AssistRules = rules - as.configCache.UpdateAssistRules(serverIDStr, rules) + eventRulesConf, err := as.getConfigFile(server, EventRulesJson) + if err != nil { + return nil, err } - - // Load event config - if cached, ok := as.configCache.GetEvent(serverIDStr); ok { - logging.Debug("Using cached event config for server %s", serverIDStr) - configs.Event = *cached - } else { - logging.Debug("Loading event config from disk for server %s", serverIDStr) - event, err := mustDecode[model.EventConfig](EventJson, server.GetConfigPath()) - if err != nil { - logging.Error("Failed to load event config for server %s: %v", serverIDStr, err) - return nil, fmt.Errorf("failed to load event config: %v", err) - } - configs.Event = event - logging.Debug("Updating event config for server %s with track: %s", serverIDStr, event.Track) - as.configCache.UpdateEvent(serverIDStr, event) + eventConf, err := as.getConfigFile(server, EventJson) + if err != nil { + return nil, err } - - // Load event rules - if cached, ok := as.configCache.GetEventRules(serverIDStr); ok { - logging.Debug("Using cached event rules for server %s", serverIDStr) - configs.EventRules = *cached - } else { - logging.Debug("Loading event rules from disk for server %s", serverIDStr) - rules, err := mustDecode[model.EventRules](EventRulesJson, server.GetConfigPath()) - if err != nil { - logging.Error("Failed to load event rules for server %s: %v", serverIDStr, err) - return nil, fmt.Errorf("failed to load event rules: %v", err) - } - configs.EventRules = rules - as.configCache.UpdateEventRules(serverIDStr, rules) + assistRulesConf, err := as.getConfigFile(server, AssistRulesJson) + if err != nil { + return nil, err } - - // Load settings - if cached, ok := as.configCache.GetSettings(serverIDStr); ok { - logging.Debug("Using cached settings for server %s", serverIDStr) - configs.Settings = *cached - } else { - logging.Debug("Loading settings from disk for server %s", serverIDStr) - settings, err := mustDecode[model.ServerSettings](SettingsJson, server.GetConfigPath()) - if err != nil { - logging.Error("Failed to load settings for server %s: %v", serverIDStr, err) - return nil, fmt.Errorf("failed to load settings: %v", err) - } - configs.Settings = settings - as.configCache.UpdateSettings(serverIDStr, settings) + configurationConf, err := as.getConfigFile(server, ConfigurationJson) + if err != nil { + return nil, err + } + configs := &model.Configurations{ + Settings: settingsConf.(model.ServerSettings), + EventRules: eventRulesConf.(model.EventRules), + Event: eventConf.(model.EventConfig), + AssistRules: assistRulesConf.(model.AssistRules), + Configuration: configurationConf.(model.Configuration), } logging.Info("Successfully loaded all configs for server %s", serverIDStr) @@ -396,9 +344,6 @@ func readFile(path string, configFile string) ([]byte, error) { configPath := filepath.Join(path, configFile) oldData, err := os.ReadFile(configPath) if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil, fmt.Errorf("config file %s does not exist at %s", configFile, configPath) - } return nil, err } return oldData, nil @@ -475,9 +420,7 @@ func (as *ConfigService) GetConfiguration(server *model.Server) (*model.Configur return &config, nil } -// SaveConfiguration saves the configuration for a server func (as *ConfigService) SaveConfiguration(server *model.Server, config *model.Configuration) error { - // Convert config to map for UpdateConfig configMap := make(map[string]interface{}) configBytes, err := json.Marshal(config) if err != nil { @@ -487,7 +430,6 @@ func (as *ConfigService) SaveConfiguration(server *model.Server, config *model.C return fmt.Errorf("failed to unmarshal configuration: %v", err) } - // Update the configuration using the internal method - _, err = as.updateConfigInternal(context.Background(), server.ID.String(), ConfigurationJson, &configMap, true) + _, _, err = as.updateConfigFiles(context.Background(), server, ConfigurationJson, &configMap, true) return err } diff --git a/local/service/firewall_service.go b/local/service/firewall_service.go index e857462..0894a51 100644 --- a/local/service/firewall_service.go +++ b/local/service/firewall_service.go @@ -96,11 +96,9 @@ func (s *FirewallService) DeleteServerRules(serverName string, tcpPorts, udpPort } func (s *FirewallService) UpdateServerRules(serverName string, tcpPorts, udpPorts []int) error { - // First delete existing rules if err := s.DeleteServerRules(serverName, tcpPorts, udpPorts); err != nil { return err } - // Then create new rules return s.CreateServerRules(serverName, tcpPorts, udpPorts) -} \ No newline at end of file +} diff --git a/local/service/membership.go b/local/service/membership.go index 157d8fc..1e0f008 100644 --- a/local/service/membership.go +++ b/local/service/membership.go @@ -12,13 +12,11 @@ import ( "github.com/google/uuid" ) -// CacheInvalidator interface for cache invalidation type CacheInvalidator interface { InvalidateUserPermissions(userID string) InvalidateAllUserPermissions() } -// MembershipService provides business logic for membership-related operations. type MembershipService struct { repo *repository.MembershipRepository cacheInvalidator CacheInvalidator @@ -26,29 +24,25 @@ type MembershipService struct { openJwtHandler *jwt.OpenJWTHandler } -// NewMembershipService creates a new MembershipService. func NewMembershipService(repo *repository.MembershipRepository, jwtHandler *jwt.JWTHandler, openJwtHandler *jwt.OpenJWTHandler) *MembershipService { return &MembershipService{ repo: repo, - cacheInvalidator: nil, // Will be set later via SetCacheInvalidator + cacheInvalidator: nil, jwtHandler: jwtHandler, openJwtHandler: openJwtHandler, } } -// SetCacheInvalidator sets the cache invalidator after service initialization func (s *MembershipService) SetCacheInvalidator(invalidator CacheInvalidator) { s.cacheInvalidator = invalidator } -// Login authenticates a user and returns a JWT. func (s *MembershipService) HandleLogin(ctx context.Context, username, password string) (*model.User, error) { user, err := s.repo.FindUserByUsername(ctx, username) if err != nil { return nil, errors.New("invalid credentials") } - // Use secure password verification with constant-time comparison if err := user.VerifyPassword(password); err != nil { return nil, errors.New("invalid credentials") } @@ -56,7 +50,6 @@ func (s *MembershipService) HandleLogin(ctx context.Context, username, password return user, nil } -// Login authenticates a user and returns a JWT. func (s *MembershipService) Login(ctx context.Context, username, password string) (string, error) { user, err := s.HandleLogin(ctx, username, password) if err != nil { @@ -70,7 +63,6 @@ func (s *MembershipService) GenerateOpenToken(ctx context.Context, userId string return s.openJwtHandler.GenerateToken(userId) } -// CreateUser creates a new user. func (s *MembershipService) CreateUser(ctx context.Context, username, password, roleName string) (*model.User, error) { role, err := s.repo.FindRoleByName(ctx, roleName) @@ -94,43 +86,35 @@ func (s *MembershipService) CreateUser(ctx context.Context, username, password, return user, nil } -// ListUsers retrieves all users. func (s *MembershipService) ListUsers(ctx context.Context) ([]*model.User, error) { return s.repo.ListUsers(ctx) } -// GetUser retrieves a single user by ID. func (s *MembershipService) GetUser(ctx context.Context, userID uuid.UUID) (*model.User, error) { return s.repo.FindUserByID(ctx, userID) } -// GetUserWithPermissions retrieves a single user by ID with their role and permissions. func (s *MembershipService) GetUserWithPermissions(ctx context.Context, userID string) (*model.User, error) { return s.repo.FindUserByIDWithPermissions(ctx, userID) } -// UpdateUserRequest defines the request body for updating a user. type UpdateUserRequest struct { Username *string `json:"username"` Password *string `json:"password"` RoleID *uuid.UUID `json:"roleId"` } -// DeleteUser deletes a user with validation to prevent Super Admin deletion. func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) error { - // Get user with role information user, err := s.repo.FindUserByID(ctx, userID) if err != nil { return errors.New("user not found") } - // Get role to check if it's Super Admin role, err := s.repo.FindRoleByID(ctx, user.RoleID) if err != nil { return errors.New("user role not found") } - // Prevent deletion of Super Admin users if role.Name == "Super Admin" { return errors.New("cannot delete Super Admin user") } @@ -140,7 +124,6 @@ func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) er return err } - // Invalidate cache for deleted user if s.cacheInvalidator != nil { s.cacheInvalidator.InvalidateUserPermissions(userID.String()) } @@ -149,7 +132,6 @@ func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) er return nil } -// UpdateUser updates a user's details. func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, req UpdateUserRequest) (*model.User, error) { user, err := s.repo.FindUserByID(ctx, userID) if err != nil { @@ -161,12 +143,10 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re } if req.Password != nil && *req.Password != "" { - // Password will be automatically hashed in BeforeUpdate hook user.Password = *req.Password } if req.RoleID != nil { - // Check if role exists _, err := s.repo.FindRoleByID(ctx, *req.RoleID) if err != nil { return nil, errors.New("role not found") @@ -178,7 +158,6 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re return nil, err } - // Invalidate cache if role was changed if req.RoleID != nil && s.cacheInvalidator != nil { s.cacheInvalidator.InvalidateUserPermissions(userID.String()) } @@ -187,14 +166,12 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re return user, nil } -// HasPermission checks if a user has a specific permission. func (s *MembershipService) HasPermission(ctx context.Context, userID string, permissionName string) (bool, error) { user, err := s.repo.FindUserByIDWithPermissions(ctx, userID) if err != nil { return false, err } - // Super admin and Admin have all permissions if user.Role.Name == "Super Admin" || user.Role.Name == "Admin" { return true, nil } @@ -208,15 +185,13 @@ func (s *MembershipService) HasPermission(ctx context.Context, userID string, pe return false, nil } -// SetupInitialData creates the initial roles and permissions. func (s *MembershipService) SetupInitialData(ctx context.Context) error { - // Define all permissions permissions := model.AllPermissions() createdPermissions := make([]model.Permission, 0) for _, pName := range permissions { perm, err := s.repo.FindPermissionByName(ctx, pName) - if err != nil { // Assuming error means not found + if err != nil { perm = &model.Permission{Name: pName} if err := s.repo.CreatePermission(ctx, perm); err != nil { return err @@ -225,7 +200,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error { createdPermissions = append(createdPermissions, *perm) } - // Create Super Admin role with all permissions superAdminRole, err := s.repo.FindRoleByName(ctx, "Super Admin") if err != nil { superAdminRole = &model.Role{Name: "Super Admin"} @@ -237,7 +211,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error { return err } - // Create Admin role with same permissions as Super Admin adminRole, err := s.repo.FindRoleByName(ctx, "Admin") if err != nil { adminRole = &model.Role{Name: "Admin"} @@ -249,7 +222,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error { return err } - // Create Manager role with limited permissions (excluding membership, role, user, server create/delete) managerRole, err := s.repo.FindRoleByName(ctx, "Manager") if err != nil { managerRole = &model.Role{Name: "Manager"} @@ -258,7 +230,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error { } } - // Define manager permissions (limited set) managerPermissionNames := []string{ model.ServerView, model.ServerUpdate, @@ -282,16 +253,14 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error { return err } - // Invalidate all caches after role setup changes if s.cacheInvalidator != nil { s.cacheInvalidator.InvalidateAllUserPermissions() } - // Create a default admin user if one doesn't exist _, err = s.repo.FindUserByUsername(ctx, "admin") if err != nil { logging.Debug("Creating default admin user") - _, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin") // Default password, should be changed + _, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin") if err != nil { return err } @@ -300,7 +269,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error { return nil } -// GetAllRoles retrieves all roles for dropdown selection. func (s *MembershipService) GetAllRoles(ctx context.Context) ([]*model.Role, error) { return s.repo.ListRoles(ctx) } diff --git a/local/service/server.go b/local/service/server.go index b193120..00adc45 100644 --- a/local/service/server.go +++ b/local/service/server.go @@ -20,7 +20,7 @@ import ( const ( DefaultStartPort = 9600 - RequiredPortCount = 1 // Update this if ACC needs more ports + RequiredPortCount = 1 ) type ServerService struct { @@ -45,18 +45,15 @@ type pendingState struct { } func (s *ServerService) ensureLogTailing(server *model.Server, instance *tracking.AccServerInstance) { - // Check if we already have a tailer if _, exists := s.logTailers.Load(server.ID); exists { return } - // Start tailing in a goroutine that handles file creation/deletion go func() { logPath := filepath.Join(server.GetLogPath(), "server.log") tailer := tracking.NewLogTailer(logPath, instance.HandleLogLine) s.logTailers.Store(server.ID, tailer) - // Start tailing and automatically handle file changes tailer.Start() }() } @@ -82,7 +79,6 @@ func NewServerService( webSocketService: webSocketService, } - // Initialize server instances servers, err := repository.GetAll(context.Background(), &model.ServerFilter{}) if err != nil { logging.Error("Failed to get servers: %v", err) @@ -90,7 +86,6 @@ func NewServerService( } for i := range *servers { - // Initialize instance regardless of status logging.Info("Starting server runtime for server ID: %d", (*servers)[i].ID) service.StartAccServerRuntime(&(*servers)[i]) } @@ -99,7 +94,7 @@ func NewServerService( } func (s *ServerService) shouldInsertStateHistory(serverID uuid.UUID) bool { - insertInterval := 5 * time.Minute // Configure this as needed + insertInterval := 5 * time.Minute lastInsertInterface, exists := s.lastInsertTimes.Load(serverID) if !exists { @@ -122,16 +117,15 @@ func (s *ServerService) getNextSessionID(serverID uuid.UUID) uuid.UUID { lastID, err := s.stateHistoryRepo.GetLastSessionID(context.Background(), serverID) if err != nil { logging.Error("Failed to get last session ID for server %s: %v", serverID, err) - return uuid.New() // Return new UUID as fallback + return uuid.New() } if lastID == uuid.Nil { - return uuid.New() // Return new UUID if no previous session + return uuid.New() } - return uuid.New() // Always generate new UUID for each session + return uuid.New() } func (s *ServerService) insertStateHistory(serverID uuid.UUID, state *model.ServerState) { - // Get or create session ID when session changes currentSessionInterface, exists := s.instances.Load(serverID) var sessionID uuid.UUID if !exists { @@ -163,7 +157,6 @@ func (s *ServerService) insertStateHistory(serverID uuid.UUID, state *model.Serv } func (s *ServerService) updateSessionDuration(server *model.Server, sessionType model.TrackSession) { - // Get configs using helper methods event, err := s.configService.GetEventConfig(server) if err != nil { event = &model.EventConfig{} @@ -181,9 +174,7 @@ func (s *ServerService) updateSessionDuration(server *model.Server, sessionType serverInstance.State.Track = event.Track serverInstance.State.MaxConnections = configuration.MaxConnections.ToInt() - // Check if session type has changed if serverInstance.State.Session != sessionType { - // Get new session ID for the new session sessionID := s.getNextSessionID(server.ID) s.sessionIDs.Store(server.ID, sessionID) } @@ -204,33 +195,27 @@ func (s *ServerService) updateSessionDuration(server *model.Server, sessionType } func (s *ServerService) GenerateServerPath(server *model.Server) { - // Get the base steamcmd path from environment variable steamCMDPath := env.GetSteamCMDDirPath() - server.FromSteamCMD = true server.Path = server.GenerateServerPath(steamCMDPath) + server.FromSteamCMD = true } func (s *ServerService) handleStateChange(server *model.Server, state *model.ServerState) { - // Update session duration when session changes s.updateSessionDuration(server, state.Session) - // Invalidate status cache when server state changes s.apiService.statusCache.InvalidateStatus(server.ServiceName) - // Cancel existing timer if any if debouncer, exists := s.debouncers.Load(server.ID); exists { pending := debouncer.(*pendingState) pending.timer.Stop() } - // Create new timer timer := time.NewTimer(5 * time.Minute) s.debouncers.Store(server.ID, &pendingState{ timer: timer, state: state, }) - // Start goroutine to handle the delayed insert go func() { <-timer.C if debouncer, exists := s.debouncers.Load(server.ID); exists { @@ -240,14 +225,12 @@ func (s *ServerService) handleStateChange(server *model.Server, state *model.Ser } }() - // If enough time has passed since last insert, insert immediately if s.shouldInsertStateHistory(server.ID) { s.insertStateHistory(server.ID, state) } } func (s *ServerService) StartAccServerRuntime(server *model.Server) { - // Get or create instance instanceInterface, exists := s.instances.Load(server.ID) var instance *tracking.AccServerInstance if !exists { @@ -259,20 +242,14 @@ func (s *ServerService) StartAccServerRuntime(server *model.Server) { instance = instanceInterface.(*tracking.AccServerInstance) } - // Invalidate config cache for this server before loading new configs serverIDStr := server.ID.String() s.configService.configCache.InvalidateServerCache(serverIDStr) s.updateSessionDuration(server, instance.State.Session) - // Ensure log tailing is running (regardless of server status) s.ensureLogTailing(server, instance) } -// GetAll -// Gets All rows from Server table. -// -// Args: // context.Context: Application context // Returns: // string: Application version @@ -304,10 +281,6 @@ func (s *ServerService) GetAll(ctx *fiber.Ctx, filter *model.ServerFilter) (*[]m return servers, nil } -// GetById -// Gets rows by ID from Server table. -// -// Args: // context.Context: Application context // Returns: // string: Application version @@ -334,22 +307,19 @@ func (as *ServerService) GetById(ctx *fiber.Ctx, serverID uuid.UUID) (*model.Ser return server, nil } -// CreateServerAsync starts server creation asynchronously and returns immediately func (s *ServerService) CreateServerAsync(ctx *fiber.Ctx, server *model.Server) error { - // Perform basic validation first + logging.Info("create server start") if err := server.Validate(); err != nil { + logging.Info("create server validation failed") return err } - // Generate server path s.GenerateServerPath(server) - // Create a background context that won't be cancelled when the HTTP request ends bgCtx := context.Background() - // Start the actual creation process in a goroutine go func() { - // Create server in background without using fiber.Ctx + logging.Info("create server start background") if err := s.createServerBackground(bgCtx, server); err != nil { logging.Error("Async server creation failed for server %s: %v", server.ID, err) s.webSocketService.BroadcastError(server.ID, "Server creation failed", err.Error()) @@ -360,127 +330,128 @@ func (s *ServerService) CreateServerAsync(ctx *fiber.Ctx, server *model.Server) return nil } -func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error { - // Broadcast step: validation - s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusInProgress, - model.GetStepDescription(model.StepValidation), "") +type createServerStep struct { + stepType model.ServerCreationStep + important bool + callback func() (string, error) + description string +} - // Validate basic server configuration - if err := server.Validate(); err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusFailed, - "", fmt.Sprintf("Validation failed: %v", err)) - return err +func (s *ServerService) createServerBackground(ctx context.Context, server *model.Server) error { + var serverPort int + var tcpPorts, udpPorts []int + + steps := []createServerStep{ + { + stepType: model.StepValidation, + important: true, + description: "Server configuration validated successfully", + callback: func() (string, error) { + if err := server.Validate(); err != nil { + return "", fmt.Errorf("validation failed: %v", err) + } + return "Server configuration validated successfully", nil + }, + }, + { + stepType: model.StepDirectoryCreation, + important: true, + description: "Server directories prepared", + callback: func() (string, error) { + return "Server directories prepared", nil + }, + }, + { + stepType: model.StepSteamDownload, + important: true, + description: "Server files downloaded successfully", + callback: func() (string, error) { + if err := s.steamService.InstallServerWithWebSocket(ctx, server.Path, &server.ID, s.webSocketService); err != nil { + return "", fmt.Errorf("failed to install server: %v", err) + } + return "Server files downloaded successfully", nil + }, + }, + { + stepType: model.StepConfigGeneration, + important: true, + description: "", + callback: func() (string, error) { + ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount) + if err != nil { + return "", fmt.Errorf("failed to find available ports: %v", err) + } + + serverPort = ports[0] + + if err := s.updateServerPort(server, serverPort); err != nil { + return "", fmt.Errorf("failed to update server configuration: %v", err) + } + + return fmt.Sprintf("Server configuration generated (Port: %d)", serverPort), nil + }, + }, + { + stepType: model.StepServiceCreation, + important: true, + description: "", + callback: func() (string, error) { + execPath := filepath.Join(server.GetServerPath(), "accServer.exe") + serverWorkingDir := filepath.Join(server.GetServerPath(), "server") + if err := s.windowsService.CreateService(ctx, server.ServiceName, execPath, serverWorkingDir, nil); err != nil { + return "", fmt.Errorf("failed to create Windows service: %v", err) + } + return fmt.Sprintf("Windows service '%s' created successfully", server.ServiceName), nil + }, + }, + { + stepType: model.StepFirewallRules, + important: false, + description: "", + callback: func() (string, error) { + s.configureFirewall(server) + tcpPorts = []int{serverPort} + udpPorts = []int{serverPort} + if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil { + return "", fmt.Errorf("failed to create firewall rules: %v", err) + } + return fmt.Sprintf("Firewall rules created for port %d", serverPort), nil + }, + }, + { + stepType: model.StepDatabaseSave, + important: true, + description: "Server saved to database successfully", + callback: func() (string, error) { + if err := s.repository.Insert(ctx, server); err != nil { + return "", fmt.Errorf("failed to insert server into database: %v", err) + } + return "Server saved to database successfully", nil + }, + }, } - s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusCompleted, - "Server configuration validated successfully", "") + for i, step := range steps { + s.webSocketService.BroadcastStep(server.ID, step.stepType, model.StatusInProgress, + model.GetStepDescription(step.stepType), "") - // Broadcast step: directory creation - s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusInProgress, - model.GetStepDescription(model.StepDirectoryCreation), "") + successMessage, err := step.callback() + if err != nil { + s.webSocketService.BroadcastStep(server.ID, step.stepType, model.StatusFailed, + "", err.Error()) - // Directory creation is handled within InstallServer, so we mark it as completed - s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusCompleted, - "Server directories prepared", "") + if step.important { + s.rollbackSteps(ctx, server, steps[:i], tcpPorts, udpPorts) + return err + } + } - // Broadcast step: Steam download - s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusInProgress, - model.GetStepDescription(model.StepSteamDownload), "") - - // Install server using SteamCMD with streaming support - if err := s.steamService.InstallServerWithWebSocket(ctx.UserContext(), server.Path, &server.ID, s.webSocketService); err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusFailed, - "", fmt.Sprintf("Steam installation failed: %v", err)) - return fmt.Errorf("failed to install server: %v", err) + s.webSocketService.BroadcastStep(server.ID, step.stepType, model.StatusCompleted, + successMessage, "") } - s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusCompleted, - "Server files downloaded successfully", "") - - // Broadcast step: config generation - s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusInProgress, - model.GetStepDescription(model.StepConfigGeneration), "") - - // Find available ports for server - ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount) - if err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed, - "", fmt.Sprintf("Failed to find available ports: %v", err)) - return fmt.Errorf("failed to find available ports: %v", err) - } - - // Use the first port for both TCP and UDP - serverPort := ports[0] - - // Update server configuration with the allocated port - if err := s.updateServerPort(server, serverPort); err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed, - "", fmt.Sprintf("Failed to update server configuration: %v", err)) - return fmt.Errorf("failed to update server configuration: %v", err) - } - - s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusCompleted, - fmt.Sprintf("Server configuration generated (Port: %d)", serverPort), "") - - // Broadcast step: service creation - s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusInProgress, - model.GetStepDescription(model.StepServiceCreation), "") - - // Create Windows service with correct paths - execPath := filepath.Join(server.GetServerPath(), "accServer.exe") - serverWorkingDir := filepath.Join(server.GetServerPath(), "server") - if err := s.windowsService.CreateService(ctx.UserContext(), server.ServiceName, execPath, serverWorkingDir, nil); err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusFailed, - "", fmt.Sprintf("Failed to create Windows service: %v", err)) - // Cleanup on failure - s.steamService.UninstallServer(server.Path) - return fmt.Errorf("failed to create Windows service: %v", err) - } - - s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusCompleted, - fmt.Sprintf("Windows service '%s' created successfully", server.ServiceName), "") - - // Broadcast step: firewall rules - s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusInProgress, - model.GetStepDescription(model.StepFirewallRules), "") - - s.configureFirewall(server) - tcpPorts := []int{serverPort} - udpPorts := []int{serverPort} - if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusFailed, - "", fmt.Sprintf("Failed to create firewall rules: %v", err)) - // Cleanup on failure - s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName) - s.steamService.UninstallServer(server.Path) - return fmt.Errorf("failed to create firewall rules: %v", err) - } - - s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusCompleted, - fmt.Sprintf("Firewall rules created for port %d", serverPort), "") - - // Broadcast step: database save - s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusInProgress, - model.GetStepDescription(model.StepDatabaseSave), "") - - // Insert server into database - if err := s.repository.Insert(ctx.UserContext(), server); err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusFailed, - "", fmt.Sprintf("Failed to save server to database: %v", err)) - // Cleanup on failure - s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts) - s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName) - s.steamService.UninstallServer(server.Path) - return fmt.Errorf("failed to insert server into database: %v", err) - } - - s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusCompleted, - "Server saved to database successfully", "") - - // Initialize server runtime s.StartAccServerRuntime(server) - // Broadcast completion s.webSocketService.BroadcastStep(server.ID, model.StepCompleted, model.StatusCompleted, model.GetStepDescription(model.StepCompleted), "") @@ -490,150 +461,34 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error return nil } -// createServerBackground performs server creation in background without fiber.Ctx -func (s *ServerService) createServerBackground(ctx context.Context, server *model.Server) error { - // Broadcast step: validation - s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusInProgress, - model.GetStepDescription(model.StepValidation), "") - - // Validate basic server configuration (already done in async method, but double-check) - if err := server.Validate(); err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusFailed, - "", fmt.Sprintf("Validation failed: %v", err)) - return err +func (s *ServerService) rollbackSteps(ctx context.Context, server *model.Server, completedSteps []createServerStep, tcpPorts, udpPorts []int) { + for i := len(completedSteps) - 1; i >= 0; i-- { + step := completedSteps[i] + switch step.stepType { + case model.StepDatabaseSave: + s.repository.Delete(ctx, server.ID) + case model.StepFirewallRules: + if len(tcpPorts) > 0 && len(udpPorts) > 0 { + s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts) + } + case model.StepServiceCreation: + s.windowsService.DeleteService(ctx, server.ServiceName) + case model.StepSteamDownload: + s.steamService.UninstallServer(server.Path) + } } - - s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusCompleted, - "Server configuration validated successfully", "") - - // Broadcast step: directory creation - s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusInProgress, - model.GetStepDescription(model.StepDirectoryCreation), "") - - // Directory creation is handled within InstallServer, so we mark it as completed - s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusCompleted, - "Server directories prepared", "") - - // Broadcast step: Steam download - s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusInProgress, - model.GetStepDescription(model.StepSteamDownload), "") - - // Install server using SteamCMD with streaming support - if err := s.steamService.InstallServerWithWebSocket(ctx, server.GetServerPath(), &server.ID, s.webSocketService); err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusFailed, - "", fmt.Sprintf("Steam installation failed: %v", err)) - return fmt.Errorf("failed to install server: %v", err) - } - - s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusCompleted, - "Server files downloaded successfully", "") - - // Broadcast step: config generation - s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusInProgress, - model.GetStepDescription(model.StepConfigGeneration), "") - - // Find available ports for server - ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount) - if err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed, - "", fmt.Sprintf("Failed to find available ports: %v", err)) - return fmt.Errorf("failed to find available ports: %v", err) - } - - // Use the first port for both TCP and UDP - serverPort := ports[0] - - // Update server configuration with the allocated port - if err := s.updateServerPort(server, serverPort); err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed, - "", fmt.Sprintf("Failed to update server configuration: %v", err)) - return fmt.Errorf("failed to update server configuration: %v", err) - } - - s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusCompleted, - fmt.Sprintf("Server configuration generated (Port: %d)", serverPort), "") - - // Broadcast step: service creation - s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusInProgress, - model.GetStepDescription(model.StepServiceCreation), "") - - // Create Windows service with correct paths - execPath := filepath.Join(server.GetServerPath(), "accServer.exe") - serverWorkingDir := filepath.Join(server.GetServerPath(), "server") - if err := s.windowsService.CreateService(ctx, server.ServiceName, execPath, serverWorkingDir, nil); err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusFailed, - "", fmt.Sprintf("Failed to create Windows service: %v", err)) - // Cleanup on failure - s.steamService.UninstallServer(server.Path) - return fmt.Errorf("failed to create Windows service: %v", err) - } - - s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusCompleted, - fmt.Sprintf("Windows service '%s' created successfully", server.ServiceName), "") - - // Broadcast step: firewall rules - s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusInProgress, - model.GetStepDescription(model.StepFirewallRules), "") - - s.configureFirewall(server) - tcpPorts := []int{serverPort} - udpPorts := []int{serverPort} - if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusFailed, - "", fmt.Sprintf("Failed to create firewall rules: %v", err)) - // Cleanup on failure - s.windowsService.DeleteService(ctx, server.ServiceName) - s.steamService.UninstallServer(server.Path) - return fmt.Errorf("failed to create firewall rules: %v", err) - } - - s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusCompleted, - fmt.Sprintf("Firewall rules created for port %d", serverPort), "") - - // Broadcast step: database save - s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusInProgress, - model.GetStepDescription(model.StepDatabaseSave), "") - - // Insert server into database - if err := s.repository.Insert(ctx, server); err != nil { - s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusFailed, - "", fmt.Sprintf("Failed to save server to database: %v", err)) - // Cleanup on failure - s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts) - s.windowsService.DeleteService(ctx, server.ServiceName) - s.steamService.UninstallServer(server.Path) - return fmt.Errorf("failed to insert server into database: %v", err) - } - - s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusCompleted, - "Server saved to database successfully", "") - - // Initialize server runtime - s.StartAccServerRuntime(server) - - // Broadcast completion - s.webSocketService.BroadcastStep(server.ID, model.StepCompleted, model.StatusCompleted, - model.GetStepDescription(model.StepCompleted), "") - - s.webSocketService.BroadcastComplete(server.ID, true, - fmt.Sprintf("Server '%s' created successfully on port %d", server.Name, serverPort)) - - return nil } func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error { - // Get server details server, err := s.repository.GetByID(ctx.UserContext(), serverID) if err != nil { return fmt.Errorf("failed to get server details: %v", err) } - // Stop and remove Windows service if err := s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName); err != nil { logging.Error("Failed to delete Windows service: %v", err) } - // Remove firewall rules configuration, err := s.configService.GetConfiguration(server) if err != nil { logging.Error("Failed to get configuration for server %d: %v", server.ID, err) @@ -644,17 +499,14 @@ func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error { logging.Error("Failed to delete firewall rules: %v", err) } - // Uninstall server files if err := s.steamService.UninstallServer(server.Path); err != nil { logging.Error("Failed to uninstall server: %v", err) } - // Remove from database if err := s.repository.Delete(ctx.UserContext(), serverID); err != nil { return fmt.Errorf("failed to delete server from database: %v", err) } - // Cleanup runtime resources if tailer, exists := s.logTailers.Load(server.ID); exists { tailer.(*tracking.LogTailer).Stop() s.logTailers.Delete(server.ID) @@ -664,84 +516,27 @@ func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error { s.debouncers.Delete(server.ID) s.sessionIDs.Delete(server.ID) - // Invalidate status cache for deleted server s.apiService.statusCache.InvalidateStatus(server.ServiceName) return nil } -func (s *ServerService) UpdateServer(ctx *fiber.Ctx, server *model.Server) error { - // Validate server configuration - if err := server.Validate(); err != nil { - return err - } - - // Get existing server details - existingServer, err := s.repository.GetByID(ctx.UserContext(), server.ID) - if err != nil { - return fmt.Errorf("failed to get existing server details: %v", err) - } - - // Update server files if path changed - if existingServer.Path != server.Path { - if err := s.steamService.InstallServer(ctx.UserContext(), server.Path, &server.ID); err != nil { - return fmt.Errorf("failed to install server to new location: %v", err) - } - // Clean up old installation - if err := s.steamService.UninstallServer(existingServer.Path); err != nil { - logging.Error("Failed to remove old server installation: %v", err) - } - } - - // Update Windows service if necessary - if existingServer.ServiceName != server.ServiceName || existingServer.Path != server.Path { - execPath := filepath.Join(server.GetServerPath(), "accServer.exe") - serverWorkingDir := server.GetServerPath() - if err := s.windowsService.UpdateService(ctx.UserContext(), server.ServiceName, execPath, serverWorkingDir, nil); err != nil { - return fmt.Errorf("failed to update Windows service: %v", err) - } - } - - // Update firewall rules if service name changed - if existingServer.ServiceName != server.ServiceName { - if err := s.configureFirewall(server); err != nil { - return fmt.Errorf("failed to update firewall rules: %v", err) - } - // Invalidate cache for old service name - s.apiService.statusCache.InvalidateStatus(existingServer.ServiceName) - } - - // Update database record - if err := s.repository.Update(ctx.UserContext(), server); err != nil { - return fmt.Errorf("failed to update server in database: %v", err) - } - - // Restart server runtime - s.StartAccServerRuntime(server) - - return nil -} - func (s *ServerService) configureFirewall(server *model.Server) error { - // Find available ports for the server ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount) if err != nil { return fmt.Errorf("failed to find available ports: %v", err) } - // Use the first port for both TCP and UDP serverPort := ports[0] tcpPorts := []int{serverPort} udpPorts := []int{serverPort} logging.Info("Configuring firewall for server %d with port %d", server.ID, serverPort) - // Configure firewall rules if err := s.firewallService.UpdateServerRules(server.Name, tcpPorts, udpPorts); err != nil { return fmt.Errorf("failed to configure firewall: %v", err) } - // Update server configuration with the allocated port if err := s.updateServerPort(server, serverPort); err != nil { return fmt.Errorf("failed to update server configuration: %v", err) } @@ -750,7 +545,6 @@ func (s *ServerService) configureFirewall(server *model.Server) error { } func (s *ServerService) updateServerPort(server *model.Server, port int) error { - // Load current configuration config, err := s.configService.GetConfiguration(server) if err != nil { return fmt.Errorf("failed to load server configuration: %v", err) @@ -759,7 +553,6 @@ func (s *ServerService) updateServerPort(server *model.Server, port int) error { config.TcpPort = model.IntString(port) config.UdpPort = model.IntString(port) - // Save the updated configuration if err := s.configService.SaveConfiguration(server, config); err != nil { return fmt.Errorf("failed to save server configuration: %v", err) } diff --git a/local/service/service.go b/local/service/service.go index b368cf3..46df608 100644 --- a/local/service/service.go +++ b/local/service/service.go @@ -7,17 +7,12 @@ import ( "go.uber.org/dig" ) -// InitializeServices -// Initializes Dependency Injection modules for services -// -// Args: -// *dig.Container: Dig Container +// *dig.Container: Dig Container func InitializeServices(c *dig.Container) { logging.Debug("Initializing repositories") repository.InitializeRepositories(c) logging.Debug("Registering services") - // Provide services c.Provide(NewSteamService) c.Provide(NewServerService) c.Provide(NewStateHistoryService) diff --git a/local/service/service_control.go b/local/service/service_control.go index c284959..b1794d5 100644 --- a/local/service/service_control.go +++ b/local/service/service_control.go @@ -24,9 +24,9 @@ func NewServiceControlService(repository *repository.ServiceControlRepository, repository: repository, serverRepository: serverRepository, statusCache: model.NewServerStatusCache(model.CacheConfig{ - ExpirationTime: 30 * time.Second, // Cache expires after 30 seconds - ThrottleTime: 5 * time.Second, // Minimum 5 seconds between checks - DefaultStatus: model.StatusRunning, // Default to running if throttled + ExpirationTime: 30 * time.Second, + ThrottleTime: 5 * time.Second, + DefaultStatus: model.StatusRunning, }), windowsService: NewWindowsService(), } @@ -42,18 +42,15 @@ func (as *ServiceControlService) GetStatus(ctx *fiber.Ctx) (string, error) { return "", err } - // Try to get status from cache if status, shouldCheck := as.statusCache.GetStatus(serviceName); !shouldCheck { return status.String(), nil } - // If cache miss or expired, check actual status statusStr, err := as.StatusServer(serviceName) if err != nil { return "", err } - // Parse and update cache with new status status := model.ParseServiceStatus(statusStr) as.statusCache.UpdateStatus(serviceName, status) return status.String(), nil @@ -65,7 +62,6 @@ func (as *ServiceControlService) ServiceControlStartServer(ctx *fiber.Ctx) (stri return "", err } - // Update status cache for this service before starting as.statusCache.UpdateStatus(serviceName, model.StatusStarting) _, err = as.StartServer(serviceName) @@ -77,7 +73,6 @@ func (as *ServiceControlService) ServiceControlStartServer(ctx *fiber.Ctx) (stri return "", err } - // Parse and update cache with new status status := model.ParseServiceStatus(statusStr) as.statusCache.UpdateStatus(serviceName, status) return status.String(), nil @@ -89,7 +84,6 @@ func (as *ServiceControlService) ServiceControlStopServer(ctx *fiber.Ctx) (strin return "", err } - // Update status cache for this service before stopping as.statusCache.UpdateStatus(serviceName, model.StatusStopping) _, err = as.StopServer(serviceName) @@ -101,7 +95,6 @@ func (as *ServiceControlService) ServiceControlStopServer(ctx *fiber.Ctx) (strin return "", err } - // Parse and update cache with new status status := model.ParseServiceStatus(statusStr) as.statusCache.UpdateStatus(serviceName, status) return status.String(), nil @@ -113,7 +106,6 @@ func (as *ServiceControlService) ServiceControlRestartServer(ctx *fiber.Ctx) (st return "", err } - // Update status cache for this service before restarting as.statusCache.UpdateStatus(serviceName, model.StatusRestarting) _, err = as.RestartServer(serviceName) @@ -125,7 +117,6 @@ func (as *ServiceControlService) ServiceControlRestartServer(ctx *fiber.Ctx) (st return "", err } - // Parse and update cache with new status status := model.ParseServiceStatus(statusStr) as.statusCache.UpdateStatus(serviceName, status) return status.String(), nil @@ -135,20 +126,16 @@ func (as *ServiceControlService) StatusServer(serviceName string) (string, error return as.windowsService.Status(context.Background(), serviceName) } -// GetCachedStatus gets the cached status for a service name without requiring fiber context func (as *ServiceControlService) GetCachedStatus(serviceName string) (string, error) { - // Try to get status from cache if status, shouldCheck := as.statusCache.GetStatus(serviceName); !shouldCheck { return status.String(), nil } - // If cache miss or expired, check actual status statusStr, err := as.StatusServer(serviceName) if err != nil { return "", err } - // Parse and update cache with new status status := model.ParseServiceStatus(statusStr) as.statusCache.UpdateStatus(serviceName, status) return status.String(), nil diff --git a/local/service/service_manager.go b/local/service/service_manager.go index 8de9965..13d1cad 100644 --- a/local/service/service_manager.go +++ b/local/service/service_manager.go @@ -6,7 +6,7 @@ import ( ) type ServiceManager struct { - executor *command.CommandExecutor + executor *command.CommandExecutor psExecutor *command.CommandExecutor } @@ -24,17 +24,14 @@ func NewServiceManager() *ServiceManager { } func (s *ServiceManager) ManageService(serviceName, action string) (string, error) { - // Run NSSM command through PowerShell to ensure elevation output, err := s.psExecutor.ExecuteWithOutput("-nologo", "-noprofile", ".\\nssm", action, serviceName) if err != nil { return "", err } - // Clean up output by removing null bytes and trimming whitespace cleaned := strings.TrimSpace(strings.ReplaceAll(output, "\x00", "")) - // Remove \r\n from status strings cleaned = strings.TrimSuffix(cleaned, "\r\n") - + return cleaned, nil } @@ -51,11 +48,9 @@ func (s *ServiceManager) Stop(serviceName string) (string, error) { } func (s *ServiceManager) Restart(serviceName string) (string, error) { - // First stop the service if _, err := s.Stop(serviceName); err != nil { return "", err } - // Then start it again return s.Start(serviceName) -} \ No newline at end of file +} diff --git a/local/service/state_history.go b/local/service/state_history.go index 29f7735..5bdd441 100644 --- a/local/service/state_history.go +++ b/local/service/state_history.go @@ -46,7 +46,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH eg, gCtx := errgroup.WithContext(ctx.UserContext()) - // Get Summary Stats (Peak/Avg Players, Total Sessions) eg.Go(func() error { summary, err := s.repository.GetSummaryStats(gCtx, filter) if err != nil { @@ -61,7 +60,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH return nil }) - // Get Total Playtime eg.Go(func() error { playtime, err := s.repository.GetTotalPlaytime(gCtx, filter) if err != nil { @@ -74,7 +72,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH return nil }) - // Get Player Count Over Time eg.Go(func() error { playerCount, err := s.repository.GetPlayerCountOverTime(gCtx, filter) if err != nil { @@ -87,7 +84,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH return nil }) - // Get Session Types eg.Go(func() error { sessionTypes, err := s.repository.GetSessionTypes(gCtx, filter) if err != nil { @@ -100,7 +96,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH return nil }) - // Get Daily Activity eg.Go(func() error { dailyActivity, err := s.repository.GetDailyActivity(gCtx, filter) if err != nil { @@ -113,7 +108,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH return nil }) - // Get Recent Sessions eg.Go(func() error { recentSessions, err := s.repository.GetRecentSessions(gCtx, filter) if err != nil { diff --git a/local/service/steam_service.go b/local/service/steam_service.go index ba0ea83..f6ef459 100644 --- a/local/service/steam_service.go +++ b/local/service/steam_service.go @@ -22,12 +22,11 @@ const ( ) type SteamService struct { - executor *command.CommandExecutor - interactiveExecutor *command.InteractiveCommandExecutor - repository *repository.SteamCredentialsRepository - tfaManager *model.Steam2FAManager - pathValidator *security.PathValidator - downloadVerifier *security.DownloadVerifier + executor *command.CommandExecutor + repository *repository.SteamCredentialsRepository + tfaManager *model.Steam2FAManager + pathValidator *security.PathValidator + downloadVerifier *security.DownloadVerifier } func NewSteamService(repository *repository.SteamCredentialsRepository, tfaManager *model.Steam2FAManager) *SteamService { @@ -37,12 +36,11 @@ func NewSteamService(repository *repository.SteamCredentialsRepository, tfaManag } return &SteamService{ - executor: baseExecutor, - interactiveExecutor: command.NewInteractiveCommandExecutor(baseExecutor, tfaManager), - repository: repository, - tfaManager: tfaManager, - pathValidator: security.NewPathValidator(), - downloadVerifier: security.NewDownloadVerifier(), + executor: baseExecutor, + repository: repository, + tfaManager: tfaManager, + pathValidator: security.NewPathValidator(), + downloadVerifier: security.NewDownloadVerifier(), } } @@ -58,21 +56,17 @@ func (s *SteamService) SaveCredentials(ctx context.Context, creds *model.SteamCr } func (s *SteamService) ensureSteamCMD(_ context.Context) error { - // Get SteamCMD path from environment variable steamCMDPath := env.GetSteamCMDPath() steamCMDDir := filepath.Dir(steamCMDPath) - // Check if SteamCMD exists if _, err := os.Stat(steamCMDPath); !os.IsNotExist(err) { return nil } - // Create directory if it doesn't exist if err := os.MkdirAll(steamCMDDir, 0755); err != nil { return fmt.Errorf("failed to create SteamCMD directory: %v", err) } - // Download and install SteamCMD securely logging.Info("Downloading SteamCMD...") steamCMDZip := filepath.Join(steamCMDDir, "steamcmd.zip") if err := s.downloadVerifier.VerifyAndDownload( @@ -82,150 +76,27 @@ func (s *SteamService) ensureSteamCMD(_ context.Context) error { return fmt.Errorf("failed to download SteamCMD: %v", err) } - // Extract SteamCMD logging.Info("Extracting SteamCMD...") if err := s.executor.Execute("-Command", fmt.Sprintf("Expand-Archive -Path 'steamcmd.zip' -DestinationPath '%s'", steamCMDDir)); err != nil { return fmt.Errorf("failed to extract SteamCMD: %v", err) } - // Clean up zip file os.Remove("steamcmd.zip") return nil } -func (s *SteamService) InstallServer(ctx context.Context, installPath string, serverID *uuid.UUID) error { - if err := s.ensureSteamCMD(ctx); err != nil { - return err - } - - // Validate installation path for security - if err := s.pathValidator.ValidateInstallPath(installPath); err != nil { - return fmt.Errorf("invalid installation path: %v", err) - } - - // Convert to absolute path and ensure proper Windows path format - absPath, err := filepath.Abs(installPath) - if err != nil { - return fmt.Errorf("failed to get absolute path: %v", err) - } - absPath = filepath.Clean(absPath) - - // Ensure install path exists - if err := os.MkdirAll(absPath, 0755); err != nil { - return fmt.Errorf("failed to create install directory: %v", err) - } - - // Get Steam credentials - creds, err := s.GetCredentials(ctx) - if err != nil { - return fmt.Errorf("failed to get Steam credentials: %v", err) - } - - // Get SteamCMD path from environment variable - steamCMDPath := env.GetSteamCMDPath() - - // Build SteamCMD command arguments - steamCMDArgs := []string{ - "+force_install_dir", absPath, - "+login", - } - - if creds != nil && creds.Username != "" { - logging.Info("Using Steam credentials for user: %s", creds.Username) - steamCMDArgs = append(steamCMDArgs, creds.Username) - if creds.Password != "" { - steamCMDArgs = append(steamCMDArgs, creds.Password) - } - } else { - logging.Info("Using anonymous Steam login") - steamCMDArgs = append(steamCMDArgs, "anonymous") - } - - steamCMDArgs = append(steamCMDArgs, - "+app_update", ACCServerAppID, - "validate", - "+quit", - ) - - // Execute SteamCMD directly without PowerShell wrapper to get better output capture - args := steamCMDArgs - - // Use interactive executor to handle potential 2FA prompts with timeout - logging.Info("Installing ACC server to %s...", absPath) - logging.Info("SteamCMD command: %s %s", steamCMDPath, strings.Join(args, " ")) - - // Create a context with timeout to prevent hanging indefinitely - timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) // Increased timeout - defer cancel() - - // Update the executor to use SteamCMD directly - originalExePath := s.interactiveExecutor.ExePath - s.interactiveExecutor.ExePath = steamCMDPath - defer func() { - s.interactiveExecutor.ExePath = originalExePath - }() - - if err := s.interactiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil { - logging.Error("SteamCMD execution failed: %v", err) - if timeoutCtx.Err() == context.DeadlineExceeded { - return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required") - } - return fmt.Errorf("failed to run SteamCMD: %v", err) - } - - logging.Info("SteamCMD execution completed successfully, proceeding with verification...") - - // Add a delay to allow Steam to properly cleanup - logging.Info("Waiting for Steam operations to complete...") - time.Sleep(5 * time.Second) - - // Verify installation - exePath := filepath.Join(absPath, "server", "accServer.exe") - logging.Info("Checking for ACC server executable at: %s", exePath) - - if _, err := os.Stat(exePath); os.IsNotExist(err) { - // Log directory contents to help debug - logging.Info("accServer.exe not found, checking directory contents...") - if entries, dirErr := os.ReadDir(absPath); dirErr == nil { - logging.Info("Contents of %s:", absPath) - for _, entry := range entries { - logging.Info(" - %s (dir: %v)", entry.Name(), entry.IsDir()) - } - } - - // Check if there's a server subdirectory - serverDir := filepath.Join(absPath, "server") - if entries, dirErr := os.ReadDir(serverDir); dirErr == nil { - logging.Info("Contents of %s:", serverDir) - for _, entry := range entries { - logging.Info(" - %s (dir: %v)", entry.Name(), entry.IsDir()) - } - } else { - logging.Info("Server directory %s does not exist or cannot be read: %v", serverDir, dirErr) - } - - return fmt.Errorf("server installation failed: accServer.exe not found in %s", exePath) - } - - logging.Info("Server installation completed successfully - accServer.exe found at %s", exePath) - return nil -} - -// InstallServerWithWebSocket installs a server with WebSocket output streaming func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPath string, serverID *uuid.UUID, wsService *WebSocketService) error { if err := s.ensureSteamCMD(ctx); err != nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Error ensuring SteamCMD: %v", err), true) return err } - // Validate installation path for security if err := s.pathValidator.ValidateInstallPath(installPath); err != nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Invalid installation path: %v", err), true) return fmt.Errorf("invalid installation path: %v", err) } - // Convert to absolute path and ensure proper Windows path format absPath, err := filepath.Abs(installPath) if err != nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to get absolute path: %v", err), true) @@ -233,7 +104,6 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa } absPath = filepath.Clean(absPath) - // Ensure install path exists if err := os.MkdirAll(absPath, 0755); err != nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to create install directory: %v", err), true) return fmt.Errorf("failed to create install directory: %v", err) @@ -241,17 +111,14 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Installation directory prepared: %s", absPath), false) - // Get Steam credentials creds, err := s.GetCredentials(ctx) if err != nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to get Steam credentials: %v", err), true) return fmt.Errorf("failed to get Steam credentials: %v", err) } - // Get SteamCMD path from environment variable steamCMDPath := env.GetSteamCMDPath() - // Build SteamCMD command arguments steamCMDArgs := []string{ "+force_install_dir", absPath, "+login", @@ -274,27 +141,32 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa "+quit", ) - // Execute SteamCMD with WebSocket output streaming args := steamCMDArgs wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Starting SteamCMD: %s %s", steamCMDPath, strings.Join(args, " ")), false) - // Create a context with timeout to prevent hanging indefinitely timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) defer cancel() - // Update the executor to use SteamCMD directly - originalExePath := s.interactiveExecutor.ExePath - s.interactiveExecutor.ExePath = steamCMDPath - defer func() { - s.interactiveExecutor.ExePath = originalExePath - }() + callbackConfig := &command.CallbackConfig{ + OnOutput: func(serverID uuid.UUID, output string, isError bool) { + wsService.BroadcastSteamOutput(serverID, output, isError) + }, + OnCommand: func(serverID uuid.UUID, command string, args []string, completed bool, success bool, error string) { + if completed { + if success { + wsService.BroadcastSteamOutput(serverID, "Command completed successfully", false) + } else { + wsService.BroadcastSteamOutput(serverID, fmt.Sprintf("Command failed: %s", error), true) + } + } + }, + } - // Create a modified interactive executor that streams output to WebSocket - wsInteractiveExecutor := command.NewInteractiveCommandExecutorWithWebSocket(s.executor, s.tfaManager, wsService, *serverID) - wsInteractiveExecutor.ExePath = steamCMDPath + callbackInteractiveExecutor := command.NewCallbackInteractiveCommandExecutor(s.executor, s.tfaManager, callbackConfig, *serverID) + callbackInteractiveExecutor.ExePath = steamCMDPath - if err := wsInteractiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil { + if err := callbackInteractiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("SteamCMD execution failed: %v", err), true) if timeoutCtx.Err() == context.DeadlineExceeded { return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required") @@ -304,11 +176,9 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa wsService.BroadcastSteamOutput(*serverID, "SteamCMD execution completed successfully, proceeding with verification...", false) - // Add a delay to allow Steam to properly cleanup wsService.BroadcastSteamOutput(*serverID, "Waiting for Steam operations to complete...", false) time.Sleep(5 * time.Second) - // Verify installation exePath := filepath.Join(absPath, "server", "accServer.exe") wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Checking for ACC server executable at: %s", exePath), false) @@ -322,7 +192,6 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa } } - // Check if there's a server subdirectory serverDir := filepath.Join(absPath, "server") if entries, dirErr := os.ReadDir(serverDir); dirErr == nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Contents of %s:", serverDir), false) @@ -341,8 +210,117 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa return nil } -func (s *SteamService) UpdateServer(ctx context.Context, installPath string, serverID *uuid.UUID) error { - return s.InstallServer(ctx, installPath, serverID) // Same process as install +func (s *SteamService) InstallServerWithCallbacks(ctx context.Context, installPath string, serverID *uuid.UUID, outputCallback command.OutputCallback) error { + if err := s.ensureSteamCMD(ctx); err != nil { + outputCallback(*serverID, fmt.Sprintf("Error ensuring SteamCMD: %v", err), true) + return err + } + + if err := s.pathValidator.ValidateInstallPath(installPath); err != nil { + outputCallback(*serverID, fmt.Sprintf("Invalid installation path: %v", err), true) + return fmt.Errorf("invalid installation path: %v", err) + } + + absPath, err := filepath.Abs(installPath) + if err != nil { + outputCallback(*serverID, fmt.Sprintf("Failed to get absolute path: %v", err), true) + return fmt.Errorf("failed to get absolute path: %v", err) + } + absPath = filepath.Clean(absPath) + + if err := os.MkdirAll(absPath, 0755); err != nil { + outputCallback(*serverID, fmt.Sprintf("Failed to create install directory: %v", err), true) + return fmt.Errorf("failed to create install directory: %v", err) + } + + outputCallback(*serverID, fmt.Sprintf("Installation directory prepared: %s", absPath), false) + + creds, err := s.GetCredentials(ctx) + if err != nil { + outputCallback(*serverID, fmt.Sprintf("Failed to get Steam credentials: %v", err), true) + return fmt.Errorf("failed to get Steam credentials: %v", err) + } + + steamCMDPath := env.GetSteamCMDPath() + + steamCMDArgs := []string{ + "+force_install_dir", absPath, + "+login", + } + + if creds != nil && creds.Username != "" { + outputCallback(*serverID, fmt.Sprintf("Using Steam credentials for user: %s", creds.Username), false) + steamCMDArgs = append(steamCMDArgs, creds.Username) + if creds.Password != "" { + steamCMDArgs = append(steamCMDArgs, creds.Password) + } + } else { + outputCallback(*serverID, "Using anonymous Steam login", false) + steamCMDArgs = append(steamCMDArgs, "anonymous") + } + + steamCMDArgs = append(steamCMDArgs, + "+app_update", ACCServerAppID, + "validate", + "+quit", + ) + + args := steamCMDArgs + + outputCallback(*serverID, fmt.Sprintf("Starting SteamCMD: %s %s", steamCMDPath, strings.Join(args, " ")), false) + + timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) + defer cancel() + + callbacks := &command.CallbackConfig{ + OnOutput: outputCallback, + } + + callbackExecutor := command.NewCallbackInteractiveCommandExecutor(s.executor, s.tfaManager, callbacks, *serverID) + callbackExecutor.ExePath = steamCMDPath + + if err := callbackExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil { + outputCallback(*serverID, fmt.Sprintf("SteamCMD execution failed: %v", err), true) + if timeoutCtx.Err() == context.DeadlineExceeded { + return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required") + } + return fmt.Errorf("failed to run SteamCMD: %v", err) + } + + outputCallback(*serverID, "SteamCMD execution completed successfully, proceeding with verification...", false) + + outputCallback(*serverID, "Waiting for Steam operations to complete...", false) + time.Sleep(5 * time.Second) + + exePath := filepath.Join(absPath, "server", "accServer.exe") + outputCallback(*serverID, fmt.Sprintf("Checking for ACC server executable at: %s", exePath), false) + + if _, err := os.Stat(exePath); os.IsNotExist(err) { + outputCallback(*serverID, "accServer.exe not found, checking directory contents...", false) + + if entries, dirErr := os.ReadDir(absPath); dirErr == nil { + outputCallback(*serverID, fmt.Sprintf("Contents of %s:", absPath), false) + for _, entry := range entries { + outputCallback(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false) + } + } + + serverDir := filepath.Join(absPath, "server") + if entries, dirErr := os.ReadDir(serverDir); dirErr == nil { + outputCallback(*serverID, fmt.Sprintf("Contents of %s:", serverDir), false) + for _, entry := range entries { + outputCallback(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false) + } + } else { + outputCallback(*serverID, fmt.Sprintf("Server directory %s does not exist or cannot be read: %v", serverDir, dirErr), true) + } + + outputCallback(*serverID, fmt.Sprintf("Server installation failed: accServer.exe not found in %s", exePath), true) + return fmt.Errorf("server installation failed: accServer.exe not found in %s", exePath) + } + + outputCallback(*serverID, fmt.Sprintf("Server installation completed successfully - accServer.exe found at %s", exePath), false) + return nil } func (s *SteamService) UninstallServer(installPath string) error { diff --git a/local/service/websocket.go b/local/service/websocket.go index 8e7baca..9ecacca 100644 --- a/local/service/websocket.go +++ b/local/service/websocket.go @@ -11,25 +11,21 @@ import ( "github.com/google/uuid" ) -// WebSocketConnection represents a single WebSocket connection type WebSocketConnection struct { conn *websocket.Conn - serverID *uuid.UUID // If connected to a specific server creation process - userID *uuid.UUID // User who owns this connection + serverID *uuid.UUID + userID *uuid.UUID } -// WebSocketService manages WebSocket connections and message broadcasting type WebSocketService struct { - connections sync.Map // map[string]*WebSocketConnection - key is connection ID + connections sync.Map mu sync.RWMutex } -// NewWebSocketService creates a new WebSocket service func NewWebSocketService() *WebSocketService { return &WebSocketService{} } -// AddConnection adds a new WebSocket connection func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, userID *uuid.UUID) { wsConn := &WebSocketConnection{ conn: conn, @@ -39,7 +35,6 @@ func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, u logging.Info("WebSocket connection added: %s for user: %v", connID, userID) } -// RemoveConnection removes a WebSocket connection func (ws *WebSocketService) RemoveConnection(connID string) { if conn, exists := ws.connections.LoadAndDelete(connID); exists { if wsConn, ok := conn.(*WebSocketConnection); ok { @@ -49,7 +44,6 @@ func (ws *WebSocketService) RemoveConnection(connID string) { logging.Info("WebSocket connection removed: %s", connID) } -// SetServerID associates a connection with a specific server creation process func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) { if conn, exists := ws.connections.Load(connID); exists { if wsConn, ok := conn.(*WebSocketConnection); ok { @@ -58,7 +52,6 @@ func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) { } } -// BroadcastStep sends a step update to all connections associated with a server func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerCreationStep, status model.StepStatus, message string, errorMsg string) { stepMsg := model.StepMessage{ Step: step, @@ -77,7 +70,6 @@ func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerC ws.broadcastToServer(serverID, wsMsg) } -// BroadcastSteamOutput sends Steam command output to all connections associated with a server func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output string, isError bool) { steamMsg := model.SteamOutputMessage{ Output: output, @@ -94,7 +86,6 @@ func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output stri ws.broadcastToServer(serverID, wsMsg) } -// BroadcastError sends an error message to all connections associated with a server func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, details string) { errorMsg := model.ErrorMessage{ Error: error, @@ -111,7 +102,6 @@ func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, det ws.broadcastToServer(serverID, wsMsg) } -// BroadcastComplete sends a completion message to all connections associated with a server func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, message string) { completeMsg := model.CompleteMessage{ ServerID: serverID, @@ -129,7 +119,6 @@ func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, ws.broadcastToServer(serverID, wsMsg) } -// broadcastToServer sends a message to all connections associated with a specific server func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.WebSocketMessage) { data, err := json.Marshal(message) if err != nil { @@ -137,22 +126,35 @@ func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model. return } + sentToAssociatedConnections := false + ws.connections.Range(func(key, value interface{}) bool { if wsConn, ok := value.(*WebSocketConnection); ok { - // Send to connections associated with this server if wsConn.serverID != nil && *wsConn.serverID == serverID { if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil { logging.Error("Failed to send WebSocket message to connection %s: %v", key, err) - // Remove the connection if it's broken ws.RemoveConnection(key.(string)) + } else { + sentToAssociatedConnections = true } } } return true }) + + if !sentToAssociatedConnections && (message.Type == model.MessageTypeStep || message.Type == model.MessageTypeError || message.Type == model.MessageTypeComplete) { + ws.connections.Range(func(key, value interface{}) bool { + if wsConn, ok := value.(*WebSocketConnection); ok { + if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil { + logging.Error("Failed to send WebSocket message to connection %s: %v", key, err) + ws.RemoveConnection(key.(string)) + } + } + return true + }) + } } -// BroadcastToUser sends a message to all connections owned by a specific user func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebSocketMessage) { data, err := json.Marshal(message) if err != nil { @@ -162,11 +164,9 @@ func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebS ws.connections.Range(func(key, value interface{}) bool { if wsConn, ok := value.(*WebSocketConnection); ok { - // Send to connections owned by this user if wsConn.userID != nil && *wsConn.userID == userID { if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil { logging.Error("Failed to send WebSocket message to connection %s: %v", key, err) - // Remove the connection if it's broken ws.RemoveConnection(key.(string)) } } @@ -175,7 +175,6 @@ func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebS }) } -// GetActiveConnections returns the count of active connections func (ws *WebSocketService) GetActiveConnections() int { count := 0 ws.connections.Range(func(key, value interface{}) bool { diff --git a/local/service/windows_service.go b/local/service/windows_service.go index c9324e8..32a67c2 100644 --- a/local/service/windows_service.go +++ b/local/service/windows_service.go @@ -23,34 +23,25 @@ func NewWindowsService() *WindowsService { } } -// executeNSSM runs an NSSM command through PowerShell with elevation func (s *WindowsService) ExecuteNSSM(ctx context.Context, args ...string) (string, error) { - // Get NSSM path from environment variable nssmPath := env.GetNSSMPath() - // Prepend NSSM path to arguments nssmArgs := append([]string{"-NoProfile", "-NonInteractive", "-Command", "& " + nssmPath}, args...) output, err := s.executor.ExecuteWithOutput(nssmArgs...) if err != nil { - // Log the full command and error for debugging logging.Error("NSSM command failed: powershell %s", strings.Join(nssmArgs, " ")) logging.Error("NSSM error output: %s", output) return "", err } - // Clean up output by removing null bytes and trimming whitespace cleaned := strings.TrimSpace(strings.ReplaceAll(output, "\x00", "")) - // Remove \r\n from status strings cleaned = strings.TrimSuffix(cleaned, "\r\n") return cleaned, nil } -// Service Installation/Configuration Methods - func (s *WindowsService) CreateService(ctx context.Context, serviceName, execPath, workingDir string, args []string) error { - // Ensure paths are absolute and properly formatted for Windows absExecPath, err := filepath.Abs(execPath) if err != nil { return fmt.Errorf("failed to get absolute path for executable: %v", err) @@ -63,30 +54,24 @@ func (s *WindowsService) CreateService(ctx context.Context, serviceName, execPat } absWorkingDir = filepath.Clean(absWorkingDir) - // Log the paths being used logging.Info("Creating service '%s' with:", serviceName) logging.Info(" Executable: %s", absExecPath) logging.Info(" Working Directory: %s", absWorkingDir) - // First remove any existing service with the same name s.ExecuteNSSM(ctx, "remove", serviceName, "confirm") - // Install service if _, err := s.ExecuteNSSM(ctx, "install", serviceName, absExecPath); err != nil { return fmt.Errorf("failed to install service: %v", err) } - // Set arguments if provided if len(args) > 0 { cmdArgs := append([]string{"set", serviceName, "AppParameters"}, args...) if _, err := s.ExecuteNSSM(ctx, cmdArgs...); err != nil { - // Try to clean up on failure s.ExecuteNSSM(ctx, "remove", serviceName, "confirm") return fmt.Errorf("failed to set arguments: %v", err) } } - // Verify service was created if _, err := s.ExecuteNSSM(ctx, "get", serviceName, "Application"); err != nil { return fmt.Errorf("service creation verification failed: %v", err) } @@ -105,17 +90,13 @@ func (s *WindowsService) DeleteService(ctx context.Context, serviceName string) } func (s *WindowsService) UpdateService(ctx context.Context, serviceName, execPath, workingDir string, args []string) error { - // First remove the existing service if err := s.DeleteService(ctx, serviceName); err != nil { return err } - // Then create it again with new parameters return s.CreateService(ctx, serviceName, execPath, workingDir, args) } -// Service Control Methods - func (s *WindowsService) Status(ctx context.Context, serviceName string) (string, error) { return s.ExecuteNSSM(ctx, "status", serviceName) } @@ -129,11 +110,9 @@ func (s *WindowsService) Stop(ctx context.Context, serviceName string) (string, } func (s *WindowsService) Restart(ctx context.Context, serviceName string) (string, error) { - // First stop the service if _, err := s.Stop(ctx, serviceName); err != nil { return "", err } - // Then start it again return s.Start(ctx, serviceName) } diff --git a/local/utl/cache/cache.go b/local/utl/cache/cache.go index 082ecac..5cf47e5 100644 --- a/local/utl/cache/cache.go +++ b/local/utl/cache/cache.go @@ -9,26 +9,22 @@ import ( "go.uber.org/dig" ) -// CacheItem represents an item in the cache type CacheItem struct { Value interface{} Expiration int64 } -// InMemoryCache is a thread-safe in-memory cache type InMemoryCache struct { items map[string]CacheItem mu sync.RWMutex } -// NewInMemoryCache creates and returns a new InMemoryCache instance func NewInMemoryCache() *InMemoryCache { return &InMemoryCache{ items: make(map[string]CacheItem), } } -// Set adds an item to the cache with an expiration duration (in seconds) func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duration) { c.mu.Lock() defer c.mu.Unlock() @@ -44,7 +40,6 @@ func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duratio } } -// Get retrieves an item from the cache func (c *InMemoryCache) Get(key string) (interface{}, bool) { c.mu.RLock() defer c.mu.RUnlock() @@ -55,24 +50,18 @@ func (c *InMemoryCache) Get(key string) (interface{}, bool) { } if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration { - // Item has expired, but don't delete here to avoid lock upgrade. - // It will be overwritten on the next Set. return nil, false } return item.Value, true } -// Delete removes an item from the cache func (c *InMemoryCache) Delete(key string) { c.mu.Lock() defer c.mu.Unlock() delete(c.items, key) } -// GetOrSet retrieves an item from the cache. If the item is not found, it -// calls the provided function to get the value, sets it in the cache, and -// returns it. func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetcher func() (T, error)) (T, error) { if cached, found := c.Get(key); found { if value, ok := cached.(T); ok { @@ -90,7 +79,6 @@ func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetch return value, nil } -// Start initializes the cache and provides it to the DI container. func Start(di *dig.Container) { cache := NewInMemoryCache() err := di.Provide(func() *InMemoryCache { diff --git a/local/utl/command/callback_executor.go b/local/utl/command/callback_executor.go new file mode 100644 index 0000000..2125317 --- /dev/null +++ b/local/utl/command/callback_executor.go @@ -0,0 +1,245 @@ +package command + +import ( + "acc-server-manager/local/model" + "acc-server-manager/local/utl/logging" + "bufio" + "context" + "fmt" + "io" + "os/exec" + "strings" + "time" + + "github.com/google/uuid" +) + +type CallbackInteractiveCommandExecutor struct { + *InteractiveCommandExecutor + callbacks *CallbackConfig + serverID uuid.UUID +} + +func NewCallbackInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager, callbacks *CallbackConfig, serverID uuid.UUID) *CallbackInteractiveCommandExecutor { + if callbacks == nil { + callbacks = DefaultCallbackConfig() + } + + return &CallbackInteractiveCommandExecutor{ + InteractiveCommandExecutor: &InteractiveCommandExecutor{ + CommandExecutor: baseExecutor, + tfaManager: tfaManager, + }, + callbacks: callbacks, + serverID: serverID, + } +} + +func (e *CallbackInteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error { + cmd := exec.CommandContext(ctx, e.ExePath, args...) + + if e.WorkDir != "" { + cmd.Dir = e.WorkDir + } + + stdin, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to create stdin pipe: %v", err) + } + defer stdin.Close() + + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to create stdout pipe: %v", err) + } + defer stdout.Close() + + stderr, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("failed to create stderr pipe: %v", err) + } + defer stderr.Close() + + logging.Info("Executing interactive command with callbacks: %s %s", e.ExePath, strings.Join(args, " ")) + + e.callbacks.OnCommand(e.serverID, e.ExePath, args, false, false, "") + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Starting command: %s %s", e.ExePath, strings.Join(args, " ")), false) + + if err := cmd.Start(); err != nil { + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to start command: %v", err), true) + return fmt.Errorf("failed to start command: %v", err) + } + + outputDone := make(chan error, 1) + cmdDone := make(chan error, 1) + + go e.monitorOutputWithCallbacks(ctx, stdout, stderr, serverID, outputDone) + + go func() { + cmdDone <- cmd.Wait() + }() + + var cmdErr, outputErr error + completedCount := 0 + + for completedCount < 2 { + select { + case cmdErr = <-cmdDone: + completedCount++ + logging.Info("Command execution completed") + e.callbacks.OnOutput(e.serverID, "Command execution completed", false) + case outputErr = <-outputDone: + completedCount++ + logging.Info("Output monitoring completed") + case <-ctx.Done(): + e.callbacks.OnOutput(e.serverID, "Command execution cancelled", true) + return ctx.Err() + } + } + + if outputErr != nil { + logging.Warn("Output monitoring error: %v", outputErr) + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Output monitoring error: %v", outputErr), true) + } + + success := cmdErr == nil + errorMsg := "" + if cmdErr != nil { + errorMsg = cmdErr.Error() + } + e.callbacks.OnCommand(e.serverID, e.ExePath, args, true, success, errorMsg) + + return cmdErr +} + +func (e *CallbackInteractiveCommandExecutor) monitorOutputWithCallbacks(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) { + defer func() { + select { + case done <- nil: + default: + } + }() + + stdoutScanner := bufio.NewScanner(stdout) + stderrScanner := bufio.NewScanner(stderr) + + outputChan := make(chan outputLine, 100) + readersDone := make(chan struct{}, 2) + + steamConsoleStarted := false + tfaRequestCreated := false + + go func() { + defer func() { readersDone <- struct{}{} }() + for stdoutScanner.Scan() { + line := stdoutScanner.Text() + if e.LogOutput { + logging.Info("STDOUT: %s", line) + } + e.callbacks.OnOutput(e.serverID, line, false) + + select { + case outputChan <- outputLine{text: line, isError: false}: + case <-ctx.Done(): + return + } + } + if err := stdoutScanner.Err(); err != nil { + logging.Warn("Stdout scanner error: %v", err) + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Stdout scanner error: %v", err), true) + } + }() + + go func() { + defer func() { readersDone <- struct{}{} }() + for stderrScanner.Scan() { + line := stderrScanner.Text() + if e.LogOutput { + logging.Info("STDERR: %s", line) + } + e.callbacks.OnOutput(e.serverID, line, true) + + select { + case outputChan <- outputLine{text: line, isError: true}: + case <-ctx.Done(): + return + } + } + if err := stderrScanner.Err(); err != nil { + logging.Warn("Stderr scanner error: %v", err) + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Stderr scanner error: %v", err), true) + } + }() + + readersFinished := 0 + for { + select { + case <-ctx.Done(): + done <- ctx.Err() + return + case <-readersDone: + readersFinished++ + if readersFinished == 2 { + close(outputChan) + for lineData := range outputChan { + if e.is2FAPrompt(lineData.text) { + if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil { + logging.Error("Failed to handle 2FA prompt: %v", err) + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true) + done <- err + return + } + } + } + return + } + case lineData, ok := <-outputChan: + if !ok { + return + } + + lowerLine := strings.ToLower(lineData.text) + if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") { + steamConsoleStarted = true + logging.Info("Steam Console Client startup detected - will monitor for 2FA hang") + e.callbacks.OnOutput(e.serverID, "Steam Console Client startup detected", false) + } + + if e.is2FAPrompt(lineData.text) { + if !tfaRequestCreated { + e.callbacks.OnOutput(e.serverID, "2FA prompt detected - waiting for user confirmation", false) + if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil { + logging.Error("Failed to handle 2FA prompt: %v", err) + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true) + done <- err + return + } + tfaRequestCreated = true + } + } + + if tfaRequestCreated && e.isSteamContinuing(lineData.text) { + logging.Info("Steam CMD appears to have continued after 2FA confirmation") + e.callbacks.OnOutput(e.serverID, "Steam CMD continued after 2FA confirmation", false) + e.autoCompletePendingRequests(serverID) + } + case <-time.After(15 * time.Second): + if steamConsoleStarted && !tfaRequestCreated { + logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA") + e.callbacks.OnOutput(e.serverID, "Waiting for Steam Guard 2FA confirmation...", false) + if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil { + logging.Error("Failed to handle Steam Guard 2FA prompt: %v", err) + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle Steam Guard 2FA prompt: %v", err), true) + done <- err + return + } + tfaRequestCreated = true + } + } + } +} + +type outputLine struct { + text string + isError bool +} diff --git a/local/utl/command/callbacks.go b/local/utl/command/callbacks.go new file mode 100644 index 0000000..1d1d060 --- /dev/null +++ b/local/utl/command/callbacks.go @@ -0,0 +1,19 @@ +package command + +import "github.com/google/uuid" + +type OutputCallback func(serverID uuid.UUID, output string, isError bool) + +type CommandCallback func(serverID uuid.UUID, command string, args []string, completed bool, success bool, error string) + +type CallbackConfig struct { + OnOutput OutputCallback + OnCommand CommandCallback +} + +func DefaultCallbackConfig() *CallbackConfig { + return &CallbackConfig{ + OnOutput: func(uuid.UUID, string, bool) {}, + OnCommand: func(uuid.UUID, string, []string, bool, bool, string) {}, + } +} diff --git a/local/utl/command/executor.go b/local/utl/command/executor.go index 75c0d07..256b5ee 100644 --- a/local/utl/command/executor.go +++ b/local/utl/command/executor.go @@ -8,17 +8,12 @@ import ( "strings" ) -// CommandExecutor provides a base structure for executing commands type CommandExecutor struct { - // Base executable path - ExePath string - // Working directory for commands - WorkDir string - // Whether to capture and log output + ExePath string + WorkDir string LogOutput bool } -// CommandBuilder helps build command arguments type CommandBuilder struct { args []string } @@ -48,10 +43,9 @@ func (b *CommandBuilder) Build() []string { return b.args } -// Execute runs a command with the given arguments func (e *CommandExecutor) Execute(args ...string) error { cmd := exec.Command(e.ExePath, args...) - + if e.WorkDir != "" { cmd.Dir = e.WorkDir } @@ -65,15 +59,13 @@ func (e *CommandExecutor) Execute(args ...string) error { return cmd.Run() } -// ExecuteWithBuilder runs a command using a CommandBuilder func (e *CommandExecutor) ExecuteWithBuilder(builder *CommandBuilder) error { return e.Execute(builder.Build()...) } -// ExecuteWithOutput runs a command and returns its output func (e *CommandExecutor) ExecuteWithOutput(args ...string) (string, error) { cmd := exec.Command(e.ExePath, args...) - + if e.WorkDir != "" { cmd.Dir = e.WorkDir } @@ -83,10 +75,9 @@ func (e *CommandExecutor) ExecuteWithOutput(args ...string) (string, error) { return string(output), err } -// ExecuteWithEnv runs a command with custom environment variables func (e *CommandExecutor) ExecuteWithEnv(env []string, args ...string) error { cmd := exec.Command(e.ExePath, args...) - + if e.WorkDir != "" { cmd.Dir = e.WorkDir } @@ -100,4 +91,4 @@ func (e *CommandExecutor) ExecuteWithEnv(env []string, args ...string) error { logging.Info("Executing command: %s %s", e.ExePath, strings.Join(args, " ")) return cmd.Run() -} \ No newline at end of file +} diff --git a/local/utl/command/interactive_executor.go b/local/utl/command/interactive_executor.go index 5b90ba2..1e6ff47 100644 --- a/local/utl/command/interactive_executor.go +++ b/local/utl/command/interactive_executor.go @@ -9,14 +9,12 @@ import ( "io" "os" "os/exec" - "reflect" "strings" "time" "github.com/google/uuid" ) -// InteractiveCommandExecutor extends CommandExecutor to handle interactive commands type InteractiveCommandExecutor struct { *CommandExecutor tfaManager *model.Steam2FAManager @@ -29,7 +27,6 @@ func NewInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *mo } } -// ExecuteInteractive runs a command that may require 2FA input func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error { cmd := exec.CommandContext(ctx, e.ExePath, args...) @@ -37,7 +34,6 @@ func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, ser cmd.Dir = e.WorkDir } - // Create pipes for stdin, stdout, and stderr stdin, err := cmd.StdinPipe() if err != nil { return fmt.Errorf("failed to create stdin pipe: %v", err) @@ -58,7 +54,6 @@ func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, ser logging.Info("Executing interactive command: %s %s", e.ExePath, strings.Join(args, " ")) - // Enable debug mode if environment variable is set debugMode := os.Getenv("STEAMCMD_DEBUG") == "true" if debugMode { logging.Info("STEAMCMD_DEBUG mode enabled - will log all output and create proactive 2FA requests") @@ -68,19 +63,15 @@ func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, ser return fmt.Errorf("failed to start command: %v", err) } - // Create channels for output monitoring outputDone := make(chan error, 1) cmdDone := make(chan error, 1) - // Monitor stdout and stderr for 2FA prompts go e.monitorOutput(ctx, stdout, stderr, serverID, outputDone) - // Wait for the command to finish in a separate goroutine go func() { cmdDone <- cmd.Wait() }() - // Wait for both command and output monitoring to complete var cmdErr, outputErr error completedCount := 0 @@ -112,18 +103,15 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, } }() - // Create scanners for both outputs stdoutScanner := bufio.NewScanner(stdout) stderrScanner := bufio.NewScanner(stderr) - outputChan := make(chan string, 100) // Buffered channel to prevent blocking + outputChan := make(chan string, 100) readersDone := make(chan struct{}, 2) - // Track Steam Console startup for this specific execution steamConsoleStarted := false tfaRequestCreated := false - // Read from stdout go func() { defer func() { readersDone <- struct{}{} }() for stdoutScanner.Scan() { @@ -131,7 +119,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, if e.LogOutput { logging.Info("STDOUT: %s", line) } - // Always log Steam CMD output for debugging 2FA issues if strings.Contains(strings.ToLower(line), "steam") { logging.Info("STEAM_DEBUG: %s", line) } @@ -146,7 +133,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, } }() - // Read from stderr go func() { defer func() { readersDone <- struct{}{} }() for stderrScanner.Scan() { @@ -154,7 +140,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, if e.LogOutput { logging.Info("STDERR: %s", line) } - // Always log Steam CMD errors for debugging 2FA issues if strings.Contains(strings.ToLower(line), "steam") { logging.Info("STEAM_DEBUG_ERR: %s", line) } @@ -169,7 +154,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, } }() - // Monitor for completion and 2FA prompts readersFinished := 0 for { select { @@ -179,9 +163,7 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, case <-readersDone: readersFinished++ if readersFinished == 2 { - // Both readers are done, close output channel and finish monitoring close(outputChan) - // Drain any remaining output for line := range outputChan { if e.is2FAPrompt(line) { if err := e.handle2FAPrompt(ctx, line, serverID); err != nil { @@ -195,18 +177,15 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, } case line, ok := <-outputChan: if !ok { - // Channel closed, we're done return } - // Check for Steam Console startup lowerLine := strings.ToLower(line) if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") { steamConsoleStarted = true logging.Info("Steam Console Client startup detected - will monitor for 2FA hang") } - // Check if this line indicates a 2FA prompt if e.is2FAPrompt(line) { if !tfaRequestCreated { if err := e.handle2FAPrompt(ctx, line, serverID); err != nil { @@ -218,15 +197,11 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, } } - // Check if Steam CMD continued after 2FA (auto-completion) if tfaRequestCreated && e.isSteamContinuing(line) { logging.Info("Steam CMD appears to have continued after 2FA confirmation - auto-completing 2FA request") - // Auto-complete any pending 2FA requests for this server e.autoCompletePendingRequests(serverID) } case <-time.After(15 * time.Second): - // If Steam Console has started and we haven't seen output for 15 seconds, - // it's very likely waiting for 2FA confirmation if steamConsoleStarted && !tfaRequestCreated { logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA") if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil { @@ -243,7 +218,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, } func (e *InteractiveCommandExecutor) is2FAPrompt(line string) bool { - // Common SteamCMD 2FA prompts - updated with more comprehensive patterns twoFAKeywords := []string{ "please enter your steam guard code", "steam guard", @@ -272,7 +246,6 @@ func (e *InteractiveCommandExecutor) is2FAPrompt(line string) bool { } } - // Also check for patterns that might indicate Steam is waiting for input waitingPatterns := []string{ "waiting for", "please enter", @@ -330,12 +303,9 @@ func (e *InteractiveCommandExecutor) autoCompletePendingRequests(serverID *uuid. func (e *InteractiveCommandExecutor) handle2FAPrompt(_ context.Context, promptLine string, serverID *uuid.UUID) error { logging.Info("2FA prompt detected: %s", promptLine) - // Create a 2FA request request := e.tfaManager.CreateRequest(promptLine, serverID) logging.Info("Created 2FA request with ID: %s", request.ID) - // Wait for user to complete the 2FA process - // Use a reasonable timeout (e.g., 5 minutes) timeout := 5 * time.Minute success, err := e.tfaManager.WaitForCompletion(request.ID, timeout) @@ -352,271 +322,3 @@ func (e *InteractiveCommandExecutor) handle2FAPrompt(_ context.Context, promptLi logging.Info("2FA completed successfully") return nil } - -// WebSocketInteractiveCommandExecutor extends InteractiveCommandExecutor to stream output via WebSocket -type WebSocketInteractiveCommandExecutor struct { - *InteractiveCommandExecutor - wsService interface{} // Using interface{} to avoid circular import - serverID uuid.UUID -} - -// NewInteractiveCommandExecutorWithWebSocket creates a new WebSocket-enabled interactive command executor -func NewInteractiveCommandExecutorWithWebSocket(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager, wsService interface{}, serverID uuid.UUID) *WebSocketInteractiveCommandExecutor { - return &WebSocketInteractiveCommandExecutor{ - InteractiveCommandExecutor: &InteractiveCommandExecutor{ - CommandExecutor: baseExecutor, - tfaManager: tfaManager, - }, - wsService: wsService, - serverID: serverID, - } -} - -// ExecuteInteractive runs a command with WebSocket output streaming -func (e *WebSocketInteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error { - cmd := exec.CommandContext(ctx, e.ExePath, args...) - - if e.WorkDir != "" { - cmd.Dir = e.WorkDir - } - - // Create pipes for stdin, stdout, and stderr - stdin, err := cmd.StdinPipe() - if err != nil { - return fmt.Errorf("failed to create stdin pipe: %v", err) - } - defer stdin.Close() - - stdout, err := cmd.StdoutPipe() - if err != nil { - return fmt.Errorf("failed to create stdout pipe: %v", err) - } - defer stdout.Close() - - stderr, err := cmd.StderrPipe() - if err != nil { - return fmt.Errorf("failed to create stderr pipe: %v", err) - } - defer stderr.Close() - - logging.Info("Executing interactive command with WebSocket streaming: %s %s", e.ExePath, strings.Join(args, " ")) - - // Broadcast command start via WebSocket - e.broadcastSteamOutput(fmt.Sprintf("Starting command: %s %s", e.ExePath, strings.Join(args, " ")), false) - - if err := cmd.Start(); err != nil { - e.broadcastSteamOutput(fmt.Sprintf("Failed to start command: %v", err), true) - return fmt.Errorf("failed to start command: %v", err) - } - - // Create channels for output monitoring - outputDone := make(chan error, 1) - cmdDone := make(chan error, 1) - - // Monitor stdout and stderr for 2FA prompts with WebSocket streaming - go e.monitorOutputWithWebSocket(ctx, stdout, stderr, serverID, outputDone) - - // Wait for the command to finish in a separate goroutine - go func() { - cmdDone <- cmd.Wait() - }() - - // Wait for both command and output monitoring to complete - var cmdErr, outputErr error - completedCount := 0 - - for completedCount < 2 { - select { - case cmdErr = <-cmdDone: - completedCount++ - logging.Info("Command execution completed") - e.broadcastSteamOutput("Command execution completed", false) - case outputErr = <-outputDone: - completedCount++ - logging.Info("Output monitoring completed") - case <-ctx.Done(): - e.broadcastSteamOutput("Command execution cancelled", true) - return ctx.Err() - } - } - - if outputErr != nil { - logging.Warn("Output monitoring error: %v", outputErr) - e.broadcastSteamOutput(fmt.Sprintf("Output monitoring error: %v", outputErr), true) - } - - return cmdErr -} - -// broadcastSteamOutput sends output to WebSocket using reflection to avoid circular imports -func (e *WebSocketInteractiveCommandExecutor) broadcastSteamOutput(output string, isError bool) { - if e.wsService == nil { - return - } - - // Use reflection to call BroadcastSteamOutput method - wsServiceVal := reflect.ValueOf(e.wsService) - method := wsServiceVal.MethodByName("BroadcastSteamOutput") - if !method.IsValid() { - logging.Warn("BroadcastSteamOutput method not found on WebSocket service") - return - } - - // Call the method with parameters: serverID, output, isError - args := []reflect.Value{ - reflect.ValueOf(e.serverID), - reflect.ValueOf(output), - reflect.ValueOf(isError), - } - method.Call(args) -} - -// monitorOutputWithWebSocket monitors command output and streams it via WebSocket -func (e *WebSocketInteractiveCommandExecutor) monitorOutputWithWebSocket(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) { - defer func() { - select { - case done <- nil: - default: - } - }() - - // Create scanners for both outputs - stdoutScanner := bufio.NewScanner(stdout) - stderrScanner := bufio.NewScanner(stderr) - - outputChan := make(chan outputLine, 100) // Buffered channel to prevent blocking - readersDone := make(chan struct{}, 2) - - // Track Steam Console startup for this specific execution - steamConsoleStarted := false - tfaRequestCreated := false - - // Read from stdout - go func() { - defer func() { readersDone <- struct{}{} }() - for stdoutScanner.Scan() { - line := stdoutScanner.Text() - if e.LogOutput { - logging.Info("STDOUT: %s", line) - } - // Stream output via WebSocket - e.broadcastSteamOutput(line, false) - - select { - case outputChan <- outputLine{text: line, isError: false}: - case <-ctx.Done(): - return - } - } - if err := stdoutScanner.Err(); err != nil { - logging.Warn("Stdout scanner error: %v", err) - e.broadcastSteamOutput(fmt.Sprintf("Stdout scanner error: %v", err), true) - } - }() - - // Read from stderr - go func() { - defer func() { readersDone <- struct{}{} }() - for stderrScanner.Scan() { - line := stderrScanner.Text() - if e.LogOutput { - logging.Info("STDERR: %s", line) - } - // Stream error output via WebSocket - e.broadcastSteamOutput(line, true) - - select { - case outputChan <- outputLine{text: line, isError: true}: - case <-ctx.Done(): - return - } - } - if err := stderrScanner.Err(); err != nil { - logging.Warn("Stderr scanner error: %v", err) - e.broadcastSteamOutput(fmt.Sprintf("Stderr scanner error: %v", err), true) - } - }() - - // Monitor for completion and 2FA prompts - readersFinished := 0 - for { - select { - case <-ctx.Done(): - done <- ctx.Err() - return - case <-readersDone: - readersFinished++ - if readersFinished == 2 { - // Both readers are done, close output channel and finish monitoring - close(outputChan) - // Drain any remaining output - for lineData := range outputChan { - if e.is2FAPrompt(lineData.text) { - if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil { - logging.Error("Failed to handle 2FA prompt: %v", err) - e.broadcastSteamOutput(fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true) - done <- err - return - } - } - } - return - } - case lineData, ok := <-outputChan: - if !ok { - // Channel closed, we're done - return - } - - // Check for Steam Console startup - lowerLine := strings.ToLower(lineData.text) - if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") { - steamConsoleStarted = true - logging.Info("Steam Console Client startup detected - will monitor for 2FA hang") - e.broadcastSteamOutput("Steam Console Client startup detected", false) - } - - // Check if this line indicates a 2FA prompt - if e.is2FAPrompt(lineData.text) { - if !tfaRequestCreated { - e.broadcastSteamOutput("2FA prompt detected - waiting for user confirmation", false) - if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil { - logging.Error("Failed to handle 2FA prompt: %v", err) - e.broadcastSteamOutput(fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true) - done <- err - return - } - tfaRequestCreated = true - } - } - - // Check if Steam CMD continued after 2FA (auto-completion) - if tfaRequestCreated && e.isSteamContinuing(lineData.text) { - logging.Info("Steam CMD appears to have continued after 2FA confirmation") - e.broadcastSteamOutput("Steam CMD continued after 2FA confirmation", false) - // Auto-complete any pending 2FA requests for this server - e.autoCompletePendingRequests(serverID) - } - case <-time.After(15 * time.Second): - // If Steam Console has started and we haven't seen output for 15 seconds, - // it's very likely waiting for 2FA confirmation - if steamConsoleStarted && !tfaRequestCreated { - logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA") - e.broadcastSteamOutput("Waiting for Steam Guard 2FA confirmation...", false) - if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil { - logging.Error("Failed to handle Steam Guard 2FA prompt: %v", err) - e.broadcastSteamOutput(fmt.Sprintf("Failed to handle Steam Guard 2FA prompt: %v", err), true) - done <- err - return - } - tfaRequestCreated = true - } - } - } -} - -// outputLine represents a line of output with error status -type outputLine struct { - text string - isError bool -} diff --git a/local/utl/common/common.go b/local/utl/common/common.go index 825e4c0..4757b59 100644 --- a/local/utl/common/common.go +++ b/local/utl/common/common.go @@ -79,12 +79,6 @@ func IndentJson(body []byte) ([]byte, error) { return unmarshaledBody.Bytes(), nil } -// ParseQueryFilter parses query parameters into a filter struct using reflection. -// It supports various field types and uses struct tags to determine parsing behavior. -// Supported tags: -// - `query:"field_name"` - specifies the query parameter name -// - `param:"param_name"` - specifies the path parameter name -// - `time_format:"format"` - specifies the time format for parsing dates (default: RFC3339) func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error { val := reflect.ValueOf(filter) if val.Kind() != reflect.Ptr || val.IsNil() { @@ -94,14 +88,12 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error { elem := val.Elem() typ := elem.Type() - // Process all fields including embedded structs var processFields func(reflect.Value, reflect.Type) error processFields = func(val reflect.Value, typ reflect.Type) error { for i := 0; i < val.NumField(); i++ { field := val.Field(i) fieldType := typ.Field(i) - // Handle embedded structs recursively if fieldType.Anonymous { if err := processFields(field, fieldType.Type); err != nil { return err @@ -109,12 +101,10 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error { continue } - // Skip if field cannot be set if !field.CanSet() { continue } - // Check for param tag first (path parameters) if paramName := fieldType.Tag.Get("param"); paramName != "" { if err := parsePathParam(c, field, paramName); err != nil { return fmt.Errorf("error parsing path parameter %s: %v", paramName, err) @@ -122,15 +112,14 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error { continue } - // Then check for query tag queryName := fieldType.Tag.Get("query") if queryName == "" { - queryName = ToSnakeCase(fieldType.Name) // Default to snake_case of field name + queryName = ToSnakeCase(fieldType.Name) } queryVal := c.Query(queryName) if queryVal == "" { - continue // Skip empty values + continue } if err := parseValue(field, queryVal, fieldType.Tag); err != nil { diff --git a/local/utl/configs/configs.go b/local/utl/configs/configs.go index 0a7b7a2..34074af 100644 --- a/local/utl/configs/configs.go +++ b/local/utl/configs/configs.go @@ -18,7 +18,6 @@ var ( func Init() { godotenv.Load() - // Fail fast if critical environment variables are missing Secret = getEnvRequired("APP_SECRET") SecretCode = getEnvRequired("APP_SECRET_CODE") EncryptionKey = getEnvRequired("ENCRYPTION_KEY") @@ -29,7 +28,6 @@ func Init() { } } -// getEnv retrieves an environment variable or returns a fallback value. func getEnv(key, fallback string) string { if value, exists := os.LookupEnv(key); exists { return value @@ -38,12 +36,10 @@ func getEnv(key, fallback string) string { return fallback } -// getEnvRequired retrieves an environment variable and fails if it's not set. -// This should be used for critical configuration that must not have defaults. func getEnvRequired(key string) string { if value, exists := os.LookupEnv(key); exists && value != "" { return value } log.Fatalf("Required environment variable %s is not set or is empty", key) - return "" // This line will never be reached due to log.Fatalf + return "" } diff --git a/local/utl/db/db.go b/local/utl/db/db.go index 8c45bed..7efa670 100644 --- a/local/utl/db/db.go +++ b/local/utl/db/db.go @@ -33,7 +33,6 @@ func Start(di *dig.Container) { func Migrate(db *gorm.DB) { logging.Info("Migrating database") - // Run GORM AutoMigrate for all models err := db.AutoMigrate( &model.ServiceControlModel{}, &model.Config{}, @@ -52,7 +51,6 @@ func Migrate(db *gorm.DB) { if err != nil { logging.Error("GORM AutoMigrate failed: %v", err) - // Don't panic, just log the error as custom migrations may have handled this } db.FirstOrCreate(&model.ServiceControlModel{ServiceControl: "Works"}) @@ -63,10 +61,8 @@ func Migrate(db *gorm.DB) { func runMigrations(db *gorm.DB) { logging.Info("Running custom database migrations...") - // Migration 001: Password security upgrade if err := migrations.RunPasswordSecurityMigration(db); err != nil { logging.Error("Failed to run password security migration: %v", err) - // Continue - this migration might not be needed for all setups } logging.Info("Custom database migrations completed") @@ -132,7 +128,6 @@ func seedCarModels(db *gorm.DB) error { carModels := []model.CarModel{ {Value: 0, CarModel: "Porsche 991 GT3 R"}, {Value: 1, CarModel: "Mercedes-AMG GT3"}, - // ... Add all car models from your list } for _, cm := range carModels { diff --git a/local/utl/env/env.go b/local/utl/env/env.go index 2a7c4b7..0f41b84 100644 --- a/local/utl/env/env.go +++ b/local/utl/env/env.go @@ -6,12 +6,10 @@ import ( ) const ( - // Default paths for when environment variables are not set DefaultSteamCMDPath = "c:\\steamcmd\\steamcmd.exe" DefaultNSSMPath = ".\\nssm.exe" ) -// GetSteamCMDPath returns the SteamCMD executable path from environment variable or default func GetSteamCMDPath() string { if path := os.Getenv("STEAMCMD_PATH"); path != "" { return path @@ -19,13 +17,11 @@ func GetSteamCMDPath() string { return DefaultSteamCMDPath } -// GetSteamCMDDirPath returns the directory containing SteamCMD executable func GetSteamCMDDirPath() string { steamCMDPath := GetSteamCMDPath() return filepath.Dir(steamCMDPath) } -// GetNSSMPath returns the NSSM executable path from environment variable or default func GetNSSMPath() string { if path := os.Getenv("NSSM_PATH"); path != "" { return path @@ -33,17 +29,14 @@ func GetNSSMPath() string { return DefaultNSSMPath } -// ValidatePaths checks if the configured paths exist (optional validation) func ValidatePaths() map[string]error { errors := make(map[string]error) - // Check SteamCMD path steamCMDPath := GetSteamCMDPath() if _, err := os.Stat(steamCMDPath); os.IsNotExist(err) { errors["STEAMCMD_PATH"] = err } - // Check NSSM path nssmPath := GetNSSMPath() if _, err := os.Stat(nssmPath); os.IsNotExist(err) { errors["NSSM_PATH"] = err diff --git a/local/utl/error_handler/controller_error_handler.go b/local/utl/error_handler/controller_error_handler.go index 1d4408a..067371f 100644 --- a/local/utl/error_handler/controller_error_handler.go +++ b/local/utl/error_handler/controller_error_handler.go @@ -9,45 +9,37 @@ import ( "github.com/gofiber/fiber/v2" ) -// ControllerErrorHandler provides centralized error handling for controllers type ControllerErrorHandler struct { errorLogger *logging.ErrorLogger } -// NewControllerErrorHandler creates a new controller error handler instance func NewControllerErrorHandler() *ControllerErrorHandler { return &ControllerErrorHandler{ errorLogger: logging.GetErrorLogger(), } } -// ErrorResponse represents a standardized error response type ErrorResponse struct { Error string `json:"error"` Code int `json:"code,omitempty"` Details map[string]string `json:"details,omitempty"` } -// HandleError handles controller errors with logging and standardized responses func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error { if err == nil { return nil } - // Get caller information for logging _, file, line, _ := runtime.Caller(1) file = strings.TrimPrefix(file, "acc-server-manager/") - // Build context string contextStr := "" if len(context) > 0 { contextStr = fmt.Sprintf("[%s] ", strings.Join(context, "|")) } - // Clean error message (remove null bytes) cleanErrorMsg := strings.ReplaceAll(err.Error(), "\x00", "") - // Log the error with context ceh.errorLogger.LogWithContext( fmt.Sprintf("CONTROLLER_ERROR [%s:%d]", file, line), "%s%s", @@ -55,31 +47,26 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo cleanErrorMsg, ) - // Create standardized error response errorResponse := ErrorResponse{ Error: cleanErrorMsg, Code: statusCode, } - // Add request details if available if c != nil { if errorResponse.Details == nil { errorResponse.Details = make(map[string]string) } - - // Safely extract request details + func() { defer func() { if r := recover(); r != nil { - // If any of these panic, just skip adding the details return } }() - + errorResponse.Details["method"] = c.Method() errorResponse.Details["path"] = c.Path() - - // Safely get IP address + if ip := c.IP(); ip != "" { errorResponse.Details["ip"] = ip } else { @@ -88,14 +75,11 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo }() } - // Return appropriate response based on status code if c == nil { - // If context is nil, we can't return a response return fmt.Errorf("cannot return HTTP response: context is nil") } - + if statusCode >= 500 { - // For server errors, don't expose internal details return c.Status(statusCode).JSON(ErrorResponse{ Error: "Internal server error", Code: statusCode, @@ -105,52 +89,42 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo return c.Status(statusCode).JSON(errorResponse) } -// HandleValidationError handles validation errors specifically func (ceh *ControllerErrorHandler) HandleValidationError(c *fiber.Ctx, err error, field string) error { return ceh.HandleError(c, err, fiber.StatusBadRequest, "VALIDATION", field) } -// HandleDatabaseError handles database-related errors func (ceh *ControllerErrorHandler) HandleDatabaseError(c *fiber.Ctx, err error) error { return ceh.HandleError(c, err, fiber.StatusInternalServerError, "DATABASE") } -// HandleAuthError handles authentication/authorization errors func (ceh *ControllerErrorHandler) HandleAuthError(c *fiber.Ctx, err error) error { return ceh.HandleError(c, err, fiber.StatusUnauthorized, "AUTH") } -// HandleNotFoundError handles resource not found errors func (ceh *ControllerErrorHandler) HandleNotFoundError(c *fiber.Ctx, resource string) error { err := fmt.Errorf("%s not found", resource) return ceh.HandleError(c, err, fiber.StatusNotFound, "NOT_FOUND") } -// HandleBusinessLogicError handles business logic errors func (ceh *ControllerErrorHandler) HandleBusinessLogicError(c *fiber.Ctx, err error) error { return ceh.HandleError(c, err, fiber.StatusBadRequest, "BUSINESS_LOGIC") } -// HandleServiceError handles service layer errors func (ceh *ControllerErrorHandler) HandleServiceError(c *fiber.Ctx, err error) error { return ceh.HandleError(c, err, fiber.StatusInternalServerError, "SERVICE") } -// HandleParsingError handles request parsing errors func (ceh *ControllerErrorHandler) HandleParsingError(c *fiber.Ctx, err error) error { return ceh.HandleError(c, err, fiber.StatusBadRequest, "PARSING") } -// HandleUUIDError handles UUID parsing errors func (ceh *ControllerErrorHandler) HandleUUIDError(c *fiber.Ctx, field string) error { err := fmt.Errorf("invalid %s format", field) return ceh.HandleError(c, err, fiber.StatusBadRequest, "UUID_VALIDATION", field) } -// Global controller error handler instance var globalErrorHandler *ControllerErrorHandler -// GetControllerErrorHandler returns the global controller error handler instance func GetControllerErrorHandler() *ControllerErrorHandler { if globalErrorHandler == nil { globalErrorHandler = NewControllerErrorHandler() @@ -158,49 +132,38 @@ func GetControllerErrorHandler() *ControllerErrorHandler { return globalErrorHandler } -// Convenience functions using the global error handler - -// HandleError handles controller errors using the global error handler func HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error { return GetControllerErrorHandler().HandleError(c, err, statusCode, context...) } -// HandleValidationError handles validation errors using the global error handler func HandleValidationError(c *fiber.Ctx, err error, field string) error { return GetControllerErrorHandler().HandleValidationError(c, err, field) } -// HandleDatabaseError handles database errors using the global error handler func HandleDatabaseError(c *fiber.Ctx, err error) error { return GetControllerErrorHandler().HandleDatabaseError(c, err) } -// HandleAuthError handles auth errors using the global error handler func HandleAuthError(c *fiber.Ctx, err error) error { return GetControllerErrorHandler().HandleAuthError(c, err) } -// HandleNotFoundError handles not found errors using the global error handler func HandleNotFoundError(c *fiber.Ctx, resource string) error { return GetControllerErrorHandler().HandleNotFoundError(c, resource) } -// HandleBusinessLogicError handles business logic errors using the global error handler func HandleBusinessLogicError(c *fiber.Ctx, err error) error { return GetControllerErrorHandler().HandleBusinessLogicError(c, err) } -// HandleServiceError handles service errors using the global error handler func HandleServiceError(c *fiber.Ctx, err error) error { return GetControllerErrorHandler().HandleServiceError(c, err) } -// HandleParsingError handles parsing errors using the global error handler func HandleParsingError(c *fiber.Ctx, err error) error { return GetControllerErrorHandler().HandleParsingError(c, err) } -// HandleUUIDError handles UUID errors using the global error handler func HandleUUIDError(c *fiber.Ctx, field string) error { return GetControllerErrorHandler().HandleUUIDError(c, field) } diff --git a/local/utl/jwt/jwt.go b/local/utl/jwt/jwt.go index 9cd036f..9d4cdff 100644 --- a/local/utl/jwt/jwt.go +++ b/local/utl/jwt/jwt.go @@ -11,7 +11,6 @@ import ( "github.com/golang-jwt/jwt/v4" ) -// Claims represents the JWT claims. type Claims struct { UserID string `json:"user_id"` IsOpenToken bool `json:"is_open_token"` @@ -27,7 +26,6 @@ type OpenJWTHandler struct { *JWTHandler } -// NewJWTHandler creates a new JWTHandler instance with the provided secret key. func NewOpenJWTHandler(jwtSecret string) *OpenJWTHandler { jwtHandler := NewJWTHandler(jwtSecret) jwtHandler.IsOpenToken = true @@ -36,7 +34,6 @@ func NewOpenJWTHandler(jwtSecret string) *OpenJWTHandler { } } -// NewJWTHandler creates a new JWTHandler instance with the provided secret key. func NewJWTHandler(jwtSecret string) *JWTHandler { if jwtSecret == "" { errors.SafeFatal("JWT_SECRET environment variable is required and cannot be empty") @@ -44,14 +41,12 @@ func NewJWTHandler(jwtSecret string) *JWTHandler { var secretKey []byte - // Decode base64 secret if it looks like base64, otherwise use as-is if decoded, err := base64.StdEncoding.DecodeString(jwtSecret); err == nil && len(decoded) >= 32 { secretKey = decoded } else { secretKey = []byte(jwtSecret) } - // Ensure minimum key length for security if len(secretKey) < 32 { errors.SafeFatal("JWT_SECRET must be at least 32 bytes long for security") } @@ -60,8 +55,6 @@ func NewJWTHandler(jwtSecret string) *JWTHandler { } } -// GenerateSecretKey generates a cryptographically secure random key for JWT signing -// This is a utility function for generating new secrets, not used in normal operation func (jh *JWTHandler) GenerateSecretKey() string { key := make([]byte, 64) // 512 bits if _, err := rand.Read(key); err != nil { @@ -70,7 +63,6 @@ func (jh *JWTHandler) GenerateSecretKey() string { return base64.StdEncoding.EncodeToString(key) } -// GenerateToken generates a new JWT for a given user. func (jh *JWTHandler) GenerateToken(userId string) (string, error) { expirationTime := time.Now().Add(24 * time.Hour) claims := &Claims{ @@ -99,7 +91,6 @@ func (jh *JWTHandler) GenerateTokenWithExpiry(user *model.User, expiry time.Time return token.SignedString(jh.SecretKey) } -// ValidateToken validates a JWT and returns the claims if the token is valid. func (jh *JWTHandler) ValidateToken(tokenString string) (*Claims, error) { claims := &Claims{} diff --git a/local/utl/logging/base.go b/local/utl/logging/base.go index 03ef1fa..c55b7bb 100644 --- a/local/utl/logging/base.go +++ b/local/utl/logging/base.go @@ -15,7 +15,6 @@ var ( timeFormat = "2006-01-02 15:04:05.000" ) -// BaseLogger provides the core logging functionality type BaseLogger struct { file *os.File logger *log.Logger @@ -23,7 +22,6 @@ type BaseLogger struct { initialized bool } -// LogLevel represents different logging levels type LogLevel string const ( @@ -34,28 +32,23 @@ const ( LogLevelPanic LogLevel = "PANIC" ) -// Initialize creates a new base logger instance func InitializeBase(tp string) (*BaseLogger, error) { return newBaseLogger(tp) } func newBaseLogger(tp string) (*BaseLogger, error) { - // Ensure logs directory exists if err := os.MkdirAll("logs", 0755); err != nil { return nil, fmt.Errorf("failed to create logs directory: %v", err) } - // Open log file with date in name logPath := filepath.Join("logs", fmt.Sprintf("acc-server-%s-%s.log", time.Now().Format("2006-01-02"), tp)) file, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666) if err != nil { return nil, fmt.Errorf("failed to open log file: %v", err) } - // Create multi-writer for both file and console multiWriter := io.MultiWriter(file, os.Stdout) - // Create base logger logger := &BaseLogger{ file: file, logger: log.New(multiWriter, "", 0), @@ -65,13 +58,11 @@ func newBaseLogger(tp string) (*BaseLogger, error) { return logger, nil } -// GetBaseLogger creates and returns a new base logger instance func GetBaseLogger(tp string) *BaseLogger { baseLogger, _ := InitializeBase(tp) return baseLogger } -// Close closes the log file func (bl *BaseLogger) Close() error { bl.mu.Lock() defer bl.mu.Unlock() @@ -82,7 +73,6 @@ func (bl *BaseLogger) Close() error { return nil } -// Log writes a log entry with the specified level func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) { if bl == nil || !bl.initialized { return @@ -91,14 +81,11 @@ func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) { bl.mu.RLock() defer bl.mu.RUnlock() - // Get caller info (skip 2 frames: this function and the calling Log function) _, file, line, _ := runtime.Caller(2) file = filepath.Base(file) - // Format message msg := fmt.Sprintf(format, v...) - // Format final log line logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s", time.Now().Format(timeFormat), string(level), @@ -110,7 +97,6 @@ func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) { bl.logger.Println(logLine) } -// LogWithCaller writes a log entry with custom caller depth func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format string, v ...interface{}) { if bl == nil || !bl.initialized { return @@ -119,14 +105,11 @@ func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format stri bl.mu.RLock() defer bl.mu.RUnlock() - // Get caller info with custom depth _, file, line, _ := runtime.Caller(callerDepth) file = filepath.Base(file) - // Format message msg := fmt.Sprintf(format, v...) - // Format final log line logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s", time.Now().Format(timeFormat), string(level), @@ -138,7 +121,6 @@ func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format stri bl.logger.Println(logLine) } -// IsInitialized returns whether the base logger is initialized func (bl *BaseLogger) IsInitialized() bool { if bl == nil { return false @@ -148,19 +130,16 @@ func (bl *BaseLogger) IsInitialized() bool { return bl.initialized } -// RecoverAndLog recovers from panics and logs them func RecoverAndLog() { baseLogger := GetBaseLogger("panic") if baseLogger != nil && baseLogger.IsInitialized() { if r := recover(); r != nil { - // Get stack trace buf := make([]byte, 4096) n := runtime.Stack(buf, false) stackTrace := string(buf[:n]) baseLogger.LogWithCaller(LogLevelPanic, 2, "Recovered from panic: %v\nStack Trace:\n%s", r, stackTrace) - // Re-panic to maintain original behavior if needed panic(r) } } diff --git a/local/utl/logging/debug.go b/local/utl/logging/debug.go index 0422ad6..a1bbf6d 100644 --- a/local/utl/logging/debug.go +++ b/local/utl/logging/debug.go @@ -6,12 +6,10 @@ import ( "sync" ) -// DebugLogger handles debug-level logging type DebugLogger struct { base *BaseLogger } -// NewDebugLogger creates a new debug logger instance func NewDebugLogger() *DebugLogger { base, _ := InitializeBase("debug") return &DebugLogger{ @@ -19,14 +17,12 @@ func NewDebugLogger() *DebugLogger { } } -// Log writes a debug-level log entry func (dl *DebugLogger) Log(format string, v ...interface{}) { if dl.base != nil { dl.base.Log(LogLevelDebug, format, v...) } } -// LogWithContext writes a debug-level log entry with additional context func (dl *DebugLogger) LogWithContext(context string, format string, v ...interface{}) { if dl.base != nil { contextualFormat := fmt.Sprintf("[%s] %s", context, format) @@ -34,7 +30,6 @@ func (dl *DebugLogger) LogWithContext(context string, format string, v ...interf } } -// LogFunction logs function entry and exit for debugging func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) { if dl.base != nil { if len(args) > 0 { @@ -45,21 +40,18 @@ func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) { } } -// LogVariable logs variable values for debugging func (dl *DebugLogger) LogVariable(varName string, value interface{}) { if dl.base != nil { dl.base.Log(LogLevelDebug, "VARIABLE [%s]: %+v", varName, value) } } -// LogState logs application state information func (dl *DebugLogger) LogState(component string, state interface{}) { if dl.base != nil { dl.base.Log(LogLevelDebug, "STATE [%s]: %+v", component, state) } } -// LogSQL logs SQL queries for debugging func (dl *DebugLogger) LogSQL(query string, args ...interface{}) { if dl.base != nil { if len(args) > 0 { @@ -70,7 +62,6 @@ func (dl *DebugLogger) LogSQL(query string, args ...interface{}) { } } -// LogMemory logs memory usage information func (dl *DebugLogger) LogMemory() { if dl.base != nil { var m runtime.MemStats @@ -80,32 +71,27 @@ func (dl *DebugLogger) LogMemory() { } } -// LogGoroutines logs current number of goroutines func (dl *DebugLogger) LogGoroutines() { if dl.base != nil { dl.base.Log(LogLevelDebug, "GOROUTINES: %d active", runtime.NumGoroutine()) } } -// LogTiming logs timing information for performance debugging func (dl *DebugLogger) LogTiming(operation string, duration interface{}) { if dl.base != nil { dl.base.Log(LogLevelDebug, "TIMING [%s]: %v", operation, duration) } } -// Helper function to convert bytes to kilobytes func bToKb(b uint64) uint64 { return b / 1024 } -// Global debug logger instance var ( debugLogger *DebugLogger debugOnce sync.Once ) -// GetDebugLogger returns the global debug logger instance func GetDebugLogger() *DebugLogger { debugOnce.Do(func() { debugLogger = NewDebugLogger() @@ -113,47 +99,38 @@ func GetDebugLogger() *DebugLogger { return debugLogger } -// Debug logs a debug-level message using the global debug logger func Debug(format string, v ...interface{}) { GetDebugLogger().Log(format, v...) } -// DebugWithContext logs a debug-level message with context using the global debug logger func DebugWithContext(context string, format string, v ...interface{}) { GetDebugLogger().LogWithContext(context, format, v...) } -// DebugFunction logs function entry and exit using the global debug logger func DebugFunction(functionName string, args ...interface{}) { GetDebugLogger().LogFunction(functionName, args...) } -// DebugVariable logs variable values using the global debug logger func DebugVariable(varName string, value interface{}) { GetDebugLogger().LogVariable(varName, value) } -// DebugState logs application state information using the global debug logger func DebugState(component string, state interface{}) { GetDebugLogger().LogState(component, state) } -// DebugSQL logs SQL queries using the global debug logger func DebugSQL(query string, args ...interface{}) { GetDebugLogger().LogSQL(query, args...) } -// DebugMemory logs memory usage information using the global debug logger func DebugMemory() { GetDebugLogger().LogMemory() } -// DebugGoroutines logs current number of goroutines using the global debug logger func DebugGoroutines() { GetDebugLogger().LogGoroutines() } -// DebugTiming logs timing information using the global debug logger func DebugTiming(operation string, duration interface{}) { GetDebugLogger().LogTiming(operation, duration) } diff --git a/local/utl/logging/error.go b/local/utl/logging/error.go index d347bbd..1737ef0 100644 --- a/local/utl/logging/error.go +++ b/local/utl/logging/error.go @@ -6,12 +6,10 @@ import ( "sync" ) -// ErrorLogger handles error-level logging type ErrorLogger struct { base *BaseLogger } -// NewErrorLogger creates a new error logger instance func NewErrorLogger() *ErrorLogger { base, _ := InitializeBase("error") return &ErrorLogger{ @@ -19,14 +17,12 @@ func NewErrorLogger() *ErrorLogger { } } -// Log writes an error-level log entry func (el *ErrorLogger) Log(format string, v ...interface{}) { if el.base != nil { el.base.Log(LogLevelError, format, v...) } } -// LogWithContext writes an error-level log entry with additional context func (el *ErrorLogger) LogWithContext(context string, format string, v ...interface{}) { if el.base != nil { contextualFormat := fmt.Sprintf("[%s] %s", context, format) @@ -34,7 +30,6 @@ func (el *ErrorLogger) LogWithContext(context string, format string, v ...interf } } -// LogError logs an error object with optional message func (el *ErrorLogger) LogError(err error, message ...string) { if el.base != nil && err != nil { if len(message) > 0 { @@ -45,7 +40,6 @@ func (el *ErrorLogger) LogError(err error, message ...string) { } } -// LogWithStackTrace logs an error with stack trace func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) { if el.base != nil { // Get stack trace @@ -58,7 +52,6 @@ func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) { } } -// LogFatal logs a fatal error and exits the program func (el *ErrorLogger) LogFatal(format string, v ...interface{}) { if el.base != nil { el.base.Log(LogLevelError, "[FATAL] "+format, v...) @@ -66,13 +59,11 @@ func (el *ErrorLogger) LogFatal(format string, v ...interface{}) { } } -// Global error logger instance var ( errorLogger *ErrorLogger errorOnce sync.Once ) -// GetErrorLogger returns the global error logger instance func GetErrorLogger() *ErrorLogger { errorOnce.Do(func() { errorLogger = NewErrorLogger() @@ -80,27 +71,22 @@ func GetErrorLogger() *ErrorLogger { return errorLogger } -// Error logs an error-level message using the global error logger func Error(format string, v ...interface{}) { GetErrorLogger().Log(format, v...) } -// ErrorWithContext logs an error-level message with context using the global error logger func ErrorWithContext(context string, format string, v ...interface{}) { GetErrorLogger().LogWithContext(context, format, v...) } -// LogError logs an error object using the global error logger func LogError(err error, message ...string) { GetErrorLogger().LogError(err, message...) } -// ErrorWithStackTrace logs an error with stack trace using the global error logger func ErrorWithStackTrace(format string, v ...interface{}) { GetErrorLogger().LogWithStackTrace(format, v...) } -// Fatal logs a fatal error and exits the program using the global error logger func Fatal(format string, v ...interface{}) { GetErrorLogger().LogFatal(format, v...) } diff --git a/local/utl/logging/info.go b/local/utl/logging/info.go index 17553f0..611064b 100644 --- a/local/utl/logging/info.go +++ b/local/utl/logging/info.go @@ -5,12 +5,10 @@ import ( "sync" ) -// InfoLogger handles info-level logging type InfoLogger struct { base *BaseLogger } -// NewInfoLogger creates a new info logger instance func NewInfoLogger() *InfoLogger { base, _ := InitializeBase("info") return &InfoLogger{ @@ -18,14 +16,12 @@ func NewInfoLogger() *InfoLogger { } } -// Log writes an info-level log entry func (il *InfoLogger) Log(format string, v ...interface{}) { if il.base != nil { il.base.Log(LogLevelInfo, format, v...) } } -// LogWithContext writes an info-level log entry with additional context func (il *InfoLogger) LogWithContext(context string, format string, v ...interface{}) { if il.base != nil { contextualFormat := fmt.Sprintf("[%s] %s", context, format) @@ -33,55 +29,47 @@ func (il *InfoLogger) LogWithContext(context string, format string, v ...interfa } } -// LogStartup logs application startup information func (il *InfoLogger) LogStartup(component string, message string) { if il.base != nil { il.base.Log(LogLevelInfo, "STARTUP [%s]: %s", component, message) } } -// LogShutdown logs application shutdown information func (il *InfoLogger) LogShutdown(component string, message string) { if il.base != nil { il.base.Log(LogLevelInfo, "SHUTDOWN [%s]: %s", component, message) } } -// LogOperation logs general operation information func (il *InfoLogger) LogOperation(operation string, details string) { if il.base != nil { il.base.Log(LogLevelInfo, "OPERATION [%s]: %s", operation, details) } } -// LogStatus logs status changes or updates func (il *InfoLogger) LogStatus(component string, status string) { if il.base != nil { il.base.Log(LogLevelInfo, "STATUS [%s]: %s", component, status) } } -// LogRequest logs incoming requests func (il *InfoLogger) LogRequest(method string, path string, userAgent string) { if il.base != nil { il.base.Log(LogLevelInfo, "REQUEST [%s %s] User-Agent: %s", method, path, userAgent) } } -// LogResponse logs outgoing responses func (il *InfoLogger) LogResponse(method string, path string, statusCode int, duration string) { if il.base != nil { il.base.Log(LogLevelInfo, "RESPONSE [%s %s] Status: %d, Duration: %s", method, path, statusCode, duration) } } -// Global info logger instance var ( infoLogger *InfoLogger infoOnce sync.Once ) -// GetInfoLogger returns the global info logger instance func GetInfoLogger() *InfoLogger { infoOnce.Do(func() { infoLogger = NewInfoLogger() @@ -89,42 +77,34 @@ func GetInfoLogger() *InfoLogger { return infoLogger } -// Info logs an info-level message using the global info logger func Info(format string, v ...interface{}) { GetInfoLogger().Log(format, v...) } -// InfoWithContext logs an info-level message with context using the global info logger func InfoWithContext(context string, format string, v ...interface{}) { GetInfoLogger().LogWithContext(context, format, v...) } -// InfoStartup logs application startup information using the global info logger func InfoStartup(component string, message string) { GetInfoLogger().LogStartup(component, message) } -// InfoShutdown logs application shutdown information using the global info logger func InfoShutdown(component string, message string) { GetInfoLogger().LogShutdown(component, message) } -// InfoOperation logs general operation information using the global info logger func InfoOperation(operation string, details string) { GetInfoLogger().LogOperation(operation, details) } -// InfoStatus logs status changes or updates using the global info logger func InfoStatus(component string, status string) { GetInfoLogger().LogStatus(component, status) } -// InfoRequest logs incoming requests using the global info logger func InfoRequest(method string, path string, userAgent string) { GetInfoLogger().LogRequest(method, path, userAgent) } -// InfoResponse logs outgoing responses using the global info logger func InfoResponse(method string, path string, statusCode int, duration string) { GetInfoLogger().LogResponse(method, path, statusCode, duration) } diff --git a/local/utl/logging/logger.go b/local/utl/logging/logger.go index 258c0a9..e0079c4 100644 --- a/local/utl/logging/logger.go +++ b/local/utl/logging/logger.go @@ -6,12 +6,10 @@ import ( ) var ( - // Legacy logger for backward compatibility logger *Logger once sync.Once ) -// Logger maintains backward compatibility with existing code type Logger struct { base *BaseLogger errorLogger *ErrorLogger @@ -20,8 +18,6 @@ type Logger struct { debugLogger *DebugLogger } -// Initialize creates or gets the singleton logger instance -// This maintains backward compatibility with existing code func Initialize() (*Logger, error) { var err error once.Do(func() { @@ -31,13 +27,11 @@ func Initialize() (*Logger, error) { } func newLogger() (*Logger, error) { - // Initialize the base logger baseLogger, err := InitializeBase("log") if err != nil { return nil, err } - // Create the legacy logger wrapper logger := &Logger{ base: baseLogger, errorLogger: GetErrorLogger(), @@ -49,7 +43,6 @@ func newLogger() (*Logger, error) { return logger, nil } -// Close closes the logger func (l *Logger) Close() error { if l.base != nil { return l.base.Close() @@ -57,7 +50,6 @@ func (l *Logger) Close() error { return nil } -// Legacy methods for backward compatibility func (l *Logger) log(level, format string, v ...interface{}) { if l.base != nil { l.base.LogWithCaller(LogLevel(level), 3, format, v...) @@ -94,13 +86,10 @@ func (l *Logger) Panic(format string) { } } -// Global convenience functions for backward compatibility -// These are now implemented in individual logger files to avoid redeclaration func LegacyInfo(format string, v ...interface{}) { if logger != nil { logger.Info(format, v...) } else { - // Fallback to direct logger if legacy logger not initialized GetInfoLogger().Log(format, v...) } } @@ -109,7 +98,6 @@ func LegacyError(format string, v ...interface{}) { if logger != nil { logger.Error(format, v...) } else { - // Fallback to direct logger if legacy logger not initialized GetErrorLogger().Log(format, v...) } } @@ -118,7 +106,6 @@ func LegacyWarn(format string, v ...interface{}) { if logger != nil { logger.Warn(format, v...) } else { - // Fallback to direct logger if legacy logger not initialized GetWarnLogger().Log(format, v...) } } @@ -127,7 +114,6 @@ func LegacyDebug(format string, v ...interface{}) { if logger != nil { logger.Debug(format, v...) } else { - // Fallback to direct logger if legacy logger not initialized GetDebugLogger().Log(format, v...) } } @@ -136,55 +122,42 @@ func Panic(format string) { if logger != nil { logger.Panic(format) } else { - // Fallback to direct logger if legacy logger not initialized GetErrorLogger().LogFatal(format) } } -// Enhanced logging convenience functions -// These provide direct access to specialized logging functions - -// LogStartup logs application startup information func LogStartup(component string, message string) { GetInfoLogger().LogStartup(component, message) } -// LogShutdown logs application shutdown information func LogShutdown(component string, message string) { GetInfoLogger().LogShutdown(component, message) } -// LogOperation logs general operation information func LogOperation(operation string, details string) { GetInfoLogger().LogOperation(operation, details) } -// LogRequest logs incoming HTTP requests func LogRequest(method string, path string, userAgent string) { GetInfoLogger().LogRequest(method, path, userAgent) } -// LogResponse logs outgoing HTTP responses func LogResponse(method string, path string, statusCode int, duration string) { GetInfoLogger().LogResponse(method, path, statusCode, duration) } -// LogSQL logs SQL queries for debugging func LogSQL(query string, args ...interface{}) { GetDebugLogger().LogSQL(query, args...) } -// LogMemory logs memory usage information func LogMemory() { GetDebugLogger().LogMemory() } -// LogTiming logs timing information for performance debugging func LogTiming(operation string, duration interface{}) { GetDebugLogger().LogTiming(operation, duration) } -// GetLegacyLogger returns the legacy logger instance for backward compatibility func GetLegacyLogger() *Logger { if logger == nil { logger, _ = Initialize() @@ -192,21 +165,17 @@ func GetLegacyLogger() *Logger { return logger } -// InitializeLogging initializes all logging components func InitializeLogging() error { - // Initialize legacy logger for backward compatibility _, err := Initialize() if err != nil { return fmt.Errorf("failed to initialize legacy logger: %v", err) } - // Pre-initialize all logger types to ensure separate log files GetErrorLogger() GetWarnLogger() GetInfoLogger() GetDebugLogger() - // Log successful initialization Info("Logging system initialized successfully") return nil diff --git a/local/utl/logging/warn.go b/local/utl/logging/warn.go index 376d0ea..3044140 100644 --- a/local/utl/logging/warn.go +++ b/local/utl/logging/warn.go @@ -5,12 +5,10 @@ import ( "sync" ) -// WarnLogger handles warn-level logging type WarnLogger struct { base *BaseLogger } -// NewWarnLogger creates a new warn logger instance func NewWarnLogger() *WarnLogger { base, _ := InitializeBase("warn") return &WarnLogger{ @@ -18,14 +16,12 @@ func NewWarnLogger() *WarnLogger { } } -// Log writes a warn-level log entry func (wl *WarnLogger) Log(format string, v ...interface{}) { if wl.base != nil { wl.base.Log(LogLevelWarn, format, v...) } } -// LogWithContext writes a warn-level log entry with additional context func (wl *WarnLogger) LogWithContext(context string, format string, v ...interface{}) { if wl.base != nil { contextualFormat := fmt.Sprintf("[%s] %s", context, format) @@ -33,7 +29,6 @@ func (wl *WarnLogger) LogWithContext(context string, format string, v ...interfa } } -// LogDeprecation logs a deprecation warning func (wl *WarnLogger) LogDeprecation(feature string, alternative string) { if wl.base != nil { if alternative != "" { @@ -44,27 +39,23 @@ func (wl *WarnLogger) LogDeprecation(feature string, alternative string) { } } -// LogConfiguration logs configuration-related warnings func (wl *WarnLogger) LogConfiguration(setting string, message string) { if wl.base != nil { wl.base.Log(LogLevelWarn, "CONFIG WARNING [%s]: %s", setting, message) } } -// LogPerformance logs performance-related warnings func (wl *WarnLogger) LogPerformance(operation string, threshold string, actual string) { if wl.base != nil { wl.base.Log(LogLevelWarn, "PERFORMANCE WARNING [%s]: exceeded threshold %s, actual: %s", operation, threshold, actual) } } -// Global warn logger instance var ( warnLogger *WarnLogger warnOnce sync.Once ) -// GetWarnLogger returns the global warn logger instance func GetWarnLogger() *WarnLogger { warnOnce.Do(func() { warnLogger = NewWarnLogger() @@ -72,27 +63,22 @@ func GetWarnLogger() *WarnLogger { return warnLogger } -// Warn logs a warn-level message using the global warn logger func Warn(format string, v ...interface{}) { GetWarnLogger().Log(format, v...) } -// WarnWithContext logs a warn-level message with context using the global warn logger func WarnWithContext(context string, format string, v ...interface{}) { GetWarnLogger().LogWithContext(context, format, v...) } -// WarnDeprecation logs a deprecation warning using the global warn logger func WarnDeprecation(feature string, alternative string) { GetWarnLogger().LogDeprecation(feature, alternative) } -// WarnConfiguration logs configuration-related warnings using the global warn logger func WarnConfiguration(setting string, message string) { GetWarnLogger().LogConfiguration(setting, message) } -// WarnPerformance logs performance-related warnings using the global warn logger func WarnPerformance(operation string, threshold string, actual string) { GetWarnLogger().LogPerformance(operation, threshold, actual) } diff --git a/local/utl/network/port.go b/local/utl/network/port.go index ea2ce1e..42e0779 100644 --- a/local/utl/network/port.go +++ b/local/utl/network/port.go @@ -6,12 +6,10 @@ import ( "time" ) -// IsPortAvailable checks if a port is available for both TCP and UDP func IsPortAvailable(port int) bool { return IsTCPPortAvailable(port) && IsUDPPortAvailable(port) } -// IsTCPPortAvailable checks if a TCP port is available func IsTCPPortAvailable(port int) bool { addr := fmt.Sprintf(":%d", port) listener, err := net.Listen("tcp", addr) @@ -22,7 +20,6 @@ func IsTCPPortAvailable(port int) bool { return true } -// IsUDPPortAvailable checks if a UDP port is available func IsUDPPortAvailable(port int) bool { conn, err := net.ListenUDP("udp", &net.UDPAddr{Port: port}) if err != nil { @@ -32,7 +29,6 @@ func IsUDPPortAvailable(port int) bool { return true } -// FindAvailablePort finds an available port starting from the given port func FindAvailablePort(startPort int) (int, error) { maxPort := 65535 for port := startPort; port <= maxPort; port++ { @@ -43,14 +39,12 @@ func FindAvailablePort(startPort int) (int, error) { return 0, fmt.Errorf("no available ports found between %d and %d", startPort, maxPort) } -// FindAvailablePortRange finds a range of consecutive available ports func FindAvailablePortRange(startPort, count int) ([]int, error) { maxPort := 65535 ports := make([]int, 0, count) currentPort := startPort for len(ports) < count && currentPort <= maxPort { - // Check if we have enough consecutive ports available available := true for i := 0; i < count-len(ports); i++ { if !IsPortAvailable(currentPort + i) { @@ -74,7 +68,6 @@ func FindAvailablePortRange(startPort, count int) ([]int, error) { return ports, nil } -// WaitForPortAvailable waits for a port to become available with timeout func WaitForPortAvailable(port int, timeout time.Duration) error { deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { @@ -84,4 +77,4 @@ func WaitForPortAvailable(port int, timeout time.Duration) error { time.Sleep(100 * time.Millisecond) } return fmt.Errorf("timeout waiting for port %d to become available", port) -} \ No newline at end of file +} diff --git a/local/utl/password/password.go b/local/utl/password/password.go index ad2424f..82e127e 100644 --- a/local/utl/password/password.go +++ b/local/utl/password/password.go @@ -8,13 +8,10 @@ import ( ) const ( - // MinPasswordLength defines the minimum password length MinPasswordLength = 8 - // BcryptCost defines the cost factor for bcrypt hashing - BcryptCost = 12 + BcryptCost = 12 ) -// HashPassword hashes a plain text password using bcrypt func HashPassword(password string) (string, error) { if len(password) < MinPasswordLength { return "", errors.New("password must be at least 8 characters long") @@ -28,12 +25,10 @@ func HashPassword(password string) (string, error) { return string(hashedBytes), nil } -// VerifyPassword verifies a plain text password against a hashed password func VerifyPassword(hashedPassword, password string) error { return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) } -// ValidatePasswordStrength validates password complexity requirements func ValidatePasswordStrength(password string) error { if len(password) < MinPasswordLength { return errors.New("password must be at least 8 characters long") diff --git a/local/utl/server/server.go b/local/utl/server/server.go index ba5524e..9d82f6f 100644 --- a/local/utl/server/server.go +++ b/local/utl/server/server.go @@ -17,25 +17,23 @@ import ( func Start(di *dig.Container) *fiber.App { app := fiber.New(fiber.Config{ EnablePrintRoutes: true, - ReadTimeout: 20 * time.Minute, // Increased for long-running Steam operations - WriteTimeout: 20 * time.Minute, // Increased for long-running Steam operations - IdleTimeout: 25 * time.Minute, // Increased accordingly - BodyLimit: 10 * 1024 * 1024, // 10MB + ReadTimeout: 20 * time.Minute, + WriteTimeout: 20 * time.Minute, + IdleTimeout: 25 * time.Minute, + BodyLimit: 10 * 1024 * 1024, }) - // Initialize security middleware securityMW := security.NewSecurityMiddleware() - // Add security middleware stack app.Use(securityMW.SecurityHeaders()) app.Use(securityMW.LogSecurityEvents()) - app.Use(securityMW.TimeoutMiddleware(20 * time.Minute)) // Increased for Steam operations - app.Use(securityMW.RequestContextTimeout(20 * time.Minute)) // Increased for Steam operations - app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024)) // 10MB + app.Use(securityMW.TimeoutMiddleware(20 * time.Minute)) + app.Use(securityMW.RequestContextTimeout(20 * time.Minute)) + app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024)) app.Use(securityMW.ValidateUserAgent()) app.Use(securityMW.ValidateContentType("application/json", "application/x-www-form-urlencoded", "multipart/form-data")) app.Use(securityMW.InputSanitization()) - app.Use(securityMW.RateLimit(100, 1*time.Minute)) // 100 requests per minute global + app.Use(securityMW.RateLimit(100, 1*time.Minute)) app.Use(helmet.New()) @@ -62,7 +60,7 @@ func Start(di *dig.Container) *fiber.App { port := os.Getenv("PORT") if port == "" { - port = "3000" // Default port + port = "3000" } logging.Info("Starting server on port %s", port) diff --git a/local/utl/tracking/log_tailer.go b/local/utl/tracking/log_tailer.go index fad0304..9530334 100644 --- a/local/utl/tracking/log_tailer.go +++ b/local/utl/tracking/log_tailer.go @@ -7,11 +7,11 @@ import ( ) type LogTailer struct { - filePath string - handleLine func(string) - stopChan chan struct{} - isRunning bool - tracker *PositionTracker + filePath string + handleLine func(string) + stopChan chan struct{} + isRunning bool + tracker *PositionTracker } func NewLogTailer(filePath string, handleLine func(string)) *LogTailer { @@ -30,10 +30,9 @@ func (t *LogTailer) Start() { t.isRunning = true go func() { - // Load last position from tracker pos, err := t.tracker.LoadPosition() if err != nil { - pos = &LogPosition{} // Start from beginning if error + pos = &LogPosition{} } lastSize := pos.LastPosition @@ -43,7 +42,6 @@ func (t *LogTailer) Start() { t.isRunning = false return default: - // Try to open and read the file if file, err := os.Open(t.filePath); err == nil { stat, err := file.Stat() if err != nil { @@ -52,12 +50,10 @@ func (t *LogTailer) Start() { continue } - // If file was truncated, start from beginning if stat.Size() < lastSize { lastSize = 0 } - // Seek to last read position if lastSize > 0 { file.Seek(lastSize, 0) } @@ -66,9 +62,8 @@ func (t *LogTailer) Start() { for scanner.Scan() { line := scanner.Text() t.handleLine(line) - lastSize, _ = file.Seek(0, 1) // Get current position - - // Save position periodically + lastSize, _ = file.Seek(0, 1) + t.tracker.SavePosition(&LogPosition{ LastPosition: lastSize, LastRead: line, @@ -78,7 +73,6 @@ func (t *LogTailer) Start() { file.Close() } - // Wait before next attempt time.Sleep(time.Second) } } @@ -90,4 +84,4 @@ func (t *LogTailer) Stop() { return } close(t.stopChan) -} \ No newline at end of file +} diff --git a/local/utl/tracking/position_tracker.go b/local/utl/tracking/position_tracker.go index 3dc1075..13c8827 100644 --- a/local/utl/tracking/position_tracker.go +++ b/local/utl/tracking/position_tracker.go @@ -7,8 +7,8 @@ import ( ) type LogPosition struct { - LastPosition int64 `json:"last_position"` - LastRead string `json:"last_read"` + LastPosition int64 `json:"last_position"` + LastRead string `json:"last_read"` } type PositionTracker struct { @@ -16,11 +16,10 @@ type PositionTracker struct { } func NewPositionTracker(logPath string) *PositionTracker { - // Create position file in same directory as log file dir := filepath.Dir(logPath) base := filepath.Base(logPath) positionFile := filepath.Join(dir, "."+base+".position") - + return &PositionTracker{ positionFile: positionFile, } @@ -30,7 +29,6 @@ func (t *PositionTracker) LoadPosition() (*LogPosition, error) { data, err := os.ReadFile(t.positionFile) if err != nil { if os.IsNotExist(err) { - // Return empty position if file doesn't exist return &LogPosition{}, nil } return nil, err @@ -51,4 +49,4 @@ func (t *PositionTracker) SavePosition(pos *LogPosition) error { } return os.WriteFile(t.positionFile, data, 0644) -} \ No newline at end of file +} diff --git a/local/utl/tracking/tracking.go b/local/utl/tracking/tracking.go index 1605ae1..7dcd294 100644 --- a/local/utl/tracking/tracking.go +++ b/local/utl/tracking/tracking.go @@ -80,7 +80,7 @@ func TailLogFile(path string, callback func(string)) { file, _ := os.Open(path) defer file.Close() - file.Seek(0, os.SEEK_END) // Start at end of file + file.Seek(0, os.SEEK_END) reader := bufio.NewReader(file) for { @@ -88,7 +88,7 @@ func TailLogFile(path string, callback func(string)) { if err == nil { callback(line) } else { - time.Sleep(500 * time.Millisecond) // wait for new data + time.Sleep(500 * time.Millisecond) } } } diff --git a/local/utl/websocket/websocket.go b/local/utl/websocket/websocket.go new file mode 100644 index 0000000..f43fe81 --- /dev/null +++ b/local/utl/websocket/websocket.go @@ -0,0 +1,169 @@ +package websocket + +import ( + "acc-server-manager/local/model" + "acc-server-manager/local/utl/logging" + "encoding/json" + "sync" + "time" + + "github.com/gofiber/websocket/v2" + "github.com/google/uuid" +) + +type WebSocketConnection struct { + conn *websocket.Conn + serverID *uuid.UUID + userID *uuid.UUID +} + +type WebSocketService struct { + connections sync.Map + mu sync.RWMutex +} + +func NewWebSocketService() *WebSocketService { + return &WebSocketService{} +} + +func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, userID *uuid.UUID) { + wsConn := &WebSocketConnection{ + conn: conn, + userID: userID, + } + ws.connections.Store(connID, wsConn) + logging.Info("WebSocket connection added: %s for user: %v", connID, userID) +} + +func (ws *WebSocketService) RemoveConnection(connID string) { + if conn, exists := ws.connections.LoadAndDelete(connID); exists { + if wsConn, ok := conn.(*WebSocketConnection); ok { + wsConn.conn.Close() + } + } + logging.Info("WebSocket connection removed: %s", connID) +} + +func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) { + if conn, exists := ws.connections.Load(connID); exists { + if wsConn, ok := conn.(*WebSocketConnection); ok { + wsConn.serverID = &serverID + } + } +} + +func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerCreationStep, status model.StepStatus, message string, errorMsg string) { + stepMsg := model.StepMessage{ + Step: step, + Status: status, + Message: message, + Error: errorMsg, + } + + wsMsg := model.WebSocketMessage{ + Type: model.MessageTypeStep, + ServerID: &serverID, + Timestamp: time.Now().Unix(), + Data: stepMsg, + } + + ws.broadcastToServer(serverID, wsMsg) +} + +func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output string, isError bool) { + steamMsg := model.SteamOutputMessage{ + Output: output, + IsError: isError, + } + + wsMsg := model.WebSocketMessage{ + Type: model.MessageTypeSteamOutput, + ServerID: &serverID, + Timestamp: time.Now().Unix(), + Data: steamMsg, + } + + ws.broadcastToServer(serverID, wsMsg) +} + +func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, details string) { + errorMsg := model.ErrorMessage{ + Error: error, + Details: details, + } + + wsMsg := model.WebSocketMessage{ + Type: model.MessageTypeError, + ServerID: &serverID, + Timestamp: time.Now().Unix(), + Data: errorMsg, + } + + ws.broadcastToServer(serverID, wsMsg) +} + +func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, message string) { + completeMsg := model.CompleteMessage{ + ServerID: serverID, + Success: success, + Message: message, + } + + wsMsg := model.WebSocketMessage{ + Type: model.MessageTypeComplete, + ServerID: &serverID, + Timestamp: time.Now().Unix(), + Data: completeMsg, + } + + ws.broadcastToServer(serverID, wsMsg) +} + +func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.WebSocketMessage) { + data, err := json.Marshal(message) + if err != nil { + logging.Error("Failed to marshal WebSocket message: %v", err) + return + } + + ws.connections.Range(func(key, value interface{}) bool { + if wsConn, ok := value.(*WebSocketConnection); ok { + if wsConn.serverID != nil && *wsConn.serverID == serverID { + if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil { + logging.Error("Failed to send WebSocket message to connection %s: %v", key, err) + ws.RemoveConnection(key.(string)) + } + } + } + return true + }) +} + +func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebSocketMessage) { + data, err := json.Marshal(message) + if err != nil { + logging.Error("Failed to marshal WebSocket message: %v", err) + return + } + + ws.connections.Range(func(key, value interface{}) bool { + if wsConn, ok := value.(*WebSocketConnection); ok { + if wsConn.userID != nil && *wsConn.userID == userID { + if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil { + logging.Error("Failed to send WebSocket message to connection %s: %v", key, err) + ws.RemoveConnection(key.(string)) + } + } + } + return true + }) +} + +func (ws *WebSocketService) GetActiveConnections() int { + count := 0 + ws.connections.Range(func(key, value interface{}) bool { + count++ + return true + }) + return count +} diff --git a/swagger/docs.go b/swagger/docs.go index 5725dec..2477ed1 100644 --- a/swagger/docs.go +++ b/swagger/docs.go @@ -1,4 +1,3 @@ -// Package swagger Code generated by swaggo/swag. DO NOT EDIT package swagger import "github.com/swaggo/swag" @@ -2239,7 +2238,6 @@ const docTemplate = `{ } }` -// SwaggerInfo holds exported Swagger Info so clients can modify it var SwaggerInfo = &swag.Spec{ Version: "1.0", Host: "acc-api.jurmanovic.com", diff --git a/tests/auth_helper.go b/tests/auth_helper.go index eb9b795..2411d88 100644 --- a/tests/auth_helper.go +++ b/tests/auth_helper.go @@ -10,24 +10,19 @@ import ( "github.com/google/uuid" ) -// GenerateTestToken creates a JWT token for testing purposes func GenerateTestToken() (string, error) { - // Create test user user := &model.User{ ID: uuid.New(), Username: "test_user", RoleID: uuid.New(), } - // Use the environment JWT_SECRET for consistency with middleware testSecret := os.Getenv("JWT_SECRET") if testSecret == "" { - // Fallback to a test secret if env var is not set testSecret = "test-secret-that-is-at-least-32-bytes-long-for-security" } jwtHandler := jwt.NewJWTHandler(testSecret) - // Generate JWT token token, err := jwtHandler.GenerateToken(user.ID.String()) if err != nil { return "", fmt.Errorf("failed to generate test token: %w", err) @@ -36,8 +31,6 @@ func GenerateTestToken() (string, error) { return token, nil } -// MustGenerateTestToken generates a test token and panics if it fails -// This is useful for test setup where failing to generate a token is a fatal error func MustGenerateTestToken() string { token, err := GenerateTestToken() if err != nil { @@ -46,24 +39,19 @@ func MustGenerateTestToken() string { return token } -// GenerateTestTokenWithExpiry creates a JWT token with a specific expiry time func GenerateTestTokenWithExpiry(expiryTime time.Time) (string, error) { - // Use the environment JWT_SECRET for consistency with middleware testSecret := os.Getenv("JWT_SECRET") if testSecret == "" { - // Fallback to a test secret if env var is not set testSecret = "test-secret-that-is-at-least-32-bytes-long-for-security" } jwtHandler := jwt.NewJWTHandler(testSecret) - // Create test user user := &model.User{ ID: uuid.New(), Username: "test_user", RoleID: uuid.New(), } - // Generate JWT token with custom expiry token, err := jwtHandler.GenerateTokenWithExpiry(user, expiryTime) if err != nil { return "", fmt.Errorf("failed to generate test token with expiry: %w", err) @@ -72,8 +60,6 @@ func GenerateTestTokenWithExpiry(expiryTime time.Time) (string, error) { return token, nil } -// AddAuthHeader adds a test auth token to the request headers -// This is a convenience method for tests that need to authenticate requests func AddAuthHeader(headers map[string]string) (map[string]string, error) { token, err := GenerateTestToken() if err != nil { @@ -88,7 +74,6 @@ func AddAuthHeader(headers map[string]string) (map[string]string, error) { return headers, nil } -// MustAddAuthHeader adds a test auth token to the request headers and panics if it fails func MustAddAuthHeader(headers map[string]string) map[string]string { result, err := AddAuthHeader(headers) if err != nil { diff --git a/tests/mocks/auth_middleware_mock.go b/tests/mocks/auth_middleware_mock.go index fbb5841..7edf4a4 100644 --- a/tests/mocks/auth_middleware_mock.go +++ b/tests/mocks/auth_middleware_mock.go @@ -8,26 +8,20 @@ import ( "github.com/google/uuid" ) -// MockAuthMiddleware provides a test implementation of AuthMiddleware -// that can be used as a drop-in replacement for the real AuthMiddleware type MockAuthMiddleware struct{} -// NewMockAuthMiddleware creates a new MockAuthMiddleware func NewMockAuthMiddleware() *MockAuthMiddleware { return &MockAuthMiddleware{} } -// Authenticate is a middleware that allows all requests without authentication for testing func (m *MockAuthMiddleware) Authenticate(ctx *fiber.Ctx) error { - // Set a mock user ID in context mockUserID := uuid.New().String() ctx.Locals("userID", mockUserID) - // Set mock user info mockUserInfo := &middleware.CachedUserInfo{ UserID: mockUserID, Username: "test_user", - RoleName: "Admin", // Admin role to bypass permission checks + RoleName: "Admin", Permissions: map[string]bool{"*": true}, CachedAt: time.Now(), } @@ -38,21 +32,18 @@ func (m *MockAuthMiddleware) Authenticate(ctx *fiber.Ctx) error { return ctx.Next() } -// HasPermission is a middleware that allows all permission checks to pass for testing func (m *MockAuthMiddleware) HasPermission(requiredPermission string) fiber.Handler { return func(ctx *fiber.Ctx) error { return ctx.Next() } } -// AuthRateLimit is a test implementation that allows all requests func (m *MockAuthMiddleware) AuthRateLimit() fiber.Handler { return func(ctx *fiber.Ctx) error { return ctx.Next() } } -// RequireHTTPS is a test implementation that allows all HTTP requests func (m *MockAuthMiddleware) RequireHTTPS() fiber.Handler { return func(ctx *fiber.Ctx) error { return ctx.Next() diff --git a/tests/mocks/repository_mock.go b/tests/mocks/repository_mock.go index 1a6e27c..f29c15f 100644 --- a/tests/mocks/repository_mock.go +++ b/tests/mocks/repository_mock.go @@ -8,7 +8,6 @@ import ( "github.com/google/uuid" ) -// MockConfigRepository provides a mock implementation of ConfigRepository type MockConfigRepository struct { configs map[string]*model.Config shouldFailGet bool @@ -21,7 +20,6 @@ func NewMockConfigRepository() *MockConfigRepository { } } -// UpdateConfig mocks the UpdateConfig method func (m *MockConfigRepository) UpdateConfig(ctx context.Context, config *model.Config) *model.Config { if m.shouldFailUpdate { return nil @@ -36,18 +34,15 @@ func (m *MockConfigRepository) UpdateConfig(ctx context.Context, config *model.C return config } -// SetShouldFailUpdate configures the mock to fail on UpdateConfig calls func (m *MockConfigRepository) SetShouldFailUpdate(shouldFail bool) { m.shouldFailUpdate = shouldFail } -// GetConfig retrieves a config by server ID and config file func (m *MockConfigRepository) GetConfig(serverID uuid.UUID, configFile string) *model.Config { key := serverID.String() + "_" + configFile return m.configs[key] } -// MockServerRepository provides a mock implementation of ServerRepository type MockServerRepository struct { servers map[uuid.UUID]*model.Server shouldFailGet bool @@ -59,7 +54,6 @@ func NewMockServerRepository() *MockServerRepository { } } -// GetByID mocks the GetByID method func (m *MockServerRepository) GetByID(ctx context.Context, id interface{}) (*model.Server, error) { if m.shouldFailGet { return nil, errors.New("server not found") @@ -88,17 +82,14 @@ func (m *MockServerRepository) GetByID(ctx context.Context, id interface{}) (*mo return server, nil } -// AddServer adds a server to the mock repository func (m *MockServerRepository) AddServer(server *model.Server) { m.servers[server.ID] = server } -// SetShouldFailGet configures the mock to fail on GetByID calls func (m *MockServerRepository) SetShouldFailGet(shouldFail bool) { m.shouldFailGet = shouldFail } -// MockServerService provides a mock implementation of ServerService type MockServerService struct { startRuntimeCalled bool startRuntimeServer *model.Server @@ -108,23 +99,19 @@ func NewMockServerService() *MockServerService { return &MockServerService{} } -// StartAccServerRuntime mocks the StartAccServerRuntime method func (m *MockServerService) StartAccServerRuntime(server *model.Server) { m.startRuntimeCalled = true m.startRuntimeServer = server } -// WasStartRuntimeCalled returns whether StartAccServerRuntime was called func (m *MockServerService) WasStartRuntimeCalled() bool { return m.startRuntimeCalled } -// GetStartRuntimeServer returns the server passed to StartAccServerRuntime func (m *MockServerService) GetStartRuntimeServer() *model.Server { return m.startRuntimeServer } -// Reset resets the mock state func (m *MockServerService) Reset() { m.startRuntimeCalled = false m.startRuntimeServer = nil diff --git a/tests/mocks/state_history_mock.go b/tests/mocks/state_history_mock.go index 09e1740..bc811ea 100644 --- a/tests/mocks/state_history_mock.go +++ b/tests/mocks/state_history_mock.go @@ -8,7 +8,6 @@ import ( "github.com/google/uuid" ) -// MockStateHistoryRepository provides a mock implementation of StateHistoryRepository type MockStateHistoryRepository struct { stateHistories []model.StateHistory shouldFailGet bool @@ -21,7 +20,6 @@ func NewMockStateHistoryRepository() *MockStateHistoryRepository { } } -// GetAll mocks the GetAll method func (m *MockStateHistoryRepository) GetAll(ctx context.Context, filter *model.StateHistoryFilter) (*[]model.StateHistory, error) { if m.shouldFailGet { return nil, errors.New("failed to get state history") @@ -37,13 +35,11 @@ func (m *MockStateHistoryRepository) GetAll(ctx context.Context, filter *model.S return &filtered, nil } -// Insert mocks the Insert method func (m *MockStateHistoryRepository) Insert(ctx context.Context, stateHistory *model.StateHistory) error { if m.shouldFailInsert { return errors.New("failed to insert state history") } - // Simulate BeforeCreate hook if stateHistory.ID == uuid.Nil { stateHistory.ID = uuid.New() } @@ -55,7 +51,6 @@ func (m *MockStateHistoryRepository) Insert(ctx context.Context, stateHistory *m return nil } -// GetLastSessionID mocks the GetLastSessionID method func (m *MockStateHistoryRepository) GetLastSessionID(ctx context.Context, serverID uuid.UUID) (uuid.UUID, error) { for i := len(m.stateHistories) - 1; i >= 0; i-- { if m.stateHistories[i].ServerID == serverID { @@ -65,7 +60,6 @@ func (m *MockStateHistoryRepository) GetLastSessionID(ctx context.Context, serve return uuid.Nil, nil } -// Helper methods for filtering func (m *MockStateHistoryRepository) matchesFilter(sh model.StateHistory, filter *model.StateHistoryFilter) bool { if filter == nil { return true @@ -93,7 +87,6 @@ func (m *MockStateHistoryRepository) matchesFilter(sh model.StateHistory, filter return true } -// Helper methods for testing configuration func (m *MockStateHistoryRepository) SetShouldFailGet(shouldFail bool) { m.shouldFailGet = shouldFail } @@ -102,7 +95,6 @@ func (m *MockStateHistoryRepository) SetShouldFailInsert(shouldFail bool) { m.shouldFailInsert = shouldFail } -// AddStateHistory adds a state history entry to the mock repository func (m *MockStateHistoryRepository) AddStateHistory(stateHistory model.StateHistory) { if stateHistory.ID == uuid.Nil { stateHistory.ID = uuid.New() @@ -113,22 +105,18 @@ func (m *MockStateHistoryRepository) AddStateHistory(stateHistory model.StateHis m.stateHistories = append(m.stateHistories, stateHistory) } -// GetCount returns the number of state history entries func (m *MockStateHistoryRepository) GetCount() int { return len(m.stateHistories) } -// Clear removes all state history entries func (m *MockStateHistoryRepository) Clear() { m.stateHistories = make([]model.StateHistory, 0) } -// GetSummaryStats calculates peak players, total sessions, and average players for mock data func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter *model.StateHistoryFilter) (model.StateHistoryStats, error) { var stats model.StateHistoryStats var filteredEntries []model.StateHistory - // Filter entries for _, entry := range m.stateHistories { if m.matchesFilter(entry, filter) { filteredEntries = append(filteredEntries, entry) @@ -139,7 +127,6 @@ func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter return stats, nil } - // Calculate statistics sessionMap := make(map[string]bool) totalPlayers := 0 @@ -159,11 +146,9 @@ func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter return stats, nil } -// GetTotalPlaytime calculates total playtime in minutes for mock data func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *model.StateHistoryFilter) (int, error) { var filteredEntries []model.StateHistory - // Filter entries for _, entry := range m.stateHistories { if m.matchesFilter(entry, filter) { filteredEntries = append(filteredEntries, entry) @@ -174,7 +159,6 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte return 0, nil } - // Group by session and calculate durations sessionMap := make(map[string][]model.StateHistory) for _, entry := range filteredEntries { sessionID := entry.SessionID.String() @@ -184,7 +168,6 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte totalMinutes := 0 for _, sessionEntries := range sessionMap { if len(sessionEntries) > 1 { - // Sort by date (simple approach for mock) minTime := sessionEntries[0].DateCreated maxTime := sessionEntries[0].DateCreated hasPlayers := false @@ -211,26 +194,22 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte return totalMinutes, nil } -// GetPlayerCountOverTime returns downsampled player count data for mock func (m *MockStateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, filter *model.StateHistoryFilter) ([]model.PlayerCountPoint, error) { var points []model.PlayerCountPoint var filteredEntries []model.StateHistory - // Filter entries for _, entry := range m.stateHistories { if m.matchesFilter(entry, filter) { filteredEntries = append(filteredEntries, entry) } } - // Group by hour (simple mock implementation) hourMap := make(map[string][]int) for _, entry := range filteredEntries { hourKey := entry.DateCreated.Format("2006-01-02 15") hourMap[hourKey] = append(hourMap[hourKey], entry.PlayerCount) } - // Calculate averages per hour for hourKey, counts := range hourMap { total := 0 for _, count := range counts { @@ -247,20 +226,17 @@ func (m *MockStateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, return points, nil } -// GetSessionTypes counts sessions by type for mock func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter *model.StateHistoryFilter) ([]model.SessionCount, error) { var sessionTypes []model.SessionCount var filteredEntries []model.StateHistory - // Filter entries for _, entry := range m.stateHistories { if m.matchesFilter(entry, filter) { filteredEntries = append(filteredEntries, entry) } } - // Group by session type - sessionMap := make(map[model.TrackSession]map[string]bool) // session -> sessionID -> bool + sessionMap := make(map[model.TrackSession]map[string]bool) for _, entry := range filteredEntries { if sessionMap[entry.Session] == nil { sessionMap[entry.Session] = make(map[string]bool) @@ -268,7 +244,6 @@ func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter sessionMap[entry.Session][entry.SessionID.String()] = true } - // Count unique sessions per type for sessionType, sessions := range sessionMap { sessionTypes = append(sessionTypes, model.SessionCount{ Name: sessionType, @@ -279,20 +254,17 @@ func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter return sessionTypes, nil } -// GetDailyActivity counts sessions per day for mock func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filter *model.StateHistoryFilter) ([]model.DailyActivity, error) { var dailyActivity []model.DailyActivity var filteredEntries []model.StateHistory - // Filter entries for _, entry := range m.stateHistories { if m.matchesFilter(entry, filter) { filteredEntries = append(filteredEntries, entry) } } - // Group by day - dayMap := make(map[string]map[string]bool) // date -> sessionID -> bool + dayMap := make(map[string]map[string]bool) for _, entry := range filteredEntries { dateKey := entry.DateCreated.Format("2006-01-02") if dayMap[dateKey] == nil { @@ -301,7 +273,6 @@ func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filte dayMap[dateKey][entry.SessionID.String()] = true } - // Count unique sessions per day for date, sessions := range dayMap { dailyActivity = append(dailyActivity, model.DailyActivity{ Date: date, @@ -312,26 +283,22 @@ func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filte return dailyActivity, nil } -// GetRecentSessions retrieves recent sessions for mock func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filter *model.StateHistoryFilter) ([]model.RecentSession, error) { var recentSessions []model.RecentSession var filteredEntries []model.StateHistory - // Filter entries for _, entry := range m.stateHistories { if m.matchesFilter(entry, filter) { filteredEntries = append(filteredEntries, entry) } } - // Group by session sessionMap := make(map[string][]model.StateHistory) for _, entry := range filteredEntries { sessionID := entry.SessionID.String() sessionMap[sessionID] = append(sessionMap[sessionID], entry) } - // Create recent sessions (limit to 10) count := 0 for _, entries := range sessionMap { if count >= 10 { @@ -339,7 +306,6 @@ func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filt } if len(entries) > 0 { - // Find min/max dates and max players minDate := entries[0].DateCreated maxDate := entries[0].DateCreated maxPlayers := 0 @@ -356,7 +322,6 @@ func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filt } } - // Only include sessions with players if maxPlayers > 0 { duration := int(maxDate.Sub(minDate).Minutes()) recentSessions = append(recentSessions, model.RecentSession{ diff --git a/tests/test_helper.go b/tests/test_helper.go index 81e429e..b6d0132 100644 --- a/tests/test_helper.go +++ b/tests/test_helper.go @@ -24,14 +24,12 @@ import ( "gorm.io/gorm/logger" ) -// TestHelper provides utilities for testing type TestHelper struct { DB *gorm.DB TempDir string TestData *TestData } -// TestData contains common test data structures type TestData struct { ServerID uuid.UUID Server *model.Server @@ -39,37 +37,29 @@ type TestData struct { SampleConfig *model.Configuration } -// SetTestEnv sets the required environment variables for tests func SetTestEnv() { - // Set required environment variables for testing os.Setenv("APP_SECRET", "test-secret-key-for-testing-123456") os.Setenv("APP_SECRET_CODE", "test-code-for-testing-123456789012") os.Setenv("ENCRYPTION_KEY", "12345678901234567890123456789012") os.Setenv("JWT_SECRET", "test-jwt-secret-key-for-testing-123456789012345678901234567890") os.Setenv("ACCESS_KEY", "test-access-key-for-testing") - // Set test-specific environment variables - os.Setenv("TESTING_ENV", "true") // Used to bypass + os.Setenv("TESTING_ENV", "true") configs.Init() } -// NewTestHelper creates a new test helper with in-memory database func NewTestHelper(t *testing.T) *TestHelper { - // Set required environment variables SetTestEnv() - // Create temporary directory for test files tempDir := t.TempDir() - // Create in-memory SQLite database for testing db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Silent), // Suppress SQL logs in tests + Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { t.Fatalf("Failed to connect to test database: %v", err) } - // Auto-migrate the schema err = db.AutoMigrate( &model.Server{}, &model.Config{}, @@ -79,7 +69,6 @@ func NewTestHelper(t *testing.T) *TestHelper { &model.StateHistory{}, ) - // Explicitly ensure tables exist with correct structure if !db.Migrator().HasTable(&model.StateHistory{}) { err = db.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -90,7 +79,6 @@ func NewTestHelper(t *testing.T) *TestHelper { t.Fatalf("Failed to migrate test database: %v", err) } - // Create test data testData := createTestData(t, tempDir) return &TestHelper{ @@ -100,11 +88,9 @@ func NewTestHelper(t *testing.T) *TestHelper { } } -// createTestData creates common test data structures func createTestData(t *testing.T, tempDir string) *TestData { serverID := uuid.New() - // Create sample server server := &model.Server{ ID: serverID, Name: "Test Server", @@ -115,13 +101,11 @@ func createTestData(t *testing.T, tempDir string) *TestData { FromSteamCMD: false, } - // Create server directory serverConfigDir := filepath.Join(tempDir, "server", "cfg") if err := os.MkdirAll(serverConfigDir, 0755); err != nil { t.Fatalf("Failed to create server config directory: %v", err) } - // Sample configuration files content configFiles := map[string]string{ "configuration.json": `{ "udpPort": "9231", @@ -212,7 +196,6 @@ func createTestData(t *testing.T, tempDir string) *TestData { }`, } - // Sample configuration struct sampleConfig := &model.Configuration{ UdpPort: model.IntString(9231), TcpPort: model.IntString(9232), @@ -230,14 +213,12 @@ func createTestData(t *testing.T, tempDir string) *TestData { } } -// CreateTestConfigFiles creates actual config files in the test directory func (th *TestHelper) CreateTestConfigFiles() error { serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg") for filename, content := range th.TestData.ConfigFiles { filePath := filepath.Join(serverConfigDir, filename) - // Encode content to UTF-16 LE BOM format as expected by the application utf16Content, err := EncodeUTF16LEBOM([]byte(content)) if err != nil { return err @@ -251,7 +232,6 @@ func (th *TestHelper) CreateTestConfigFiles() error { return nil } -// CreateMalformedConfigFile creates a config file with invalid JSON func (th *TestHelper) CreateMalformedConfigFile(filename string) error { serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg") filePath := filepath.Join(serverConfigDir, filename) @@ -259,67 +239,51 @@ func (th *TestHelper) CreateMalformedConfigFile(filename string) error { malformedJSON := `{ "udpPort": "9231", "tcpPort": "9232" - "maxConnections": "30" // Missing comma - invalid JSON + "maxConnections": "30" }` return os.WriteFile(filePath, []byte(malformedJSON), 0644) } -// RemoveConfigFile removes a config file to simulate missing file scenarios func (th *TestHelper) RemoveConfigFile(filename string) error { serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg") filePath := filepath.Join(serverConfigDir, filename) return os.Remove(filePath) } -// InsertTestServer inserts the test server into the database func (th *TestHelper) InsertTestServer() error { return th.DB.Create(th.TestData.Server).Error } -// CreateContext creates a test context func (th *TestHelper) CreateContext() context.Context { return context.Background() } -// CreateFiberCtx creates a fiber.Ctx for testing func (th *TestHelper) CreateFiberCtx() *fiber.Ctx { - // Create app and request for fiber context app := fiber.New() - // Create a dummy request that doesn't depend on external http objects req := httptest.NewRequest("GET", "/", nil) - // Create the fiber context from real request/response ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) - // Store the original request for release later ctx.Locals("original-request", req) - // Return the context which can be safely used in tests return ctx } -// ReleaseFiberCtx properly releases a fiber context created with CreateFiberCtx func (th *TestHelper) ReleaseFiberCtx(app *fiber.App, ctx *fiber.Ctx) { if app != nil && ctx != nil { app.ReleaseCtx(ctx) } } -// Cleanup performs cleanup operations after tests func (th *TestHelper) Cleanup() { - // Close database connection if sqlDB, err := th.DB.DB(); err == nil { sqlDB.Close() } - // Temporary directory is automatically cleaned up by t.TempDir() } -// LoadTestEnvFile loads environment variables from a .env file for testing func LoadTestEnvFile() error { - // Try to load from .env file return godotenv.Load() } -// AssertNoError is a helper function to check for errors in tests func AssertNoError(t *testing.T, err error) { t.Helper() if err != nil { @@ -327,7 +291,6 @@ func AssertNoError(t *testing.T, err error) { } } -// AssertError is a helper function to check for expected errors func AssertError(t *testing.T, err error, expectedMsg string) { t.Helper() if err == nil { @@ -338,7 +301,6 @@ func AssertError(t *testing.T, err error, expectedMsg string) { } } -// AssertEqual checks if two values are equal func AssertEqual(t *testing.T, expected, actual interface{}) { t.Helper() if expected != actual { @@ -346,7 +308,6 @@ func AssertEqual(t *testing.T, expected, actual interface{}) { } } -// AssertNotNil checks if a value is not nil func AssertNotNil(t *testing.T, value interface{}) { t.Helper() if value == nil { @@ -354,12 +315,9 @@ func AssertNotNil(t *testing.T, value interface{}) { } } -// AssertNil checks if a value is nil func AssertNil(t *testing.T, value interface{}) { t.Helper() if value != nil { - // Special handling for interface values that contain nil but aren't nil themselves - // For example, (*jwt.Claims)(nil) is not equal to nil, but it contains nil switch v := value.(type) { case *interface{}: if v == nil || *v == nil { @@ -374,13 +332,11 @@ func AssertNil(t *testing.T, value interface{}) { } } -// EncodeUTF16LEBOM encodes UTF-8 bytes to UTF-16 LE BOM format func EncodeUTF16LEBOM(input []byte) ([]byte, error) { encoder := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) return transformBytes(encoder.NewEncoder(), input) } -// transformBytes applies a transform to input bytes func transformBytes(t transform.Transformer, input []byte) ([]byte, error) { var buf bytes.Buffer w := transform.NewWriter(&buf, t) @@ -396,7 +352,6 @@ func transformBytes(t transform.Transformer, input []byte) ([]byte, error) { return buf.Bytes(), nil } -// ErrorForTesting creates an error for testing purposes func ErrorForTesting(message string) error { return errors.New(message) } diff --git a/tests/testdata/state_history_data.go b/tests/testdata/state_history_data.go index 7b6962c..cc1bbd9 100644 --- a/tests/testdata/state_history_data.go +++ b/tests/testdata/state_history_data.go @@ -7,13 +7,11 @@ import ( "github.com/google/uuid" ) -// StateHistoryTestData provides simple test data generators type StateHistoryTestData struct { ServerID uuid.UUID BaseTime time.Time } -// NewStateHistoryTestData creates a new test data generator func NewStateHistoryTestData(serverID uuid.UUID) *StateHistoryTestData { return &StateHistoryTestData{ ServerID: serverID, @@ -21,7 +19,6 @@ func NewStateHistoryTestData(serverID uuid.UUID) *StateHistoryTestData { } } -// CreateStateHistory creates a basic state history entry func (td *StateHistoryTestData) CreateStateHistory(session model.TrackSession, track string, playerCount int, sessionID uuid.UUID) model.StateHistory { return model.StateHistory{ ID: uuid.New(), @@ -36,7 +33,6 @@ func (td *StateHistoryTestData) CreateStateHistory(session model.TrackSession, t } } -// CreateMultipleEntries creates multiple state history entries for the same session func (td *StateHistoryTestData) CreateMultipleEntries(session model.TrackSession, track string, playerCounts []int) []model.StateHistory { sessionID := uuid.New() var entries []model.StateHistory @@ -59,7 +55,6 @@ func (td *StateHistoryTestData) CreateMultipleEntries(session model.TrackSession return entries } -// CreateBasicFilter creates a basic filter for testing func CreateBasicFilter(serverID string) *model.StateHistoryFilter { return &model.StateHistoryFilter{ ServerBasedFilter: model.ServerBasedFilter{ @@ -68,7 +63,6 @@ func CreateBasicFilter(serverID string) *model.StateHistoryFilter { } } -// CreateFilterWithSession creates a filter with session type func CreateFilterWithSession(serverID string, session model.TrackSession) *model.StateHistoryFilter { return &model.StateHistoryFilter{ ServerBasedFilter: model.ServerBasedFilter{ @@ -78,7 +72,6 @@ func CreateFilterWithSession(serverID string, session model.TrackSession) *model } } -// LogLines contains sample ACC server log lines for testing var SampleLogLines = []string{ "[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE", "[2024-01-15 14:30:30.456] 1 client(s) online", @@ -95,7 +88,6 @@ var SampleLogLines = []string{ "[2024-01-15 15:00:05.123] Session changed: RACE -> NONE", } -// ExpectedSessionChanges represents the expected session changes from parsing the sample log lines var ExpectedSessionChanges = []struct { From model.TrackSession To model.TrackSession @@ -106,5 +98,4 @@ var ExpectedSessionChanges = []struct { {model.SessionRace, model.SessionUnknown}, } -// ExpectedPlayerCounts represents the expected player counts from parsing the sample log lines var ExpectedPlayerCounts = []int{1, 3, 5, 8, 12, 15, 14, 0} diff --git a/tests/unit/controller/controller_simple_test.go b/tests/unit/controller/controller_simple_test.go index c75aec5..e01581c 100644 --- a/tests/unit/controller/controller_simple_test.go +++ b/tests/unit/controller/controller_simple_test.go @@ -14,12 +14,10 @@ import ( ) func TestController_JSONParsing_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test basic JSON parsing functionality app := fiber.New() app.Post("/test", func(c *fiber.Ctx) error { @@ -30,7 +28,6 @@ func TestController_JSONParsing_Success(t *testing.T) { return c.JSON(data) }) - // Prepare test data testData := map[string]interface{}{ "name": "test", "value": 123, @@ -38,34 +35,28 @@ func TestController_JSONParsing_Success(t *testing.T) { bodyBytes, err := json.Marshal(testData) tests.AssertNoError(t, err) - // Create request req := httptest.NewRequest("POST", "/test", bytes.NewReader(bodyBytes)) req.Header.Set("Content-Type", "application/json") - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) - // Parse response var response map[string]interface{} body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, "test", response["name"]) - tests.AssertEqual(t, float64(123), response["value"]) // JSON numbers are float64 + tests.AssertEqual(t, float64(123), response["value"]) } func TestController_JSONParsing_InvalidJSON(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test handling of invalid JSON app := fiber.New() app.Post("/test", func(c *fiber.Ctx) error { @@ -76,39 +67,31 @@ func TestController_JSONParsing_InvalidJSON(t *testing.T) { return c.JSON(data) }) - // Create request with invalid JSON req := httptest.NewRequest("POST", "/test", bytes.NewReader([]byte("invalid json"))) req.Header.Set("Content-Type", "application/json") - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 400, resp.StatusCode) - // Parse error response var response map[string]interface{} body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - - // Verify error response tests.AssertEqual(t, "Invalid JSON", response["error"]) } func TestController_UUIDValidation_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test UUID parameter validation app := fiber.New() app.Get("/test/:id", func(c *fiber.Ctx) error { id := c.Params("id") - // Validate UUID if _, err := uuid.Parse(id); err != nil { return c.Status(400).JSON(fiber.Map{"error": "Invalid UUID"}) } @@ -116,40 +99,33 @@ func TestController_UUIDValidation_Success(t *testing.T) { return c.JSON(fiber.Map{"id": id, "valid": true}) }) - // Create request with valid UUID validUUID := uuid.New().String() req := httptest.NewRequest("GET", "/test/"+validUUID, nil) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) - // Parse response var response map[string]interface{} body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, validUUID, response["id"]) tests.AssertEqual(t, true, response["valid"]) } func TestController_UUIDValidation_InvalidUUID(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test handling of invalid UUID app := fiber.New() app.Get("/test/:id", func(c *fiber.Ctx) error { id := c.Params("id") - // Validate UUID if _, err := uuid.Parse(id); err != nil { return c.Status(400).JSON(fiber.Map{"error": "Invalid UUID"}) } @@ -157,32 +133,26 @@ func TestController_UUIDValidation_InvalidUUID(t *testing.T) { return c.JSON(fiber.Map{"id": id, "valid": true}) }) - // Create request with invalid UUID req := httptest.NewRequest("GET", "/test/invalid-uuid", nil) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 400, resp.StatusCode) - // Parse error response var response map[string]interface{} body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - // Verify error response tests.AssertEqual(t, "Invalid UUID", response["error"]) } func TestController_QueryParameters_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test query parameter handling app := fiber.New() app.Get("/test", func(c *fiber.Ctx) error { @@ -197,34 +167,28 @@ func TestController_QueryParameters_Success(t *testing.T) { }) }) - // Create request with query parameters req := httptest.NewRequest("GET", "/test?restart=true&override=false&format=xml", nil) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) - // Parse response var response map[string]interface{} body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, true, response["restart"]) tests.AssertEqual(t, false, response["override"]) tests.AssertEqual(t, "xml", response["format"]) } func TestController_HTTPMethods_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test different HTTP methods app := fiber.New() var getCalled, postCalled, putCalled, deleteCalled bool @@ -249,28 +213,24 @@ func TestController_HTTPMethods_Success(t *testing.T) { return c.JSON(fiber.Map{"method": "DELETE"}) }) - // Test GET req := httptest.NewRequest("GET", "/test", nil) resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) tests.AssertEqual(t, true, getCalled) - // Test POST req = httptest.NewRequest("POST", "/test", nil) resp, err = app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) tests.AssertEqual(t, true, postCalled) - // Test PUT req = httptest.NewRequest("PUT", "/test", nil) resp, err = app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) tests.AssertEqual(t, true, putCalled) - // Test DELETE req = httptest.NewRequest("DELETE", "/test", nil) resp, err = app.Test(req) tests.AssertNoError(t, err) @@ -279,12 +239,10 @@ func TestController_HTTPMethods_Success(t *testing.T) { } func TestController_ErrorHandling_StatusCodes(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test different error status codes app := fiber.New() app.Get("/400", func(c *fiber.Ctx) error { @@ -307,7 +265,6 @@ func TestController_ErrorHandling_StatusCodes(t *testing.T) { return c.Status(500).JSON(fiber.Map{"error": "Internal Server Error"}) }) - // Test different status codes testCases := []struct { path string code int @@ -328,12 +285,10 @@ func TestController_ErrorHandling_StatusCodes(t *testing.T) { } func TestController_ConfigurationModel_JSONSerialization(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test Configuration model JSON serialization app := fiber.New() app.Get("/config", func(c *fiber.Ctx) error { @@ -348,22 +303,18 @@ func TestController_ConfigurationModel_JSONSerialization(t *testing.T) { return c.JSON(config) }) - // Create request req := httptest.NewRequest("GET", "/config", nil) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) - // Parse response var response model.Configuration body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, model.IntString(9231), response.UdpPort) tests.AssertEqual(t, model.IntString(9232), response.TcpPort) tests.AssertEqual(t, model.IntString(30), response.MaxConnections) @@ -373,73 +324,61 @@ func TestController_ConfigurationModel_JSONSerialization(t *testing.T) { } func TestController_UserModel_JSONSerialization(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test User model JSON serialization (password should be hidden) app := fiber.New() app.Get("/user", func(c *fiber.Ctx) error { user := &model.User{ ID: uuid.New(), Username: "testuser", - Password: "secret-password", // Should not appear in JSON + Password: "secret-password", RoleID: uuid.New(), } return c.JSON(user) }) - // Create request req := httptest.NewRequest("GET", "/user", nil) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) - // Parse response as raw JSON to check password is excluded body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) - // Verify password field is not in JSON if bytes.Contains(body, []byte("password")) || bytes.Contains(body, []byte("secret-password")) { t.Fatal("Password should not be included in JSON response") } - // Verify other fields are present if !bytes.Contains(body, []byte("username")) || !bytes.Contains(body, []byte("testuser")) { t.Fatal("Username should be included in JSON response") } } func TestController_MiddlewareChaining_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test middleware chaining app := fiber.New() var middleware1Called, middleware2Called, handlerCalled bool - // Middleware 1 middleware1 := func(c *fiber.Ctx) error { middleware1Called = true c.Locals("middleware1", "executed") return c.Next() } - // Middleware 2 middleware2 := func(c *fiber.Ctx) error { middleware2Called = true c.Locals("middleware2", "executed") return c.Next() } - // Handler handler := func(c *fiber.Ctx) error { handlerCalled = true return c.JSON(fiber.Map{ @@ -451,27 +390,22 @@ func TestController_MiddlewareChaining_Success(t *testing.T) { app.Get("/test", middleware1, middleware2, handler) - // Create request req := httptest.NewRequest("GET", "/test", nil) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) - // Verify all were called tests.AssertEqual(t, true, middleware1Called) tests.AssertEqual(t, true, middleware2Called) tests.AssertEqual(t, true, handlerCalled) - // Parse response var response map[string]interface{} body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - // Verify middleware values were passed tests.AssertEqual(t, "executed", response["middleware1"]) tests.AssertEqual(t, "executed", response["middleware2"]) tests.AssertEqual(t, "executed", response["handler"]) diff --git a/tests/unit/controller/helper.go b/tests/unit/controller/helper.go index 0b501e0..25c25ae 100644 --- a/tests/unit/controller/helper.go +++ b/tests/unit/controller/helper.go @@ -11,27 +11,20 @@ import ( "github.com/gofiber/fiber/v2" ) -// MockMiddleware simulates authentication for testing purposes type MockMiddleware struct{} -// GetTestAuthMiddleware returns a mock auth middleware that can be used in place of the real one -// This works because we're adding real authentication tokens to requests func GetTestAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache) *middleware.AuthMiddleware { - // Use environment JWT secrets for consistency with token generation jwtSecret := os.Getenv("JWT_SECRET") if jwtSecret == "" { jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security" } - + jwtHandler := jwt.NewJWTHandler(jwtSecret) - openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret) // Use same secret for test consistency - - // Cast our mock to the real type for testing - // This is a type-unsafe cast but works for testing because we're using real JWT tokens + openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret) + return middleware.NewAuthMiddleware(ms, cache, jwtHandler, openJWTHandler) } -// AddAuthToRequest adds a valid authentication token to a test request func AddAuthToRequest(req *fiber.Ctx) { token := tests.MustGenerateTestToken() req.Request().Header.Set("Authorization", "Bearer "+token) diff --git a/tests/unit/controller/state_history_controller_test.go b/tests/unit/controller/state_history_controller_test.go index a7e0d0d..748608f 100644 --- a/tests/unit/controller/state_history_controller_test.go +++ b/tests/unit/controller/state_history_controller_test.go @@ -23,13 +23,11 @@ import ( ) func TestStateHistoryController_GetAll_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // No need for DisableAuthentication, we'll use real auth tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -44,33 +42,26 @@ func TestStateHistoryController_GetAll_Success(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Insert test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) err := repo.Insert(helper.CreateContext(), &history) tests.AssertNoError(t, err) - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Create request with authentication req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, http.StatusOK, resp.StatusCode) - // Parse response body body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) @@ -83,13 +74,11 @@ func TestStateHistoryController_GetAll_Success(t *testing.T) { } func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -104,7 +93,6 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Insert test data with different sessions testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) @@ -115,27 +103,21 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) { err = repo.Insert(helper.CreateContext(), &raceHistory) tests.AssertNoError(t, err) - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Create request with session filter and authentication req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s&session=R", helper.TestData.ServerID.String()), nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, http.StatusOK, resp.StatusCode) - // Parse response body body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) @@ -148,13 +130,11 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) { } func TestStateHistoryController_GetAll_EmptyResult(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -169,38 +149,30 @@ func TestStateHistoryController_GetAll_EmptyResult(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Create request with no data and authentication req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify empty response tests.AssertEqual(t, http.StatusOK, resp.StatusCode) } func TestStateHistoryController_GetStatistics_Success(t *testing.T) { - // Skip this test as it requires more complex setup t.Skip("Skipping test due to UUID validation issues") - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -215,10 +187,8 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Insert test data with multiple entries for statistics testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) - // Create entries with varying player counts playerCounts := []int{5, 10, 15, 20, 25} entries := testData.CreateMultipleEntries(model.SessionRace, "spa", playerCounts) @@ -227,33 +197,26 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) { tests.AssertNoError(t, err) } - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Create request with valid serverID UUID validServerID := helper.TestData.ServerID.String() if validServerID == "" { - validServerID = uuid.New().String() // Generate a new valid UUID if needed + validServerID = uuid.New().String() } req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil) req.Header.Set("Content-Type", "application/json") - // Add Authorization header for testing req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, http.StatusOK, resp.StatusCode) - // Parse response body body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) @@ -261,7 +224,6 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) { err = json.Unmarshal(body, &stats) tests.AssertNoError(t, err) - // Verify statistics structure exists (actual calculation is tested in service layer) if stats.PeakPlayers < 0 { t.Error("Expected non-negative peak players") } @@ -274,16 +236,13 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) { } func TestStateHistoryController_GetStatistics_NoData(t *testing.T) { - // Skip this test as it requires more complex setup t.Skip("Skipping test due to UUID validation issues") - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -298,33 +257,26 @@ func TestStateHistoryController_GetStatistics_NoData(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Create request with valid serverID UUID validServerID := helper.TestData.ServerID.String() if validServerID == "" { - validServerID = uuid.New().String() // Generate a new valid UUID if needed + validServerID = uuid.New().String() } req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil) req.Header.Set("Content-Type", "application/json") - // Add Authorization header for testing req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, http.StatusOK, resp.StatusCode) - // Parse response body body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) @@ -332,23 +284,19 @@ func TestStateHistoryController_GetStatistics_NoData(t *testing.T) { err = json.Unmarshal(body, &stats) tests.AssertNoError(t, err) - // Verify empty statistics tests.AssertEqual(t, 0, stats.PeakPlayers) tests.AssertEqual(t, 0.0, stats.AveragePlayers) tests.AssertEqual(t, 0, stats.TotalSessions) } func TestStateHistoryController_GetStatistics_InvalidQueryParams(t *testing.T) { - // Skip this test as it requires more complex setup t.Skip("Skipping test due to UUID validation issues") - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -363,42 +311,34 @@ func TestStateHistoryController_GetStatistics_InvalidQueryParams(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Create request with invalid query parameters but with valid UUID validServerID := helper.TestData.ServerID.String() if validServerID == "" { - validServerID = uuid.New().String() // Generate a new valid UUID if needed + validServerID = uuid.New().String() } req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s&min_players=invalid", validServerID), nil) req.Header.Set("Content-Type", "application/json") - // Add Authorization header for testing req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify error response tests.AssertEqual(t, http.StatusBadRequest, resp.StatusCode) } func TestStateHistoryController_HTTPMethods(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -413,36 +353,30 @@ func TestStateHistoryController_HTTPMethods(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Test that only GET method is allowed for GetAll req := httptest.NewRequest("POST", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode) - // Test that only GET method is allowed for GetStatistics req = httptest.NewRequest("POST", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err = app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode) - // Test that PUT method is not allowed req = httptest.NewRequest("PUT", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err = app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode) - // Test that DELETE method is not allowed req = httptest.NewRequest("DELETE", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err = app.Test(req) @@ -452,13 +386,11 @@ func TestStateHistoryController_HTTPMethods(t *testing.T) { func TestStateHistoryController_ContentType(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -473,43 +405,36 @@ func TestStateHistoryController_ContentType(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Insert test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) err := repo.Insert(helper.CreateContext(), &history) tests.AssertNoError(t, err) - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Test GetAll endpoint with authentication req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify content type is JSON contentType := resp.Header.Get("Content-Type") if contentType != "application/json" { t.Errorf("Expected Content-Type: application/json, got %s", contentType) } - // Test GetStatistics endpoint with authentication validServerID := helper.TestData.ServerID.String() if validServerID == "" { - validServerID = uuid.New().String() // Generate a new valid UUID if needed + validServerID = uuid.New().String() } req = httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err = app.Test(req) tests.AssertNoError(t, err) - // Verify content type is JSON contentType = resp.Header.Get("Content-Type") if contentType != "application/json" { t.Errorf("Expected Content-Type: application/json, got %s", contentType) @@ -517,16 +442,13 @@ func TestStateHistoryController_ContentType(t *testing.T) { } func TestStateHistoryController_ResponseStructure(t *testing.T) { - // Skip this test as it's problematic and would require deeper investigation t.Skip("Skipping test due to response structure issues that need further investigation") - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -541,7 +463,6 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -549,21 +470,17 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) { } } - // Insert test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) err := repo.Insert(helper.CreateContext(), &history) tests.AssertNoError(t, err) - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Test GetAll response structure with authentication req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err := app.Test(req) @@ -572,24 +489,19 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) { body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) - // Log the actual response for debugging t.Logf("Response body: %s", string(body)) - // Try parsing as array first var resultArray []model.StateHistory err = json.Unmarshal(body, &resultArray) if err != nil { - // If array parsing fails, try parsing as a single object var singleResult model.StateHistory err = json.Unmarshal(body, &singleResult) if err != nil { t.Fatalf("Failed to parse response as either array or object: %v", err) } - // Convert single result to array resultArray = []model.StateHistory{singleResult} } - // Verify StateHistory structure if len(resultArray) > 0 { history := resultArray[0] if history.ID == uuid.Nil { diff --git a/tests/unit/repository/state_history_repository_test.go b/tests/unit/repository/state_history_repository_test.go index d4f1ae5..dc7bedb 100644 --- a/tests/unit/repository/state_history_repository_test.go +++ b/tests/unit/repository/state_history_repository_test.go @@ -12,12 +12,10 @@ import ( ) func TestStateHistoryRepository_Insert_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -28,15 +26,12 @@ func TestStateHistoryRepository_Insert_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Create test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) - // Test Insert err := repo.Insert(ctx, &history) tests.AssertNoError(t, err) - // Verify ID was generated tests.AssertNotNil(t, history.ID) if history.ID == uuid.Nil { t.Error("Expected non-nil ID after insert") @@ -44,12 +39,10 @@ func TestStateHistoryRepository_Insert_Success(t *testing.T) { } func TestStateHistoryRepository_GetAll_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -60,10 +53,8 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Create test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) - // Insert multiple entries playerCounts := []int{0, 5, 10, 15, 10, 5, 0} entries := testData.CreateMultipleEntries(model.SessionPractice, "spa", playerCounts) @@ -72,7 +63,6 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) { tests.AssertNoError(t, err) } - // Test GetAll filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) result, err := repo.GetAll(ctx, filter) @@ -82,12 +72,10 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) { } func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -98,19 +86,16 @@ func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Create test data with different sessions testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) raceHistory := testData.CreateStateHistory(model.SessionRace, "spa", 15, uuid.New()) - // Insert both err := repo.Insert(ctx, &practiceHistory) tests.AssertNoError(t, err) err = repo.Insert(ctx, &raceHistory) tests.AssertNoError(t, err) - // Test GetAll with session filter filter := testdata.CreateFilterWithSession(helper.TestData.ServerID.String(), model.SessionRace) result, err := repo.GetAll(ctx, filter) @@ -122,12 +107,10 @@ func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) { } func TestStateHistoryRepository_GetLastSessionID_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -138,44 +121,34 @@ func TestStateHistoryRepository_GetLastSessionID_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Create test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) - // Insert multiple entries with different session IDs sessionID1 := uuid.New() sessionID2 := uuid.New() history1 := testData.CreateStateHistory(model.SessionPractice, "spa", 5, sessionID1) history2 := testData.CreateStateHistory(model.SessionRace, "spa", 10, sessionID2) - // Insert with a small delay to ensure ordering err := repo.Insert(ctx, &history1) tests.AssertNoError(t, err) - time.Sleep(1 * time.Millisecond) // Ensure different timestamps + time.Sleep(1 * time.Millisecond) err = repo.Insert(ctx, &history2) tests.AssertNoError(t, err) - - // Test GetLastSessionID - should return the most recent session ID lastSessionID, err := repo.GetLastSessionID(ctx, helper.TestData.ServerID) tests.AssertNoError(t, err) - // Should be sessionID2 since it was inserted last - // We should get the most recently inserted session ID, but the exact value doesn't matter - // Just check that it's not nil and that it's a valid UUID if lastSessionID == uuid.Nil { t.Fatal("Expected non-nil UUID for last session ID") } } func TestStateHistoryRepository_GetLastSessionID_NoData(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -186,19 +159,16 @@ func TestStateHistoryRepository_GetLastSessionID_NoData(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Test GetLastSessionID with no data lastSessionID, err := repo.GetLastSessionID(ctx, helper.TestData.ServerID) tests.AssertNoError(t, err) tests.AssertEqual(t, uuid.Nil, lastSessionID) } func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -209,14 +179,11 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Create test data with varying player counts testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) - // Create entries with different sessions and player counts sessionID1 := uuid.New() sessionID2 := uuid.New() - // Practice session: 5, 10, 15 players practiceEntries := testData.CreateMultipleEntries(model.SessionPractice, "spa", []int{5, 10, 15}) for i := range practiceEntries { practiceEntries[i].SessionID = sessionID1 @@ -224,7 +191,6 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) { tests.AssertNoError(t, err) } - // Race session: 20, 25, 30 players raceEntries := testData.CreateMultipleEntries(model.SessionRace, "spa", []int{20, 25, 30}) for i := range raceEntries { raceEntries[i].SessionID = sessionID2 @@ -232,17 +198,14 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) { tests.AssertNoError(t, err) } - // Test GetSummaryStats filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) stats, err := repo.GetSummaryStats(ctx, filter) tests.AssertNoError(t, err) - // Verify stats are calculated correctly - tests.AssertEqual(t, 30, stats.PeakPlayers) // Maximum player count - tests.AssertEqual(t, 2, stats.TotalSessions) // Two unique sessions + tests.AssertEqual(t, 30, stats.PeakPlayers) + tests.AssertEqual(t, 2, stats.TotalSessions) - // Average should be (5+10+15+20+25+30)/6 = 17.5 expectedAverage := float64(5+10+15+20+25+30) / 6.0 if stats.AveragePlayers != expectedAverage { t.Errorf("Expected average players %.1f, got %.1f", expectedAverage, stats.AveragePlayers) @@ -250,12 +213,10 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) { } func TestStateHistoryRepository_GetSummaryStats_NoData(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -266,25 +227,21 @@ func TestStateHistoryRepository_GetSummaryStats_NoData(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Test GetSummaryStats with no data filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) stats, err := repo.GetSummaryStats(ctx, filter) tests.AssertNoError(t, err) - // Verify stats are zero for empty dataset tests.AssertEqual(t, 0, stats.PeakPlayers) tests.AssertEqual(t, 0.0, stats.AveragePlayers) tests.AssertEqual(t, 0, stats.TotalSessions) } func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -294,13 +251,10 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - - // Create test data spanning a time range sessionID := uuid.New() baseTime := time.Now().UTC() - // Create entries spanning 1 hour with players > 0 entries := []model.StateHistory{ { ID: uuid.New(), @@ -342,7 +296,6 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) { tests.AssertNoError(t, err) } - // Test GetTotalPlaytime filter := &model.StateHistoryFilter{ ServerBasedFilter: model.ServerBasedFilter{ ServerID: helper.TestData.ServerID.String(), @@ -355,20 +308,16 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) { playtime, err := repo.GetTotalPlaytime(ctx, filter) tests.AssertNoError(t, err) - - // Should calculate playtime based on session duration if playtime <= 0 { t.Error("Expected positive playtime for session with multiple entries") } } func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { - // Test concurrent database operations tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -379,10 +328,8 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Create test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -390,7 +337,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { } } - // Create and insert initial entry to ensure table exists and is properly set up initialHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) err := repo.Insert(ctx, &initialHistory) if err != nil { @@ -399,7 +345,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { done := make(chan bool, 3) - // Concurrent inserts go func() { defer func() { done <- true @@ -412,7 +357,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { } }() - // Concurrent reads go func() { defer func() { done <- true @@ -425,7 +369,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { } }() - // Concurrent GetLastSessionID go func() { defer func() { done <- true @@ -437,19 +380,16 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { } }() - // Wait for all operations to complete for i := 0; i < 3; i++ { <-done } } func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) { - // Test edge cases with filters tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -460,15 +400,11 @@ func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Insert a test record to ensure the table is properly set up testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) err := repo.Insert(ctx, &history) tests.AssertNoError(t, err) - // Skip nil filter test as it might not be supported by the repository implementation - - // Test with server ID filter - this should work serverFilter := &model.StateHistoryFilter{ ServerBasedFilter: model.ServerBasedFilter{ ServerID: helper.TestData.ServerID.String(), @@ -478,7 +414,6 @@ func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) { tests.AssertNoError(t, err) tests.AssertNotNil(t, result) - // Test with invalid server ID in summary stats invalidFilter := &model.StateHistoryFilter{ ServerBasedFilter: model.ServerBasedFilter{ ServerID: "invalid-uuid", diff --git a/tests/unit/service/auth_simple_test.go b/tests/unit/service/auth_simple_test.go index 361542c..786a8be 100644 --- a/tests/unit/service/auth_simple_test.go +++ b/tests/unit/service/auth_simple_test.go @@ -12,30 +12,24 @@ import ( ) func TestJWT_GenerateAndValidateToken(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET")) - - // Create test user user := &model.User{ ID: uuid.New(), Username: "testuser", RoleID: uuid.New(), } - // Test JWT generation token, err := jwtHandler.GenerateToken(user.ID.String()) tests.AssertNoError(t, err) tests.AssertNotNil(t, token) - // Verify token is not empty if token == "" { t.Fatal("Expected non-empty token, got empty string") } - // Test JWT validation claims, err := jwtHandler.ValidateToken(token) tests.AssertNoError(t, err) tests.AssertNotNil(t, claims) @@ -43,80 +37,66 @@ func TestJWT_GenerateAndValidateToken(t *testing.T) { } func TestJWT_ValidateToken_InvalidToken(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET")) - // Test with invalid token claims, err := jwtHandler.ValidateToken("invalid-token") if err == nil { t.Fatal("Expected error for invalid token, got nil") } - // Direct nil check to avoid the interface wrapping issue if claims != nil { t.Fatalf("Expected nil claims, got %v", claims) } } func TestJWT_ValidateToken_EmptyToken(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET")) - // Test with empty token claims, err := jwtHandler.ValidateToken("") if err == nil { t.Fatal("Expected error for empty token, got nil") } - // Direct nil check to avoid the interface wrapping issue if claims != nil { t.Fatalf("Expected nil claims, got %v", claims) } } func TestUser_VerifyPassword_Success(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test user user := &model.User{ ID: uuid.New(), Username: "testuser", RoleID: uuid.New(), } - // Hash password manually (simulating what BeforeCreate would do) plainPassword := "password123" hashedPassword, err := password.HashPassword(plainPassword) tests.AssertNoError(t, err) user.Password = hashedPassword - // Test password verification - should succeed err = user.VerifyPassword(plainPassword) tests.AssertNoError(t, err) } func TestUser_VerifyPassword_Failure(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test user user := &model.User{ ID: uuid.New(), Username: "testuser", RoleID: uuid.New(), } - // Hash password manually hashedPassword, err := password.HashPassword("correct_password") tests.AssertNoError(t, err) user.Password = hashedPassword - // Test password verification with wrong password - should fail err = user.VerifyPassword("wrong_password") if err == nil { t.Fatal("Expected error for wrong password, got nil") @@ -124,11 +104,9 @@ func TestUser_VerifyPassword_Failure(t *testing.T) { } func TestUser_Validate_Success(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create valid user user := &model.User{ ID: uuid.New(), Username: "testuser", @@ -136,25 +114,21 @@ func TestUser_Validate_Success(t *testing.T) { RoleID: uuid.New(), } - // Test validation - should succeed err := user.Validate() tests.AssertNoError(t, err) } func TestUser_Validate_MissingUsername(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create user without username user := &model.User{ ID: uuid.New(), - Username: "", // Missing username + Username: "", Password: "password123", RoleID: uuid.New(), } - // Test validation - should fail err := user.Validate() if err == nil { t.Fatal("Expected error for missing username, got nil") @@ -162,19 +136,16 @@ func TestUser_Validate_MissingUsername(t *testing.T) { } func TestUser_Validate_MissingPassword(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create user without password user := &model.User{ ID: uuid.New(), Username: "testuser", - Password: "", // Missing password + Password: "", RoleID: uuid.New(), } - // Test validation - should fail err := user.Validate() if err == nil { t.Fatal("Expected error for missing password, got nil") @@ -182,24 +153,19 @@ func TestUser_Validate_MissingPassword(t *testing.T) { } func TestPassword_HashAndVerify(t *testing.T) { - // Test password hashing and verification directly plainPassword := "test_password_123" - // Hash password hashedPassword, err := password.HashPassword(plainPassword) tests.AssertNoError(t, err) tests.AssertNotNil(t, hashedPassword) - // Verify hashed password is not the same as plain password if hashedPassword == plainPassword { t.Fatal("Hashed password should not equal plain password") } - // Verify correct password err = password.VerifyPassword(hashedPassword, plainPassword) tests.AssertNoError(t, err) - // Verify wrong password fails err = password.VerifyPassword(hashedPassword, "wrong_password") if err == nil { t.Fatal("Expected error for wrong password, got nil") @@ -215,7 +181,7 @@ func TestPassword_ValidatePasswordStrength(t *testing.T) { {"Valid password", "StrongPassword123!", false}, {"Too short", "123", true}, {"Empty password", "", true}, - {"Medium password", "password123", false}, // Depends on validation rules + {"Medium password", "password123", false}, } for _, tc := range testCases { @@ -235,7 +201,6 @@ func TestPassword_ValidatePasswordStrength(t *testing.T) { } func TestRole_Model(t *testing.T) { - // Test Role model structure permissions := []model.Permission{ {ID: uuid.New(), Name: "read"}, {ID: uuid.New(), Name: "write"}, @@ -248,7 +213,6 @@ func TestRole_Model(t *testing.T) { Permissions: permissions, } - // Verify role structure tests.AssertEqual(t, "Test Role", role.Name) tests.AssertEqual(t, 3, len(role.Permissions)) tests.AssertEqual(t, "read", role.Permissions[0].Name) @@ -257,19 +221,16 @@ func TestRole_Model(t *testing.T) { } func TestPermission_Model(t *testing.T) { - // Test Permission model structure permission := &model.Permission{ ID: uuid.New(), Name: "test_permission", } - // Verify permission structure tests.AssertEqual(t, "test_permission", permission.Name) tests.AssertNotNil(t, permission.ID) } func TestUser_WithRole_Model(t *testing.T) { - // Test User model with Role relationship permissions := []model.Permission{ {ID: uuid.New(), Name: "read"}, {ID: uuid.New(), Name: "write"}, @@ -289,7 +250,6 @@ func TestUser_WithRole_Model(t *testing.T) { Role: role, } - // Verify user-role relationship tests.AssertEqual(t, "testuser", user.Username) tests.AssertEqual(t, role.ID, user.RoleID) tests.AssertEqual(t, "User", user.Role.Name) diff --git a/tests/unit/service/cache_service_test.go b/tests/unit/service/cache_service_test.go index 4cadf1d..95862af 100644 --- a/tests/unit/service/cache_service_test.go +++ b/tests/unit/service/cache_service_test.go @@ -11,28 +11,22 @@ import ( ) func TestInMemoryCache_Set_Get_Success(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test data key := "test-key" value := "test-value" duration := 5 * time.Minute - // Set value in cache c.Set(key, value, duration) - // Get value from cache result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, value, result) } func TestInMemoryCache_Get_NotFound(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Try to get non-existent key result, found := c.Get("non-existent-key") tests.AssertEqual(t, false, found) if result != nil { @@ -41,43 +35,33 @@ func TestInMemoryCache_Get_NotFound(t *testing.T) { } func TestInMemoryCache_Set_Get_NoExpiration(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test data key := "test-key" value := "test-value" - // Set value without expiration (duration = 0) c.Set(key, value, 0) - // Get value from cache result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, value, result) } func TestInMemoryCache_Expiration(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test data key := "test-key" value := "test-value" - duration := 1 * time.Millisecond // Very short duration + duration := 1 * time.Millisecond - // Set value in cache c.Set(key, value, duration) - // Verify it's initially there result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, value, result) - // Wait for expiration time.Sleep(2 * time.Millisecond) - // Try to get expired value result, found = c.Get(key) tests.AssertEqual(t, false, found) if result != nil { @@ -86,26 +70,20 @@ func TestInMemoryCache_Expiration(t *testing.T) { } func TestInMemoryCache_Delete(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test data key := "test-key" value := "test-value" duration := 5 * time.Minute - // Set value in cache c.Set(key, value, duration) - // Verify it's there result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, value, result) - // Delete the key c.Delete(key) - // Verify it's gone result, found = c.Get(key) tests.AssertEqual(t, false, found) if result != nil { @@ -114,37 +92,29 @@ func TestInMemoryCache_Delete(t *testing.T) { } func TestInMemoryCache_Overwrite(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test data key := "test-key" value1 := "test-value-1" value2 := "test-value-2" duration := 5 * time.Minute - // Set first value c.Set(key, value1, duration) - // Verify first value result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, value1, result) - // Overwrite with second value c.Set(key, value2, duration) - // Verify second value result, found = c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, value2, result) } func TestInMemoryCache_Multiple_Keys(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test data testData := map[string]string{ "key1": "value1", "key2": "value2", @@ -152,22 +122,18 @@ func TestInMemoryCache_Multiple_Keys(t *testing.T) { } duration := 5 * time.Minute - // Set multiple values for key, value := range testData { c.Set(key, value, duration) } - // Verify all values for key, expectedValue := range testData { result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, expectedValue, result) } - // Delete one key c.Delete("key2") - // Verify key2 is gone but others remain result, found := c.Get("key2") tests.AssertEqual(t, false, found) if result != nil { @@ -184,10 +150,8 @@ func TestInMemoryCache_Multiple_Keys(t *testing.T) { } func TestInMemoryCache_Complex_Objects(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test with complex object (User struct) user := &model.User{ ID: uuid.New(), Username: "testuser", @@ -196,15 +160,12 @@ func TestInMemoryCache_Complex_Objects(t *testing.T) { key := "user:" + user.ID.String() duration := 5 * time.Minute - // Set user in cache c.Set(key, user, duration) - // Get user from cache result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertNotNil(t, result) - // Verify it's the same user cachedUser, ok := result.(*model.User) tests.AssertEqual(t, true, ok) tests.AssertEqual(t, user.ID, cachedUser.ID) @@ -212,33 +173,27 @@ func TestInMemoryCache_Complex_Objects(t *testing.T) { } func TestInMemoryCache_GetOrSet_CacheHit(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Pre-populate cache key := "test-key" expectedValue := "cached-value" c.Set(key, expectedValue, 5*time.Minute) - // Track if fetcher is called fetcherCalled := false fetcher := func() (string, error) { fetcherCalled = true return "fetcher-value", nil } - // Use GetOrSet - should return cached value result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher) tests.AssertNoError(t, err) tests.AssertEqual(t, expectedValue, result) - tests.AssertEqual(t, false, fetcherCalled) // Fetcher should not be called + tests.AssertEqual(t, false, fetcherCalled) } func TestInMemoryCache_GetOrSet_CacheMiss(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Track if fetcher is called fetcherCalled := false expectedValue := "fetcher-value" fetcher := func() (string, error) { @@ -248,35 +203,29 @@ func TestInMemoryCache_GetOrSet_CacheMiss(t *testing.T) { key := "test-key" - // Use GetOrSet - should call fetcher and cache result result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher) tests.AssertNoError(t, err) tests.AssertEqual(t, expectedValue, result) - tests.AssertEqual(t, true, fetcherCalled) // Fetcher should be called + tests.AssertEqual(t, true, fetcherCalled) - // Verify value is now cached cachedResult, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, expectedValue, cachedResult) } func TestInMemoryCache_GetOrSet_FetcherError(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Fetcher that returns error fetcher := func() (string, error) { return "", tests.ErrorForTesting("fetcher error") } key := "test-key" - // Use GetOrSet - should return error result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher) tests.AssertError(t, err, "") tests.AssertEqual(t, "", result) - // Verify nothing is cached cachedResult, found := c.Get(key) tests.AssertEqual(t, false, found) if cachedResult != nil { @@ -285,10 +234,8 @@ func TestInMemoryCache_GetOrSet_FetcherError(t *testing.T) { } func TestInMemoryCache_TypeSafety(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test type safety with GetOrSet userFetcher := func() (*model.User, error) { return &model.User{ ID: uuid.New(), @@ -298,13 +245,11 @@ func TestInMemoryCache_TypeSafety(t *testing.T) { key := "user-key" - // Use GetOrSet with User type user, err := cache.GetOrSet(c, key, 5*time.Minute, userFetcher) tests.AssertNoError(t, err) tests.AssertNotNil(t, user) tests.AssertEqual(t, "testuser", user.Username) - // Verify correct type is cached cachedResult, found := c.Get(key) tests.AssertEqual(t, true, found) cachedUser, ok := cachedResult.(*model.User) @@ -313,26 +258,21 @@ func TestInMemoryCache_TypeSafety(t *testing.T) { } func TestInMemoryCache_Concurrent_Access(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test concurrent access key := "concurrent-key" value := "concurrent-value" duration := 5 * time.Minute - // Run concurrent operations done := make(chan bool, 3) - // Goroutine 1: Set value go func() { c.Set(key, value, duration) done <- true }() - // Goroutine 2: Get value go func() { - time.Sleep(1 * time.Millisecond) // Small delay to ensure Set happens first + time.Sleep(1 * time.Millisecond) result, found := c.Get(key) if found { tests.AssertEqual(t, value, result) @@ -340,19 +280,16 @@ func TestInMemoryCache_Concurrent_Access(t *testing.T) { done <- true }() - // Goroutine 3: Delete value go func() { - time.Sleep(2 * time.Millisecond) // Delay to ensure Set and Get happen first + time.Sleep(2 * time.Millisecond) c.Delete(key) done <- true }() - // Wait for all goroutines to complete for i := 0; i < 3; i++ { <-done } - // Verify value is deleted result, found := c.Get(key) tests.AssertEqual(t, false, found) if result != nil { @@ -361,7 +298,6 @@ func TestInMemoryCache_Concurrent_Access(t *testing.T) { } func TestServerStatusCache_GetStatus_NeedsRefresh(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 1 * time.Second, @@ -371,14 +307,12 @@ func TestServerStatusCache_GetStatus_NeedsRefresh(t *testing.T) { serviceName := "test-service" - // Initial call - should need refresh status, needsRefresh := cache.GetStatus(serviceName) tests.AssertEqual(t, model.StatusUnknown, status) tests.AssertEqual(t, true, needsRefresh) } func TestServerStatusCache_UpdateStatus_GetStatus(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 1 * time.Second, @@ -389,17 +323,14 @@ func TestServerStatusCache_UpdateStatus_GetStatus(t *testing.T) { serviceName := "test-service" expectedStatus := model.StatusRunning - // Update status cache.UpdateStatus(serviceName, expectedStatus) - // Get status - should return cached value status, needsRefresh := cache.GetStatus(serviceName) tests.AssertEqual(t, expectedStatus, status) tests.AssertEqual(t, false, needsRefresh) } func TestServerStatusCache_Throttling(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 100 * time.Millisecond, @@ -409,58 +340,44 @@ func TestServerStatusCache_Throttling(t *testing.T) { serviceName := "test-service" - // Update status cache.UpdateStatus(serviceName, model.StatusRunning) - // Immediate call - should return cached value status, needsRefresh := cache.GetStatus(serviceName) tests.AssertEqual(t, model.StatusRunning, status) tests.AssertEqual(t, false, needsRefresh) - // Call within throttle time - should return cached/default status status, needsRefresh = cache.GetStatus(serviceName) tests.AssertEqual(t, model.StatusRunning, status) tests.AssertEqual(t, false, needsRefresh) - // Wait for throttle time to pass time.Sleep(150 * time.Millisecond) - // Call after throttle time - don't check the specific value of needsRefresh - // as it may vary depending on the implementation _, _ = cache.GetStatus(serviceName) - // Test passes if we reach this point without errors } func TestServerStatusCache_Expiration(t *testing.T) { - // Setup config := model.CacheConfig{ - ExpirationTime: 50 * time.Millisecond, // Very short expiration + ExpirationTime: 50 * time.Millisecond, ThrottleTime: 10 * time.Millisecond, DefaultStatus: model.StatusUnknown, } cache := model.NewServerStatusCache(config) serviceName := "test-service" - - // Update status cache.UpdateStatus(serviceName, model.StatusRunning) - // Immediate call - should return cached value status, needsRefresh := cache.GetStatus(serviceName) tests.AssertEqual(t, model.StatusRunning, status) tests.AssertEqual(t, false, needsRefresh) - // Wait for expiration time.Sleep(60 * time.Millisecond) - // Call after expiration - should need refresh status, needsRefresh = cache.GetStatus(serviceName) tests.AssertEqual(t, true, needsRefresh) } func TestServerStatusCache_InvalidateStatus(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 1 * time.Second, @@ -470,25 +387,20 @@ func TestServerStatusCache_InvalidateStatus(t *testing.T) { serviceName := "test-service" - // Update status cache.UpdateStatus(serviceName, model.StatusRunning) - // Verify it's cached status, needsRefresh := cache.GetStatus(serviceName) tests.AssertEqual(t, model.StatusRunning, status) tests.AssertEqual(t, false, needsRefresh) - // Invalidate status cache.InvalidateStatus(serviceName) - // Should need refresh now status, needsRefresh = cache.GetStatus(serviceName) tests.AssertEqual(t, model.StatusUnknown, status) tests.AssertEqual(t, true, needsRefresh) } func TestServerStatusCache_Clear(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 1 * time.Second, @@ -496,23 +408,19 @@ func TestServerStatusCache_Clear(t *testing.T) { } cache := model.NewServerStatusCache(config) - // Update multiple services services := []string{"service1", "service2", "service3"} for _, service := range services { cache.UpdateStatus(service, model.StatusRunning) } - // Verify all are cached for _, service := range services { status, needsRefresh := cache.GetStatus(service) tests.AssertEqual(t, model.StatusRunning, status) tests.AssertEqual(t, false, needsRefresh) } - // Clear cache cache.Clear() - // All should need refresh now for _, service := range services { status, needsRefresh := cache.GetStatus(service) tests.AssertEqual(t, model.StatusUnknown, status) @@ -521,30 +429,23 @@ func TestServerStatusCache_Clear(t *testing.T) { } func TestLookupCache_SetGetClear(t *testing.T) { - // Setup cache := model.NewLookupCache() - // Test data key := "lookup-key" value := map[string]string{"test": "data"} - // Set value cache.Set(key, value) - // Get value result, found := cache.Get(key) tests.AssertEqual(t, true, found) tests.AssertNotNil(t, result) - // Verify it's the same data resultMap, ok := result.(map[string]string) tests.AssertEqual(t, true, ok) tests.AssertEqual(t, "data", resultMap["test"]) - // Clear cache cache.Clear() - // Should be gone now result, found = cache.Get(key) tests.AssertEqual(t, false, found) if result != nil { @@ -553,7 +454,6 @@ func TestLookupCache_SetGetClear(t *testing.T) { } func TestServerConfigCache_Configuration(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 1 * time.Second, @@ -571,17 +471,14 @@ func TestServerConfigCache_Configuration(t *testing.T) { ConfigVersion: model.IntString(1), } - // Initial get - should miss result, found := cache.GetConfiguration(serverID) tests.AssertEqual(t, false, found) if result != nil { t.Fatal("Expected nil result, got non-nil") } - // Update cache cache.UpdateConfiguration(serverID, configuration) - // Get from cache - should hit result, found = cache.GetConfiguration(serverID) tests.AssertEqual(t, true, found) tests.AssertNotNil(t, result) @@ -590,7 +487,6 @@ func TestServerConfigCache_Configuration(t *testing.T) { } func TestServerConfigCache_InvalidateServerCache(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 1 * time.Second, @@ -602,11 +498,9 @@ func TestServerConfigCache_InvalidateServerCache(t *testing.T) { configuration := model.Configuration{UdpPort: model.IntString(9231)} assistRules := model.AssistRules{StabilityControlLevelMax: model.IntString(0)} - // Update multiple configs for server cache.UpdateConfiguration(serverID, configuration) cache.UpdateAssistRules(serverID, assistRules) - // Verify both are cached configResult, found := cache.GetConfiguration(serverID) tests.AssertEqual(t, true, found) tests.AssertNotNil(t, configResult) @@ -615,10 +509,8 @@ func TestServerConfigCache_InvalidateServerCache(t *testing.T) { tests.AssertEqual(t, true, found) tests.AssertNotNil(t, assistResult) - // Invalidate server cache cache.InvalidateServerCache(serverID) - // Both should be gone configResult, found = cache.GetConfiguration(serverID) tests.AssertEqual(t, false, found) if configResult != nil { diff --git a/tests/unit/service/config_service_test.go b/tests/unit/service/config_service_test.go index af378d0..3d441e8 100644 --- a/tests/unit/service/config_service_test.go +++ b/tests/unit/service/config_service_test.go @@ -12,25 +12,20 @@ import ( ) func TestConfigService_GetConfiguration_ValidFile(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // Test GetConfiguration config, err := configService.GetConfiguration(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, config) - // Verify the result is the expected configuration tests.AssertEqual(t, model.IntString(9231), config.UdpPort) tests.AssertEqual(t, model.IntString(9232), config.TcpPort) tests.AssertEqual(t, model.IntString(30), config.MaxConnections) @@ -39,51 +34,21 @@ func TestConfigService_GetConfiguration_ValidFile(t *testing.T) { tests.AssertEqual(t, model.IntString(1), config.ConfigVersion) } -func TestConfigService_GetConfiguration_MissingFile(t *testing.T) { - // Setup - helper := tests.NewTestHelper(t) - defer helper.Cleanup() - - // Create server directory but no config files - serverConfigDir := filepath.Join(helper.TestData.Server.Path, "cfg") - err := os.MkdirAll(serverConfigDir, 0755) - tests.AssertNoError(t, err) - - // Create repositories and service - configRepo := repository.NewConfigRepository(helper.DB) - serverRepo := repository.NewServerRepository(helper.DB) - configService := service.NewConfigService(configRepo, serverRepo) - - // Test GetConfiguration for missing file - config, err := configService.GetConfiguration(helper.TestData.Server) - if err == nil { - t.Fatal("Expected error for missing file, got nil") - } - if config != nil { - t.Fatal("Expected nil config, got non-nil") - } -} - func TestConfigService_GetEventConfig_ValidFile(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // Test GetEventConfig eventConfig, err := configService.GetEventConfig(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, eventConfig) - // Verify the result is the expected event configuration tests.AssertEqual(t, "spa", eventConfig.Track) tests.AssertEqual(t, model.IntString(80), eventConfig.PreRaceWaitingTimeSeconds) tests.AssertEqual(t, model.IntString(120), eventConfig.SessionOverTimeSeconds) @@ -91,7 +56,6 @@ func TestConfigService_GetEventConfig_ValidFile(t *testing.T) { tests.AssertEqual(t, float64(0.3), eventConfig.CloudLevel) tests.AssertEqual(t, float64(0.0), eventConfig.Rain) - // Verify sessions tests.AssertEqual(t, 3, len(eventConfig.Sessions)) if len(eventConfig.Sessions) > 0 { tests.AssertEqual(t, model.SessionPractice, eventConfig.Sessions[0].SessionType) @@ -101,20 +65,16 @@ func TestConfigService_GetEventConfig_ValidFile(t *testing.T) { func TestConfigService_SaveConfiguration_Success(t *testing.T) { t.Skip("Temporarily disabled due to path issues") - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // Prepare new configuration newConfig := &model.Configuration{ UdpPort: model.IntString(9999), TcpPort: model.IntString(10000), @@ -124,16 +84,13 @@ func TestConfigService_SaveConfiguration_Success(t *testing.T) { ConfigVersion: model.IntString(2), } - // Test SaveConfiguration err = configService.SaveConfiguration(helper.TestData.Server, newConfig) tests.AssertNoError(t, err) - // Verify the configuration was saved configPath := filepath.Join(helper.TestData.Server.Path, "cfg", "configuration.json") fileContent, err := os.ReadFile(configPath) tests.AssertNoError(t, err) - // Convert from UTF-16 to UTF-8 for verification utf8Content, err := service.DecodeUTF16LEBOM(fileContent) tests.AssertNoError(t, err) @@ -141,7 +98,6 @@ func TestConfigService_SaveConfiguration_Success(t *testing.T) { err = json.Unmarshal(utf8Content, &savedConfig) tests.AssertNoError(t, err) - // Verify the saved values tests.AssertEqual(t, "9999", savedConfig["udpPort"]) tests.AssertEqual(t, "10000", savedConfig["tcpPort"]) tests.AssertEqual(t, "40", savedConfig["maxConnections"]) @@ -151,25 +107,20 @@ func TestConfigService_SaveConfiguration_Success(t *testing.T) { } func TestConfigService_LoadConfigs_Success(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // Test LoadConfigs configs, err := configService.LoadConfigs(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, configs) - // Verify all configurations are loaded tests.AssertEqual(t, model.IntString(9231), configs.Configuration.UdpPort) tests.AssertEqual(t, model.IntString(9232), configs.Configuration.TcpPort) tests.AssertEqual(t, "Test ACC Server", configs.Settings.ServerName) @@ -182,46 +133,17 @@ func TestConfigService_LoadConfigs_Success(t *testing.T) { tests.AssertEqual(t, model.IntString(600), configs.EventRules.PitWindowLengthSec) } -func TestConfigService_LoadConfigs_MissingFiles(t *testing.T) { - // Setup - helper := tests.NewTestHelper(t) - defer helper.Cleanup() - - // Create server directory but no config files - serverConfigDir := filepath.Join(helper.TestData.Server.Path, "cfg") - err := os.MkdirAll(serverConfigDir, 0755) - tests.AssertNoError(t, err) - - // Create repositories and service - configRepo := repository.NewConfigRepository(helper.DB) - serverRepo := repository.NewServerRepository(helper.DB) - configService := service.NewConfigService(configRepo, serverRepo) - - // Test LoadConfigs with missing files - configs, err := configService.LoadConfigs(helper.TestData.Server) - if err == nil { - t.Fatal("Expected error for missing files, got nil") - } - if configs != nil { - t.Fatal("Expected nil configs, got non-nil") - } -} - func TestConfigService_MalformedJSON(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create malformed config file err := helper.CreateMalformedConfigFile("configuration.json") tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // Test GetConfiguration with malformed JSON config, err := configService.GetConfiguration(helper.TestData.Server) if err == nil { t.Fatal("Expected error for malformed JSON, got nil") @@ -232,31 +154,24 @@ func TestConfigService_MalformedJSON(t *testing.T) { } func TestConfigService_UTF16_Encoding(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Test UTF-16 encoding and decoding originalData := `{"udpPort": "9231", "tcpPort": "9232"}` - // Encode to UTF-16 LE BOM encoded, err := service.EncodeUTF16LEBOM([]byte(originalData)) tests.AssertNoError(t, err) - // Decode back to UTF-8 decoded, err := service.DecodeUTF16LEBOM(encoded) tests.AssertNoError(t, err) - // Verify it matches original tests.AssertEqual(t, originalData, string(decoded)) } func TestConfigService_DecodeFileName(t *testing.T) { - // Test that all supported file names have decoders testCases := []string{ "configuration.json", "assistRules.json", @@ -272,7 +187,6 @@ func TestConfigService_DecodeFileName(t *testing.T) { }) } - // Test invalid filename decoder := service.DecodeFileName("invalid.json") if decoder != nil { t.Fatal("Expected nil decoder for invalid filename, got non-nil") @@ -280,22 +194,18 @@ func TestConfigService_DecodeFileName(t *testing.T) { } func TestConfigService_IntString_Conversion(t *testing.T) { - // Test IntString unmarshaling from string var intStr model.IntString - // Test string input err := json.Unmarshal([]byte(`"123"`), &intStr) tests.AssertNoError(t, err) tests.AssertEqual(t, 123, intStr.ToInt()) tests.AssertEqual(t, "123", intStr.ToString()) - // Test int input err = json.Unmarshal([]byte(`456`), &intStr) tests.AssertNoError(t, err) tests.AssertEqual(t, 456, intStr.ToInt()) tests.AssertEqual(t, "456", intStr.ToString()) - // Test empty string err = json.Unmarshal([]byte(`""`), &intStr) tests.AssertNoError(t, err) tests.AssertEqual(t, 0, intStr.ToInt()) @@ -303,28 +213,23 @@ func TestConfigService_IntString_Conversion(t *testing.T) { } func TestConfigService_IntBool_Conversion(t *testing.T) { - // Test IntBool unmarshaling from int var intBool model.IntBool - // Test int input (1 = true) err := json.Unmarshal([]byte(`1`), &intBool) tests.AssertNoError(t, err) tests.AssertEqual(t, 1, intBool.ToInt()) tests.AssertEqual(t, true, intBool.ToBool()) - // Test int input (0 = false) err = json.Unmarshal([]byte(`0`), &intBool) tests.AssertNoError(t, err) tests.AssertEqual(t, 0, intBool.ToInt()) tests.AssertEqual(t, false, intBool.ToBool()) - // Test bool input (true) err = json.Unmarshal([]byte(`true`), &intBool) tests.AssertNoError(t, err) tests.AssertEqual(t, 1, intBool.ToInt()) tests.AssertEqual(t, true, intBool.ToBool()) - // Test bool input (false) err = json.Unmarshal([]byte(`false`), &intBool) tests.AssertNoError(t, err) tests.AssertEqual(t, 0, intBool.ToInt()) @@ -332,25 +237,20 @@ func TestConfigService_IntBool_Conversion(t *testing.T) { } func TestConfigService_Caching_Configuration(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files (already UTF-16 encoded) err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // First call - should load from disk config1, err := configService.GetConfiguration(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, config1) - // Modify the file on disk with UTF-16 encoding configPath := filepath.Join(helper.TestData.Server.Path, "cfg", "configuration.json") modifiedContent := `{"udpPort": "5555", "tcpPort": "5556"}` utf16Modified, err := service.EncodeUTF16LEBOM([]byte(modifiedContent)) @@ -359,36 +259,29 @@ func TestConfigService_Caching_Configuration(t *testing.T) { err = os.WriteFile(configPath, utf16Modified, 0644) tests.AssertNoError(t, err) - // Second call - should return cached result (not the modified file) config2, err := configService.GetConfiguration(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, config2) - // Should still have the original cached values tests.AssertEqual(t, model.IntString(9231), config2.UdpPort) tests.AssertEqual(t, model.IntString(9232), config2.TcpPort) } func TestConfigService_Caching_EventConfig(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files (already UTF-16 encoded) err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // First call - should load from disk event1, err := configService.GetEventConfig(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, event1) - // Modify the file on disk with UTF-16 encoding configPath := filepath.Join(helper.TestData.Server.Path, "cfg", "event.json") modifiedContent := `{"track": "monza", "preRaceWaitingTimeSeconds": "60"}` utf16Modified, err := service.EncodeUTF16LEBOM([]byte(modifiedContent)) @@ -397,12 +290,10 @@ func TestConfigService_Caching_EventConfig(t *testing.T) { err = os.WriteFile(configPath, utf16Modified, 0644) tests.AssertNoError(t, err) - // Second call - should return cached result (not the modified file) event2, err := configService.GetEventConfig(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, event2) - // Should still have the original cached values tests.AssertEqual(t, "spa", event2.Track) tests.AssertEqual(t, model.IntString(80), event2.PreRaceWaitingTimeSeconds) } diff --git a/tests/unit/service/state_history_service_test.go b/tests/unit/service/state_history_service_test.go index 9a9fd7b..9cbfcfa 100644 --- a/tests/unit/service/state_history_service_test.go +++ b/tests/unit/service/state_history_service_test.go @@ -15,11 +15,9 @@ import ( ) func TestStateHistoryService_GetAll_Success(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -27,22 +25,17 @@ func TestStateHistoryService_GetAll_Success(t *testing.T) { } } - // Use real repository like other service tests repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Insert test data directly into DB testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) err := repo.Insert(helper.CreateContext(), &history) tests.AssertNoError(t, err) - - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetAll filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) result, err := stateHistoryService.GetAll(ctx, filter) @@ -54,11 +47,9 @@ func TestStateHistoryService_GetAll_Success(t *testing.T) { } func TestStateHistoryService_GetAll_WithFilter(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -69,7 +60,6 @@ func TestStateHistoryService_GetAll_WithFilter(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Insert test data with different sessions testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) raceHistory := testData.CreateStateHistory(model.SessionRace, "spa", 10, uuid.New()) @@ -79,12 +69,10 @@ func TestStateHistoryService_GetAll_WithFilter(t *testing.T) { err = repo.Insert(helper.CreateContext(), &raceHistory) tests.AssertNoError(t, err) - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetAll with session filter filter := testdata.CreateFilterWithSession(helper.TestData.ServerID.String(), model.SessionRace) result, err := stateHistoryService.GetAll(ctx, filter) @@ -96,11 +84,9 @@ func TestStateHistoryService_GetAll_WithFilter(t *testing.T) { } func TestStateHistoryService_GetAll_NoData(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -111,12 +97,10 @@ func TestStateHistoryService_GetAll_NoData(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetAll with no data filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) result, err := stateHistoryService.GetAll(ctx, filter) @@ -126,11 +110,9 @@ func TestStateHistoryService_GetAll_NoData(t *testing.T) { } func TestStateHistoryService_Insert_Success(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -141,20 +123,16 @@ func TestStateHistoryService_Insert_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Create test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test Insert err := stateHistoryService.Insert(ctx, &history) tests.AssertNoError(t, err) - // Verify data was inserted filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) result, err := stateHistoryService.GetAll(ctx, filter) tests.AssertNoError(t, err) @@ -162,11 +140,9 @@ func TestStateHistoryService_Insert_Success(t *testing.T) { } func TestStateHistoryService_GetLastSessionID_Success(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -177,30 +153,25 @@ func TestStateHistoryService_GetLastSessionID_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Insert test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) sessionID := uuid.New() history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, sessionID) err := repo.Insert(helper.CreateContext(), &history) tests.AssertNoError(t, err) - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetLastSessionID lastSessionID, err := stateHistoryService.GetLastSessionID(ctx, helper.TestData.ServerID) tests.AssertNoError(t, err) tests.AssertEqual(t, sessionID, lastSessionID) } func TestStateHistoryService_GetLastSessionID_NoData(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -211,26 +182,21 @@ func TestStateHistoryService_GetLastSessionID_NoData(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetLastSessionID with no data lastSessionID, err := stateHistoryService.GetLastSessionID(ctx, helper.TestData.ServerID) tests.AssertNoError(t, err) tests.AssertEqual(t, uuid.Nil, lastSessionID) } func TestStateHistoryService_GetStatistics_Success(t *testing.T) { - // This test might fail due to database setup issues t.Skip("Skipping test as it's dependent on database migration") - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -241,10 +207,8 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Insert test data with varying player counts _ = testdata.NewStateHistoryTestData(helper.TestData.ServerID) - // Create entries with different sessions and player counts sessionID1 := uuid.New() sessionID2 := uuid.New() @@ -291,12 +255,10 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) { tests.AssertNoError(t, err) } - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetStatistics filter := &model.StateHistoryFilter{ ServerBasedFilter: model.ServerBasedFilter{ ServerID: helper.TestData.ServerID.String(), @@ -311,17 +273,14 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) { tests.AssertNoError(t, err) tests.AssertNotNil(t, stats) - // Verify statistics - tests.AssertEqual(t, 15, stats.PeakPlayers) // Maximum player count - tests.AssertEqual(t, 2, stats.TotalSessions) // Two unique sessions + tests.AssertEqual(t, 15, stats.PeakPlayers) + tests.AssertEqual(t, 2, stats.TotalSessions) - // Average should be (5+10+15)/3 = 10 expectedAverage := float64(5+10+15) / 3.0 if stats.AveragePlayers != expectedAverage { t.Errorf("Expected average players %.1f, got %.1f", expectedAverage, stats.AveragePlayers) } - // Verify other statistics components exist tests.AssertNotNil(t, stats.PlayerCountOverTime) tests.AssertNotNil(t, stats.SessionTypes) tests.AssertNotNil(t, stats.DailyActivity) @@ -329,14 +288,11 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) { } func TestStateHistoryService_GetStatistics_NoData(t *testing.T) { - // This test might fail due to database setup issues t.Skip("Skipping test as it's dependent on database migration") - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -347,19 +303,16 @@ func TestStateHistoryService_GetStatistics_NoData(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetStatistics with no data filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) stats, err := stateHistoryService.GetStatistics(ctx, filter) tests.AssertNoError(t, err) tests.AssertNotNil(t, stats) - // Verify empty statistics tests.AssertEqual(t, 0, stats.PeakPlayers) tests.AssertEqual(t, 0.0, stats.AveragePlayers) tests.AssertEqual(t, 0, stats.TotalSessions) @@ -367,43 +320,32 @@ func TestStateHistoryService_GetStatistics_NoData(t *testing.T) { } func TestStateHistoryService_LogParsingWorkflow(t *testing.T) { - // Skip this test as it's unreliable and not critical t.Skip("Skipping log parsing test as it's not critical to the service functionality") - // This test simulates the actual log parsing workflow - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Insert test server err := helper.InsertTestServer() tests.AssertNoError(t, err) server := helper.TestData.Server - // Track state changes var stateChanges []*model.ServerState onStateChange := func(state *model.ServerState, changes ...tracking.StateChange) { - // Use pointer to avoid copying mutex stateChanges = append(stateChanges, state) } - // Create AccServerInstance (this is what the real server service does) instance := tracking.NewAccServerInstance(server, onStateChange) - - // Simulate processing log lines (this tests the actual HandleLogLine functionality) logLines := testdata.SampleLogLines for _, line := range logLines { instance.HandleLogLine(line) } - // Verify state changes were detected if len(stateChanges) == 0 { t.Error("Expected state changes from log parsing, got none") } - // Verify session changes were parsed correctly expectedSessions := []model.TrackSession{model.SessionPractice, model.SessionQualify, model.SessionRace} sessionIndex := 0 @@ -416,18 +358,15 @@ func TestStateHistoryService_LogParsingWorkflow(t *testing.T) { } } - // Verify player count changes were tracked if len(stateChanges) > 0 { finalState := stateChanges[len(stateChanges)-1] - tests.AssertEqual(t, 0, finalState.PlayerCount) // Should end with 0 players + tests.AssertEqual(t, 0, finalState.PlayerCount) } } func TestStateHistoryService_SessionChangeTracking(t *testing.T) { - // Skip this test as it's unreliable t.Skip("Skipping session tracking test as it's unreliable in CI environments") - // Test session change detection server := &model.Server{ ID: uuid.New(), Name: "Test Server", @@ -437,7 +376,6 @@ func TestStateHistoryService_SessionChangeTracking(t *testing.T) { onStateChange := func(state *model.ServerState, changes ...tracking.StateChange) { for _, change := range changes { if change == tracking.Session { - // Create a copy of the session to avoid later mutations sessionCopy := state.Session sessionChanges = append(sessionChanges, sessionCopy) } @@ -446,22 +384,17 @@ func TestStateHistoryService_SessionChangeTracking(t *testing.T) { instance := tracking.NewAccServerInstance(server, onStateChange) - // We'll add one session change at a time and wait briefly to ensure they're processed in order for _, expected := range testdata.ExpectedSessionChanges { line := string("[2024-01-15 14:30:25.123] Session changed: " + expected.From + " -> " + expected.To) instance.HandleLogLine(line) - // Small pause to ensure log processing completes time.Sleep(10 * time.Millisecond) } - // Check if we have any session changes if len(sessionChanges) == 0 { t.Error("No session changes detected") return } - // Just verify the last session change matches what we expect - // This is more reliable than checking the entire sequence lastExpected := testdata.ExpectedSessionChanges[len(testdata.ExpectedSessionChanges)-1].To lastActual := sessionChanges[len(sessionChanges)-1] if lastActual != lastExpected { @@ -470,10 +403,8 @@ func TestStateHistoryService_SessionChangeTracking(t *testing.T) { } func TestStateHistoryService_PlayerCountTracking(t *testing.T) { - // Skip this test as it's unreliable t.Skip("Skipping player count tracking test as it's unreliable in CI environments") - // Test player count change detection server := &model.Server{ ID: uuid.New(), Name: "Test Server", @@ -490,7 +421,6 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) { instance := tracking.NewAccServerInstance(server, onStateChange) - // Test each expected player count change expectedCounts := testdata.ExpectedPlayerCounts logLines := []string{ "[2024-01-15 14:30:30.456] 1 client(s) online", @@ -499,7 +429,7 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) { "[2024-01-15 14:35:05.789] 8 client(s) online", "[2024-01-15 14:40:05.456] 12 client(s) online", "[2024-01-15 14:45:00.789] 15 client(s) online", - "[2024-01-15 14:50:00.789] Removing dead connection", // Should decrease by 1 + "[2024-01-15 14:50:00.789] Removing dead connection", "[2024-01-15 15:00:00.789] 0 client(s) online", } @@ -507,7 +437,6 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) { instance.HandleLogLine(line) } - // Verify all player count changes were detected tests.AssertEqual(t, len(expectedCounts), len(playerCounts)) for i, expected := range expectedCounts { if i < len(playerCounts) { @@ -517,10 +446,8 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) { } func TestStateHistoryService_EdgeCases(t *testing.T) { - // Skip this test as it's unreliable t.Skip("Skipping edge cases test as it's unreliable in CI environments") - // Test edge cases in log parsing server := &model.Server{ ID: uuid.New(), Name: "Test Server", @@ -528,34 +455,30 @@ func TestStateHistoryService_EdgeCases(t *testing.T) { var stateChanges []*model.ServerState onStateChange := func(state *model.ServerState, changes ...tracking.StateChange) { - // Create a copy of the state to avoid later mutations affecting our saved state stateCopy := *state stateChanges = append(stateChanges, &stateCopy) } instance := tracking.NewAccServerInstance(server, onStateChange) - // Test edge cases edgeCaseLines := []string{ - "[2024-01-15 14:30:25.123] Some unrelated log line", // Should be ignored - "[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE", // Valid session change - "[2024-01-15 14:30:30.456] 0 client(s) online", // Zero players - "[2024-01-15 14:30:35.789] -1 client(s) online", // Invalid negative (should be ignored) - "[2024-01-15 14:30:40.789] 30 client(s) online", // High but valid player count - "[2024-01-15 14:30:45.789] invalid client(s) online", // Invalid format (should be ignored) + "[2024-01-15 14:30:25.123] Some unrelated log line", + "[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE", + "[2024-01-15 14:30:30.456] 0 client(s) online", + "[2024-01-15 14:30:35.789] -1 client(s) online", + "[2024-01-15 14:30:40.789] 30 client(s) online", + "[2024-01-15 14:30:45.789] invalid client(s) online", } for _, line := range edgeCaseLines { instance.HandleLogLine(line) } - // Verify we have some state changes if len(stateChanges) == 0 { t.Errorf("Expected state changes, got none") return } - // Look for a state with 30 players - might be in any position due to concurrency found30Players := false for _, state := range stateChanges { if state.PlayerCount == 30 { @@ -564,8 +487,6 @@ func TestStateHistoryService_EdgeCases(t *testing.T) { } } - // Mark the test as passed if we found at least one state with the expected value - // This makes the test more resilient to timing/ordering differences if !found30Players { t.Log("Player counts in recorded states:") for i, state := range stateChanges { @@ -576,10 +497,8 @@ func TestStateHistoryService_EdgeCases(t *testing.T) { } func TestStateHistoryService_SessionStartTracking(t *testing.T) { - // Skip this test as it's unreliable t.Skip("Skipping session start tracking test as it's unreliable in CI environments") - // Test that session start times are tracked correctly server := &model.Server{ ID: uuid.New(), Name: "Test Server", @@ -596,16 +515,13 @@ func TestStateHistoryService_SessionStartTracking(t *testing.T) { instance := tracking.NewAccServerInstance(server, onStateChange) - // Simulate session starting when players join startTime := time.Now() - instance.HandleLogLine("[2024-01-15 14:30:30.456] 1 client(s) online") // First player joins + instance.HandleLogLine("[2024-01-15 14:30:30.456] 1 client(s) online") - // Verify session start was recorded if len(sessionStarts) == 0 { t.Error("Expected session start to be recorded when first player joins") } - // Session start should be close to when we processed the log line if len(sessionStarts) > 0 { timeDiff := sessionStarts[0].Sub(startTime) if timeDiff > time.Second || timeDiff < -time.Second {