code cleanup

This commit is contained in:
Fran Jurmanović
2025-09-18 13:33:51 +02:00
parent 901dbe697e
commit 5e7c96697a
83 changed files with 2832 additions and 2186 deletions

View File

@@ -0,0 +1,5 @@
{
"permissions": {
"defaultMode": "acceptEdits"
}
}

121
CLAUDE.md Normal file
View File

@@ -0,0 +1,121 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Development Commands
### Building and Running
```bash
# Build the main application
go build -o api.exe cmd/api/main.go
# Build with hot reload (requires air)
go install github.com/cosmtrek/air@latest
air
# Run the built binary
./api.exe
# Build migration utility
go build -o acc-server-migration.exe cmd/migrate/main.go
```
### Testing
```bash
# Run all tests
go test ./...
# Run tests with verbose output
go test -v ./...
# Run specific test package
go test ./tests/unit/service/
go test ./tests/unit/controller/
go test ./tests/unit/repository/
```
### Documentation
```bash
# Generate Swagger documentation (if swag is installed)
swag init -g cmd/api/main.go
```
### Setup and Configuration
```bash
# Generate security configuration (Windows PowerShell)
.\scripts\generate-secrets.ps1
# Deploy (requires configuration)
.\scripts\deploy.ps1
```
## Architecture Overview
This is a Go-based web application for managing Assetto Corsa Competizione (ACC) dedicated servers on Windows. The architecture follows a layered approach:
### Core Layers
- **cmd/**: Application entry points (main.go for API server, migrate/main.go for migrations)
- **local/**: Core application code organized by architectural layer
- **api/**: HTTP route definitions and API setup
- **controller/**: HTTP request handlers (config, membership, server, service_control, steam_2fa, system)
- **service/**: Business logic layer (server management, Steam integration, Windows services)
- **repository/**: Data access layer with GORM ORM
- **model/**: Data models and structures
- **middleware/**: HTTP middleware (auth, security, logging)
- **utl/**: Utilities organized by function (cache, command execution, JWT, logging, etc.)
### Key Components
#### Dependency Injection
Uses `go.uber.org/dig` for dependency injection. Main dependencies are set up in `cmd/api/main.go`.
#### Database
- SQLite with GORM ORM
- Database migrations in `local/migrations/`
- Models support UUID primary keys
#### Authentication & Security
- JWT-based authentication with two token types (regular and "open")
- Comprehensive security middleware stack including rate limiting, input sanitization, CORS
- Encrypted credential storage for Steam integration
#### Server Management
- Windows service integration via NSSM
- Steam integration for server installation/updates via SteamCMD
- Interactive command execution for Steam 2FA
- Firewall management
- Configuration file generation and management
#### Logging
- Custom logging system with multiple levels (debug, info, warn, error)
- Request logging middleware
- Structured logging with categories
### Testing Structure
- Unit tests in `tests/unit/` organized by layer (controller, service, repository)
- Test helpers and mocks in `tests/` directory
- Uses standard Go testing with mocks for external dependencies
### External Dependencies
- **Fiber v2**: Web framework
- **GORM**: ORM for database operations
- **SteamCMD**: External tool for Steam server management (configured via STEAMCMD_PATH env var)
- **NSSM**: Windows service management (configured via NSSM_PATH env var)
### Configuration
- Environment variables for external tool paths and configuration
- JWT secrets generated via setup scripts
- CORS configuration with configurable allowed origins
- Default port 3000 (configurable via PORT env var)
## Important Notes
### Windows-Specific Features
This application is designed specifically for Windows and includes:
- Windows service management integration
- PowerShell script execution
- Windows-specific path handling
- Firewall rule management
### Steam Integration
The Steam 2FA implementation (`local/controller/steam_2fa.go`, `local/model/steam_2fa.go`) provides interactive Steam authentication for automated server management.

1986
frontend.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -29,7 +29,6 @@ func NewServerController(ss *service.ServerService, routeGroups *common.RouteGro
serverRoutes.Get("/", auth.HasPermission(model.ServerView), ac.GetAll)
serverRoutes.Get("/:id", auth.HasPermission(model.ServerView), ac.GetById)
serverRoutes.Post("/", auth.HasPermission(model.ServerCreate), ac.CreateServer)
serverRoutes.Put("/:id", auth.HasPermission(model.ServerUpdate), ac.UpdateServer)
serverRoutes.Delete("/:id", auth.HasPermission(model.ServerDelete), ac.DeleteServer)
apiServerRoutes := routeGroups.Api.Group("/server")
@@ -152,41 +151,6 @@ func (ac *ServerController) CreateServer(c *fiber.Ctx) error {
return c.JSON(server)
}
// UpdateServer updates an existing server
// @Summary Update an ACC server
// @Description Update configuration for an existing ACC server
// @Tags Server
// @Accept json
// @Produce json
// @Param id path string true "Server ID (UUID format)"
// @Param server body model.Server true "Updated server configuration"
// @Success 200 {object} object "Updated server details"
// @Failure 400 {object} error_handler.ErrorResponse "Invalid server data or ID"
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions"
// @Failure 404 {object} error_handler.ErrorResponse "Server not found"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /server/{id} [put]
func (ac *ServerController) UpdateServer(c *fiber.Ctx) error {
serverIDStr := c.Params("id")
serverID, err := uuid.Parse(serverIDStr)
if err != nil {
return ac.errorHandler.HandleUUIDError(c, "server ID")
}
server := new(model.Server)
if err := c.BodyParser(server); err != nil {
return ac.errorHandler.HandleParsingError(c, err)
}
server.ID = serverID
if err := ac.service.UpdateServer(c, server); err != nil {
return ac.errorHandler.HandleServiceError(c, err)
}
return c.JSON(server)
}
// DeleteServer deletes an existing server
// @Summary Delete an ACC server
// @Description Delete an existing ACC server

View File

@@ -17,7 +17,6 @@ type WebSocketController struct {
jwtHandler *jwt.OpenJWTHandler
}
// NewWebSocketController initializes WebSocketController
func NewWebSocketController(
wsService *service.WebSocketService,
jwtHandler *jwt.OpenJWTHandler,
@@ -29,7 +28,6 @@ func NewWebSocketController(
jwtHandler: jwtHandler,
}
// WebSocket routes
wsRoutes := routeGroups.WebSocket
wsRoutes.Use("/", wsc.upgradeWebSocket)
wsRoutes.Get("/", websocket.New(wsc.handleWebSocket))
@@ -37,11 +35,8 @@ func NewWebSocketController(
return wsc
}
// upgradeWebSocket middleware to upgrade HTTP to WebSocket and validate authentication
func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error {
// Check if it's a WebSocket upgrade request
if websocket.IsWebSocketUpgrade(c) {
// Validate JWT token from query parameter or header
token := c.Query("token")
if token == "" {
token = c.Get("Authorization")
@@ -54,21 +49,18 @@ func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnauthorized, "Missing authentication token")
}
// Validate the token
claims, err := wsc.jwtHandler.ValidateToken(token)
if err != nil {
return fiber.NewError(fiber.StatusUnauthorized, "Invalid authentication token")
}
// Parse UserID string to UUID
userID, err := uuid.Parse(claims.UserID)
if err != nil {
return fiber.NewError(fiber.StatusUnauthorized, "Invalid user ID in token")
}
// Store user info in context for use in WebSocket handler
c.Locals("userID", userID)
c.Locals("username", claims.UserID) // Use UserID as username for now
c.Locals("username", claims.UserID)
return c.Next()
}
@@ -76,12 +68,9 @@ func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUpgradeRequired, "WebSocket upgrade required")
}
// handleWebSocket handles WebSocket connections
func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) {
// Generate a unique connection ID
connID := uuid.New().String()
// Get user info from locals (set by middleware)
userID, ok := c.Locals("userID").(uuid.UUID)
if !ok {
logging.Error("Failed to get user ID from WebSocket connection")
@@ -92,16 +81,13 @@ func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) {
username, _ := c.Locals("username").(string)
logging.Info("WebSocket connection established for user: %s (ID: %s)", username, userID.String())
// Add the connection to the service
wsc.webSocketService.AddConnection(connID, c, &userID)
// Handle connection cleanup
defer func() {
wsc.webSocketService.RemoveConnection(connID)
logging.Info("WebSocket connection closed for user: %s", username)
}()
// Handle incoming messages from the client
for {
messageType, message, err := c.ReadMessage()
if err != nil {
@@ -111,14 +97,12 @@ func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) {
break
}
// Handle different message types
switch messageType {
case websocket.TextMessage:
wsc.handleTextMessage(connID, userID, message)
case websocket.BinaryMessage:
logging.Debug("Received binary message from user %s (not supported)", username)
case websocket.PingMessage:
// Respond with pong
if err := c.WriteMessage(websocket.PongMessage, nil); err != nil {
logging.Error("Failed to send pong to user %s: %v", username, err)
break
@@ -127,21 +111,11 @@ func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) {
}
}
// handleTextMessage processes text messages from the client
func (wsc *WebSocketController) handleTextMessage(connID string, userID uuid.UUID, message []byte) {
logging.Debug("Received WebSocket message from user %s: %s", userID.String(), string(message))
// Parse the message to handle different types of client requests
// For now, we'll just log it. In the future, you might want to handle:
// - Subscription to specific server creation processes
// - Client heartbeat/keepalive
// - Request for status updates
// Example: If the message contains a server ID, associate this connection with that server
// This is a simple implementation - you might want to use proper JSON parsing
messageStr := string(message)
if len(messageStr) > 10 && messageStr[:9] == "server_id" {
// Extract server ID from message like "server_id:uuid"
if serverIDStr := messageStr[10:]; len(serverIDStr) > 0 {
if serverID, err := uuid.Parse(serverIDStr); err == nil {
wsc.webSocketService.SetServerID(connID, serverID)
@@ -151,18 +125,14 @@ func (wsc *WebSocketController) handleTextMessage(connID string, userID uuid.UUI
}
}
// GetWebSocketUpgrade returns the WebSocket upgrade handler for use in other controllers
func (wsc *WebSocketController) GetWebSocketUpgrade() fiber.Handler {
return wsc.upgradeWebSocket
}
// GetWebSocketHandler returns the WebSocket connection handler for use in other controllers
func (wsc *WebSocketController) GetWebSocketHandler() func(*websocket.Conn) {
return wsc.handleWebSocket
}
// BroadcastServerCreationProgress is a helper method for other services to broadcast progress
func (wsc *WebSocketController) BroadcastServerCreationProgress(serverID uuid.UUID, step string, status string, message string) {
// This can be used by the ServerService during server creation
logging.Info("Broadcasting server creation progress: %s - %s: %s", serverID.String(), step, status)
}

View File

@@ -10,12 +10,10 @@ import (
"github.com/google/uuid"
)
// AccessKeyMiddleware provides authentication and permission middleware.
type AccessKeyMiddleware struct {
userInfo CachedUserInfo
}
// NewAccessKeyMiddleware creates a new AccessKeyMiddleware.
func NewAccessKeyMiddleware() *AccessKeyMiddleware {
auth := &AccessKeyMiddleware{
userInfo: CachedUserInfo{UserID: uuid.New().String(), Username: "access_key", RoleName: "Admin", Permissions: map[string]bool{
@@ -25,9 +23,7 @@ func NewAccessKeyMiddleware() *AccessKeyMiddleware {
return auth
}
// Authenticate is a middleware for JWT authentication with enhanced security.
func (m *AccessKeyMiddleware) Authenticate(ctx *fiber.Ctx) error {
// Log authentication attempt
ip := ctx.IP()
userAgent := ctx.Get("User-Agent")

View File

@@ -16,7 +16,6 @@ import (
"github.com/google/uuid"
)
// CachedUserInfo holds cached user authentication and permission data
type CachedUserInfo struct {
UserID string
Username string
@@ -25,7 +24,6 @@ type CachedUserInfo struct {
CachedAt time.Time
}
// AuthMiddleware provides authentication and permission middleware.
type AuthMiddleware struct {
membershipService *service.MembershipService
cache *cache.InMemoryCache
@@ -34,7 +32,6 @@ type AuthMiddleware struct {
openJWTHandler *jwt.OpenJWTHandler
}
// NewAuthMiddleware creates a new AuthMiddleware.
func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache, jwtHandler *jwt.JWTHandler, openJWTHandler *jwt.OpenJWTHandler) *AuthMiddleware {
auth := &AuthMiddleware{
membershipService: ms,
@@ -44,24 +41,20 @@ func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache
openJWTHandler: openJWTHandler,
}
// Set up bidirectional relationship for cache invalidation
ms.SetCacheInvalidator(auth)
return auth
}
// Authenticate is a middleware for JWT authentication with enhanced security.
func (m *AuthMiddleware) AuthenticateOpen(ctx *fiber.Ctx) error {
return m.AuthenticateWithHandler(m.openJWTHandler.JWTHandler, true, ctx)
}
// Authenticate is a middleware for JWT authentication with enhanced security.
func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
return m.AuthenticateWithHandler(m.jwtHandler, false, ctx)
}
func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isOpenToken bool, ctx *fiber.Ctx) error {
// Log authentication attempt
ip := ctx.IP()
userAgent := ctx.Get("User-Agent")
@@ -89,7 +82,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO
})
}
// Validate token length to prevent potential attacks
token := parts[1]
if len(token) < 10 || len(token) > 2048 {
logging.Error("Authentication failed: invalid token length from IP %s", ip)
@@ -113,7 +105,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO
})
}
// Additional security: validate user ID format
if claims.UserID == "" || len(claims.UserID) < 10 {
logging.Error("Authentication failed: invalid user ID in token from IP %s", ip)
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
@@ -127,7 +118,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO
ctx.Locals("userInfo", userInfo)
ctx.Locals("authTime", time.Now())
} else {
// Preload and cache user info to avoid database queries on permission checks
userInfo, err := m.getCachedUserInfo(ctx.UserContext(), claims.UserID)
if err != nil {
logging.Error("Authentication failed: unable to load user info for %s from IP %s: %v", claims.UserID, ip, err)
@@ -145,7 +135,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO
return ctx.Next()
}
// HasPermission is a middleware for checking user permissions with enhanced logging.
func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler {
return func(ctx *fiber.Ctx) error {
userID, ok := ctx.Locals("userID").(string)
@@ -160,7 +149,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
return ctx.Next()
}
// Validate permission parameter
if requiredPermission == "" {
logging.Error("Permission check failed: empty permission requirement")
return ctx.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
@@ -168,7 +156,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
})
}
// Use cached user info from authentication step - no database queries needed
userInfo, ok := ctx.Locals("userInfo").(*CachedUserInfo)
if !ok {
logging.Error("Permission check failed: no cached user info in context from IP %s", ctx.IP())
@@ -177,7 +164,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
})
}
// Check if user has permission using cached data
has := m.hasPermissionFromCache(userInfo, requiredPermission)
if !has {
@@ -192,16 +178,13 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
}
}
// AuthRateLimit applies rate limiting specifically for authentication endpoints
func (m *AuthMiddleware) AuthRateLimit() fiber.Handler {
return m.securityMW.AuthRateLimit()
}
// RequireHTTPS redirects HTTP requests to HTTPS in production
func (m *AuthMiddleware) RequireHTTPS() fiber.Handler {
return func(ctx *fiber.Ctx) error {
if ctx.Protocol() != "https" && ctx.Get("X-Forwarded-Proto") != "https" {
// Allow HTTP in development/testing
if ctx.Hostname() != "localhost" && ctx.Hostname() != "127.0.0.1" {
httpsURL := "https://" + ctx.Hostname() + ctx.OriginalURL()
return ctx.Redirect(httpsURL, fiber.StatusMovedPermanently)
@@ -211,11 +194,9 @@ func (m *AuthMiddleware) RequireHTTPS() fiber.Handler {
}
}
// getCachedUserInfo retrieves and caches complete user information including permissions
func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (*CachedUserInfo, error) {
cacheKey := fmt.Sprintf("userinfo:%s", userID)
// Try cache first
if cached, found := m.cache.Get(cacheKey); found {
if userInfo, ok := cached.(*CachedUserInfo); ok {
logging.DebugWithContext("AUTH_CACHE", "User info for %s found in cache", userID)
@@ -223,13 +204,11 @@ func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (
}
}
// Cache miss - load from database
user, err := m.membershipService.GetUserWithPermissions(ctx, userID)
if err != nil {
return nil, err
}
// Build permission map for fast lookups
permissions := make(map[string]bool)
for _, p := range user.Role.Permissions {
permissions[p.Name] = true
@@ -243,34 +222,26 @@ func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (
CachedAt: time.Now(),
}
// Cache for 15 minutes
m.cache.Set(cacheKey, userInfo, 15*time.Minute)
logging.DebugWithContext("AUTH_CACHE", "User info for %s cached with %d permissions", userID, len(permissions))
return userInfo, nil
}
// hasPermissionFromCache checks permissions using cached user info (no database queries)
func (m *AuthMiddleware) hasPermissionFromCache(userInfo *CachedUserInfo, permission string) bool {
// Super Admin and Admin have all permissions
if userInfo.RoleName == "Super Admin" || userInfo.RoleName == "Admin" {
return true
}
// Check specific permission in cached map
return userInfo.Permissions[permission]
}
// InvalidateUserPermissions removes cached user info for a user
func (m *AuthMiddleware) InvalidateUserPermissions(userID string) {
cacheKey := fmt.Sprintf("userinfo:%s", userID)
m.cache.Delete(cacheKey)
logging.InfoWithContext("AUTH_CACHE", "User info cache invalidated for user %s", userID)
}
// InvalidateAllUserPermissions clears all cached user info (useful for role/permission changes)
func (m *AuthMiddleware) InvalidateAllUserPermissions() {
// This would need to be implemented based on your cache interface
// For now, just log that invalidation was requested
logging.InfoWithContext("AUTH_CACHE", "All user info caches invalidation requested")
}

View File

@@ -7,25 +7,20 @@ import (
"github.com/gofiber/fiber/v2"
)
// RequestLoggingMiddleware logs HTTP requests and responses
type RequestLoggingMiddleware struct {
infoLogger *logging.InfoLogger
}
// NewRequestLoggingMiddleware creates a new request logging middleware
func NewRequestLoggingMiddleware() *RequestLoggingMiddleware {
return &RequestLoggingMiddleware{
infoLogger: logging.GetInfoLogger(),
}
}
// Handler returns the middleware handler function
func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler {
return func(c *fiber.Ctx) error {
// Record start time
start := time.Now()
// Log incoming request
userAgent := c.Get("User-Agent")
if userAgent == "" {
userAgent = "Unknown"
@@ -33,17 +28,13 @@ func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler {
rlm.infoLogger.LogRequest(c.Method(), c.OriginalURL(), userAgent)
// Continue to next handler
err := c.Next()
// Calculate duration
duration := time.Since(start)
// Log response
statusCode := c.Response().StatusCode()
rlm.infoLogger.LogResponse(c.Method(), c.OriginalURL(), statusCode, duration.String())
// Log error if present
if err != nil {
logging.ErrorWithContext("REQUEST_MIDDLEWARE", "Request failed: %v", err)
}
@@ -52,10 +43,8 @@ func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler {
}
}
// Global request logging middleware instance
var globalRequestLoggingMiddleware *RequestLoggingMiddleware
// GetRequestLoggingMiddleware returns the global request logging middleware
func GetRequestLoggingMiddleware() *RequestLoggingMiddleware {
if globalRequestLoggingMiddleware == nil {
globalRequestLoggingMiddleware = NewRequestLoggingMiddleware()
@@ -63,7 +52,6 @@ func GetRequestLoggingMiddleware() *RequestLoggingMiddleware {
return globalRequestLoggingMiddleware
}
// Handler returns the global request logging middleware handler
func Handler() fiber.Handler {
return GetRequestLoggingMiddleware().Handler()
}

View File

@@ -11,19 +11,16 @@ import (
"github.com/gofiber/fiber/v2"
)
// RateLimiter stores rate limiting information
type RateLimiter struct {
requests map[string][]time.Time
mutex sync.RWMutex
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter() *RateLimiter {
rl := &RateLimiter{
requests: make(map[string][]time.Time),
}
// Use graceful shutdown for cleanup goroutine
shutdownManager := graceful.GetManager()
shutdownManager.RunGoroutine(func(ctx context.Context) {
rl.cleanupWithContext(ctx)
@@ -32,7 +29,6 @@ func NewRateLimiter() *RateLimiter {
return rl
}
// cleanup removes old entries from the rate limiter
func (rl *RateLimiter) cleanupWithContext(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
@@ -62,44 +58,29 @@ func (rl *RateLimiter) cleanupWithContext(ctx context.Context) {
}
}
// SecurityMiddleware provides comprehensive security middleware
type SecurityMiddleware struct {
rateLimiter *RateLimiter
}
// NewSecurityMiddleware creates a new security middleware
func NewSecurityMiddleware() *SecurityMiddleware {
return &SecurityMiddleware{
rateLimiter: NewRateLimiter(),
}
}
// SecurityHeaders adds security headers to responses
func (sm *SecurityMiddleware) SecurityHeaders() fiber.Handler {
return func(c *fiber.Ctx) error {
// Prevent MIME type sniffing
c.Set("X-Content-Type-Options", "nosniff")
// Prevent clickjacking
c.Set("X-Frame-Options", "DENY")
// Enable XSS protection
c.Set("X-XSS-Protection", "1; mode=block")
// Prevent referrer leakage
c.Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Content Security Policy
c.Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none'")
// Permissions Policy
c.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), interest-cohort=()")
return c.Next()
}
}
// RateLimit implements rate limiting for API endpoints
func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) fiber.Handler {
return func(c *fiber.Ctx) error {
ip := c.IP()
@@ -111,7 +92,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
now := time.Now()
requests := sm.rateLimiter.requests[key]
// Remove requests older than duration
filtered := make([]time.Time, 0, len(requests))
for _, t := range requests {
if now.Sub(t) < duration {
@@ -119,7 +99,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
}
}
// Check if limit is exceeded
if len(filtered) >= maxRequests {
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
"error": "Rate limit exceeded",
@@ -127,7 +106,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
})
}
// Add current request
filtered = append(filtered, now)
sm.rateLimiter.requests[key] = filtered
@@ -135,7 +113,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
}
}
// AuthRateLimit implements stricter rate limiting for authentication endpoints
func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
return func(c *fiber.Ctx) error {
ip := c.IP()
@@ -148,7 +125,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
now := time.Now()
requests := sm.rateLimiter.requests[key]
// Remove requests older than 15 minutes
filtered := make([]time.Time, 0, len(requests))
for _, t := range requests {
if now.Sub(t) < 15*time.Minute {
@@ -156,7 +132,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
}
}
// Check if limit is exceeded (5 requests per 15 minutes for auth)
if len(filtered) >= 5 {
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
"error": "Too many authentication attempts",
@@ -164,7 +139,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
})
}
// Add current request
filtered = append(filtered, now)
sm.rateLimiter.requests[key] = filtered
@@ -172,20 +146,16 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
}
}
// InputSanitization sanitizes user input to prevent XSS and injection attacks
func (sm *SecurityMiddleware) InputSanitization() fiber.Handler {
return func(c *fiber.Ctx) error {
// Sanitize query parameters
c.Request().URI().QueryArgs().VisitAll(func(key, value []byte) {
sanitized := sanitizeInput(string(value))
c.Request().URI().QueryArgs().Set(string(key), sanitized)
})
// Store original body for processing
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
body := c.Body()
if len(body) > 0 {
// Basic sanitization - remove potentially dangerous patterns
sanitized := sanitizeInput(string(body))
c.Request().SetBodyString(sanitized)
}
@@ -195,7 +165,6 @@ func (sm *SecurityMiddleware) InputSanitization() fiber.Handler {
}
}
// sanitizeInput removes potentially dangerous patterns from input
func sanitizeInput(input string) string {
dangerous := []string{
"<script",
@@ -236,7 +205,7 @@ func sanitizeInput(input string) string {
result := input
lowerInput := strings.ToLower(input)
for _, pattern := range dangerous {
if strings.Contains(lowerInput, pattern) {
return ""
@@ -254,7 +223,6 @@ func sanitizeInput(input string) string {
return result
}
// ValidateContentType ensures only expected content types are accepted
func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.Handler {
return func(c *fiber.Ctx) error {
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
@@ -265,7 +233,6 @@ func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.
})
}
// Check if content type is allowed
allowed := false
for _, allowedType := range allowedTypes {
if strings.Contains(contentType, allowedType) {
@@ -285,7 +252,6 @@ func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.
}
}
// ValidateUserAgent blocks requests with suspicious or missing user agents
func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
suspiciousAgents := []string{
"sqlmap",
@@ -296,21 +262,19 @@ func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
"dirb",
"dirbuster",
"wpscan",
"curl/7.0", // Very old curl versions
"wget/1.0", // Very old wget versions
"curl/7.0",
"wget/1.0",
}
return func(c *fiber.Ctx) error {
userAgent := strings.ToLower(c.Get("User-Agent"))
// Block empty user agents
if userAgent == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "User-Agent header is required",
})
}
// Block suspicious user agents
for _, suspicious := range suspiciousAgents {
if strings.Contains(userAgent, suspicious) {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
@@ -323,7 +287,6 @@ func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
}
}
// RequestSizeLimit limits the size of incoming requests
func (sm *SecurityMiddleware) RequestSizeLimit(maxSize int) fiber.Handler {
return func(c *fiber.Ctx) error {
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
@@ -340,19 +303,15 @@ func (sm *SecurityMiddleware) RequestSizeLimit(maxSize int) fiber.Handler {
}
}
// LogSecurityEvents logs security-related events
func (sm *SecurityMiddleware) LogSecurityEvents() fiber.Handler {
return func(c *fiber.Ctx) error {
start := time.Now()
// Process request
err := c.Next()
// Log suspicious activity
status := c.Response().StatusCode()
if status == 401 || status == 403 || status == 429 {
duration := time.Since(start)
// In a real implementation, you would send this to your logging system
fmt.Printf("[SECURITY] %s %s %s %d %v %s\n",
time.Now().Format(time.RFC3339),
c.IP(),
@@ -367,7 +326,6 @@ func (sm *SecurityMiddleware) LogSecurityEvents() fiber.Handler {
}
}
// TimeoutMiddleware adds request timeout
func (sm *SecurityMiddleware) TimeoutMiddleware(timeout time.Duration) fiber.Handler {
return func(c *fiber.Ctx) error {
ctx, cancel := context.WithTimeout(c.UserContext(), timeout)

View File

@@ -9,21 +9,17 @@ import (
"gorm.io/gorm"
)
// Migration001UpgradePasswordSecurity migrates existing user passwords from encrypted to hashed format
type Migration001UpgradePasswordSecurity struct {
DB *gorm.DB
}
// NewMigration001UpgradePasswordSecurity creates a new password security migration
func NewMigration001UpgradePasswordSecurity(db *gorm.DB) *Migration001UpgradePasswordSecurity {
return &Migration001UpgradePasswordSecurity{DB: db}
}
// Up executes the migration
func (m *Migration001UpgradePasswordSecurity) Up() error {
logging.Info("Starting password security upgrade migration...")
// Check if migration has already been applied
var migrationRecord MigrationRecord
err := m.DB.Where("migration_name = ?", "001_upgrade_password_security").First(&migrationRecord).Error
if err == nil {
@@ -31,12 +27,10 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
return nil
}
// Create migration tracking table if it doesn't exist
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
return fmt.Errorf("failed to create migration tracking table: %v", err)
}
// Start transaction
tx := m.DB.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to start transaction: %v", tx.Error)
@@ -47,16 +41,13 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
}
}()
// Add a backup column for old passwords (temporary)
if err := tx.Exec("ALTER TABLE users ADD COLUMN password_backup TEXT").Error; err != nil {
// Column might already exist, ignore if it's a duplicate column error
if !isDuplicateColumnError(err) {
tx.Rollback()
return fmt.Errorf("failed to add backup column: %v", err)
}
}
// Get all users with encrypted passwords
var users []UserForMigration
if err := tx.Find(&users).Error; err != nil {
tx.Rollback()
@@ -72,19 +63,15 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
if err := m.migrateUserPassword(tx, &user); err != nil {
logging.Error("Failed to migrate user %s (ID: %s): %v", user.Username, user.ID, err)
failedCount++
// Continue with other users rather than failing completely
continue
}
migratedCount++
}
// Remove backup column after successful migration
if err := tx.Exec("ALTER TABLE users DROP COLUMN password_backup").Error; err != nil {
logging.Error("Failed to remove backup column (non-critical): %v", err)
// Don't fail the migration for this
}
// Record successful migration
migrationRecord = MigrationRecord{
MigrationName: "001_upgrade_password_security",
AppliedAt: "datetime('now')",
@@ -97,7 +84,6 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
return fmt.Errorf("failed to record migration: %v", err)
}
// Commit transaction
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit migration: %v", err)
}
@@ -111,32 +97,24 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
return nil
}
// migrateUserPassword migrates a single user's password
func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, user *UserForMigration) error {
// Skip if password is already hashed (bcrypt hashes start with $2a$, $2b$, or $2y$)
if isAlreadyHashed(user.Password) {
logging.Debug("User %s already has hashed password, skipping", user.Username)
return nil
}
// Backup original password
if err := tx.Model(user).Update("password_backup", user.Password).Error; err != nil {
return fmt.Errorf("failed to backup password: %v", err)
}
// Try to decrypt the old password
var plainPassword string
// First, try to decrypt using the old encryption method
decrypted, err := decryptOldPassword(user.Password)
if err != nil {
// If decryption fails, the password might already be plain text or corrupted
logging.Error("Failed to decrypt password for user %s, treating as plain text: %v", user.Username, err)
// Use original password as-is (might be plain text from development)
plainPassword = user.Password
// Validate it's not obviously encrypted data
if len(plainPassword) > 100 || containsBinaryData(plainPassword) {
return fmt.Errorf("password appears to be corrupted encrypted data")
}
@@ -144,7 +122,6 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u
plainPassword = decrypted
}
// Validate plain password
if plainPassword == "" {
return errors.New("decrypted password is empty")
}
@@ -153,13 +130,11 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u
return errors.New("password too short after decryption")
}
// Hash the plain password using bcrypt
hashedPassword, err := password.HashPassword(plainPassword)
if err != nil {
return fmt.Errorf("failed to hash password: %v", err)
}
// Update with hashed password
if err := tx.Model(user).Update("password", hashedPassword).Error; err != nil {
return fmt.Errorf("failed to update password: %v", err)
}
@@ -168,19 +143,16 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u
return nil
}
// UserForMigration represents a user record for migration purposes
type UserForMigration struct {
ID string `gorm:"column:id"`
Username string `gorm:"column:username"`
Password string `gorm:"column:password"`
}
// TableName specifies the table name for GORM
func (UserForMigration) TableName() string {
return "users"
}
// MigrationRecord tracks applied migrations
type MigrationRecord struct {
ID uint `gorm:"primaryKey"`
MigrationName string `gorm:"unique;not null"`
@@ -189,49 +161,38 @@ type MigrationRecord struct {
Notes string
}
// TableName specifies the table name for GORM
func (MigrationRecord) TableName() string {
return "migration_records"
}
// isAlreadyHashed checks if a password is already bcrypt hashed
func isAlreadyHashed(password string) bool {
return len(password) >= 60 && (password[:4] == "$2a$" || password[:4] == "$2b$" || password[:4] == "$2y$")
}
// containsBinaryData checks if a string contains binary data
func containsBinaryData(s string) bool {
for _, b := range []byte(s) {
if b < 32 && b != 9 && b != 10 && b != 13 { // Allow tab, newline, carriage return
if b < 32 && b != 9 && b != 10 && b != 13 {
return true
}
}
return false
}
// isDuplicateColumnError checks if an error is due to duplicate column
func isDuplicateColumnError(err error) bool {
errStr := err.Error()
return fmt.Sprintf("%v", errStr) == "duplicate column name: password_backup" ||
fmt.Sprintf("%v", errStr) == "SQLITE_ERROR: duplicate column name: password_backup"
}
// decryptOldPassword attempts to decrypt using the old encryption method
// This is a simplified version of the old DecryptPassword function
func decryptOldPassword(encryptedPassword string) (string, error) {
// This would use the old decryption logic
// For now, we'll return an error to force treating as plain text
// In a real scenario, you'd implement the old decryption here
return "", errors.New("old decryption not implemented - treating as plain text")
}
// Down reverses the migration (if needed)
func (m *Migration001UpgradePasswordSecurity) Down() error {
logging.Error("Password security migration rollback is not supported for security reasons")
return errors.New("password security migration rollback is not supported")
}
// RunMigration is a convenience function to run the migration
func RunPasswordSecurityMigration(db *gorm.DB) error {
migration := NewMigration001UpgradePasswordSecurity(db)
return migration.Up()

View File

@@ -9,21 +9,17 @@ import (
"gorm.io/gorm"
)
// Migration002MigrateToUUID migrates tables from integer IDs to UUIDs
type Migration002MigrateToUUID struct {
DB *gorm.DB
}
// NewMigration002MigrateToUUID creates a new UUID migration
func NewMigration002MigrateToUUID(db *gorm.DB) *Migration002MigrateToUUID {
return &Migration002MigrateToUUID{DB: db}
}
// Up executes the migration
func (m *Migration002MigrateToUUID) Up() error {
logging.Info("Checking UUID migration...")
// Check if migration is needed by looking at the servers table structure
if !m.needsMigration() {
logging.Info("UUID migration not needed - tables already use UUID primary keys")
return nil
@@ -31,7 +27,6 @@ func (m *Migration002MigrateToUUID) Up() error {
logging.Info("Starting UUID migration...")
// Check if migration has already been applied
var migrationRecord MigrationRecord
err := m.DB.Where("migration_name = ?", "002_migrate_to_uuid").First(&migrationRecord).Error
if err == nil {
@@ -39,12 +34,10 @@ func (m *Migration002MigrateToUUID) Up() error {
return nil
}
// Create migration tracking table if it doesn't exist
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
return fmt.Errorf("failed to create migration tracking table: %v", err)
}
// Execute the UUID migration using the existing migration function
logging.Info("Executing UUID migration...")
if err := runUUIDMigrationSQL(m.DB); err != nil {
return fmt.Errorf("failed to execute UUID migration: %v", err)
@@ -54,9 +47,7 @@ func (m *Migration002MigrateToUUID) Up() error {
return nil
}
// needsMigration checks if the UUID migration is needed by examining table structure
func (m *Migration002MigrateToUUID) needsMigration() bool {
// Check if servers table exists and has integer primary key
var result struct {
Type string `gorm:"column:type"`
}
@@ -67,29 +58,22 @@ func (m *Migration002MigrateToUUID) needsMigration() bool {
`).Scan(&result).Error
if err != nil || result.Type == "" {
// Table doesn't exist or no primary key found - assume no migration needed
return false
}
// If the primary key is INTEGER, we need migration
// If it's TEXT (UUID), migration already done
return result.Type == "INTEGER" || result.Type == "integer"
}
// Down reverses the migration (not implemented for safety)
func (m *Migration002MigrateToUUID) Down() error {
logging.Error("UUID migration rollback is not supported for data safety reasons")
return fmt.Errorf("UUID migration rollback is not supported")
}
// runUUIDMigrationSQL executes the UUID migration using the SQL file
func runUUIDMigrationSQL(db *gorm.DB) error {
// Disable foreign key constraints during migration
if err := db.Exec("PRAGMA foreign_keys=OFF").Error; err != nil {
return fmt.Errorf("failed to disable foreign keys: %v", err)
}
// Start transaction
tx := db.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to start transaction: %v", tx.Error)
@@ -101,25 +85,21 @@ func runUUIDMigrationSQL(db *gorm.DB) error {
}
}()
// Read the migration SQL from file
sqlPath := filepath.Join("scripts", "migrations", "002_migrate_servers_to_uuid.sql")
migrationSQL, err := ioutil.ReadFile(sqlPath)
if err != nil {
return fmt.Errorf("failed to read migration SQL file: %v", err)
}
// Execute the migration
if err := tx.Exec(string(migrationSQL)).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to execute migration: %v", err)
}
// Commit transaction
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit migration: %v", err)
}
// Re-enable foreign key constraints
if err := db.Exec("PRAGMA foreign_keys=ON").Error; err != nil {
return fmt.Errorf("failed to re-enable foreign keys: %v", err)
}
@@ -127,7 +107,6 @@ func runUUIDMigrationSQL(db *gorm.DB) error {
return nil
}
// RunUUIDMigration is a convenience function to run the migration
func RunUUIDMigration(db *gorm.DB) error {
migration := NewMigration002MigrateToUUID(db)
return migration.Up()

View File

@@ -7,21 +7,17 @@ import (
"gorm.io/gorm"
)
// UpdateStateHistorySessions migrates tables from integer IDs to UUIDs
type UpdateStateHistorySessions struct {
DB *gorm.DB
}
// NewUpdateStateHistorySessions creates a new UUID migration
func NewUpdateStateHistorySessions(db *gorm.DB) *UpdateStateHistorySessions {
return &UpdateStateHistorySessions{DB: db}
}
// Up executes the migration
func (m *UpdateStateHistorySessions) Up() error {
logging.Info("Checking UUID migration...")
// Check if migration is needed by looking at the servers table structure
if !m.needsMigration() {
logging.Info("UUID migration not needed - tables already use UUID primary keys")
return nil
@@ -29,7 +25,6 @@ func (m *UpdateStateHistorySessions) Up() error {
logging.Info("Starting UUID migration...")
// Check if migration has already been applied
var migrationRecord MigrationRecord
err := m.DB.Where("migration_name = ?", "002_migrate_to_uuid").First(&migrationRecord).Error
if err == nil {
@@ -37,12 +32,10 @@ func (m *UpdateStateHistorySessions) Up() error {
return nil
}
// Create migration tracking table if it doesn't exist
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
return fmt.Errorf("failed to create migration tracking table: %v", err)
}
// Execute the UUID migration using the existing migration function
logging.Info("Executing UUID migration...")
if err := runUUIDMigrationSQL(m.DB); err != nil {
return fmt.Errorf("failed to execute UUID migration: %v", err)
@@ -52,9 +45,7 @@ func (m *UpdateStateHistorySessions) Up() error {
return nil
}
// needsMigration checks if the UUID migration is needed by examining table structure
func (m *UpdateStateHistorySessions) needsMigration() bool {
// Check if servers table exists and has integer primary key
var result struct {
Exists bool `gorm:"column:exists"`
}
@@ -65,26 +56,21 @@ func (m *UpdateStateHistorySessions) needsMigration() bool {
`).Scan(&result).Error
if err != nil || !result.Exists {
// Table doesn't exist or no primary key found - assume no migration needed
return false
}
return result.Exists
}
// Down reverses the migration (not implemented for safety)
func (m *UpdateStateHistorySessions) Down() error {
logging.Error("UUID migration rollback is not supported for data safety reasons")
return fmt.Errorf("UUID migration rollback is not supported")
}
// runUpdateStateHistorySessionsMigration executes the UUID migration using the SQL file
func runUpdateStateHistorySessionsMigration(db *gorm.DB) error {
// Disable foreign key constraints during migration
if err := db.Exec("PRAGMA foreign_keys=OFF").Error; err != nil {
return fmt.Errorf("failed to disable foreign keys: %v", err)
}
// Start transaction
tx := db.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to start transaction: %v", tx.Error)
@@ -98,18 +84,15 @@ func runUpdateStateHistorySessionsMigration(db *gorm.DB) error {
migrationSQL := "UPDATE state_history SET session = upper(substr(session, 1, 1));"
// Execute the migration
if err := tx.Exec(string(migrationSQL)).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to execute migration: %v", err)
}
// Commit transaction
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit migration: %v", err)
}
// Re-enable foreign key constraints
if err := db.Exec("PRAGMA foreign_keys=ON").Error; err != nil {
return fmt.Errorf("failed to re-enable foreign keys: %v", err)
}
@@ -117,7 +100,6 @@ func runUpdateStateHistorySessionsMigration(db *gorm.DB) error {
return nil
}
// RunUpdateStateHistorySessionsMigration is a convenience function to run the migration
func RunUpdateStateHistorySessionsMigration(db *gorm.DB) error {
migration := NewUpdateStateHistorySessions(db)
return migration.Up()

View File

@@ -6,20 +6,17 @@ import (
"time"
)
// StatusCache represents a cached server status with expiration
type StatusCache struct {
Status ServiceStatus
UpdatedAt time.Time
}
// CacheConfig holds configuration for cache behavior
type CacheConfig struct {
ExpirationTime time.Duration // How long before a cache entry expires
ThrottleTime time.Duration // Minimum time between status checks
DefaultStatus ServiceStatus // Default status to return when throttled
ExpirationTime time.Duration
ThrottleTime time.Duration
DefaultStatus ServiceStatus
}
// ServerStatusCache manages cached server statuses
type ServerStatusCache struct {
sync.RWMutex
cache map[string]*StatusCache
@@ -27,7 +24,6 @@ type ServerStatusCache struct {
lastChecked map[string]time.Time
}
// NewServerStatusCache creates a new server status cache
func NewServerStatusCache(config CacheConfig) *ServerStatusCache {
return &ServerStatusCache{
cache: make(map[string]*StatusCache),
@@ -36,12 +32,10 @@ func NewServerStatusCache(config CacheConfig) *ServerStatusCache {
}
}
// GetStatus retrieves the cached status or indicates if a fresh check is needed
func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool) {
c.RLock()
defer c.RUnlock()
// Check if we're being throttled
if lastCheck, exists := c.lastChecked[serviceName]; exists {
if time.Since(lastCheck) < c.config.ThrottleTime {
if cached, ok := c.cache[serviceName]; ok {
@@ -51,7 +45,6 @@ func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool)
}
}
// Check if we have a valid cached entry
if cached, ok := c.cache[serviceName]; ok {
if time.Since(cached.UpdatedAt) < c.config.ExpirationTime {
return cached.Status, false
@@ -61,7 +54,6 @@ func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool)
return StatusUnknown, true
}
// UpdateStatus updates the cache with a new status
func (c *ServerStatusCache) UpdateStatus(serviceName string, status ServiceStatus) {
c.Lock()
defer c.Unlock()
@@ -73,7 +65,6 @@ func (c *ServerStatusCache) UpdateStatus(serviceName string, status ServiceStatu
c.lastChecked[serviceName] = time.Now()
}
// InvalidateStatus removes a specific service from the cache
func (c *ServerStatusCache) InvalidateStatus(serviceName string) {
c.Lock()
defer c.Unlock()
@@ -82,7 +73,6 @@ func (c *ServerStatusCache) InvalidateStatus(serviceName string) {
delete(c.lastChecked, serviceName)
}
// Clear removes all entries from the cache
func (c *ServerStatusCache) Clear() {
c.Lock()
defer c.Unlock()
@@ -91,13 +81,11 @@ func (c *ServerStatusCache) Clear() {
c.lastChecked = make(map[string]time.Time)
}
// LookupCache provides a generic cache for lookup data
type LookupCache struct {
sync.RWMutex
data map[string]interface{}
}
// NewLookupCache creates a new lookup cache
func NewLookupCache() *LookupCache {
logging.Debug("Initializing new LookupCache")
return &LookupCache{
@@ -105,7 +93,6 @@ func NewLookupCache() *LookupCache {
}
}
// Get retrieves a cached value by key
func (c *LookupCache) Get(key string) (interface{}, bool) {
c.RLock()
defer c.RUnlock()
@@ -119,7 +106,6 @@ func (c *LookupCache) Get(key string) (interface{}, bool) {
return value, exists
}
// Set stores a value in the cache
func (c *LookupCache) Set(key string, value interface{}) {
c.Lock()
defer c.Unlock()
@@ -128,7 +114,6 @@ func (c *LookupCache) Set(key string, value interface{}) {
logging.Debug("Cache SET for key: %s", key)
}
// Clear removes all entries from the cache
func (c *LookupCache) Clear() {
c.Lock()
defer c.Unlock()
@@ -137,13 +122,11 @@ func (c *LookupCache) Clear() {
logging.Debug("Cache CLEARED")
}
// ConfigEntry represents a cached configuration entry with its update time
type ConfigEntry[T any] struct {
Data T
UpdatedAt time.Time
}
// getConfigFromCache is a generic helper function to retrieve cached configs
func getConfigFromCache[T any](cache map[string]*ConfigEntry[T], serverID string, expirationTime time.Duration) (*T, bool) {
if entry, ok := cache[serverID]; ok {
if time.Since(entry.UpdatedAt) < expirationTime {
@@ -157,7 +140,6 @@ func getConfigFromCache[T any](cache map[string]*ConfigEntry[T], serverID string
return nil, false
}
// updateConfigInCache is a generic helper function to update cached configs
func updateConfigInCache[T any](cache map[string]*ConfigEntry[T], serverID string, data T) {
cache[serverID] = &ConfigEntry[T]{
Data: data,
@@ -166,7 +148,6 @@ func updateConfigInCache[T any](cache map[string]*ConfigEntry[T], serverID strin
logging.Debug("Config cache SET for server ID: %s", serverID)
}
// ServerConfigCache manages cached server configurations
type ServerConfigCache struct {
sync.RWMutex
configuration map[string]*ConfigEntry[Configuration]
@@ -177,7 +158,6 @@ type ServerConfigCache struct {
config CacheConfig
}
// NewServerConfigCache creates a new server configuration cache
func NewServerConfigCache(config CacheConfig) *ServerConfigCache {
logging.Debug("Initializing new ServerConfigCache with expiration time: %v, throttle time: %v", config.ExpirationTime, config.ThrottleTime)
return &ServerConfigCache{
@@ -190,7 +170,6 @@ func NewServerConfigCache(config CacheConfig) *ServerConfigCache {
}
}
// GetConfiguration retrieves a cached configuration
func (c *ServerConfigCache) GetConfiguration(serverID string) (*Configuration, bool) {
c.RLock()
defer c.RUnlock()
@@ -198,7 +177,6 @@ func (c *ServerConfigCache) GetConfiguration(serverID string) (*Configuration, b
return getConfigFromCache(c.configuration, serverID, c.config.ExpirationTime)
}
// GetAssistRules retrieves cached assist rules
func (c *ServerConfigCache) GetAssistRules(serverID string) (*AssistRules, bool) {
c.RLock()
defer c.RUnlock()
@@ -206,7 +184,6 @@ func (c *ServerConfigCache) GetAssistRules(serverID string) (*AssistRules, bool)
return getConfigFromCache(c.assistRules, serverID, c.config.ExpirationTime)
}
// GetEvent retrieves cached event configuration
func (c *ServerConfigCache) GetEvent(serverID string) (*EventConfig, bool) {
c.RLock()
defer c.RUnlock()
@@ -214,7 +191,6 @@ func (c *ServerConfigCache) GetEvent(serverID string) (*EventConfig, bool) {
return getConfigFromCache(c.event, serverID, c.config.ExpirationTime)
}
// GetEventRules retrieves cached event rules
func (c *ServerConfigCache) GetEventRules(serverID string) (*EventRules, bool) {
c.RLock()
defer c.RUnlock()
@@ -222,7 +198,6 @@ func (c *ServerConfigCache) GetEventRules(serverID string) (*EventRules, bool) {
return getConfigFromCache(c.eventRules, serverID, c.config.ExpirationTime)
}
// GetSettings retrieves cached server settings
func (c *ServerConfigCache) GetSettings(serverID string) (*ServerSettings, bool) {
c.RLock()
defer c.RUnlock()
@@ -230,7 +205,6 @@ func (c *ServerConfigCache) GetSettings(serverID string) (*ServerSettings, bool)
return getConfigFromCache(c.settings, serverID, c.config.ExpirationTime)
}
// UpdateConfiguration updates the configuration cache
func (c *ServerConfigCache) UpdateConfiguration(serverID string, config Configuration) {
c.Lock()
defer c.Unlock()
@@ -238,7 +212,6 @@ func (c *ServerConfigCache) UpdateConfiguration(serverID string, config Configur
updateConfigInCache(c.configuration, serverID, config)
}
// UpdateAssistRules updates the assist rules cache
func (c *ServerConfigCache) UpdateAssistRules(serverID string, rules AssistRules) {
c.Lock()
defer c.Unlock()
@@ -246,7 +219,6 @@ func (c *ServerConfigCache) UpdateAssistRules(serverID string, rules AssistRules
updateConfigInCache(c.assistRules, serverID, rules)
}
// UpdateEvent updates the event configuration cache
func (c *ServerConfigCache) UpdateEvent(serverID string, event EventConfig) {
c.Lock()
defer c.Unlock()
@@ -254,7 +226,6 @@ func (c *ServerConfigCache) UpdateEvent(serverID string, event EventConfig) {
updateConfigInCache(c.event, serverID, event)
}
// UpdateEventRules updates the event rules cache
func (c *ServerConfigCache) UpdateEventRules(serverID string, rules EventRules) {
c.Lock()
defer c.Unlock()
@@ -262,7 +233,6 @@ func (c *ServerConfigCache) UpdateEventRules(serverID string, rules EventRules)
updateConfigInCache(c.eventRules, serverID, rules)
}
// UpdateSettings updates the server settings cache
func (c *ServerConfigCache) UpdateSettings(serverID string, settings ServerSettings) {
c.Lock()
defer c.Unlock()
@@ -270,7 +240,6 @@ func (c *ServerConfigCache) UpdateSettings(serverID string, settings ServerSetti
updateConfigInCache(c.settings, serverID, settings)
}
// InvalidateServerCache removes all cached configurations for a specific server
func (c *ServerConfigCache) InvalidateServerCache(serverID string) {
c.Lock()
defer c.Unlock()
@@ -283,7 +252,6 @@ func (c *ServerConfigCache) InvalidateServerCache(serverID string) {
delete(c.settings, serverID)
}
// Clear removes all entries from the cache
func (c *ServerConfigCache) Clear() {
c.Lock()
defer c.Unlock()

View File

@@ -13,17 +13,15 @@ import (
type IntString int
type IntBool int
// Config tracks configuration modifications
type Config struct {
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
ServerID uuid.UUID `json:"serverId" gorm:"not null;type:uuid"`
ConfigFile string `json:"configFile" gorm:"not null"` // e.g. "settings.json"
ConfigFile string `json:"configFile" gorm:"not null"`
OldConfig string `json:"oldConfig" gorm:"type:text"`
NewConfig string `json:"newConfig" gorm:"type:text"`
ChangedAt time.Time `json:"changedAt" gorm:"default:CURRENT_TIMESTAMP"`
}
// BeforeCreate is a GORM hook that runs before creating new config entries
func (c *Config) BeforeCreate(tx *gorm.DB) error {
if c.ID == uuid.Nil {
c.ID = uuid.New()
@@ -121,8 +119,6 @@ type Configuration struct {
ConfigVersion IntString `json:"configVersion"`
}
// Known configuration keys
func (i *IntBool) UnmarshalJSON(b []byte) error {
var str int
if err := json.Unmarshal(b, &str); err == nil && str <= 1 {

View File

@@ -7,7 +7,6 @@ import (
"gorm.io/gorm"
)
// BaseFilter contains common filter fields that can be embedded in other filters
type BaseFilter struct {
Page int `query:"page"`
PageSize int `query:"page_size"`
@@ -15,18 +14,15 @@ type BaseFilter struct {
SortDesc bool `query:"sort_desc"`
}
// DateRangeFilter adds date range filtering capabilities
type DateRangeFilter struct {
StartDate time.Time `query:"start_date" time_format:"2006-01-02T15:04:05Z07:00"`
EndDate time.Time `query:"end_date" time_format:"2006-01-02T15:04:05Z07:00"`
}
// ServerBasedFilter adds server ID filtering capability
type ServerBasedFilter struct {
ServerID string `param:"id"`
}
// ConfigFilter defines filtering options for Config queries
type ConfigFilter struct {
BaseFilter
ServerBasedFilter
@@ -34,13 +30,11 @@ type ConfigFilter struct {
ChangedAt time.Time `query:"changed_at" time_format:"2006-01-02T15:04:05Z07:00"`
}
// ApiFilter defines filtering options for Api queries
type ServiceControlFilter struct {
BaseFilter
ServiceControl string `query:"serviceControl"`
}
// MembershipFilter defines filtering options for User queries
type MembershipFilter struct {
BaseFilter
Username string `query:"username"`
@@ -48,36 +42,32 @@ type MembershipFilter struct {
RoleID string `query:"role_id"`
}
// Pagination returns the offset and limit for database queries
func (f *BaseFilter) Pagination() (offset, limit int) {
if f.Page < 1 {
f.Page = 1
}
if f.PageSize < 1 {
f.PageSize = 10 // Default page size
f.PageSize = 10
}
offset = (f.Page - 1) * f.PageSize
limit = f.PageSize
return
}
// GetSorting returns the sort field and direction for database queries
func (f *BaseFilter) GetSorting() (field string, desc bool) {
if f.SortBy == "" {
return "id", false // Default sorting
return "id", false
}
return f.SortBy, f.SortDesc
}
// IsDateRangeValid checks if both dates are set and start date is before end date
func (f *DateRangeFilter) IsDateRangeValid() bool {
if f.StartDate.IsZero() || f.EndDate.IsZero() {
return true // If either date is not set, consider it valid
return true
}
return f.StartDate.Before(f.EndDate)
}
// ApplyFilter applies the membership filter to a GORM query
func (f *MembershipFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
if f.Username != "" {
query = query.Where("username LIKE ?", "%"+f.Username+"%")
@@ -93,12 +83,10 @@ func (f *MembershipFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
return query
}
// Pagination returns the offset and limit for database queries
func (f *MembershipFilter) Pagination() (offset, limit int) {
return f.BaseFilter.Pagination()
}
// GetSorting returns the sort field and direction for database queries
func (f *MembershipFilter) GetSorting() (field string, desc bool) {
return f.BaseFilter.GetSorting()
}

View File

@@ -1,31 +1,26 @@
package model
// Track represents a track and its capacity
type Track struct {
Name string `json:"track" gorm:"primaryKey;size:50"`
UniquePitBoxes int `json:"unique_pit_boxes"`
PrivateServerSlots int `json:"private_server_slots"`
}
// CarModel represents a car model mapping
type CarModel struct {
Value int `json:"value" gorm:"primaryKey"`
CarModel string `json:"car_model"`
}
// DriverCategory represents driver skill categories
type DriverCategory struct {
Value int `json:"value" gorm:"primaryKey"`
Category string `json:"category"`
}
// CupCategory represents championship cup categories
type CupCategory struct {
Value int `json:"value" gorm:"primaryKey"`
Category string `json:"category"`
}
// SessionType represents session types
type SessionType struct {
Value int `json:"value" gorm:"primaryKey"`
SessionType string `json:"session_type"`

View File

@@ -32,8 +32,6 @@ type BaseModel struct {
DateUpdated time.Time `json:"dateUpdated"`
}
// Init
// Initializes base model with DateCreated, DateUpdated, and Id values.
func (cm *BaseModel) Init() {
date := time.Now()
cm.Id = uuid.NewString()

View File

@@ -5,15 +5,13 @@ import (
"gorm.io/gorm"
)
// Permission represents an action that can be performed in the system.
type Permission struct {
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
Name string `json:"name" gorm:"unique_index;not null"`
}
// BeforeCreate is a GORM hook that runs before creating new credentials
func (s *Permission) BeforeCreate(tx *gorm.DB) error {
s.ID = uuid.New()
return nil
}
}

View File

@@ -1,6 +1,5 @@
package model
// Permission constants
const (
ServerView = "server.view"
ServerCreate = "server.create"
@@ -27,7 +26,6 @@ const (
MembershipEdit = "membership.edit"
)
// AllPermissions returns a slice of all permission strings.
func AllPermissions() []string {
return []string{
ServerView,

View File

@@ -5,16 +5,14 @@ import (
"gorm.io/gorm"
)
// Role represents a user role in the system.
type Role struct {
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
Name string `json:"name" gorm:"unique_index;not null"`
Permissions []Permission `json:"permissions" gorm:"many2many:role_permissions;"`
}
// BeforeCreate is a GORM hook that runs before creating new credentials
func (s *Role) BeforeCreate(tx *gorm.DB) error {
s.ID = uuid.New()
return nil
}
}

View File

@@ -16,7 +16,6 @@ const (
ServiceNamePrefix = "ACC-Server"
)
// Server represents an ACC server instance
type ServerAPI struct {
Name string `json:"name"`
Status ServiceStatus `json:"status"`
@@ -35,28 +34,27 @@ func (s *Server) ToServerAPI() *ServerAPI {
}
}
// Server represents an ACC server instance
type Server struct {
ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"`
Name string `gorm:"not null" json:"name"`
Status ServiceStatus `json:"status" gorm:"-"`
IP string `gorm:"not null" json:"-"`
Port int `gorm:"not null" json:"-"`
Path string `gorm:"not null" json:"path"` // e.g. "/acc/servers/server1/"
ServiceName string `gorm:"not null" json:"serviceName"` // Windows service name
Path string `gorm:"not null" json:"path"`
ServiceName string `gorm:"not null" json:"serviceName"`
State *ServerState `gorm:"-" json:"state"`
DateCreated time.Time `json:"dateCreated"`
FromSteamCMD bool `gorm:"not null; default:true" json:"-"`
}
type PlayerState struct {
CarID int // Car ID in broadcast packets
DriverName string // Optional: pulled from registration packet
CarID int
DriverName string
TeamName string
CarModel string
CurrentLap int
LastLapTime int // in milliseconds
BestLapTime int // in milliseconds
LastLapTime int
BestLapTime int
Position int
ConnectedAt time.Time
DisconnectedAt *time.Time
@@ -67,8 +65,6 @@ type State struct {
Session string `json:"session"`
SessionStart time.Time `json:"sessionStart"`
PlayerCount int `json:"playerCount"`
// Players map[int]*PlayerState
// etc.
}
type ServerState struct {
@@ -79,11 +75,8 @@ type ServerState struct {
Track string `json:"track"`
MaxConnections int `json:"maxConnections"`
SessionDurationMinutes int `json:"sessionDurationMinutes"`
// Players map[int]*PlayerState
// etc.
}
// ServerFilter defines filtering options for Server queries
type ServerFilter struct {
BaseFilter
ServerBasedFilter
@@ -92,9 +85,7 @@ type ServerFilter struct {
Status string `query:"status"`
}
// ApplyFilter implements the Filterable interface
func (f *ServerFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
// Apply server filter
if f.ServerID != "" {
if serverUUID, err := uuid.Parse(f.ServerID); err == nil {
query = query.Where("id = ?", serverUUID)
@@ -110,16 +101,13 @@ func (s *Server) GenerateUUID() {
}
}
// BeforeCreate is a GORM hook that runs before creating a new server
func (s *Server) BeforeCreate(tx *gorm.DB) error {
if s.Name == "" {
return errors.New("server name is required")
}
// Generate UUID if not set
s.GenerateUUID()
// Generate service name and config path if not set
if s.ServiceName == "" {
s.ServiceName = s.GenerateServiceName()
}
@@ -127,7 +115,6 @@ func (s *Server) BeforeCreate(tx *gorm.DB) error {
s.Path = s.GenerateServerPath(BaseServerPath)
}
// Set creation date if not set
if s.DateCreated.IsZero() {
s.DateCreated = time.Now().UTC()
}
@@ -135,19 +122,14 @@ func (s *Server) BeforeCreate(tx *gorm.DB) error {
return nil
}
// GenerateServiceName creates a unique service name based on the server name
func (s *Server) GenerateServiceName() string {
// If ID is set, use it
if s.ID != uuid.Nil {
return fmt.Sprintf("%s-%s", ServiceNamePrefix, s.ID.String()[:8])
}
// Otherwise use a timestamp-based unique identifier
return fmt.Sprintf("%s-%d", ServiceNamePrefix, time.Now().UnixNano())
}
// GenerateServerPath creates the config path based on the service name
func (s *Server) GenerateServerPath(steamCMDPath string) string {
// Ensure service name is set
if s.ServiceName == "" {
s.ServiceName = s.GenerateServiceName()
}

View File

@@ -17,7 +17,6 @@ const (
StatusRunning
)
// String converts the ServiceStatus to its string representation
func (s ServiceStatus) String() string {
switch s {
case StatusRunning:
@@ -35,7 +34,6 @@ func (s ServiceStatus) String() string {
}
}
// ParseServiceStatus converts a string to ServiceStatus
func ParseServiceStatus(s string) ServiceStatus {
switch s {
case "SERVICE_RUNNING":
@@ -53,31 +51,24 @@ func ParseServiceStatus(s string) ServiceStatus {
}
}
// MarshalJSON implements json.Marshaler interface
func (s ServiceStatus) MarshalJSON() ([]byte, error) {
// Return the numeric value instead of string
return []byte(strconv.Itoa(int(s))), nil
}
// UnmarshalJSON implements json.Unmarshaler interface
func (s *ServiceStatus) UnmarshalJSON(data []byte) error {
// Try to parse as number first
if i, err := strconv.Atoi(string(data)); err == nil {
*s = ServiceStatus(i)
return nil
}
// Fallback to string parsing for backward compatibility
str := string(data)
if len(str) >= 2 {
// Remove quotes if present
str = str[1 : len(str)-1]
}
*s = ParseServiceStatus(str)
return nil
}
// Scan implements the sql.Scanner interface
func (s *ServiceStatus) Scan(value interface{}) error {
if value == nil {
*s = StatusUnknown
@@ -99,7 +90,6 @@ func (s *ServiceStatus) Scan(value interface{}) error {
}
}
// Value implements the driver.Valuer interface
func (s ServiceStatus) Value() (driver.Value, error) {
return s.String(), nil
}

View File

@@ -10,27 +10,22 @@ import (
"gorm.io/gorm"
)
// StateHistoryFilter combines common filter capabilities
type StateHistoryFilter struct {
ServerBasedFilter // Adds server ID from path parameter
DateRangeFilter // Adds date range filtering
ServerBasedFilter
DateRangeFilter
// Additional fields specific to state history
Session TrackSession `query:"session"`
MinPlayers *int `query:"min_players"`
MaxPlayers *int `query:"max_players"`
}
// ApplyFilter implements the Filterable interface
func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
// Apply server filter
if f.ServerID != "" {
if serverUUID, err := uuid.Parse(f.ServerID); err == nil {
query = query.Where("server_id = ?", serverUUID)
}
}
// Apply date range filter if set
timeZero := time.Time{}
if f.StartDate != timeZero {
query = query.Where("date_created >= ?", f.StartDate)
@@ -39,12 +34,10 @@ func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
query = query.Where("date_created <= ?", f.EndDate)
}
// Apply session filter if set
if f.Session != "" {
query = query.Where("session = ?", f.Session)
}
// Apply player count filters if set
if f.MinPlayers != nil {
query = query.Where("player_count >= ?", *f.MinPlayers)
}
@@ -114,10 +107,9 @@ type StateHistory struct {
DateCreated time.Time `json:"dateCreated"`
SessionStart time.Time `json:"sessionStart"`
SessionDurationMinutes int `json:"sessionDurationMinutes"`
SessionID uuid.UUID `json:"sessionId" gorm:"not null;type:uuid"` // Unique identifier for each session/event
SessionID uuid.UUID `json:"sessionId" gorm:"not null;type:uuid"`
}
// BeforeCreate is a GORM hook that runs before creating new state history entries
func (sh *StateHistory) BeforeCreate(tx *gorm.DB) error {
if sh.ID == uuid.Nil {
sh.ID = uuid.New()

View File

@@ -21,7 +21,7 @@ type StateHistoryStats struct {
AveragePlayers float64 `json:"averagePlayers"`
PeakPlayers int `json:"peakPlayers"`
TotalSessions int `json:"totalSessions"`
TotalPlaytime int `json:"totalPlaytime" gorm:"-"` // in minutes
TotalPlaytime int `json:"totalPlaytime" gorm:"-"`
PlayerCountOverTime []PlayerCountPoint `json:"playerCountOverTime" gorm:"-"`
SessionTypes []SessionCount `json:"sessionTypes" gorm:"-"`
DailyActivity []DailyActivity `json:"dailyActivity" gorm:"-"`

View File

@@ -16,21 +16,18 @@ import (
"gorm.io/gorm"
)
// SteamCredentials represents stored Steam login credentials
type SteamCredentials struct {
ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"`
Username string `gorm:"not null" json:"username"`
Password string `gorm:"not null" json:"-"` // Encrypted, not exposed in JSON
Password string `gorm:"not null" json:"-"`
DateCreated time.Time `json:"dateCreated"`
LastUpdated time.Time `json:"lastUpdated"`
}
// TableName specifies the table name for GORM
func (SteamCredentials) TableName() string {
return "steam_credentials"
}
// BeforeCreate is a GORM hook that runs before creating new credentials
func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error {
if s.ID == uuid.Nil {
s.ID = uuid.New()
@@ -42,7 +39,6 @@ func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error {
}
s.LastUpdated = now
// Encrypt password before saving
encrypted, err := EncryptPassword(s.Password)
if err != nil {
return err
@@ -52,11 +48,9 @@ func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error {
return nil
}
// BeforeUpdate is a GORM hook that runs before updating credentials
func (s *SteamCredentials) BeforeUpdate(tx *gorm.DB) error {
s.LastUpdated = time.Now().UTC()
// Only encrypt if password field is being updated
if tx.Statement.Changed("Password") {
encrypted, err := EncryptPassword(s.Password)
if err != nil {
@@ -68,9 +62,7 @@ func (s *SteamCredentials) BeforeUpdate(tx *gorm.DB) error {
return nil
}
// AfterFind is a GORM hook that runs after fetching credentials
func (s *SteamCredentials) AfterFind(tx *gorm.DB) error {
// Decrypt password after fetching
if s.Password != "" {
decrypted, err := DecryptPassword(s.Password)
if err != nil {
@@ -81,18 +73,15 @@ func (s *SteamCredentials) AfterFind(tx *gorm.DB) error {
return nil
}
// Validate checks if the credentials are valid with enhanced security checks
func (s *SteamCredentials) Validate() error {
if s.Username == "" {
return errors.New("username is required")
}
// Enhanced username validation
if len(s.Username) < 3 || len(s.Username) > 64 {
return errors.New("username must be between 3 and 64 characters")
}
// Check for valid characters in username (alphanumeric, underscore, hyphen)
if matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, s.Username); !matched {
return errors.New("username contains invalid characters")
}
@@ -101,7 +90,6 @@ func (s *SteamCredentials) Validate() error {
return errors.New("password is required")
}
// Basic password validation
if len(s.Password) < 6 {
return errors.New("password must be at least 6 characters long")
}
@@ -110,7 +98,6 @@ func (s *SteamCredentials) Validate() error {
return errors.New("password is too long")
}
// Check for obvious weak passwords
weakPasswords := []string{"password", "123456", "steam", "admin", "user"}
lowerPass := strings.ToLower(s.Password)
for _, weak := range weakPasswords {
@@ -122,8 +109,6 @@ func (s *SteamCredentials) Validate() error {
return nil
}
// GetEncryptionKey returns the encryption key from config.
// The key is loaded from the ENCRYPTION_KEY environment variable.
func GetEncryptionKey() []byte {
key := []byte(configs.EncryptionKey)
if len(key) != 32 {
@@ -132,7 +117,6 @@ func GetEncryptionKey() []byte {
return key
}
// EncryptPassword encrypts a password using AES-256-GCM with enhanced security
func EncryptPassword(password string) (string, error) {
if password == "" {
return "", errors.New("password cannot be empty")
@@ -148,33 +132,27 @@ func EncryptPassword(password string) (string, error) {
return "", err
}
// Create a new GCM cipher
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
// Create a cryptographically secure nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
// Encrypt the password with authenticated encryption
ciphertext := gcm.Seal(nonce, nonce, []byte(password), nil)
// Return base64 encoded encrypted password
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptPassword decrypts an encrypted password with enhanced validation
func DecryptPassword(encryptedPassword string) (string, error) {
if encryptedPassword == "" {
return "", errors.New("encrypted password cannot be empty")
}
// Validate base64 format
if len(encryptedPassword) < 24 { // Minimum reasonable length
if len(encryptedPassword) < 24 {
return "", errors.New("invalid encrypted password format")
}
@@ -184,13 +162,11 @@ func DecryptPassword(encryptedPassword string) (string, error) {
return "", err
}
// Create a new GCM cipher
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
// Decode base64 encoded password
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
if err != nil {
return "", errors.New("invalid base64 encoding")
@@ -207,7 +183,6 @@ func DecryptPassword(encryptedPassword string) (string, error) {
return "", errors.New("decryption failed - invalid ciphertext or key")
}
// Validate decrypted content
decrypted := string(plaintext)
if len(decrypted) == 0 || len(decrypted) > 1024 {
return "", errors.New("invalid decrypted password")

View File

@@ -8,7 +8,6 @@ import (
"gorm.io/gorm"
)
// User represents a user account in the system.
type User struct {
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
Username string `json:"username" gorm:"unique_index;not null"`
@@ -17,16 +16,13 @@ type User struct {
Role Role `json:"role"`
}
// BeforeCreate is a GORM hook that runs before creating new users
func (s *User) BeforeCreate(tx *gorm.DB) error {
s.ID = uuid.New()
// Validate password strength
if err := password.ValidatePasswordStrength(s.Password); err != nil {
return err
}
// Hash password before saving
hashed, err := password.HashPassword(s.Password)
if err != nil {
return err
@@ -36,11 +32,8 @@ func (s *User) BeforeCreate(tx *gorm.DB) error {
return nil
}
// BeforeUpdate is a GORM hook that runs before updating users
func (s *User) BeforeUpdate(tx *gorm.DB) error {
// Only hash if password field is being updated
if tx.Statement.Changed("Password") {
// Validate password strength
if err := password.ValidatePasswordStrength(s.Password); err != nil {
return err
}
@@ -55,14 +48,10 @@ func (s *User) BeforeUpdate(tx *gorm.DB) error {
return nil
}
// AfterFind is a GORM hook that runs after fetching users
func (s *User) AfterFind(tx *gorm.DB) error {
// Password remains hashed - never decrypt
// This hook is kept for potential future use
return nil
}
// Validate checks if the user data is valid
func (s *User) Validate() error {
if s.Username == "" {
return errors.New("username is required")
@@ -73,7 +62,6 @@ func (s *User) Validate() error {
return nil
}
// VerifyPassword verifies a plain text password against the stored hash
func (s *User) VerifyPassword(plainPassword string) error {
return password.VerifyPassword(s.Password, plainPassword)
}

View File

@@ -4,7 +4,6 @@ import (
"github.com/google/uuid"
)
// ServerCreationStep represents the steps in server creation process
type ServerCreationStep string
const (
@@ -18,7 +17,6 @@ const (
StepCompleted ServerCreationStep = "completed"
)
// StepStatus represents the status of a step
type StepStatus string
const (
@@ -28,7 +26,6 @@ const (
StatusFailed StepStatus = "failed"
)
// WebSocketMessageType represents different types of WebSocket messages
type WebSocketMessageType string
const (
@@ -38,7 +35,6 @@ const (
MessageTypeComplete WebSocketMessageType = "complete"
)
// WebSocketMessage is the base structure for all WebSocket messages
type WebSocketMessage struct {
Type WebSocketMessageType `json:"type"`
ServerID *uuid.UUID `json:"server_id,omitempty"`
@@ -46,7 +42,6 @@ type WebSocketMessage struct {
Data interface{} `json:"data"`
}
// StepMessage represents a step update message
type StepMessage struct {
Step ServerCreationStep `json:"step"`
Status StepStatus `json:"status"`
@@ -54,26 +49,22 @@ type StepMessage struct {
Error string `json:"error,omitempty"`
}
// SteamOutputMessage represents SteamCMD output
type SteamOutputMessage struct {
Output string `json:"output"`
IsError bool `json:"is_error"`
}
// ErrorMessage represents an error message
type ErrorMessage struct {
Error string `json:"error"`
Details string `json:"details,omitempty"`
}
// CompleteMessage represents completion message
type CompleteMessage struct {
ServerID uuid.UUID `json:"server_id"`
Success bool `json:"success"`
Message string `json:"message"`
}
// GetStepDescription returns a human-readable description for each step
func GetStepDescription(step ServerCreationStep) string {
descriptions := map[ServerCreationStep]string{
StepValidation: "Validating server configuration",

View File

@@ -7,13 +7,11 @@ import (
"gorm.io/gorm"
)
// BaseRepository provides generic CRUD operations for any model
type BaseRepository[T any, F any] struct {
db *gorm.DB
modelType T
}
// NewBaseRepository creates a new base repository for the given model type
func NewBaseRepository[T any, F any](db *gorm.DB, model T) *BaseRepository[T, F] {
return &BaseRepository[T, F]{
db: db,
@@ -21,23 +19,19 @@ func NewBaseRepository[T any, F any](db *gorm.DB, model T) *BaseRepository[T, F]
}
}
// GetAll retrieves all records based on the filter
func (r *BaseRepository[T, F]) GetAll(ctx context.Context, filter *F) (*[]T, error) {
result := new([]T)
query := r.db.WithContext(ctx).Model(&r.modelType)
// Apply filter conditions if filter implements Filterable
if filterable, ok := any(filter).(Filterable); ok {
query = filterable.ApplyFilter(query)
}
// Apply pagination if filter implements Pageable
if pageable, ok := any(filter).(Pageable); ok {
offset, limit := pageable.Pagination()
query = query.Offset(offset).Limit(limit)
}
// Apply sorting if filter implements Sortable
if sortable, ok := any(filter).(Sortable); ok {
field, desc := sortable.GetSorting()
if desc {
@@ -54,7 +48,6 @@ func (r *BaseRepository[T, F]) GetAll(ctx context.Context, filter *F) (*[]T, err
return result, nil
}
// GetByID retrieves a single record by ID
func (r *BaseRepository[T, F]) GetByID(ctx context.Context, id interface{}) (*T, error) {
result := new(T)
if err := r.db.WithContext(ctx).Where("id = ?", id).First(result).Error; err != nil {
@@ -66,7 +59,6 @@ func (r *BaseRepository[T, F]) GetByID(ctx context.Context, id interface{}) (*T,
return result, nil
}
// Insert creates a new record
func (r *BaseRepository[T, F]) Insert(ctx context.Context, model *T) error {
if err := r.db.WithContext(ctx).Create(model).Error; err != nil {
return fmt.Errorf("error creating record: %w", err)
@@ -74,7 +66,6 @@ func (r *BaseRepository[T, F]) Insert(ctx context.Context, model *T) error {
return nil
}
// Update modifies an existing record
func (r *BaseRepository[T, F]) Update(ctx context.Context, model *T) error {
if err := r.db.WithContext(ctx).Save(model).Error; err != nil {
return fmt.Errorf("error updating record: %w", err)
@@ -82,7 +73,6 @@ func (r *BaseRepository[T, F]) Update(ctx context.Context, model *T) error {
return nil
}
// Delete removes a record by ID
func (r *BaseRepository[T, F]) Delete(ctx context.Context, id interface{}) error {
if err := r.db.WithContext(ctx).Delete(new(T), id).Error; err != nil {
return fmt.Errorf("error deleting record: %w", err)
@@ -90,7 +80,6 @@ func (r *BaseRepository[T, F]) Delete(ctx context.Context, id interface{}) error
return nil
}
// Count returns the total number of records matching the filter
func (r *BaseRepository[T, F]) Count(ctx context.Context, filter *F) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&r.modelType)
@@ -106,8 +95,6 @@ func (r *BaseRepository[T, F]) Count(ctx context.Context, filter *F) (int64, err
return count, nil
}
// Interfaces for filter capabilities
type Filterable interface {
ApplyFilter(*gorm.DB) *gorm.DB
}

View File

@@ -17,13 +17,11 @@ func NewConfigRepository(db *gorm.DB) *ConfigRepository {
}
}
// UpdateConfig updates or creates a Config record
func (r *ConfigRepository) UpdateConfig(ctx context.Context, config *model.Config) *model.Config {
if err := r.Update(ctx, config); err != nil {
// If update fails, try to insert
if err := r.Insert(ctx, config); err != nil {
return nil
}
}
return config
}
}

View File

@@ -8,20 +8,16 @@ import (
"gorm.io/gorm"
)
// MembershipRepository handles database operations for users, roles, and permissions.
type MembershipRepository struct {
*BaseRepository[model.User, model.MembershipFilter]
}
// NewMembershipRepository creates a new MembershipRepository.
func NewMembershipRepository(db *gorm.DB) *MembershipRepository {
return &MembershipRepository{
BaseRepository: NewBaseRepository[model.User, model.MembershipFilter](db, model.User{}),
}
}
// FindUserByUsername finds a user by their username.
// It preloads the user's role and the role's permissions.
func (r *MembershipRepository) FindUserByUsername(ctx context.Context, username string) (*model.User, error) {
var user model.User
db := r.db.WithContext(ctx)
@@ -32,7 +28,6 @@ func (r *MembershipRepository) FindUserByUsername(ctx context.Context, username
return &user, nil
}
// FindUserByIDWithPermissions finds a user by their ID and preloads Role and Permissions.
func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context, userID string) (*model.User, error) {
var user model.User
db := r.db.WithContext(ctx)
@@ -43,13 +38,11 @@ func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context,
return &user, nil
}
// CreateUser creates a new user.
func (r *MembershipRepository) CreateUser(ctx context.Context, user *model.User) error {
db := r.db.WithContext(ctx)
return db.Create(user).Error
}
// FindRoleByName finds a role by its name.
func (r *MembershipRepository) FindRoleByName(ctx context.Context, name string) (*model.Role, error) {
var role model.Role
db := r.db.WithContext(ctx)
@@ -60,13 +53,11 @@ func (r *MembershipRepository) FindRoleByName(ctx context.Context, name string)
return &role, nil
}
// CreateRole creates a new role.
func (r *MembershipRepository) CreateRole(ctx context.Context, role *model.Role) error {
db := r.db.WithContext(ctx)
return db.Create(role).Error
}
// FindPermissionByName finds a permission by its name.
func (r *MembershipRepository) FindPermissionByName(ctx context.Context, name string) (*model.Permission, error) {
var permission model.Permission
db := r.db.WithContext(ctx)
@@ -77,19 +68,16 @@ func (r *MembershipRepository) FindPermissionByName(ctx context.Context, name st
return &permission, nil
}
// CreatePermission creates a new permission.
func (r *MembershipRepository) CreatePermission(ctx context.Context, permission *model.Permission) error {
db := r.db.WithContext(ctx)
return db.Create(permission).Error
}
// AssignPermissionsToRole assigns a set of permissions to a role.
func (r *MembershipRepository) AssignPermissionsToRole(ctx context.Context, role *model.Role, permissions []model.Permission) error {
db := r.db.WithContext(ctx)
return db.Model(role).Association("Permissions").Replace(permissions)
}
// GetUserPermissions retrieves all permissions for a given user ID.
func (r *MembershipRepository) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) {
var user model.User
db := r.db.WithContext(ctx)
@@ -106,7 +94,6 @@ func (r *MembershipRepository) GetUserPermissions(ctx context.Context, userID uu
return permissions, nil
}
// ListUsers retrieves all users.
func (r *MembershipRepository) ListUsers(ctx context.Context) ([]*model.User, error) {
var users []*model.User
db := r.db.WithContext(ctx)
@@ -114,13 +101,11 @@ func (r *MembershipRepository) ListUsers(ctx context.Context) ([]*model.User, er
return users, err
}
// DeleteUser deletes a user.
func (r *MembershipRepository) DeleteUser(ctx context.Context, userID uuid.UUID) error {
db := r.db.WithContext(ctx)
return db.Delete(&model.User{}, "id = ?", userID).Error
}
// FindUserByID finds a user by their ID.
func (r *MembershipRepository) FindUserByID(ctx context.Context, userID uuid.UUID) (*model.User, error) {
var user model.User
db := r.db.WithContext(ctx)
@@ -131,13 +116,11 @@ func (r *MembershipRepository) FindUserByID(ctx context.Context, userID uuid.UUI
return &user, nil
}
// UpdateUser updates a user's details in the database.
func (r *MembershipRepository) UpdateUser(ctx context.Context, user *model.User) error {
db := r.db.WithContext(ctx)
return db.Save(user).Error
}
// FindRoleByID finds a role by its ID.
func (r *MembershipRepository) FindRoleByID(ctx context.Context, roleID uuid.UUID) (*model.Role, error) {
var role model.Role
db := r.db.WithContext(ctx)
@@ -148,12 +131,10 @@ func (r *MembershipRepository) FindRoleByID(ctx context.Context, roleID uuid.UUI
return &role, nil
}
// ListUsersWithFilter retrieves users based on the membership filter.
func (r *MembershipRepository) ListUsersWithFilter(ctx context.Context, filter *model.MembershipFilter) (*[]model.User, error) {
return r.BaseRepository.GetAll(ctx, filter)
}
// ListRoles retrieves all roles.
func (r *MembershipRepository) ListRoles(ctx context.Context) ([]*model.Role, error) {
var roles []*model.Role
db := r.db.WithContext(ctx)

View File

@@ -10,11 +10,7 @@ import (
"go.uber.org/dig"
)
// InitializeRepositories
// Initializes Dependency Injection modules for repositories
//
// Args:
// *dig.Container: Dig Container
// *dig.Container: Dig Container
func InitializeRepositories(c *dig.Container) {
c.Provide(NewServiceControlRepository)
c.Provide(NewStateHistoryRepository)
@@ -24,16 +20,14 @@ func InitializeRepositories(c *dig.Container) {
c.Provide(NewSteamCredentialsRepository)
c.Provide(NewMembershipRepository)
// Provide the Steam2FAManager as a singleton
if err := c.Provide(func() *model.Steam2FAManager {
manager := model.NewSteam2FAManager()
// Use graceful shutdown manager for cleanup goroutine
shutdownManager := graceful.GetManager()
shutdownManager.RunGoroutine(func(ctx context.Context) {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
@@ -43,7 +37,7 @@ func InitializeRepositories(c *dig.Container) {
}
}
})
return manager
}); err != nil {
logging.Panic("unable to initialize steam 2fa manager")

View File

@@ -20,10 +20,6 @@ func NewServerRepository(db *gorm.DB) *ServerRepository {
return repo
}
// GetFirstByServiceName
// Gets first row from Server table.
//
// Args:
// context.Context: Application context
// Returns:
// model.ServerModel: Server object from database.

View File

@@ -18,17 +18,14 @@ func NewStateHistoryRepository(db *gorm.DB) *StateHistoryRepository {
}
}
// GetAll retrieves all state history records with the given filter
func (r *StateHistoryRepository) GetAll(ctx context.Context, filter *model.StateHistoryFilter) (*[]model.StateHistory, error) {
return r.BaseRepository.GetAll(ctx, filter)
}
// Insert creates a new state history record
func (r *StateHistoryRepository) Insert(ctx context.Context, model *model.StateHistory) error {
return r.BaseRepository.Insert(ctx, model)
}
// GetLastSessionID gets the last session ID for a server
func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID uuid.UUID) (uuid.UUID, error) {
var lastSession model.StateHistory
result := r.BaseRepository.db.WithContext(ctx).
@@ -38,7 +35,7 @@ func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return uuid.Nil, nil // Return nil UUID if no sessions found
return uuid.Nil, nil
}
return uuid.Nil, result.Error
}
@@ -46,10 +43,8 @@ func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID
return lastSession.SessionID, nil
}
// GetSummaryStats calculates peak players, total sessions, and average players.
func (r *StateHistoryRepository) GetSummaryStats(ctx context.Context, filter *model.StateHistoryFilter) (model.StateHistoryStats, error) {
var stats model.StateHistoryStats
// Parse ServerID to UUID for query
serverUUID, err := uuid.Parse(filter.ServerID)
if err != nil {
return model.StateHistoryStats{}, err
@@ -73,12 +68,10 @@ func (r *StateHistoryRepository) GetSummaryStats(ctx context.Context, filter *mo
return stats, nil
}
// GetTotalPlaytime calculates the total playtime in minutes.
func (r *StateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *model.StateHistoryFilter) (int, error) {
var totalPlaytime struct {
TotalMinutes float64
}
// Parse ServerID to UUID for query
serverUUID, err := uuid.Parse(filter.ServerID)
if err != nil {
return 0, err
@@ -100,10 +93,8 @@ func (r *StateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *m
return int(totalPlaytime.TotalMinutes), nil
}
// GetPlayerCountOverTime gets downsampled player count data.
func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, filter *model.StateHistoryFilter) ([]model.PlayerCountPoint, error) {
var points []model.PlayerCountPoint
// Parse ServerID to UUID for query
serverUUID, err := uuid.Parse(filter.ServerID)
if err != nil {
return points, err
@@ -122,10 +113,8 @@ func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, fil
return points, err
}
// GetSessionTypes counts sessions by type.
func (r *StateHistoryRepository) GetSessionTypes(ctx context.Context, filter *model.StateHistoryFilter) ([]model.SessionCount, error) {
var sessionTypes []model.SessionCount
// Parse ServerID to UUID for query
serverUUID, err := uuid.Parse(filter.ServerID)
if err != nil {
return sessionTypes, err
@@ -145,10 +134,8 @@ func (r *StateHistoryRepository) GetSessionTypes(ctx context.Context, filter *mo
return sessionTypes, err
}
// GetDailyActivity counts sessions per day.
func (r *StateHistoryRepository) GetDailyActivity(ctx context.Context, filter *model.StateHistoryFilter) ([]model.DailyActivity, error) {
var dailyActivity []model.DailyActivity
// Parse ServerID to UUID for query
serverUUID, err := uuid.Parse(filter.ServerID)
if err != nil {
return dailyActivity, err
@@ -167,10 +154,8 @@ func (r *StateHistoryRepository) GetDailyActivity(ctx context.Context, filter *m
return dailyActivity, err
}
// GetRecentSessions retrieves the 10 most recent sessions.
func (r *StateHistoryRepository) GetRecentSessions(ctx context.Context, filter *model.StateHistoryFilter) ([]model.RecentSession, error) {
var recentSessions []model.RecentSession
// Parse ServerID to UUID for query
serverUUID, err := uuid.Parse(filter.ServerID)
if err != nil {
return recentSessions, err

View File

@@ -77,8 +77,8 @@ func NewConfigService(repository *repository.ConfigRepository, serverRepository
repository: repository,
serverRepository: serverRepository,
configCache: model.NewServerConfigCache(model.CacheConfig{
ExpirationTime: 5 * time.Minute, // Cache configs for 5 minutes
ThrottleTime: 1 * time.Second, // Prevent rapid re-reads
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
DefaultStatus: model.StatusUnknown,
}),
}
@@ -88,10 +88,6 @@ func (as *ConfigService) SetServerService(serverService *ServerService) {
as.serverService = serverService
}
// UpdateConfig
// Updates physical config file and caches it in database.
//
// Args:
// context.Context: Application context
// Returns:
// string: Application version
@@ -103,7 +99,6 @@ func (as *ConfigService) UpdateConfig(ctx *fiber.Ctx, body *map[string]interface
return as.updateConfigInternal(ctx.UserContext(), serverID, configFile, body, override)
}
// updateConfigInternal handles the actual config update logic without Fiber dependencies
func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID string, configFile string, body *map[string]interface{}, override bool) (*model.Config, error) {
serverUUID, err := uuid.Parse(serverID)
if err != nil {
@@ -117,17 +112,14 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri
return nil, fmt.Errorf("server not found")
}
// Read existing config
configPath := filepath.Join(server.GetConfigPath(), configFile)
oldData, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
// Create directory if it doesn't exist
dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, err
}
// Create empty JSON file
if err := os.WriteFile(configPath, []byte("{}"), 0644); err != nil {
return nil, err
}
@@ -142,7 +134,6 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri
return nil, err
}
// Write new config
newData, err := json.Marshal(&body)
if err != nil {
return nil, err
@@ -168,12 +159,9 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri
return nil, err
}
// Invalidate all configs for this server since configs can be interdependent
as.configCache.InvalidateServerCache(serverID)
as.serverService.StartAccServerRuntime(server)
// Log change
return as.repository.UpdateConfig(ctx, &model.Config{
ServerID: serverUUID,
ConfigFile: configFile,
@@ -183,10 +171,6 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri
}), nil
}
// GetConfig
// Gets physical config file and caches it in database.
//
// Args:
// context.Context: Application context
// Returns:
// string: Application version
@@ -202,7 +186,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
return nil, fiber.NewError(404, "Server not found")
}
// Try to get from cache based on config file type
switch configFile {
case ConfigurationJson:
if cached, ok := as.configCache.GetConfiguration(serverIDStr); ok {
@@ -233,7 +216,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
logging.Debug("Cache miss for server ID: %s, file: %s - loading from disk", serverIDStr, configFile)
// Not in cache, load from disk
configPath := filepath.Join(server.GetConfigPath(), configFile)
decoder := DecodeFileName(configFile)
if decoder == nil {
@@ -244,7 +226,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
if err != nil {
if os.IsNotExist(err) {
logging.Debug("Config file not found, creating default for server ID: %s, file: %s", serverIDStr, configFile)
// Return empty config if file doesn't exist
switch configFile {
case ConfigurationJson:
return &model.Configuration{}, nil
@@ -261,7 +242,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
return nil, err
}
// Cache the loaded config
switch configFile {
case ConfigurationJson:
as.configCache.UpdateConfiguration(serverIDStr, *config.(*model.Configuration))
@@ -279,8 +259,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
return config, nil
}
// GetConfigs
// Gets all configurations for a server, using cache when possible.
func (as *ConfigService) GetConfigs(ctx *fiber.Ctx) (*model.Configurations, error) {
serverID := ctx.Params("id")
@@ -298,7 +276,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration
logging.Info("Loading configs for server ID: %s at path: %s", serverIDStr, server.GetConfigPath())
configs := &model.Configurations{}
// Load configuration
if cached, ok := as.configCache.GetConfiguration(serverIDStr); ok {
logging.Debug("Using cached configuration for server %s", serverIDStr)
configs.Configuration = *cached
@@ -313,7 +290,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration
as.configCache.UpdateConfiguration(serverIDStr, config)
}
// Load assist rules
if cached, ok := as.configCache.GetAssistRules(serverIDStr); ok {
logging.Debug("Using cached assist rules for server %s", serverIDStr)
configs.AssistRules = *cached
@@ -328,7 +304,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration
as.configCache.UpdateAssistRules(serverIDStr, rules)
}
// Load event config
if cached, ok := as.configCache.GetEvent(serverIDStr); ok {
logging.Debug("Using cached event config for server %s", serverIDStr)
configs.Event = *cached
@@ -344,7 +319,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration
as.configCache.UpdateEvent(serverIDStr, event)
}
// Load event rules
if cached, ok := as.configCache.GetEventRules(serverIDStr); ok {
logging.Debug("Using cached event rules for server %s", serverIDStr)
configs.EventRules = *cached
@@ -359,7 +333,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration
as.configCache.UpdateEventRules(serverIDStr, rules)
}
// Load settings
if cached, ok := as.configCache.GetSettings(serverIDStr); ok {
logging.Debug("Using cached settings for server %s", serverIDStr)
configs.Settings = *cached
@@ -475,9 +448,7 @@ func (as *ConfigService) GetConfiguration(server *model.Server) (*model.Configur
return &config, nil
}
// SaveConfiguration saves the configuration for a server
func (as *ConfigService) SaveConfiguration(server *model.Server, config *model.Configuration) error {
// Convert config to map for UpdateConfig
configMap := make(map[string]interface{})
configBytes, err := json.Marshal(config)
if err != nil {
@@ -487,7 +458,6 @@ func (as *ConfigService) SaveConfiguration(server *model.Server, config *model.C
return fmt.Errorf("failed to unmarshal configuration: %v", err)
}
// Update the configuration using the internal method
_, err = as.updateConfigInternal(context.Background(), server.ID.String(), ConfigurationJson, &configMap, true)
return err
}

View File

@@ -96,11 +96,9 @@ func (s *FirewallService) DeleteServerRules(serverName string, tcpPorts, udpPort
}
func (s *FirewallService) UpdateServerRules(serverName string, tcpPorts, udpPorts []int) error {
// First delete existing rules
if err := s.DeleteServerRules(serverName, tcpPorts, udpPorts); err != nil {
return err
}
// Then create new rules
return s.CreateServerRules(serverName, tcpPorts, udpPorts)
}
}

View File

@@ -12,13 +12,11 @@ import (
"github.com/google/uuid"
)
// CacheInvalidator interface for cache invalidation
type CacheInvalidator interface {
InvalidateUserPermissions(userID string)
InvalidateAllUserPermissions()
}
// MembershipService provides business logic for membership-related operations.
type MembershipService struct {
repo *repository.MembershipRepository
cacheInvalidator CacheInvalidator
@@ -26,29 +24,25 @@ type MembershipService struct {
openJwtHandler *jwt.OpenJWTHandler
}
// NewMembershipService creates a new MembershipService.
func NewMembershipService(repo *repository.MembershipRepository, jwtHandler *jwt.JWTHandler, openJwtHandler *jwt.OpenJWTHandler) *MembershipService {
return &MembershipService{
repo: repo,
cacheInvalidator: nil, // Will be set later via SetCacheInvalidator
cacheInvalidator: nil,
jwtHandler: jwtHandler,
openJwtHandler: openJwtHandler,
}
}
// SetCacheInvalidator sets the cache invalidator after service initialization
func (s *MembershipService) SetCacheInvalidator(invalidator CacheInvalidator) {
s.cacheInvalidator = invalidator
}
// Login authenticates a user and returns a JWT.
func (s *MembershipService) HandleLogin(ctx context.Context, username, password string) (*model.User, error) {
user, err := s.repo.FindUserByUsername(ctx, username)
if err != nil {
return nil, errors.New("invalid credentials")
}
// Use secure password verification with constant-time comparison
if err := user.VerifyPassword(password); err != nil {
return nil, errors.New("invalid credentials")
}
@@ -56,7 +50,6 @@ func (s *MembershipService) HandleLogin(ctx context.Context, username, password
return user, nil
}
// Login authenticates a user and returns a JWT.
func (s *MembershipService) Login(ctx context.Context, username, password string) (string, error) {
user, err := s.HandleLogin(ctx, username, password)
if err != nil {
@@ -70,7 +63,6 @@ func (s *MembershipService) GenerateOpenToken(ctx context.Context, userId string
return s.openJwtHandler.GenerateToken(userId)
}
// CreateUser creates a new user.
func (s *MembershipService) CreateUser(ctx context.Context, username, password, roleName string) (*model.User, error) {
role, err := s.repo.FindRoleByName(ctx, roleName)
@@ -94,43 +86,35 @@ func (s *MembershipService) CreateUser(ctx context.Context, username, password,
return user, nil
}
// ListUsers retrieves all users.
func (s *MembershipService) ListUsers(ctx context.Context) ([]*model.User, error) {
return s.repo.ListUsers(ctx)
}
// GetUser retrieves a single user by ID.
func (s *MembershipService) GetUser(ctx context.Context, userID uuid.UUID) (*model.User, error) {
return s.repo.FindUserByID(ctx, userID)
}
// GetUserWithPermissions retrieves a single user by ID with their role and permissions.
func (s *MembershipService) GetUserWithPermissions(ctx context.Context, userID string) (*model.User, error) {
return s.repo.FindUserByIDWithPermissions(ctx, userID)
}
// UpdateUserRequest defines the request body for updating a user.
type UpdateUserRequest struct {
Username *string `json:"username"`
Password *string `json:"password"`
RoleID *uuid.UUID `json:"roleId"`
}
// DeleteUser deletes a user with validation to prevent Super Admin deletion.
func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) error {
// Get user with role information
user, err := s.repo.FindUserByID(ctx, userID)
if err != nil {
return errors.New("user not found")
}
// Get role to check if it's Super Admin
role, err := s.repo.FindRoleByID(ctx, user.RoleID)
if err != nil {
return errors.New("user role not found")
}
// Prevent deletion of Super Admin users
if role.Name == "Super Admin" {
return errors.New("cannot delete Super Admin user")
}
@@ -140,7 +124,6 @@ func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) er
return err
}
// Invalidate cache for deleted user
if s.cacheInvalidator != nil {
s.cacheInvalidator.InvalidateUserPermissions(userID.String())
}
@@ -149,7 +132,6 @@ func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) er
return nil
}
// UpdateUser updates a user's details.
func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, req UpdateUserRequest) (*model.User, error) {
user, err := s.repo.FindUserByID(ctx, userID)
if err != nil {
@@ -161,12 +143,10 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
}
if req.Password != nil && *req.Password != "" {
// Password will be automatically hashed in BeforeUpdate hook
user.Password = *req.Password
}
if req.RoleID != nil {
// Check if role exists
_, err := s.repo.FindRoleByID(ctx, *req.RoleID)
if err != nil {
return nil, errors.New("role not found")
@@ -178,7 +158,6 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
return nil, err
}
// Invalidate cache if role was changed
if req.RoleID != nil && s.cacheInvalidator != nil {
s.cacheInvalidator.InvalidateUserPermissions(userID.String())
}
@@ -187,14 +166,12 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
return user, nil
}
// HasPermission checks if a user has a specific permission.
func (s *MembershipService) HasPermission(ctx context.Context, userID string, permissionName string) (bool, error) {
user, err := s.repo.FindUserByIDWithPermissions(ctx, userID)
if err != nil {
return false, err
}
// Super admin and Admin have all permissions
if user.Role.Name == "Super Admin" || user.Role.Name == "Admin" {
return true, nil
}
@@ -208,15 +185,13 @@ func (s *MembershipService) HasPermission(ctx context.Context, userID string, pe
return false, nil
}
// SetupInitialData creates the initial roles and permissions.
func (s *MembershipService) SetupInitialData(ctx context.Context) error {
// Define all permissions
permissions := model.AllPermissions()
createdPermissions := make([]model.Permission, 0)
for _, pName := range permissions {
perm, err := s.repo.FindPermissionByName(ctx, pName)
if err != nil { // Assuming error means not found
if err != nil {
perm = &model.Permission{Name: pName}
if err := s.repo.CreatePermission(ctx, perm); err != nil {
return err
@@ -225,7 +200,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
createdPermissions = append(createdPermissions, *perm)
}
// Create Super Admin role with all permissions
superAdminRole, err := s.repo.FindRoleByName(ctx, "Super Admin")
if err != nil {
superAdminRole = &model.Role{Name: "Super Admin"}
@@ -237,7 +211,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
return err
}
// Create Admin role with same permissions as Super Admin
adminRole, err := s.repo.FindRoleByName(ctx, "Admin")
if err != nil {
adminRole = &model.Role{Name: "Admin"}
@@ -249,7 +222,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
return err
}
// Create Manager role with limited permissions (excluding membership, role, user, server create/delete)
managerRole, err := s.repo.FindRoleByName(ctx, "Manager")
if err != nil {
managerRole = &model.Role{Name: "Manager"}
@@ -258,7 +230,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
}
}
// Define manager permissions (limited set)
managerPermissionNames := []string{
model.ServerView,
model.ServerUpdate,
@@ -282,16 +253,14 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
return err
}
// Invalidate all caches after role setup changes
if s.cacheInvalidator != nil {
s.cacheInvalidator.InvalidateAllUserPermissions()
}
// Create a default admin user if one doesn't exist
_, err = s.repo.FindUserByUsername(ctx, "admin")
if err != nil {
logging.Debug("Creating default admin user")
_, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin") // Default password, should be changed
_, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin")
if err != nil {
return err
}
@@ -300,7 +269,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
return nil
}
// GetAllRoles retrieves all roles for dropdown selection.
func (s *MembershipService) GetAllRoles(ctx context.Context) ([]*model.Role, error) {
return s.repo.ListRoles(ctx)
}

View File

@@ -20,7 +20,7 @@ import (
const (
DefaultStartPort = 9600
RequiredPortCount = 1 // Update this if ACC needs more ports
RequiredPortCount = 1
)
type ServerService struct {
@@ -45,18 +45,15 @@ type pendingState struct {
}
func (s *ServerService) ensureLogTailing(server *model.Server, instance *tracking.AccServerInstance) {
// Check if we already have a tailer
if _, exists := s.logTailers.Load(server.ID); exists {
return
}
// Start tailing in a goroutine that handles file creation/deletion
go func() {
logPath := filepath.Join(server.GetLogPath(), "server.log")
tailer := tracking.NewLogTailer(logPath, instance.HandleLogLine)
s.logTailers.Store(server.ID, tailer)
// Start tailing and automatically handle file changes
tailer.Start()
}()
}
@@ -82,7 +79,6 @@ func NewServerService(
webSocketService: webSocketService,
}
// Initialize server instances
servers, err := repository.GetAll(context.Background(), &model.ServerFilter{})
if err != nil {
logging.Error("Failed to get servers: %v", err)
@@ -90,7 +86,6 @@ func NewServerService(
}
for i := range *servers {
// Initialize instance regardless of status
logging.Info("Starting server runtime for server ID: %d", (*servers)[i].ID)
service.StartAccServerRuntime(&(*servers)[i])
}
@@ -99,7 +94,7 @@ func NewServerService(
}
func (s *ServerService) shouldInsertStateHistory(serverID uuid.UUID) bool {
insertInterval := 5 * time.Minute // Configure this as needed
insertInterval := 5 * time.Minute
lastInsertInterface, exists := s.lastInsertTimes.Load(serverID)
if !exists {
@@ -122,16 +117,15 @@ func (s *ServerService) getNextSessionID(serverID uuid.UUID) uuid.UUID {
lastID, err := s.stateHistoryRepo.GetLastSessionID(context.Background(), serverID)
if err != nil {
logging.Error("Failed to get last session ID for server %s: %v", serverID, err)
return uuid.New() // Return new UUID as fallback
return uuid.New()
}
if lastID == uuid.Nil {
return uuid.New() // Return new UUID if no previous session
return uuid.New()
}
return uuid.New() // Always generate new UUID for each session
return uuid.New()
}
func (s *ServerService) insertStateHistory(serverID uuid.UUID, state *model.ServerState) {
// Get or create session ID when session changes
currentSessionInterface, exists := s.instances.Load(serverID)
var sessionID uuid.UUID
if !exists {
@@ -163,7 +157,6 @@ func (s *ServerService) insertStateHistory(serverID uuid.UUID, state *model.Serv
}
func (s *ServerService) updateSessionDuration(server *model.Server, sessionType model.TrackSession) {
// Get configs using helper methods
event, err := s.configService.GetEventConfig(server)
if err != nil {
event = &model.EventConfig{}
@@ -181,9 +174,7 @@ func (s *ServerService) updateSessionDuration(server *model.Server, sessionType
serverInstance.State.Track = event.Track
serverInstance.State.MaxConnections = configuration.MaxConnections.ToInt()
// Check if session type has changed
if serverInstance.State.Session != sessionType {
// Get new session ID for the new session
sessionID := s.getNextSessionID(server.ID)
s.sessionIDs.Store(server.ID, sessionID)
}
@@ -204,33 +195,27 @@ func (s *ServerService) updateSessionDuration(server *model.Server, sessionType
}
func (s *ServerService) GenerateServerPath(server *model.Server) {
// Get the base steamcmd path from environment variable
steamCMDPath := env.GetSteamCMDDirPath()
server.FromSteamCMD = true
server.Path = server.GenerateServerPath(steamCMDPath)
server.FromSteamCMD = true
}
func (s *ServerService) handleStateChange(server *model.Server, state *model.ServerState) {
// Update session duration when session changes
s.updateSessionDuration(server, state.Session)
// Invalidate status cache when server state changes
s.apiService.statusCache.InvalidateStatus(server.ServiceName)
// Cancel existing timer if any
if debouncer, exists := s.debouncers.Load(server.ID); exists {
pending := debouncer.(*pendingState)
pending.timer.Stop()
}
// Create new timer
timer := time.NewTimer(5 * time.Minute)
s.debouncers.Store(server.ID, &pendingState{
timer: timer,
state: state,
})
// Start goroutine to handle the delayed insert
go func() {
<-timer.C
if debouncer, exists := s.debouncers.Load(server.ID); exists {
@@ -240,14 +225,12 @@ func (s *ServerService) handleStateChange(server *model.Server, state *model.Ser
}
}()
// If enough time has passed since last insert, insert immediately
if s.shouldInsertStateHistory(server.ID) {
s.insertStateHistory(server.ID, state)
}
}
func (s *ServerService) StartAccServerRuntime(server *model.Server) {
// Get or create instance
instanceInterface, exists := s.instances.Load(server.ID)
var instance *tracking.AccServerInstance
if !exists {
@@ -259,20 +242,14 @@ func (s *ServerService) StartAccServerRuntime(server *model.Server) {
instance = instanceInterface.(*tracking.AccServerInstance)
}
// Invalidate config cache for this server before loading new configs
serverIDStr := server.ID.String()
s.configService.configCache.InvalidateServerCache(serverIDStr)
s.updateSessionDuration(server, instance.State.Session)
// Ensure log tailing is running (regardless of server status)
s.ensureLogTailing(server, instance)
}
// GetAll
// Gets All rows from Server table.
//
// Args:
// context.Context: Application context
// Returns:
// string: Application version
@@ -304,10 +281,6 @@ func (s *ServerService) GetAll(ctx *fiber.Ctx, filter *model.ServerFilter) (*[]m
return servers, nil
}
// GetById
// Gets rows by ID from Server table.
//
// Args:
// context.Context: Application context
// Returns:
// string: Application version
@@ -334,22 +307,16 @@ func (as *ServerService) GetById(ctx *fiber.Ctx, serverID uuid.UUID) (*model.Ser
return server, nil
}
// CreateServerAsync starts server creation asynchronously and returns immediately
func (s *ServerService) CreateServerAsync(ctx *fiber.Ctx, server *model.Server) error {
// Perform basic validation first
if err := server.Validate(); err != nil {
return err
}
// Generate server path
s.GenerateServerPath(server)
// Create a background context that won't be cancelled when the HTTP request ends
bgCtx := context.Background()
// Start the actual creation process in a goroutine
go func() {
// Create server in background without using fiber.Ctx
if err := s.createServerBackground(bgCtx, server); err != nil {
logging.Error("Async server creation failed for server %s: %v", server.ID, err)
s.webSocketService.BroadcastError(server.ID, "Server creation failed", err.Error())
@@ -361,11 +328,9 @@ func (s *ServerService) CreateServerAsync(ctx *fiber.Ctx, server *model.Server)
}
func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error {
// Broadcast step: validation
s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusInProgress,
model.GetStepDescription(model.StepValidation), "")
// Validate basic server configuration
if err := server.Validate(); err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusFailed,
"", fmt.Sprintf("Validation failed: %v", err))
@@ -375,19 +340,15 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusCompleted,
"Server configuration validated successfully", "")
// Broadcast step: directory creation
s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusInProgress,
model.GetStepDescription(model.StepDirectoryCreation), "")
// Directory creation is handled within InstallServer, so we mark it as completed
s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusCompleted,
"Server directories prepared", "")
// Broadcast step: Steam download
s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusInProgress,
model.GetStepDescription(model.StepSteamDownload), "")
// Install server using SteamCMD with streaming support
if err := s.steamService.InstallServerWithWebSocket(ctx.UserContext(), server.Path, &server.ID, s.webSocketService); err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusFailed,
"", fmt.Sprintf("Steam installation failed: %v", err))
@@ -397,11 +358,9 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusCompleted,
"Server files downloaded successfully", "")
// Broadcast step: config generation
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusInProgress,
model.GetStepDescription(model.StepConfigGeneration), "")
// Find available ports for server
ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount)
if err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed,
@@ -409,10 +368,8 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
return fmt.Errorf("failed to find available ports: %v", err)
}
// Use the first port for both TCP and UDP
serverPort := ports[0]
// Update server configuration with the allocated port
if err := s.updateServerPort(server, serverPort); err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed,
"", fmt.Sprintf("Failed to update server configuration: %v", err))
@@ -422,17 +379,14 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusCompleted,
fmt.Sprintf("Server configuration generated (Port: %d)", serverPort), "")
// Broadcast step: service creation
s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusInProgress,
model.GetStepDescription(model.StepServiceCreation), "")
// Create Windows service with correct paths
execPath := filepath.Join(server.GetServerPath(), "accServer.exe")
serverWorkingDir := filepath.Join(server.GetServerPath(), "server")
if err := s.windowsService.CreateService(ctx.UserContext(), server.ServiceName, execPath, serverWorkingDir, nil); err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusFailed,
"", fmt.Sprintf("Failed to create Windows service: %v", err))
// Cleanup on failure
s.steamService.UninstallServer(server.Path)
return fmt.Errorf("failed to create Windows service: %v", err)
}
@@ -440,7 +394,6 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusCompleted,
fmt.Sprintf("Windows service '%s' created successfully", server.ServiceName), "")
// Broadcast step: firewall rules
s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusInProgress,
model.GetStepDescription(model.StepFirewallRules), "")
@@ -450,7 +403,6 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusFailed,
"", fmt.Sprintf("Failed to create firewall rules: %v", err))
// Cleanup on failure
s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName)
s.steamService.UninstallServer(server.Path)
return fmt.Errorf("failed to create firewall rules: %v", err)
@@ -459,15 +411,12 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusCompleted,
fmt.Sprintf("Firewall rules created for port %d", serverPort), "")
// Broadcast step: database save
s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusInProgress,
model.GetStepDescription(model.StepDatabaseSave), "")
// Insert server into database
if err := s.repository.Insert(ctx.UserContext(), server); err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusFailed,
"", fmt.Sprintf("Failed to save server to database: %v", err))
// Cleanup on failure
s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts)
s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName)
s.steamService.UninstallServer(server.Path)
@@ -477,10 +426,8 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusCompleted,
"Server saved to database successfully", "")
// Initialize server runtime
s.StartAccServerRuntime(server)
// Broadcast completion
s.webSocketService.BroadcastStep(server.ID, model.StepCompleted, model.StatusCompleted,
model.GetStepDescription(model.StepCompleted), "")
@@ -490,13 +437,10 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
return nil
}
// createServerBackground performs server creation in background without fiber.Ctx
func (s *ServerService) createServerBackground(ctx context.Context, server *model.Server) error {
// Broadcast step: validation
s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusInProgress,
model.GetStepDescription(model.StepValidation), "")
// Validate basic server configuration (already done in async method, but double-check)
if err := server.Validate(); err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusFailed,
"", fmt.Sprintf("Validation failed: %v", err))
@@ -506,20 +450,16 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusCompleted,
"Server configuration validated successfully", "")
// Broadcast step: directory creation
s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusInProgress,
model.GetStepDescription(model.StepDirectoryCreation), "")
// Directory creation is handled within InstallServer, so we mark it as completed
s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusCompleted,
"Server directories prepared", "")
// Broadcast step: Steam download
s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusInProgress,
model.GetStepDescription(model.StepSteamDownload), "")
// Install server using SteamCMD with streaming support
if err := s.steamService.InstallServerWithWebSocket(ctx, server.GetServerPath(), &server.ID, s.webSocketService); err != nil {
if err := s.steamService.InstallServerWithWebSocket(ctx, server.Path, &server.ID, s.webSocketService); err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusFailed,
"", fmt.Sprintf("Steam installation failed: %v", err))
return fmt.Errorf("failed to install server: %v", err)
@@ -528,11 +468,9 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusCompleted,
"Server files downloaded successfully", "")
// Broadcast step: config generation
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusInProgress,
model.GetStepDescription(model.StepConfigGeneration), "")
// Find available ports for server
ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount)
if err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed,
@@ -540,10 +478,8 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
return fmt.Errorf("failed to find available ports: %v", err)
}
// Use the first port for both TCP and UDP
serverPort := ports[0]
// Update server configuration with the allocated port
if err := s.updateServerPort(server, serverPort); err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed,
"", fmt.Sprintf("Failed to update server configuration: %v", err))
@@ -553,17 +489,14 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusCompleted,
fmt.Sprintf("Server configuration generated (Port: %d)", serverPort), "")
// Broadcast step: service creation
s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusInProgress,
model.GetStepDescription(model.StepServiceCreation), "")
// Create Windows service with correct paths
execPath := filepath.Join(server.GetServerPath(), "accServer.exe")
serverWorkingDir := filepath.Join(server.GetServerPath(), "server")
if err := s.windowsService.CreateService(ctx, server.ServiceName, execPath, serverWorkingDir, nil); err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusFailed,
"", fmt.Sprintf("Failed to create Windows service: %v", err))
// Cleanup on failure
s.steamService.UninstallServer(server.Path)
return fmt.Errorf("failed to create Windows service: %v", err)
}
@@ -571,7 +504,6 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusCompleted,
fmt.Sprintf("Windows service '%s' created successfully", server.ServiceName), "")
// Broadcast step: firewall rules
s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusInProgress,
model.GetStepDescription(model.StepFirewallRules), "")
@@ -581,7 +513,6 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusFailed,
"", fmt.Sprintf("Failed to create firewall rules: %v", err))
// Cleanup on failure
s.windowsService.DeleteService(ctx, server.ServiceName)
s.steamService.UninstallServer(server.Path)
return fmt.Errorf("failed to create firewall rules: %v", err)
@@ -590,15 +521,12 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusCompleted,
fmt.Sprintf("Firewall rules created for port %d", serverPort), "")
// Broadcast step: database save
s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusInProgress,
model.GetStepDescription(model.StepDatabaseSave), "")
// Insert server into database
if err := s.repository.Insert(ctx, server); err != nil {
s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusFailed,
"", fmt.Sprintf("Failed to save server to database: %v", err))
// Cleanup on failure
s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts)
s.windowsService.DeleteService(ctx, server.ServiceName)
s.steamService.UninstallServer(server.Path)
@@ -608,10 +536,8 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusCompleted,
"Server saved to database successfully", "")
// Initialize server runtime
s.StartAccServerRuntime(server)
// Broadcast completion
s.webSocketService.BroadcastStep(server.ID, model.StepCompleted, model.StatusCompleted,
model.GetStepDescription(model.StepCompleted), "")
@@ -622,18 +548,15 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
}
func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error {
// Get server details
server, err := s.repository.GetByID(ctx.UserContext(), serverID)
if err != nil {
return fmt.Errorf("failed to get server details: %v", err)
}
// Stop and remove Windows service
if err := s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName); err != nil {
logging.Error("Failed to delete Windows service: %v", err)
}
// Remove firewall rules
configuration, err := s.configService.GetConfiguration(server)
if err != nil {
logging.Error("Failed to get configuration for server %d: %v", server.ID, err)
@@ -644,17 +567,14 @@ func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error {
logging.Error("Failed to delete firewall rules: %v", err)
}
// Uninstall server files
if err := s.steamService.UninstallServer(server.Path); err != nil {
logging.Error("Failed to uninstall server: %v", err)
}
// Remove from database
if err := s.repository.Delete(ctx.UserContext(), serverID); err != nil {
return fmt.Errorf("failed to delete server from database: %v", err)
}
// Cleanup runtime resources
if tailer, exists := s.logTailers.Load(server.ID); exists {
tailer.(*tracking.LogTailer).Stop()
s.logTailers.Delete(server.ID)
@@ -664,84 +584,27 @@ func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error {
s.debouncers.Delete(server.ID)
s.sessionIDs.Delete(server.ID)
// Invalidate status cache for deleted server
s.apiService.statusCache.InvalidateStatus(server.ServiceName)
return nil
}
func (s *ServerService) UpdateServer(ctx *fiber.Ctx, server *model.Server) error {
// Validate server configuration
if err := server.Validate(); err != nil {
return err
}
// Get existing server details
existingServer, err := s.repository.GetByID(ctx.UserContext(), server.ID)
if err != nil {
return fmt.Errorf("failed to get existing server details: %v", err)
}
// Update server files if path changed
if existingServer.Path != server.Path {
if err := s.steamService.InstallServer(ctx.UserContext(), server.Path, &server.ID); err != nil {
return fmt.Errorf("failed to install server to new location: %v", err)
}
// Clean up old installation
if err := s.steamService.UninstallServer(existingServer.Path); err != nil {
logging.Error("Failed to remove old server installation: %v", err)
}
}
// Update Windows service if necessary
if existingServer.ServiceName != server.ServiceName || existingServer.Path != server.Path {
execPath := filepath.Join(server.GetServerPath(), "accServer.exe")
serverWorkingDir := server.GetServerPath()
if err := s.windowsService.UpdateService(ctx.UserContext(), server.ServiceName, execPath, serverWorkingDir, nil); err != nil {
return fmt.Errorf("failed to update Windows service: %v", err)
}
}
// Update firewall rules if service name changed
if existingServer.ServiceName != server.ServiceName {
if err := s.configureFirewall(server); err != nil {
return fmt.Errorf("failed to update firewall rules: %v", err)
}
// Invalidate cache for old service name
s.apiService.statusCache.InvalidateStatus(existingServer.ServiceName)
}
// Update database record
if err := s.repository.Update(ctx.UserContext(), server); err != nil {
return fmt.Errorf("failed to update server in database: %v", err)
}
// Restart server runtime
s.StartAccServerRuntime(server)
return nil
}
func (s *ServerService) configureFirewall(server *model.Server) error {
// Find available ports for the server
ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount)
if err != nil {
return fmt.Errorf("failed to find available ports: %v", err)
}
// Use the first port for both TCP and UDP
serverPort := ports[0]
tcpPorts := []int{serverPort}
udpPorts := []int{serverPort}
logging.Info("Configuring firewall for server %d with port %d", server.ID, serverPort)
// Configure firewall rules
if err := s.firewallService.UpdateServerRules(server.Name, tcpPorts, udpPorts); err != nil {
return fmt.Errorf("failed to configure firewall: %v", err)
}
// Update server configuration with the allocated port
if err := s.updateServerPort(server, serverPort); err != nil {
return fmt.Errorf("failed to update server configuration: %v", err)
}
@@ -750,7 +613,6 @@ func (s *ServerService) configureFirewall(server *model.Server) error {
}
func (s *ServerService) updateServerPort(server *model.Server, port int) error {
// Load current configuration
config, err := s.configService.GetConfiguration(server)
if err != nil {
return fmt.Errorf("failed to load server configuration: %v", err)
@@ -759,7 +621,6 @@ func (s *ServerService) updateServerPort(server *model.Server, port int) error {
config.TcpPort = model.IntString(port)
config.UdpPort = model.IntString(port)
// Save the updated configuration
if err := s.configService.SaveConfiguration(server, config); err != nil {
return fmt.Errorf("failed to save server configuration: %v", err)
}

View File

@@ -7,17 +7,12 @@ import (
"go.uber.org/dig"
)
// InitializeServices
// Initializes Dependency Injection modules for services
//
// Args:
// *dig.Container: Dig Container
// *dig.Container: Dig Container
func InitializeServices(c *dig.Container) {
logging.Debug("Initializing repositories")
repository.InitializeRepositories(c)
logging.Debug("Registering services")
// Provide services
c.Provide(NewSteamService)
c.Provide(NewServerService)
c.Provide(NewStateHistoryService)

View File

@@ -24,9 +24,9 @@ func NewServiceControlService(repository *repository.ServiceControlRepository,
repository: repository,
serverRepository: serverRepository,
statusCache: model.NewServerStatusCache(model.CacheConfig{
ExpirationTime: 30 * time.Second, // Cache expires after 30 seconds
ThrottleTime: 5 * time.Second, // Minimum 5 seconds between checks
DefaultStatus: model.StatusRunning, // Default to running if throttled
ExpirationTime: 30 * time.Second,
ThrottleTime: 5 * time.Second,
DefaultStatus: model.StatusRunning,
}),
windowsService: NewWindowsService(),
}
@@ -42,18 +42,15 @@ func (as *ServiceControlService) GetStatus(ctx *fiber.Ctx) (string, error) {
return "", err
}
// Try to get status from cache
if status, shouldCheck := as.statusCache.GetStatus(serviceName); !shouldCheck {
return status.String(), nil
}
// If cache miss or expired, check actual status
statusStr, err := as.StatusServer(serviceName)
if err != nil {
return "", err
}
// Parse and update cache with new status
status := model.ParseServiceStatus(statusStr)
as.statusCache.UpdateStatus(serviceName, status)
return status.String(), nil
@@ -65,7 +62,6 @@ func (as *ServiceControlService) ServiceControlStartServer(ctx *fiber.Ctx) (stri
return "", err
}
// Update status cache for this service before starting
as.statusCache.UpdateStatus(serviceName, model.StatusStarting)
_, err = as.StartServer(serviceName)
@@ -77,7 +73,6 @@ func (as *ServiceControlService) ServiceControlStartServer(ctx *fiber.Ctx) (stri
return "", err
}
// Parse and update cache with new status
status := model.ParseServiceStatus(statusStr)
as.statusCache.UpdateStatus(serviceName, status)
return status.String(), nil
@@ -89,7 +84,6 @@ func (as *ServiceControlService) ServiceControlStopServer(ctx *fiber.Ctx) (strin
return "", err
}
// Update status cache for this service before stopping
as.statusCache.UpdateStatus(serviceName, model.StatusStopping)
_, err = as.StopServer(serviceName)
@@ -101,7 +95,6 @@ func (as *ServiceControlService) ServiceControlStopServer(ctx *fiber.Ctx) (strin
return "", err
}
// Parse and update cache with new status
status := model.ParseServiceStatus(statusStr)
as.statusCache.UpdateStatus(serviceName, status)
return status.String(), nil
@@ -113,7 +106,6 @@ func (as *ServiceControlService) ServiceControlRestartServer(ctx *fiber.Ctx) (st
return "", err
}
// Update status cache for this service before restarting
as.statusCache.UpdateStatus(serviceName, model.StatusRestarting)
_, err = as.RestartServer(serviceName)
@@ -125,7 +117,6 @@ func (as *ServiceControlService) ServiceControlRestartServer(ctx *fiber.Ctx) (st
return "", err
}
// Parse and update cache with new status
status := model.ParseServiceStatus(statusStr)
as.statusCache.UpdateStatus(serviceName, status)
return status.String(), nil
@@ -135,20 +126,16 @@ func (as *ServiceControlService) StatusServer(serviceName string) (string, error
return as.windowsService.Status(context.Background(), serviceName)
}
// GetCachedStatus gets the cached status for a service name without requiring fiber context
func (as *ServiceControlService) GetCachedStatus(serviceName string) (string, error) {
// Try to get status from cache
if status, shouldCheck := as.statusCache.GetStatus(serviceName); !shouldCheck {
return status.String(), nil
}
// If cache miss or expired, check actual status
statusStr, err := as.StatusServer(serviceName)
if err != nil {
return "", err
}
// Parse and update cache with new status
status := model.ParseServiceStatus(statusStr)
as.statusCache.UpdateStatus(serviceName, status)
return status.String(), nil

View File

@@ -6,7 +6,7 @@ import (
)
type ServiceManager struct {
executor *command.CommandExecutor
executor *command.CommandExecutor
psExecutor *command.CommandExecutor
}
@@ -24,17 +24,14 @@ func NewServiceManager() *ServiceManager {
}
func (s *ServiceManager) ManageService(serviceName, action string) (string, error) {
// Run NSSM command through PowerShell to ensure elevation
output, err := s.psExecutor.ExecuteWithOutput("-nologo", "-noprofile", ".\\nssm", action, serviceName)
if err != nil {
return "", err
}
// Clean up output by removing null bytes and trimming whitespace
cleaned := strings.TrimSpace(strings.ReplaceAll(output, "\x00", ""))
// Remove \r\n from status strings
cleaned = strings.TrimSuffix(cleaned, "\r\n")
return cleaned, nil
}
@@ -51,11 +48,9 @@ func (s *ServiceManager) Stop(serviceName string) (string, error) {
}
func (s *ServiceManager) Restart(serviceName string) (string, error) {
// First stop the service
if _, err := s.Stop(serviceName); err != nil {
return "", err
}
// Then start it again
return s.Start(serviceName)
}
}

View File

@@ -46,7 +46,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
eg, gCtx := errgroup.WithContext(ctx.UserContext())
// Get Summary Stats (Peak/Avg Players, Total Sessions)
eg.Go(func() error {
summary, err := s.repository.GetSummaryStats(gCtx, filter)
if err != nil {
@@ -61,7 +60,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
return nil
})
// Get Total Playtime
eg.Go(func() error {
playtime, err := s.repository.GetTotalPlaytime(gCtx, filter)
if err != nil {
@@ -74,7 +72,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
return nil
})
// Get Player Count Over Time
eg.Go(func() error {
playerCount, err := s.repository.GetPlayerCountOverTime(gCtx, filter)
if err != nil {
@@ -87,7 +84,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
return nil
})
// Get Session Types
eg.Go(func() error {
sessionTypes, err := s.repository.GetSessionTypes(gCtx, filter)
if err != nil {
@@ -100,7 +96,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
return nil
})
// Get Daily Activity
eg.Go(func() error {
dailyActivity, err := s.repository.GetDailyActivity(gCtx, filter)
if err != nil {
@@ -113,7 +108,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
return nil
})
// Get Recent Sessions
eg.Go(func() error {
recentSessions, err := s.repository.GetRecentSessions(gCtx, filter)
if err != nil {

View File

@@ -22,12 +22,11 @@ const (
)
type SteamService struct {
executor *command.CommandExecutor
interactiveExecutor *command.InteractiveCommandExecutor
repository *repository.SteamCredentialsRepository
tfaManager *model.Steam2FAManager
pathValidator *security.PathValidator
downloadVerifier *security.DownloadVerifier
executor *command.CommandExecutor
repository *repository.SteamCredentialsRepository
tfaManager *model.Steam2FAManager
pathValidator *security.PathValidator
downloadVerifier *security.DownloadVerifier
}
func NewSteamService(repository *repository.SteamCredentialsRepository, tfaManager *model.Steam2FAManager) *SteamService {
@@ -37,12 +36,11 @@ func NewSteamService(repository *repository.SteamCredentialsRepository, tfaManag
}
return &SteamService{
executor: baseExecutor,
interactiveExecutor: command.NewInteractiveCommandExecutor(baseExecutor, tfaManager),
repository: repository,
tfaManager: tfaManager,
pathValidator: security.NewPathValidator(),
downloadVerifier: security.NewDownloadVerifier(),
executor: baseExecutor,
repository: repository,
tfaManager: tfaManager,
pathValidator: security.NewPathValidator(),
downloadVerifier: security.NewDownloadVerifier(),
}
}
@@ -58,21 +56,17 @@ func (s *SteamService) SaveCredentials(ctx context.Context, creds *model.SteamCr
}
func (s *SteamService) ensureSteamCMD(_ context.Context) error {
// Get SteamCMD path from environment variable
steamCMDPath := env.GetSteamCMDPath()
steamCMDDir := filepath.Dir(steamCMDPath)
// Check if SteamCMD exists
if _, err := os.Stat(steamCMDPath); !os.IsNotExist(err) {
return nil
}
// Create directory if it doesn't exist
if err := os.MkdirAll(steamCMDDir, 0755); err != nil {
return fmt.Errorf("failed to create SteamCMD directory: %v", err)
}
// Download and install SteamCMD securely
logging.Info("Downloading SteamCMD...")
steamCMDZip := filepath.Join(steamCMDDir, "steamcmd.zip")
if err := s.downloadVerifier.VerifyAndDownload(
@@ -82,150 +76,27 @@ func (s *SteamService) ensureSteamCMD(_ context.Context) error {
return fmt.Errorf("failed to download SteamCMD: %v", err)
}
// Extract SteamCMD
logging.Info("Extracting SteamCMD...")
if err := s.executor.Execute("-Command",
fmt.Sprintf("Expand-Archive -Path 'steamcmd.zip' -DestinationPath '%s'", steamCMDDir)); err != nil {
return fmt.Errorf("failed to extract SteamCMD: %v", err)
}
// Clean up zip file
os.Remove("steamcmd.zip")
return nil
}
func (s *SteamService) InstallServer(ctx context.Context, installPath string, serverID *uuid.UUID) error {
if err := s.ensureSteamCMD(ctx); err != nil {
return err
}
// Validate installation path for security
if err := s.pathValidator.ValidateInstallPath(installPath); err != nil {
return fmt.Errorf("invalid installation path: %v", err)
}
// Convert to absolute path and ensure proper Windows path format
absPath, err := filepath.Abs(installPath)
if err != nil {
return fmt.Errorf("failed to get absolute path: %v", err)
}
absPath = filepath.Clean(absPath)
// Ensure install path exists
if err := os.MkdirAll(absPath, 0755); err != nil {
return fmt.Errorf("failed to create install directory: %v", err)
}
// Get Steam credentials
creds, err := s.GetCredentials(ctx)
if err != nil {
return fmt.Errorf("failed to get Steam credentials: %v", err)
}
// Get SteamCMD path from environment variable
steamCMDPath := env.GetSteamCMDPath()
// Build SteamCMD command arguments
steamCMDArgs := []string{
"+force_install_dir", absPath,
"+login",
}
if creds != nil && creds.Username != "" {
logging.Info("Using Steam credentials for user: %s", creds.Username)
steamCMDArgs = append(steamCMDArgs, creds.Username)
if creds.Password != "" {
steamCMDArgs = append(steamCMDArgs, creds.Password)
}
} else {
logging.Info("Using anonymous Steam login")
steamCMDArgs = append(steamCMDArgs, "anonymous")
}
steamCMDArgs = append(steamCMDArgs,
"+app_update", ACCServerAppID,
"validate",
"+quit",
)
// Execute SteamCMD directly without PowerShell wrapper to get better output capture
args := steamCMDArgs
// Use interactive executor to handle potential 2FA prompts with timeout
logging.Info("Installing ACC server to %s...", absPath)
logging.Info("SteamCMD command: %s %s", steamCMDPath, strings.Join(args, " "))
// Create a context with timeout to prevent hanging indefinitely
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) // Increased timeout
defer cancel()
// Update the executor to use SteamCMD directly
originalExePath := s.interactiveExecutor.ExePath
s.interactiveExecutor.ExePath = steamCMDPath
defer func() {
s.interactiveExecutor.ExePath = originalExePath
}()
if err := s.interactiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil {
logging.Error("SteamCMD execution failed: %v", err)
if timeoutCtx.Err() == context.DeadlineExceeded {
return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required")
}
return fmt.Errorf("failed to run SteamCMD: %v", err)
}
logging.Info("SteamCMD execution completed successfully, proceeding with verification...")
// Add a delay to allow Steam to properly cleanup
logging.Info("Waiting for Steam operations to complete...")
time.Sleep(5 * time.Second)
// Verify installation
exePath := filepath.Join(absPath, "server", "accServer.exe")
logging.Info("Checking for ACC server executable at: %s", exePath)
if _, err := os.Stat(exePath); os.IsNotExist(err) {
// Log directory contents to help debug
logging.Info("accServer.exe not found, checking directory contents...")
if entries, dirErr := os.ReadDir(absPath); dirErr == nil {
logging.Info("Contents of %s:", absPath)
for _, entry := range entries {
logging.Info(" - %s (dir: %v)", entry.Name(), entry.IsDir())
}
}
// Check if there's a server subdirectory
serverDir := filepath.Join(absPath, "server")
if entries, dirErr := os.ReadDir(serverDir); dirErr == nil {
logging.Info("Contents of %s:", serverDir)
for _, entry := range entries {
logging.Info(" - %s (dir: %v)", entry.Name(), entry.IsDir())
}
} else {
logging.Info("Server directory %s does not exist or cannot be read: %v", serverDir, dirErr)
}
return fmt.Errorf("server installation failed: accServer.exe not found in %s", exePath)
}
logging.Info("Server installation completed successfully - accServer.exe found at %s", exePath)
return nil
}
// InstallServerWithWebSocket installs a server with WebSocket output streaming
func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPath string, serverID *uuid.UUID, wsService *WebSocketService) error {
if err := s.ensureSteamCMD(ctx); err != nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Error ensuring SteamCMD: %v", err), true)
return err
}
// Validate installation path for security
if err := s.pathValidator.ValidateInstallPath(installPath); err != nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Invalid installation path: %v", err), true)
return fmt.Errorf("invalid installation path: %v", err)
}
// Convert to absolute path and ensure proper Windows path format
absPath, err := filepath.Abs(installPath)
if err != nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to get absolute path: %v", err), true)
@@ -233,7 +104,6 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa
}
absPath = filepath.Clean(absPath)
// Ensure install path exists
if err := os.MkdirAll(absPath, 0755); err != nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to create install directory: %v", err), true)
return fmt.Errorf("failed to create install directory: %v", err)
@@ -241,17 +111,14 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Installation directory prepared: %s", absPath), false)
// Get Steam credentials
creds, err := s.GetCredentials(ctx)
if err != nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to get Steam credentials: %v", err), true)
return fmt.Errorf("failed to get Steam credentials: %v", err)
}
// Get SteamCMD path from environment variable
steamCMDPath := env.GetSteamCMDPath()
// Build SteamCMD command arguments
steamCMDArgs := []string{
"+force_install_dir", absPath,
"+login",
@@ -274,27 +141,32 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa
"+quit",
)
// Execute SteamCMD with WebSocket output streaming
args := steamCMDArgs
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Starting SteamCMD: %s %s", steamCMDPath, strings.Join(args, " ")), false)
// Create a context with timeout to prevent hanging indefinitely
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
defer cancel()
// Update the executor to use SteamCMD directly
originalExePath := s.interactiveExecutor.ExePath
s.interactiveExecutor.ExePath = steamCMDPath
defer func() {
s.interactiveExecutor.ExePath = originalExePath
}()
callbackConfig := &command.CallbackConfig{
OnOutput: func(serverID uuid.UUID, output string, isError bool) {
wsService.BroadcastSteamOutput(serverID, output, isError)
},
OnCommand: func(serverID uuid.UUID, command string, args []string, completed bool, success bool, error string) {
if completed {
if success {
wsService.BroadcastSteamOutput(serverID, "Command completed successfully", false)
} else {
wsService.BroadcastSteamOutput(serverID, fmt.Sprintf("Command failed: %s", error), true)
}
}
},
}
// Create a modified interactive executor that streams output to WebSocket
wsInteractiveExecutor := command.NewInteractiveCommandExecutorWithWebSocket(s.executor, s.tfaManager, wsService, *serverID)
wsInteractiveExecutor.ExePath = steamCMDPath
callbackInteractiveExecutor := command.NewCallbackInteractiveCommandExecutor(s.executor, s.tfaManager, callbackConfig, *serverID)
callbackInteractiveExecutor.ExePath = steamCMDPath
if err := wsInteractiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil {
if err := callbackInteractiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("SteamCMD execution failed: %v", err), true)
if timeoutCtx.Err() == context.DeadlineExceeded {
return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required")
@@ -304,11 +176,9 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa
wsService.BroadcastSteamOutput(*serverID, "SteamCMD execution completed successfully, proceeding with verification...", false)
// Add a delay to allow Steam to properly cleanup
wsService.BroadcastSteamOutput(*serverID, "Waiting for Steam operations to complete...", false)
time.Sleep(5 * time.Second)
// Verify installation
exePath := filepath.Join(absPath, "server", "accServer.exe")
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Checking for ACC server executable at: %s", exePath), false)
@@ -322,7 +192,6 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa
}
}
// Check if there's a server subdirectory
serverDir := filepath.Join(absPath, "server")
if entries, dirErr := os.ReadDir(serverDir); dirErr == nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Contents of %s:", serverDir), false)
@@ -341,8 +210,117 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa
return nil
}
func (s *SteamService) UpdateServer(ctx context.Context, installPath string, serverID *uuid.UUID) error {
return s.InstallServer(ctx, installPath, serverID) // Same process as install
func (s *SteamService) InstallServerWithCallbacks(ctx context.Context, installPath string, serverID *uuid.UUID, outputCallback command.OutputCallback) error {
if err := s.ensureSteamCMD(ctx); err != nil {
outputCallback(*serverID, fmt.Sprintf("Error ensuring SteamCMD: %v", err), true)
return err
}
if err := s.pathValidator.ValidateInstallPath(installPath); err != nil {
outputCallback(*serverID, fmt.Sprintf("Invalid installation path: %v", err), true)
return fmt.Errorf("invalid installation path: %v", err)
}
absPath, err := filepath.Abs(installPath)
if err != nil {
outputCallback(*serverID, fmt.Sprintf("Failed to get absolute path: %v", err), true)
return fmt.Errorf("failed to get absolute path: %v", err)
}
absPath = filepath.Clean(absPath)
if err := os.MkdirAll(absPath, 0755); err != nil {
outputCallback(*serverID, fmt.Sprintf("Failed to create install directory: %v", err), true)
return fmt.Errorf("failed to create install directory: %v", err)
}
outputCallback(*serverID, fmt.Sprintf("Installation directory prepared: %s", absPath), false)
creds, err := s.GetCredentials(ctx)
if err != nil {
outputCallback(*serverID, fmt.Sprintf("Failed to get Steam credentials: %v", err), true)
return fmt.Errorf("failed to get Steam credentials: %v", err)
}
steamCMDPath := env.GetSteamCMDPath()
steamCMDArgs := []string{
"+force_install_dir", absPath,
"+login",
}
if creds != nil && creds.Username != "" {
outputCallback(*serverID, fmt.Sprintf("Using Steam credentials for user: %s", creds.Username), false)
steamCMDArgs = append(steamCMDArgs, creds.Username)
if creds.Password != "" {
steamCMDArgs = append(steamCMDArgs, creds.Password)
}
} else {
outputCallback(*serverID, "Using anonymous Steam login", false)
steamCMDArgs = append(steamCMDArgs, "anonymous")
}
steamCMDArgs = append(steamCMDArgs,
"+app_update", ACCServerAppID,
"validate",
"+quit",
)
args := steamCMDArgs
outputCallback(*serverID, fmt.Sprintf("Starting SteamCMD: %s %s", steamCMDPath, strings.Join(args, " ")), false)
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
defer cancel()
callbacks := &command.CallbackConfig{
OnOutput: outputCallback,
}
callbackExecutor := command.NewCallbackInteractiveCommandExecutor(s.executor, s.tfaManager, callbacks, *serverID)
callbackExecutor.ExePath = steamCMDPath
if err := callbackExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil {
outputCallback(*serverID, fmt.Sprintf("SteamCMD execution failed: %v", err), true)
if timeoutCtx.Err() == context.DeadlineExceeded {
return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required")
}
return fmt.Errorf("failed to run SteamCMD: %v", err)
}
outputCallback(*serverID, "SteamCMD execution completed successfully, proceeding with verification...", false)
outputCallback(*serverID, "Waiting for Steam operations to complete...", false)
time.Sleep(5 * time.Second)
exePath := filepath.Join(absPath, "server", "accServer.exe")
outputCallback(*serverID, fmt.Sprintf("Checking for ACC server executable at: %s", exePath), false)
if _, err := os.Stat(exePath); os.IsNotExist(err) {
outputCallback(*serverID, "accServer.exe not found, checking directory contents...", false)
if entries, dirErr := os.ReadDir(absPath); dirErr == nil {
outputCallback(*serverID, fmt.Sprintf("Contents of %s:", absPath), false)
for _, entry := range entries {
outputCallback(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false)
}
}
serverDir := filepath.Join(absPath, "server")
if entries, dirErr := os.ReadDir(serverDir); dirErr == nil {
outputCallback(*serverID, fmt.Sprintf("Contents of %s:", serverDir), false)
for _, entry := range entries {
outputCallback(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false)
}
} else {
outputCallback(*serverID, fmt.Sprintf("Server directory %s does not exist or cannot be read: %v", serverDir, dirErr), true)
}
outputCallback(*serverID, fmt.Sprintf("Server installation failed: accServer.exe not found in %s", exePath), true)
return fmt.Errorf("server installation failed: accServer.exe not found in %s", exePath)
}
outputCallback(*serverID, fmt.Sprintf("Server installation completed successfully - accServer.exe found at %s", exePath), false)
return nil
}
func (s *SteamService) UninstallServer(installPath string) error {

View File

@@ -11,25 +11,21 @@ import (
"github.com/google/uuid"
)
// WebSocketConnection represents a single WebSocket connection
type WebSocketConnection struct {
conn *websocket.Conn
serverID *uuid.UUID // If connected to a specific server creation process
userID *uuid.UUID // User who owns this connection
serverID *uuid.UUID
userID *uuid.UUID
}
// WebSocketService manages WebSocket connections and message broadcasting
type WebSocketService struct {
connections sync.Map // map[string]*WebSocketConnection - key is connection ID
connections sync.Map
mu sync.RWMutex
}
// NewWebSocketService creates a new WebSocket service
func NewWebSocketService() *WebSocketService {
return &WebSocketService{}
}
// AddConnection adds a new WebSocket connection
func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, userID *uuid.UUID) {
wsConn := &WebSocketConnection{
conn: conn,
@@ -39,7 +35,6 @@ func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, u
logging.Info("WebSocket connection added: %s for user: %v", connID, userID)
}
// RemoveConnection removes a WebSocket connection
func (ws *WebSocketService) RemoveConnection(connID string) {
if conn, exists := ws.connections.LoadAndDelete(connID); exists {
if wsConn, ok := conn.(*WebSocketConnection); ok {
@@ -49,7 +44,6 @@ func (ws *WebSocketService) RemoveConnection(connID string) {
logging.Info("WebSocket connection removed: %s", connID)
}
// SetServerID associates a connection with a specific server creation process
func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) {
if conn, exists := ws.connections.Load(connID); exists {
if wsConn, ok := conn.(*WebSocketConnection); ok {
@@ -58,7 +52,6 @@ func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) {
}
}
// BroadcastStep sends a step update to all connections associated with a server
func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerCreationStep, status model.StepStatus, message string, errorMsg string) {
stepMsg := model.StepMessage{
Step: step,
@@ -77,7 +70,6 @@ func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerC
ws.broadcastToServer(serverID, wsMsg)
}
// BroadcastSteamOutput sends Steam command output to all connections associated with a server
func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output string, isError bool) {
steamMsg := model.SteamOutputMessage{
Output: output,
@@ -94,7 +86,6 @@ func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output stri
ws.broadcastToServer(serverID, wsMsg)
}
// BroadcastError sends an error message to all connections associated with a server
func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, details string) {
errorMsg := model.ErrorMessage{
Error: error,
@@ -111,7 +102,6 @@ func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, det
ws.broadcastToServer(serverID, wsMsg)
}
// BroadcastComplete sends a completion message to all connections associated with a server
func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, message string) {
completeMsg := model.CompleteMessage{
ServerID: serverID,
@@ -129,7 +119,6 @@ func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool,
ws.broadcastToServer(serverID, wsMsg)
}
// broadcastToServer sends a message to all connections associated with a specific server
func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.WebSocketMessage) {
data, err := json.Marshal(message)
if err != nil {
@@ -137,22 +126,35 @@ func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.
return
}
sentToAssociatedConnections := false
ws.connections.Range(func(key, value interface{}) bool {
if wsConn, ok := value.(*WebSocketConnection); ok {
// Send to connections associated with this server
if wsConn.serverID != nil && *wsConn.serverID == serverID {
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
// Remove the connection if it's broken
ws.RemoveConnection(key.(string))
} else {
sentToAssociatedConnections = true
}
}
}
return true
})
if !sentToAssociatedConnections && (message.Type == model.MessageTypeStep || message.Type == model.MessageTypeError || message.Type == model.MessageTypeComplete) {
ws.connections.Range(func(key, value interface{}) bool {
if wsConn, ok := value.(*WebSocketConnection); ok {
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
ws.RemoveConnection(key.(string))
}
}
return true
})
}
}
// BroadcastToUser sends a message to all connections owned by a specific user
func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebSocketMessage) {
data, err := json.Marshal(message)
if err != nil {
@@ -162,11 +164,9 @@ func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebS
ws.connections.Range(func(key, value interface{}) bool {
if wsConn, ok := value.(*WebSocketConnection); ok {
// Send to connections owned by this user
if wsConn.userID != nil && *wsConn.userID == userID {
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
// Remove the connection if it's broken
ws.RemoveConnection(key.(string))
}
}
@@ -175,7 +175,6 @@ func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebS
})
}
// GetActiveConnections returns the count of active connections
func (ws *WebSocketService) GetActiveConnections() int {
count := 0
ws.connections.Range(func(key, value interface{}) bool {

View File

@@ -23,34 +23,25 @@ func NewWindowsService() *WindowsService {
}
}
// executeNSSM runs an NSSM command through PowerShell with elevation
func (s *WindowsService) ExecuteNSSM(ctx context.Context, args ...string) (string, error) {
// Get NSSM path from environment variable
nssmPath := env.GetNSSMPath()
// Prepend NSSM path to arguments
nssmArgs := append([]string{"-NoProfile", "-NonInteractive", "-Command", "& " + nssmPath}, args...)
output, err := s.executor.ExecuteWithOutput(nssmArgs...)
if err != nil {
// Log the full command and error for debugging
logging.Error("NSSM command failed: powershell %s", strings.Join(nssmArgs, " "))
logging.Error("NSSM error output: %s", output)
return "", err
}
// Clean up output by removing null bytes and trimming whitespace
cleaned := strings.TrimSpace(strings.ReplaceAll(output, "\x00", ""))
// Remove \r\n from status strings
cleaned = strings.TrimSuffix(cleaned, "\r\n")
return cleaned, nil
}
// Service Installation/Configuration Methods
func (s *WindowsService) CreateService(ctx context.Context, serviceName, execPath, workingDir string, args []string) error {
// Ensure paths are absolute and properly formatted for Windows
absExecPath, err := filepath.Abs(execPath)
if err != nil {
return fmt.Errorf("failed to get absolute path for executable: %v", err)
@@ -63,30 +54,24 @@ func (s *WindowsService) CreateService(ctx context.Context, serviceName, execPat
}
absWorkingDir = filepath.Clean(absWorkingDir)
// Log the paths being used
logging.Info("Creating service '%s' with:", serviceName)
logging.Info(" Executable: %s", absExecPath)
logging.Info(" Working Directory: %s", absWorkingDir)
// First remove any existing service with the same name
s.ExecuteNSSM(ctx, "remove", serviceName, "confirm")
// Install service
if _, err := s.ExecuteNSSM(ctx, "install", serviceName, absExecPath); err != nil {
return fmt.Errorf("failed to install service: %v", err)
}
// Set arguments if provided
if len(args) > 0 {
cmdArgs := append([]string{"set", serviceName, "AppParameters"}, args...)
if _, err := s.ExecuteNSSM(ctx, cmdArgs...); err != nil {
// Try to clean up on failure
s.ExecuteNSSM(ctx, "remove", serviceName, "confirm")
return fmt.Errorf("failed to set arguments: %v", err)
}
}
// Verify service was created
if _, err := s.ExecuteNSSM(ctx, "get", serviceName, "Application"); err != nil {
return fmt.Errorf("service creation verification failed: %v", err)
}
@@ -105,17 +90,13 @@ func (s *WindowsService) DeleteService(ctx context.Context, serviceName string)
}
func (s *WindowsService) UpdateService(ctx context.Context, serviceName, execPath, workingDir string, args []string) error {
// First remove the existing service
if err := s.DeleteService(ctx, serviceName); err != nil {
return err
}
// Then create it again with new parameters
return s.CreateService(ctx, serviceName, execPath, workingDir, args)
}
// Service Control Methods
func (s *WindowsService) Status(ctx context.Context, serviceName string) (string, error) {
return s.ExecuteNSSM(ctx, "status", serviceName)
}
@@ -129,11 +110,9 @@ func (s *WindowsService) Stop(ctx context.Context, serviceName string) (string,
}
func (s *WindowsService) Restart(ctx context.Context, serviceName string) (string, error) {
// First stop the service
if _, err := s.Stop(ctx, serviceName); err != nil {
return "", err
}
// Then start it again
return s.Start(ctx, serviceName)
}

View File

@@ -9,26 +9,22 @@ import (
"go.uber.org/dig"
)
// CacheItem represents an item in the cache
type CacheItem struct {
Value interface{}
Expiration int64
}
// InMemoryCache is a thread-safe in-memory cache
type InMemoryCache struct {
items map[string]CacheItem
mu sync.RWMutex
}
// NewInMemoryCache creates and returns a new InMemoryCache instance
func NewInMemoryCache() *InMemoryCache {
return &InMemoryCache{
items: make(map[string]CacheItem),
}
}
// Set adds an item to the cache with an expiration duration (in seconds)
func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
@@ -44,7 +40,6 @@ func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duratio
}
}
// Get retrieves an item from the cache
func (c *InMemoryCache) Get(key string) (interface{}, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
@@ -55,24 +50,18 @@ func (c *InMemoryCache) Get(key string) (interface{}, bool) {
}
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
// Item has expired, but don't delete here to avoid lock upgrade.
// It will be overwritten on the next Set.
return nil, false
}
return item.Value, true
}
// Delete removes an item from the cache
func (c *InMemoryCache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.items, key)
}
// GetOrSet retrieves an item from the cache. If the item is not found, it
// calls the provided function to get the value, sets it in the cache, and
// returns it.
func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetcher func() (T, error)) (T, error) {
if cached, found := c.Get(key); found {
if value, ok := cached.(T); ok {
@@ -90,7 +79,6 @@ func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetch
return value, nil
}
// Start initializes the cache and provides it to the DI container.
func Start(di *dig.Container) {
cache := NewInMemoryCache()
err := di.Provide(func() *InMemoryCache {

View File

@@ -0,0 +1,245 @@
package command
import (
"acc-server-manager/local/model"
"acc-server-manager/local/utl/logging"
"bufio"
"context"
"fmt"
"io"
"os/exec"
"strings"
"time"
"github.com/google/uuid"
)
type CallbackInteractiveCommandExecutor struct {
*InteractiveCommandExecutor
callbacks *CallbackConfig
serverID uuid.UUID
}
func NewCallbackInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager, callbacks *CallbackConfig, serverID uuid.UUID) *CallbackInteractiveCommandExecutor {
if callbacks == nil {
callbacks = DefaultCallbackConfig()
}
return &CallbackInteractiveCommandExecutor{
InteractiveCommandExecutor: &InteractiveCommandExecutor{
CommandExecutor: baseExecutor,
tfaManager: tfaManager,
},
callbacks: callbacks,
serverID: serverID,
}
}
func (e *CallbackInteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error {
cmd := exec.CommandContext(ctx, e.ExePath, args...)
if e.WorkDir != "" {
cmd.Dir = e.WorkDir
}
stdin, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("failed to create stdin pipe: %v", err)
}
defer stdin.Close()
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create stdout pipe: %v", err)
}
defer stdout.Close()
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %v", err)
}
defer stderr.Close()
logging.Info("Executing interactive command with callbacks: %s %s", e.ExePath, strings.Join(args, " "))
e.callbacks.OnCommand(e.serverID, e.ExePath, args, false, false, "")
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Starting command: %s %s", e.ExePath, strings.Join(args, " ")), false)
if err := cmd.Start(); err != nil {
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to start command: %v", err), true)
return fmt.Errorf("failed to start command: %v", err)
}
outputDone := make(chan error, 1)
cmdDone := make(chan error, 1)
go e.monitorOutputWithCallbacks(ctx, stdout, stderr, serverID, outputDone)
go func() {
cmdDone <- cmd.Wait()
}()
var cmdErr, outputErr error
completedCount := 0
for completedCount < 2 {
select {
case cmdErr = <-cmdDone:
completedCount++
logging.Info("Command execution completed")
e.callbacks.OnOutput(e.serverID, "Command execution completed", false)
case outputErr = <-outputDone:
completedCount++
logging.Info("Output monitoring completed")
case <-ctx.Done():
e.callbacks.OnOutput(e.serverID, "Command execution cancelled", true)
return ctx.Err()
}
}
if outputErr != nil {
logging.Warn("Output monitoring error: %v", outputErr)
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Output monitoring error: %v", outputErr), true)
}
success := cmdErr == nil
errorMsg := ""
if cmdErr != nil {
errorMsg = cmdErr.Error()
}
e.callbacks.OnCommand(e.serverID, e.ExePath, args, true, success, errorMsg)
return cmdErr
}
func (e *CallbackInteractiveCommandExecutor) monitorOutputWithCallbacks(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) {
defer func() {
select {
case done <- nil:
default:
}
}()
stdoutScanner := bufio.NewScanner(stdout)
stderrScanner := bufio.NewScanner(stderr)
outputChan := make(chan outputLine, 100)
readersDone := make(chan struct{}, 2)
steamConsoleStarted := false
tfaRequestCreated := false
go func() {
defer func() { readersDone <- struct{}{} }()
for stdoutScanner.Scan() {
line := stdoutScanner.Text()
if e.LogOutput {
logging.Info("STDOUT: %s", line)
}
e.callbacks.OnOutput(e.serverID, line, false)
select {
case outputChan <- outputLine{text: line, isError: false}:
case <-ctx.Done():
return
}
}
if err := stdoutScanner.Err(); err != nil {
logging.Warn("Stdout scanner error: %v", err)
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Stdout scanner error: %v", err), true)
}
}()
go func() {
defer func() { readersDone <- struct{}{} }()
for stderrScanner.Scan() {
line := stderrScanner.Text()
if e.LogOutput {
logging.Info("STDERR: %s", line)
}
e.callbacks.OnOutput(e.serverID, line, true)
select {
case outputChan <- outputLine{text: line, isError: true}:
case <-ctx.Done():
return
}
}
if err := stderrScanner.Err(); err != nil {
logging.Warn("Stderr scanner error: %v", err)
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Stderr scanner error: %v", err), true)
}
}()
readersFinished := 0
for {
select {
case <-ctx.Done():
done <- ctx.Err()
return
case <-readersDone:
readersFinished++
if readersFinished == 2 {
close(outputChan)
for lineData := range outputChan {
if e.is2FAPrompt(lineData.text) {
if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil {
logging.Error("Failed to handle 2FA prompt: %v", err)
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true)
done <- err
return
}
}
}
return
}
case lineData, ok := <-outputChan:
if !ok {
return
}
lowerLine := strings.ToLower(lineData.text)
if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") {
steamConsoleStarted = true
logging.Info("Steam Console Client startup detected - will monitor for 2FA hang")
e.callbacks.OnOutput(e.serverID, "Steam Console Client startup detected", false)
}
if e.is2FAPrompt(lineData.text) {
if !tfaRequestCreated {
e.callbacks.OnOutput(e.serverID, "2FA prompt detected - waiting for user confirmation", false)
if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil {
logging.Error("Failed to handle 2FA prompt: %v", err)
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true)
done <- err
return
}
tfaRequestCreated = true
}
}
if tfaRequestCreated && e.isSteamContinuing(lineData.text) {
logging.Info("Steam CMD appears to have continued after 2FA confirmation")
e.callbacks.OnOutput(e.serverID, "Steam CMD continued after 2FA confirmation", false)
e.autoCompletePendingRequests(serverID)
}
case <-time.After(15 * time.Second):
if steamConsoleStarted && !tfaRequestCreated {
logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA")
e.callbacks.OnOutput(e.serverID, "Waiting for Steam Guard 2FA confirmation...", false)
if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil {
logging.Error("Failed to handle Steam Guard 2FA prompt: %v", err)
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle Steam Guard 2FA prompt: %v", err), true)
done <- err
return
}
tfaRequestCreated = true
}
}
}
}
type outputLine struct {
text string
isError bool
}

View File

@@ -0,0 +1,19 @@
package command
import "github.com/google/uuid"
type OutputCallback func(serverID uuid.UUID, output string, isError bool)
type CommandCallback func(serverID uuid.UUID, command string, args []string, completed bool, success bool, error string)
type CallbackConfig struct {
OnOutput OutputCallback
OnCommand CommandCallback
}
func DefaultCallbackConfig() *CallbackConfig {
return &CallbackConfig{
OnOutput: func(uuid.UUID, string, bool) {},
OnCommand: func(uuid.UUID, string, []string, bool, bool, string) {},
}
}

View File

@@ -8,17 +8,12 @@ import (
"strings"
)
// CommandExecutor provides a base structure for executing commands
type CommandExecutor struct {
// Base executable path
ExePath string
// Working directory for commands
WorkDir string
// Whether to capture and log output
ExePath string
WorkDir string
LogOutput bool
}
// CommandBuilder helps build command arguments
type CommandBuilder struct {
args []string
}
@@ -48,10 +43,9 @@ func (b *CommandBuilder) Build() []string {
return b.args
}
// Execute runs a command with the given arguments
func (e *CommandExecutor) Execute(args ...string) error {
cmd := exec.Command(e.ExePath, args...)
if e.WorkDir != "" {
cmd.Dir = e.WorkDir
}
@@ -65,15 +59,13 @@ func (e *CommandExecutor) Execute(args ...string) error {
return cmd.Run()
}
// ExecuteWithBuilder runs a command using a CommandBuilder
func (e *CommandExecutor) ExecuteWithBuilder(builder *CommandBuilder) error {
return e.Execute(builder.Build()...)
}
// ExecuteWithOutput runs a command and returns its output
func (e *CommandExecutor) ExecuteWithOutput(args ...string) (string, error) {
cmd := exec.Command(e.ExePath, args...)
if e.WorkDir != "" {
cmd.Dir = e.WorkDir
}
@@ -83,10 +75,9 @@ func (e *CommandExecutor) ExecuteWithOutput(args ...string) (string, error) {
return string(output), err
}
// ExecuteWithEnv runs a command with custom environment variables
func (e *CommandExecutor) ExecuteWithEnv(env []string, args ...string) error {
cmd := exec.Command(e.ExePath, args...)
if e.WorkDir != "" {
cmd.Dir = e.WorkDir
}
@@ -100,4 +91,4 @@ func (e *CommandExecutor) ExecuteWithEnv(env []string, args ...string) error {
logging.Info("Executing command: %s %s", e.ExePath, strings.Join(args, " "))
return cmd.Run()
}
}

View File

@@ -9,14 +9,12 @@ import (
"io"
"os"
"os/exec"
"reflect"
"strings"
"time"
"github.com/google/uuid"
)
// InteractiveCommandExecutor extends CommandExecutor to handle interactive commands
type InteractiveCommandExecutor struct {
*CommandExecutor
tfaManager *model.Steam2FAManager
@@ -29,7 +27,6 @@ func NewInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *mo
}
}
// ExecuteInteractive runs a command that may require 2FA input
func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error {
cmd := exec.CommandContext(ctx, e.ExePath, args...)
@@ -37,7 +34,6 @@ func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, ser
cmd.Dir = e.WorkDir
}
// Create pipes for stdin, stdout, and stderr
stdin, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("failed to create stdin pipe: %v", err)
@@ -58,7 +54,6 @@ func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, ser
logging.Info("Executing interactive command: %s %s", e.ExePath, strings.Join(args, " "))
// Enable debug mode if environment variable is set
debugMode := os.Getenv("STEAMCMD_DEBUG") == "true"
if debugMode {
logging.Info("STEAMCMD_DEBUG mode enabled - will log all output and create proactive 2FA requests")
@@ -68,19 +63,15 @@ func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, ser
return fmt.Errorf("failed to start command: %v", err)
}
// Create channels for output monitoring
outputDone := make(chan error, 1)
cmdDone := make(chan error, 1)
// Monitor stdout and stderr for 2FA prompts
go e.monitorOutput(ctx, stdout, stderr, serverID, outputDone)
// Wait for the command to finish in a separate goroutine
go func() {
cmdDone <- cmd.Wait()
}()
// Wait for both command and output monitoring to complete
var cmdErr, outputErr error
completedCount := 0
@@ -112,18 +103,15 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
}
}()
// Create scanners for both outputs
stdoutScanner := bufio.NewScanner(stdout)
stderrScanner := bufio.NewScanner(stderr)
outputChan := make(chan string, 100) // Buffered channel to prevent blocking
outputChan := make(chan string, 100)
readersDone := make(chan struct{}, 2)
// Track Steam Console startup for this specific execution
steamConsoleStarted := false
tfaRequestCreated := false
// Read from stdout
go func() {
defer func() { readersDone <- struct{}{} }()
for stdoutScanner.Scan() {
@@ -131,7 +119,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
if e.LogOutput {
logging.Info("STDOUT: %s", line)
}
// Always log Steam CMD output for debugging 2FA issues
if strings.Contains(strings.ToLower(line), "steam") {
logging.Info("STEAM_DEBUG: %s", line)
}
@@ -146,7 +133,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
}
}()
// Read from stderr
go func() {
defer func() { readersDone <- struct{}{} }()
for stderrScanner.Scan() {
@@ -154,7 +140,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
if e.LogOutput {
logging.Info("STDERR: %s", line)
}
// Always log Steam CMD errors for debugging 2FA issues
if strings.Contains(strings.ToLower(line), "steam") {
logging.Info("STEAM_DEBUG_ERR: %s", line)
}
@@ -169,7 +154,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
}
}()
// Monitor for completion and 2FA prompts
readersFinished := 0
for {
select {
@@ -179,9 +163,7 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
case <-readersDone:
readersFinished++
if readersFinished == 2 {
// Both readers are done, close output channel and finish monitoring
close(outputChan)
// Drain any remaining output
for line := range outputChan {
if e.is2FAPrompt(line) {
if err := e.handle2FAPrompt(ctx, line, serverID); err != nil {
@@ -195,18 +177,15 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
}
case line, ok := <-outputChan:
if !ok {
// Channel closed, we're done
return
}
// Check for Steam Console startup
lowerLine := strings.ToLower(line)
if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") {
steamConsoleStarted = true
logging.Info("Steam Console Client startup detected - will monitor for 2FA hang")
}
// Check if this line indicates a 2FA prompt
if e.is2FAPrompt(line) {
if !tfaRequestCreated {
if err := e.handle2FAPrompt(ctx, line, serverID); err != nil {
@@ -218,15 +197,11 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
}
}
// Check if Steam CMD continued after 2FA (auto-completion)
if tfaRequestCreated && e.isSteamContinuing(line) {
logging.Info("Steam CMD appears to have continued after 2FA confirmation - auto-completing 2FA request")
// Auto-complete any pending 2FA requests for this server
e.autoCompletePendingRequests(serverID)
}
case <-time.After(15 * time.Second):
// If Steam Console has started and we haven't seen output for 15 seconds,
// it's very likely waiting for 2FA confirmation
if steamConsoleStarted && !tfaRequestCreated {
logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA")
if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil {
@@ -243,7 +218,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
}
func (e *InteractiveCommandExecutor) is2FAPrompt(line string) bool {
// Common SteamCMD 2FA prompts - updated with more comprehensive patterns
twoFAKeywords := []string{
"please enter your steam guard code",
"steam guard",
@@ -272,7 +246,6 @@ func (e *InteractiveCommandExecutor) is2FAPrompt(line string) bool {
}
}
// Also check for patterns that might indicate Steam is waiting for input
waitingPatterns := []string{
"waiting for",
"please enter",
@@ -330,12 +303,9 @@ func (e *InteractiveCommandExecutor) autoCompletePendingRequests(serverID *uuid.
func (e *InteractiveCommandExecutor) handle2FAPrompt(_ context.Context, promptLine string, serverID *uuid.UUID) error {
logging.Info("2FA prompt detected: %s", promptLine)
// Create a 2FA request
request := e.tfaManager.CreateRequest(promptLine, serverID)
logging.Info("Created 2FA request with ID: %s", request.ID)
// Wait for user to complete the 2FA process
// Use a reasonable timeout (e.g., 5 minutes)
timeout := 5 * time.Minute
success, err := e.tfaManager.WaitForCompletion(request.ID, timeout)
@@ -352,271 +322,3 @@ func (e *InteractiveCommandExecutor) handle2FAPrompt(_ context.Context, promptLi
logging.Info("2FA completed successfully")
return nil
}
// WebSocketInteractiveCommandExecutor extends InteractiveCommandExecutor to stream output via WebSocket
type WebSocketInteractiveCommandExecutor struct {
*InteractiveCommandExecutor
wsService interface{} // Using interface{} to avoid circular import
serverID uuid.UUID
}
// NewInteractiveCommandExecutorWithWebSocket creates a new WebSocket-enabled interactive command executor
func NewInteractiveCommandExecutorWithWebSocket(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager, wsService interface{}, serverID uuid.UUID) *WebSocketInteractiveCommandExecutor {
return &WebSocketInteractiveCommandExecutor{
InteractiveCommandExecutor: &InteractiveCommandExecutor{
CommandExecutor: baseExecutor,
tfaManager: tfaManager,
},
wsService: wsService,
serverID: serverID,
}
}
// ExecuteInteractive runs a command with WebSocket output streaming
func (e *WebSocketInteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error {
cmd := exec.CommandContext(ctx, e.ExePath, args...)
if e.WorkDir != "" {
cmd.Dir = e.WorkDir
}
// Create pipes for stdin, stdout, and stderr
stdin, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("failed to create stdin pipe: %v", err)
}
defer stdin.Close()
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create stdout pipe: %v", err)
}
defer stdout.Close()
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %v", err)
}
defer stderr.Close()
logging.Info("Executing interactive command with WebSocket streaming: %s %s", e.ExePath, strings.Join(args, " "))
// Broadcast command start via WebSocket
e.broadcastSteamOutput(fmt.Sprintf("Starting command: %s %s", e.ExePath, strings.Join(args, " ")), false)
if err := cmd.Start(); err != nil {
e.broadcastSteamOutput(fmt.Sprintf("Failed to start command: %v", err), true)
return fmt.Errorf("failed to start command: %v", err)
}
// Create channels for output monitoring
outputDone := make(chan error, 1)
cmdDone := make(chan error, 1)
// Monitor stdout and stderr for 2FA prompts with WebSocket streaming
go e.monitorOutputWithWebSocket(ctx, stdout, stderr, serverID, outputDone)
// Wait for the command to finish in a separate goroutine
go func() {
cmdDone <- cmd.Wait()
}()
// Wait for both command and output monitoring to complete
var cmdErr, outputErr error
completedCount := 0
for completedCount < 2 {
select {
case cmdErr = <-cmdDone:
completedCount++
logging.Info("Command execution completed")
e.broadcastSteamOutput("Command execution completed", false)
case outputErr = <-outputDone:
completedCount++
logging.Info("Output monitoring completed")
case <-ctx.Done():
e.broadcastSteamOutput("Command execution cancelled", true)
return ctx.Err()
}
}
if outputErr != nil {
logging.Warn("Output monitoring error: %v", outputErr)
e.broadcastSteamOutput(fmt.Sprintf("Output monitoring error: %v", outputErr), true)
}
return cmdErr
}
// broadcastSteamOutput sends output to WebSocket using reflection to avoid circular imports
func (e *WebSocketInteractiveCommandExecutor) broadcastSteamOutput(output string, isError bool) {
if e.wsService == nil {
return
}
// Use reflection to call BroadcastSteamOutput method
wsServiceVal := reflect.ValueOf(e.wsService)
method := wsServiceVal.MethodByName("BroadcastSteamOutput")
if !method.IsValid() {
logging.Warn("BroadcastSteamOutput method not found on WebSocket service")
return
}
// Call the method with parameters: serverID, output, isError
args := []reflect.Value{
reflect.ValueOf(e.serverID),
reflect.ValueOf(output),
reflect.ValueOf(isError),
}
method.Call(args)
}
// monitorOutputWithWebSocket monitors command output and streams it via WebSocket
func (e *WebSocketInteractiveCommandExecutor) monitorOutputWithWebSocket(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) {
defer func() {
select {
case done <- nil:
default:
}
}()
// Create scanners for both outputs
stdoutScanner := bufio.NewScanner(stdout)
stderrScanner := bufio.NewScanner(stderr)
outputChan := make(chan outputLine, 100) // Buffered channel to prevent blocking
readersDone := make(chan struct{}, 2)
// Track Steam Console startup for this specific execution
steamConsoleStarted := false
tfaRequestCreated := false
// Read from stdout
go func() {
defer func() { readersDone <- struct{}{} }()
for stdoutScanner.Scan() {
line := stdoutScanner.Text()
if e.LogOutput {
logging.Info("STDOUT: %s", line)
}
// Stream output via WebSocket
e.broadcastSteamOutput(line, false)
select {
case outputChan <- outputLine{text: line, isError: false}:
case <-ctx.Done():
return
}
}
if err := stdoutScanner.Err(); err != nil {
logging.Warn("Stdout scanner error: %v", err)
e.broadcastSteamOutput(fmt.Sprintf("Stdout scanner error: %v", err), true)
}
}()
// Read from stderr
go func() {
defer func() { readersDone <- struct{}{} }()
for stderrScanner.Scan() {
line := stderrScanner.Text()
if e.LogOutput {
logging.Info("STDERR: %s", line)
}
// Stream error output via WebSocket
e.broadcastSteamOutput(line, true)
select {
case outputChan <- outputLine{text: line, isError: true}:
case <-ctx.Done():
return
}
}
if err := stderrScanner.Err(); err != nil {
logging.Warn("Stderr scanner error: %v", err)
e.broadcastSteamOutput(fmt.Sprintf("Stderr scanner error: %v", err), true)
}
}()
// Monitor for completion and 2FA prompts
readersFinished := 0
for {
select {
case <-ctx.Done():
done <- ctx.Err()
return
case <-readersDone:
readersFinished++
if readersFinished == 2 {
// Both readers are done, close output channel and finish monitoring
close(outputChan)
// Drain any remaining output
for lineData := range outputChan {
if e.is2FAPrompt(lineData.text) {
if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil {
logging.Error("Failed to handle 2FA prompt: %v", err)
e.broadcastSteamOutput(fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true)
done <- err
return
}
}
}
return
}
case lineData, ok := <-outputChan:
if !ok {
// Channel closed, we're done
return
}
// Check for Steam Console startup
lowerLine := strings.ToLower(lineData.text)
if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") {
steamConsoleStarted = true
logging.Info("Steam Console Client startup detected - will monitor for 2FA hang")
e.broadcastSteamOutput("Steam Console Client startup detected", false)
}
// Check if this line indicates a 2FA prompt
if e.is2FAPrompt(lineData.text) {
if !tfaRequestCreated {
e.broadcastSteamOutput("2FA prompt detected - waiting for user confirmation", false)
if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil {
logging.Error("Failed to handle 2FA prompt: %v", err)
e.broadcastSteamOutput(fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true)
done <- err
return
}
tfaRequestCreated = true
}
}
// Check if Steam CMD continued after 2FA (auto-completion)
if tfaRequestCreated && e.isSteamContinuing(lineData.text) {
logging.Info("Steam CMD appears to have continued after 2FA confirmation")
e.broadcastSteamOutput("Steam CMD continued after 2FA confirmation", false)
// Auto-complete any pending 2FA requests for this server
e.autoCompletePendingRequests(serverID)
}
case <-time.After(15 * time.Second):
// If Steam Console has started and we haven't seen output for 15 seconds,
// it's very likely waiting for 2FA confirmation
if steamConsoleStarted && !tfaRequestCreated {
logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA")
e.broadcastSteamOutput("Waiting for Steam Guard 2FA confirmation...", false)
if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil {
logging.Error("Failed to handle Steam Guard 2FA prompt: %v", err)
e.broadcastSteamOutput(fmt.Sprintf("Failed to handle Steam Guard 2FA prompt: %v", err), true)
done <- err
return
}
tfaRequestCreated = true
}
}
}
}
// outputLine represents a line of output with error status
type outputLine struct {
text string
isError bool
}

View File

@@ -79,12 +79,6 @@ func IndentJson(body []byte) ([]byte, error) {
return unmarshaledBody.Bytes(), nil
}
// ParseQueryFilter parses query parameters into a filter struct using reflection.
// It supports various field types and uses struct tags to determine parsing behavior.
// Supported tags:
// - `query:"field_name"` - specifies the query parameter name
// - `param:"param_name"` - specifies the path parameter name
// - `time_format:"format"` - specifies the time format for parsing dates (default: RFC3339)
func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
val := reflect.ValueOf(filter)
if val.Kind() != reflect.Ptr || val.IsNil() {
@@ -94,14 +88,12 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
elem := val.Elem()
typ := elem.Type()
// Process all fields including embedded structs
var processFields func(reflect.Value, reflect.Type) error
processFields = func(val reflect.Value, typ reflect.Type) error {
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
// Handle embedded structs recursively
if fieldType.Anonymous {
if err := processFields(field, fieldType.Type); err != nil {
return err
@@ -109,12 +101,10 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
continue
}
// Skip if field cannot be set
if !field.CanSet() {
continue
}
// Check for param tag first (path parameters)
if paramName := fieldType.Tag.Get("param"); paramName != "" {
if err := parsePathParam(c, field, paramName); err != nil {
return fmt.Errorf("error parsing path parameter %s: %v", paramName, err)
@@ -122,15 +112,14 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
continue
}
// Then check for query tag
queryName := fieldType.Tag.Get("query")
if queryName == "" {
queryName = ToSnakeCase(fieldType.Name) // Default to snake_case of field name
queryName = ToSnakeCase(fieldType.Name)
}
queryVal := c.Query(queryName)
if queryVal == "" {
continue // Skip empty values
continue
}
if err := parseValue(field, queryVal, fieldType.Tag); err != nil {

View File

@@ -18,7 +18,6 @@ var (
func Init() {
godotenv.Load()
// Fail fast if critical environment variables are missing
Secret = getEnvRequired("APP_SECRET")
SecretCode = getEnvRequired("APP_SECRET_CODE")
EncryptionKey = getEnvRequired("ENCRYPTION_KEY")
@@ -29,7 +28,6 @@ func Init() {
}
}
// getEnv retrieves an environment variable or returns a fallback value.
func getEnv(key, fallback string) string {
if value, exists := os.LookupEnv(key); exists {
return value
@@ -38,12 +36,10 @@ func getEnv(key, fallback string) string {
return fallback
}
// getEnvRequired retrieves an environment variable and fails if it's not set.
// This should be used for critical configuration that must not have defaults.
func getEnvRequired(key string) string {
if value, exists := os.LookupEnv(key); exists && value != "" {
return value
}
log.Fatalf("Required environment variable %s is not set or is empty", key)
return "" // This line will never be reached due to log.Fatalf
return ""
}

View File

@@ -33,7 +33,6 @@ func Start(di *dig.Container) {
func Migrate(db *gorm.DB) {
logging.Info("Migrating database")
// Run GORM AutoMigrate for all models
err := db.AutoMigrate(
&model.ServiceControlModel{},
&model.Config{},
@@ -52,7 +51,6 @@ func Migrate(db *gorm.DB) {
if err != nil {
logging.Error("GORM AutoMigrate failed: %v", err)
// Don't panic, just log the error as custom migrations may have handled this
}
db.FirstOrCreate(&model.ServiceControlModel{ServiceControl: "Works"})
@@ -63,10 +61,8 @@ func Migrate(db *gorm.DB) {
func runMigrations(db *gorm.DB) {
logging.Info("Running custom database migrations...")
// Migration 001: Password security upgrade
if err := migrations.RunPasswordSecurityMigration(db); err != nil {
logging.Error("Failed to run password security migration: %v", err)
// Continue - this migration might not be needed for all setups
}
logging.Info("Custom database migrations completed")
@@ -132,7 +128,6 @@ func seedCarModels(db *gorm.DB) error {
carModels := []model.CarModel{
{Value: 0, CarModel: "Porsche 991 GT3 R"},
{Value: 1, CarModel: "Mercedes-AMG GT3"},
// ... Add all car models from your list
}
for _, cm := range carModels {

View File

@@ -6,12 +6,10 @@ import (
)
const (
// Default paths for when environment variables are not set
DefaultSteamCMDPath = "c:\\steamcmd\\steamcmd.exe"
DefaultNSSMPath = ".\\nssm.exe"
)
// GetSteamCMDPath returns the SteamCMD executable path from environment variable or default
func GetSteamCMDPath() string {
if path := os.Getenv("STEAMCMD_PATH"); path != "" {
return path
@@ -19,13 +17,11 @@ func GetSteamCMDPath() string {
return DefaultSteamCMDPath
}
// GetSteamCMDDirPath returns the directory containing SteamCMD executable
func GetSteamCMDDirPath() string {
steamCMDPath := GetSteamCMDPath()
return filepath.Dir(steamCMDPath)
}
// GetNSSMPath returns the NSSM executable path from environment variable or default
func GetNSSMPath() string {
if path := os.Getenv("NSSM_PATH"); path != "" {
return path
@@ -33,17 +29,14 @@ func GetNSSMPath() string {
return DefaultNSSMPath
}
// ValidatePaths checks if the configured paths exist (optional validation)
func ValidatePaths() map[string]error {
errors := make(map[string]error)
// Check SteamCMD path
steamCMDPath := GetSteamCMDPath()
if _, err := os.Stat(steamCMDPath); os.IsNotExist(err) {
errors["STEAMCMD_PATH"] = err
}
// Check NSSM path
nssmPath := GetNSSMPath()
if _, err := os.Stat(nssmPath); os.IsNotExist(err) {
errors["NSSM_PATH"] = err

View File

@@ -9,45 +9,37 @@ import (
"github.com/gofiber/fiber/v2"
)
// ControllerErrorHandler provides centralized error handling for controllers
type ControllerErrorHandler struct {
errorLogger *logging.ErrorLogger
}
// NewControllerErrorHandler creates a new controller error handler instance
func NewControllerErrorHandler() *ControllerErrorHandler {
return &ControllerErrorHandler{
errorLogger: logging.GetErrorLogger(),
}
}
// ErrorResponse represents a standardized error response
type ErrorResponse struct {
Error string `json:"error"`
Code int `json:"code,omitempty"`
Details map[string]string `json:"details,omitempty"`
}
// HandleError handles controller errors with logging and standardized responses
func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error {
if err == nil {
return nil
}
// Get caller information for logging
_, file, line, _ := runtime.Caller(1)
file = strings.TrimPrefix(file, "acc-server-manager/")
// Build context string
contextStr := ""
if len(context) > 0 {
contextStr = fmt.Sprintf("[%s] ", strings.Join(context, "|"))
}
// Clean error message (remove null bytes)
cleanErrorMsg := strings.ReplaceAll(err.Error(), "\x00", "")
// Log the error with context
ceh.errorLogger.LogWithContext(
fmt.Sprintf("CONTROLLER_ERROR [%s:%d]", file, line),
"%s%s",
@@ -55,31 +47,26 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo
cleanErrorMsg,
)
// Create standardized error response
errorResponse := ErrorResponse{
Error: cleanErrorMsg,
Code: statusCode,
}
// Add request details if available
if c != nil {
if errorResponse.Details == nil {
errorResponse.Details = make(map[string]string)
}
// Safely extract request details
func() {
defer func() {
if r := recover(); r != nil {
// If any of these panic, just skip adding the details
return
}
}()
errorResponse.Details["method"] = c.Method()
errorResponse.Details["path"] = c.Path()
// Safely get IP address
if ip := c.IP(); ip != "" {
errorResponse.Details["ip"] = ip
} else {
@@ -88,14 +75,11 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo
}()
}
// Return appropriate response based on status code
if c == nil {
// If context is nil, we can't return a response
return fmt.Errorf("cannot return HTTP response: context is nil")
}
if statusCode >= 500 {
// For server errors, don't expose internal details
return c.Status(statusCode).JSON(ErrorResponse{
Error: "Internal server error",
Code: statusCode,
@@ -105,52 +89,42 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo
return c.Status(statusCode).JSON(errorResponse)
}
// HandleValidationError handles validation errors specifically
func (ceh *ControllerErrorHandler) HandleValidationError(c *fiber.Ctx, err error, field string) error {
return ceh.HandleError(c, err, fiber.StatusBadRequest, "VALIDATION", field)
}
// HandleDatabaseError handles database-related errors
func (ceh *ControllerErrorHandler) HandleDatabaseError(c *fiber.Ctx, err error) error {
return ceh.HandleError(c, err, fiber.StatusInternalServerError, "DATABASE")
}
// HandleAuthError handles authentication/authorization errors
func (ceh *ControllerErrorHandler) HandleAuthError(c *fiber.Ctx, err error) error {
return ceh.HandleError(c, err, fiber.StatusUnauthorized, "AUTH")
}
// HandleNotFoundError handles resource not found errors
func (ceh *ControllerErrorHandler) HandleNotFoundError(c *fiber.Ctx, resource string) error {
err := fmt.Errorf("%s not found", resource)
return ceh.HandleError(c, err, fiber.StatusNotFound, "NOT_FOUND")
}
// HandleBusinessLogicError handles business logic errors
func (ceh *ControllerErrorHandler) HandleBusinessLogicError(c *fiber.Ctx, err error) error {
return ceh.HandleError(c, err, fiber.StatusBadRequest, "BUSINESS_LOGIC")
}
// HandleServiceError handles service layer errors
func (ceh *ControllerErrorHandler) HandleServiceError(c *fiber.Ctx, err error) error {
return ceh.HandleError(c, err, fiber.StatusInternalServerError, "SERVICE")
}
// HandleParsingError handles request parsing errors
func (ceh *ControllerErrorHandler) HandleParsingError(c *fiber.Ctx, err error) error {
return ceh.HandleError(c, err, fiber.StatusBadRequest, "PARSING")
}
// HandleUUIDError handles UUID parsing errors
func (ceh *ControllerErrorHandler) HandleUUIDError(c *fiber.Ctx, field string) error {
err := fmt.Errorf("invalid %s format", field)
return ceh.HandleError(c, err, fiber.StatusBadRequest, "UUID_VALIDATION", field)
}
// Global controller error handler instance
var globalErrorHandler *ControllerErrorHandler
// GetControllerErrorHandler returns the global controller error handler instance
func GetControllerErrorHandler() *ControllerErrorHandler {
if globalErrorHandler == nil {
globalErrorHandler = NewControllerErrorHandler()
@@ -158,49 +132,38 @@ func GetControllerErrorHandler() *ControllerErrorHandler {
return globalErrorHandler
}
// Convenience functions using the global error handler
// HandleError handles controller errors using the global error handler
func HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error {
return GetControllerErrorHandler().HandleError(c, err, statusCode, context...)
}
// HandleValidationError handles validation errors using the global error handler
func HandleValidationError(c *fiber.Ctx, err error, field string) error {
return GetControllerErrorHandler().HandleValidationError(c, err, field)
}
// HandleDatabaseError handles database errors using the global error handler
func HandleDatabaseError(c *fiber.Ctx, err error) error {
return GetControllerErrorHandler().HandleDatabaseError(c, err)
}
// HandleAuthError handles auth errors using the global error handler
func HandleAuthError(c *fiber.Ctx, err error) error {
return GetControllerErrorHandler().HandleAuthError(c, err)
}
// HandleNotFoundError handles not found errors using the global error handler
func HandleNotFoundError(c *fiber.Ctx, resource string) error {
return GetControllerErrorHandler().HandleNotFoundError(c, resource)
}
// HandleBusinessLogicError handles business logic errors using the global error handler
func HandleBusinessLogicError(c *fiber.Ctx, err error) error {
return GetControllerErrorHandler().HandleBusinessLogicError(c, err)
}
// HandleServiceError handles service errors using the global error handler
func HandleServiceError(c *fiber.Ctx, err error) error {
return GetControllerErrorHandler().HandleServiceError(c, err)
}
// HandleParsingError handles parsing errors using the global error handler
func HandleParsingError(c *fiber.Ctx, err error) error {
return GetControllerErrorHandler().HandleParsingError(c, err)
}
// HandleUUIDError handles UUID errors using the global error handler
func HandleUUIDError(c *fiber.Ctx, field string) error {
return GetControllerErrorHandler().HandleUUIDError(c, field)
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/golang-jwt/jwt/v4"
)
// Claims represents the JWT claims.
type Claims struct {
UserID string `json:"user_id"`
IsOpenToken bool `json:"is_open_token"`
@@ -27,7 +26,6 @@ type OpenJWTHandler struct {
*JWTHandler
}
// NewJWTHandler creates a new JWTHandler instance with the provided secret key.
func NewOpenJWTHandler(jwtSecret string) *OpenJWTHandler {
jwtHandler := NewJWTHandler(jwtSecret)
jwtHandler.IsOpenToken = true
@@ -36,7 +34,6 @@ func NewOpenJWTHandler(jwtSecret string) *OpenJWTHandler {
}
}
// NewJWTHandler creates a new JWTHandler instance with the provided secret key.
func NewJWTHandler(jwtSecret string) *JWTHandler {
if jwtSecret == "" {
errors.SafeFatal("JWT_SECRET environment variable is required and cannot be empty")
@@ -44,14 +41,12 @@ func NewJWTHandler(jwtSecret string) *JWTHandler {
var secretKey []byte
// Decode base64 secret if it looks like base64, otherwise use as-is
if decoded, err := base64.StdEncoding.DecodeString(jwtSecret); err == nil && len(decoded) >= 32 {
secretKey = decoded
} else {
secretKey = []byte(jwtSecret)
}
// Ensure minimum key length for security
if len(secretKey) < 32 {
errors.SafeFatal("JWT_SECRET must be at least 32 bytes long for security")
}
@@ -60,8 +55,6 @@ func NewJWTHandler(jwtSecret string) *JWTHandler {
}
}
// GenerateSecretKey generates a cryptographically secure random key for JWT signing
// This is a utility function for generating new secrets, not used in normal operation
func (jh *JWTHandler) GenerateSecretKey() string {
key := make([]byte, 64) // 512 bits
if _, err := rand.Read(key); err != nil {
@@ -70,7 +63,6 @@ func (jh *JWTHandler) GenerateSecretKey() string {
return base64.StdEncoding.EncodeToString(key)
}
// GenerateToken generates a new JWT for a given user.
func (jh *JWTHandler) GenerateToken(userId string) (string, error) {
expirationTime := time.Now().Add(24 * time.Hour)
claims := &Claims{
@@ -99,7 +91,6 @@ func (jh *JWTHandler) GenerateTokenWithExpiry(user *model.User, expiry time.Time
return token.SignedString(jh.SecretKey)
}
// ValidateToken validates a JWT and returns the claims if the token is valid.
func (jh *JWTHandler) ValidateToken(tokenString string) (*Claims, error) {
claims := &Claims{}

View File

@@ -15,7 +15,6 @@ var (
timeFormat = "2006-01-02 15:04:05.000"
)
// BaseLogger provides the core logging functionality
type BaseLogger struct {
file *os.File
logger *log.Logger
@@ -23,7 +22,6 @@ type BaseLogger struct {
initialized bool
}
// LogLevel represents different logging levels
type LogLevel string
const (
@@ -34,28 +32,23 @@ const (
LogLevelPanic LogLevel = "PANIC"
)
// Initialize creates a new base logger instance
func InitializeBase(tp string) (*BaseLogger, error) {
return newBaseLogger(tp)
}
func newBaseLogger(tp string) (*BaseLogger, error) {
// Ensure logs directory exists
if err := os.MkdirAll("logs", 0755); err != nil {
return nil, fmt.Errorf("failed to create logs directory: %v", err)
}
// Open log file with date in name
logPath := filepath.Join("logs", fmt.Sprintf("acc-server-%s-%s.log", time.Now().Format("2006-01-02"), tp))
file, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666)
if err != nil {
return nil, fmt.Errorf("failed to open log file: %v", err)
}
// Create multi-writer for both file and console
multiWriter := io.MultiWriter(file, os.Stdout)
// Create base logger
logger := &BaseLogger{
file: file,
logger: log.New(multiWriter, "", 0),
@@ -65,13 +58,11 @@ func newBaseLogger(tp string) (*BaseLogger, error) {
return logger, nil
}
// GetBaseLogger creates and returns a new base logger instance
func GetBaseLogger(tp string) *BaseLogger {
baseLogger, _ := InitializeBase(tp)
return baseLogger
}
// Close closes the log file
func (bl *BaseLogger) Close() error {
bl.mu.Lock()
defer bl.mu.Unlock()
@@ -82,7 +73,6 @@ func (bl *BaseLogger) Close() error {
return nil
}
// Log writes a log entry with the specified level
func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
if bl == nil || !bl.initialized {
return
@@ -91,14 +81,11 @@ func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
bl.mu.RLock()
defer bl.mu.RUnlock()
// Get caller info (skip 2 frames: this function and the calling Log function)
_, file, line, _ := runtime.Caller(2)
file = filepath.Base(file)
// Format message
msg := fmt.Sprintf(format, v...)
// Format final log line
logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s",
time.Now().Format(timeFormat),
string(level),
@@ -110,7 +97,6 @@ func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
bl.logger.Println(logLine)
}
// LogWithCaller writes a log entry with custom caller depth
func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format string, v ...interface{}) {
if bl == nil || !bl.initialized {
return
@@ -119,14 +105,11 @@ func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format stri
bl.mu.RLock()
defer bl.mu.RUnlock()
// Get caller info with custom depth
_, file, line, _ := runtime.Caller(callerDepth)
file = filepath.Base(file)
// Format message
msg := fmt.Sprintf(format, v...)
// Format final log line
logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s",
time.Now().Format(timeFormat),
string(level),
@@ -138,7 +121,6 @@ func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format stri
bl.logger.Println(logLine)
}
// IsInitialized returns whether the base logger is initialized
func (bl *BaseLogger) IsInitialized() bool {
if bl == nil {
return false
@@ -148,19 +130,16 @@ func (bl *BaseLogger) IsInitialized() bool {
return bl.initialized
}
// RecoverAndLog recovers from panics and logs them
func RecoverAndLog() {
baseLogger := GetBaseLogger("panic")
if baseLogger != nil && baseLogger.IsInitialized() {
if r := recover(); r != nil {
// Get stack trace
buf := make([]byte, 4096)
n := runtime.Stack(buf, false)
stackTrace := string(buf[:n])
baseLogger.LogWithCaller(LogLevelPanic, 2, "Recovered from panic: %v\nStack Trace:\n%s", r, stackTrace)
// Re-panic to maintain original behavior if needed
panic(r)
}
}

View File

@@ -6,12 +6,10 @@ import (
"sync"
)
// DebugLogger handles debug-level logging
type DebugLogger struct {
base *BaseLogger
}
// NewDebugLogger creates a new debug logger instance
func NewDebugLogger() *DebugLogger {
base, _ := InitializeBase("debug")
return &DebugLogger{
@@ -19,14 +17,12 @@ func NewDebugLogger() *DebugLogger {
}
}
// Log writes a debug-level log entry
func (dl *DebugLogger) Log(format string, v ...interface{}) {
if dl.base != nil {
dl.base.Log(LogLevelDebug, format, v...)
}
}
// LogWithContext writes a debug-level log entry with additional context
func (dl *DebugLogger) LogWithContext(context string, format string, v ...interface{}) {
if dl.base != nil {
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
@@ -34,7 +30,6 @@ func (dl *DebugLogger) LogWithContext(context string, format string, v ...interf
}
}
// LogFunction logs function entry and exit for debugging
func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) {
if dl.base != nil {
if len(args) > 0 {
@@ -45,21 +40,18 @@ func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) {
}
}
// LogVariable logs variable values for debugging
func (dl *DebugLogger) LogVariable(varName string, value interface{}) {
if dl.base != nil {
dl.base.Log(LogLevelDebug, "VARIABLE [%s]: %+v", varName, value)
}
}
// LogState logs application state information
func (dl *DebugLogger) LogState(component string, state interface{}) {
if dl.base != nil {
dl.base.Log(LogLevelDebug, "STATE [%s]: %+v", component, state)
}
}
// LogSQL logs SQL queries for debugging
func (dl *DebugLogger) LogSQL(query string, args ...interface{}) {
if dl.base != nil {
if len(args) > 0 {
@@ -70,7 +62,6 @@ func (dl *DebugLogger) LogSQL(query string, args ...interface{}) {
}
}
// LogMemory logs memory usage information
func (dl *DebugLogger) LogMemory() {
if dl.base != nil {
var m runtime.MemStats
@@ -80,32 +71,27 @@ func (dl *DebugLogger) LogMemory() {
}
}
// LogGoroutines logs current number of goroutines
func (dl *DebugLogger) LogGoroutines() {
if dl.base != nil {
dl.base.Log(LogLevelDebug, "GOROUTINES: %d active", runtime.NumGoroutine())
}
}
// LogTiming logs timing information for performance debugging
func (dl *DebugLogger) LogTiming(operation string, duration interface{}) {
if dl.base != nil {
dl.base.Log(LogLevelDebug, "TIMING [%s]: %v", operation, duration)
}
}
// Helper function to convert bytes to kilobytes
func bToKb(b uint64) uint64 {
return b / 1024
}
// Global debug logger instance
var (
debugLogger *DebugLogger
debugOnce sync.Once
)
// GetDebugLogger returns the global debug logger instance
func GetDebugLogger() *DebugLogger {
debugOnce.Do(func() {
debugLogger = NewDebugLogger()
@@ -113,47 +99,38 @@ func GetDebugLogger() *DebugLogger {
return debugLogger
}
// Debug logs a debug-level message using the global debug logger
func Debug(format string, v ...interface{}) {
GetDebugLogger().Log(format, v...)
}
// DebugWithContext logs a debug-level message with context using the global debug logger
func DebugWithContext(context string, format string, v ...interface{}) {
GetDebugLogger().LogWithContext(context, format, v...)
}
// DebugFunction logs function entry and exit using the global debug logger
func DebugFunction(functionName string, args ...interface{}) {
GetDebugLogger().LogFunction(functionName, args...)
}
// DebugVariable logs variable values using the global debug logger
func DebugVariable(varName string, value interface{}) {
GetDebugLogger().LogVariable(varName, value)
}
// DebugState logs application state information using the global debug logger
func DebugState(component string, state interface{}) {
GetDebugLogger().LogState(component, state)
}
// DebugSQL logs SQL queries using the global debug logger
func DebugSQL(query string, args ...interface{}) {
GetDebugLogger().LogSQL(query, args...)
}
// DebugMemory logs memory usage information using the global debug logger
func DebugMemory() {
GetDebugLogger().LogMemory()
}
// DebugGoroutines logs current number of goroutines using the global debug logger
func DebugGoroutines() {
GetDebugLogger().LogGoroutines()
}
// DebugTiming logs timing information using the global debug logger
func DebugTiming(operation string, duration interface{}) {
GetDebugLogger().LogTiming(operation, duration)
}

View File

@@ -6,12 +6,10 @@ import (
"sync"
)
// ErrorLogger handles error-level logging
type ErrorLogger struct {
base *BaseLogger
}
// NewErrorLogger creates a new error logger instance
func NewErrorLogger() *ErrorLogger {
base, _ := InitializeBase("error")
return &ErrorLogger{
@@ -19,14 +17,12 @@ func NewErrorLogger() *ErrorLogger {
}
}
// Log writes an error-level log entry
func (el *ErrorLogger) Log(format string, v ...interface{}) {
if el.base != nil {
el.base.Log(LogLevelError, format, v...)
}
}
// LogWithContext writes an error-level log entry with additional context
func (el *ErrorLogger) LogWithContext(context string, format string, v ...interface{}) {
if el.base != nil {
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
@@ -34,7 +30,6 @@ func (el *ErrorLogger) LogWithContext(context string, format string, v ...interf
}
}
// LogError logs an error object with optional message
func (el *ErrorLogger) LogError(err error, message ...string) {
if el.base != nil && err != nil {
if len(message) > 0 {
@@ -45,7 +40,6 @@ func (el *ErrorLogger) LogError(err error, message ...string) {
}
}
// LogWithStackTrace logs an error with stack trace
func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) {
if el.base != nil {
// Get stack trace
@@ -58,7 +52,6 @@ func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) {
}
}
// LogFatal logs a fatal error and exits the program
func (el *ErrorLogger) LogFatal(format string, v ...interface{}) {
if el.base != nil {
el.base.Log(LogLevelError, "[FATAL] "+format, v...)
@@ -66,13 +59,11 @@ func (el *ErrorLogger) LogFatal(format string, v ...interface{}) {
}
}
// Global error logger instance
var (
errorLogger *ErrorLogger
errorOnce sync.Once
)
// GetErrorLogger returns the global error logger instance
func GetErrorLogger() *ErrorLogger {
errorOnce.Do(func() {
errorLogger = NewErrorLogger()
@@ -80,27 +71,22 @@ func GetErrorLogger() *ErrorLogger {
return errorLogger
}
// Error logs an error-level message using the global error logger
func Error(format string, v ...interface{}) {
GetErrorLogger().Log(format, v...)
}
// ErrorWithContext logs an error-level message with context using the global error logger
func ErrorWithContext(context string, format string, v ...interface{}) {
GetErrorLogger().LogWithContext(context, format, v...)
}
// LogError logs an error object using the global error logger
func LogError(err error, message ...string) {
GetErrorLogger().LogError(err, message...)
}
// ErrorWithStackTrace logs an error with stack trace using the global error logger
func ErrorWithStackTrace(format string, v ...interface{}) {
GetErrorLogger().LogWithStackTrace(format, v...)
}
// Fatal logs a fatal error and exits the program using the global error logger
func Fatal(format string, v ...interface{}) {
GetErrorLogger().LogFatal(format, v...)
}

View File

@@ -5,12 +5,10 @@ import (
"sync"
)
// InfoLogger handles info-level logging
type InfoLogger struct {
base *BaseLogger
}
// NewInfoLogger creates a new info logger instance
func NewInfoLogger() *InfoLogger {
base, _ := InitializeBase("info")
return &InfoLogger{
@@ -18,14 +16,12 @@ func NewInfoLogger() *InfoLogger {
}
}
// Log writes an info-level log entry
func (il *InfoLogger) Log(format string, v ...interface{}) {
if il.base != nil {
il.base.Log(LogLevelInfo, format, v...)
}
}
// LogWithContext writes an info-level log entry with additional context
func (il *InfoLogger) LogWithContext(context string, format string, v ...interface{}) {
if il.base != nil {
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
@@ -33,55 +29,47 @@ func (il *InfoLogger) LogWithContext(context string, format string, v ...interfa
}
}
// LogStartup logs application startup information
func (il *InfoLogger) LogStartup(component string, message string) {
if il.base != nil {
il.base.Log(LogLevelInfo, "STARTUP [%s]: %s", component, message)
}
}
// LogShutdown logs application shutdown information
func (il *InfoLogger) LogShutdown(component string, message string) {
if il.base != nil {
il.base.Log(LogLevelInfo, "SHUTDOWN [%s]: %s", component, message)
}
}
// LogOperation logs general operation information
func (il *InfoLogger) LogOperation(operation string, details string) {
if il.base != nil {
il.base.Log(LogLevelInfo, "OPERATION [%s]: %s", operation, details)
}
}
// LogStatus logs status changes or updates
func (il *InfoLogger) LogStatus(component string, status string) {
if il.base != nil {
il.base.Log(LogLevelInfo, "STATUS [%s]: %s", component, status)
}
}
// LogRequest logs incoming requests
func (il *InfoLogger) LogRequest(method string, path string, userAgent string) {
if il.base != nil {
il.base.Log(LogLevelInfo, "REQUEST [%s %s] User-Agent: %s", method, path, userAgent)
}
}
// LogResponse logs outgoing responses
func (il *InfoLogger) LogResponse(method string, path string, statusCode int, duration string) {
if il.base != nil {
il.base.Log(LogLevelInfo, "RESPONSE [%s %s] Status: %d, Duration: %s", method, path, statusCode, duration)
}
}
// Global info logger instance
var (
infoLogger *InfoLogger
infoOnce sync.Once
)
// GetInfoLogger returns the global info logger instance
func GetInfoLogger() *InfoLogger {
infoOnce.Do(func() {
infoLogger = NewInfoLogger()
@@ -89,42 +77,34 @@ func GetInfoLogger() *InfoLogger {
return infoLogger
}
// Info logs an info-level message using the global info logger
func Info(format string, v ...interface{}) {
GetInfoLogger().Log(format, v...)
}
// InfoWithContext logs an info-level message with context using the global info logger
func InfoWithContext(context string, format string, v ...interface{}) {
GetInfoLogger().LogWithContext(context, format, v...)
}
// InfoStartup logs application startup information using the global info logger
func InfoStartup(component string, message string) {
GetInfoLogger().LogStartup(component, message)
}
// InfoShutdown logs application shutdown information using the global info logger
func InfoShutdown(component string, message string) {
GetInfoLogger().LogShutdown(component, message)
}
// InfoOperation logs general operation information using the global info logger
func InfoOperation(operation string, details string) {
GetInfoLogger().LogOperation(operation, details)
}
// InfoStatus logs status changes or updates using the global info logger
func InfoStatus(component string, status string) {
GetInfoLogger().LogStatus(component, status)
}
// InfoRequest logs incoming requests using the global info logger
func InfoRequest(method string, path string, userAgent string) {
GetInfoLogger().LogRequest(method, path, userAgent)
}
// InfoResponse logs outgoing responses using the global info logger
func InfoResponse(method string, path string, statusCode int, duration string) {
GetInfoLogger().LogResponse(method, path, statusCode, duration)
}

View File

@@ -6,12 +6,10 @@ import (
)
var (
// Legacy logger for backward compatibility
logger *Logger
once sync.Once
)
// Logger maintains backward compatibility with existing code
type Logger struct {
base *BaseLogger
errorLogger *ErrorLogger
@@ -20,8 +18,6 @@ type Logger struct {
debugLogger *DebugLogger
}
// Initialize creates or gets the singleton logger instance
// This maintains backward compatibility with existing code
func Initialize() (*Logger, error) {
var err error
once.Do(func() {
@@ -31,13 +27,11 @@ func Initialize() (*Logger, error) {
}
func newLogger() (*Logger, error) {
// Initialize the base logger
baseLogger, err := InitializeBase("log")
if err != nil {
return nil, err
}
// Create the legacy logger wrapper
logger := &Logger{
base: baseLogger,
errorLogger: GetErrorLogger(),
@@ -49,7 +43,6 @@ func newLogger() (*Logger, error) {
return logger, nil
}
// Close closes the logger
func (l *Logger) Close() error {
if l.base != nil {
return l.base.Close()
@@ -57,7 +50,6 @@ func (l *Logger) Close() error {
return nil
}
// Legacy methods for backward compatibility
func (l *Logger) log(level, format string, v ...interface{}) {
if l.base != nil {
l.base.LogWithCaller(LogLevel(level), 3, format, v...)
@@ -94,13 +86,10 @@ func (l *Logger) Panic(format string) {
}
}
// Global convenience functions for backward compatibility
// These are now implemented in individual logger files to avoid redeclaration
func LegacyInfo(format string, v ...interface{}) {
if logger != nil {
logger.Info(format, v...)
} else {
// Fallback to direct logger if legacy logger not initialized
GetInfoLogger().Log(format, v...)
}
}
@@ -109,7 +98,6 @@ func LegacyError(format string, v ...interface{}) {
if logger != nil {
logger.Error(format, v...)
} else {
// Fallback to direct logger if legacy logger not initialized
GetErrorLogger().Log(format, v...)
}
}
@@ -118,7 +106,6 @@ func LegacyWarn(format string, v ...interface{}) {
if logger != nil {
logger.Warn(format, v...)
} else {
// Fallback to direct logger if legacy logger not initialized
GetWarnLogger().Log(format, v...)
}
}
@@ -127,7 +114,6 @@ func LegacyDebug(format string, v ...interface{}) {
if logger != nil {
logger.Debug(format, v...)
} else {
// Fallback to direct logger if legacy logger not initialized
GetDebugLogger().Log(format, v...)
}
}
@@ -136,55 +122,42 @@ func Panic(format string) {
if logger != nil {
logger.Panic(format)
} else {
// Fallback to direct logger if legacy logger not initialized
GetErrorLogger().LogFatal(format)
}
}
// Enhanced logging convenience functions
// These provide direct access to specialized logging functions
// LogStartup logs application startup information
func LogStartup(component string, message string) {
GetInfoLogger().LogStartup(component, message)
}
// LogShutdown logs application shutdown information
func LogShutdown(component string, message string) {
GetInfoLogger().LogShutdown(component, message)
}
// LogOperation logs general operation information
func LogOperation(operation string, details string) {
GetInfoLogger().LogOperation(operation, details)
}
// LogRequest logs incoming HTTP requests
func LogRequest(method string, path string, userAgent string) {
GetInfoLogger().LogRequest(method, path, userAgent)
}
// LogResponse logs outgoing HTTP responses
func LogResponse(method string, path string, statusCode int, duration string) {
GetInfoLogger().LogResponse(method, path, statusCode, duration)
}
// LogSQL logs SQL queries for debugging
func LogSQL(query string, args ...interface{}) {
GetDebugLogger().LogSQL(query, args...)
}
// LogMemory logs memory usage information
func LogMemory() {
GetDebugLogger().LogMemory()
}
// LogTiming logs timing information for performance debugging
func LogTiming(operation string, duration interface{}) {
GetDebugLogger().LogTiming(operation, duration)
}
// GetLegacyLogger returns the legacy logger instance for backward compatibility
func GetLegacyLogger() *Logger {
if logger == nil {
logger, _ = Initialize()
@@ -192,21 +165,17 @@ func GetLegacyLogger() *Logger {
return logger
}
// InitializeLogging initializes all logging components
func InitializeLogging() error {
// Initialize legacy logger for backward compatibility
_, err := Initialize()
if err != nil {
return fmt.Errorf("failed to initialize legacy logger: %v", err)
}
// Pre-initialize all logger types to ensure separate log files
GetErrorLogger()
GetWarnLogger()
GetInfoLogger()
GetDebugLogger()
// Log successful initialization
Info("Logging system initialized successfully")
return nil

View File

@@ -5,12 +5,10 @@ import (
"sync"
)
// WarnLogger handles warn-level logging
type WarnLogger struct {
base *BaseLogger
}
// NewWarnLogger creates a new warn logger instance
func NewWarnLogger() *WarnLogger {
base, _ := InitializeBase("warn")
return &WarnLogger{
@@ -18,14 +16,12 @@ func NewWarnLogger() *WarnLogger {
}
}
// Log writes a warn-level log entry
func (wl *WarnLogger) Log(format string, v ...interface{}) {
if wl.base != nil {
wl.base.Log(LogLevelWarn, format, v...)
}
}
// LogWithContext writes a warn-level log entry with additional context
func (wl *WarnLogger) LogWithContext(context string, format string, v ...interface{}) {
if wl.base != nil {
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
@@ -33,7 +29,6 @@ func (wl *WarnLogger) LogWithContext(context string, format string, v ...interfa
}
}
// LogDeprecation logs a deprecation warning
func (wl *WarnLogger) LogDeprecation(feature string, alternative string) {
if wl.base != nil {
if alternative != "" {
@@ -44,27 +39,23 @@ func (wl *WarnLogger) LogDeprecation(feature string, alternative string) {
}
}
// LogConfiguration logs configuration-related warnings
func (wl *WarnLogger) LogConfiguration(setting string, message string) {
if wl.base != nil {
wl.base.Log(LogLevelWarn, "CONFIG WARNING [%s]: %s", setting, message)
}
}
// LogPerformance logs performance-related warnings
func (wl *WarnLogger) LogPerformance(operation string, threshold string, actual string) {
if wl.base != nil {
wl.base.Log(LogLevelWarn, "PERFORMANCE WARNING [%s]: exceeded threshold %s, actual: %s", operation, threshold, actual)
}
}
// Global warn logger instance
var (
warnLogger *WarnLogger
warnOnce sync.Once
)
// GetWarnLogger returns the global warn logger instance
func GetWarnLogger() *WarnLogger {
warnOnce.Do(func() {
warnLogger = NewWarnLogger()
@@ -72,27 +63,22 @@ func GetWarnLogger() *WarnLogger {
return warnLogger
}
// Warn logs a warn-level message using the global warn logger
func Warn(format string, v ...interface{}) {
GetWarnLogger().Log(format, v...)
}
// WarnWithContext logs a warn-level message with context using the global warn logger
func WarnWithContext(context string, format string, v ...interface{}) {
GetWarnLogger().LogWithContext(context, format, v...)
}
// WarnDeprecation logs a deprecation warning using the global warn logger
func WarnDeprecation(feature string, alternative string) {
GetWarnLogger().LogDeprecation(feature, alternative)
}
// WarnConfiguration logs configuration-related warnings using the global warn logger
func WarnConfiguration(setting string, message string) {
GetWarnLogger().LogConfiguration(setting, message)
}
// WarnPerformance logs performance-related warnings using the global warn logger
func WarnPerformance(operation string, threshold string, actual string) {
GetWarnLogger().LogPerformance(operation, threshold, actual)
}

View File

@@ -6,12 +6,10 @@ import (
"time"
)
// IsPortAvailable checks if a port is available for both TCP and UDP
func IsPortAvailable(port int) bool {
return IsTCPPortAvailable(port) && IsUDPPortAvailable(port)
}
// IsTCPPortAvailable checks if a TCP port is available
func IsTCPPortAvailable(port int) bool {
addr := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", addr)
@@ -22,7 +20,6 @@ func IsTCPPortAvailable(port int) bool {
return true
}
// IsUDPPortAvailable checks if a UDP port is available
func IsUDPPortAvailable(port int) bool {
conn, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
if err != nil {
@@ -32,7 +29,6 @@ func IsUDPPortAvailable(port int) bool {
return true
}
// FindAvailablePort finds an available port starting from the given port
func FindAvailablePort(startPort int) (int, error) {
maxPort := 65535
for port := startPort; port <= maxPort; port++ {
@@ -43,14 +39,12 @@ func FindAvailablePort(startPort int) (int, error) {
return 0, fmt.Errorf("no available ports found between %d and %d", startPort, maxPort)
}
// FindAvailablePortRange finds a range of consecutive available ports
func FindAvailablePortRange(startPort, count int) ([]int, error) {
maxPort := 65535
ports := make([]int, 0, count)
currentPort := startPort
for len(ports) < count && currentPort <= maxPort {
// Check if we have enough consecutive ports available
available := true
for i := 0; i < count-len(ports); i++ {
if !IsPortAvailable(currentPort + i) {
@@ -74,7 +68,6 @@ func FindAvailablePortRange(startPort, count int) ([]int, error) {
return ports, nil
}
// WaitForPortAvailable waits for a port to become available with timeout
func WaitForPortAvailable(port int, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
@@ -84,4 +77,4 @@ func WaitForPortAvailable(port int, timeout time.Duration) error {
time.Sleep(100 * time.Millisecond)
}
return fmt.Errorf("timeout waiting for port %d to become available", port)
}
}

View File

@@ -8,13 +8,10 @@ import (
)
const (
// MinPasswordLength defines the minimum password length
MinPasswordLength = 8
// BcryptCost defines the cost factor for bcrypt hashing
BcryptCost = 12
BcryptCost = 12
)
// HashPassword hashes a plain text password using bcrypt
func HashPassword(password string) (string, error) {
if len(password) < MinPasswordLength {
return "", errors.New("password must be at least 8 characters long")
@@ -28,12 +25,10 @@ func HashPassword(password string) (string, error) {
return string(hashedBytes), nil
}
// VerifyPassword verifies a plain text password against a hashed password
func VerifyPassword(hashedPassword, password string) error {
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
}
// ValidatePasswordStrength validates password complexity requirements
func ValidatePasswordStrength(password string) error {
if len(password) < MinPasswordLength {
return errors.New("password must be at least 8 characters long")

View File

@@ -17,25 +17,23 @@ import (
func Start(di *dig.Container) *fiber.App {
app := fiber.New(fiber.Config{
EnablePrintRoutes: true,
ReadTimeout: 20 * time.Minute, // Increased for long-running Steam operations
WriteTimeout: 20 * time.Minute, // Increased for long-running Steam operations
IdleTimeout: 25 * time.Minute, // Increased accordingly
BodyLimit: 10 * 1024 * 1024, // 10MB
ReadTimeout: 20 * time.Minute,
WriteTimeout: 20 * time.Minute,
IdleTimeout: 25 * time.Minute,
BodyLimit: 10 * 1024 * 1024,
})
// Initialize security middleware
securityMW := security.NewSecurityMiddleware()
// Add security middleware stack
app.Use(securityMW.SecurityHeaders())
app.Use(securityMW.LogSecurityEvents())
app.Use(securityMW.TimeoutMiddleware(20 * time.Minute)) // Increased for Steam operations
app.Use(securityMW.RequestContextTimeout(20 * time.Minute)) // Increased for Steam operations
app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024)) // 10MB
app.Use(securityMW.TimeoutMiddleware(20 * time.Minute))
app.Use(securityMW.RequestContextTimeout(20 * time.Minute))
app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024))
app.Use(securityMW.ValidateUserAgent())
app.Use(securityMW.ValidateContentType("application/json", "application/x-www-form-urlencoded", "multipart/form-data"))
app.Use(securityMW.InputSanitization())
app.Use(securityMW.RateLimit(100, 1*time.Minute)) // 100 requests per minute global
app.Use(securityMW.RateLimit(100, 1*time.Minute))
app.Use(helmet.New())
@@ -62,7 +60,7 @@ func Start(di *dig.Container) *fiber.App {
port := os.Getenv("PORT")
if port == "" {
port = "3000" // Default port
port = "3000"
}
logging.Info("Starting server on port %s", port)

View File

@@ -7,11 +7,11 @@ import (
)
type LogTailer struct {
filePath string
handleLine func(string)
stopChan chan struct{}
isRunning bool
tracker *PositionTracker
filePath string
handleLine func(string)
stopChan chan struct{}
isRunning bool
tracker *PositionTracker
}
func NewLogTailer(filePath string, handleLine func(string)) *LogTailer {
@@ -30,10 +30,9 @@ func (t *LogTailer) Start() {
t.isRunning = true
go func() {
// Load last position from tracker
pos, err := t.tracker.LoadPosition()
if err != nil {
pos = &LogPosition{} // Start from beginning if error
pos = &LogPosition{}
}
lastSize := pos.LastPosition
@@ -43,7 +42,6 @@ func (t *LogTailer) Start() {
t.isRunning = false
return
default:
// Try to open and read the file
if file, err := os.Open(t.filePath); err == nil {
stat, err := file.Stat()
if err != nil {
@@ -52,12 +50,10 @@ func (t *LogTailer) Start() {
continue
}
// If file was truncated, start from beginning
if stat.Size() < lastSize {
lastSize = 0
}
// Seek to last read position
if lastSize > 0 {
file.Seek(lastSize, 0)
}
@@ -66,9 +62,8 @@ func (t *LogTailer) Start() {
for scanner.Scan() {
line := scanner.Text()
t.handleLine(line)
lastSize, _ = file.Seek(0, 1) // Get current position
// Save position periodically
lastSize, _ = file.Seek(0, 1)
t.tracker.SavePosition(&LogPosition{
LastPosition: lastSize,
LastRead: line,
@@ -78,7 +73,6 @@ func (t *LogTailer) Start() {
file.Close()
}
// Wait before next attempt
time.Sleep(time.Second)
}
}
@@ -90,4 +84,4 @@ func (t *LogTailer) Stop() {
return
}
close(t.stopChan)
}
}

View File

@@ -7,8 +7,8 @@ import (
)
type LogPosition struct {
LastPosition int64 `json:"last_position"`
LastRead string `json:"last_read"`
LastPosition int64 `json:"last_position"`
LastRead string `json:"last_read"`
}
type PositionTracker struct {
@@ -16,11 +16,10 @@ type PositionTracker struct {
}
func NewPositionTracker(logPath string) *PositionTracker {
// Create position file in same directory as log file
dir := filepath.Dir(logPath)
base := filepath.Base(logPath)
positionFile := filepath.Join(dir, "."+base+".position")
return &PositionTracker{
positionFile: positionFile,
}
@@ -30,7 +29,6 @@ func (t *PositionTracker) LoadPosition() (*LogPosition, error) {
data, err := os.ReadFile(t.positionFile)
if err != nil {
if os.IsNotExist(err) {
// Return empty position if file doesn't exist
return &LogPosition{}, nil
}
return nil, err
@@ -51,4 +49,4 @@ func (t *PositionTracker) SavePosition(pos *LogPosition) error {
}
return os.WriteFile(t.positionFile, data, 0644)
}
}

View File

@@ -80,7 +80,7 @@ func TailLogFile(path string, callback func(string)) {
file, _ := os.Open(path)
defer file.Close()
file.Seek(0, os.SEEK_END) // Start at end of file
file.Seek(0, os.SEEK_END)
reader := bufio.NewReader(file)
for {
@@ -88,7 +88,7 @@ func TailLogFile(path string, callback func(string)) {
if err == nil {
callback(line)
} else {
time.Sleep(500 * time.Millisecond) // wait for new data
time.Sleep(500 * time.Millisecond)
}
}
}

View File

@@ -0,0 +1,169 @@
package websocket
import (
"acc-server-manager/local/model"
"acc-server-manager/local/utl/logging"
"encoding/json"
"sync"
"time"
"github.com/gofiber/websocket/v2"
"github.com/google/uuid"
)
type WebSocketConnection struct {
conn *websocket.Conn
serverID *uuid.UUID
userID *uuid.UUID
}
type WebSocketService struct {
connections sync.Map
mu sync.RWMutex
}
func NewWebSocketService() *WebSocketService {
return &WebSocketService{}
}
func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, userID *uuid.UUID) {
wsConn := &WebSocketConnection{
conn: conn,
userID: userID,
}
ws.connections.Store(connID, wsConn)
logging.Info("WebSocket connection added: %s for user: %v", connID, userID)
}
func (ws *WebSocketService) RemoveConnection(connID string) {
if conn, exists := ws.connections.LoadAndDelete(connID); exists {
if wsConn, ok := conn.(*WebSocketConnection); ok {
wsConn.conn.Close()
}
}
logging.Info("WebSocket connection removed: %s", connID)
}
func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) {
if conn, exists := ws.connections.Load(connID); exists {
if wsConn, ok := conn.(*WebSocketConnection); ok {
wsConn.serverID = &serverID
}
}
}
func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerCreationStep, status model.StepStatus, message string, errorMsg string) {
stepMsg := model.StepMessage{
Step: step,
Status: status,
Message: message,
Error: errorMsg,
}
wsMsg := model.WebSocketMessage{
Type: model.MessageTypeStep,
ServerID: &serverID,
Timestamp: time.Now().Unix(),
Data: stepMsg,
}
ws.broadcastToServer(serverID, wsMsg)
}
func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output string, isError bool) {
steamMsg := model.SteamOutputMessage{
Output: output,
IsError: isError,
}
wsMsg := model.WebSocketMessage{
Type: model.MessageTypeSteamOutput,
ServerID: &serverID,
Timestamp: time.Now().Unix(),
Data: steamMsg,
}
ws.broadcastToServer(serverID, wsMsg)
}
func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, details string) {
errorMsg := model.ErrorMessage{
Error: error,
Details: details,
}
wsMsg := model.WebSocketMessage{
Type: model.MessageTypeError,
ServerID: &serverID,
Timestamp: time.Now().Unix(),
Data: errorMsg,
}
ws.broadcastToServer(serverID, wsMsg)
}
func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, message string) {
completeMsg := model.CompleteMessage{
ServerID: serverID,
Success: success,
Message: message,
}
wsMsg := model.WebSocketMessage{
Type: model.MessageTypeComplete,
ServerID: &serverID,
Timestamp: time.Now().Unix(),
Data: completeMsg,
}
ws.broadcastToServer(serverID, wsMsg)
}
func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.WebSocketMessage) {
data, err := json.Marshal(message)
if err != nil {
logging.Error("Failed to marshal WebSocket message: %v", err)
return
}
ws.connections.Range(func(key, value interface{}) bool {
if wsConn, ok := value.(*WebSocketConnection); ok {
if wsConn.serverID != nil && *wsConn.serverID == serverID {
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
ws.RemoveConnection(key.(string))
}
}
}
return true
})
}
func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebSocketMessage) {
data, err := json.Marshal(message)
if err != nil {
logging.Error("Failed to marshal WebSocket message: %v", err)
return
}
ws.connections.Range(func(key, value interface{}) bool {
if wsConn, ok := value.(*WebSocketConnection); ok {
if wsConn.userID != nil && *wsConn.userID == userID {
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
ws.RemoveConnection(key.(string))
}
}
}
return true
})
}
func (ws *WebSocketService) GetActiveConnections() int {
count := 0
ws.connections.Range(func(key, value interface{}) bool {
count++
return true
})
return count
}

View File

@@ -1,4 +1,3 @@
// Package swagger Code generated by swaggo/swag. DO NOT EDIT
package swagger
import "github.com/swaggo/swag"
@@ -2239,7 +2238,6 @@ const docTemplate = `{
}
}`
// SwaggerInfo holds exported Swagger Info so clients can modify it
var SwaggerInfo = &swag.Spec{
Version: "1.0",
Host: "acc-api.jurmanovic.com",

View File

@@ -10,24 +10,19 @@ import (
"github.com/google/uuid"
)
// GenerateTestToken creates a JWT token for testing purposes
func GenerateTestToken() (string, error) {
// Create test user
user := &model.User{
ID: uuid.New(),
Username: "test_user",
RoleID: uuid.New(),
}
// Use the environment JWT_SECRET for consistency with middleware
testSecret := os.Getenv("JWT_SECRET")
if testSecret == "" {
// Fallback to a test secret if env var is not set
testSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(testSecret)
// Generate JWT token
token, err := jwtHandler.GenerateToken(user.ID.String())
if err != nil {
return "", fmt.Errorf("failed to generate test token: %w", err)
@@ -36,8 +31,6 @@ func GenerateTestToken() (string, error) {
return token, nil
}
// MustGenerateTestToken generates a test token and panics if it fails
// This is useful for test setup where failing to generate a token is a fatal error
func MustGenerateTestToken() string {
token, err := GenerateTestToken()
if err != nil {
@@ -46,24 +39,19 @@ func MustGenerateTestToken() string {
return token
}
// GenerateTestTokenWithExpiry creates a JWT token with a specific expiry time
func GenerateTestTokenWithExpiry(expiryTime time.Time) (string, error) {
// Use the environment JWT_SECRET for consistency with middleware
testSecret := os.Getenv("JWT_SECRET")
if testSecret == "" {
// Fallback to a test secret if env var is not set
testSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(testSecret)
// Create test user
user := &model.User{
ID: uuid.New(),
Username: "test_user",
RoleID: uuid.New(),
}
// Generate JWT token with custom expiry
token, err := jwtHandler.GenerateTokenWithExpiry(user, expiryTime)
if err != nil {
return "", fmt.Errorf("failed to generate test token with expiry: %w", err)
@@ -72,8 +60,6 @@ func GenerateTestTokenWithExpiry(expiryTime time.Time) (string, error) {
return token, nil
}
// AddAuthHeader adds a test auth token to the request headers
// This is a convenience method for tests that need to authenticate requests
func AddAuthHeader(headers map[string]string) (map[string]string, error) {
token, err := GenerateTestToken()
if err != nil {
@@ -88,7 +74,6 @@ func AddAuthHeader(headers map[string]string) (map[string]string, error) {
return headers, nil
}
// MustAddAuthHeader adds a test auth token to the request headers and panics if it fails
func MustAddAuthHeader(headers map[string]string) map[string]string {
result, err := AddAuthHeader(headers)
if err != nil {

View File

@@ -8,26 +8,20 @@ import (
"github.com/google/uuid"
)
// MockAuthMiddleware provides a test implementation of AuthMiddleware
// that can be used as a drop-in replacement for the real AuthMiddleware
type MockAuthMiddleware struct{}
// NewMockAuthMiddleware creates a new MockAuthMiddleware
func NewMockAuthMiddleware() *MockAuthMiddleware {
return &MockAuthMiddleware{}
}
// Authenticate is a middleware that allows all requests without authentication for testing
func (m *MockAuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
// Set a mock user ID in context
mockUserID := uuid.New().String()
ctx.Locals("userID", mockUserID)
// Set mock user info
mockUserInfo := &middleware.CachedUserInfo{
UserID: mockUserID,
Username: "test_user",
RoleName: "Admin", // Admin role to bypass permission checks
RoleName: "Admin",
Permissions: map[string]bool{"*": true},
CachedAt: time.Now(),
}
@@ -38,21 +32,18 @@ func (m *MockAuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
return ctx.Next()
}
// HasPermission is a middleware that allows all permission checks to pass for testing
func (m *MockAuthMiddleware) HasPermission(requiredPermission string) fiber.Handler {
return func(ctx *fiber.Ctx) error {
return ctx.Next()
}
}
// AuthRateLimit is a test implementation that allows all requests
func (m *MockAuthMiddleware) AuthRateLimit() fiber.Handler {
return func(ctx *fiber.Ctx) error {
return ctx.Next()
}
}
// RequireHTTPS is a test implementation that allows all HTTP requests
func (m *MockAuthMiddleware) RequireHTTPS() fiber.Handler {
return func(ctx *fiber.Ctx) error {
return ctx.Next()

View File

@@ -8,7 +8,6 @@ import (
"github.com/google/uuid"
)
// MockConfigRepository provides a mock implementation of ConfigRepository
type MockConfigRepository struct {
configs map[string]*model.Config
shouldFailGet bool
@@ -21,7 +20,6 @@ func NewMockConfigRepository() *MockConfigRepository {
}
}
// UpdateConfig mocks the UpdateConfig method
func (m *MockConfigRepository) UpdateConfig(ctx context.Context, config *model.Config) *model.Config {
if m.shouldFailUpdate {
return nil
@@ -36,18 +34,15 @@ func (m *MockConfigRepository) UpdateConfig(ctx context.Context, config *model.C
return config
}
// SetShouldFailUpdate configures the mock to fail on UpdateConfig calls
func (m *MockConfigRepository) SetShouldFailUpdate(shouldFail bool) {
m.shouldFailUpdate = shouldFail
}
// GetConfig retrieves a config by server ID and config file
func (m *MockConfigRepository) GetConfig(serverID uuid.UUID, configFile string) *model.Config {
key := serverID.String() + "_" + configFile
return m.configs[key]
}
// MockServerRepository provides a mock implementation of ServerRepository
type MockServerRepository struct {
servers map[uuid.UUID]*model.Server
shouldFailGet bool
@@ -59,7 +54,6 @@ func NewMockServerRepository() *MockServerRepository {
}
}
// GetByID mocks the GetByID method
func (m *MockServerRepository) GetByID(ctx context.Context, id interface{}) (*model.Server, error) {
if m.shouldFailGet {
return nil, errors.New("server not found")
@@ -88,17 +82,14 @@ func (m *MockServerRepository) GetByID(ctx context.Context, id interface{}) (*mo
return server, nil
}
// AddServer adds a server to the mock repository
func (m *MockServerRepository) AddServer(server *model.Server) {
m.servers[server.ID] = server
}
// SetShouldFailGet configures the mock to fail on GetByID calls
func (m *MockServerRepository) SetShouldFailGet(shouldFail bool) {
m.shouldFailGet = shouldFail
}
// MockServerService provides a mock implementation of ServerService
type MockServerService struct {
startRuntimeCalled bool
startRuntimeServer *model.Server
@@ -108,23 +99,19 @@ func NewMockServerService() *MockServerService {
return &MockServerService{}
}
// StartAccServerRuntime mocks the StartAccServerRuntime method
func (m *MockServerService) StartAccServerRuntime(server *model.Server) {
m.startRuntimeCalled = true
m.startRuntimeServer = server
}
// WasStartRuntimeCalled returns whether StartAccServerRuntime was called
func (m *MockServerService) WasStartRuntimeCalled() bool {
return m.startRuntimeCalled
}
// GetStartRuntimeServer returns the server passed to StartAccServerRuntime
func (m *MockServerService) GetStartRuntimeServer() *model.Server {
return m.startRuntimeServer
}
// Reset resets the mock state
func (m *MockServerService) Reset() {
m.startRuntimeCalled = false
m.startRuntimeServer = nil

View File

@@ -8,7 +8,6 @@ import (
"github.com/google/uuid"
)
// MockStateHistoryRepository provides a mock implementation of StateHistoryRepository
type MockStateHistoryRepository struct {
stateHistories []model.StateHistory
shouldFailGet bool
@@ -21,7 +20,6 @@ func NewMockStateHistoryRepository() *MockStateHistoryRepository {
}
}
// GetAll mocks the GetAll method
func (m *MockStateHistoryRepository) GetAll(ctx context.Context, filter *model.StateHistoryFilter) (*[]model.StateHistory, error) {
if m.shouldFailGet {
return nil, errors.New("failed to get state history")
@@ -37,13 +35,11 @@ func (m *MockStateHistoryRepository) GetAll(ctx context.Context, filter *model.S
return &filtered, nil
}
// Insert mocks the Insert method
func (m *MockStateHistoryRepository) Insert(ctx context.Context, stateHistory *model.StateHistory) error {
if m.shouldFailInsert {
return errors.New("failed to insert state history")
}
// Simulate BeforeCreate hook
if stateHistory.ID == uuid.Nil {
stateHistory.ID = uuid.New()
}
@@ -55,7 +51,6 @@ func (m *MockStateHistoryRepository) Insert(ctx context.Context, stateHistory *m
return nil
}
// GetLastSessionID mocks the GetLastSessionID method
func (m *MockStateHistoryRepository) GetLastSessionID(ctx context.Context, serverID uuid.UUID) (uuid.UUID, error) {
for i := len(m.stateHistories) - 1; i >= 0; i-- {
if m.stateHistories[i].ServerID == serverID {
@@ -65,7 +60,6 @@ func (m *MockStateHistoryRepository) GetLastSessionID(ctx context.Context, serve
return uuid.Nil, nil
}
// Helper methods for filtering
func (m *MockStateHistoryRepository) matchesFilter(sh model.StateHistory, filter *model.StateHistoryFilter) bool {
if filter == nil {
return true
@@ -93,7 +87,6 @@ func (m *MockStateHistoryRepository) matchesFilter(sh model.StateHistory, filter
return true
}
// Helper methods for testing configuration
func (m *MockStateHistoryRepository) SetShouldFailGet(shouldFail bool) {
m.shouldFailGet = shouldFail
}
@@ -102,7 +95,6 @@ func (m *MockStateHistoryRepository) SetShouldFailInsert(shouldFail bool) {
m.shouldFailInsert = shouldFail
}
// AddStateHistory adds a state history entry to the mock repository
func (m *MockStateHistoryRepository) AddStateHistory(stateHistory model.StateHistory) {
if stateHistory.ID == uuid.Nil {
stateHistory.ID = uuid.New()
@@ -113,22 +105,18 @@ func (m *MockStateHistoryRepository) AddStateHistory(stateHistory model.StateHis
m.stateHistories = append(m.stateHistories, stateHistory)
}
// GetCount returns the number of state history entries
func (m *MockStateHistoryRepository) GetCount() int {
return len(m.stateHistories)
}
// Clear removes all state history entries
func (m *MockStateHistoryRepository) Clear() {
m.stateHistories = make([]model.StateHistory, 0)
}
// GetSummaryStats calculates peak players, total sessions, and average players for mock data
func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter *model.StateHistoryFilter) (model.StateHistoryStats, error) {
var stats model.StateHistoryStats
var filteredEntries []model.StateHistory
// Filter entries
for _, entry := range m.stateHistories {
if m.matchesFilter(entry, filter) {
filteredEntries = append(filteredEntries, entry)
@@ -139,7 +127,6 @@ func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter
return stats, nil
}
// Calculate statistics
sessionMap := make(map[string]bool)
totalPlayers := 0
@@ -159,11 +146,9 @@ func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter
return stats, nil
}
// GetTotalPlaytime calculates total playtime in minutes for mock data
func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *model.StateHistoryFilter) (int, error) {
var filteredEntries []model.StateHistory
// Filter entries
for _, entry := range m.stateHistories {
if m.matchesFilter(entry, filter) {
filteredEntries = append(filteredEntries, entry)
@@ -174,7 +159,6 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte
return 0, nil
}
// Group by session and calculate durations
sessionMap := make(map[string][]model.StateHistory)
for _, entry := range filteredEntries {
sessionID := entry.SessionID.String()
@@ -184,7 +168,6 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte
totalMinutes := 0
for _, sessionEntries := range sessionMap {
if len(sessionEntries) > 1 {
// Sort by date (simple approach for mock)
minTime := sessionEntries[0].DateCreated
maxTime := sessionEntries[0].DateCreated
hasPlayers := false
@@ -211,26 +194,22 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte
return totalMinutes, nil
}
// GetPlayerCountOverTime returns downsampled player count data for mock
func (m *MockStateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, filter *model.StateHistoryFilter) ([]model.PlayerCountPoint, error) {
var points []model.PlayerCountPoint
var filteredEntries []model.StateHistory
// Filter entries
for _, entry := range m.stateHistories {
if m.matchesFilter(entry, filter) {
filteredEntries = append(filteredEntries, entry)
}
}
// Group by hour (simple mock implementation)
hourMap := make(map[string][]int)
for _, entry := range filteredEntries {
hourKey := entry.DateCreated.Format("2006-01-02 15")
hourMap[hourKey] = append(hourMap[hourKey], entry.PlayerCount)
}
// Calculate averages per hour
for hourKey, counts := range hourMap {
total := 0
for _, count := range counts {
@@ -247,20 +226,17 @@ func (m *MockStateHistoryRepository) GetPlayerCountOverTime(ctx context.Context,
return points, nil
}
// GetSessionTypes counts sessions by type for mock
func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter *model.StateHistoryFilter) ([]model.SessionCount, error) {
var sessionTypes []model.SessionCount
var filteredEntries []model.StateHistory
// Filter entries
for _, entry := range m.stateHistories {
if m.matchesFilter(entry, filter) {
filteredEntries = append(filteredEntries, entry)
}
}
// Group by session type
sessionMap := make(map[model.TrackSession]map[string]bool) // session -> sessionID -> bool
sessionMap := make(map[model.TrackSession]map[string]bool)
for _, entry := range filteredEntries {
if sessionMap[entry.Session] == nil {
sessionMap[entry.Session] = make(map[string]bool)
@@ -268,7 +244,6 @@ func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter
sessionMap[entry.Session][entry.SessionID.String()] = true
}
// Count unique sessions per type
for sessionType, sessions := range sessionMap {
sessionTypes = append(sessionTypes, model.SessionCount{
Name: sessionType,
@@ -279,20 +254,17 @@ func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter
return sessionTypes, nil
}
// GetDailyActivity counts sessions per day for mock
func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filter *model.StateHistoryFilter) ([]model.DailyActivity, error) {
var dailyActivity []model.DailyActivity
var filteredEntries []model.StateHistory
// Filter entries
for _, entry := range m.stateHistories {
if m.matchesFilter(entry, filter) {
filteredEntries = append(filteredEntries, entry)
}
}
// Group by day
dayMap := make(map[string]map[string]bool) // date -> sessionID -> bool
dayMap := make(map[string]map[string]bool)
for _, entry := range filteredEntries {
dateKey := entry.DateCreated.Format("2006-01-02")
if dayMap[dateKey] == nil {
@@ -301,7 +273,6 @@ func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filte
dayMap[dateKey][entry.SessionID.String()] = true
}
// Count unique sessions per day
for date, sessions := range dayMap {
dailyActivity = append(dailyActivity, model.DailyActivity{
Date: date,
@@ -312,26 +283,22 @@ func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filte
return dailyActivity, nil
}
// GetRecentSessions retrieves recent sessions for mock
func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filter *model.StateHistoryFilter) ([]model.RecentSession, error) {
var recentSessions []model.RecentSession
var filteredEntries []model.StateHistory
// Filter entries
for _, entry := range m.stateHistories {
if m.matchesFilter(entry, filter) {
filteredEntries = append(filteredEntries, entry)
}
}
// Group by session
sessionMap := make(map[string][]model.StateHistory)
for _, entry := range filteredEntries {
sessionID := entry.SessionID.String()
sessionMap[sessionID] = append(sessionMap[sessionID], entry)
}
// Create recent sessions (limit to 10)
count := 0
for _, entries := range sessionMap {
if count >= 10 {
@@ -339,7 +306,6 @@ func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filt
}
if len(entries) > 0 {
// Find min/max dates and max players
minDate := entries[0].DateCreated
maxDate := entries[0].DateCreated
maxPlayers := 0
@@ -356,7 +322,6 @@ func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filt
}
}
// Only include sessions with players
if maxPlayers > 0 {
duration := int(maxDate.Sub(minDate).Minutes())
recentSessions = append(recentSessions, model.RecentSession{

View File

@@ -24,14 +24,12 @@ import (
"gorm.io/gorm/logger"
)
// TestHelper provides utilities for testing
type TestHelper struct {
DB *gorm.DB
TempDir string
TestData *TestData
}
// TestData contains common test data structures
type TestData struct {
ServerID uuid.UUID
Server *model.Server
@@ -39,37 +37,29 @@ type TestData struct {
SampleConfig *model.Configuration
}
// SetTestEnv sets the required environment variables for tests
func SetTestEnv() {
// Set required environment variables for testing
os.Setenv("APP_SECRET", "test-secret-key-for-testing-123456")
os.Setenv("APP_SECRET_CODE", "test-code-for-testing-123456789012")
os.Setenv("ENCRYPTION_KEY", "12345678901234567890123456789012")
os.Setenv("JWT_SECRET", "test-jwt-secret-key-for-testing-123456789012345678901234567890")
os.Setenv("ACCESS_KEY", "test-access-key-for-testing")
// Set test-specific environment variables
os.Setenv("TESTING_ENV", "true") // Used to bypass
os.Setenv("TESTING_ENV", "true")
configs.Init()
}
// NewTestHelper creates a new test helper with in-memory database
func NewTestHelper(t *testing.T) *TestHelper {
// Set required environment variables
SetTestEnv()
// Create temporary directory for test files
tempDir := t.TempDir()
// Create in-memory SQLite database for testing
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent), // Suppress SQL logs in tests
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("Failed to connect to test database: %v", err)
}
// Auto-migrate the schema
err = db.AutoMigrate(
&model.Server{},
&model.Config{},
@@ -79,7 +69,6 @@ func NewTestHelper(t *testing.T) *TestHelper {
&model.StateHistory{},
)
// Explicitly ensure tables exist with correct structure
if !db.Migrator().HasTable(&model.StateHistory{}) {
err = db.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -90,7 +79,6 @@ func NewTestHelper(t *testing.T) *TestHelper {
t.Fatalf("Failed to migrate test database: %v", err)
}
// Create test data
testData := createTestData(t, tempDir)
return &TestHelper{
@@ -100,11 +88,9 @@ func NewTestHelper(t *testing.T) *TestHelper {
}
}
// createTestData creates common test data structures
func createTestData(t *testing.T, tempDir string) *TestData {
serverID := uuid.New()
// Create sample server
server := &model.Server{
ID: serverID,
Name: "Test Server",
@@ -115,13 +101,11 @@ func createTestData(t *testing.T, tempDir string) *TestData {
FromSteamCMD: false,
}
// Create server directory
serverConfigDir := filepath.Join(tempDir, "server", "cfg")
if err := os.MkdirAll(serverConfigDir, 0755); err != nil {
t.Fatalf("Failed to create server config directory: %v", err)
}
// Sample configuration files content
configFiles := map[string]string{
"configuration.json": `{
"udpPort": "9231",
@@ -212,7 +196,6 @@ func createTestData(t *testing.T, tempDir string) *TestData {
}`,
}
// Sample configuration struct
sampleConfig := &model.Configuration{
UdpPort: model.IntString(9231),
TcpPort: model.IntString(9232),
@@ -230,14 +213,12 @@ func createTestData(t *testing.T, tempDir string) *TestData {
}
}
// CreateTestConfigFiles creates actual config files in the test directory
func (th *TestHelper) CreateTestConfigFiles() error {
serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg")
for filename, content := range th.TestData.ConfigFiles {
filePath := filepath.Join(serverConfigDir, filename)
// Encode content to UTF-16 LE BOM format as expected by the application
utf16Content, err := EncodeUTF16LEBOM([]byte(content))
if err != nil {
return err
@@ -251,7 +232,6 @@ func (th *TestHelper) CreateTestConfigFiles() error {
return nil
}
// CreateMalformedConfigFile creates a config file with invalid JSON
func (th *TestHelper) CreateMalformedConfigFile(filename string) error {
serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg")
filePath := filepath.Join(serverConfigDir, filename)
@@ -259,67 +239,51 @@ func (th *TestHelper) CreateMalformedConfigFile(filename string) error {
malformedJSON := `{
"udpPort": "9231",
"tcpPort": "9232"
"maxConnections": "30" // Missing comma - invalid JSON
"maxConnections": "30"
}`
return os.WriteFile(filePath, []byte(malformedJSON), 0644)
}
// RemoveConfigFile removes a config file to simulate missing file scenarios
func (th *TestHelper) RemoveConfigFile(filename string) error {
serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg")
filePath := filepath.Join(serverConfigDir, filename)
return os.Remove(filePath)
}
// InsertTestServer inserts the test server into the database
func (th *TestHelper) InsertTestServer() error {
return th.DB.Create(th.TestData.Server).Error
}
// CreateContext creates a test context
func (th *TestHelper) CreateContext() context.Context {
return context.Background()
}
// CreateFiberCtx creates a fiber.Ctx for testing
func (th *TestHelper) CreateFiberCtx() *fiber.Ctx {
// Create app and request for fiber context
app := fiber.New()
// Create a dummy request that doesn't depend on external http objects
req := httptest.NewRequest("GET", "/", nil)
// Create the fiber context from real request/response
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// Store the original request for release later
ctx.Locals("original-request", req)
// Return the context which can be safely used in tests
return ctx
}
// ReleaseFiberCtx properly releases a fiber context created with CreateFiberCtx
func (th *TestHelper) ReleaseFiberCtx(app *fiber.App, ctx *fiber.Ctx) {
if app != nil && ctx != nil {
app.ReleaseCtx(ctx)
}
}
// Cleanup performs cleanup operations after tests
func (th *TestHelper) Cleanup() {
// Close database connection
if sqlDB, err := th.DB.DB(); err == nil {
sqlDB.Close()
}
// Temporary directory is automatically cleaned up by t.TempDir()
}
// LoadTestEnvFile loads environment variables from a .env file for testing
func LoadTestEnvFile() error {
// Try to load from .env file
return godotenv.Load()
}
// AssertNoError is a helper function to check for errors in tests
func AssertNoError(t *testing.T, err error) {
t.Helper()
if err != nil {
@@ -327,7 +291,6 @@ func AssertNoError(t *testing.T, err error) {
}
}
// AssertError is a helper function to check for expected errors
func AssertError(t *testing.T, err error, expectedMsg string) {
t.Helper()
if err == nil {
@@ -338,7 +301,6 @@ func AssertError(t *testing.T, err error, expectedMsg string) {
}
}
// AssertEqual checks if two values are equal
func AssertEqual(t *testing.T, expected, actual interface{}) {
t.Helper()
if expected != actual {
@@ -346,7 +308,6 @@ func AssertEqual(t *testing.T, expected, actual interface{}) {
}
}
// AssertNotNil checks if a value is not nil
func AssertNotNil(t *testing.T, value interface{}) {
t.Helper()
if value == nil {
@@ -354,12 +315,9 @@ func AssertNotNil(t *testing.T, value interface{}) {
}
}
// AssertNil checks if a value is nil
func AssertNil(t *testing.T, value interface{}) {
t.Helper()
if value != nil {
// Special handling for interface values that contain nil but aren't nil themselves
// For example, (*jwt.Claims)(nil) is not equal to nil, but it contains nil
switch v := value.(type) {
case *interface{}:
if v == nil || *v == nil {
@@ -374,13 +332,11 @@ func AssertNil(t *testing.T, value interface{}) {
}
}
// EncodeUTF16LEBOM encodes UTF-8 bytes to UTF-16 LE BOM format
func EncodeUTF16LEBOM(input []byte) ([]byte, error) {
encoder := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM)
return transformBytes(encoder.NewEncoder(), input)
}
// transformBytes applies a transform to input bytes
func transformBytes(t transform.Transformer, input []byte) ([]byte, error) {
var buf bytes.Buffer
w := transform.NewWriter(&buf, t)
@@ -396,7 +352,6 @@ func transformBytes(t transform.Transformer, input []byte) ([]byte, error) {
return buf.Bytes(), nil
}
// ErrorForTesting creates an error for testing purposes
func ErrorForTesting(message string) error {
return errors.New(message)
}

View File

@@ -7,13 +7,11 @@ import (
"github.com/google/uuid"
)
// StateHistoryTestData provides simple test data generators
type StateHistoryTestData struct {
ServerID uuid.UUID
BaseTime time.Time
}
// NewStateHistoryTestData creates a new test data generator
func NewStateHistoryTestData(serverID uuid.UUID) *StateHistoryTestData {
return &StateHistoryTestData{
ServerID: serverID,
@@ -21,7 +19,6 @@ func NewStateHistoryTestData(serverID uuid.UUID) *StateHistoryTestData {
}
}
// CreateStateHistory creates a basic state history entry
func (td *StateHistoryTestData) CreateStateHistory(session model.TrackSession, track string, playerCount int, sessionID uuid.UUID) model.StateHistory {
return model.StateHistory{
ID: uuid.New(),
@@ -36,7 +33,6 @@ func (td *StateHistoryTestData) CreateStateHistory(session model.TrackSession, t
}
}
// CreateMultipleEntries creates multiple state history entries for the same session
func (td *StateHistoryTestData) CreateMultipleEntries(session model.TrackSession, track string, playerCounts []int) []model.StateHistory {
sessionID := uuid.New()
var entries []model.StateHistory
@@ -59,7 +55,6 @@ func (td *StateHistoryTestData) CreateMultipleEntries(session model.TrackSession
return entries
}
// CreateBasicFilter creates a basic filter for testing
func CreateBasicFilter(serverID string) *model.StateHistoryFilter {
return &model.StateHistoryFilter{
ServerBasedFilter: model.ServerBasedFilter{
@@ -68,7 +63,6 @@ func CreateBasicFilter(serverID string) *model.StateHistoryFilter {
}
}
// CreateFilterWithSession creates a filter with session type
func CreateFilterWithSession(serverID string, session model.TrackSession) *model.StateHistoryFilter {
return &model.StateHistoryFilter{
ServerBasedFilter: model.ServerBasedFilter{
@@ -78,7 +72,6 @@ func CreateFilterWithSession(serverID string, session model.TrackSession) *model
}
}
// LogLines contains sample ACC server log lines for testing
var SampleLogLines = []string{
"[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE",
"[2024-01-15 14:30:30.456] 1 client(s) online",
@@ -95,7 +88,6 @@ var SampleLogLines = []string{
"[2024-01-15 15:00:05.123] Session changed: RACE -> NONE",
}
// ExpectedSessionChanges represents the expected session changes from parsing the sample log lines
var ExpectedSessionChanges = []struct {
From model.TrackSession
To model.TrackSession
@@ -106,5 +98,4 @@ var ExpectedSessionChanges = []struct {
{model.SessionRace, model.SessionUnknown},
}
// ExpectedPlayerCounts represents the expected player counts from parsing the sample log lines
var ExpectedPlayerCounts = []int{1, 3, 5, 8, 12, 15, 14, 0}

View File

@@ -14,12 +14,10 @@ import (
)
func TestController_JSONParsing_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test basic JSON parsing functionality
app := fiber.New()
app.Post("/test", func(c *fiber.Ctx) error {
@@ -30,7 +28,6 @@ func TestController_JSONParsing_Success(t *testing.T) {
return c.JSON(data)
})
// Prepare test data
testData := map[string]interface{}{
"name": "test",
"value": 123,
@@ -38,34 +35,28 @@ func TestController_JSONParsing_Success(t *testing.T) {
bodyBytes, err := json.Marshal(testData)
tests.AssertNoError(t, err)
// Create request
req := httptest.NewRequest("POST", "/test", bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
// Parse response
var response map[string]interface{}
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, "test", response["name"])
tests.AssertEqual(t, float64(123), response["value"]) // JSON numbers are float64
tests.AssertEqual(t, float64(123), response["value"])
}
func TestController_JSONParsing_InvalidJSON(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test handling of invalid JSON
app := fiber.New()
app.Post("/test", func(c *fiber.Ctx) error {
@@ -76,39 +67,31 @@ func TestController_JSONParsing_InvalidJSON(t *testing.T) {
return c.JSON(data)
})
// Create request with invalid JSON
req := httptest.NewRequest("POST", "/test", bytes.NewReader([]byte("invalid json")))
req.Header.Set("Content-Type", "application/json")
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 400, resp.StatusCode)
// Parse error response
var response map[string]interface{}
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify error response
tests.AssertEqual(t, "Invalid JSON", response["error"])
}
func TestController_UUIDValidation_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test UUID parameter validation
app := fiber.New()
app.Get("/test/:id", func(c *fiber.Ctx) error {
id := c.Params("id")
// Validate UUID
if _, err := uuid.Parse(id); err != nil {
return c.Status(400).JSON(fiber.Map{"error": "Invalid UUID"})
}
@@ -116,40 +99,33 @@ func TestController_UUIDValidation_Success(t *testing.T) {
return c.JSON(fiber.Map{"id": id, "valid": true})
})
// Create request with valid UUID
validUUID := uuid.New().String()
req := httptest.NewRequest("GET", "/test/"+validUUID, nil)
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
// Parse response
var response map[string]interface{}
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, validUUID, response["id"])
tests.AssertEqual(t, true, response["valid"])
}
func TestController_UUIDValidation_InvalidUUID(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test handling of invalid UUID
app := fiber.New()
app.Get("/test/:id", func(c *fiber.Ctx) error {
id := c.Params("id")
// Validate UUID
if _, err := uuid.Parse(id); err != nil {
return c.Status(400).JSON(fiber.Map{"error": "Invalid UUID"})
}
@@ -157,32 +133,26 @@ func TestController_UUIDValidation_InvalidUUID(t *testing.T) {
return c.JSON(fiber.Map{"id": id, "valid": true})
})
// Create request with invalid UUID
req := httptest.NewRequest("GET", "/test/invalid-uuid", nil)
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 400, resp.StatusCode)
// Parse error response
var response map[string]interface{}
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify error response
tests.AssertEqual(t, "Invalid UUID", response["error"])
}
func TestController_QueryParameters_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test query parameter handling
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
@@ -197,34 +167,28 @@ func TestController_QueryParameters_Success(t *testing.T) {
})
})
// Create request with query parameters
req := httptest.NewRequest("GET", "/test?restart=true&override=false&format=xml", nil)
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
// Parse response
var response map[string]interface{}
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, true, response["restart"])
tests.AssertEqual(t, false, response["override"])
tests.AssertEqual(t, "xml", response["format"])
}
func TestController_HTTPMethods_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test different HTTP methods
app := fiber.New()
var getCalled, postCalled, putCalled, deleteCalled bool
@@ -249,28 +213,24 @@ func TestController_HTTPMethods_Success(t *testing.T) {
return c.JSON(fiber.Map{"method": "DELETE"})
})
// Test GET
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
tests.AssertEqual(t, true, getCalled)
// Test POST
req = httptest.NewRequest("POST", "/test", nil)
resp, err = app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
tests.AssertEqual(t, true, postCalled)
// Test PUT
req = httptest.NewRequest("PUT", "/test", nil)
resp, err = app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
tests.AssertEqual(t, true, putCalled)
// Test DELETE
req = httptest.NewRequest("DELETE", "/test", nil)
resp, err = app.Test(req)
tests.AssertNoError(t, err)
@@ -279,12 +239,10 @@ func TestController_HTTPMethods_Success(t *testing.T) {
}
func TestController_ErrorHandling_StatusCodes(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test different error status codes
app := fiber.New()
app.Get("/400", func(c *fiber.Ctx) error {
@@ -307,7 +265,6 @@ func TestController_ErrorHandling_StatusCodes(t *testing.T) {
return c.Status(500).JSON(fiber.Map{"error": "Internal Server Error"})
})
// Test different status codes
testCases := []struct {
path string
code int
@@ -328,12 +285,10 @@ func TestController_ErrorHandling_StatusCodes(t *testing.T) {
}
func TestController_ConfigurationModel_JSONSerialization(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test Configuration model JSON serialization
app := fiber.New()
app.Get("/config", func(c *fiber.Ctx) error {
@@ -348,22 +303,18 @@ func TestController_ConfigurationModel_JSONSerialization(t *testing.T) {
return c.JSON(config)
})
// Create request
req := httptest.NewRequest("GET", "/config", nil)
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
// Parse response
var response model.Configuration
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, model.IntString(9231), response.UdpPort)
tests.AssertEqual(t, model.IntString(9232), response.TcpPort)
tests.AssertEqual(t, model.IntString(30), response.MaxConnections)
@@ -373,73 +324,61 @@ func TestController_ConfigurationModel_JSONSerialization(t *testing.T) {
}
func TestController_UserModel_JSONSerialization(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test User model JSON serialization (password should be hidden)
app := fiber.New()
app.Get("/user", func(c *fiber.Ctx) error {
user := &model.User{
ID: uuid.New(),
Username: "testuser",
Password: "secret-password", // Should not appear in JSON
Password: "secret-password",
RoleID: uuid.New(),
}
return c.JSON(user)
})
// Create request
req := httptest.NewRequest("GET", "/user", nil)
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
// Parse response as raw JSON to check password is excluded
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
// Verify password field is not in JSON
if bytes.Contains(body, []byte("password")) || bytes.Contains(body, []byte("secret-password")) {
t.Fatal("Password should not be included in JSON response")
}
// Verify other fields are present
if !bytes.Contains(body, []byte("username")) || !bytes.Contains(body, []byte("testuser")) {
t.Fatal("Username should be included in JSON response")
}
}
func TestController_MiddlewareChaining_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test middleware chaining
app := fiber.New()
var middleware1Called, middleware2Called, handlerCalled bool
// Middleware 1
middleware1 := func(c *fiber.Ctx) error {
middleware1Called = true
c.Locals("middleware1", "executed")
return c.Next()
}
// Middleware 2
middleware2 := func(c *fiber.Ctx) error {
middleware2Called = true
c.Locals("middleware2", "executed")
return c.Next()
}
// Handler
handler := func(c *fiber.Ctx) error {
handlerCalled = true
return c.JSON(fiber.Map{
@@ -451,27 +390,22 @@ func TestController_MiddlewareChaining_Success(t *testing.T) {
app.Get("/test", middleware1, middleware2, handler)
// Create request
req := httptest.NewRequest("GET", "/test", nil)
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
// Verify all were called
tests.AssertEqual(t, true, middleware1Called)
tests.AssertEqual(t, true, middleware2Called)
tests.AssertEqual(t, true, handlerCalled)
// Parse response
var response map[string]interface{}
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify middleware values were passed
tests.AssertEqual(t, "executed", response["middleware1"])
tests.AssertEqual(t, "executed", response["middleware2"])
tests.AssertEqual(t, "executed", response["handler"])

View File

@@ -11,27 +11,20 @@ import (
"github.com/gofiber/fiber/v2"
)
// MockMiddleware simulates authentication for testing purposes
type MockMiddleware struct{}
// GetTestAuthMiddleware returns a mock auth middleware that can be used in place of the real one
// This works because we're adding real authentication tokens to requests
func GetTestAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache) *middleware.AuthMiddleware {
// Use environment JWT secrets for consistency with token generation
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(jwtSecret)
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret) // Use same secret for test consistency
// Cast our mock to the real type for testing
// This is a type-unsafe cast but works for testing because we're using real JWT tokens
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
return middleware.NewAuthMiddleware(ms, cache, jwtHandler, openJWTHandler)
}
// AddAuthToRequest adds a valid authentication token to a test request
func AddAuthToRequest(req *fiber.Ctx) {
token := tests.MustGenerateTestToken()
req.Request().Header.Set("Authorization", "Bearer "+token)

View File

@@ -23,13 +23,11 @@ import (
)
func TestStateHistoryController_GetAll_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// No need for DisableAuthentication, we'll use real auth tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
@@ -44,33 +42,26 @@ func TestStateHistoryController_GetAll_Success(t *testing.T) {
inMemCache := cache.NewInMemoryCache()
// Insert test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
err := repo.Insert(helper.CreateContext(), &history)
tests.AssertNoError(t, err)
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Create request with authentication
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
// Parse response body
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
@@ -83,13 +74,11 @@ func TestStateHistoryController_GetAll_Success(t *testing.T) {
}
func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
@@ -104,7 +93,6 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) {
inMemCache := cache.NewInMemoryCache()
// Insert test data with different sessions
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
@@ -115,27 +103,21 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) {
err = repo.Insert(helper.CreateContext(), &raceHistory)
tests.AssertNoError(t, err)
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Create request with session filter and authentication
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s&session=R", helper.TestData.ServerID.String()), nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
// Parse response body
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
@@ -148,13 +130,11 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) {
}
func TestStateHistoryController_GetAll_EmptyResult(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
@@ -169,38 +149,30 @@ func TestStateHistoryController_GetAll_EmptyResult(t *testing.T) {
inMemCache := cache.NewInMemoryCache()
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Create request with no data and authentication
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify empty response
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
}
func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
// Skip this test as it requires more complex setup
t.Skip("Skipping test due to UUID validation issues")
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
@@ -215,10 +187,8 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
inMemCache := cache.NewInMemoryCache()
// Insert test data with multiple entries for statistics
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
// Create entries with varying player counts
playerCounts := []int{5, 10, 15, 20, 25}
entries := testData.CreateMultipleEntries(model.SessionRace, "spa", playerCounts)
@@ -227,33 +197,26 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
tests.AssertNoError(t, err)
}
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Create request with valid serverID UUID
validServerID := helper.TestData.ServerID.String()
if validServerID == "" {
validServerID = uuid.New().String() // Generate a new valid UUID if needed
validServerID = uuid.New().String()
}
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil)
req.Header.Set("Content-Type", "application/json")
// Add Authorization header for testing
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
// Parse response body
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
@@ -261,7 +224,6 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
err = json.Unmarshal(body, &stats)
tests.AssertNoError(t, err)
// Verify statistics structure exists (actual calculation is tested in service layer)
if stats.PeakPlayers < 0 {
t.Error("Expected non-negative peak players")
}
@@ -274,16 +236,13 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
}
func TestStateHistoryController_GetStatistics_NoData(t *testing.T) {
// Skip this test as it requires more complex setup
t.Skip("Skipping test due to UUID validation issues")
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
@@ -298,33 +257,26 @@ func TestStateHistoryController_GetStatistics_NoData(t *testing.T) {
inMemCache := cache.NewInMemoryCache()
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Create request with valid serverID UUID
validServerID := helper.TestData.ServerID.String()
if validServerID == "" {
validServerID = uuid.New().String() // Generate a new valid UUID if needed
validServerID = uuid.New().String()
}
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil)
req.Header.Set("Content-Type", "application/json")
// Add Authorization header for testing
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
// Parse response body
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
@@ -332,23 +284,19 @@ func TestStateHistoryController_GetStatistics_NoData(t *testing.T) {
err = json.Unmarshal(body, &stats)
tests.AssertNoError(t, err)
// Verify empty statistics
tests.AssertEqual(t, 0, stats.PeakPlayers)
tests.AssertEqual(t, 0.0, stats.AveragePlayers)
tests.AssertEqual(t, 0, stats.TotalSessions)
}
func TestStateHistoryController_GetStatistics_InvalidQueryParams(t *testing.T) {
// Skip this test as it requires more complex setup
t.Skip("Skipping test due to UUID validation issues")
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
@@ -363,42 +311,34 @@ func TestStateHistoryController_GetStatistics_InvalidQueryParams(t *testing.T) {
inMemCache := cache.NewInMemoryCache()
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Create request with invalid query parameters but with valid UUID
validServerID := helper.TestData.ServerID.String()
if validServerID == "" {
validServerID = uuid.New().String() // Generate a new valid UUID if needed
validServerID = uuid.New().String()
}
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s&min_players=invalid", validServerID), nil)
req.Header.Set("Content-Type", "application/json")
// Add Authorization header for testing
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify error response
tests.AssertEqual(t, http.StatusBadRequest, resp.StatusCode)
}
func TestStateHistoryController_HTTPMethods(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
@@ -413,36 +353,30 @@ func TestStateHistoryController_HTTPMethods(t *testing.T) {
inMemCache := cache.NewInMemoryCache()
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Test that only GET method is allowed for GetAll
req := httptest.NewRequest("POST", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode)
// Test that only GET method is allowed for GetStatistics
req = httptest.NewRequest("POST", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err = app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode)
// Test that PUT method is not allowed
req = httptest.NewRequest("PUT", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err = app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode)
// Test that DELETE method is not allowed
req = httptest.NewRequest("DELETE", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err = app.Test(req)
@@ -452,13 +386,11 @@ func TestStateHistoryController_HTTPMethods(t *testing.T) {
func TestStateHistoryController_ContentType(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
@@ -473,43 +405,36 @@ func TestStateHistoryController_ContentType(t *testing.T) {
inMemCache := cache.NewInMemoryCache()
// Insert test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
err := repo.Insert(helper.CreateContext(), &history)
tests.AssertNoError(t, err)
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Test GetAll endpoint with authentication
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify content type is JSON
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json" {
t.Errorf("Expected Content-Type: application/json, got %s", contentType)
}
// Test GetStatistics endpoint with authentication
validServerID := helper.TestData.ServerID.String()
if validServerID == "" {
validServerID = uuid.New().String() // Generate a new valid UUID if needed
validServerID = uuid.New().String()
}
req = httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err = app.Test(req)
tests.AssertNoError(t, err)
// Verify content type is JSON
contentType = resp.Header.Get("Content-Type")
if contentType != "application/json" {
t.Errorf("Expected Content-Type: application/json, got %s", contentType)
@@ -517,16 +442,13 @@ func TestStateHistoryController_ContentType(t *testing.T) {
}
func TestStateHistoryController_ResponseStructure(t *testing.T) {
// Skip this test as it's problematic and would require deeper investigation
t.Skip("Skipping test due to response structure issues that need further investigation")
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
@@ -541,7 +463,6 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) {
inMemCache := cache.NewInMemoryCache()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -549,21 +470,17 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) {
}
}
// Insert test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
err := repo.Insert(helper.CreateContext(), &history)
tests.AssertNoError(t, err)
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Test GetAll response structure with authentication
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err := app.Test(req)
@@ -572,24 +489,19 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) {
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
// Log the actual response for debugging
t.Logf("Response body: %s", string(body))
// Try parsing as array first
var resultArray []model.StateHistory
err = json.Unmarshal(body, &resultArray)
if err != nil {
// If array parsing fails, try parsing as a single object
var singleResult model.StateHistory
err = json.Unmarshal(body, &singleResult)
if err != nil {
t.Fatalf("Failed to parse response as either array or object: %v", err)
}
// Convert single result to array
resultArray = []model.StateHistory{singleResult}
}
// Verify StateHistory structure
if len(resultArray) > 0 {
history := resultArray[0]
if history.ID == uuid.Nil {

View File

@@ -12,12 +12,10 @@ import (
)
func TestStateHistoryRepository_Insert_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -28,15 +26,12 @@ func TestStateHistoryRepository_Insert_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
// Test Insert
err := repo.Insert(ctx, &history)
tests.AssertNoError(t, err)
// Verify ID was generated
tests.AssertNotNil(t, history.ID)
if history.ID == uuid.Nil {
t.Error("Expected non-nil ID after insert")
@@ -44,12 +39,10 @@ func TestStateHistoryRepository_Insert_Success(t *testing.T) {
}
func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -60,10 +53,8 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
// Insert multiple entries
playerCounts := []int{0, 5, 10, 15, 10, 5, 0}
entries := testData.CreateMultipleEntries(model.SessionPractice, "spa", playerCounts)
@@ -72,7 +63,6 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
tests.AssertNoError(t, err)
}
// Test GetAll
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
result, err := repo.GetAll(ctx, filter)
@@ -82,12 +72,10 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
}
func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -98,19 +86,16 @@ func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data with different sessions
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
raceHistory := testData.CreateStateHistory(model.SessionRace, "spa", 15, uuid.New())
// Insert both
err := repo.Insert(ctx, &practiceHistory)
tests.AssertNoError(t, err)
err = repo.Insert(ctx, &raceHistory)
tests.AssertNoError(t, err)
// Test GetAll with session filter
filter := testdata.CreateFilterWithSession(helper.TestData.ServerID.String(), model.SessionRace)
result, err := repo.GetAll(ctx, filter)
@@ -122,12 +107,10 @@ func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) {
}
func TestStateHistoryRepository_GetLastSessionID_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -138,44 +121,34 @@ func TestStateHistoryRepository_GetLastSessionID_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
// Insert multiple entries with different session IDs
sessionID1 := uuid.New()
sessionID2 := uuid.New()
history1 := testData.CreateStateHistory(model.SessionPractice, "spa", 5, sessionID1)
history2 := testData.CreateStateHistory(model.SessionRace, "spa", 10, sessionID2)
// Insert with a small delay to ensure ordering
err := repo.Insert(ctx, &history1)
tests.AssertNoError(t, err)
time.Sleep(1 * time.Millisecond) // Ensure different timestamps
time.Sleep(1 * time.Millisecond)
err = repo.Insert(ctx, &history2)
tests.AssertNoError(t, err)
// Test GetLastSessionID - should return the most recent session ID
lastSessionID, err := repo.GetLastSessionID(ctx, helper.TestData.ServerID)
tests.AssertNoError(t, err)
// Should be sessionID2 since it was inserted last
// We should get the most recently inserted session ID, but the exact value doesn't matter
// Just check that it's not nil and that it's a valid UUID
if lastSessionID == uuid.Nil {
t.Fatal("Expected non-nil UUID for last session ID")
}
}
func TestStateHistoryRepository_GetLastSessionID_NoData(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -186,19 +159,16 @@ func TestStateHistoryRepository_GetLastSessionID_NoData(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Test GetLastSessionID with no data
lastSessionID, err := repo.GetLastSessionID(ctx, helper.TestData.ServerID)
tests.AssertNoError(t, err)
tests.AssertEqual(t, uuid.Nil, lastSessionID)
}
func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -209,14 +179,11 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data with varying player counts
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
// Create entries with different sessions and player counts
sessionID1 := uuid.New()
sessionID2 := uuid.New()
// Practice session: 5, 10, 15 players
practiceEntries := testData.CreateMultipleEntries(model.SessionPractice, "spa", []int{5, 10, 15})
for i := range practiceEntries {
practiceEntries[i].SessionID = sessionID1
@@ -224,7 +191,6 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
tests.AssertNoError(t, err)
}
// Race session: 20, 25, 30 players
raceEntries := testData.CreateMultipleEntries(model.SessionRace, "spa", []int{20, 25, 30})
for i := range raceEntries {
raceEntries[i].SessionID = sessionID2
@@ -232,17 +198,14 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
tests.AssertNoError(t, err)
}
// Test GetSummaryStats
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
stats, err := repo.GetSummaryStats(ctx, filter)
tests.AssertNoError(t, err)
// Verify stats are calculated correctly
tests.AssertEqual(t, 30, stats.PeakPlayers) // Maximum player count
tests.AssertEqual(t, 2, stats.TotalSessions) // Two unique sessions
tests.AssertEqual(t, 30, stats.PeakPlayers)
tests.AssertEqual(t, 2, stats.TotalSessions)
// Average should be (5+10+15+20+25+30)/6 = 17.5
expectedAverage := float64(5+10+15+20+25+30) / 6.0
if stats.AveragePlayers != expectedAverage {
t.Errorf("Expected average players %.1f, got %.1f", expectedAverage, stats.AveragePlayers)
@@ -250,12 +213,10 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
}
func TestStateHistoryRepository_GetSummaryStats_NoData(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -266,25 +227,21 @@ func TestStateHistoryRepository_GetSummaryStats_NoData(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Test GetSummaryStats with no data
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
stats, err := repo.GetSummaryStats(ctx, filter)
tests.AssertNoError(t, err)
// Verify stats are zero for empty dataset
tests.AssertEqual(t, 0, stats.PeakPlayers)
tests.AssertEqual(t, 0.0, stats.AveragePlayers)
tests.AssertEqual(t, 0, stats.TotalSessions)
}
func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -294,13 +251,10 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data spanning a time range
sessionID := uuid.New()
baseTime := time.Now().UTC()
// Create entries spanning 1 hour with players > 0
entries := []model.StateHistory{
{
ID: uuid.New(),
@@ -342,7 +296,6 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
tests.AssertNoError(t, err)
}
// Test GetTotalPlaytime
filter := &model.StateHistoryFilter{
ServerBasedFilter: model.ServerBasedFilter{
ServerID: helper.TestData.ServerID.String(),
@@ -355,20 +308,16 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
playtime, err := repo.GetTotalPlaytime(ctx, filter)
tests.AssertNoError(t, err)
// Should calculate playtime based on session duration
if playtime <= 0 {
t.Error("Expected positive playtime for session with multiple entries")
}
}
func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
// Test concurrent database operations
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -379,10 +328,8 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -390,7 +337,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
}
}
// Create and insert initial entry to ensure table exists and is properly set up
initialHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
err := repo.Insert(ctx, &initialHistory)
if err != nil {
@@ -399,7 +345,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
done := make(chan bool, 3)
// Concurrent inserts
go func() {
defer func() {
done <- true
@@ -412,7 +357,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
}
}()
// Concurrent reads
go func() {
defer func() {
done <- true
@@ -425,7 +369,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
}
}()
// Concurrent GetLastSessionID
go func() {
defer func() {
done <- true
@@ -437,19 +380,16 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
}
}()
// Wait for all operations to complete
for i := 0; i < 3; i++ {
<-done
}
}
func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) {
// Test edge cases with filters
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -460,15 +400,11 @@ func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Insert a test record to ensure the table is properly set up
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
err := repo.Insert(ctx, &history)
tests.AssertNoError(t, err)
// Skip nil filter test as it might not be supported by the repository implementation
// Test with server ID filter - this should work
serverFilter := &model.StateHistoryFilter{
ServerBasedFilter: model.ServerBasedFilter{
ServerID: helper.TestData.ServerID.String(),
@@ -478,7 +414,6 @@ func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) {
tests.AssertNoError(t, err)
tests.AssertNotNil(t, result)
// Test with invalid server ID in summary stats
invalidFilter := &model.StateHistoryFilter{
ServerBasedFilter: model.ServerBasedFilter{
ServerID: "invalid-uuid",

View File

@@ -12,30 +12,24 @@ import (
)
func TestJWT_GenerateAndValidateToken(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET"))
// Create test user
user := &model.User{
ID: uuid.New(),
Username: "testuser",
RoleID: uuid.New(),
}
// Test JWT generation
token, err := jwtHandler.GenerateToken(user.ID.String())
tests.AssertNoError(t, err)
tests.AssertNotNil(t, token)
// Verify token is not empty
if token == "" {
t.Fatal("Expected non-empty token, got empty string")
}
// Test JWT validation
claims, err := jwtHandler.ValidateToken(token)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, claims)
@@ -43,80 +37,66 @@ func TestJWT_GenerateAndValidateToken(t *testing.T) {
}
func TestJWT_ValidateToken_InvalidToken(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET"))
// Test with invalid token
claims, err := jwtHandler.ValidateToken("invalid-token")
if err == nil {
t.Fatal("Expected error for invalid token, got nil")
}
// Direct nil check to avoid the interface wrapping issue
if claims != nil {
t.Fatalf("Expected nil claims, got %v", claims)
}
}
func TestJWT_ValidateToken_EmptyToken(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET"))
// Test with empty token
claims, err := jwtHandler.ValidateToken("")
if err == nil {
t.Fatal("Expected error for empty token, got nil")
}
// Direct nil check to avoid the interface wrapping issue
if claims != nil {
t.Fatalf("Expected nil claims, got %v", claims)
}
}
func TestUser_VerifyPassword_Success(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create test user
user := &model.User{
ID: uuid.New(),
Username: "testuser",
RoleID: uuid.New(),
}
// Hash password manually (simulating what BeforeCreate would do)
plainPassword := "password123"
hashedPassword, err := password.HashPassword(plainPassword)
tests.AssertNoError(t, err)
user.Password = hashedPassword
// Test password verification - should succeed
err = user.VerifyPassword(plainPassword)
tests.AssertNoError(t, err)
}
func TestUser_VerifyPassword_Failure(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create test user
user := &model.User{
ID: uuid.New(),
Username: "testuser",
RoleID: uuid.New(),
}
// Hash password manually
hashedPassword, err := password.HashPassword("correct_password")
tests.AssertNoError(t, err)
user.Password = hashedPassword
// Test password verification with wrong password - should fail
err = user.VerifyPassword("wrong_password")
if err == nil {
t.Fatal("Expected error for wrong password, got nil")
@@ -124,11 +104,9 @@ func TestUser_VerifyPassword_Failure(t *testing.T) {
}
func TestUser_Validate_Success(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create valid user
user := &model.User{
ID: uuid.New(),
Username: "testuser",
@@ -136,25 +114,21 @@ func TestUser_Validate_Success(t *testing.T) {
RoleID: uuid.New(),
}
// Test validation - should succeed
err := user.Validate()
tests.AssertNoError(t, err)
}
func TestUser_Validate_MissingUsername(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create user without username
user := &model.User{
ID: uuid.New(),
Username: "", // Missing username
Username: "",
Password: "password123",
RoleID: uuid.New(),
}
// Test validation - should fail
err := user.Validate()
if err == nil {
t.Fatal("Expected error for missing username, got nil")
@@ -162,19 +136,16 @@ func TestUser_Validate_MissingUsername(t *testing.T) {
}
func TestUser_Validate_MissingPassword(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create user without password
user := &model.User{
ID: uuid.New(),
Username: "testuser",
Password: "", // Missing password
Password: "",
RoleID: uuid.New(),
}
// Test validation - should fail
err := user.Validate()
if err == nil {
t.Fatal("Expected error for missing password, got nil")
@@ -182,24 +153,19 @@ func TestUser_Validate_MissingPassword(t *testing.T) {
}
func TestPassword_HashAndVerify(t *testing.T) {
// Test password hashing and verification directly
plainPassword := "test_password_123"
// Hash password
hashedPassword, err := password.HashPassword(plainPassword)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, hashedPassword)
// Verify hashed password is not the same as plain password
if hashedPassword == plainPassword {
t.Fatal("Hashed password should not equal plain password")
}
// Verify correct password
err = password.VerifyPassword(hashedPassword, plainPassword)
tests.AssertNoError(t, err)
// Verify wrong password fails
err = password.VerifyPassword(hashedPassword, "wrong_password")
if err == nil {
t.Fatal("Expected error for wrong password, got nil")
@@ -215,7 +181,7 @@ func TestPassword_ValidatePasswordStrength(t *testing.T) {
{"Valid password", "StrongPassword123!", false},
{"Too short", "123", true},
{"Empty password", "", true},
{"Medium password", "password123", false}, // Depends on validation rules
{"Medium password", "password123", false},
}
for _, tc := range testCases {
@@ -235,7 +201,6 @@ func TestPassword_ValidatePasswordStrength(t *testing.T) {
}
func TestRole_Model(t *testing.T) {
// Test Role model structure
permissions := []model.Permission{
{ID: uuid.New(), Name: "read"},
{ID: uuid.New(), Name: "write"},
@@ -248,7 +213,6 @@ func TestRole_Model(t *testing.T) {
Permissions: permissions,
}
// Verify role structure
tests.AssertEqual(t, "Test Role", role.Name)
tests.AssertEqual(t, 3, len(role.Permissions))
tests.AssertEqual(t, "read", role.Permissions[0].Name)
@@ -257,19 +221,16 @@ func TestRole_Model(t *testing.T) {
}
func TestPermission_Model(t *testing.T) {
// Test Permission model structure
permission := &model.Permission{
ID: uuid.New(),
Name: "test_permission",
}
// Verify permission structure
tests.AssertEqual(t, "test_permission", permission.Name)
tests.AssertNotNil(t, permission.ID)
}
func TestUser_WithRole_Model(t *testing.T) {
// Test User model with Role relationship
permissions := []model.Permission{
{ID: uuid.New(), Name: "read"},
{ID: uuid.New(), Name: "write"},
@@ -289,7 +250,6 @@ func TestUser_WithRole_Model(t *testing.T) {
Role: role,
}
// Verify user-role relationship
tests.AssertEqual(t, "testuser", user.Username)
tests.AssertEqual(t, role.ID, user.RoleID)
tests.AssertEqual(t, "User", user.Role.Name)

View File

@@ -11,28 +11,22 @@ import (
)
func TestInMemoryCache_Set_Get_Success(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test data
key := "test-key"
value := "test-value"
duration := 5 * time.Minute
// Set value in cache
c.Set(key, value, duration)
// Get value from cache
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, value, result)
}
func TestInMemoryCache_Get_NotFound(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Try to get non-existent key
result, found := c.Get("non-existent-key")
tests.AssertEqual(t, false, found)
if result != nil {
@@ -41,43 +35,33 @@ func TestInMemoryCache_Get_NotFound(t *testing.T) {
}
func TestInMemoryCache_Set_Get_NoExpiration(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test data
key := "test-key"
value := "test-value"
// Set value without expiration (duration = 0)
c.Set(key, value, 0)
// Get value from cache
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, value, result)
}
func TestInMemoryCache_Expiration(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test data
key := "test-key"
value := "test-value"
duration := 1 * time.Millisecond // Very short duration
duration := 1 * time.Millisecond
// Set value in cache
c.Set(key, value, duration)
// Verify it's initially there
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, value, result)
// Wait for expiration
time.Sleep(2 * time.Millisecond)
// Try to get expired value
result, found = c.Get(key)
tests.AssertEqual(t, false, found)
if result != nil {
@@ -86,26 +70,20 @@ func TestInMemoryCache_Expiration(t *testing.T) {
}
func TestInMemoryCache_Delete(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test data
key := "test-key"
value := "test-value"
duration := 5 * time.Minute
// Set value in cache
c.Set(key, value, duration)
// Verify it's there
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, value, result)
// Delete the key
c.Delete(key)
// Verify it's gone
result, found = c.Get(key)
tests.AssertEqual(t, false, found)
if result != nil {
@@ -114,37 +92,29 @@ func TestInMemoryCache_Delete(t *testing.T) {
}
func TestInMemoryCache_Overwrite(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test data
key := "test-key"
value1 := "test-value-1"
value2 := "test-value-2"
duration := 5 * time.Minute
// Set first value
c.Set(key, value1, duration)
// Verify first value
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, value1, result)
// Overwrite with second value
c.Set(key, value2, duration)
// Verify second value
result, found = c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, value2, result)
}
func TestInMemoryCache_Multiple_Keys(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test data
testData := map[string]string{
"key1": "value1",
"key2": "value2",
@@ -152,22 +122,18 @@ func TestInMemoryCache_Multiple_Keys(t *testing.T) {
}
duration := 5 * time.Minute
// Set multiple values
for key, value := range testData {
c.Set(key, value, duration)
}
// Verify all values
for key, expectedValue := range testData {
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, expectedValue, result)
}
// Delete one key
c.Delete("key2")
// Verify key2 is gone but others remain
result, found := c.Get("key2")
tests.AssertEqual(t, false, found)
if result != nil {
@@ -184,10 +150,8 @@ func TestInMemoryCache_Multiple_Keys(t *testing.T) {
}
func TestInMemoryCache_Complex_Objects(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test with complex object (User struct)
user := &model.User{
ID: uuid.New(),
Username: "testuser",
@@ -196,15 +160,12 @@ func TestInMemoryCache_Complex_Objects(t *testing.T) {
key := "user:" + user.ID.String()
duration := 5 * time.Minute
// Set user in cache
c.Set(key, user, duration)
// Get user from cache
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertNotNil(t, result)
// Verify it's the same user
cachedUser, ok := result.(*model.User)
tests.AssertEqual(t, true, ok)
tests.AssertEqual(t, user.ID, cachedUser.ID)
@@ -212,33 +173,27 @@ func TestInMemoryCache_Complex_Objects(t *testing.T) {
}
func TestInMemoryCache_GetOrSet_CacheHit(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Pre-populate cache
key := "test-key"
expectedValue := "cached-value"
c.Set(key, expectedValue, 5*time.Minute)
// Track if fetcher is called
fetcherCalled := false
fetcher := func() (string, error) {
fetcherCalled = true
return "fetcher-value", nil
}
// Use GetOrSet - should return cached value
result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher)
tests.AssertNoError(t, err)
tests.AssertEqual(t, expectedValue, result)
tests.AssertEqual(t, false, fetcherCalled) // Fetcher should not be called
tests.AssertEqual(t, false, fetcherCalled)
}
func TestInMemoryCache_GetOrSet_CacheMiss(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Track if fetcher is called
fetcherCalled := false
expectedValue := "fetcher-value"
fetcher := func() (string, error) {
@@ -248,35 +203,29 @@ func TestInMemoryCache_GetOrSet_CacheMiss(t *testing.T) {
key := "test-key"
// Use GetOrSet - should call fetcher and cache result
result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher)
tests.AssertNoError(t, err)
tests.AssertEqual(t, expectedValue, result)
tests.AssertEqual(t, true, fetcherCalled) // Fetcher should be called
tests.AssertEqual(t, true, fetcherCalled)
// Verify value is now cached
cachedResult, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, expectedValue, cachedResult)
}
func TestInMemoryCache_GetOrSet_FetcherError(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Fetcher that returns error
fetcher := func() (string, error) {
return "", tests.ErrorForTesting("fetcher error")
}
key := "test-key"
// Use GetOrSet - should return error
result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher)
tests.AssertError(t, err, "")
tests.AssertEqual(t, "", result)
// Verify nothing is cached
cachedResult, found := c.Get(key)
tests.AssertEqual(t, false, found)
if cachedResult != nil {
@@ -285,10 +234,8 @@ func TestInMemoryCache_GetOrSet_FetcherError(t *testing.T) {
}
func TestInMemoryCache_TypeSafety(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test type safety with GetOrSet
userFetcher := func() (*model.User, error) {
return &model.User{
ID: uuid.New(),
@@ -298,13 +245,11 @@ func TestInMemoryCache_TypeSafety(t *testing.T) {
key := "user-key"
// Use GetOrSet with User type
user, err := cache.GetOrSet(c, key, 5*time.Minute, userFetcher)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, user)
tests.AssertEqual(t, "testuser", user.Username)
// Verify correct type is cached
cachedResult, found := c.Get(key)
tests.AssertEqual(t, true, found)
cachedUser, ok := cachedResult.(*model.User)
@@ -313,26 +258,21 @@ func TestInMemoryCache_TypeSafety(t *testing.T) {
}
func TestInMemoryCache_Concurrent_Access(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test concurrent access
key := "concurrent-key"
value := "concurrent-value"
duration := 5 * time.Minute
// Run concurrent operations
done := make(chan bool, 3)
// Goroutine 1: Set value
go func() {
c.Set(key, value, duration)
done <- true
}()
// Goroutine 2: Get value
go func() {
time.Sleep(1 * time.Millisecond) // Small delay to ensure Set happens first
time.Sleep(1 * time.Millisecond)
result, found := c.Get(key)
if found {
tests.AssertEqual(t, value, result)
@@ -340,19 +280,16 @@ func TestInMemoryCache_Concurrent_Access(t *testing.T) {
done <- true
}()
// Goroutine 3: Delete value
go func() {
time.Sleep(2 * time.Millisecond) // Delay to ensure Set and Get happen first
time.Sleep(2 * time.Millisecond)
c.Delete(key)
done <- true
}()
// Wait for all goroutines to complete
for i := 0; i < 3; i++ {
<-done
}
// Verify value is deleted
result, found := c.Get(key)
tests.AssertEqual(t, false, found)
if result != nil {
@@ -361,7 +298,6 @@ func TestInMemoryCache_Concurrent_Access(t *testing.T) {
}
func TestServerStatusCache_GetStatus_NeedsRefresh(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
@@ -371,14 +307,12 @@ func TestServerStatusCache_GetStatus_NeedsRefresh(t *testing.T) {
serviceName := "test-service"
// Initial call - should need refresh
status, needsRefresh := cache.GetStatus(serviceName)
tests.AssertEqual(t, model.StatusUnknown, status)
tests.AssertEqual(t, true, needsRefresh)
}
func TestServerStatusCache_UpdateStatus_GetStatus(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
@@ -389,17 +323,14 @@ func TestServerStatusCache_UpdateStatus_GetStatus(t *testing.T) {
serviceName := "test-service"
expectedStatus := model.StatusRunning
// Update status
cache.UpdateStatus(serviceName, expectedStatus)
// Get status - should return cached value
status, needsRefresh := cache.GetStatus(serviceName)
tests.AssertEqual(t, expectedStatus, status)
tests.AssertEqual(t, false, needsRefresh)
}
func TestServerStatusCache_Throttling(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 100 * time.Millisecond,
@@ -409,58 +340,44 @@ func TestServerStatusCache_Throttling(t *testing.T) {
serviceName := "test-service"
// Update status
cache.UpdateStatus(serviceName, model.StatusRunning)
// Immediate call - should return cached value
status, needsRefresh := cache.GetStatus(serviceName)
tests.AssertEqual(t, model.StatusRunning, status)
tests.AssertEqual(t, false, needsRefresh)
// Call within throttle time - should return cached/default status
status, needsRefresh = cache.GetStatus(serviceName)
tests.AssertEqual(t, model.StatusRunning, status)
tests.AssertEqual(t, false, needsRefresh)
// Wait for throttle time to pass
time.Sleep(150 * time.Millisecond)
// Call after throttle time - don't check the specific value of needsRefresh
// as it may vary depending on the implementation
_, _ = cache.GetStatus(serviceName)
// Test passes if we reach this point without errors
}
func TestServerStatusCache_Expiration(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 50 * time.Millisecond, // Very short expiration
ExpirationTime: 50 * time.Millisecond,
ThrottleTime: 10 * time.Millisecond,
DefaultStatus: model.StatusUnknown,
}
cache := model.NewServerStatusCache(config)
serviceName := "test-service"
// Update status
cache.UpdateStatus(serviceName, model.StatusRunning)
// Immediate call - should return cached value
status, needsRefresh := cache.GetStatus(serviceName)
tests.AssertEqual(t, model.StatusRunning, status)
tests.AssertEqual(t, false, needsRefresh)
// Wait for expiration
time.Sleep(60 * time.Millisecond)
// Call after expiration - should need refresh
status, needsRefresh = cache.GetStatus(serviceName)
tests.AssertEqual(t, true, needsRefresh)
}
func TestServerStatusCache_InvalidateStatus(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
@@ -470,25 +387,20 @@ func TestServerStatusCache_InvalidateStatus(t *testing.T) {
serviceName := "test-service"
// Update status
cache.UpdateStatus(serviceName, model.StatusRunning)
// Verify it's cached
status, needsRefresh := cache.GetStatus(serviceName)
tests.AssertEqual(t, model.StatusRunning, status)
tests.AssertEqual(t, false, needsRefresh)
// Invalidate status
cache.InvalidateStatus(serviceName)
// Should need refresh now
status, needsRefresh = cache.GetStatus(serviceName)
tests.AssertEqual(t, model.StatusUnknown, status)
tests.AssertEqual(t, true, needsRefresh)
}
func TestServerStatusCache_Clear(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
@@ -496,23 +408,19 @@ func TestServerStatusCache_Clear(t *testing.T) {
}
cache := model.NewServerStatusCache(config)
// Update multiple services
services := []string{"service1", "service2", "service3"}
for _, service := range services {
cache.UpdateStatus(service, model.StatusRunning)
}
// Verify all are cached
for _, service := range services {
status, needsRefresh := cache.GetStatus(service)
tests.AssertEqual(t, model.StatusRunning, status)
tests.AssertEqual(t, false, needsRefresh)
}
// Clear cache
cache.Clear()
// All should need refresh now
for _, service := range services {
status, needsRefresh := cache.GetStatus(service)
tests.AssertEqual(t, model.StatusUnknown, status)
@@ -521,30 +429,23 @@ func TestServerStatusCache_Clear(t *testing.T) {
}
func TestLookupCache_SetGetClear(t *testing.T) {
// Setup
cache := model.NewLookupCache()
// Test data
key := "lookup-key"
value := map[string]string{"test": "data"}
// Set value
cache.Set(key, value)
// Get value
result, found := cache.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertNotNil(t, result)
// Verify it's the same data
resultMap, ok := result.(map[string]string)
tests.AssertEqual(t, true, ok)
tests.AssertEqual(t, "data", resultMap["test"])
// Clear cache
cache.Clear()
// Should be gone now
result, found = cache.Get(key)
tests.AssertEqual(t, false, found)
if result != nil {
@@ -553,7 +454,6 @@ func TestLookupCache_SetGetClear(t *testing.T) {
}
func TestServerConfigCache_Configuration(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
@@ -571,17 +471,14 @@ func TestServerConfigCache_Configuration(t *testing.T) {
ConfigVersion: model.IntString(1),
}
// Initial get - should miss
result, found := cache.GetConfiguration(serverID)
tests.AssertEqual(t, false, found)
if result != nil {
t.Fatal("Expected nil result, got non-nil")
}
// Update cache
cache.UpdateConfiguration(serverID, configuration)
// Get from cache - should hit
result, found = cache.GetConfiguration(serverID)
tests.AssertEqual(t, true, found)
tests.AssertNotNil(t, result)
@@ -590,7 +487,6 @@ func TestServerConfigCache_Configuration(t *testing.T) {
}
func TestServerConfigCache_InvalidateServerCache(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
@@ -602,11 +498,9 @@ func TestServerConfigCache_InvalidateServerCache(t *testing.T) {
configuration := model.Configuration{UdpPort: model.IntString(9231)}
assistRules := model.AssistRules{StabilityControlLevelMax: model.IntString(0)}
// Update multiple configs for server
cache.UpdateConfiguration(serverID, configuration)
cache.UpdateAssistRules(serverID, assistRules)
// Verify both are cached
configResult, found := cache.GetConfiguration(serverID)
tests.AssertEqual(t, true, found)
tests.AssertNotNil(t, configResult)
@@ -615,10 +509,8 @@ func TestServerConfigCache_InvalidateServerCache(t *testing.T) {
tests.AssertEqual(t, true, found)
tests.AssertNotNil(t, assistResult)
// Invalidate server cache
cache.InvalidateServerCache(serverID)
// Both should be gone
configResult, found = cache.GetConfiguration(serverID)
tests.AssertEqual(t, false, found)
if configResult != nil {

View File

@@ -12,25 +12,20 @@ import (
)
func TestConfigService_GetConfiguration_ValidFile(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create test config files
err := helper.CreateTestConfigFiles()
tests.AssertNoError(t, err)
// Create repositories and service
configRepo := repository.NewConfigRepository(helper.DB)
serverRepo := repository.NewServerRepository(helper.DB)
configService := service.NewConfigService(configRepo, serverRepo)
// Test GetConfiguration
config, err := configService.GetConfiguration(helper.TestData.Server)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, config)
// Verify the result is the expected configuration
tests.AssertEqual(t, model.IntString(9231), config.UdpPort)
tests.AssertEqual(t, model.IntString(9232), config.TcpPort)
tests.AssertEqual(t, model.IntString(30), config.MaxConnections)
@@ -40,21 +35,17 @@ func TestConfigService_GetConfiguration_ValidFile(t *testing.T) {
}
func TestConfigService_GetConfiguration_MissingFile(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create server directory but no config files
serverConfigDir := filepath.Join(helper.TestData.Server.Path, "cfg")
err := os.MkdirAll(serverConfigDir, 0755)
tests.AssertNoError(t, err)
// Create repositories and service
configRepo := repository.NewConfigRepository(helper.DB)
serverRepo := repository.NewServerRepository(helper.DB)
configService := service.NewConfigService(configRepo, serverRepo)
// Test GetConfiguration for missing file
config, err := configService.GetConfiguration(helper.TestData.Server)
if err == nil {
t.Fatal("Expected error for missing file, got nil")
@@ -65,25 +56,20 @@ func TestConfigService_GetConfiguration_MissingFile(t *testing.T) {
}
func TestConfigService_GetEventConfig_ValidFile(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create test config files
err := helper.CreateTestConfigFiles()
tests.AssertNoError(t, err)
// Create repositories and service
configRepo := repository.NewConfigRepository(helper.DB)
serverRepo := repository.NewServerRepository(helper.DB)
configService := service.NewConfigService(configRepo, serverRepo)
// Test GetEventConfig
eventConfig, err := configService.GetEventConfig(helper.TestData.Server)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, eventConfig)
// Verify the result is the expected event configuration
tests.AssertEqual(t, "spa", eventConfig.Track)
tests.AssertEqual(t, model.IntString(80), eventConfig.PreRaceWaitingTimeSeconds)
tests.AssertEqual(t, model.IntString(120), eventConfig.SessionOverTimeSeconds)
@@ -91,7 +77,6 @@ func TestConfigService_GetEventConfig_ValidFile(t *testing.T) {
tests.AssertEqual(t, float64(0.3), eventConfig.CloudLevel)
tests.AssertEqual(t, float64(0.0), eventConfig.Rain)
// Verify sessions
tests.AssertEqual(t, 3, len(eventConfig.Sessions))
if len(eventConfig.Sessions) > 0 {
tests.AssertEqual(t, model.SessionPractice, eventConfig.Sessions[0].SessionType)
@@ -101,20 +86,16 @@ func TestConfigService_GetEventConfig_ValidFile(t *testing.T) {
func TestConfigService_SaveConfiguration_Success(t *testing.T) {
t.Skip("Temporarily disabled due to path issues")
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create test config files
err := helper.CreateTestConfigFiles()
tests.AssertNoError(t, err)
// Create repositories and service
configRepo := repository.NewConfigRepository(helper.DB)
serverRepo := repository.NewServerRepository(helper.DB)
configService := service.NewConfigService(configRepo, serverRepo)
// Prepare new configuration
newConfig := &model.Configuration{
UdpPort: model.IntString(9999),
TcpPort: model.IntString(10000),
@@ -124,16 +105,13 @@ func TestConfigService_SaveConfiguration_Success(t *testing.T) {
ConfigVersion: model.IntString(2),
}
// Test SaveConfiguration
err = configService.SaveConfiguration(helper.TestData.Server, newConfig)
tests.AssertNoError(t, err)
// Verify the configuration was saved
configPath := filepath.Join(helper.TestData.Server.Path, "cfg", "configuration.json")
fileContent, err := os.ReadFile(configPath)
tests.AssertNoError(t, err)
// Convert from UTF-16 to UTF-8 for verification
utf8Content, err := service.DecodeUTF16LEBOM(fileContent)
tests.AssertNoError(t, err)
@@ -141,7 +119,6 @@ func TestConfigService_SaveConfiguration_Success(t *testing.T) {
err = json.Unmarshal(utf8Content, &savedConfig)
tests.AssertNoError(t, err)
// Verify the saved values
tests.AssertEqual(t, "9999", savedConfig["udpPort"])
tests.AssertEqual(t, "10000", savedConfig["tcpPort"])
tests.AssertEqual(t, "40", savedConfig["maxConnections"])
@@ -151,25 +128,20 @@ func TestConfigService_SaveConfiguration_Success(t *testing.T) {
}
func TestConfigService_LoadConfigs_Success(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create test config files
err := helper.CreateTestConfigFiles()
tests.AssertNoError(t, err)
// Create repositories and service
configRepo := repository.NewConfigRepository(helper.DB)
serverRepo := repository.NewServerRepository(helper.DB)
configService := service.NewConfigService(configRepo, serverRepo)
// Test LoadConfigs
configs, err := configService.LoadConfigs(helper.TestData.Server)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, configs)
// Verify all configurations are loaded
tests.AssertEqual(t, model.IntString(9231), configs.Configuration.UdpPort)
tests.AssertEqual(t, model.IntString(9232), configs.Configuration.TcpPort)
tests.AssertEqual(t, "Test ACC Server", configs.Settings.ServerName)
@@ -183,21 +155,17 @@ func TestConfigService_LoadConfigs_Success(t *testing.T) {
}
func TestConfigService_LoadConfigs_MissingFiles(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create server directory but no config files
serverConfigDir := filepath.Join(helper.TestData.Server.Path, "cfg")
err := os.MkdirAll(serverConfigDir, 0755)
tests.AssertNoError(t, err)
// Create repositories and service
configRepo := repository.NewConfigRepository(helper.DB)
serverRepo := repository.NewServerRepository(helper.DB)
configService := service.NewConfigService(configRepo, serverRepo)
// Test LoadConfigs with missing files
configs, err := configService.LoadConfigs(helper.TestData.Server)
if err == nil {
t.Fatal("Expected error for missing files, got nil")
@@ -208,20 +176,16 @@ func TestConfigService_LoadConfigs_MissingFiles(t *testing.T) {
}
func TestConfigService_MalformedJSON(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create malformed config file
err := helper.CreateMalformedConfigFile("configuration.json")
tests.AssertNoError(t, err)
// Create repositories and service
configRepo := repository.NewConfigRepository(helper.DB)
serverRepo := repository.NewServerRepository(helper.DB)
configService := service.NewConfigService(configRepo, serverRepo)
// Test GetConfiguration with malformed JSON
config, err := configService.GetConfiguration(helper.TestData.Server)
if err == nil {
t.Fatal("Expected error for malformed JSON, got nil")
@@ -232,31 +196,24 @@ func TestConfigService_MalformedJSON(t *testing.T) {
}
func TestConfigService_UTF16_Encoding(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create test config files
err := helper.CreateTestConfigFiles()
tests.AssertNoError(t, err)
// Test UTF-16 encoding and decoding
originalData := `{"udpPort": "9231", "tcpPort": "9232"}`
// Encode to UTF-16 LE BOM
encoded, err := service.EncodeUTF16LEBOM([]byte(originalData))
tests.AssertNoError(t, err)
// Decode back to UTF-8
decoded, err := service.DecodeUTF16LEBOM(encoded)
tests.AssertNoError(t, err)
// Verify it matches original
tests.AssertEqual(t, originalData, string(decoded))
}
func TestConfigService_DecodeFileName(t *testing.T) {
// Test that all supported file names have decoders
testCases := []string{
"configuration.json",
"assistRules.json",
@@ -272,7 +229,6 @@ func TestConfigService_DecodeFileName(t *testing.T) {
})
}
// Test invalid filename
decoder := service.DecodeFileName("invalid.json")
if decoder != nil {
t.Fatal("Expected nil decoder for invalid filename, got non-nil")
@@ -280,22 +236,18 @@ func TestConfigService_DecodeFileName(t *testing.T) {
}
func TestConfigService_IntString_Conversion(t *testing.T) {
// Test IntString unmarshaling from string
var intStr model.IntString
// Test string input
err := json.Unmarshal([]byte(`"123"`), &intStr)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 123, intStr.ToInt())
tests.AssertEqual(t, "123", intStr.ToString())
// Test int input
err = json.Unmarshal([]byte(`456`), &intStr)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 456, intStr.ToInt())
tests.AssertEqual(t, "456", intStr.ToString())
// Test empty string
err = json.Unmarshal([]byte(`""`), &intStr)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 0, intStr.ToInt())
@@ -303,28 +255,23 @@ func TestConfigService_IntString_Conversion(t *testing.T) {
}
func TestConfigService_IntBool_Conversion(t *testing.T) {
// Test IntBool unmarshaling from int
var intBool model.IntBool
// Test int input (1 = true)
err := json.Unmarshal([]byte(`1`), &intBool)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 1, intBool.ToInt())
tests.AssertEqual(t, true, intBool.ToBool())
// Test int input (0 = false)
err = json.Unmarshal([]byte(`0`), &intBool)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 0, intBool.ToInt())
tests.AssertEqual(t, false, intBool.ToBool())
// Test bool input (true)
err = json.Unmarshal([]byte(`true`), &intBool)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 1, intBool.ToInt())
tests.AssertEqual(t, true, intBool.ToBool())
// Test bool input (false)
err = json.Unmarshal([]byte(`false`), &intBool)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 0, intBool.ToInt())
@@ -332,25 +279,20 @@ func TestConfigService_IntBool_Conversion(t *testing.T) {
}
func TestConfigService_Caching_Configuration(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create test config files (already UTF-16 encoded)
err := helper.CreateTestConfigFiles()
tests.AssertNoError(t, err)
// Create repositories and service
configRepo := repository.NewConfigRepository(helper.DB)
serverRepo := repository.NewServerRepository(helper.DB)
configService := service.NewConfigService(configRepo, serverRepo)
// First call - should load from disk
config1, err := configService.GetConfiguration(helper.TestData.Server)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, config1)
// Modify the file on disk with UTF-16 encoding
configPath := filepath.Join(helper.TestData.Server.Path, "cfg", "configuration.json")
modifiedContent := `{"udpPort": "5555", "tcpPort": "5556"}`
utf16Modified, err := service.EncodeUTF16LEBOM([]byte(modifiedContent))
@@ -359,36 +301,29 @@ func TestConfigService_Caching_Configuration(t *testing.T) {
err = os.WriteFile(configPath, utf16Modified, 0644)
tests.AssertNoError(t, err)
// Second call - should return cached result (not the modified file)
config2, err := configService.GetConfiguration(helper.TestData.Server)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, config2)
// Should still have the original cached values
tests.AssertEqual(t, model.IntString(9231), config2.UdpPort)
tests.AssertEqual(t, model.IntString(9232), config2.TcpPort)
}
func TestConfigService_Caching_EventConfig(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create test config files (already UTF-16 encoded)
err := helper.CreateTestConfigFiles()
tests.AssertNoError(t, err)
// Create repositories and service
configRepo := repository.NewConfigRepository(helper.DB)
serverRepo := repository.NewServerRepository(helper.DB)
configService := service.NewConfigService(configRepo, serverRepo)
// First call - should load from disk
event1, err := configService.GetEventConfig(helper.TestData.Server)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, event1)
// Modify the file on disk with UTF-16 encoding
configPath := filepath.Join(helper.TestData.Server.Path, "cfg", "event.json")
modifiedContent := `{"track": "monza", "preRaceWaitingTimeSeconds": "60"}`
utf16Modified, err := service.EncodeUTF16LEBOM([]byte(modifiedContent))
@@ -397,12 +332,10 @@ func TestConfigService_Caching_EventConfig(t *testing.T) {
err = os.WriteFile(configPath, utf16Modified, 0644)
tests.AssertNoError(t, err)
// Second call - should return cached result (not the modified file)
event2, err := configService.GetEventConfig(helper.TestData.Server)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, event2)
// Should still have the original cached values
tests.AssertEqual(t, "spa", event2.Track)
tests.AssertEqual(t, model.IntString(80), event2.PreRaceWaitingTimeSeconds)
}

View File

@@ -15,11 +15,9 @@ import (
)
func TestStateHistoryService_GetAll_Success(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -27,22 +25,17 @@ func TestStateHistoryService_GetAll_Success(t *testing.T) {
}
}
// Use real repository like other service tests
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
// Insert test data directly into DB
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
err := repo.Insert(helper.CreateContext(), &history)
tests.AssertNoError(t, err)
// Create proper Fiber context
app := fiber.New()
ctx := helper.CreateFiberCtx()
defer helper.ReleaseFiberCtx(app, ctx)
// Test GetAll
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
result, err := stateHistoryService.GetAll(ctx, filter)
@@ -54,11 +47,9 @@ func TestStateHistoryService_GetAll_Success(t *testing.T) {
}
func TestStateHistoryService_GetAll_WithFilter(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -69,7 +60,6 @@ func TestStateHistoryService_GetAll_WithFilter(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
// Insert test data with different sessions
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
raceHistory := testData.CreateStateHistory(model.SessionRace, "spa", 10, uuid.New())
@@ -79,12 +69,10 @@ func TestStateHistoryService_GetAll_WithFilter(t *testing.T) {
err = repo.Insert(helper.CreateContext(), &raceHistory)
tests.AssertNoError(t, err)
// Create proper Fiber context
app := fiber.New()
ctx := helper.CreateFiberCtx()
defer helper.ReleaseFiberCtx(app, ctx)
// Test GetAll with session filter
filter := testdata.CreateFilterWithSession(helper.TestData.ServerID.String(), model.SessionRace)
result, err := stateHistoryService.GetAll(ctx, filter)
@@ -96,11 +84,9 @@ func TestStateHistoryService_GetAll_WithFilter(t *testing.T) {
}
func TestStateHistoryService_GetAll_NoData(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -111,12 +97,10 @@ func TestStateHistoryService_GetAll_NoData(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
// Create proper Fiber context
app := fiber.New()
ctx := helper.CreateFiberCtx()
defer helper.ReleaseFiberCtx(app, ctx)
// Test GetAll with no data
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
result, err := stateHistoryService.GetAll(ctx, filter)
@@ -126,11 +110,9 @@ func TestStateHistoryService_GetAll_NoData(t *testing.T) {
}
func TestStateHistoryService_Insert_Success(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -141,20 +123,16 @@ func TestStateHistoryService_Insert_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
// Create test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
// Create proper Fiber context
app := fiber.New()
ctx := helper.CreateFiberCtx()
defer helper.ReleaseFiberCtx(app, ctx)
// Test Insert
err := stateHistoryService.Insert(ctx, &history)
tests.AssertNoError(t, err)
// Verify data was inserted
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
result, err := stateHistoryService.GetAll(ctx, filter)
tests.AssertNoError(t, err)
@@ -162,11 +140,9 @@ func TestStateHistoryService_Insert_Success(t *testing.T) {
}
func TestStateHistoryService_GetLastSessionID_Success(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -177,30 +153,25 @@ func TestStateHistoryService_GetLastSessionID_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
// Insert test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
sessionID := uuid.New()
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, sessionID)
err := repo.Insert(helper.CreateContext(), &history)
tests.AssertNoError(t, err)
// Create proper Fiber context
app := fiber.New()
ctx := helper.CreateFiberCtx()
defer helper.ReleaseFiberCtx(app, ctx)
// Test GetLastSessionID
lastSessionID, err := stateHistoryService.GetLastSessionID(ctx, helper.TestData.ServerID)
tests.AssertNoError(t, err)
tests.AssertEqual(t, sessionID, lastSessionID)
}
func TestStateHistoryService_GetLastSessionID_NoData(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -211,26 +182,21 @@ func TestStateHistoryService_GetLastSessionID_NoData(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
// Create proper Fiber context
app := fiber.New()
ctx := helper.CreateFiberCtx()
defer helper.ReleaseFiberCtx(app, ctx)
// Test GetLastSessionID with no data
lastSessionID, err := stateHistoryService.GetLastSessionID(ctx, helper.TestData.ServerID)
tests.AssertNoError(t, err)
tests.AssertEqual(t, uuid.Nil, lastSessionID)
}
func TestStateHistoryService_GetStatistics_Success(t *testing.T) {
// This test might fail due to database setup issues
t.Skip("Skipping test as it's dependent on database migration")
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -241,10 +207,8 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
// Insert test data with varying player counts
_ = testdata.NewStateHistoryTestData(helper.TestData.ServerID)
// Create entries with different sessions and player counts
sessionID1 := uuid.New()
sessionID2 := uuid.New()
@@ -291,12 +255,10 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) {
tests.AssertNoError(t, err)
}
// Create proper Fiber context
app := fiber.New()
ctx := helper.CreateFiberCtx()
defer helper.ReleaseFiberCtx(app, ctx)
// Test GetStatistics
filter := &model.StateHistoryFilter{
ServerBasedFilter: model.ServerBasedFilter{
ServerID: helper.TestData.ServerID.String(),
@@ -311,17 +273,14 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) {
tests.AssertNoError(t, err)
tests.AssertNotNil(t, stats)
// Verify statistics
tests.AssertEqual(t, 15, stats.PeakPlayers) // Maximum player count
tests.AssertEqual(t, 2, stats.TotalSessions) // Two unique sessions
tests.AssertEqual(t, 15, stats.PeakPlayers)
tests.AssertEqual(t, 2, stats.TotalSessions)
// Average should be (5+10+15)/3 = 10
expectedAverage := float64(5+10+15) / 3.0
if stats.AveragePlayers != expectedAverage {
t.Errorf("Expected average players %.1f, got %.1f", expectedAverage, stats.AveragePlayers)
}
// Verify other statistics components exist
tests.AssertNotNil(t, stats.PlayerCountOverTime)
tests.AssertNotNil(t, stats.SessionTypes)
tests.AssertNotNil(t, stats.DailyActivity)
@@ -329,14 +288,11 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) {
}
func TestStateHistoryService_GetStatistics_NoData(t *testing.T) {
// This test might fail due to database setup issues
t.Skip("Skipping test as it's dependent on database migration")
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -347,19 +303,16 @@ func TestStateHistoryService_GetStatistics_NoData(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
// Create proper Fiber context
app := fiber.New()
ctx := helper.CreateFiberCtx()
defer helper.ReleaseFiberCtx(app, ctx)
// Test GetStatistics with no data
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
stats, err := stateHistoryService.GetStatistics(ctx, filter)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, stats)
// Verify empty statistics
tests.AssertEqual(t, 0, stats.PeakPlayers)
tests.AssertEqual(t, 0.0, stats.AveragePlayers)
tests.AssertEqual(t, 0, stats.TotalSessions)
@@ -367,43 +320,32 @@ func TestStateHistoryService_GetStatistics_NoData(t *testing.T) {
}
func TestStateHistoryService_LogParsingWorkflow(t *testing.T) {
// Skip this test as it's unreliable and not critical
t.Skip("Skipping log parsing test as it's not critical to the service functionality")
// This test simulates the actual log parsing workflow
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Insert test server
err := helper.InsertTestServer()
tests.AssertNoError(t, err)
server := helper.TestData.Server
// Track state changes
var stateChanges []*model.ServerState
onStateChange := func(state *model.ServerState, changes ...tracking.StateChange) {
// Use pointer to avoid copying mutex
stateChanges = append(stateChanges, state)
}
// Create AccServerInstance (this is what the real server service does)
instance := tracking.NewAccServerInstance(server, onStateChange)
// Simulate processing log lines (this tests the actual HandleLogLine functionality)
logLines := testdata.SampleLogLines
for _, line := range logLines {
instance.HandleLogLine(line)
}
// Verify state changes were detected
if len(stateChanges) == 0 {
t.Error("Expected state changes from log parsing, got none")
}
// Verify session changes were parsed correctly
expectedSessions := []model.TrackSession{model.SessionPractice, model.SessionQualify, model.SessionRace}
sessionIndex := 0
@@ -416,18 +358,15 @@ func TestStateHistoryService_LogParsingWorkflow(t *testing.T) {
}
}
// Verify player count changes were tracked
if len(stateChanges) > 0 {
finalState := stateChanges[len(stateChanges)-1]
tests.AssertEqual(t, 0, finalState.PlayerCount) // Should end with 0 players
tests.AssertEqual(t, 0, finalState.PlayerCount)
}
}
func TestStateHistoryService_SessionChangeTracking(t *testing.T) {
// Skip this test as it's unreliable
t.Skip("Skipping session tracking test as it's unreliable in CI environments")
// Test session change detection
server := &model.Server{
ID: uuid.New(),
Name: "Test Server",
@@ -437,7 +376,6 @@ func TestStateHistoryService_SessionChangeTracking(t *testing.T) {
onStateChange := func(state *model.ServerState, changes ...tracking.StateChange) {
for _, change := range changes {
if change == tracking.Session {
// Create a copy of the session to avoid later mutations
sessionCopy := state.Session
sessionChanges = append(sessionChanges, sessionCopy)
}
@@ -446,22 +384,17 @@ func TestStateHistoryService_SessionChangeTracking(t *testing.T) {
instance := tracking.NewAccServerInstance(server, onStateChange)
// We'll add one session change at a time and wait briefly to ensure they're processed in order
for _, expected := range testdata.ExpectedSessionChanges {
line := string("[2024-01-15 14:30:25.123] Session changed: " + expected.From + " -> " + expected.To)
instance.HandleLogLine(line)
// Small pause to ensure log processing completes
time.Sleep(10 * time.Millisecond)
}
// Check if we have any session changes
if len(sessionChanges) == 0 {
t.Error("No session changes detected")
return
}
// Just verify the last session change matches what we expect
// This is more reliable than checking the entire sequence
lastExpected := testdata.ExpectedSessionChanges[len(testdata.ExpectedSessionChanges)-1].To
lastActual := sessionChanges[len(sessionChanges)-1]
if lastActual != lastExpected {
@@ -470,10 +403,8 @@ func TestStateHistoryService_SessionChangeTracking(t *testing.T) {
}
func TestStateHistoryService_PlayerCountTracking(t *testing.T) {
// Skip this test as it's unreliable
t.Skip("Skipping player count tracking test as it's unreliable in CI environments")
// Test player count change detection
server := &model.Server{
ID: uuid.New(),
Name: "Test Server",
@@ -490,7 +421,6 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) {
instance := tracking.NewAccServerInstance(server, onStateChange)
// Test each expected player count change
expectedCounts := testdata.ExpectedPlayerCounts
logLines := []string{
"[2024-01-15 14:30:30.456] 1 client(s) online",
@@ -499,7 +429,7 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) {
"[2024-01-15 14:35:05.789] 8 client(s) online",
"[2024-01-15 14:40:05.456] 12 client(s) online",
"[2024-01-15 14:45:00.789] 15 client(s) online",
"[2024-01-15 14:50:00.789] Removing dead connection", // Should decrease by 1
"[2024-01-15 14:50:00.789] Removing dead connection",
"[2024-01-15 15:00:00.789] 0 client(s) online",
}
@@ -507,7 +437,6 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) {
instance.HandleLogLine(line)
}
// Verify all player count changes were detected
tests.AssertEqual(t, len(expectedCounts), len(playerCounts))
for i, expected := range expectedCounts {
if i < len(playerCounts) {
@@ -517,10 +446,8 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) {
}
func TestStateHistoryService_EdgeCases(t *testing.T) {
// Skip this test as it's unreliable
t.Skip("Skipping edge cases test as it's unreliable in CI environments")
// Test edge cases in log parsing
server := &model.Server{
ID: uuid.New(),
Name: "Test Server",
@@ -528,34 +455,30 @@ func TestStateHistoryService_EdgeCases(t *testing.T) {
var stateChanges []*model.ServerState
onStateChange := func(state *model.ServerState, changes ...tracking.StateChange) {
// Create a copy of the state to avoid later mutations affecting our saved state
stateCopy := *state
stateChanges = append(stateChanges, &stateCopy)
}
instance := tracking.NewAccServerInstance(server, onStateChange)
// Test edge cases
edgeCaseLines := []string{
"[2024-01-15 14:30:25.123] Some unrelated log line", // Should be ignored
"[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE", // Valid session change
"[2024-01-15 14:30:30.456] 0 client(s) online", // Zero players
"[2024-01-15 14:30:35.789] -1 client(s) online", // Invalid negative (should be ignored)
"[2024-01-15 14:30:40.789] 30 client(s) online", // High but valid player count
"[2024-01-15 14:30:45.789] invalid client(s) online", // Invalid format (should be ignored)
"[2024-01-15 14:30:25.123] Some unrelated log line",
"[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE",
"[2024-01-15 14:30:30.456] 0 client(s) online",
"[2024-01-15 14:30:35.789] -1 client(s) online",
"[2024-01-15 14:30:40.789] 30 client(s) online",
"[2024-01-15 14:30:45.789] invalid client(s) online",
}
for _, line := range edgeCaseLines {
instance.HandleLogLine(line)
}
// Verify we have some state changes
if len(stateChanges) == 0 {
t.Errorf("Expected state changes, got none")
return
}
// Look for a state with 30 players - might be in any position due to concurrency
found30Players := false
for _, state := range stateChanges {
if state.PlayerCount == 30 {
@@ -564,8 +487,6 @@ func TestStateHistoryService_EdgeCases(t *testing.T) {
}
}
// Mark the test as passed if we found at least one state with the expected value
// This makes the test more resilient to timing/ordering differences
if !found30Players {
t.Log("Player counts in recorded states:")
for i, state := range stateChanges {
@@ -576,10 +497,8 @@ func TestStateHistoryService_EdgeCases(t *testing.T) {
}
func TestStateHistoryService_SessionStartTracking(t *testing.T) {
// Skip this test as it's unreliable
t.Skip("Skipping session start tracking test as it's unreliable in CI environments")
// Test that session start times are tracked correctly
server := &model.Server{
ID: uuid.New(),
Name: "Test Server",
@@ -596,16 +515,13 @@ func TestStateHistoryService_SessionStartTracking(t *testing.T) {
instance := tracking.NewAccServerInstance(server, onStateChange)
// Simulate session starting when players join
startTime := time.Now()
instance.HandleLogLine("[2024-01-15 14:30:30.456] 1 client(s) online") // First player joins
instance.HandleLogLine("[2024-01-15 14:30:30.456] 1 client(s) online")
// Verify session start was recorded
if len(sessionStarts) == 0 {
t.Error("Expected session start to be recorded when first player joins")
}
// Session start should be close to when we processed the log line
if len(sessionStarts) > 0 {
timeDiff := sessionStarts[0].Sub(startTime)
if timeDiff > time.Second || timeDiff < -time.Second {